From 3895416ea8d8ce406891d376d662b28bd84b594a Mon Sep 17 00:00:00 2001 From: Sven Rebhan <36194019+srebhan@users.noreply.github.com> Date: Wed, 23 Apr 2025 20:01:19 +0200 Subject: [PATCH] chore(inputs.snmp_trap): Cleanup code (#16810) --- plugins/inputs/snmp_trap/logger.go | 15 + plugins/inputs/snmp_trap/snmp_trap.go | 465 ++++++++++----------- plugins/inputs/snmp_trap/snmp_trap_test.go | 24 +- 3 files changed, 235 insertions(+), 269 deletions(-) create mode 100644 plugins/inputs/snmp_trap/logger.go diff --git a/plugins/inputs/snmp_trap/logger.go b/plugins/inputs/snmp_trap/logger.go new file mode 100644 index 000000000..397da7a5f --- /dev/null +++ b/plugins/inputs/snmp_trap/logger.go @@ -0,0 +1,15 @@ +package snmp_trap + +import "github.com/influxdata/telegraf" + +type logger struct { + telegraf.Logger +} + +func (l logger) Printf(format string, args ...interface{}) { + l.Tracef(format, args...) +} + +func (l logger) Print(args ...interface{}) { + l.Trace(args...) +} diff --git a/plugins/inputs/snmp_trap/snmp_trap.go b/plugins/inputs/snmp_trap/snmp_trap.go index d6e09f8b0..2d8d6c0e4 100644 --- a/plugins/inputs/snmp_trap/snmp_trap.go +++ b/plugins/inputs/snmp_trap/snmp_trap.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "net" + "net/url" "strconv" "strings" "time" @@ -30,15 +31,12 @@ type SnmpTrap struct { Timeout config.Duration `toml:"timeout"` Version string `toml:"version"` Path []string `toml:"path"` - // Settings for version 3 - // Values: "noAuthNoPriv", "authNoPriv", "authPriv" - SecLevel string `toml:"sec_level"` - SecName config.Secret `toml:"sec_name"` - // Values: "MD5", "SHA", "". Default: "" + // Settings for version 3 security + SecLevel string `toml:"sec_level"` + SecName config.Secret `toml:"sec_name"` AuthProtocol string `toml:"auth_protocol"` AuthPassword config.Secret `toml:"auth_password"` - // Values: "DES", "AES", "". Default: "" PrivProtocol string `toml:"priv_protocol"` PrivPassword config.Secret `toml:"priv_password"` @@ -47,15 +45,8 @@ type SnmpTrap struct { acc telegraf.Accumulator listener *gosnmp.TrapListener - timeFunc func() time.Time - errCh chan error - makeHandlerWrapper func(gosnmp.TrapHandlerFunc) gosnmp.TrapHandlerFunc - transl translator -} - -type wrapLog struct { - telegraf.Logger + transl translator } type translator interface { @@ -71,164 +62,160 @@ func (s *SnmpTrap) SetTranslator(name string) { } func (s *SnmpTrap) Init() error { - var err error + // Set defaults + if s.ServiceAddress == "" { + s.ServiceAddress = "udp://:162" + } + + if len(s.Path) == 0 { + s.Path = []string{"/usr/share/snmp/mibs"} + } + + // Check input parameters switch s.Translator { case "gosmi": - s.transl, err = newGosmiTranslator(s.Path, s.Log) + t, err := newGosmiTranslator(s.Path, s.Log) if err != nil { return err } + s.transl = t case "netsnmp": s.transl = newNetsnmpTranslator(s.Timeout) default: - return errors.New("invalid translator value") + // Ignore the translator for testing if an instance was set + if s.transl == nil { + return errors.New("invalid translator value") + } } - if err != nil { - s.Log.Errorf("Could not get path %v", err) + // Setup the SNMP parameters + params := *gosnmp.Default + switch s.Version { + case "1": + params.Version = gosnmp.Version1 + case "", "2c": + params.Version = gosnmp.Version2c + case "3": + params.Version = gosnmp.Version3 + + // Setup the security for v3 + var security gosnmp.UsmSecurityParameters + params.SecurityModel = gosnmp.UserSecurityModel + + // Set security mechanisms + switch strings.ToLower(s.SecLevel) { + case "noauthnopriv", "": + params.MsgFlags = gosnmp.NoAuthNoPriv + case "authnopriv": + params.MsgFlags = gosnmp.AuthNoPriv + case "authpriv": + params.MsgFlags = gosnmp.AuthPriv + default: + return fmt.Errorf("unknown security level %q", s.SecLevel) + } + + // Set authentication + switch strings.ToLower(s.AuthProtocol) { + case "": + security.AuthenticationProtocol = gosnmp.NoAuth + case "md5": + security.AuthenticationProtocol = gosnmp.MD5 + case "sha": + security.AuthenticationProtocol = gosnmp.SHA + case "sha224": + security.AuthenticationProtocol = gosnmp.SHA224 + case "sha256": + security.AuthenticationProtocol = gosnmp.SHA256 + case "sha384": + security.AuthenticationProtocol = gosnmp.SHA384 + case "sha512": + security.AuthenticationProtocol = gosnmp.SHA512 + default: + return fmt.Errorf("unknown authentication protocol %q", s.AuthProtocol) + } + + // Set privacy + switch strings.ToLower(s.PrivProtocol) { + case "": + security.PrivacyProtocol = gosnmp.NoPriv + case "aes": + security.PrivacyProtocol = gosnmp.AES + case "des": + security.PrivacyProtocol = gosnmp.DES + case "aes192": + security.PrivacyProtocol = gosnmp.AES192 + case "aes192c": + security.PrivacyProtocol = gosnmp.AES192C + case "aes256": + security.PrivacyProtocol = gosnmp.AES256 + case "aes256c": + security.PrivacyProtocol = gosnmp.AES256C + default: + return fmt.Errorf("unknown privacy protocol %q", s.PrivProtocol) + } + + // Set credentials + secnameSecret, err := s.SecName.Get() + if err != nil { + return fmt.Errorf("getting secname failed: %w", err) + } + security.UserName = secnameSecret.String() + secnameSecret.Destroy() + + privPasswdSecret, err := s.PrivPassword.Get() + if err != nil { + return fmt.Errorf("getting priv-password failed: %w", err) + } + security.PrivacyPassphrase = privPasswdSecret.String() + privPasswdSecret.Destroy() + + authPasswdSecret, err := s.AuthPassword.Get() + if err != nil { + return fmt.Errorf("getting auth-password failed: %w", err) + } + security.AuthenticationPassphrase = authPasswdSecret.String() + authPasswdSecret.Destroy() + + params.SecurityParameters = &security + default: + return fmt.Errorf("unknown version %q", s.Version) } + if s.Log.Level().Includes(telegraf.Trace) { + params.Logger = gosnmp.NewLogger(&logger{s.Log}) + } + + // Initialize the listener + s.listener = gosnmp.NewTrapListener() + s.listener.OnNewTrap = s.handler + s.listener.Params = ¶ms + return nil } func (s *SnmpTrap) Start(acc telegraf.Accumulator) error { s.acc = acc - s.listener = gosnmp.NewTrapListener() - s.listener.OnNewTrap = makeTrapHandler(s) - // gosnmp.Default is a pointer, using this more than once - // has side effects - defaults := *gosnmp.Default - s.listener.Params = &defaults - s.listener.Params.Logger = gosnmp.NewLogger(wrapLog{s.Log}) - - switch s.Version { - case "3": - s.listener.Params.Version = gosnmp.Version3 - case "2c": - s.listener.Params.Version = gosnmp.Version2c - case "1": - s.listener.Params.Version = gosnmp.Version1 - default: - s.listener.Params.Version = gosnmp.Version2c - } - - if s.listener.Params.Version == gosnmp.Version3 { - s.listener.Params.SecurityModel = gosnmp.UserSecurityModel - - switch strings.ToLower(s.SecLevel) { - case "noauthnopriv", "": - s.listener.Params.MsgFlags = gosnmp.NoAuthNoPriv - case "authnopriv": - s.listener.Params.MsgFlags = gosnmp.AuthNoPriv - case "authpriv": - s.listener.Params.MsgFlags = gosnmp.AuthPriv - default: - return fmt.Errorf("unknown security level %q", s.SecLevel) - } - - var authenticationProtocol gosnmp.SnmpV3AuthProtocol - switch strings.ToLower(s.AuthProtocol) { - case "md5": - authenticationProtocol = gosnmp.MD5 - case "sha": - authenticationProtocol = gosnmp.SHA - case "sha224": - authenticationProtocol = gosnmp.SHA224 - case "sha256": - authenticationProtocol = gosnmp.SHA256 - case "sha384": - authenticationProtocol = gosnmp.SHA384 - case "sha512": - authenticationProtocol = gosnmp.SHA512 - case "": - authenticationProtocol = gosnmp.NoAuth - default: - return fmt.Errorf("unknown authentication protocol %q", s.AuthProtocol) - } - - var privacyProtocol gosnmp.SnmpV3PrivProtocol - switch strings.ToLower(s.PrivProtocol) { - case "aes": - privacyProtocol = gosnmp.AES - case "des": - privacyProtocol = gosnmp.DES - case "aes192": - privacyProtocol = gosnmp.AES192 - case "aes192c": - privacyProtocol = gosnmp.AES192C - case "aes256": - privacyProtocol = gosnmp.AES256 - case "aes256c": - privacyProtocol = gosnmp.AES256C - case "": - privacyProtocol = gosnmp.NoPriv - default: - return fmt.Errorf("unknown privacy protocol %q", s.PrivProtocol) - } - - secnameSecret, err := s.SecName.Get() - if err != nil { - return fmt.Errorf("getting secname failed: %w", err) - } - secname := secnameSecret.String() - secnameSecret.Destroy() - - privPasswdSecret, err := s.PrivPassword.Get() - if err != nil { - return fmt.Errorf("getting secname failed: %w", err) - } - privPasswd := privPasswdSecret.String() - privPasswdSecret.Destroy() - - authPasswdSecret, err := s.AuthPassword.Get() - if err != nil { - return fmt.Errorf("getting secname failed: %w", err) - } - authPasswd := authPasswdSecret.String() - authPasswdSecret.Destroy() - - s.listener.Params.SecurityParameters = &gosnmp.UsmSecurityParameters{ - UserName: secname, - PrivacyProtocol: privacyProtocol, - PrivacyPassphrase: privPasswd, - AuthenticationPassphrase: authPasswd, - AuthenticationProtocol: authenticationProtocol, - } - } - - // wrap the handler, used in unit tests - if nil != s.makeHandlerWrapper { - s.listener.OnNewTrap = s.makeHandlerWrapper(s.listener.OnNewTrap) - } - - split := strings.SplitN(s.ServiceAddress, "://", 2) - if len(split) != 2 { + u, err := url.Parse(s.ServiceAddress) + if err != nil { return fmt.Errorf("invalid service address: %s", s.ServiceAddress) } - protocol := split[0] - addr := split[1] - - // gosnmp.TrapListener currently supports udp only. For forward - // compatibility, require udp in the service address - if protocol != "udp" { - return fmt.Errorf("unknown protocol %q in %q", protocol, s.ServiceAddress) + // The gosnmp package currently only supports UDP + if u.Scheme != "udp" { + return fmt.Errorf("unknown protocol for service address %q", s.ServiceAddress) } - // If (*TrapListener).Listen immediately returns an error we need - // to return it from this function. Use a channel to get it here - // from the goroutine. Buffer one in case Listen returns after - // Listening but before our Close is called. - s.errCh = make(chan error, 1) + // If the listener immediately returns an error we need to return it + errCh := make(chan error, 1) go func() { - s.errCh <- s.listener.Listen(addr) + errCh <- s.listener.Listen(u.Host) }() select { case <-s.listener.Listening(): s.Log.Infof("Listening on %s", s.ServiceAddress) - case err := <-s.errCh: - return err + case err := <-errCh: + return fmt.Errorf("listening failed: %w", err) } return nil @@ -240,10 +227,6 @@ func (*SnmpTrap) Gather(telegraf.Accumulator) error { func (s *SnmpTrap) Stop() { s.listener.Close() - err := <-s.errCh - if nil != err { - s.Log.Errorf("Error stopping trap listener %v", err) - } } func setTrapOid(tags map[string]string, oid string, e snmp.MibEntry) { @@ -252,131 +235,115 @@ func setTrapOid(tags map[string]string, oid string, e snmp.MibEntry) { tags["mib"] = e.MibName } -func makeTrapHandler(s *SnmpTrap) gosnmp.TrapHandlerFunc { - return func(packet *gosnmp.SnmpPacket, addr *net.UDPAddr) { - tm := s.timeFunc() - fields := make(map[string]interface{}, len(packet.Variables)+1) - tags := map[string]string{ - "version": packet.Version.String(), - "source": addr.IP.String(), +func (s *SnmpTrap) handler(packet *gosnmp.SnmpPacket, addr *net.UDPAddr) { + tm := time.Now() + fields := make(map[string]interface{}, len(packet.Variables)+1) + tags := map[string]string{ + "version": packet.Version.String(), + "source": addr.IP.String(), + } + + if packet.Version == gosnmp.Version1 { + // Follow the procedure described in RFC 2576 3.1 to + // translate a v1 trap to v2. + var trapOid string + + if packet.GenericTrap >= 0 && packet.GenericTrap < 6 { + trapOid = ".1.3.6.1.6.3.1.1.5." + strconv.Itoa(packet.GenericTrap+1) + } else if packet.GenericTrap == 6 { + trapOid = packet.Enterprise + ".0." + strconv.Itoa(packet.SpecificTrap) } - if packet.Version == gosnmp.Version1 { - // Follow the procedure described in RFC 2576 3.1 to - // translate a v1 trap to v2. - var trapOid string - - if packet.GenericTrap >= 0 && packet.GenericTrap < 6 { - trapOid = ".1.3.6.1.6.3.1.1.5." + strconv.Itoa(packet.GenericTrap+1) - } else if packet.GenericTrap == 6 { - trapOid = packet.Enterprise + ".0." + strconv.Itoa(packet.SpecificTrap) + if trapOid != "" { + e, err := s.transl.lookup(trapOid) + if err != nil { + s.Log.Errorf("Error resolving V1 OID, oid=%s, source=%s: %v", trapOid, tags["source"], err) + return } - - if trapOid != "" { - e, err := s.transl.lookup(trapOid) - if err != nil { - s.Log.Errorf("Error resolving V1 OID, oid=%s, source=%s: %v", trapOid, tags["source"], err) - return - } - setTrapOid(tags, trapOid, e) - } - - if packet.AgentAddress != "" { - tags["agent_address"] = packet.AgentAddress - } - - fields["sysUpTimeInstance"] = packet.Timestamp + setTrapOid(tags, trapOid, e) } - for _, v := range packet.Variables { - // Use system mibs to resolve oids. Don't fall back to - // numeric oid because it's not useful enough to the end - // user and can be difficult to translate or remove from - // the database later. + if packet.AgentAddress != "" { + tags["agent_address"] = packet.AgentAddress + } - var value interface{} + fields["sysUpTimeInstance"] = packet.Timestamp + } - // todo: format the pdu value based on its snmp type and - // the mib's textual convention. The snmp input plugin - // only handles textual convention for ip and mac - // addresses + for _, v := range packet.Variables { + var value interface{} - switch v.Type { - case gosnmp.ObjectIdentifier: - val, ok := v.Value.(string) - if !ok { - s.Log.Errorf("Error getting value OID") - return - } - - var e snmp.MibEntry - var err error - e, err = s.transl.lookup(val) - if nil != err { - s.Log.Errorf("Error resolving value OID, oid=%s, source=%s: %v", val, tags["source"], err) - return - } - - value = e.OidText - - // 1.3.6.1.6.3.1.1.4.1.0 is SNMPv2-MIB::snmpTrapOID.0. - // If v.Name is this oid, set a tag of the trap name. - if v.Name == ".1.3.6.1.6.3.1.1.4.1.0" { - setTrapOid(tags, val, e) - continue - } - case gosnmp.OctetString: - // OctetStrings may contain hex data that needs its own conversion - if !utf8.Valid(v.Value.([]byte)[:]) { - value = hex.EncodeToString(v.Value.([]byte)) - } else { - value = v.Value - } - default: - value = v.Value - } - - e, err := s.transl.lookup(v.Name) - if nil != err { - s.Log.Errorf("Error resolving OID oid=%s, source=%s: %v", v.Name, tags["source"], err) + // Use system mibs to resolve oids. Don't fall back to numeric oid + // because it's not useful enough to the end user and can be difficult + // to translate or remove from the database later. + // + // TODO: format the pdu value based on its snmp type and the mib's + // textual convention. The snmp input plugin only handles textual + // convention for ip and mac addresses + switch v.Type { + case gosnmp.ObjectIdentifier: + val, ok := v.Value.(string) + if !ok { + s.Log.Errorf("Error getting value OID") return } - name := e.OidText + var e snmp.MibEntry + var err error + e, err = s.transl.lookup(val) + if nil != err { + s.Log.Errorf("Error resolving value OID, oid=%s, source=%s: %v", val, tags["source"], err) + return + } - fields[name] = value + value = e.OidText + + // 1.3.6.1.6.3.1.1.4.1.0 is SNMPv2-MIB::snmpTrapOID.0. + // If v.Name is this oid, set a tag of the trap name. + if v.Name == ".1.3.6.1.6.3.1.1.4.1.0" { + setTrapOid(tags, val, e) + continue + } + case gosnmp.OctetString: + // OctetStrings may contain hex data that needs its own conversion + if !utf8.Valid(v.Value.([]byte)[:]) { + value = hex.EncodeToString(v.Value.([]byte)) + } else { + value = v.Value + } + default: + value = v.Value } - if packet.Version == gosnmp.Version3 { - if packet.ContextName != "" { - tags["context_name"] = packet.ContextName - } - if packet.ContextEngineID != "" { - // SNMP RFCs like 3411 and 5343 show engine ID as a hex string - tags["engine_id"] = fmt.Sprintf("%x", packet.ContextEngineID) - } - } else { - if packet.Community != "" { - tags["community"] = packet.Community - } + e, err := s.transl.lookup(v.Name) + if nil != err { + s.Log.Errorf("Error resolving OID oid=%s, source=%s: %v", v.Name, tags["source"], err) + return } - s.acc.AddFields("snmp_trap", fields, tags, tm) + fields[e.OidText] = value } -} -func (l wrapLog) Printf(format string, args ...interface{}) { - l.Debugf(format, args...) -} + if packet.Version == gosnmp.Version3 { + if packet.ContextName != "" { + tags["context_name"] = packet.ContextName + } + if packet.ContextEngineID != "" { + // SNMP RFCs like 3411 and 5343 show engine ID as a hex string + tags["engine_id"] = fmt.Sprintf("%x", packet.ContextEngineID) + } + } else { + if packet.Community != "" { + tags["community"] = packet.Community + } + } -func (l wrapLog) Print(args ...interface{}) { - l.Debug(args...) + s.acc.AddFields("snmp_trap", fields, tags, tm) } func init() { inputs.Add("snmp_trap", func() telegraf.Input { return &SnmpTrap{ - timeFunc: time.Now, ServiceAddress: "udp://:162", Timeout: defaultTimeout, Path: []string{"/usr/share/snmp/mibs"}, diff --git a/plugins/inputs/snmp_trap/snmp_trap_test.go b/plugins/inputs/snmp_trap/snmp_trap_test.go index c71643df6..dcc158b2c 100644 --- a/plugins/inputs/snmp_trap/snmp_trap_test.go +++ b/plugins/inputs/snmp_trap/snmp_trap_test.go @@ -165,15 +165,11 @@ func TestReceiveTrapV1(t *testing.T) { plugin := &SnmpTrap{ ServiceAddress: "udp://:" + strconv.Itoa(port), Version: "1", - Translator: "netsnmp", Log: testutil.Logger{}, - timeFunc: time.Now, + transl: &testTranslator{entries: tt.entries}, } require.NoError(t, plugin.Init()) - // inject test translator - plugin.transl = &testTranslator{entries: tt.entries} - // Start the plugin var acc testutil.Accumulator require.NoError(t, plugin.Start(&acc)) @@ -294,15 +290,11 @@ func TestReceiveTrapV2c(t *testing.T) { plugin := &SnmpTrap{ ServiceAddress: "udp://:" + strconv.Itoa(port), Version: "2c", - Translator: "netsnmp", Log: testutil.Logger{}, - timeFunc: time.Now, + transl: &testTranslator{entries: tt.entries}, } require.NoError(t, plugin.Init()) - // inject test translator - plugin.transl = &testTranslator{entries: tt.entries} - var acc testutil.Accumulator require.NoError(t, plugin.Start(&acc)) defer plugin.Stop() @@ -1244,7 +1236,6 @@ func TestReceiveTrapV3(t *testing.T) { plugin := &SnmpTrap{ ServiceAddress: "udp://:" + strconv.Itoa(port), Version: "3", - Translator: "netsnmp", SecName: config.NewSecret([]byte(tt.secName)), SecLevel: tt.secLevel, AuthProtocol: tt.authProto, @@ -1252,13 +1243,10 @@ func TestReceiveTrapV3(t *testing.T) { PrivProtocol: tt.privProto, PrivPassword: config.NewSecret([]byte(tt.privPass)), Log: testutil.Logger{}, - timeFunc: time.Now, + transl: &testTranslator{entries: tt.entries}, } require.NoError(t, plugin.Init()) - // inject test translator - plugin.transl = &testTranslator{entries: tt.entries} - var acc testutil.Accumulator require.NoError(t, plugin.Start(&acc)) defer plugin.Stop() @@ -1342,15 +1330,11 @@ func TestOidLookupFail(t *testing.T) { plugin := &SnmpTrap{ ServiceAddress: "udp://:" + strconv.Itoa(port), Version: "2c", - Translator: "netsnmp", Log: logger, - timeFunc: time.Now, + transl: &testTranslator{fail: fail}, } require.NoError(t, plugin.Init()) - // inject test translator - plugin.transl = &testTranslator{fail: fail} - var acc testutil.Accumulator require.NoError(t, plugin.Start(&acc)) defer plugin.Stop()