diff --git a/plugins/inputs/http_listener_v2/http_listener_v2.go b/plugins/inputs/http_listener_v2/http_listener_v2.go index d2a2e5f35..85dbf89f1 100644 --- a/plugins/inputs/http_listener_v2/http_listener_v2.go +++ b/plugins/inputs/http_listener_v2/http_listener_v2.go @@ -4,6 +4,7 @@ import ( "compress/gzip" "crypto/subtle" "crypto/tls" + "errors" "io" "net" "net/http" @@ -50,12 +51,15 @@ type HTTPListenerV2 struct { BasicUsername string `toml:"basic_username"` BasicPassword string `toml:"basic_password"` HTTPHeaderTags map[string]string `toml:"http_header_tags"` + tlsint.ServerConfig + tlsConf *tls.Config TimeFunc Log telegraf.Logger - wg sync.WaitGroup + wg sync.WaitGroup + close chan struct{} listener net.Listener @@ -154,17 +158,48 @@ func (h *HTTPListenerV2) Start(acc telegraf.Accumulator) error { h.acc = acc - tlsConf, err := h.ServerConfig.TLSConfig() - if err != nil { - return err - } + server := h.createHTTPServer() - server := &http.Server{ + h.wg.Add(1) + go func() { + defer h.wg.Done() + if err := server.Serve(h.listener); err != nil { + if !errors.Is(err, net.ErrClosed) { + h.Log.Errorf("Serve failed: %v", err) + } + close(h.close) + } + }() + + h.Log.Infof("Listening on %s", h.listener.Addr().String()) + + return nil +} + +func (h *HTTPListenerV2) createHTTPServer() *http.Server { + return &http.Server{ Addr: h.ServiceAddress, Handler: h, ReadTimeout: time.Duration(h.ReadTimeout), WriteTimeout: time.Duration(h.WriteTimeout), - TLSConfig: tlsConf, + TLSConfig: h.tlsConf, + } +} + +// Stop cleans up all resources +func (h *HTTPListenerV2) Stop() { + if h.listener != nil { + // Ignore the returned error as we cannot do anything about it anyway + //nolint:errcheck,revive + h.listener.Close() + } + h.wg.Wait() +} + +func (h *HTTPListenerV2) Init() error { + tlsConf, err := h.ServerConfig.TLSConfig() + if err != nil { + return err } var listener net.Listener @@ -176,32 +211,13 @@ func (h *HTTPListenerV2) Start(acc telegraf.Accumulator) error { if err != nil { return err } + h.tlsConf = tlsConf h.listener = listener h.Port = listener.Addr().(*net.TCPAddr).Port - h.wg.Add(1) - go func() { - defer h.wg.Done() - if err := server.Serve(h.listener); err != nil { - h.Log.Errorf("Serve failed: %v", err) - } - }() - - h.Log.Infof("Listening on %s", listener.Addr().String()) - return nil } -// Stop cleans up all resources -func (h *HTTPListenerV2) Stop() { - if h.listener != nil { - // Ignore the returned error as we cannot do anything about it anyway - //nolint:errcheck,revive - h.listener.Close() - } - h.wg.Wait() -} - func (h *HTTPListenerV2) ServeHTTP(res http.ResponseWriter, req *http.Request) { handler := h.serveWrite @@ -213,6 +229,13 @@ func (h *HTTPListenerV2) ServeHTTP(res http.ResponseWriter, req *http.Request) { } func (h *HTTPListenerV2) serveWrite(res http.ResponseWriter, req *http.Request) { + select { + case <-h.close: + res.WriteHeader(http.StatusGone) + return + default: + } + // Check that the content length is not too large for us to handle. if req.ContentLength > int64(h.MaxBodySize) { if err := tooLarge(res); err != nil { @@ -393,6 +416,7 @@ func init() { Paths: []string{"/telegraf"}, Methods: []string{"POST", "PUT"}, DataSource: body, + close: make(chan struct{}), } }) } diff --git a/plugins/inputs/http_listener_v2/http_listener_v2_test.go b/plugins/inputs/http_listener_v2/http_listener_v2_test.go index bf320d6f0..ddbb5be64 100644 --- a/plugins/inputs/http_listener_v2/http_listener_v2_test.go +++ b/plugins/inputs/http_listener_v2/http_listener_v2_test.go @@ -56,6 +56,7 @@ func newTestHTTPListenerV2() *HTTPListenerV2 { TimeFunc: time.Now, MaxBodySize: config.Size(70000), DataSource: "body", + close: make(chan struct{}), } return listener } @@ -78,6 +79,7 @@ func newTestHTTPSListenerV2() *HTTPListenerV2 { Parser: parser, ServerConfig: *pki.TLSServerConfig(), TimeFunc: time.Now, + close: make(chan struct{}), } return listener @@ -117,10 +119,10 @@ func TestInvalidListenerConfig(t *testing.T) { TimeFunc: time.Now, MaxBodySize: config.Size(70000), DataSource: "body", + close: make(chan struct{}), } - acc := &testutil.Accumulator{} - require.Error(t, listener.Start(acc)) + require.Error(t, listener.Init()) // Stop is called when any ServiceInput fails to start; it must succeed regardless of state listener.Stop() @@ -131,6 +133,7 @@ func TestWriteHTTPSNoClientAuth(t *testing.T) { listener.TLSAllowedCACerts = nil acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -155,6 +158,7 @@ func TestWriteHTTPSWithClientAuth(t *testing.T) { listener := newTestHTTPSListenerV2() acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -169,6 +173,7 @@ func TestWriteHTTPBasicAuth(t *testing.T) { listener := newTestHTTPAuthListener() acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -187,6 +192,7 @@ func TestWriteHTTP(t *testing.T) { listener := newTestHTTPListenerV2() acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -237,6 +243,7 @@ func TestWriteHTTPWithPathTag(t *testing.T) { listener.PathTag = true acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -260,6 +267,7 @@ func TestWriteHTTPWithMultiplePaths(t *testing.T) { listener.PathTag = true acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -292,6 +300,7 @@ func TestWriteHTTPNoNewline(t *testing.T) { listener := newTestHTTPListenerV2() acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -319,9 +328,11 @@ func TestWriteHTTPExactMaxBodySize(t *testing.T) { Parser: parser, MaxBodySize: config.Size(len(hugeMetric)), TimeFunc: time.Now, + close: make(chan struct{}), } acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -342,9 +353,11 @@ func TestWriteHTTPVerySmallMaxBody(t *testing.T) { Parser: parser, MaxBodySize: config.Size(4096), TimeFunc: time.Now, + close: make(chan struct{}), } acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -359,6 +372,8 @@ func TestWriteHTTPGzippedData(t *testing.T) { listener := newTestHTTPListenerV2() acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -391,6 +406,7 @@ func TestWriteHTTPSnappyData(t *testing.T) { listener := newTestHTTPListenerV2() acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -429,6 +445,7 @@ func TestWriteHTTPHighTraffic(t *testing.T) { listener := newTestHTTPListenerV2() acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -464,6 +481,7 @@ func TestReceive404ForInvalidEndpoint(t *testing.T) { listener := newTestHTTPListenerV2() acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -478,6 +496,7 @@ func TestWriteHTTPInvalid(t *testing.T) { listener := newTestHTTPListenerV2() acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -492,6 +511,7 @@ func TestWriteHTTPEmpty(t *testing.T) { listener := newTestHTTPListenerV2() acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -507,6 +527,7 @@ func TestWriteHTTPTransformHeaderValuesToTagsSingleWrite(t *testing.T) { listener.HTTPHeaderTags = map[string]string{"Present_http_header_1": "presentMeasurementKey1", "present_http_header_2": "presentMeasurementKey2", "NOT_PRESENT_HEADER": "notPresentMeasurementKey"} acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -545,6 +566,7 @@ func TestWriteHTTPTransformHeaderValuesToTagsBulkWrite(t *testing.T) { listener.HTTPHeaderTags = map[string]string{"Present_http_header_1": "presentMeasurementKey1", "Present_http_header_2": "presentMeasurementKey2", "NOT_PRESENT_HEADER": "notPresentMeasurementKey"} acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -576,6 +598,7 @@ func TestWriteHTTPQueryParams(t *testing.T) { listener.Parser = parser acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop() @@ -597,6 +620,7 @@ func TestWriteHTTPFormData(t *testing.T) { listener.Parser = parser acc := &testutil.Accumulator{} + require.NoError(t, listener.Init()) require.NoError(t, listener.Start(acc)) defer listener.Stop()