From cc94587f11d08d9138677dcab5952891b4df5090 Mon Sep 17 00:00:00 2001 From: Sven Rebhan <36194019+srebhan@users.noreply.github.com> Date: Tue, 1 Nov 2022 12:18:14 +0100 Subject: [PATCH] chore(inputs.socket_listener): Reorganize plugin code (#12031) --- .../inputs/socket_listener/packet_listener.go | 172 ++++++++ .../inputs/socket_listener/socket_listener.go | 417 ++++-------------- .../socket_listener/socket_listener_test.go | 71 ++- .../inputs/socket_listener/stream_listener.go | 246 +++++++++++ 4 files changed, 540 insertions(+), 366 deletions(-) create mode 100644 plugins/inputs/socket_listener/packet_listener.go create mode 100644 plugins/inputs/socket_listener/stream_listener.go diff --git a/plugins/inputs/socket_listener/packet_listener.go b/plugins/inputs/socket_listener/packet_listener.go new file mode 100644 index 000000000..037c7c646 --- /dev/null +++ b/plugins/inputs/socket_listener/packet_listener.go @@ -0,0 +1,172 @@ +package socket_listener + +import ( + "errors" + "fmt" + "net" + "net/url" + "os" + "strconv" + "strings" + + "github.com/influxdata/telegraf" + "github.com/influxdata/telegraf/internal" +) + +type packetListener struct { + Encoding string + SocketMode string + ReadBufferSize int + Parser telegraf.Parser + Log telegraf.Logger + + conn net.PacketConn + decoder internal.ContentDecoder + path string +} + +func (l *packetListener) listen(acc telegraf.Accumulator) { + buf := make([]byte, 64*1024) // 64kb - maximum size of IP packet + for { + n, _, err := l.conn.ReadFrom(buf) + if err != nil { + if !strings.HasSuffix(err.Error(), ": use of closed network connection") { + acc.AddError(err) + } + break + } + + body, err := l.decoder.Decode(buf[:n]) + if err != nil { + acc.AddError(fmt.Errorf("unable to decode incoming packet: %w", err)) + } + + metrics, err := l.Parser.Parse(body) + if err != nil { + acc.AddError(fmt.Errorf("unable to parse incoming packet: %w", err)) + // TODO rate limit + continue + } + for _, m := range metrics { + acc.AddMetric(m) + } + } +} + +func (l *packetListener) setupUnixgram(u *url.URL, socketMode string) error { + err := os.Remove(u.Path) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("removing socket failed: %w", err) + } + + conn, err := net.ListenPacket(u.Scheme, u.Path) + if err != nil { + return fmt.Errorf("listening (unixgram) failed: %w", err) + } + l.path = u.Path + l.conn = conn + + // Set permissions on socket + if socketMode != "" { + // Convert from octal in string to int + i, err := strconv.ParseUint(socketMode, 8, 32) + if err != nil { + return fmt.Errorf("converting socket mode failed: %w", err) + } + + perm := os.FileMode(uint32(i)) + if err := os.Chmod(u.Path, perm); err != nil { + return fmt.Errorf("changing socket permissions failed: %w", err) + } + } + + // Create a decoder for the given encoding + decoder, err := internal.NewContentDecoder(l.Encoding) + if err != nil { + return fmt.Errorf("creating decoder failed: %w", err) + } + l.decoder = decoder + + return nil +} + +func (l *packetListener) setupUDP(u *url.URL, ifname string, bufferSize int) error { + var conn *net.UDPConn + + addr, err := net.ResolveUDPAddr(u.Scheme, u.Host) + if err != nil { + return fmt.Errorf("resolving UDP address failed: %w", err) + } + if addr.IP.IsMulticast() { + var iface *net.Interface + if ifname != "" { + var err error + iface, err = net.InterfaceByName(ifname) + if err != nil { + return fmt.Errorf("resolving address of %q failed: %w", ifname, err) + } + } + conn, err = net.ListenMulticastUDP(u.Scheme, iface, addr) + if err != nil { + return fmt.Errorf("listening (udp multicast) failed: %w", err) + } + } else { + conn, err = net.ListenUDP(u.Scheme, addr) + if err != nil { + return fmt.Errorf("listening (udp) failed: %w", err) + } + } + + if bufferSize > 0 { + if err := conn.SetReadBuffer(bufferSize); err != nil { + l.Log.Warnf("Setting read buffer on %s socket failed: %v", u.Scheme, err) + } + } + l.conn = conn + + // Create a decoder for the given encoding + decoder, err := internal.NewContentDecoder(l.Encoding) + if err != nil { + return fmt.Errorf("creating decoder failed: %w", err) + } + l.decoder = decoder + + return nil +} + +func (l *packetListener) setupIP(u *url.URL) error { + conn, err := net.ListenPacket(u.Scheme, u.Host) + if err != nil { + return fmt.Errorf("listening (ip) failed: %w", err) + } + l.conn = conn + + // Create a decoder for the given encoding + decoder, err := internal.NewContentDecoder(l.Encoding) + if err != nil { + return fmt.Errorf("creating decoder failed: %w", err) + } + l.decoder = decoder + + return nil +} + +func (l *packetListener) addr() net.Addr { + return l.conn.LocalAddr() +} + +func (l *packetListener) close() error { + if err := l.conn.Close(); err != nil { + return err + } + + if l.path != "" { + err := os.Remove(l.path) + if err != nil && !errors.Is(err, os.ErrNotExist) { + // Ignore file-not-exists errors when removing the socket + return err + } + } + + return nil +} diff --git a/plugins/inputs/socket_listener/socket_listener.go b/plugins/inputs/socket_listener/socket_listener.go index b67c4780d..3355d07fd 100644 --- a/plugins/inputs/socket_listener/socket_listener.go +++ b/plugins/inputs/socket_listener/socket_listener.go @@ -2,21 +2,16 @@ package socket_listener import ( - "bufio" - "crypto/tls" _ "embed" "fmt" - "io" "net" - "os" - "strconv" + "net/url" + "regexp" "strings" "sync" - "time" "github.com/influxdata/telegraf" "github.com/influxdata/telegraf/config" - "github.com/influxdata/telegraf/internal" tlsint "github.com/influxdata/telegraf/plugins/common/tls" "github.com/influxdata/telegraf/plugins/inputs" "github.com/influxdata/telegraf/plugins/parsers" @@ -25,196 +20,27 @@ import ( //go:embed sample.conf var sampleConfig string -type setReadBufferer interface { - SetReadBuffer(bytes int) error -} - -type streamSocketListener struct { - net.Listener - *SocketListener - - sockType string - - connections map[string]net.Conn - connectionsMtx sync.Mutex -} - -func (ssl *streamSocketListener) listen() { - ssl.connections = map[string]net.Conn{} - - wg := sync.WaitGroup{} - - for { - c, err := ssl.Accept() - if err != nil { - if !strings.HasSuffix(err.Error(), ": use of closed network connection") { - ssl.Log.Error(err.Error()) - } - break - } - - if ssl.ReadBufferSize > 0 { - if srb, ok := c.(setReadBufferer); ok { - if err := srb.SetReadBuffer(int(ssl.ReadBufferSize)); err != nil { - ssl.Log.Error(err.Error()) - break - } - } else { - ssl.Log.Warnf("Unable to set read buffer on a %s socket", ssl.sockType) - } - } - - ssl.connectionsMtx.Lock() - if ssl.MaxConnections > 0 && len(ssl.connections) >= ssl.MaxConnections { - ssl.connectionsMtx.Unlock() - // Ignore the returned error as we cannot do anything about it anyway - //nolint:errcheck,revive - c.Close() - continue - } - ssl.connections[c.RemoteAddr().String()] = c - ssl.connectionsMtx.Unlock() - - if err := ssl.setKeepAlive(c); err != nil { - ssl.Log.Errorf("Unable to configure keep alive %q: %s", ssl.ServiceAddress, err.Error()) - } - - wg.Add(1) - go func() { - defer wg.Done() - ssl.read(c) - }() - } - - ssl.connectionsMtx.Lock() - for _, c := range ssl.connections { - // Ignore the returned error as we cannot do anything about it anyway - //nolint:errcheck,revive - c.Close() - } - ssl.connectionsMtx.Unlock() - - wg.Wait() -} - -func (ssl *streamSocketListener) setKeepAlive(c net.Conn) error { - if ssl.KeepAlivePeriod == nil { - return nil - } - tcpc, ok := c.(*net.TCPConn) - if !ok { - return fmt.Errorf("cannot set keep alive on a %s socket", strings.SplitN(ssl.ServiceAddress, "://", 2)[0]) - } - if *ssl.KeepAlivePeriod == 0 { - return tcpc.SetKeepAlive(false) - } - if err := tcpc.SetKeepAlive(true); err != nil { - return err - } - return tcpc.SetKeepAlivePeriod(time.Duration(*ssl.KeepAlivePeriod)) -} - -func (ssl *streamSocketListener) removeConnection(c net.Conn) { - ssl.connectionsMtx.Lock() - delete(ssl.connections, c.RemoteAddr().String()) - ssl.connectionsMtx.Unlock() -} - -func (ssl *streamSocketListener) read(c net.Conn) { - defer ssl.removeConnection(c) - defer c.Close() - - decoder, err := internal.NewStreamContentDecoder(ssl.ContentEncoding, c) - if err != nil { - ssl.Log.Error("Read error: %v", err) - return - } - - scnr := bufio.NewScanner(decoder) - for { - if ssl.ReadTimeout != nil && *ssl.ReadTimeout > 0 { - if err := c.SetReadDeadline(time.Now().Add(time.Duration(*ssl.ReadTimeout))); err != nil { - ssl.Log.Error("setting read deadline failed: %v", err) - return - } - } - if !scnr.Scan() { - break - } - - body := scnr.Bytes() - - metrics, err := ssl.Parse(body) - if err != nil { - ssl.Log.Errorf("Unable to parse incoming line: %s", err.Error()) - // TODO rate limit - continue - } - for _, m := range metrics { - ssl.AddMetric(m) - } - } - - if err := scnr.Err(); err != nil { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - ssl.Log.Debugf("Timeout in plugin: %s", err.Error()) - } else if netErr != nil && !strings.HasSuffix(err.Error(), ": use of closed network connection") { - ssl.Log.Error(err.Error()) - } - } -} - -type packetSocketListener struct { - net.PacketConn - *SocketListener - decoder internal.ContentDecoder -} - -func (psl *packetSocketListener) listen() { - buf := make([]byte, 64*1024) // 64kb - maximum size of IP packet - for { - n, _, err := psl.ReadFrom(buf) - if err != nil { - if !strings.HasSuffix(err.Error(), ": use of closed network connection") { - psl.Log.Error(err.Error()) - } - break - } - - body, err := psl.decoder.Decode(buf[:n]) - if err != nil { - psl.Log.Errorf("Unable to decode incoming packet: %s", err.Error()) - } - - metrics, err := psl.Parse(body) - if err != nil { - psl.Log.Errorf("Unable to parse incoming packet: %s", err.Error()) - // TODO rate limit - continue - } - for _, m := range metrics { - psl.AddMetric(m) - } - } +type listener interface { + listen(acc telegraf.Accumulator) + addr() net.Addr + close() error } type SocketListener struct { ServiceAddress string `toml:"service_address"` MaxConnections int `toml:"max_connections"` ReadBufferSize config.Size `toml:"read_buffer_size"` - ReadTimeout *config.Duration `toml:"read_timeout"` + ReadTimeout config.Duration `toml:"read_timeout"` KeepAlivePeriod *config.Duration `toml:"keep_alive_period"` SocketMode string `toml:"socket_mode"` ContentEncoding string `toml:"content_encoding"` + Log telegraf.Logger `toml:"-"` tlsint.ServerConfig - wg sync.WaitGroup + wg sync.WaitGroup + parser parsers.Parser - Log telegraf.Logger - - parsers.Parser - telegraf.Accumulator - io.Closer + listener listener } func (*SocketListener) SampleConfig() string { @@ -225,181 +51,116 @@ func (sl *SocketListener) Gather(_ telegraf.Accumulator) error { return nil } -func (sl *SocketListener) SetParser(parser parsers.Parser) { - sl.Parser = parser +func (sl *SocketListener) SetParser(parser telegraf.Parser) { + sl.parser = parser } func (sl *SocketListener) Start(acc telegraf.Accumulator) error { - sl.Accumulator = acc - spl := strings.SplitN(sl.ServiceAddress, "://", 2) - if len(spl) != 2 { - return fmt.Errorf("invalid service address: %s", sl.ServiceAddress) + // Resolve the interface to an address if any given + var ifname string + ifregex := regexp.MustCompile(`%([\w\.]+)`) + if matches := ifregex.FindStringSubmatch(sl.ServiceAddress); len(matches) == 2 { + ifname := matches[1] + sl.ServiceAddress = strings.Replace(sl.ServiceAddress, "%"+ifname, "", 1) } - protocol := spl[0] - addr := spl[1] - - if protocol == "unix" || protocol == "unixpacket" || protocol == "unixgram" { - // no good way of testing for "file does not exist". - // Instead just ignore error and blow up when we try to listen, which will - // indicate "address already in use" if file existed and we couldn't remove. - //nolint:errcheck,revive - os.Remove(addr) + // Preparing TLS configuration + tlsCfg, err := sl.ServerConfig.TLSConfig() + if err != nil { + return fmt.Errorf("getting TLS config failed: %w", err) } - switch protocol { - case "tcp", "tcp4", "tcp6", "unix", "unixpacket": - tlsCfg, err := sl.ServerConfig.TLSConfig() - if err != nil { + // Setup the network connection + u, err := url.Parse(sl.ServiceAddress) + if err != nil { + return fmt.Errorf("parsing address failed: %w", err) + } + + switch u.Scheme { + case "tcp", "tcp4", "tcp6": + ssl := &streamListener{ + ReadBufferSize: int(sl.ReadBufferSize), + ReadTimeout: sl.ReadTimeout, + KeepAlivePeriod: sl.KeepAlivePeriod, + MaxConnections: sl.MaxConnections, + Encoding: sl.ContentEncoding, + Parser: sl.parser, + Log: sl.Log, + } + + if err := ssl.setupTCP(u, tlsCfg); err != nil { return err } - - var l net.Listener - if tlsCfg == nil { - l, err = net.Listen(protocol, addr) - } else { - l, err = tls.Listen(protocol, addr, tlsCfg) + sl.listener = ssl + case "unix", "unixpacket": + ssl := &streamListener{ + ReadBufferSize: int(sl.ReadBufferSize), + ReadTimeout: sl.ReadTimeout, + KeepAlivePeriod: sl.KeepAlivePeriod, + MaxConnections: sl.MaxConnections, + Encoding: sl.ContentEncoding, + Parser: sl.parser, + Log: sl.Log, } - if err != nil { + + if err := ssl.setupUnix(u, tlsCfg, sl.SocketMode); err != nil { return err } + sl.listener = ssl - sl.Log.Infof("Listening on %s://%s", protocol, l.Addr()) - - // Set permissions on socket - if (spl[0] == "unix" || spl[0] == "unixpacket") && sl.SocketMode != "" { - // Convert from octal in string to int - i, err := strconv.ParseUint(sl.SocketMode, 8, 32) - if err != nil { - return err - } - - if err := os.Chmod(spl[1], os.FileMode(uint32(i))); err != nil { - return err - } + case "udp", "udp4", "udp6": + psl := &packetListener{ + Encoding: sl.ContentEncoding, + Parser: sl.parser, } - - ssl := &streamSocketListener{ - Listener: l, - SocketListener: sl, - sockType: spl[0], - } - - sl.Closer = ssl - sl.wg = sync.WaitGroup{} - sl.wg.Add(1) - go func() { - defer sl.wg.Done() - ssl.listen() - }() - case "udp", "udp4", "udp6", "ip", "ip4", "ip6", "unixgram": - decoder, err := internal.NewContentDecoder(sl.ContentEncoding) - if err != nil { + if err := psl.setupUDP(u, ifname, int(sl.ReadBufferSize)); err != nil { return err } - - pc, err := udpListen(protocol, addr) - if err != nil { + sl.listener = psl + case "ip", "ip4", "ip6": + psl := &packetListener{ + Encoding: sl.ContentEncoding, + Parser: sl.parser, + } + if err := psl.setupIP(u); err != nil { return err } - - // Set permissions on socket - if spl[0] == "unixgram" && sl.SocketMode != "" { - // Convert from octal in string to int - i, err := strconv.ParseUint(sl.SocketMode, 8, 32) - if err != nil { - return err - } - - if err := os.Chmod(spl[1], os.FileMode(uint32(i))); err != nil { - return err - } + sl.listener = psl + case "unixgram": + psl := &packetListener{ + Encoding: sl.ContentEncoding, + Parser: sl.parser, } - - if sl.ReadBufferSize > 0 { - if srb, ok := pc.(setReadBufferer); ok { - if err := srb.SetReadBuffer(int(sl.ReadBufferSize)); err != nil { - sl.Log.Warnf("Setting read buffer on a %s socket failed: %v", protocol, err) - } - } else { - sl.Log.Warnf("Unable to set read buffer on a %s socket", protocol) - } + if err := psl.setupUnixgram(u, sl.SocketMode); err != nil { + return err } - - sl.Log.Infof("Listening on %s://%s", protocol, pc.LocalAddr()) - - psl := &packetSocketListener{ - PacketConn: pc, - SocketListener: sl, - decoder: decoder, - } - - sl.Closer = psl - sl.wg = sync.WaitGroup{} - sl.wg.Add(1) - go func() { - defer sl.wg.Done() - psl.listen() - }() + sl.listener = psl default: - return fmt.Errorf("unknown protocol '%s' in '%s'", protocol, sl.ServiceAddress) + return fmt.Errorf("unknown protocol %q in %q", u.Scheme, sl.ServiceAddress) } - if protocol == "unix" || protocol == "unixpacket" || protocol == "unixgram" { - sl.Closer = unixCloser{path: spl[1], closer: sl.Closer} - } + sl.Log.Infof("Listening on %s://%s", u.Scheme, sl.listener.addr()) + + sl.wg.Add(1) + go func() { + defer sl.wg.Done() + sl.listener.listen(acc) + }() return nil } -func udpListen(network string, address string) (net.PacketConn, error) { - switch network { - case "udp", "udp4", "udp6": - var addr *net.UDPAddr - var err error - var ifi *net.Interface - if spl := strings.SplitN(address, "%", 2); len(spl) == 2 { - address = spl[0] - ifi, err = net.InterfaceByName(spl[1]) - if err != nil { - return nil, err - } - } - addr, err = net.ResolveUDPAddr(network, address) - if err != nil { - return nil, err - } - if addr.IP.IsMulticast() { - return net.ListenMulticastUDP(network, ifi, addr) - } - return net.ListenUDP(network, addr) - } - return net.ListenPacket(network, address) -} - func (sl *SocketListener) Stop() { - if sl.Closer != nil { + if sl.listener != nil { // Ignore the returned error as we cannot do anything about it anyway - //nolint:errcheck,revive - sl.Close() - sl.Closer = nil + _ = sl.listener.close() + sl.listener = nil } sl.wg.Wait() } -type unixCloser struct { - path string - closer io.Closer -} - -func (uc unixCloser) Close() error { - err := uc.closer.Close() - // Ignore the error if e.g. the file does not exist - //nolint:errcheck,revive - os.Remove(uc.path) - return err -} - func init() { - inputs.Add("socket_listener", func() telegraf.Input { return &SocketListener{} }) + inputs.Add("socket_listener", func() telegraf.Input { + return &SocketListener{} + }) } diff --git a/plugins/inputs/socket_listener/socket_listener_test.go b/plugins/inputs/socket_listener/socket_listener_test.go index 8e2b6f681..f0c277a7a 100644 --- a/plugins/inputs/socket_listener/socket_listener_test.go +++ b/plugins/inputs/socket_listener/socket_listener_test.go @@ -2,6 +2,7 @@ package socket_listener import ( "crypto/tls" + "fmt" "net" "os" "runtime" @@ -100,12 +101,17 @@ func TestSocketListener(t *testing.T) { }, } + serverTLS := pki.TLSServerConfig() + clientTLS := pki.TLSClientConfig() + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { proto := strings.TrimSuffix(tt.schema, "+tls") // Prepare the address and socket if needed var serverAddr string + var tlsCfg *tls.Config + switch proto { case "tcp", "udp": serverAddr = "127.0.0.1:0" @@ -130,7 +136,10 @@ func TestSocketListener(t *testing.T) { ReadBufferSize: tt.buffersize, } if strings.HasSuffix(tt.schema, "tls") { - plugin.ServerConfig = *pki.TLSServerConfig() + plugin.ServerConfig = *serverTLS + var err error + tlsCfg, err = clientTLS.TLSConfig() + require.NoError(t, err) } parser := &influx.Parser{} require.NoError(t, parser.Init()) @@ -142,41 +151,9 @@ func TestSocketListener(t *testing.T) { defer plugin.Stop() // Setup the client for submitting data - var client net.Conn - switch tt.schema { - case "tcp": - var err error - addr := plugin.Closer.(net.Listener).Addr().String() - client, err = net.Dial("tcp", addr) - require.NoError(t, err) - case "tcp+tls": - addr := plugin.Closer.(net.Listener).Addr().String() - tlscfg, err := pki.TLSClientConfig().TLSConfig() - require.NoError(t, err) - client, err = tls.Dial("tcp", addr, tlscfg) - require.NoError(t, err) - case "udp": - var err error - addr := plugin.Closer.(net.PacketConn).LocalAddr().String() - client, err = net.Dial("udp", addr) - require.NoError(t, err) - case "unix": - var err error - client, err = net.Dial("unix", serverAddr) - require.NoError(t, err) - case "unix+tls": - tlscfg, err := pki.TLSClientConfig().TLSConfig() - require.NoError(t, err) - tlscfg.InsecureSkipVerify = true - client, err = tls.Dial("unix", serverAddr, tlscfg) - require.NoError(t, err) - case "unixgram": - var err error - client, err = net.Dial("unixgram", serverAddr) - require.NoError(t, err) - default: - require.Failf(t, "schema %q not supported in test", tt.schema) - } + addr := plugin.listener.addr() + client, err := createClient(plugin.ServiceAddress, addr, tlsCfg) + require.NoError(t, err) // Send the data with the correct encoding encoder, err := internal.NewContentEncoder(tt.encoding) @@ -190,13 +167,31 @@ func TestSocketListener(t *testing.T) { } // Test the resulting metrics and compare against expected results - require.Eventually(t, func() bool { + require.Eventuallyf(t, func() bool { acc.Lock() defer acc.Unlock() return acc.NMetrics() >= uint64(len(expected)) - }, time.Second, 100*time.Millisecond, "did not receive metrics") + }, time.Second, 100*time.Millisecond, "did not receive metrics (%d)", acc.NMetrics()) actual := acc.GetTelegrafMetrics() testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics()) }) } } + +func createClient(endpoint string, addr net.Addr, tlsCfg *tls.Config) (net.Conn, error) { + // Determine the protocol in a crude fashion + parts := strings.SplitN(endpoint, "://", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid endpoint %q", endpoint) + } + protocol := parts[0] + + if tlsCfg == nil { + return net.Dial(protocol, addr.String()) + } + + if protocol == "unix" { + tlsCfg.InsecureSkipVerify = true + } + return tls.Dial(protocol, addr.String(), tlsCfg) +} diff --git a/plugins/inputs/socket_listener/stream_listener.go b/plugins/inputs/socket_listener/stream_listener.go new file mode 100644 index 000000000..90ff4637c --- /dev/null +++ b/plugins/inputs/socket_listener/stream_listener.go @@ -0,0 +1,246 @@ +package socket_listener + +import ( + "bufio" + "crypto/tls" + "errors" + "fmt" + "net" + "net/url" + "os" + "strconv" + "sync" + "time" + + "github.com/influxdata/telegraf" + "github.com/influxdata/telegraf/config" + "github.com/influxdata/telegraf/internal" +) + +type hasSetReadBuffer interface { + SetReadBuffer(bytes int) error +} + +type streamListener struct { + Encoding string + ReadBufferSize int + MaxConnections int + ReadTimeout config.Duration + KeepAlivePeriod *config.Duration + Parser telegraf.Parser + Log telegraf.Logger + + listener net.Listener + connections map[string]net.Conn + path string + + wg sync.WaitGroup + sync.Mutex +} + +func (l *streamListener) setupTCP(u *url.URL, tlsCfg *tls.Config) error { + var err error + if tlsCfg == nil { + l.listener, err = net.Listen(u.Scheme, u.Host) + } else { + l.listener, err = tls.Listen(u.Scheme, u.Host, tlsCfg) + } + return err +} + +func (l *streamListener) setupUnix(u *url.URL, tlsCfg *tls.Config, socketMode string) error { + err := os.Remove(u.Path) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("removing socket failed: %w", err) + } + + if tlsCfg == nil { + l.listener, err = net.Listen(u.Scheme, u.Path) + } else { + l.listener, err = tls.Listen(u.Scheme, u.Path, tlsCfg) + } + if err != nil { + return err + } + l.path = u.Path + + // Set permissions on socket + if socketMode != "" { + // Convert from octal in string to int + i, err := strconv.ParseUint(socketMode, 8, 32) + if err != nil { + return fmt.Errorf("converting socket mode failed: %w", err) + } + + perm := os.FileMode(uint32(i)) + if err := os.Chmod(u.Path, perm); err != nil { + return fmt.Errorf("changing socket permissions failed: %w", err) + } + } + return nil +} + +func (l *streamListener) setupConnection(conn net.Conn) error { + if l.ReadBufferSize > 0 { + if rb, ok := conn.(hasSetReadBuffer); ok { + if err := rb.SetReadBuffer(l.ReadBufferSize); err != nil { + l.Log.Warnf("Setting read buffer on socket failed: %v", err) + } + } else { + l.Log.Warn("Cannot set read buffer on socket of this type") + } + } + + addr := conn.RemoteAddr().String() + if l.MaxConnections > 0 && len(l.connections) >= l.MaxConnections { + // Ignore the returned error as we cannot do anything about it anyway + _ = conn.Close() + l.Log.Infof("unable to accept connection from %q: too many connections", addr) + return nil + } + + // Set keep alive handlings + if l.KeepAlivePeriod != nil { + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return fmt.Errorf("connection not a TCP connection (%T)", conn) + } + if *l.KeepAlivePeriod == 0 { + if err := tcpConn.SetKeepAlive(false); err != nil { + return fmt.Errorf("cannot set keep-alive: %w", err) + } + } else { + if err := tcpConn.SetKeepAlive(true); err != nil { + return fmt.Errorf("cannot set keep-alive: %w", err) + } + err := tcpConn.SetKeepAlivePeriod(time.Duration(*l.KeepAlivePeriod)) + if err != nil { + return fmt.Errorf("cannot set keep-alive period: %w", err) + } + } + } + + // Store the connection mapped to its address + l.Lock() + defer l.Unlock() + l.connections[addr] = conn + + return nil +} + +func (l *streamListener) closeConnection(conn net.Conn) { + l.Lock() + defer l.Unlock() + addr := conn.RemoteAddr().String() + if err := conn.Close(); err != nil { + l.Log.Errorf("Cannot close connection to %q: %v", addr, err) + } + delete(l.connections, addr) +} + +func (l *streamListener) addr() net.Addr { + return l.listener.Addr() +} + +func (l *streamListener) close() error { + if err := l.listener.Close(); err != nil { + return err + } + + for _, conn := range l.connections { + l.closeConnection(conn) + } + l.wg.Wait() + + if l.path != "" { + err := os.Remove(l.path) + if err != nil && !errors.Is(err, os.ErrNotExist) { + // Ignore file-not-exists errors when removing the socket + return err + } + } + return nil +} + +func (l *streamListener) listen(acc telegraf.Accumulator) { + l.connections = make(map[string]net.Conn) + + l.wg.Add(1) + defer l.wg.Done() + + var wg sync.WaitGroup + for { + conn, err := l.listener.Accept() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + acc.AddError(err) + } + break + } + + if err := l.setupConnection(conn); err != nil { + acc.AddError(err) + } + + wg.Add(1) + go func() { + defer wg.Done() + if err := l.read(acc, conn); err != nil { + acc.AddError(err) + } + }() + } + wg.Wait() +} + +func (l *streamListener) read(acc telegraf.Accumulator, conn net.Conn) error { + decoder, err := internal.NewStreamContentDecoder(l.Encoding, conn) + if err != nil { + return fmt.Errorf("creating decoder failed: %w", err) + } + + timeout := time.Duration(l.ReadTimeout) + + scanner := bufio.NewScanner(decoder) + for { + // Set the read deadline, if any, then start reading. The read + // will accept the deadline and return if no or insufficient data + // arrived in time. We need to set the deadline in every cycle as + // it is an ABSOLUTE time and not a timeout. + if timeout > 0 { + deadline := time.Now().Add(timeout) + if err := conn.SetReadDeadline(deadline); err != nil { + return fmt.Errorf("setting read deadline failed: %w", err) + } + } + if !scanner.Scan() { + // Exit if no data arrived e.g. due to timeout or closed connection + break + } + + data := scanner.Bytes() + metrics, err := l.Parser.Parse(data) + if err != nil { + acc.AddError(fmt.Errorf("parsing error: %w", err)) + l.Log.Debugf("invalid data for parser: %v", data) + continue + } + for _, m := range metrics { + acc.AddMetric(m) + } + } + + if err := scanner.Err(); err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + // Ignore the timeout and silently close the connection + l.Log.Debug(err) + return nil + } + if errors.Is(err, net.ErrClosed) { + // Ignore the connection closing of the remote side + return nil + } + return err + } + return nil +}