diff --git a/plugins/inputs/dns_query/README.md b/plugins/inputs/dns_query/README.md index c6a2d3aad..1c535088c 100644 --- a/plugins/inputs/dns_query/README.md +++ b/plugins/inputs/dns_query/README.md @@ -33,8 +33,14 @@ See the [CONFIGURATION.md][CONFIGURATION.md] for more details. ## Dns server port. # port = 53 - ## Query timeout in seconds. - # timeout = 2 + ## Query timeout + # timeout = "2s" + + ## Include the specified additional properties in the resulting metric. + ## The following values are supported: + ## "first_ip" -- return IP of the first A and AAAA answer + ## "all_ips" -- return IPs of all A and AAAA answers + # include_fields = [] ``` ## Metrics diff --git a/plugins/inputs/dns_query/dns_query.go b/plugins/inputs/dns_query/dns_query.go index f9bc0d05b..b987e5a99 100644 --- a/plugins/inputs/dns_query/dns_query.go +++ b/plugins/inputs/dns_query/dns_query.go @@ -12,6 +12,7 @@ import ( "github.com/miekg/dns" "github.com/influxdata/telegraf" + "github.com/influxdata/telegraf/config" "github.com/influxdata/telegraf/plugins/inputs" ) @@ -27,76 +28,39 @@ const ( ) type DNSQuery struct { - // Domains or subdomains to query - Domains []string + Domains []string `toml:"domains"` + Network string `toml:"network"` + Servers []string `toml:"servers"` + RecordType string `toml:"record_type"` + Port int `toml:"port"` + Timeout config.Duration `toml:"timeout"` + IncludeFields []string `toml:"include_fields"` - // Network protocol name - Network string - - // Server to query - Servers []string - - // Record type - RecordType string `toml:"record_type"` - - // DNS server port number - Port int - - // Dns query timeout in seconds. 0 means no timeout - Timeout int + fieldEnabled map[string]bool } func (*DNSQuery) SampleConfig() string { return sampleConfig } -func (d *DNSQuery) Gather(acc telegraf.Accumulator) error { - var wg sync.WaitGroup - d.setDefaultValues() - - for _, domain := range d.Domains { - for _, server := range d.Servers { - wg.Add(1) - go func(domain, server string) { - fields := make(map[string]interface{}, 2) - tags := map[string]string{ - "server": server, - "domain": domain, - "record_type": d.RecordType, - } - - dnsQueryTime, rcode, err := d.getDNSQueryTime(domain, server) - if rcode >= 0 { - tags["rcode"] = dns.RcodeToString[rcode] - fields["rcode_value"] = rcode - } - if err == nil { - setResult(Success, fields, tags) - fields["query_time_ms"] = dnsQueryTime - } else if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() { - setResult(Timeout, fields, tags) - } else if err != nil { - setResult(Error, fields, tags) - acc.AddError(err) - } - - acc.AddFields("dns_query", fields, tags) - - wg.Done() - }(domain, server) +func (d *DNSQuery) Init() error { + // Convert the included fields into a lookup-table + d.fieldEnabled = make(map[string]bool, len(d.IncludeFields)) + for _, f := range d.IncludeFields { + switch f { + case "first_ip", "all_ips": + default: + return fmt.Errorf("invalid field %q included", f) } + d.fieldEnabled[f] = true } - wg.Wait() - return nil -} - -func (d *DNSQuery) setDefaultValues() { + // Set defaults if d.Network == "" { d.Network = "udp" } - if len(d.RecordType) == 0 { + if d.RecordType == "" { d.RecordType = "NS" } @@ -105,39 +69,106 @@ func (d *DNSQuery) setDefaultValues() { d.RecordType = "NS" } - if d.Port == 0 { + if d.Port < 1 { d.Port = 53 } - if d.Timeout == 0 { - d.Timeout = 2 - } + return nil } -func (d *DNSQuery) getDNSQueryTime(domain string, server string) (float64, int, error) { - dnsQueryTime := float64(0) +func (d *DNSQuery) Gather(acc telegraf.Accumulator) error { + var wg sync.WaitGroup - c := new(dns.Client) - c.ReadTimeout = time.Duration(d.Timeout) * time.Second - c.Net = d.Network + for _, domain := range d.Domains { + for _, server := range d.Servers { + wg.Add(1) + go func(domain, server string) { + defer wg.Done() + + fields, tags, err := d.query(domain, server) + if err != nil { + if opErr, ok := err.(*net.OpError); !ok || !opErr.Timeout() { + acc.AddError(err) + } + } + acc.AddFields("dns_query", fields, tags) + }(domain, server) + } + } + wg.Wait() + + return nil +} + +func (d *DNSQuery) query(domain string, server string) (map[string]interface{}, map[string]string, error) { + tags := map[string]string{ + "server": server, + "domain": domain, + "record_type": d.RecordType, + "result": "error", + } + + fields := map[string]interface{}{ + "query_time_ms": float64(0), + "result_code": uint64(Error), + } + + c := dns.Client{ + ReadTimeout: time.Duration(d.Timeout), + Net: d.Network, + } - m := new(dns.Msg) recordType, err := d.parseRecordType() if err != nil { - return dnsQueryTime, -1, err + return fields, tags, err } - m.SetQuestion(dns.Fqdn(domain), recordType) - m.RecursionDesired = true - r, rtt, err := c.Exchange(m, net.JoinHostPort(server, strconv.Itoa(d.Port))) + var msg dns.Msg + msg.SetQuestion(dns.Fqdn(domain), recordType) + msg.RecursionDesired = true + + addr := net.JoinHostPort(server, strconv.Itoa(d.Port)) + r, rtt, err := c.Exchange(&msg, addr) if err != nil { - return dnsQueryTime, -1, err + if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() { + tags["result"] = "timeout" + fields["result_code"] = uint64(Timeout) + return fields, tags, err + } + return fields, tags, err } + + // Fill valid fields + tags["rcode"] = dns.RcodeToString[r.Rcode] + fields["rcode_value"] = r.Rcode + fields["query_time_ms"] = float64(rtt.Nanoseconds()) / 1e6 + + // Handle the failure case if r.Rcode != dns.RcodeSuccess { - return dnsQueryTime, r.Rcode, fmt.Errorf("Invalid answer (%s) from %s after %s query for %s", dns.RcodeToString[r.Rcode], server, d.RecordType, domain) + return fields, tags, fmt.Errorf("invalid answer (%s) from %s after %s query for %s", dns.RcodeToString[r.Rcode], server, d.RecordType, domain) } - dnsQueryTime = float64(rtt.Nanoseconds()) / 1e6 - return dnsQueryTime, r.Rcode, nil + + // Success + tags["result"] = "success" + fields["result_code"] = uint64(Success) + + if d.fieldEnabled["first_ip"] { + for _, record := range r.Answer { + if ip, found := extractIP(record); found { + fields["ip"] = ip + break + } + } + } + if d.fieldEnabled["all_ips"] { + for i, record := range r.Answer { + if ip, found := extractIP(record); found { + fields["ip_"+strconv.Itoa(i)] = ip + } + } + } + + return fields, tags, nil } func (d *DNSQuery) parseRecordType() (uint16, error) { @@ -168,29 +199,26 @@ func (d *DNSQuery) parseRecordType() (uint16, error) { case "TXT": recordType = dns.TypeTXT default: - err = fmt.Errorf("Record type %s not recognized", d.RecordType) + err = fmt.Errorf("record type %s not recognized", d.RecordType) } return recordType, err } -func setResult(result ResultType, fields map[string]interface{}, tags map[string]string) { - var tag string - switch result { - case Success: - tag = "success" - case Timeout: - tag = "timeout" - case Error: - tag = "error" +func extractIP(record dns.RR) (string, bool) { + if r, ok := record.(*dns.A); ok { + return r.A.String(), true } - - tags["result"] = tag - fields["result_code"] = uint64(result) + if r, ok := record.(*dns.AAAA); ok { + return r.AAAA.String(), true + } + return "", false } func init() { inputs.Add("dns_query", func() telegraf.Input { - return &DNSQuery{} + return &DNSQuery{ + Timeout: config.Duration(2 * time.Second), + } }) } diff --git a/plugins/inputs/dns_query/dns_query_test.go b/plugins/inputs/dns_query/dns_query_test.go index 2e57e2f7b..ed49c68fc 100644 --- a/plugins/inputs/dns_query/dns_query_test.go +++ b/plugins/inputs/dns_query/dns_query_test.go @@ -7,6 +7,9 @@ import ( "github.com/miekg/dns" "github.com/stretchr/testify/require" + "github.com/influxdata/telegraf" + "github.com/influxdata/telegraf/config" + "github.com/influxdata/telegraf/metric" "github.com/influxdata/telegraf/testutil" ) @@ -17,116 +20,141 @@ func TestGathering(t *testing.T) { if testing.Short() { t.Skip("Skipping network-dependent test in short mode.") } - var dnsConfig = DNSQuery{ + + dnsConfig := DNSQuery{ Servers: servers, Domains: domains, + Timeout: config.Duration(2 * time.Second), } - var acc testutil.Accumulator - err := acc.GatherError(dnsConfig.Gather) - require.NoError(t, err) + var acc testutil.Accumulator + require.NoError(t, dnsConfig.Init()) + require.NoError(t, acc.GatherError(dnsConfig.Gather)) metric, ok := acc.Get("dns_query") require.True(t, ok) - queryTime, _ := metric.Fields["query_time_ms"].(float64) - - require.NotEqual(t, 0, queryTime) + queryTime, ok := metric.Fields["query_time_ms"].(float64) + require.True(t, ok) + require.NotEqual(t, float64(0), queryTime) } func TestGatheringMxRecord(t *testing.T) { if testing.Short() { t.Skip("Skipping network-dependent test in short mode.") } - var dnsConfig = DNSQuery{ - Servers: servers, - Domains: domains, + + dnsConfig := DNSQuery{ + Servers: servers, + Domains: domains, + RecordType: "MX", + Timeout: config.Duration(2 * time.Second), } var acc testutil.Accumulator - dnsConfig.RecordType = "MX" - err := acc.GatherError(dnsConfig.Gather) - require.NoError(t, err) + require.NoError(t, dnsConfig.Init()) + require.NoError(t, acc.GatherError(dnsConfig.Gather)) metric, ok := acc.Get("dns_query") require.True(t, ok) - queryTime, _ := metric.Fields["query_time_ms"].(float64) - - require.NotEqual(t, 0, queryTime) + queryTime, ok := metric.Fields["query_time_ms"].(float64) + require.True(t, ok) + require.NotEqual(t, float64(0), queryTime) } func TestGatheringRootDomain(t *testing.T) { if testing.Short() { t.Skip("Skipping network-dependent test in short mode.") } - var dnsConfig = DNSQuery{ + + dnsConfig := DNSQuery{ Servers: servers, Domains: []string{"."}, RecordType: "MX", + Timeout: config.Duration(2 * time.Second), } + require.NoError(t, dnsConfig.Init()) + var acc testutil.Accumulator - tags := map[string]string{ - "server": "8.8.8.8", - "domain": ".", - "record_type": "MX", - "rcode": "NOERROR", - "result": "success", - } - fields := map[string]interface{}{ - "rcode_value": 0, - "result_code": uint64(0), - } + require.NoError(t, acc.GatherError(dnsConfig.Gather)) - err := acc.GatherError(dnsConfig.Gather) - require.NoError(t, err) - metric, ok := acc.Get("dns_query") + m, ok := acc.Get("dns_query") + require.True(t, ok) + queryTime, ok := m.Fields["query_time_ms"].(float64) require.True(t, ok) - queryTime, _ := metric.Fields["query_time_ms"].(float64) - fields["query_time_ms"] = queryTime - acc.AssertContainsTaggedFields(t, "dns_query", fields, tags) + expected := []telegraf.Metric{ + metric.New( + "dns_query", + map[string]string{ + "server": "8.8.8.8", + "domain": ".", + "record_type": "MX", + "rcode": "NOERROR", + "result": "success", + }, + map[string]interface{}{ + "rcode_value": 0, + "result_code": uint64(0), + "query_time_ms": queryTime, + }, + time.Unix(0, 0), + ), + } + testutil.RequireMetricsEqual(t, expected, acc.GetTelegrafMetrics(), testutil.IgnoreTime()) } func TestMetricContainsServerAndDomainAndRecordTypeTags(t *testing.T) { if testing.Short() { t.Skip("Skipping network-dependent test in short mode.") } - var dnsConfig = DNSQuery{ + + dnsConfig := DNSQuery{ Servers: servers, Domains: domains, + Timeout: config.Duration(2 * time.Second), } + require.NoError(t, dnsConfig.Init()) + var acc testutil.Accumulator - tags := map[string]string{ - "server": "8.8.8.8", - "domain": "google.com", - "record_type": "NS", - "rcode": "NOERROR", - "result": "success", - } - fields := map[string]interface{}{ - "rcode_value": 0, - "result_code": uint64(0), - } + require.NoError(t, acc.GatherError(dnsConfig.Gather)) - err := acc.GatherError(dnsConfig.Gather) - require.NoError(t, err) - metric, ok := acc.Get("dns_query") + m, ok := acc.Get("dns_query") require.True(t, ok) - queryTime, _ := metric.Fields["query_time_ms"].(float64) - - fields["query_time_ms"] = queryTime - acc.AssertContainsTaggedFields(t, "dns_query", fields, tags) + queryTime, ok := m.Fields["query_time_ms"].(float64) + require.True(t, ok) + expected := []telegraf.Metric{ + metric.New( + "dns_query", + map[string]string{ + "server": "8.8.8.8", + "domain": "google.com", + "record_type": "NS", + "rcode": "NOERROR", + "result": "success", + }, + map[string]interface{}{ + "rcode_value": 0, + "result_code": uint64(0), + "query_time_ms": queryTime, + }, + time.Unix(0, 0), + ), + } + testutil.RequireMetricsEqual(t, expected, acc.GetTelegrafMetrics(), testutil.IgnoreTime()) } func TestGatheringTimeout(t *testing.T) { if testing.Short() { t.Skip("Skipping network-dependent test in short mode.") } - var dnsConfig = DNSQuery{ + + dnsConfig := DNSQuery{ Servers: servers, Domains: domains, + Timeout: config.Duration(1 * time.Second), + Port: 60054, } - var acc testutil.Accumulator - dnsConfig.Port = 60054 - dnsConfig.Timeout = 1 + require.NoError(t, dnsConfig.Init()) + var acc testutil.Accumulator channel := make(chan error, 1) go func() { channel <- acc.GatherError(dnsConfig.Gather) @@ -140,76 +168,95 @@ func TestGatheringTimeout(t *testing.T) { } func TestSettingDefaultValues(t *testing.T) { - dnsConfig := DNSQuery{} - - dnsConfig.setDefaultValues() - + dnsConfig := DNSQuery{ + Timeout: config.Duration(2 * time.Second), + } + require.NoError(t, dnsConfig.Init()) require.Equal(t, []string{"."}, dnsConfig.Domains, "Default domain not equal \".\"") require.Equal(t, "NS", dnsConfig.RecordType, "Default record type not equal 'NS'") require.Equal(t, 53, dnsConfig.Port, "Default port number not equal 53") - require.Equal(t, 2, dnsConfig.Timeout, "Default timeout not equal 2") - - dnsConfig = DNSQuery{Domains: []string{"."}} - - dnsConfig.setDefaultValues() + require.Equal(t, config.Duration(2*time.Second), dnsConfig.Timeout, "Default timeout not equal 2s") + dnsConfig = DNSQuery{ + Domains: []string{"."}, + Timeout: config.Duration(2 * time.Second), + } + require.NoError(t, dnsConfig.Init()) require.Equal(t, "NS", dnsConfig.RecordType, "Default record type not equal 'NS'") } func TestRecordTypeParser(t *testing.T) { - var dnsConfig = DNSQuery{} - var recordType uint16 + tests := []struct { + record string + expected uint16 + }{ + { + record: "A", + expected: dns.TypeA, + }, + { + record: "AAAA", + expected: dns.TypeAAAA, + }, + { + record: "ANY", + expected: dns.TypeANY, + }, + { + record: "CNAME", + expected: dns.TypeCNAME, + }, + { + record: "MX", + expected: dns.TypeMX, + }, + { + record: "NS", + expected: dns.TypeNS, + }, + { + record: "PTR", + expected: dns.TypePTR, + }, + { + record: "SOA", + expected: dns.TypeSOA, + }, + { + record: "SPF", + expected: dns.TypeSPF, + }, + { + record: "SRV", + expected: dns.TypeSRV, + }, + { + record: "TXT", + expected: dns.TypeTXT, + }, + } - dnsConfig.RecordType = "A" - recordType, _ = dnsConfig.parseRecordType() - require.Equal(t, dns.TypeA, recordType) - - dnsConfig.RecordType = "AAAA" - recordType, _ = dnsConfig.parseRecordType() - require.Equal(t, dns.TypeAAAA, recordType) - - dnsConfig.RecordType = "ANY" - recordType, _ = dnsConfig.parseRecordType() - require.Equal(t, dns.TypeANY, recordType) - - dnsConfig.RecordType = "CNAME" - recordType, _ = dnsConfig.parseRecordType() - require.Equal(t, dns.TypeCNAME, recordType) - - dnsConfig.RecordType = "MX" - recordType, _ = dnsConfig.parseRecordType() - require.Equal(t, dns.TypeMX, recordType) - - dnsConfig.RecordType = "NS" - recordType, _ = dnsConfig.parseRecordType() - require.Equal(t, dns.TypeNS, recordType) - - dnsConfig.RecordType = "PTR" - recordType, _ = dnsConfig.parseRecordType() - require.Equal(t, dns.TypePTR, recordType) - - dnsConfig.RecordType = "SOA" - recordType, _ = dnsConfig.parseRecordType() - require.Equal(t, dns.TypeSOA, recordType) - - dnsConfig.RecordType = "SPF" - recordType, _ = dnsConfig.parseRecordType() - require.Equal(t, dns.TypeSPF, recordType) - - dnsConfig.RecordType = "SRV" - recordType, _ = dnsConfig.parseRecordType() - require.Equal(t, dns.TypeSRV, recordType) - - dnsConfig.RecordType = "TXT" - recordType, _ = dnsConfig.parseRecordType() - require.Equal(t, dns.TypeTXT, recordType) + for _, tt := range tests { + t.Run(tt.record, func(t *testing.T) { + plugin := DNSQuery{ + Timeout: config.Duration(2 * time.Second), + Domains: []string{"example.com"}, + RecordType: tt.record, + } + require.NoError(t, plugin.Init()) + recordType, err := plugin.parseRecordType() + require.NoError(t, err) + require.Equal(t, tt.expected, recordType) + }) + } } func TestRecordTypeParserError(t *testing.T) { - var dnsConfig = DNSQuery{} - var err error + plugin := DNSQuery{ + Timeout: config.Duration(2 * time.Second), + RecordType: "nil", + } - dnsConfig.RecordType = "nil" - _, err = dnsConfig.parseRecordType() + _, err := plugin.parseRecordType() require.Error(t, err) } diff --git a/plugins/inputs/dns_query/sample.conf b/plugins/inputs/dns_query/sample.conf index 60ac2cc02..ea8dfe20d 100644 --- a/plugins/inputs/dns_query/sample.conf +++ b/plugins/inputs/dns_query/sample.conf @@ -16,5 +16,11 @@ ## Dns server port. # port = 53 - ## Query timeout in seconds. - # timeout = 2 + ## Query timeout + # timeout = "2s" + + ## Include the specified additional properties in the resulting metric. + ## The following values are supported: + ## "first_ip" -- return IP of the first A and AAAA answer + ## "all_ips" -- return IPs of all A and AAAA answers + # include_fields = []