fix(common.socket): Switch to context to simplify closing (#15589)

This commit is contained in:
Sven Rebhan 2024-07-17 16:21:03 +02:00 committed by GitHub
parent 90fdcfcdd7
commit 3d9562ba91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 289 additions and 54 deletions

View File

@ -27,7 +27,7 @@ type listener interface {
}
type Config struct {
MaxConnections int `toml:"max_connections"`
MaxConnections uint64 `toml:"max_connections"`
ReadBufferSize config.Size `toml:"read_buffer_size"`
ReadTimeout config.Duration `toml:"read_timeout"`
KeepAlivePeriod *config.Duration `toml:"keep_alive_period"`

View File

@ -2,12 +2,14 @@ package socket
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"os"
"runtime"
"strings"
"sync"
"testing"
"time"
@ -481,7 +483,7 @@ func TestClosingConnections(t *testing.T) {
listener, ok := sock.listener.(*streamListener)
require.True(t, ok)
listener.Lock()
conns := len(listener.connections)
conns := listener.connections
listener.Unlock()
require.NotZero(t, conns)
@ -496,6 +498,130 @@ func TestClosingConnections(t *testing.T) {
require.Empty(t, logger.Errors())
require.Empty(t, logger.Warnings())
}
func TestMaxConnections(t *testing.T) {
if runtime.GOOS == "darwin" {
t.Skip("Skipping on darwin due to missing socket options")
}
// Setup the configuration
period := config.Duration(10 * time.Millisecond)
cfg := &Config{
MaxConnections: 5,
KeepAlivePeriod: &period,
}
// Create the socket
serviceAddress := "tcp://127.0.0.1:0"
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
require.NoError(t, err)
// Create callback
var errs []error
var mu sync.Mutex
onData := func(_ net.Addr, _ []byte) {}
onError := func(err error) {
mu.Lock()
errs = append(errs, err)
mu.Unlock()
}
// Start the listener
require.NoError(t, sock.Setup())
sock.Listen(onData, onError)
defer sock.Close()
addr := sock.Address()
// Create maximum number of connections and write some data. All of this
// should succeed...
clients := make([]*net.TCPConn, 0, cfg.MaxConnections)
for i := 0; i < int(cfg.MaxConnections); i++ {
c, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
require.NoError(t, err)
require.NoError(t, c.SetWriteBuffer(0))
require.NoError(t, c.SetNoDelay(true))
clients = append(clients, c)
_, err = c.Write([]byte("test value=42i\n"))
require.NoError(t, err)
}
func() {
mu.Lock()
defer mu.Unlock()
require.Empty(t, errs)
}()
// Create another client. This should fail because we already reached the
// connection limit and the connection should be closed...
client, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
require.NoError(t, err)
require.NoError(t, client.SetWriteBuffer(0))
require.NoError(t, client.SetNoDelay(true))
require.Eventually(t, func() bool {
mu.Lock()
defer mu.Unlock()
return len(errs) > 0
}, 3*time.Second, 100*time.Millisecond)
func() {
mu.Lock()
defer mu.Unlock()
require.Len(t, errs, 1)
require.ErrorContains(t, errs[0], "too many connections")
errs = make([]error, 0)
}()
require.Eventually(t, func() bool {
_, err := client.Write([]byte("fail\n"))
return err != nil
}, 3*time.Second, 100*time.Millisecond)
_, err = client.Write([]byte("test\n"))
require.Error(t, err)
// Check other connections are still good
for _, c := range clients {
_, err := c.Write([]byte("test\n"))
require.NoError(t, err)
}
func() {
mu.Lock()
defer mu.Unlock()
require.Empty(t, errs)
}()
// Close the first client and check if we can connect now
require.NoError(t, clients[0].Close())
client, err = net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
require.NoError(t, err)
require.NoError(t, client.SetWriteBuffer(0))
require.NoError(t, client.SetNoDelay(true))
_, err = client.Write([]byte("success\n"))
require.NoError(t, err)
// Close all connections
require.NoError(t, client.Close())
for _, c := range clients[1:] {
require.NoError(t, c.Close())
}
// Close the clients and check the connection counter
listener, ok := sock.listener.(*streamListener)
require.True(t, ok)
require.Eventually(t, func() bool {
listener.Lock()
conns := listener.connections
listener.Unlock()
return conns == 0
}, 3*time.Second, 100*time.Millisecond)
// Close the socket and check again...
sock.Close()
listener.Lock()
conns := listener.connections
listener.Unlock()
require.Zero(t, conns)
}
func TestNoSplitter(t *testing.T) {
messages := [][]byte{
@ -605,6 +731,101 @@ func TestNoSplitter(t *testing.T) {
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
}
func TestTLSMemLeak(t *testing.T) {
// For issue https://github.com/influxdata/telegraf/issues/15509
// Prepare the address and socket if needed
serviceAddress := "tcp://127.0.0.1:0"
// Setup a TLS socket to trigger the issue
cfg := &Config{
ServerConfig: *pki.TLSServerConfig(),
}
// Create the socket
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
require.NoError(t, err)
// Create callbacks
onConnection := func(_ net.Addr, reader io.ReadCloser) {
//nolint:errcheck // We are not interested in the data so ignore all errors
io.Copy(io.Discard, reader)
}
// Start the listener
require.NoError(t, sock.Setup())
sock.ListenConnection(onConnection, nil)
defer sock.Close()
addr := sock.Address()
// Setup the client side TLS
tlsCfg, err := pki.TLSClientConfig().TLSConfig()
require.NoError(t, err)
// Define a single client write sequence
data := []byte("test value=42i")
write := func() error {
conn, err := tls.Dial("tcp", addr.String(), tlsCfg)
if err != nil {
return err
}
defer conn.Close()
_, err = conn.Write(data)
return err
}
// Define a test with the given number of connections
maxConcurrency := runtime.GOMAXPROCS(0)
testCycle := func(connections int) (uint64, error) {
var mu sync.Mutex
var errs []error
var wg sync.WaitGroup
for count := 1; count < connections; count++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := write(); err != nil {
mu.Lock()
errs = append(errs, err)
mu.Unlock()
}
}()
if count%maxConcurrency == 0 {
wg.Wait()
mu.Lock()
if len(errs) > 0 {
mu.Unlock()
return 0, errors.Join(errs...)
}
mu.Unlock()
}
}
//nolint:revive // We need to actively run the garbage collector to get reliable measurements
runtime.GC()
var stats runtime.MemStats
runtime.ReadMemStats(&stats)
return stats.HeapObjects, nil
}
// Measure the memory usage after a short warmup and after some time.
// The final number of heap objects should not exceed the number of
// runs by a save margin
// Warmup, do a low number of runs to initialize all data structures
// taking them out of the equation.
initial, err := testCycle(100)
require.NoError(t, err)
// Do some more runs and make sure the memory growth is bound
final, err := testCycle(2000)
require.NoError(t, err)
require.Less(t, final, 2*initial)
}
func createClient(endpoint string, addr net.Addr, tlsCfg *tls.Config) (net.Conn, error) {
// Determine the protocol in a crude fashion
parts := strings.SplitN(endpoint, "://", 2)

View File

@ -2,6 +2,7 @@ package socket
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
@ -32,15 +33,16 @@ type hasSetReadBuffer interface {
type streamListener struct {
Encoding string
ReadBufferSize int
MaxConnections int
MaxConnections uint64
ReadTimeout config.Duration
KeepAlivePeriod *config.Duration
Splitter bufio.SplitFunc
Log telegraf.Logger
listener net.Listener
connections map[net.Conn]bool
connections uint64
path string
cancel context.CancelFunc
wg sync.WaitGroup
sync.Mutex
@ -122,19 +124,15 @@ func (l *streamListener) setupVsock(u *url.URL) error {
}
func (l *streamListener) setupConnection(conn net.Conn) error {
if c, ok := conn.(*tls.Conn); ok {
conn = c.NetConn()
}
addr := conn.RemoteAddr().String()
l.Lock()
if l.MaxConnections > 0 && len(l.connections) >= l.MaxConnections {
if l.MaxConnections > 0 && l.connections >= l.MaxConnections {
l.Unlock()
// Ignore the returned error as we cannot do anything about it anyway
_ = conn.Close()
return fmt.Errorf("unable to accept connection from %q: too many connections", addr)
}
l.connections[conn] = true
l.connections++
l.Unlock()
if l.ReadBufferSize > 0 {
@ -149,6 +147,9 @@ func (l *streamListener) setupConnection(conn net.Conn) error {
// Set keep alive handlings
if l.KeepAlivePeriod != nil {
if c, ok := conn.(*tls.Conn); ok {
conn = c.NetConn()
}
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
l.Log.Warnf("connection not a TCP connection (%T)", conn)
@ -172,11 +173,18 @@ func (l *streamListener) setupConnection(conn net.Conn) error {
}
func (l *streamListener) closeConnection(conn net.Conn) {
// Fallback to enforce blocked reads on connections to end immediately
//nolint:errcheck // Ignore errors as this is a fallback only
conn.SetReadDeadline(time.Now())
addr := conn.RemoteAddr().String()
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, syscall.EPIPE) {
l.Log.Warnf("Cannot close connection to %q: %v", addr, err)
} else {
l.Lock()
l.connections--
l.Unlock()
}
delete(l.connections, conn)
}
func (l *streamListener) address() net.Addr {
@ -184,15 +192,18 @@ func (l *streamListener) address() net.Addr {
}
func (l *streamListener) close() error {
if err := l.listener.Close(); err != nil {
return err
if l.listener != nil {
// Continue even if we cannot close the listener in order to at least
// close all active connections
if err := l.listener.Close(); err != nil {
l.Log.Errorf("Cannot close listener: %v", err)
}
}
l.Lock()
for conn := range l.connections {
l.closeConnection(conn)
if l.cancel != nil {
l.cancel()
l.cancel = nil
}
l.Unlock()
l.wg.Wait()
if l.path != "" {
@ -200,8 +211,8 @@ func (l *streamListener) close() error {
if runtime.GOOS == "windows" && strings.Contains(fn, ":") {
fn = strings.TrimPrefix(fn, `\`)
}
// Ignore file-not-exists errors when removing the socket
if err := os.Remove(fn); err != nil && !errors.Is(err, os.ErrNotExist) {
// Ignore file-not-exists errors when removing the socket
return err
}
}
@ -209,13 +220,13 @@ func (l *streamListener) close() error {
}
func (l *streamListener) listenData(onData CallbackData, onError CallbackError) {
l.connections = make(map[net.Conn]bool)
ctx, cancel := context.WithCancel(context.Background())
l.cancel = cancel
l.wg.Add(1)
go func() {
defer l.wg.Done()
var wg sync.WaitGroup
for {
conn, err := l.listener.Accept()
if err != nil {
@ -230,40 +241,42 @@ func (l *streamListener) listenData(onData CallbackData, onError CallbackError)
continue
}
wg.Add(1)
go func(c net.Conn) {
defer wg.Done()
defer func() {
l.Lock()
l.closeConnection(conn)
l.Unlock()
}()
reader := l.read
if l.Splitter == nil {
reader = l.readAll
}
if err := reader(c, onData); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
if onError != nil {
onError(err)
}
}
}
}(conn)
l.wg.Add(1)
go l.handleReaderConn(ctx, conn, onData, onError)
}
wg.Wait()
}()
}
func (l *streamListener) handleReaderConn(ctx context.Context, conn net.Conn, onData CallbackData, onError CallbackError) {
defer l.wg.Done()
localCtx, cancel := context.WithCancel(ctx)
defer cancel()
defer l.closeConnection(conn)
stopFunc := context.AfterFunc(localCtx, func() { l.closeConnection(conn) })
defer stopFunc()
reader := l.read
if l.Splitter == nil {
reader = l.readAll
}
if err := reader(conn, onData); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
if onError != nil {
onError(err)
}
}
}
}
func (l *streamListener) listenConnection(onConnection CallbackConnection, onError CallbackError) {
l.connections = make(map[net.Conn]bool)
ctx, cancel := context.WithCancel(context.Background())
l.cancel = cancel
l.wg.Add(1)
go func() {
defer l.wg.Done()
var wg sync.WaitGroup
for {
conn, err := l.listener.Accept()
if err != nil {
@ -272,28 +285,22 @@ func (l *streamListener) listenConnection(onConnection CallbackConnection, onErr
}
break
}
if err := l.setupConnection(conn); err != nil && onError != nil {
onError(err)
continue
}
wg.Add(1)
l.wg.Add(1)
go func(c net.Conn) {
defer wg.Done()
if err := l.handleConnection(c, onConnection); err != nil {
if err := l.handleConnection(ctx, c, onConnection); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
if onError != nil {
onError(err)
}
}
}
l.Lock()
l.closeConnection(conn)
l.Unlock()
}(conn)
}
wg.Wait()
}()
}
@ -368,7 +375,6 @@ func (l *streamListener) readAll(conn net.Conn, onData CallbackData) error {
return fmt.Errorf("setting read deadline failed: %w", err)
}
}
buf, err := io.ReadAll(decoder)
if err != nil {
return fmt.Errorf("read on %s failed: %w", src, err)
@ -378,7 +384,15 @@ func (l *streamListener) readAll(conn net.Conn, onData CallbackData) error {
return nil
}
func (l *streamListener) handleConnection(conn net.Conn, onConnection CallbackConnection) error {
func (l *streamListener) handleConnection(ctx context.Context, conn net.Conn, onConnection CallbackConnection) error {
defer l.wg.Done()
localCtx, cancel := context.WithCancel(ctx)
defer cancel()
defer l.closeConnection(conn)
stopFunc := context.AfterFunc(localCtx, func() { l.closeConnection(conn) })
defer stopFunc()
// Prepare the data decoder for the connection
decoder, err := internal.NewStreamContentDecoder(l.Encoding, conn)
if err != nil {