fix(common.socket): Switch to context to simplify closing (#15589)
This commit is contained in:
parent
90fdcfcdd7
commit
3d9562ba91
|
|
@ -27,7 +27,7 @@ type listener interface {
|
|||
}
|
||||
|
||||
type Config struct {
|
||||
MaxConnections int `toml:"max_connections"`
|
||||
MaxConnections uint64 `toml:"max_connections"`
|
||||
ReadBufferSize config.Size `toml:"read_buffer_size"`
|
||||
ReadTimeout config.Duration `toml:"read_timeout"`
|
||||
KeepAlivePeriod *config.Duration `toml:"keep_alive_period"`
|
||||
|
|
|
|||
|
|
@ -2,12 +2,14 @@ package socket
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -481,7 +483,7 @@ func TestClosingConnections(t *testing.T) {
|
|||
listener, ok := sock.listener.(*streamListener)
|
||||
require.True(t, ok)
|
||||
listener.Lock()
|
||||
conns := len(listener.connections)
|
||||
conns := listener.connections
|
||||
listener.Unlock()
|
||||
require.NotZero(t, conns)
|
||||
|
||||
|
|
@ -496,6 +498,130 @@ func TestClosingConnections(t *testing.T) {
|
|||
require.Empty(t, logger.Errors())
|
||||
require.Empty(t, logger.Warnings())
|
||||
}
|
||||
func TestMaxConnections(t *testing.T) {
|
||||
if runtime.GOOS == "darwin" {
|
||||
t.Skip("Skipping on darwin due to missing socket options")
|
||||
}
|
||||
|
||||
// Setup the configuration
|
||||
period := config.Duration(10 * time.Millisecond)
|
||||
cfg := &Config{
|
||||
MaxConnections: 5,
|
||||
KeepAlivePeriod: &period,
|
||||
}
|
||||
|
||||
// Create the socket
|
||||
serviceAddress := "tcp://127.0.0.1:0"
|
||||
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create callback
|
||||
var errs []error
|
||||
var mu sync.Mutex
|
||||
onData := func(_ net.Addr, _ []byte) {}
|
||||
onError := func(err error) {
|
||||
mu.Lock()
|
||||
errs = append(errs, err)
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
require.NoError(t, sock.Setup())
|
||||
sock.Listen(onData, onError)
|
||||
defer sock.Close()
|
||||
|
||||
addr := sock.Address()
|
||||
|
||||
// Create maximum number of connections and write some data. All of this
|
||||
// should succeed...
|
||||
clients := make([]*net.TCPConn, 0, cfg.MaxConnections)
|
||||
for i := 0; i < int(cfg.MaxConnections); i++ {
|
||||
c, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, c.SetWriteBuffer(0))
|
||||
require.NoError(t, c.SetNoDelay(true))
|
||||
clients = append(clients, c)
|
||||
|
||||
_, err = c.Write([]byte("test value=42i\n"))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.Empty(t, errs)
|
||||
}()
|
||||
|
||||
// Create another client. This should fail because we already reached the
|
||||
// connection limit and the connection should be closed...
|
||||
client, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.SetWriteBuffer(0))
|
||||
require.NoError(t, client.SetNoDelay(true))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return len(errs) > 0
|
||||
}, 3*time.Second, 100*time.Millisecond)
|
||||
func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.Len(t, errs, 1)
|
||||
require.ErrorContains(t, errs[0], "too many connections")
|
||||
errs = make([]error, 0)
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := client.Write([]byte("fail\n"))
|
||||
return err != nil
|
||||
}, 3*time.Second, 100*time.Millisecond)
|
||||
_, err = client.Write([]byte("test\n"))
|
||||
require.Error(t, err)
|
||||
|
||||
// Check other connections are still good
|
||||
for _, c := range clients {
|
||||
_, err := c.Write([]byte("test\n"))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.Empty(t, errs)
|
||||
}()
|
||||
|
||||
// Close the first client and check if we can connect now
|
||||
require.NoError(t, clients[0].Close())
|
||||
client, err = net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.SetWriteBuffer(0))
|
||||
require.NoError(t, client.SetNoDelay(true))
|
||||
_, err = client.Write([]byte("success\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close all connections
|
||||
require.NoError(t, client.Close())
|
||||
for _, c := range clients[1:] {
|
||||
require.NoError(t, c.Close())
|
||||
}
|
||||
|
||||
// Close the clients and check the connection counter
|
||||
listener, ok := sock.listener.(*streamListener)
|
||||
require.True(t, ok)
|
||||
require.Eventually(t, func() bool {
|
||||
listener.Lock()
|
||||
conns := listener.connections
|
||||
listener.Unlock()
|
||||
return conns == 0
|
||||
}, 3*time.Second, 100*time.Millisecond)
|
||||
|
||||
// Close the socket and check again...
|
||||
sock.Close()
|
||||
listener.Lock()
|
||||
conns := listener.connections
|
||||
listener.Unlock()
|
||||
require.Zero(t, conns)
|
||||
}
|
||||
|
||||
func TestNoSplitter(t *testing.T) {
|
||||
messages := [][]byte{
|
||||
|
|
@ -605,6 +731,101 @@ func TestNoSplitter(t *testing.T) {
|
|||
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
|
||||
}
|
||||
|
||||
func TestTLSMemLeak(t *testing.T) {
|
||||
// For issue https://github.com/influxdata/telegraf/issues/15509
|
||||
|
||||
// Prepare the address and socket if needed
|
||||
serviceAddress := "tcp://127.0.0.1:0"
|
||||
|
||||
// Setup a TLS socket to trigger the issue
|
||||
cfg := &Config{
|
||||
ServerConfig: *pki.TLSServerConfig(),
|
||||
}
|
||||
|
||||
// Create the socket
|
||||
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create callbacks
|
||||
onConnection := func(_ net.Addr, reader io.ReadCloser) {
|
||||
//nolint:errcheck // We are not interested in the data so ignore all errors
|
||||
io.Copy(io.Discard, reader)
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
require.NoError(t, sock.Setup())
|
||||
sock.ListenConnection(onConnection, nil)
|
||||
defer sock.Close()
|
||||
|
||||
addr := sock.Address()
|
||||
|
||||
// Setup the client side TLS
|
||||
tlsCfg, err := pki.TLSClientConfig().TLSConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Define a single client write sequence
|
||||
data := []byte("test value=42i")
|
||||
write := func() error {
|
||||
conn, err := tls.Dial("tcp", addr.String(), tlsCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
// Define a test with the given number of connections
|
||||
maxConcurrency := runtime.GOMAXPROCS(0)
|
||||
testCycle := func(connections int) (uint64, error) {
|
||||
var mu sync.Mutex
|
||||
var errs []error
|
||||
var wg sync.WaitGroup
|
||||
for count := 1; count < connections; count++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := write(); err != nil {
|
||||
mu.Lock()
|
||||
errs = append(errs, err)
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
if count%maxConcurrency == 0 {
|
||||
wg.Wait()
|
||||
mu.Lock()
|
||||
if len(errs) > 0 {
|
||||
mu.Unlock()
|
||||
return 0, errors.Join(errs...)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
//nolint:revive // We need to actively run the garbage collector to get reliable measurements
|
||||
runtime.GC()
|
||||
|
||||
var stats runtime.MemStats
|
||||
runtime.ReadMemStats(&stats)
|
||||
return stats.HeapObjects, nil
|
||||
}
|
||||
|
||||
// Measure the memory usage after a short warmup and after some time.
|
||||
// The final number of heap objects should not exceed the number of
|
||||
// runs by a save margin
|
||||
|
||||
// Warmup, do a low number of runs to initialize all data structures
|
||||
// taking them out of the equation.
|
||||
initial, err := testCycle(100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Do some more runs and make sure the memory growth is bound
|
||||
final, err := testCycle(2000)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Less(t, final, 2*initial)
|
||||
}
|
||||
|
||||
func createClient(endpoint string, addr net.Addr, tlsCfg *tls.Config) (net.Conn, error) {
|
||||
// Determine the protocol in a crude fashion
|
||||
parts := strings.SplitN(endpoint, "://", 2)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package socket
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -32,15 +33,16 @@ type hasSetReadBuffer interface {
|
|||
type streamListener struct {
|
||||
Encoding string
|
||||
ReadBufferSize int
|
||||
MaxConnections int
|
||||
MaxConnections uint64
|
||||
ReadTimeout config.Duration
|
||||
KeepAlivePeriod *config.Duration
|
||||
Splitter bufio.SplitFunc
|
||||
Log telegraf.Logger
|
||||
|
||||
listener net.Listener
|
||||
connections map[net.Conn]bool
|
||||
connections uint64
|
||||
path string
|
||||
cancel context.CancelFunc
|
||||
|
||||
wg sync.WaitGroup
|
||||
sync.Mutex
|
||||
|
|
@ -122,19 +124,15 @@ func (l *streamListener) setupVsock(u *url.URL) error {
|
|||
}
|
||||
|
||||
func (l *streamListener) setupConnection(conn net.Conn) error {
|
||||
if c, ok := conn.(*tls.Conn); ok {
|
||||
conn = c.NetConn()
|
||||
}
|
||||
|
||||
addr := conn.RemoteAddr().String()
|
||||
l.Lock()
|
||||
if l.MaxConnections > 0 && len(l.connections) >= l.MaxConnections {
|
||||
if l.MaxConnections > 0 && l.connections >= l.MaxConnections {
|
||||
l.Unlock()
|
||||
// Ignore the returned error as we cannot do anything about it anyway
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("unable to accept connection from %q: too many connections", addr)
|
||||
}
|
||||
l.connections[conn] = true
|
||||
l.connections++
|
||||
l.Unlock()
|
||||
|
||||
if l.ReadBufferSize > 0 {
|
||||
|
|
@ -149,6 +147,9 @@ func (l *streamListener) setupConnection(conn net.Conn) error {
|
|||
|
||||
// Set keep alive handlings
|
||||
if l.KeepAlivePeriod != nil {
|
||||
if c, ok := conn.(*tls.Conn); ok {
|
||||
conn = c.NetConn()
|
||||
}
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
l.Log.Warnf("connection not a TCP connection (%T)", conn)
|
||||
|
|
@ -172,11 +173,18 @@ func (l *streamListener) setupConnection(conn net.Conn) error {
|
|||
}
|
||||
|
||||
func (l *streamListener) closeConnection(conn net.Conn) {
|
||||
// Fallback to enforce blocked reads on connections to end immediately
|
||||
//nolint:errcheck // Ignore errors as this is a fallback only
|
||||
conn.SetReadDeadline(time.Now())
|
||||
|
||||
addr := conn.RemoteAddr().String()
|
||||
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, syscall.EPIPE) {
|
||||
l.Log.Warnf("Cannot close connection to %q: %v", addr, err)
|
||||
} else {
|
||||
l.Lock()
|
||||
l.connections--
|
||||
l.Unlock()
|
||||
}
|
||||
delete(l.connections, conn)
|
||||
}
|
||||
|
||||
func (l *streamListener) address() net.Addr {
|
||||
|
|
@ -184,15 +192,18 @@ func (l *streamListener) address() net.Addr {
|
|||
}
|
||||
|
||||
func (l *streamListener) close() error {
|
||||
if err := l.listener.Close(); err != nil {
|
||||
return err
|
||||
if l.listener != nil {
|
||||
// Continue even if we cannot close the listener in order to at least
|
||||
// close all active connections
|
||||
if err := l.listener.Close(); err != nil {
|
||||
l.Log.Errorf("Cannot close listener: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
l.Lock()
|
||||
for conn := range l.connections {
|
||||
l.closeConnection(conn)
|
||||
if l.cancel != nil {
|
||||
l.cancel()
|
||||
l.cancel = nil
|
||||
}
|
||||
l.Unlock()
|
||||
l.wg.Wait()
|
||||
|
||||
if l.path != "" {
|
||||
|
|
@ -200,8 +211,8 @@ func (l *streamListener) close() error {
|
|||
if runtime.GOOS == "windows" && strings.Contains(fn, ":") {
|
||||
fn = strings.TrimPrefix(fn, `\`)
|
||||
}
|
||||
// Ignore file-not-exists errors when removing the socket
|
||||
if err := os.Remove(fn); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
// Ignore file-not-exists errors when removing the socket
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -209,13 +220,13 @@ func (l *streamListener) close() error {
|
|||
}
|
||||
|
||||
func (l *streamListener) listenData(onData CallbackData, onError CallbackError) {
|
||||
l.connections = make(map[net.Conn]bool)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
l.cancel = cancel
|
||||
|
||||
l.wg.Add(1)
|
||||
go func() {
|
||||
defer l.wg.Done()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for {
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
|
|
@ -230,40 +241,42 @@ func (l *streamListener) listenData(onData CallbackData, onError CallbackError)
|
|||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(c net.Conn) {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
l.Lock()
|
||||
l.closeConnection(conn)
|
||||
l.Unlock()
|
||||
}()
|
||||
|
||||
reader := l.read
|
||||
if l.Splitter == nil {
|
||||
reader = l.readAll
|
||||
}
|
||||
if err := reader(c, onData); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
|
||||
if onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}(conn)
|
||||
l.wg.Add(1)
|
||||
go l.handleReaderConn(ctx, conn, onData, onError)
|
||||
}
|
||||
wg.Wait()
|
||||
}()
|
||||
}
|
||||
|
||||
func (l *streamListener) handleReaderConn(ctx context.Context, conn net.Conn, onData CallbackData, onError CallbackError) {
|
||||
defer l.wg.Done()
|
||||
|
||||
localCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
defer l.closeConnection(conn)
|
||||
stopFunc := context.AfterFunc(localCtx, func() { l.closeConnection(conn) })
|
||||
defer stopFunc()
|
||||
|
||||
reader := l.read
|
||||
if l.Splitter == nil {
|
||||
reader = l.readAll
|
||||
}
|
||||
if err := reader(conn, onData); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
|
||||
if onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *streamListener) listenConnection(onConnection CallbackConnection, onError CallbackError) {
|
||||
l.connections = make(map[net.Conn]bool)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
l.cancel = cancel
|
||||
|
||||
l.wg.Add(1)
|
||||
go func() {
|
||||
defer l.wg.Done()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for {
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
|
|
@ -272,28 +285,22 @@ func (l *streamListener) listenConnection(onConnection CallbackConnection, onErr
|
|||
}
|
||||
break
|
||||
}
|
||||
|
||||
if err := l.setupConnection(conn); err != nil && onError != nil {
|
||||
onError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
l.wg.Add(1)
|
||||
go func(c net.Conn) {
|
||||
defer wg.Done()
|
||||
if err := l.handleConnection(c, onConnection); err != nil {
|
||||
if err := l.handleConnection(ctx, c, onConnection); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
|
||||
if onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
l.Lock()
|
||||
l.closeConnection(conn)
|
||||
l.Unlock()
|
||||
}(conn)
|
||||
}
|
||||
wg.Wait()
|
||||
}()
|
||||
}
|
||||
|
||||
|
|
@ -368,7 +375,6 @@ func (l *streamListener) readAll(conn net.Conn, onData CallbackData) error {
|
|||
return fmt.Errorf("setting read deadline failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
buf, err := io.ReadAll(decoder)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read on %s failed: %w", src, err)
|
||||
|
|
@ -378,7 +384,15 @@ func (l *streamListener) readAll(conn net.Conn, onData CallbackData) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (l *streamListener) handleConnection(conn net.Conn, onConnection CallbackConnection) error {
|
||||
func (l *streamListener) handleConnection(ctx context.Context, conn net.Conn, onConnection CallbackConnection) error {
|
||||
defer l.wg.Done()
|
||||
|
||||
localCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
defer l.closeConnection(conn)
|
||||
stopFunc := context.AfterFunc(localCtx, func() { l.closeConnection(conn) })
|
||||
defer stopFunc()
|
||||
|
||||
// Prepare the data decoder for the connection
|
||||
decoder, err := internal.NewStreamContentDecoder(l.Encoding, conn)
|
||||
if err != nil {
|
||||
|
|
|
|||
Loading…
Reference in New Issue