feat(inputs.dns_query): Add IP field(s) (#12519)

This commit is contained in:
Sven Rebhan 2023-01-20 16:40:43 +01:00 committed by GitHub
parent bfb26a8af6
commit 410226051d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 294 additions and 207 deletions

View File

@ -33,8 +33,14 @@ See the [CONFIGURATION.md][CONFIGURATION.md] for more details.
## Dns server port.
# port = 53
## Query timeout in seconds.
# timeout = 2
## Query timeout
# timeout = "2s"
## Include the specified additional properties in the resulting metric.
## The following values are supported:
## "first_ip" -- return IP of the first A and AAAA answer
## "all_ips" -- return IPs of all A and AAAA answers
# include_fields = []
```
## Metrics

View File

@ -12,6 +12,7 @@ import (
"github.com/miekg/dns"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/plugins/inputs"
)
@ -27,76 +28,39 @@ const (
)
type DNSQuery struct {
// Domains or subdomains to query
Domains []string
Domains []string `toml:"domains"`
Network string `toml:"network"`
Servers []string `toml:"servers"`
RecordType string `toml:"record_type"`
Port int `toml:"port"`
Timeout config.Duration `toml:"timeout"`
IncludeFields []string `toml:"include_fields"`
// Network protocol name
Network string
// Server to query
Servers []string
// Record type
RecordType string `toml:"record_type"`
// DNS server port number
Port int
// Dns query timeout in seconds. 0 means no timeout
Timeout int
fieldEnabled map[string]bool
}
func (*DNSQuery) SampleConfig() string {
return sampleConfig
}
func (d *DNSQuery) Gather(acc telegraf.Accumulator) error {
var wg sync.WaitGroup
d.setDefaultValues()
for _, domain := range d.Domains {
for _, server := range d.Servers {
wg.Add(1)
go func(domain, server string) {
fields := make(map[string]interface{}, 2)
tags := map[string]string{
"server": server,
"domain": domain,
"record_type": d.RecordType,
}
dnsQueryTime, rcode, err := d.getDNSQueryTime(domain, server)
if rcode >= 0 {
tags["rcode"] = dns.RcodeToString[rcode]
fields["rcode_value"] = rcode
}
if err == nil {
setResult(Success, fields, tags)
fields["query_time_ms"] = dnsQueryTime
} else if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
setResult(Timeout, fields, tags)
} else if err != nil {
setResult(Error, fields, tags)
acc.AddError(err)
}
acc.AddFields("dns_query", fields, tags)
wg.Done()
}(domain, server)
func (d *DNSQuery) Init() error {
// Convert the included fields into a lookup-table
d.fieldEnabled = make(map[string]bool, len(d.IncludeFields))
for _, f := range d.IncludeFields {
switch f {
case "first_ip", "all_ips":
default:
return fmt.Errorf("invalid field %q included", f)
}
d.fieldEnabled[f] = true
}
wg.Wait()
return nil
}
func (d *DNSQuery) setDefaultValues() {
// Set defaults
if d.Network == "" {
d.Network = "udp"
}
if len(d.RecordType) == 0 {
if d.RecordType == "" {
d.RecordType = "NS"
}
@ -105,39 +69,106 @@ func (d *DNSQuery) setDefaultValues() {
d.RecordType = "NS"
}
if d.Port == 0 {
if d.Port < 1 {
d.Port = 53
}
if d.Timeout == 0 {
d.Timeout = 2
}
return nil
}
func (d *DNSQuery) getDNSQueryTime(domain string, server string) (float64, int, error) {
dnsQueryTime := float64(0)
func (d *DNSQuery) Gather(acc telegraf.Accumulator) error {
var wg sync.WaitGroup
c := new(dns.Client)
c.ReadTimeout = time.Duration(d.Timeout) * time.Second
c.Net = d.Network
for _, domain := range d.Domains {
for _, server := range d.Servers {
wg.Add(1)
go func(domain, server string) {
defer wg.Done()
fields, tags, err := d.query(domain, server)
if err != nil {
if opErr, ok := err.(*net.OpError); !ok || !opErr.Timeout() {
acc.AddError(err)
}
}
acc.AddFields("dns_query", fields, tags)
}(domain, server)
}
}
wg.Wait()
return nil
}
func (d *DNSQuery) query(domain string, server string) (map[string]interface{}, map[string]string, error) {
tags := map[string]string{
"server": server,
"domain": domain,
"record_type": d.RecordType,
"result": "error",
}
fields := map[string]interface{}{
"query_time_ms": float64(0),
"result_code": uint64(Error),
}
c := dns.Client{
ReadTimeout: time.Duration(d.Timeout),
Net: d.Network,
}
m := new(dns.Msg)
recordType, err := d.parseRecordType()
if err != nil {
return dnsQueryTime, -1, err
return fields, tags, err
}
m.SetQuestion(dns.Fqdn(domain), recordType)
m.RecursionDesired = true
r, rtt, err := c.Exchange(m, net.JoinHostPort(server, strconv.Itoa(d.Port)))
var msg dns.Msg
msg.SetQuestion(dns.Fqdn(domain), recordType)
msg.RecursionDesired = true
addr := net.JoinHostPort(server, strconv.Itoa(d.Port))
r, rtt, err := c.Exchange(&msg, addr)
if err != nil {
return dnsQueryTime, -1, err
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
tags["result"] = "timeout"
fields["result_code"] = uint64(Timeout)
return fields, tags, err
}
return fields, tags, err
}
// Fill valid fields
tags["rcode"] = dns.RcodeToString[r.Rcode]
fields["rcode_value"] = r.Rcode
fields["query_time_ms"] = float64(rtt.Nanoseconds()) / 1e6
// Handle the failure case
if r.Rcode != dns.RcodeSuccess {
return dnsQueryTime, r.Rcode, fmt.Errorf("Invalid answer (%s) from %s after %s query for %s", dns.RcodeToString[r.Rcode], server, d.RecordType, domain)
return fields, tags, fmt.Errorf("invalid answer (%s) from %s after %s query for %s", dns.RcodeToString[r.Rcode], server, d.RecordType, domain)
}
dnsQueryTime = float64(rtt.Nanoseconds()) / 1e6
return dnsQueryTime, r.Rcode, nil
// Success
tags["result"] = "success"
fields["result_code"] = uint64(Success)
if d.fieldEnabled["first_ip"] {
for _, record := range r.Answer {
if ip, found := extractIP(record); found {
fields["ip"] = ip
break
}
}
}
if d.fieldEnabled["all_ips"] {
for i, record := range r.Answer {
if ip, found := extractIP(record); found {
fields["ip_"+strconv.Itoa(i)] = ip
}
}
}
return fields, tags, nil
}
func (d *DNSQuery) parseRecordType() (uint16, error) {
@ -168,29 +199,26 @@ func (d *DNSQuery) parseRecordType() (uint16, error) {
case "TXT":
recordType = dns.TypeTXT
default:
err = fmt.Errorf("Record type %s not recognized", d.RecordType)
err = fmt.Errorf("record type %s not recognized", d.RecordType)
}
return recordType, err
}
func setResult(result ResultType, fields map[string]interface{}, tags map[string]string) {
var tag string
switch result {
case Success:
tag = "success"
case Timeout:
tag = "timeout"
case Error:
tag = "error"
func extractIP(record dns.RR) (string, bool) {
if r, ok := record.(*dns.A); ok {
return r.A.String(), true
}
tags["result"] = tag
fields["result_code"] = uint64(result)
if r, ok := record.(*dns.AAAA); ok {
return r.AAAA.String(), true
}
return "", false
}
func init() {
inputs.Add("dns_query", func() telegraf.Input {
return &DNSQuery{}
return &DNSQuery{
Timeout: config.Duration(2 * time.Second),
}
})
}

View File

@ -7,6 +7,9 @@ import (
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/metric"
"github.com/influxdata/telegraf/testutil"
)
@ -17,116 +20,141 @@ func TestGathering(t *testing.T) {
if testing.Short() {
t.Skip("Skipping network-dependent test in short mode.")
}
var dnsConfig = DNSQuery{
dnsConfig := DNSQuery{
Servers: servers,
Domains: domains,
Timeout: config.Duration(2 * time.Second),
}
var acc testutil.Accumulator
err := acc.GatherError(dnsConfig.Gather)
require.NoError(t, err)
var acc testutil.Accumulator
require.NoError(t, dnsConfig.Init())
require.NoError(t, acc.GatherError(dnsConfig.Gather))
metric, ok := acc.Get("dns_query")
require.True(t, ok)
queryTime, _ := metric.Fields["query_time_ms"].(float64)
require.NotEqual(t, 0, queryTime)
queryTime, ok := metric.Fields["query_time_ms"].(float64)
require.True(t, ok)
require.NotEqual(t, float64(0), queryTime)
}
func TestGatheringMxRecord(t *testing.T) {
if testing.Short() {
t.Skip("Skipping network-dependent test in short mode.")
}
var dnsConfig = DNSQuery{
Servers: servers,
Domains: domains,
dnsConfig := DNSQuery{
Servers: servers,
Domains: domains,
RecordType: "MX",
Timeout: config.Duration(2 * time.Second),
}
var acc testutil.Accumulator
dnsConfig.RecordType = "MX"
err := acc.GatherError(dnsConfig.Gather)
require.NoError(t, err)
require.NoError(t, dnsConfig.Init())
require.NoError(t, acc.GatherError(dnsConfig.Gather))
metric, ok := acc.Get("dns_query")
require.True(t, ok)
queryTime, _ := metric.Fields["query_time_ms"].(float64)
require.NotEqual(t, 0, queryTime)
queryTime, ok := metric.Fields["query_time_ms"].(float64)
require.True(t, ok)
require.NotEqual(t, float64(0), queryTime)
}
func TestGatheringRootDomain(t *testing.T) {
if testing.Short() {
t.Skip("Skipping network-dependent test in short mode.")
}
var dnsConfig = DNSQuery{
dnsConfig := DNSQuery{
Servers: servers,
Domains: []string{"."},
RecordType: "MX",
Timeout: config.Duration(2 * time.Second),
}
require.NoError(t, dnsConfig.Init())
var acc testutil.Accumulator
tags := map[string]string{
"server": "8.8.8.8",
"domain": ".",
"record_type": "MX",
"rcode": "NOERROR",
"result": "success",
}
fields := map[string]interface{}{
"rcode_value": 0,
"result_code": uint64(0),
}
require.NoError(t, acc.GatherError(dnsConfig.Gather))
err := acc.GatherError(dnsConfig.Gather)
require.NoError(t, err)
metric, ok := acc.Get("dns_query")
m, ok := acc.Get("dns_query")
require.True(t, ok)
queryTime, ok := m.Fields["query_time_ms"].(float64)
require.True(t, ok)
queryTime, _ := metric.Fields["query_time_ms"].(float64)
fields["query_time_ms"] = queryTime
acc.AssertContainsTaggedFields(t, "dns_query", fields, tags)
expected := []telegraf.Metric{
metric.New(
"dns_query",
map[string]string{
"server": "8.8.8.8",
"domain": ".",
"record_type": "MX",
"rcode": "NOERROR",
"result": "success",
},
map[string]interface{}{
"rcode_value": 0,
"result_code": uint64(0),
"query_time_ms": queryTime,
},
time.Unix(0, 0),
),
}
testutil.RequireMetricsEqual(t, expected, acc.GetTelegrafMetrics(), testutil.IgnoreTime())
}
func TestMetricContainsServerAndDomainAndRecordTypeTags(t *testing.T) {
if testing.Short() {
t.Skip("Skipping network-dependent test in short mode.")
}
var dnsConfig = DNSQuery{
dnsConfig := DNSQuery{
Servers: servers,
Domains: domains,
Timeout: config.Duration(2 * time.Second),
}
require.NoError(t, dnsConfig.Init())
var acc testutil.Accumulator
tags := map[string]string{
"server": "8.8.8.8",
"domain": "google.com",
"record_type": "NS",
"rcode": "NOERROR",
"result": "success",
}
fields := map[string]interface{}{
"rcode_value": 0,
"result_code": uint64(0),
}
require.NoError(t, acc.GatherError(dnsConfig.Gather))
err := acc.GatherError(dnsConfig.Gather)
require.NoError(t, err)
metric, ok := acc.Get("dns_query")
m, ok := acc.Get("dns_query")
require.True(t, ok)
queryTime, _ := metric.Fields["query_time_ms"].(float64)
fields["query_time_ms"] = queryTime
acc.AssertContainsTaggedFields(t, "dns_query", fields, tags)
queryTime, ok := m.Fields["query_time_ms"].(float64)
require.True(t, ok)
expected := []telegraf.Metric{
metric.New(
"dns_query",
map[string]string{
"server": "8.8.8.8",
"domain": "google.com",
"record_type": "NS",
"rcode": "NOERROR",
"result": "success",
},
map[string]interface{}{
"rcode_value": 0,
"result_code": uint64(0),
"query_time_ms": queryTime,
},
time.Unix(0, 0),
),
}
testutil.RequireMetricsEqual(t, expected, acc.GetTelegrafMetrics(), testutil.IgnoreTime())
}
func TestGatheringTimeout(t *testing.T) {
if testing.Short() {
t.Skip("Skipping network-dependent test in short mode.")
}
var dnsConfig = DNSQuery{
dnsConfig := DNSQuery{
Servers: servers,
Domains: domains,
Timeout: config.Duration(1 * time.Second),
Port: 60054,
}
var acc testutil.Accumulator
dnsConfig.Port = 60054
dnsConfig.Timeout = 1
require.NoError(t, dnsConfig.Init())
var acc testutil.Accumulator
channel := make(chan error, 1)
go func() {
channel <- acc.GatherError(dnsConfig.Gather)
@ -140,76 +168,95 @@ func TestGatheringTimeout(t *testing.T) {
}
func TestSettingDefaultValues(t *testing.T) {
dnsConfig := DNSQuery{}
dnsConfig.setDefaultValues()
dnsConfig := DNSQuery{
Timeout: config.Duration(2 * time.Second),
}
require.NoError(t, dnsConfig.Init())
require.Equal(t, []string{"."}, dnsConfig.Domains, "Default domain not equal \".\"")
require.Equal(t, "NS", dnsConfig.RecordType, "Default record type not equal 'NS'")
require.Equal(t, 53, dnsConfig.Port, "Default port number not equal 53")
require.Equal(t, 2, dnsConfig.Timeout, "Default timeout not equal 2")
dnsConfig = DNSQuery{Domains: []string{"."}}
dnsConfig.setDefaultValues()
require.Equal(t, config.Duration(2*time.Second), dnsConfig.Timeout, "Default timeout not equal 2s")
dnsConfig = DNSQuery{
Domains: []string{"."},
Timeout: config.Duration(2 * time.Second),
}
require.NoError(t, dnsConfig.Init())
require.Equal(t, "NS", dnsConfig.RecordType, "Default record type not equal 'NS'")
}
func TestRecordTypeParser(t *testing.T) {
var dnsConfig = DNSQuery{}
var recordType uint16
tests := []struct {
record string
expected uint16
}{
{
record: "A",
expected: dns.TypeA,
},
{
record: "AAAA",
expected: dns.TypeAAAA,
},
{
record: "ANY",
expected: dns.TypeANY,
},
{
record: "CNAME",
expected: dns.TypeCNAME,
},
{
record: "MX",
expected: dns.TypeMX,
},
{
record: "NS",
expected: dns.TypeNS,
},
{
record: "PTR",
expected: dns.TypePTR,
},
{
record: "SOA",
expected: dns.TypeSOA,
},
{
record: "SPF",
expected: dns.TypeSPF,
},
{
record: "SRV",
expected: dns.TypeSRV,
},
{
record: "TXT",
expected: dns.TypeTXT,
},
}
dnsConfig.RecordType = "A"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeA, recordType)
dnsConfig.RecordType = "AAAA"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeAAAA, recordType)
dnsConfig.RecordType = "ANY"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeANY, recordType)
dnsConfig.RecordType = "CNAME"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeCNAME, recordType)
dnsConfig.RecordType = "MX"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeMX, recordType)
dnsConfig.RecordType = "NS"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeNS, recordType)
dnsConfig.RecordType = "PTR"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypePTR, recordType)
dnsConfig.RecordType = "SOA"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeSOA, recordType)
dnsConfig.RecordType = "SPF"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeSPF, recordType)
dnsConfig.RecordType = "SRV"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeSRV, recordType)
dnsConfig.RecordType = "TXT"
recordType, _ = dnsConfig.parseRecordType()
require.Equal(t, dns.TypeTXT, recordType)
for _, tt := range tests {
t.Run(tt.record, func(t *testing.T) {
plugin := DNSQuery{
Timeout: config.Duration(2 * time.Second),
Domains: []string{"example.com"},
RecordType: tt.record,
}
require.NoError(t, plugin.Init())
recordType, err := plugin.parseRecordType()
require.NoError(t, err)
require.Equal(t, tt.expected, recordType)
})
}
}
func TestRecordTypeParserError(t *testing.T) {
var dnsConfig = DNSQuery{}
var err error
plugin := DNSQuery{
Timeout: config.Duration(2 * time.Second),
RecordType: "nil",
}
dnsConfig.RecordType = "nil"
_, err = dnsConfig.parseRecordType()
_, err := plugin.parseRecordType()
require.Error(t, err)
}

View File

@ -16,5 +16,11 @@
## Dns server port.
# port = 53
## Query timeout in seconds.
# timeout = 2
## Query timeout
# timeout = "2s"
## Include the specified additional properties in the resulting metric.
## The following values are supported:
## "first_ip" -- return IP of the first A and AAAA answer
## "all_ips" -- return IPs of all A and AAAA answers
# include_fields = []