fix(http_listener_v2): fix panic on close (#10132)

This commit is contained in:
Patryk Małek 2021-12-10 21:14:16 +01:00 committed by GitHub
parent 039c9683fd
commit 1b9572085b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 29 deletions

View File

@ -4,6 +4,7 @@ import (
"compress/gzip"
"crypto/subtle"
"crypto/tls"
"errors"
"io"
"net"
"net/http"
@ -50,12 +51,15 @@ type HTTPListenerV2 struct {
BasicUsername string `toml:"basic_username"`
BasicPassword string `toml:"basic_password"`
HTTPHeaderTags map[string]string `toml:"http_header_tags"`
tlsint.ServerConfig
tlsConf *tls.Config
TimeFunc
Log telegraf.Logger
wg sync.WaitGroup
wg sync.WaitGroup
close chan struct{}
listener net.Listener
@ -154,17 +158,48 @@ func (h *HTTPListenerV2) Start(acc telegraf.Accumulator) error {
h.acc = acc
tlsConf, err := h.ServerConfig.TLSConfig()
if err != nil {
return err
}
server := h.createHTTPServer()
server := &http.Server{
h.wg.Add(1)
go func() {
defer h.wg.Done()
if err := server.Serve(h.listener); err != nil {
if !errors.Is(err, net.ErrClosed) {
h.Log.Errorf("Serve failed: %v", err)
}
close(h.close)
}
}()
h.Log.Infof("Listening on %s", h.listener.Addr().String())
return nil
}
func (h *HTTPListenerV2) createHTTPServer() *http.Server {
return &http.Server{
Addr: h.ServiceAddress,
Handler: h,
ReadTimeout: time.Duration(h.ReadTimeout),
WriteTimeout: time.Duration(h.WriteTimeout),
TLSConfig: tlsConf,
TLSConfig: h.tlsConf,
}
}
// Stop cleans up all resources
func (h *HTTPListenerV2) Stop() {
if h.listener != nil {
// Ignore the returned error as we cannot do anything about it anyway
//nolint:errcheck,revive
h.listener.Close()
}
h.wg.Wait()
}
func (h *HTTPListenerV2) Init() error {
tlsConf, err := h.ServerConfig.TLSConfig()
if err != nil {
return err
}
var listener net.Listener
@ -176,32 +211,13 @@ func (h *HTTPListenerV2) Start(acc telegraf.Accumulator) error {
if err != nil {
return err
}
h.tlsConf = tlsConf
h.listener = listener
h.Port = listener.Addr().(*net.TCPAddr).Port
h.wg.Add(1)
go func() {
defer h.wg.Done()
if err := server.Serve(h.listener); err != nil {
h.Log.Errorf("Serve failed: %v", err)
}
}()
h.Log.Infof("Listening on %s", listener.Addr().String())
return nil
}
// Stop cleans up all resources
func (h *HTTPListenerV2) Stop() {
if h.listener != nil {
// Ignore the returned error as we cannot do anything about it anyway
//nolint:errcheck,revive
h.listener.Close()
}
h.wg.Wait()
}
func (h *HTTPListenerV2) ServeHTTP(res http.ResponseWriter, req *http.Request) {
handler := h.serveWrite
@ -213,6 +229,13 @@ func (h *HTTPListenerV2) ServeHTTP(res http.ResponseWriter, req *http.Request) {
}
func (h *HTTPListenerV2) serveWrite(res http.ResponseWriter, req *http.Request) {
select {
case <-h.close:
res.WriteHeader(http.StatusGone)
return
default:
}
// Check that the content length is not too large for us to handle.
if req.ContentLength > int64(h.MaxBodySize) {
if err := tooLarge(res); err != nil {
@ -393,6 +416,7 @@ func init() {
Paths: []string{"/telegraf"},
Methods: []string{"POST", "PUT"},
DataSource: body,
close: make(chan struct{}),
}
})
}

View File

@ -56,6 +56,7 @@ func newTestHTTPListenerV2() *HTTPListenerV2 {
TimeFunc: time.Now,
MaxBodySize: config.Size(70000),
DataSource: "body",
close: make(chan struct{}),
}
return listener
}
@ -78,6 +79,7 @@ func newTestHTTPSListenerV2() *HTTPListenerV2 {
Parser: parser,
ServerConfig: *pki.TLSServerConfig(),
TimeFunc: time.Now,
close: make(chan struct{}),
}
return listener
@ -117,10 +119,10 @@ func TestInvalidListenerConfig(t *testing.T) {
TimeFunc: time.Now,
MaxBodySize: config.Size(70000),
DataSource: "body",
close: make(chan struct{}),
}
acc := &testutil.Accumulator{}
require.Error(t, listener.Start(acc))
require.Error(t, listener.Init())
// Stop is called when any ServiceInput fails to start; it must succeed regardless of state
listener.Stop()
@ -131,6 +133,7 @@ func TestWriteHTTPSNoClientAuth(t *testing.T) {
listener.TLSAllowedCACerts = nil
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -155,6 +158,7 @@ func TestWriteHTTPSWithClientAuth(t *testing.T) {
listener := newTestHTTPSListenerV2()
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -169,6 +173,7 @@ func TestWriteHTTPBasicAuth(t *testing.T) {
listener := newTestHTTPAuthListener()
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -187,6 +192,7 @@ func TestWriteHTTP(t *testing.T) {
listener := newTestHTTPListenerV2()
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -237,6 +243,7 @@ func TestWriteHTTPWithPathTag(t *testing.T) {
listener.PathTag = true
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -260,6 +267,7 @@ func TestWriteHTTPWithMultiplePaths(t *testing.T) {
listener.PathTag = true
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -292,6 +300,7 @@ func TestWriteHTTPNoNewline(t *testing.T) {
listener := newTestHTTPListenerV2()
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -319,9 +328,11 @@ func TestWriteHTTPExactMaxBodySize(t *testing.T) {
Parser: parser,
MaxBodySize: config.Size(len(hugeMetric)),
TimeFunc: time.Now,
close: make(chan struct{}),
}
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -342,9 +353,11 @@ func TestWriteHTTPVerySmallMaxBody(t *testing.T) {
Parser: parser,
MaxBodySize: config.Size(4096),
TimeFunc: time.Now,
close: make(chan struct{}),
}
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -359,6 +372,8 @@ func TestWriteHTTPGzippedData(t *testing.T) {
listener := newTestHTTPListenerV2()
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -391,6 +406,7 @@ func TestWriteHTTPSnappyData(t *testing.T) {
listener := newTestHTTPListenerV2()
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -429,6 +445,7 @@ func TestWriteHTTPHighTraffic(t *testing.T) {
listener := newTestHTTPListenerV2()
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -464,6 +481,7 @@ func TestReceive404ForInvalidEndpoint(t *testing.T) {
listener := newTestHTTPListenerV2()
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -478,6 +496,7 @@ func TestWriteHTTPInvalid(t *testing.T) {
listener := newTestHTTPListenerV2()
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -492,6 +511,7 @@ func TestWriteHTTPEmpty(t *testing.T) {
listener := newTestHTTPListenerV2()
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -507,6 +527,7 @@ func TestWriteHTTPTransformHeaderValuesToTagsSingleWrite(t *testing.T) {
listener.HTTPHeaderTags = map[string]string{"Present_http_header_1": "presentMeasurementKey1", "present_http_header_2": "presentMeasurementKey2", "NOT_PRESENT_HEADER": "notPresentMeasurementKey"}
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -545,6 +566,7 @@ func TestWriteHTTPTransformHeaderValuesToTagsBulkWrite(t *testing.T) {
listener.HTTPHeaderTags = map[string]string{"Present_http_header_1": "presentMeasurementKey1", "Present_http_header_2": "presentMeasurementKey2", "NOT_PRESENT_HEADER": "notPresentMeasurementKey"}
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -576,6 +598,7 @@ func TestWriteHTTPQueryParams(t *testing.T) {
listener.Parser = parser
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()
@ -597,6 +620,7 @@ func TestWriteHTTPFormData(t *testing.T) {
listener.Parser = parser
acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()