fix(common.socket): Switch to context to simplify closing (#15589)
This commit is contained in:
parent
90fdcfcdd7
commit
3d9562ba91
|
|
@ -27,7 +27,7 @@ type listener interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
MaxConnections int `toml:"max_connections"`
|
MaxConnections uint64 `toml:"max_connections"`
|
||||||
ReadBufferSize config.Size `toml:"read_buffer_size"`
|
ReadBufferSize config.Size `toml:"read_buffer_size"`
|
||||||
ReadTimeout config.Duration `toml:"read_timeout"`
|
ReadTimeout config.Duration `toml:"read_timeout"`
|
||||||
KeepAlivePeriod *config.Duration `toml:"keep_alive_period"`
|
KeepAlivePeriod *config.Duration `toml:"keep_alive_period"`
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,14 @@ package socket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -481,7 +483,7 @@ func TestClosingConnections(t *testing.T) {
|
||||||
listener, ok := sock.listener.(*streamListener)
|
listener, ok := sock.listener.(*streamListener)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
listener.Lock()
|
listener.Lock()
|
||||||
conns := len(listener.connections)
|
conns := listener.connections
|
||||||
listener.Unlock()
|
listener.Unlock()
|
||||||
require.NotZero(t, conns)
|
require.NotZero(t, conns)
|
||||||
|
|
||||||
|
|
@ -496,6 +498,130 @@ func TestClosingConnections(t *testing.T) {
|
||||||
require.Empty(t, logger.Errors())
|
require.Empty(t, logger.Errors())
|
||||||
require.Empty(t, logger.Warnings())
|
require.Empty(t, logger.Warnings())
|
||||||
}
|
}
|
||||||
|
func TestMaxConnections(t *testing.T) {
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
t.Skip("Skipping on darwin due to missing socket options")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup the configuration
|
||||||
|
period := config.Duration(10 * time.Millisecond)
|
||||||
|
cfg := &Config{
|
||||||
|
MaxConnections: 5,
|
||||||
|
KeepAlivePeriod: &period,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the socket
|
||||||
|
serviceAddress := "tcp://127.0.0.1:0"
|
||||||
|
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create callback
|
||||||
|
var errs []error
|
||||||
|
var mu sync.Mutex
|
||||||
|
onData := func(_ net.Addr, _ []byte) {}
|
||||||
|
onError := func(err error) {
|
||||||
|
mu.Lock()
|
||||||
|
errs = append(errs, err)
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
require.NoError(t, sock.Setup())
|
||||||
|
sock.Listen(onData, onError)
|
||||||
|
defer sock.Close()
|
||||||
|
|
||||||
|
addr := sock.Address()
|
||||||
|
|
||||||
|
// Create maximum number of connections and write some data. All of this
|
||||||
|
// should succeed...
|
||||||
|
clients := make([]*net.TCPConn, 0, cfg.MaxConnections)
|
||||||
|
for i := 0; i < int(cfg.MaxConnections); i++ {
|
||||||
|
c, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, c.SetWriteBuffer(0))
|
||||||
|
require.NoError(t, c.SetNoDelay(true))
|
||||||
|
clients = append(clients, c)
|
||||||
|
|
||||||
|
_, err = c.Write([]byte("test value=42i\n"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func() {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
require.Empty(t, errs)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Create another client. This should fail because we already reached the
|
||||||
|
// connection limit and the connection should be closed...
|
||||||
|
client, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, client.SetWriteBuffer(0))
|
||||||
|
require.NoError(t, client.SetNoDelay(true))
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
return len(errs) > 0
|
||||||
|
}, 3*time.Second, 100*time.Millisecond)
|
||||||
|
func() {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
require.Len(t, errs, 1)
|
||||||
|
require.ErrorContains(t, errs[0], "too many connections")
|
||||||
|
errs = make([]error, 0)
|
||||||
|
}()
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
_, err := client.Write([]byte("fail\n"))
|
||||||
|
return err != nil
|
||||||
|
}, 3*time.Second, 100*time.Millisecond)
|
||||||
|
_, err = client.Write([]byte("test\n"))
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// Check other connections are still good
|
||||||
|
for _, c := range clients {
|
||||||
|
_, err := c.Write([]byte("test\n"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
func() {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
require.Empty(t, errs)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Close the first client and check if we can connect now
|
||||||
|
require.NoError(t, clients[0].Close())
|
||||||
|
client, err = net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, client.SetWriteBuffer(0))
|
||||||
|
require.NoError(t, client.SetNoDelay(true))
|
||||||
|
_, err = client.Write([]byte("success\n"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Close all connections
|
||||||
|
require.NoError(t, client.Close())
|
||||||
|
for _, c := range clients[1:] {
|
||||||
|
require.NoError(t, c.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the clients and check the connection counter
|
||||||
|
listener, ok := sock.listener.(*streamListener)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
listener.Lock()
|
||||||
|
conns := listener.connections
|
||||||
|
listener.Unlock()
|
||||||
|
return conns == 0
|
||||||
|
}, 3*time.Second, 100*time.Millisecond)
|
||||||
|
|
||||||
|
// Close the socket and check again...
|
||||||
|
sock.Close()
|
||||||
|
listener.Lock()
|
||||||
|
conns := listener.connections
|
||||||
|
listener.Unlock()
|
||||||
|
require.Zero(t, conns)
|
||||||
|
}
|
||||||
|
|
||||||
func TestNoSplitter(t *testing.T) {
|
func TestNoSplitter(t *testing.T) {
|
||||||
messages := [][]byte{
|
messages := [][]byte{
|
||||||
|
|
@ -605,6 +731,101 @@ func TestNoSplitter(t *testing.T) {
|
||||||
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
|
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTLSMemLeak(t *testing.T) {
|
||||||
|
// For issue https://github.com/influxdata/telegraf/issues/15509
|
||||||
|
|
||||||
|
// Prepare the address and socket if needed
|
||||||
|
serviceAddress := "tcp://127.0.0.1:0"
|
||||||
|
|
||||||
|
// Setup a TLS socket to trigger the issue
|
||||||
|
cfg := &Config{
|
||||||
|
ServerConfig: *pki.TLSServerConfig(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the socket
|
||||||
|
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create callbacks
|
||||||
|
onConnection := func(_ net.Addr, reader io.ReadCloser) {
|
||||||
|
//nolint:errcheck // We are not interested in the data so ignore all errors
|
||||||
|
io.Copy(io.Discard, reader)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
require.NoError(t, sock.Setup())
|
||||||
|
sock.ListenConnection(onConnection, nil)
|
||||||
|
defer sock.Close()
|
||||||
|
|
||||||
|
addr := sock.Address()
|
||||||
|
|
||||||
|
// Setup the client side TLS
|
||||||
|
tlsCfg, err := pki.TLSClientConfig().TLSConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Define a single client write sequence
|
||||||
|
data := []byte("test value=42i")
|
||||||
|
write := func() error {
|
||||||
|
conn, err := tls.Dial("tcp", addr.String(), tlsCfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
_, err = conn.Write(data)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define a test with the given number of connections
|
||||||
|
maxConcurrency := runtime.GOMAXPROCS(0)
|
||||||
|
testCycle := func(connections int) (uint64, error) {
|
||||||
|
var mu sync.Mutex
|
||||||
|
var errs []error
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for count := 1; count < connections; count++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := write(); err != nil {
|
||||||
|
mu.Lock()
|
||||||
|
errs = append(errs, err)
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if count%maxConcurrency == 0 {
|
||||||
|
wg.Wait()
|
||||||
|
mu.Lock()
|
||||||
|
if len(errs) > 0 {
|
||||||
|
mu.Unlock()
|
||||||
|
return 0, errors.Join(errs...)
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//nolint:revive // We need to actively run the garbage collector to get reliable measurements
|
||||||
|
runtime.GC()
|
||||||
|
|
||||||
|
var stats runtime.MemStats
|
||||||
|
runtime.ReadMemStats(&stats)
|
||||||
|
return stats.HeapObjects, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Measure the memory usage after a short warmup and after some time.
|
||||||
|
// The final number of heap objects should not exceed the number of
|
||||||
|
// runs by a save margin
|
||||||
|
|
||||||
|
// Warmup, do a low number of runs to initialize all data structures
|
||||||
|
// taking them out of the equation.
|
||||||
|
initial, err := testCycle(100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Do some more runs and make sure the memory growth is bound
|
||||||
|
final, err := testCycle(2000)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Less(t, final, 2*initial)
|
||||||
|
}
|
||||||
|
|
||||||
func createClient(endpoint string, addr net.Addr, tlsCfg *tls.Config) (net.Conn, error) {
|
func createClient(endpoint string, addr net.Addr, tlsCfg *tls.Config) (net.Conn, error) {
|
||||||
// Determine the protocol in a crude fashion
|
// Determine the protocol in a crude fashion
|
||||||
parts := strings.SplitN(endpoint, "://", 2)
|
parts := strings.SplitN(endpoint, "://", 2)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package socket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
@ -32,15 +33,16 @@ type hasSetReadBuffer interface {
|
||||||
type streamListener struct {
|
type streamListener struct {
|
||||||
Encoding string
|
Encoding string
|
||||||
ReadBufferSize int
|
ReadBufferSize int
|
||||||
MaxConnections int
|
MaxConnections uint64
|
||||||
ReadTimeout config.Duration
|
ReadTimeout config.Duration
|
||||||
KeepAlivePeriod *config.Duration
|
KeepAlivePeriod *config.Duration
|
||||||
Splitter bufio.SplitFunc
|
Splitter bufio.SplitFunc
|
||||||
Log telegraf.Logger
|
Log telegraf.Logger
|
||||||
|
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
connections map[net.Conn]bool
|
connections uint64
|
||||||
path string
|
path string
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
@ -122,19 +124,15 @@ func (l *streamListener) setupVsock(u *url.URL) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *streamListener) setupConnection(conn net.Conn) error {
|
func (l *streamListener) setupConnection(conn net.Conn) error {
|
||||||
if c, ok := conn.(*tls.Conn); ok {
|
|
||||||
conn = c.NetConn()
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := conn.RemoteAddr().String()
|
addr := conn.RemoteAddr().String()
|
||||||
l.Lock()
|
l.Lock()
|
||||||
if l.MaxConnections > 0 && len(l.connections) >= l.MaxConnections {
|
if l.MaxConnections > 0 && l.connections >= l.MaxConnections {
|
||||||
l.Unlock()
|
l.Unlock()
|
||||||
// Ignore the returned error as we cannot do anything about it anyway
|
// Ignore the returned error as we cannot do anything about it anyway
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
return fmt.Errorf("unable to accept connection from %q: too many connections", addr)
|
return fmt.Errorf("unable to accept connection from %q: too many connections", addr)
|
||||||
}
|
}
|
||||||
l.connections[conn] = true
|
l.connections++
|
||||||
l.Unlock()
|
l.Unlock()
|
||||||
|
|
||||||
if l.ReadBufferSize > 0 {
|
if l.ReadBufferSize > 0 {
|
||||||
|
|
@ -149,6 +147,9 @@ func (l *streamListener) setupConnection(conn net.Conn) error {
|
||||||
|
|
||||||
// Set keep alive handlings
|
// Set keep alive handlings
|
||||||
if l.KeepAlivePeriod != nil {
|
if l.KeepAlivePeriod != nil {
|
||||||
|
if c, ok := conn.(*tls.Conn); ok {
|
||||||
|
conn = c.NetConn()
|
||||||
|
}
|
||||||
tcpConn, ok := conn.(*net.TCPConn)
|
tcpConn, ok := conn.(*net.TCPConn)
|
||||||
if !ok {
|
if !ok {
|
||||||
l.Log.Warnf("connection not a TCP connection (%T)", conn)
|
l.Log.Warnf("connection not a TCP connection (%T)", conn)
|
||||||
|
|
@ -172,11 +173,18 @@ func (l *streamListener) setupConnection(conn net.Conn) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *streamListener) closeConnection(conn net.Conn) {
|
func (l *streamListener) closeConnection(conn net.Conn) {
|
||||||
|
// Fallback to enforce blocked reads on connections to end immediately
|
||||||
|
//nolint:errcheck // Ignore errors as this is a fallback only
|
||||||
|
conn.SetReadDeadline(time.Now())
|
||||||
|
|
||||||
addr := conn.RemoteAddr().String()
|
addr := conn.RemoteAddr().String()
|
||||||
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, syscall.EPIPE) {
|
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, syscall.EPIPE) {
|
||||||
l.Log.Warnf("Cannot close connection to %q: %v", addr, err)
|
l.Log.Warnf("Cannot close connection to %q: %v", addr, err)
|
||||||
|
} else {
|
||||||
|
l.Lock()
|
||||||
|
l.connections--
|
||||||
|
l.Unlock()
|
||||||
}
|
}
|
||||||
delete(l.connections, conn)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *streamListener) address() net.Addr {
|
func (l *streamListener) address() net.Addr {
|
||||||
|
|
@ -184,15 +192,18 @@ func (l *streamListener) address() net.Addr {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *streamListener) close() error {
|
func (l *streamListener) close() error {
|
||||||
if err := l.listener.Close(); err != nil {
|
if l.listener != nil {
|
||||||
return err
|
// Continue even if we cannot close the listener in order to at least
|
||||||
|
// close all active connections
|
||||||
|
if err := l.listener.Close(); err != nil {
|
||||||
|
l.Log.Errorf("Cannot close listener: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
l.Lock()
|
if l.cancel != nil {
|
||||||
for conn := range l.connections {
|
l.cancel()
|
||||||
l.closeConnection(conn)
|
l.cancel = nil
|
||||||
}
|
}
|
||||||
l.Unlock()
|
|
||||||
l.wg.Wait()
|
l.wg.Wait()
|
||||||
|
|
||||||
if l.path != "" {
|
if l.path != "" {
|
||||||
|
|
@ -200,8 +211,8 @@ func (l *streamListener) close() error {
|
||||||
if runtime.GOOS == "windows" && strings.Contains(fn, ":") {
|
if runtime.GOOS == "windows" && strings.Contains(fn, ":") {
|
||||||
fn = strings.TrimPrefix(fn, `\`)
|
fn = strings.TrimPrefix(fn, `\`)
|
||||||
}
|
}
|
||||||
|
// Ignore file-not-exists errors when removing the socket
|
||||||
if err := os.Remove(fn); err != nil && !errors.Is(err, os.ErrNotExist) {
|
if err := os.Remove(fn); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
// Ignore file-not-exists errors when removing the socket
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -209,13 +220,13 @@ func (l *streamListener) close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *streamListener) listenData(onData CallbackData, onError CallbackError) {
|
func (l *streamListener) listenData(onData CallbackData, onError CallbackError) {
|
||||||
l.connections = make(map[net.Conn]bool)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
l.cancel = cancel
|
||||||
|
|
||||||
l.wg.Add(1)
|
l.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer l.wg.Done()
|
defer l.wg.Done()
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for {
|
for {
|
||||||
conn, err := l.listener.Accept()
|
conn, err := l.listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -230,40 +241,42 @@ func (l *streamListener) listenData(onData CallbackData, onError CallbackError)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
l.wg.Add(1)
|
||||||
go func(c net.Conn) {
|
go l.handleReaderConn(ctx, conn, onData, onError)
|
||||||
defer wg.Done()
|
|
||||||
defer func() {
|
|
||||||
l.Lock()
|
|
||||||
l.closeConnection(conn)
|
|
||||||
l.Unlock()
|
|
||||||
}()
|
|
||||||
|
|
||||||
reader := l.read
|
|
||||||
if l.Splitter == nil {
|
|
||||||
reader = l.readAll
|
|
||||||
}
|
|
||||||
if err := reader(c, onData); err != nil {
|
|
||||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
|
|
||||||
if onError != nil {
|
|
||||||
onError(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}(conn)
|
|
||||||
}
|
}
|
||||||
wg.Wait()
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *streamListener) handleReaderConn(ctx context.Context, conn net.Conn, onData CallbackData, onError CallbackError) {
|
||||||
|
defer l.wg.Done()
|
||||||
|
|
||||||
|
localCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
defer l.closeConnection(conn)
|
||||||
|
stopFunc := context.AfterFunc(localCtx, func() { l.closeConnection(conn) })
|
||||||
|
defer stopFunc()
|
||||||
|
|
||||||
|
reader := l.read
|
||||||
|
if l.Splitter == nil {
|
||||||
|
reader = l.readAll
|
||||||
|
}
|
||||||
|
if err := reader(conn, onData); err != nil {
|
||||||
|
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
|
||||||
|
if onError != nil {
|
||||||
|
onError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (l *streamListener) listenConnection(onConnection CallbackConnection, onError CallbackError) {
|
func (l *streamListener) listenConnection(onConnection CallbackConnection, onError CallbackError) {
|
||||||
l.connections = make(map[net.Conn]bool)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
l.cancel = cancel
|
||||||
|
|
||||||
l.wg.Add(1)
|
l.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer l.wg.Done()
|
defer l.wg.Done()
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for {
|
for {
|
||||||
conn, err := l.listener.Accept()
|
conn, err := l.listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -272,28 +285,22 @@ func (l *streamListener) listenConnection(onConnection CallbackConnection, onErr
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := l.setupConnection(conn); err != nil && onError != nil {
|
if err := l.setupConnection(conn); err != nil && onError != nil {
|
||||||
onError(err)
|
onError(err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
l.wg.Add(1)
|
||||||
go func(c net.Conn) {
|
go func(c net.Conn) {
|
||||||
defer wg.Done()
|
if err := l.handleConnection(ctx, c, onConnection); err != nil {
|
||||||
if err := l.handleConnection(c, onConnection); err != nil {
|
|
||||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
|
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
|
||||||
if onError != nil {
|
if onError != nil {
|
||||||
onError(err)
|
onError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
l.Lock()
|
|
||||||
l.closeConnection(conn)
|
|
||||||
l.Unlock()
|
|
||||||
}(conn)
|
}(conn)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -368,7 +375,6 @@ func (l *streamListener) readAll(conn net.Conn, onData CallbackData) error {
|
||||||
return fmt.Errorf("setting read deadline failed: %w", err)
|
return fmt.Errorf("setting read deadline failed: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
buf, err := io.ReadAll(decoder)
|
buf, err := io.ReadAll(decoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("read on %s failed: %w", src, err)
|
return fmt.Errorf("read on %s failed: %w", src, err)
|
||||||
|
|
@ -378,7 +384,15 @@ func (l *streamListener) readAll(conn net.Conn, onData CallbackData) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *streamListener) handleConnection(conn net.Conn, onConnection CallbackConnection) error {
|
func (l *streamListener) handleConnection(ctx context.Context, conn net.Conn, onConnection CallbackConnection) error {
|
||||||
|
defer l.wg.Done()
|
||||||
|
|
||||||
|
localCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
defer l.closeConnection(conn)
|
||||||
|
stopFunc := context.AfterFunc(localCtx, func() { l.closeConnection(conn) })
|
||||||
|
defer stopFunc()
|
||||||
|
|
||||||
// Prepare the data decoder for the connection
|
// Prepare the data decoder for the connection
|
||||||
decoder, err := internal.NewStreamContentDecoder(l.Encoding, conn)
|
decoder, err := internal.NewStreamContentDecoder(l.Encoding, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue