fix(inputs.socket_listener): Fix loss of connection tracking (#13056)

This commit is contained in:
Patrick Hemmer 2023-05-17 14:34:53 -04:00 committed by GitHub
parent f0dc15fd9c
commit f098e5f9f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 20 deletions

View File

@ -16,7 +16,6 @@ import (
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/influxdata/telegraf" "github.com/influxdata/telegraf"
@ -184,6 +183,10 @@ func TestSocketListener(t *testing.T) {
actual := acc.GetTelegrafMetrics() actual := acc.GetTelegrafMetrics()
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics()) testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
if sl, ok := plugin.listener.(*streamListener); ok {
require.NotEmpty(t, sl.connections)
}
plugin.Stop() plugin.Stop()
if _, ok := plugin.listener.(*streamListener); ok { if _, ok := plugin.listener.(*streamListener); ok {
@ -191,7 +194,7 @@ func TestSocketListener(t *testing.T) {
_ = client.SetReadDeadline(time.Now().Add(time.Second)) _ = client.SetReadDeadline(time.Now().Add(time.Second))
buf := []byte{1} buf := []byte{1}
_, err = client.Read(buf) _, err = client.Read(buf)
assert.Equal(t, err, io.EOF) require.Equal(t, err, io.EOF)
} }
}) })
} }

View File

@ -86,6 +86,17 @@ func (l *streamListener) setupConnection(conn net.Conn) error {
conn = c.NetConn() conn = c.NetConn()
} }
addr := conn.RemoteAddr().String()
l.Lock()
if l.MaxConnections > 0 && len(l.connections) >= l.MaxConnections {
l.Unlock()
// Ignore the returned error as we cannot do anything about it anyway
_ = conn.Close()
return fmt.Errorf("unable to accept connection from %q: too many connections", addr)
}
l.connections[conn] = struct{}{}
l.Unlock()
if l.ReadBufferSize > 0 { if l.ReadBufferSize > 0 {
if rb, ok := conn.(hasSetReadBuffer); ok { if rb, ok := conn.(hasSetReadBuffer); ok {
if err := rb.SetReadBuffer(l.ReadBufferSize); err != nil { if err := rb.SetReadBuffer(l.ReadBufferSize); err != nil {
@ -96,47 +107,34 @@ func (l *streamListener) setupConnection(conn net.Conn) error {
} }
} }
addr := conn.RemoteAddr().String()
if l.MaxConnections > 0 && len(l.connections) >= l.MaxConnections {
// Ignore the returned error as we cannot do anything about it anyway
_ = conn.Close()
l.Log.Infof("unable to accept connection from %q: too many connections", addr)
return nil
}
// Set keep alive handlings // Set keep alive handlings
if l.KeepAlivePeriod != nil { if l.KeepAlivePeriod != nil {
tcpConn, ok := conn.(*net.TCPConn) tcpConn, ok := conn.(*net.TCPConn)
if !ok { if !ok {
return fmt.Errorf("cannot set keep-alive: not a TCP connection (%T)", conn) l.Log.Warnf("connection not a TCP connection (%T)", conn)
} }
if *l.KeepAlivePeriod == 0 { if *l.KeepAlivePeriod == 0 {
if err := tcpConn.SetKeepAlive(false); err != nil { if err := tcpConn.SetKeepAlive(false); err != nil {
return fmt.Errorf("cannot set keep-alive: %w", err) l.Log.Warnf("Cannot set keep-alive: %w", err)
} }
} else { } else {
if err := tcpConn.SetKeepAlive(true); err != nil { if err := tcpConn.SetKeepAlive(true); err != nil {
return fmt.Errorf("cannot set keep-alive: %w", err) l.Log.Warnf("Cannot set keep-alive: %w", err)
} }
err := tcpConn.SetKeepAlivePeriod(time.Duration(*l.KeepAlivePeriod)) err := tcpConn.SetKeepAlivePeriod(time.Duration(*l.KeepAlivePeriod))
if err != nil { if err != nil {
return fmt.Errorf("cannot set keep-alive period: %w", err) l.Log.Warnf("Cannot set keep-alive period: %w", err)
} }
} }
} }
// Store the connection mapped to its address
l.Lock()
l.connections[conn] = struct{}{}
l.Unlock()
return nil return nil
} }
func (l *streamListener) closeConnection(conn net.Conn) { func (l *streamListener) closeConnection(conn net.Conn) {
addr := conn.RemoteAddr().String() addr := conn.RemoteAddr().String()
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
l.Log.Errorf("Cannot close connection to %q: %v", addr, err) l.Log.Warnf("Cannot close connection to %q: %v", addr, err)
} }
delete(l.connections, conn) delete(l.connections, conn)
} }
@ -185,6 +183,7 @@ func (l *streamListener) listen(acc telegraf.Accumulator) {
if err := l.setupConnection(conn); err != nil { if err := l.setupConnection(conn); err != nil {
acc.AddError(err) acc.AddError(err)
continue
} }
wg.Add(1) wg.Add(1)