feat(inputs.dns_query): Add IP field(s) (#12519)

This commit is contained in:
Sven Rebhan 2023-01-20 16:40:43 +01:00 committed by GitHub
parent bfb26a8af6
commit 410226051d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 294 additions and 207 deletions

View File

@ -33,8 +33,14 @@ See the [CONFIGURATION.md][CONFIGURATION.md] for more details.
## Dns server port. ## Dns server port.
# port = 53 # port = 53
## Query timeout in seconds. ## Query timeout
# timeout = 2 # timeout = "2s"
## Include the specified additional properties in the resulting metric.
## The following values are supported:
## "first_ip" -- return IP of the first A and AAAA answer
## "all_ips" -- return IPs of all A and AAAA answers
# include_fields = []
``` ```
## Metrics ## Metrics

View File

@ -12,6 +12,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/influxdata/telegraf" "github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/plugins/inputs" "github.com/influxdata/telegraf/plugins/inputs"
) )
@ -27,76 +28,39 @@ const (
) )
type DNSQuery struct { type DNSQuery struct {
// Domains or subdomains to query Domains []string `toml:"domains"`
Domains []string Network string `toml:"network"`
Servers []string `toml:"servers"`
// Network protocol name
Network string
// Server to query
Servers []string
// Record type
RecordType string `toml:"record_type"` RecordType string `toml:"record_type"`
Port int `toml:"port"`
Timeout config.Duration `toml:"timeout"`
IncludeFields []string `toml:"include_fields"`
// DNS server port number fieldEnabled map[string]bool
Port int
// Dns query timeout in seconds. 0 means no timeout
Timeout int
} }
func (*DNSQuery) SampleConfig() string { func (*DNSQuery) SampleConfig() string {
return sampleConfig return sampleConfig
} }
func (d *DNSQuery) Gather(acc telegraf.Accumulator) error { func (d *DNSQuery) Init() error {
var wg sync.WaitGroup // Convert the included fields into a lookup-table
d.setDefaultValues() d.fieldEnabled = make(map[string]bool, len(d.IncludeFields))
for _, f := range d.IncludeFields {
for _, domain := range d.Domains { switch f {
for _, server := range d.Servers { case "first_ip", "all_ips":
wg.Add(1) default:
go func(domain, server string) { return fmt.Errorf("invalid field %q included", f)
fields := make(map[string]interface{}, 2) }
tags := map[string]string{ d.fieldEnabled[f] = true
"server": server,
"domain": domain,
"record_type": d.RecordType,
} }
dnsQueryTime, rcode, err := d.getDNSQueryTime(domain, server) // Set defaults
if rcode >= 0 {
tags["rcode"] = dns.RcodeToString[rcode]
fields["rcode_value"] = rcode
}
if err == nil {
setResult(Success, fields, tags)
fields["query_time_ms"] = dnsQueryTime
} else if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
setResult(Timeout, fields, tags)
} else if err != nil {
setResult(Error, fields, tags)
acc.AddError(err)
}
acc.AddFields("dns_query", fields, tags)
wg.Done()
}(domain, server)
}
}
wg.Wait()
return nil
}
func (d *DNSQuery) setDefaultValues() {
if d.Network == "" { if d.Network == "" {
d.Network = "udp" d.Network = "udp"
} }
if len(d.RecordType) == 0 { if d.RecordType == "" {
d.RecordType = "NS" d.RecordType = "NS"
} }
@ -105,39 +69,106 @@ func (d *DNSQuery) setDefaultValues() {
d.RecordType = "NS" d.RecordType = "NS"
} }
if d.Port == 0 { if d.Port < 1 {
d.Port = 53 d.Port = 53
} }
if d.Timeout == 0 { return nil
d.Timeout = 2
}
} }
func (d *DNSQuery) getDNSQueryTime(domain string, server string) (float64, int, error) { func (d *DNSQuery) Gather(acc telegraf.Accumulator) error {
dnsQueryTime := float64(0) var wg sync.WaitGroup
c := new(dns.Client) for _, domain := range d.Domains {
c.ReadTimeout = time.Duration(d.Timeout) * time.Second for _, server := range d.Servers {
c.Net = d.Network wg.Add(1)
go func(domain, server string) {
defer wg.Done()
fields, tags, err := d.query(domain, server)
if err != nil {
if opErr, ok := err.(*net.OpError); !ok || !opErr.Timeout() {
acc.AddError(err)
}
}
acc.AddFields("dns_query", fields, tags)
}(domain, server)
}
}
wg.Wait()
return nil
}
func (d *DNSQuery) query(domain string, server string) (map[string]interface{}, map[string]string, error) {
tags := map[string]string{
"server": server,
"domain": domain,
"record_type": d.RecordType,
"result": "error",
}
fields := map[string]interface{}{
"query_time_ms": float64(0),
"result_code": uint64(Error),
}
c := dns.Client{
ReadTimeout: time.Duration(d.Timeout),
Net: d.Network,
}
m := new(dns.Msg)
recordType, err := d.parseRecordType() recordType, err := d.parseRecordType()
if err != nil { if err != nil {
return dnsQueryTime, -1, err return fields, tags, err
} }
m.SetQuestion(dns.Fqdn(domain), recordType)
m.RecursionDesired = true
r, rtt, err := c.Exchange(m, net.JoinHostPort(server, strconv.Itoa(d.Port))) var msg dns.Msg
msg.SetQuestion(dns.Fqdn(domain), recordType)
msg.RecursionDesired = true
addr := net.JoinHostPort(server, strconv.Itoa(d.Port))
r, rtt, err := c.Exchange(&msg, addr)
if err != nil { if err != nil {
return dnsQueryTime, -1, err if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
tags["result"] = "timeout"
fields["result_code"] = uint64(Timeout)
return fields, tags, err
} }
return fields, tags, err
}
// Fill valid fields
tags["rcode"] = dns.RcodeToString[r.Rcode]
fields["rcode_value"] = r.Rcode
fields["query_time_ms"] = float64(rtt.Nanoseconds()) / 1e6
// Handle the failure case
if r.Rcode != dns.RcodeSuccess { if r.Rcode != dns.RcodeSuccess {
return dnsQueryTime, r.Rcode, fmt.Errorf("Invalid answer (%s) from %s after %s query for %s", dns.RcodeToString[r.Rcode], server, d.RecordType, domain) return fields, tags, fmt.Errorf("invalid answer (%s) from %s after %s query for %s", dns.RcodeToString[r.Rcode], server, d.RecordType, domain)
} }
dnsQueryTime = float64(rtt.Nanoseconds()) / 1e6
return dnsQueryTime, r.Rcode, nil // Success
tags["result"] = "success"
fields["result_code"] = uint64(Success)
if d.fieldEnabled["first_ip"] {
for _, record := range r.Answer {
if ip, found := extractIP(record); found {
fields["ip"] = ip
break
}
}
}
if d.fieldEnabled["all_ips"] {
for i, record := range r.Answer {
if ip, found := extractIP(record); found {
fields["ip_"+strconv.Itoa(i)] = ip
}
}
}
return fields, tags, nil
} }
func (d *DNSQuery) parseRecordType() (uint16, error) { func (d *DNSQuery) parseRecordType() (uint16, error) {
@ -168,29 +199,26 @@ func (d *DNSQuery) parseRecordType() (uint16, error) {
case "TXT": case "TXT":
recordType = dns.TypeTXT recordType = dns.TypeTXT
default: default:
err = fmt.Errorf("Record type %s not recognized", d.RecordType) err = fmt.Errorf("record type %s not recognized", d.RecordType)
} }
return recordType, err return recordType, err
} }
func setResult(result ResultType, fields map[string]interface{}, tags map[string]string) { func extractIP(record dns.RR) (string, bool) {
var tag string if r, ok := record.(*dns.A); ok {
switch result { return r.A.String(), true
case Success:
tag = "success"
case Timeout:
tag = "timeout"
case Error:
tag = "error"
} }
if r, ok := record.(*dns.AAAA); ok {
tags["result"] = tag return r.AAAA.String(), true
fields["result_code"] = uint64(result) }
return "", false
} }
func init() { func init() {
inputs.Add("dns_query", func() telegraf.Input { inputs.Add("dns_query", func() telegraf.Input {
return &DNSQuery{} return &DNSQuery{
Timeout: config.Duration(2 * time.Second),
}
}) })
} }

View File

@ -7,6 +7,9 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/metric"
"github.com/influxdata/telegraf/testutil" "github.com/influxdata/telegraf/testutil"
) )
@ -17,116 +20,141 @@ func TestGathering(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("Skipping network-dependent test in short mode.") t.Skip("Skipping network-dependent test in short mode.")
} }
var dnsConfig = DNSQuery{
dnsConfig := DNSQuery{
Servers: servers, Servers: servers,
Domains: domains, Domains: domains,
Timeout: config.Duration(2 * time.Second),
} }
var acc testutil.Accumulator
err := acc.GatherError(dnsConfig.Gather) var acc testutil.Accumulator
require.NoError(t, err) require.NoError(t, dnsConfig.Init())
require.NoError(t, acc.GatherError(dnsConfig.Gather))
metric, ok := acc.Get("dns_query") metric, ok := acc.Get("dns_query")
require.True(t, ok) require.True(t, ok)
queryTime, _ := metric.Fields["query_time_ms"].(float64) queryTime, ok := metric.Fields["query_time_ms"].(float64)
require.True(t, ok)
require.NotEqual(t, 0, queryTime) require.NotEqual(t, float64(0), queryTime)
} }
func TestGatheringMxRecord(t *testing.T) { func TestGatheringMxRecord(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("Skipping network-dependent test in short mode.") t.Skip("Skipping network-dependent test in short mode.")
} }
var dnsConfig = DNSQuery{
dnsConfig := DNSQuery{
Servers: servers, Servers: servers,
Domains: domains, Domains: domains,
RecordType: "MX",
Timeout: config.Duration(2 * time.Second),
} }
var acc testutil.Accumulator var acc testutil.Accumulator
dnsConfig.RecordType = "MX"
err := acc.GatherError(dnsConfig.Gather) require.NoError(t, dnsConfig.Init())
require.NoError(t, err) require.NoError(t, acc.GatherError(dnsConfig.Gather))
metric, ok := acc.Get("dns_query") metric, ok := acc.Get("dns_query")
require.True(t, ok) require.True(t, ok)
queryTime, _ := metric.Fields["query_time_ms"].(float64) queryTime, ok := metric.Fields["query_time_ms"].(float64)
require.True(t, ok)
require.NotEqual(t, 0, queryTime) require.NotEqual(t, float64(0), queryTime)
} }
func TestGatheringRootDomain(t *testing.T) { func TestGatheringRootDomain(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("Skipping network-dependent test in short mode.") t.Skip("Skipping network-dependent test in short mode.")
} }
var dnsConfig = DNSQuery{
dnsConfig := DNSQuery{
Servers: servers, Servers: servers,
Domains: []string{"."}, Domains: []string{"."},
RecordType: "MX", RecordType: "MX",
Timeout: config.Duration(2 * time.Second),
} }
require.NoError(t, dnsConfig.Init())
var acc testutil.Accumulator var acc testutil.Accumulator
tags := map[string]string{ require.NoError(t, acc.GatherError(dnsConfig.Gather))
m, ok := acc.Get("dns_query")
require.True(t, ok)
queryTime, ok := m.Fields["query_time_ms"].(float64)
require.True(t, ok)
expected := []telegraf.Metric{
metric.New(
"dns_query",
map[string]string{
"server": "8.8.8.8", "server": "8.8.8.8",
"domain": ".", "domain": ".",
"record_type": "MX", "record_type": "MX",
"rcode": "NOERROR", "rcode": "NOERROR",
"result": "success", "result": "success",
} },
fields := map[string]interface{}{ map[string]interface{}{
"rcode_value": 0, "rcode_value": 0,
"result_code": uint64(0), "result_code": uint64(0),
"query_time_ms": queryTime,
},
time.Unix(0, 0),
),
} }
testutil.RequireMetricsEqual(t, expected, acc.GetTelegrafMetrics(), testutil.IgnoreTime())
err := acc.GatherError(dnsConfig.Gather)
require.NoError(t, err)
metric, ok := acc.Get("dns_query")
require.True(t, ok)
queryTime, _ := metric.Fields["query_time_ms"].(float64)
fields["query_time_ms"] = queryTime
acc.AssertContainsTaggedFields(t, "dns_query", fields, tags)
} }
func TestMetricContainsServerAndDomainAndRecordTypeTags(t *testing.T) { func TestMetricContainsServerAndDomainAndRecordTypeTags(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("Skipping network-dependent test in short mode.") t.Skip("Skipping network-dependent test in short mode.")
} }
var dnsConfig = DNSQuery{
dnsConfig := DNSQuery{
Servers: servers, Servers: servers,
Domains: domains, Domains: domains,
Timeout: config.Duration(2 * time.Second),
} }
require.NoError(t, dnsConfig.Init())
var acc testutil.Accumulator var acc testutil.Accumulator
tags := map[string]string{ require.NoError(t, acc.GatherError(dnsConfig.Gather))
m, ok := acc.Get("dns_query")
require.True(t, ok)
queryTime, ok := m.Fields["query_time_ms"].(float64)
require.True(t, ok)
expected := []telegraf.Metric{
metric.New(
"dns_query",
map[string]string{
"server": "8.8.8.8", "server": "8.8.8.8",
"domain": "google.com", "domain": "google.com",
"record_type": "NS", "record_type": "NS",
"rcode": "NOERROR", "rcode": "NOERROR",
"result": "success", "result": "success",
} },
fields := map[string]interface{}{ map[string]interface{}{
"rcode_value": 0, "rcode_value": 0,
"result_code": uint64(0), "result_code": uint64(0),
"query_time_ms": queryTime,
},
time.Unix(0, 0),
),
} }
testutil.RequireMetricsEqual(t, expected, acc.GetTelegrafMetrics(), testutil.IgnoreTime())
err := acc.GatherError(dnsConfig.Gather)
require.NoError(t, err)
metric, ok := acc.Get("dns_query")
require.True(t, ok)
queryTime, _ := metric.Fields["query_time_ms"].(float64)
fields["query_time_ms"] = queryTime
acc.AssertContainsTaggedFields(t, "dns_query", fields, tags)
} }
func TestGatheringTimeout(t *testing.T) { func TestGatheringTimeout(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("Skipping network-dependent test in short mode.") t.Skip("Skipping network-dependent test in short mode.")
} }
var dnsConfig = DNSQuery{
dnsConfig := DNSQuery{
Servers: servers, Servers: servers,
Domains: domains, Domains: domains,
Timeout: config.Duration(1 * time.Second),
Port: 60054,
} }
var acc testutil.Accumulator require.NoError(t, dnsConfig.Init())
dnsConfig.Port = 60054
dnsConfig.Timeout = 1
var acc testutil.Accumulator
channel := make(chan error, 1) channel := make(chan error, 1)
go func() { go func() {
channel <- acc.GatherError(dnsConfig.Gather) channel <- acc.GatherError(dnsConfig.Gather)
@ -140,76 +168,95 @@ func TestGatheringTimeout(t *testing.T) {
} }
func TestSettingDefaultValues(t *testing.T) { func TestSettingDefaultValues(t *testing.T) {
dnsConfig := DNSQuery{} dnsConfig := DNSQuery{
Timeout: config.Duration(2 * time.Second),
dnsConfig.setDefaultValues() }
require.NoError(t, dnsConfig.Init())
require.Equal(t, []string{"."}, dnsConfig.Domains, "Default domain not equal \".\"") require.Equal(t, []string{"."}, dnsConfig.Domains, "Default domain not equal \".\"")
require.Equal(t, "NS", dnsConfig.RecordType, "Default record type not equal 'NS'") require.Equal(t, "NS", dnsConfig.RecordType, "Default record type not equal 'NS'")
require.Equal(t, 53, dnsConfig.Port, "Default port number not equal 53") require.Equal(t, 53, dnsConfig.Port, "Default port number not equal 53")
require.Equal(t, 2, dnsConfig.Timeout, "Default timeout not equal 2") require.Equal(t, config.Duration(2*time.Second), dnsConfig.Timeout, "Default timeout not equal 2s")
dnsConfig = DNSQuery{Domains: []string{"."}}
dnsConfig.setDefaultValues()
dnsConfig = DNSQuery{
Domains: []string{"."},
Timeout: config.Duration(2 * time.Second),
}
require.NoError(t, dnsConfig.Init())
require.Equal(t, "NS", dnsConfig.RecordType, "Default record type not equal 'NS'") require.Equal(t, "NS", dnsConfig.RecordType, "Default record type not equal 'NS'")
} }
func TestRecordTypeParser(t *testing.T) { func TestRecordTypeParser(t *testing.T) {
var dnsConfig = DNSQuery{} tests := []struct {
var recordType uint16 record string
expected uint16
}{
{
record: "A",
expected: dns.TypeA,
},
{
record: "AAAA",
expected: dns.TypeAAAA,
},
{
record: "ANY",
expected: dns.TypeANY,
},
{
record: "CNAME",
expected: dns.TypeCNAME,
},
{
record: "MX",
expected: dns.TypeMX,
},
{
record: "NS",
expected: dns.TypeNS,
},
{
record: "PTR",
expected: dns.TypePTR,
},
{
record: "SOA",
expected: dns.TypeSOA,
},
{
record: "SPF",
expected: dns.TypeSPF,
},
{
record: "SRV",
expected: dns.TypeSRV,
},
{
record: "TXT",
expected: dns.TypeTXT,
},
}
dnsConfig.RecordType = "A" for _, tt := range tests {
recordType, _ = dnsConfig.parseRecordType() t.Run(tt.record, func(t *testing.T) {
require.Equal(t, dns.TypeA, recordType) plugin := DNSQuery{
Timeout: config.Duration(2 * time.Second),
dnsConfig.RecordType = "AAAA" Domains: []string{"example.com"},
recordType, _ = dnsConfig.parseRecordType() RecordType: tt.record,
require.Equal(t, dns.TypeAAAA, recordType) }
require.NoError(t, plugin.Init())
dnsConfig.RecordType = "ANY" recordType, err := plugin.parseRecordType()
recordType, _ = dnsConfig.parseRecordType() require.NoError(t, err)
require.Equal(t, dns.TypeANY, recordType) require.Equal(t, tt.expected, recordType)
})
dnsConfig.RecordType = "CNAME" }
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeCNAME, recordType)
dnsConfig.RecordType = "MX"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeMX, recordType)
dnsConfig.RecordType = "NS"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeNS, recordType)
dnsConfig.RecordType = "PTR"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypePTR, recordType)
dnsConfig.RecordType = "SOA"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeSOA, recordType)
dnsConfig.RecordType = "SPF"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeSPF, recordType)
dnsConfig.RecordType = "SRV"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeSRV, recordType)
dnsConfig.RecordType = "TXT"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeTXT, recordType)
} }
func TestRecordTypeParserError(t *testing.T) { func TestRecordTypeParserError(t *testing.T) {
var dnsConfig = DNSQuery{} plugin := DNSQuery{
var err error Timeout: config.Duration(2 * time.Second),
RecordType: "nil",
}
dnsConfig.RecordType = "nil" _, err := plugin.parseRecordType()
_, err = dnsConfig.parseRecordType()
require.Error(t, err) require.Error(t, err)
} }

View File

@ -16,5 +16,11 @@
## Dns server port. ## Dns server port.
# port = 53 # port = 53
## Query timeout in seconds. ## Query timeout
# timeout = 2 # timeout = "2s"
## Include the specified additional properties in the resulting metric.
## The following values are supported:
## "first_ip" -- return IP of the first A and AAAA answer
## "all_ips" -- return IPs of all A and AAAA answers
# include_fields = []