chore(inputs.snmp_trap): Cleanup code (#16810)

This commit is contained in:
Sven Rebhan 2025-04-23 20:01:19 +02:00 committed by GitHub
parent 403199fb46
commit 3895416ea8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 235 additions and 269 deletions

View File

@ -0,0 +1,15 @@
package snmp_trap
import "github.com/influxdata/telegraf"
type logger struct {
telegraf.Logger
}
func (l logger) Printf(format string, args ...interface{}) {
l.Tracef(format, args...)
}
func (l logger) Print(args ...interface{}) {
l.Trace(args...)
}

View File

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -30,15 +31,12 @@ type SnmpTrap struct {
Timeout config.Duration `toml:"timeout"` Timeout config.Duration `toml:"timeout"`
Version string `toml:"version"` Version string `toml:"version"`
Path []string `toml:"path"` Path []string `toml:"path"`
// Settings for version 3
// Values: "noAuthNoPriv", "authNoPriv", "authPriv"
SecLevel string `toml:"sec_level"`
SecName config.Secret `toml:"sec_name"` // Settings for version 3 security
// Values: "MD5", "SHA", "". Default: "" SecLevel string `toml:"sec_level"`
SecName config.Secret `toml:"sec_name"`
AuthProtocol string `toml:"auth_protocol"` AuthProtocol string `toml:"auth_protocol"`
AuthPassword config.Secret `toml:"auth_password"` AuthPassword config.Secret `toml:"auth_password"`
// Values: "DES", "AES", "". Default: ""
PrivProtocol string `toml:"priv_protocol"` PrivProtocol string `toml:"priv_protocol"`
PrivPassword config.Secret `toml:"priv_password"` PrivPassword config.Secret `toml:"priv_password"`
@ -47,15 +45,8 @@ type SnmpTrap struct {
acc telegraf.Accumulator acc telegraf.Accumulator
listener *gosnmp.TrapListener listener *gosnmp.TrapListener
timeFunc func() time.Time
errCh chan error
makeHandlerWrapper func(gosnmp.TrapHandlerFunc) gosnmp.TrapHandlerFunc transl translator
transl translator
}
type wrapLog struct {
telegraf.Logger
} }
type translator interface { type translator interface {
@ -71,164 +62,160 @@ func (s *SnmpTrap) SetTranslator(name string) {
} }
func (s *SnmpTrap) Init() error { func (s *SnmpTrap) Init() error {
var err error // Set defaults
if s.ServiceAddress == "" {
s.ServiceAddress = "udp://:162"
}
if len(s.Path) == 0 {
s.Path = []string{"/usr/share/snmp/mibs"}
}
// Check input parameters
switch s.Translator { switch s.Translator {
case "gosmi": case "gosmi":
s.transl, err = newGosmiTranslator(s.Path, s.Log) t, err := newGosmiTranslator(s.Path, s.Log)
if err != nil { if err != nil {
return err return err
} }
s.transl = t
case "netsnmp": case "netsnmp":
s.transl = newNetsnmpTranslator(s.Timeout) s.transl = newNetsnmpTranslator(s.Timeout)
default: default:
return errors.New("invalid translator value") // Ignore the translator for testing if an instance was set
if s.transl == nil {
return errors.New("invalid translator value")
}
} }
if err != nil { // Setup the SNMP parameters
s.Log.Errorf("Could not get path %v", err) params := *gosnmp.Default
switch s.Version {
case "1":
params.Version = gosnmp.Version1
case "", "2c":
params.Version = gosnmp.Version2c
case "3":
params.Version = gosnmp.Version3
// Setup the security for v3
var security gosnmp.UsmSecurityParameters
params.SecurityModel = gosnmp.UserSecurityModel
// Set security mechanisms
switch strings.ToLower(s.SecLevel) {
case "noauthnopriv", "":
params.MsgFlags = gosnmp.NoAuthNoPriv
case "authnopriv":
params.MsgFlags = gosnmp.AuthNoPriv
case "authpriv":
params.MsgFlags = gosnmp.AuthPriv
default:
return fmt.Errorf("unknown security level %q", s.SecLevel)
}
// Set authentication
switch strings.ToLower(s.AuthProtocol) {
case "":
security.AuthenticationProtocol = gosnmp.NoAuth
case "md5":
security.AuthenticationProtocol = gosnmp.MD5
case "sha":
security.AuthenticationProtocol = gosnmp.SHA
case "sha224":
security.AuthenticationProtocol = gosnmp.SHA224
case "sha256":
security.AuthenticationProtocol = gosnmp.SHA256
case "sha384":
security.AuthenticationProtocol = gosnmp.SHA384
case "sha512":
security.AuthenticationProtocol = gosnmp.SHA512
default:
return fmt.Errorf("unknown authentication protocol %q", s.AuthProtocol)
}
// Set privacy
switch strings.ToLower(s.PrivProtocol) {
case "":
security.PrivacyProtocol = gosnmp.NoPriv
case "aes":
security.PrivacyProtocol = gosnmp.AES
case "des":
security.PrivacyProtocol = gosnmp.DES
case "aes192":
security.PrivacyProtocol = gosnmp.AES192
case "aes192c":
security.PrivacyProtocol = gosnmp.AES192C
case "aes256":
security.PrivacyProtocol = gosnmp.AES256
case "aes256c":
security.PrivacyProtocol = gosnmp.AES256C
default:
return fmt.Errorf("unknown privacy protocol %q", s.PrivProtocol)
}
// Set credentials
secnameSecret, err := s.SecName.Get()
if err != nil {
return fmt.Errorf("getting secname failed: %w", err)
}
security.UserName = secnameSecret.String()
secnameSecret.Destroy()
privPasswdSecret, err := s.PrivPassword.Get()
if err != nil {
return fmt.Errorf("getting priv-password failed: %w", err)
}
security.PrivacyPassphrase = privPasswdSecret.String()
privPasswdSecret.Destroy()
authPasswdSecret, err := s.AuthPassword.Get()
if err != nil {
return fmt.Errorf("getting auth-password failed: %w", err)
}
security.AuthenticationPassphrase = authPasswdSecret.String()
authPasswdSecret.Destroy()
params.SecurityParameters = &security
default:
return fmt.Errorf("unknown version %q", s.Version)
} }
if s.Log.Level().Includes(telegraf.Trace) {
params.Logger = gosnmp.NewLogger(&logger{s.Log})
}
// Initialize the listener
s.listener = gosnmp.NewTrapListener()
s.listener.OnNewTrap = s.handler
s.listener.Params = &params
return nil return nil
} }
func (s *SnmpTrap) Start(acc telegraf.Accumulator) error { func (s *SnmpTrap) Start(acc telegraf.Accumulator) error {
s.acc = acc s.acc = acc
s.listener = gosnmp.NewTrapListener()
s.listener.OnNewTrap = makeTrapHandler(s)
// gosnmp.Default is a pointer, using this more than once u, err := url.Parse(s.ServiceAddress)
// has side effects if err != nil {
defaults := *gosnmp.Default
s.listener.Params = &defaults
s.listener.Params.Logger = gosnmp.NewLogger(wrapLog{s.Log})
switch s.Version {
case "3":
s.listener.Params.Version = gosnmp.Version3
case "2c":
s.listener.Params.Version = gosnmp.Version2c
case "1":
s.listener.Params.Version = gosnmp.Version1
default:
s.listener.Params.Version = gosnmp.Version2c
}
if s.listener.Params.Version == gosnmp.Version3 {
s.listener.Params.SecurityModel = gosnmp.UserSecurityModel
switch strings.ToLower(s.SecLevel) {
case "noauthnopriv", "":
s.listener.Params.MsgFlags = gosnmp.NoAuthNoPriv
case "authnopriv":
s.listener.Params.MsgFlags = gosnmp.AuthNoPriv
case "authpriv":
s.listener.Params.MsgFlags = gosnmp.AuthPriv
default:
return fmt.Errorf("unknown security level %q", s.SecLevel)
}
var authenticationProtocol gosnmp.SnmpV3AuthProtocol
switch strings.ToLower(s.AuthProtocol) {
case "md5":
authenticationProtocol = gosnmp.MD5
case "sha":
authenticationProtocol = gosnmp.SHA
case "sha224":
authenticationProtocol = gosnmp.SHA224
case "sha256":
authenticationProtocol = gosnmp.SHA256
case "sha384":
authenticationProtocol = gosnmp.SHA384
case "sha512":
authenticationProtocol = gosnmp.SHA512
case "":
authenticationProtocol = gosnmp.NoAuth
default:
return fmt.Errorf("unknown authentication protocol %q", s.AuthProtocol)
}
var privacyProtocol gosnmp.SnmpV3PrivProtocol
switch strings.ToLower(s.PrivProtocol) {
case "aes":
privacyProtocol = gosnmp.AES
case "des":
privacyProtocol = gosnmp.DES
case "aes192":
privacyProtocol = gosnmp.AES192
case "aes192c":
privacyProtocol = gosnmp.AES192C
case "aes256":
privacyProtocol = gosnmp.AES256
case "aes256c":
privacyProtocol = gosnmp.AES256C
case "":
privacyProtocol = gosnmp.NoPriv
default:
return fmt.Errorf("unknown privacy protocol %q", s.PrivProtocol)
}
secnameSecret, err := s.SecName.Get()
if err != nil {
return fmt.Errorf("getting secname failed: %w", err)
}
secname := secnameSecret.String()
secnameSecret.Destroy()
privPasswdSecret, err := s.PrivPassword.Get()
if err != nil {
return fmt.Errorf("getting secname failed: %w", err)
}
privPasswd := privPasswdSecret.String()
privPasswdSecret.Destroy()
authPasswdSecret, err := s.AuthPassword.Get()
if err != nil {
return fmt.Errorf("getting secname failed: %w", err)
}
authPasswd := authPasswdSecret.String()
authPasswdSecret.Destroy()
s.listener.Params.SecurityParameters = &gosnmp.UsmSecurityParameters{
UserName: secname,
PrivacyProtocol: privacyProtocol,
PrivacyPassphrase: privPasswd,
AuthenticationPassphrase: authPasswd,
AuthenticationProtocol: authenticationProtocol,
}
}
// wrap the handler, used in unit tests
if nil != s.makeHandlerWrapper {
s.listener.OnNewTrap = s.makeHandlerWrapper(s.listener.OnNewTrap)
}
split := strings.SplitN(s.ServiceAddress, "://", 2)
if len(split) != 2 {
return fmt.Errorf("invalid service address: %s", s.ServiceAddress) return fmt.Errorf("invalid service address: %s", s.ServiceAddress)
} }
protocol := split[0] // The gosnmp package currently only supports UDP
addr := split[1] if u.Scheme != "udp" {
return fmt.Errorf("unknown protocol for service address %q", s.ServiceAddress)
// gosnmp.TrapListener currently supports udp only. For forward
// compatibility, require udp in the service address
if protocol != "udp" {
return fmt.Errorf("unknown protocol %q in %q", protocol, s.ServiceAddress)
} }
// If (*TrapListener).Listen immediately returns an error we need // If the listener immediately returns an error we need to return it
// to return it from this function. Use a channel to get it here errCh := make(chan error, 1)
// from the goroutine. Buffer one in case Listen returns after
// Listening but before our Close is called.
s.errCh = make(chan error, 1)
go func() { go func() {
s.errCh <- s.listener.Listen(addr) errCh <- s.listener.Listen(u.Host)
}() }()
select { select {
case <-s.listener.Listening(): case <-s.listener.Listening():
s.Log.Infof("Listening on %s", s.ServiceAddress) s.Log.Infof("Listening on %s", s.ServiceAddress)
case err := <-s.errCh: case err := <-errCh:
return err return fmt.Errorf("listening failed: %w", err)
} }
return nil return nil
@ -240,10 +227,6 @@ func (*SnmpTrap) Gather(telegraf.Accumulator) error {
func (s *SnmpTrap) Stop() { func (s *SnmpTrap) Stop() {
s.listener.Close() s.listener.Close()
err := <-s.errCh
if nil != err {
s.Log.Errorf("Error stopping trap listener %v", err)
}
} }
func setTrapOid(tags map[string]string, oid string, e snmp.MibEntry) { func setTrapOid(tags map[string]string, oid string, e snmp.MibEntry) {
@ -252,131 +235,115 @@ func setTrapOid(tags map[string]string, oid string, e snmp.MibEntry) {
tags["mib"] = e.MibName tags["mib"] = e.MibName
} }
func makeTrapHandler(s *SnmpTrap) gosnmp.TrapHandlerFunc { func (s *SnmpTrap) handler(packet *gosnmp.SnmpPacket, addr *net.UDPAddr) {
return func(packet *gosnmp.SnmpPacket, addr *net.UDPAddr) { tm := time.Now()
tm := s.timeFunc() fields := make(map[string]interface{}, len(packet.Variables)+1)
fields := make(map[string]interface{}, len(packet.Variables)+1) tags := map[string]string{
tags := map[string]string{ "version": packet.Version.String(),
"version": packet.Version.String(), "source": addr.IP.String(),
"source": addr.IP.String(), }
if packet.Version == gosnmp.Version1 {
// Follow the procedure described in RFC 2576 3.1 to
// translate a v1 trap to v2.
var trapOid string
if packet.GenericTrap >= 0 && packet.GenericTrap < 6 {
trapOid = ".1.3.6.1.6.3.1.1.5." + strconv.Itoa(packet.GenericTrap+1)
} else if packet.GenericTrap == 6 {
trapOid = packet.Enterprise + ".0." + strconv.Itoa(packet.SpecificTrap)
} }
if packet.Version == gosnmp.Version1 { if trapOid != "" {
// Follow the procedure described in RFC 2576 3.1 to e, err := s.transl.lookup(trapOid)
// translate a v1 trap to v2. if err != nil {
var trapOid string s.Log.Errorf("Error resolving V1 OID, oid=%s, source=%s: %v", trapOid, tags["source"], err)
return
if packet.GenericTrap >= 0 && packet.GenericTrap < 6 {
trapOid = ".1.3.6.1.6.3.1.1.5." + strconv.Itoa(packet.GenericTrap+1)
} else if packet.GenericTrap == 6 {
trapOid = packet.Enterprise + ".0." + strconv.Itoa(packet.SpecificTrap)
} }
setTrapOid(tags, trapOid, e)
if trapOid != "" {
e, err := s.transl.lookup(trapOid)
if err != nil {
s.Log.Errorf("Error resolving V1 OID, oid=%s, source=%s: %v", trapOid, tags["source"], err)
return
}
setTrapOid(tags, trapOid, e)
}
if packet.AgentAddress != "" {
tags["agent_address"] = packet.AgentAddress
}
fields["sysUpTimeInstance"] = packet.Timestamp
} }
for _, v := range packet.Variables { if packet.AgentAddress != "" {
// Use system mibs to resolve oids. Don't fall back to tags["agent_address"] = packet.AgentAddress
// numeric oid because it's not useful enough to the end }
// user and can be difficult to translate or remove from
// the database later.
var value interface{} fields["sysUpTimeInstance"] = packet.Timestamp
}
// todo: format the pdu value based on its snmp type and for _, v := range packet.Variables {
// the mib's textual convention. The snmp input plugin var value interface{}
// only handles textual convention for ip and mac
// addresses
switch v.Type { // Use system mibs to resolve oids. Don't fall back to numeric oid
case gosnmp.ObjectIdentifier: // because it's not useful enough to the end user and can be difficult
val, ok := v.Value.(string) // to translate or remove from the database later.
if !ok { //
s.Log.Errorf("Error getting value OID") // TODO: format the pdu value based on its snmp type and the mib's
return // textual convention. The snmp input plugin only handles textual
} // convention for ip and mac addresses
switch v.Type {
var e snmp.MibEntry case gosnmp.ObjectIdentifier:
var err error val, ok := v.Value.(string)
e, err = s.transl.lookup(val) if !ok {
if nil != err { s.Log.Errorf("Error getting value OID")
s.Log.Errorf("Error resolving value OID, oid=%s, source=%s: %v", val, tags["source"], err)
return
}
value = e.OidText
// 1.3.6.1.6.3.1.1.4.1.0 is SNMPv2-MIB::snmpTrapOID.0.
// If v.Name is this oid, set a tag of the trap name.
if v.Name == ".1.3.6.1.6.3.1.1.4.1.0" {
setTrapOid(tags, val, e)
continue
}
case gosnmp.OctetString:
// OctetStrings may contain hex data that needs its own conversion
if !utf8.Valid(v.Value.([]byte)[:]) {
value = hex.EncodeToString(v.Value.([]byte))
} else {
value = v.Value
}
default:
value = v.Value
}
e, err := s.transl.lookup(v.Name)
if nil != err {
s.Log.Errorf("Error resolving OID oid=%s, source=%s: %v", v.Name, tags["source"], err)
return return
} }
name := e.OidText var e snmp.MibEntry
var err error
e, err = s.transl.lookup(val)
if nil != err {
s.Log.Errorf("Error resolving value OID, oid=%s, source=%s: %v", val, tags["source"], err)
return
}
fields[name] = value value = e.OidText
// 1.3.6.1.6.3.1.1.4.1.0 is SNMPv2-MIB::snmpTrapOID.0.
// If v.Name is this oid, set a tag of the trap name.
if v.Name == ".1.3.6.1.6.3.1.1.4.1.0" {
setTrapOid(tags, val, e)
continue
}
case gosnmp.OctetString:
// OctetStrings may contain hex data that needs its own conversion
if !utf8.Valid(v.Value.([]byte)[:]) {
value = hex.EncodeToString(v.Value.([]byte))
} else {
value = v.Value
}
default:
value = v.Value
} }
if packet.Version == gosnmp.Version3 { e, err := s.transl.lookup(v.Name)
if packet.ContextName != "" { if nil != err {
tags["context_name"] = packet.ContextName s.Log.Errorf("Error resolving OID oid=%s, source=%s: %v", v.Name, tags["source"], err)
} return
if packet.ContextEngineID != "" {
// SNMP RFCs like 3411 and 5343 show engine ID as a hex string
tags["engine_id"] = fmt.Sprintf("%x", packet.ContextEngineID)
}
} else {
if packet.Community != "" {
tags["community"] = packet.Community
}
} }
s.acc.AddFields("snmp_trap", fields, tags, tm) fields[e.OidText] = value
} }
}
func (l wrapLog) Printf(format string, args ...interface{}) { if packet.Version == gosnmp.Version3 {
l.Debugf(format, args...) if packet.ContextName != "" {
} tags["context_name"] = packet.ContextName
}
if packet.ContextEngineID != "" {
// SNMP RFCs like 3411 and 5343 show engine ID as a hex string
tags["engine_id"] = fmt.Sprintf("%x", packet.ContextEngineID)
}
} else {
if packet.Community != "" {
tags["community"] = packet.Community
}
}
func (l wrapLog) Print(args ...interface{}) { s.acc.AddFields("snmp_trap", fields, tags, tm)
l.Debug(args...)
} }
func init() { func init() {
inputs.Add("snmp_trap", func() telegraf.Input { inputs.Add("snmp_trap", func() telegraf.Input {
return &SnmpTrap{ return &SnmpTrap{
timeFunc: time.Now,
ServiceAddress: "udp://:162", ServiceAddress: "udp://:162",
Timeout: defaultTimeout, Timeout: defaultTimeout,
Path: []string{"/usr/share/snmp/mibs"}, Path: []string{"/usr/share/snmp/mibs"},

View File

@ -165,15 +165,11 @@ func TestReceiveTrapV1(t *testing.T) {
plugin := &SnmpTrap{ plugin := &SnmpTrap{
ServiceAddress: "udp://:" + strconv.Itoa(port), ServiceAddress: "udp://:" + strconv.Itoa(port),
Version: "1", Version: "1",
Translator: "netsnmp",
Log: testutil.Logger{}, Log: testutil.Logger{},
timeFunc: time.Now, transl: &testTranslator{entries: tt.entries},
} }
require.NoError(t, plugin.Init()) require.NoError(t, plugin.Init())
// inject test translator
plugin.transl = &testTranslator{entries: tt.entries}
// Start the plugin // Start the plugin
var acc testutil.Accumulator var acc testutil.Accumulator
require.NoError(t, plugin.Start(&acc)) require.NoError(t, plugin.Start(&acc))
@ -294,15 +290,11 @@ func TestReceiveTrapV2c(t *testing.T) {
plugin := &SnmpTrap{ plugin := &SnmpTrap{
ServiceAddress: "udp://:" + strconv.Itoa(port), ServiceAddress: "udp://:" + strconv.Itoa(port),
Version: "2c", Version: "2c",
Translator: "netsnmp",
Log: testutil.Logger{}, Log: testutil.Logger{},
timeFunc: time.Now, transl: &testTranslator{entries: tt.entries},
} }
require.NoError(t, plugin.Init()) require.NoError(t, plugin.Init())
// inject test translator
plugin.transl = &testTranslator{entries: tt.entries}
var acc testutil.Accumulator var acc testutil.Accumulator
require.NoError(t, plugin.Start(&acc)) require.NoError(t, plugin.Start(&acc))
defer plugin.Stop() defer plugin.Stop()
@ -1244,7 +1236,6 @@ func TestReceiveTrapV3(t *testing.T) {
plugin := &SnmpTrap{ plugin := &SnmpTrap{
ServiceAddress: "udp://:" + strconv.Itoa(port), ServiceAddress: "udp://:" + strconv.Itoa(port),
Version: "3", Version: "3",
Translator: "netsnmp",
SecName: config.NewSecret([]byte(tt.secName)), SecName: config.NewSecret([]byte(tt.secName)),
SecLevel: tt.secLevel, SecLevel: tt.secLevel,
AuthProtocol: tt.authProto, AuthProtocol: tt.authProto,
@ -1252,13 +1243,10 @@ func TestReceiveTrapV3(t *testing.T) {
PrivProtocol: tt.privProto, PrivProtocol: tt.privProto,
PrivPassword: config.NewSecret([]byte(tt.privPass)), PrivPassword: config.NewSecret([]byte(tt.privPass)),
Log: testutil.Logger{}, Log: testutil.Logger{},
timeFunc: time.Now, transl: &testTranslator{entries: tt.entries},
} }
require.NoError(t, plugin.Init()) require.NoError(t, plugin.Init())
// inject test translator
plugin.transl = &testTranslator{entries: tt.entries}
var acc testutil.Accumulator var acc testutil.Accumulator
require.NoError(t, plugin.Start(&acc)) require.NoError(t, plugin.Start(&acc))
defer plugin.Stop() defer plugin.Stop()
@ -1342,15 +1330,11 @@ func TestOidLookupFail(t *testing.T) {
plugin := &SnmpTrap{ plugin := &SnmpTrap{
ServiceAddress: "udp://:" + strconv.Itoa(port), ServiceAddress: "udp://:" + strconv.Itoa(port),
Version: "2c", Version: "2c",
Translator: "netsnmp",
Log: logger, Log: logger,
timeFunc: time.Now, transl: &testTranslator{fail: fail},
} }
require.NoError(t, plugin.Init()) require.NoError(t, plugin.Init())
// inject test translator
plugin.transl = &testTranslator{fail: fail}
var acc testutil.Accumulator var acc testutil.Accumulator
require.NoError(t, plugin.Start(&acc)) require.NoError(t, plugin.Start(&acc))
defer plugin.Stop() defer plugin.Stop()