fix(inputs.mysql): avoid side-effects for TLS between plugin instances (#12576)

This commit is contained in:
Sven Rebhan 2023-01-31 19:29:45 +01:00 committed by GitHub
parent f82f2fdb16
commit e6655d534e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 100 additions and 110 deletions

View File

@ -11,8 +11,10 @@ import (
"time" "time"
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
"github.com/gofrs/uuid"
"github.com/influxdata/telegraf" "github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/plugins/common/tls" "github.com/influxdata/telegraf/plugins/common/tls"
"github.com/influxdata/telegraf/plugins/inputs" "github.com/influxdata/telegraf/plugins/inputs"
v1 "github.com/influxdata/telegraf/plugins/inputs/mysql/v1" v1 "github.com/influxdata/telegraf/plugins/inputs/mysql/v1"
@ -23,38 +25,36 @@ import (
var sampleConfig string var sampleConfig string
type Mysql struct { type Mysql struct {
Servers []string `toml:"servers"` Servers []string `toml:"servers"`
PerfEventsStatementsDigestTextLimit int64 `toml:"perf_events_statements_digest_text_limit"` PerfEventsStatementsDigestTextLimit int64 `toml:"perf_events_statements_digest_text_limit"`
PerfEventsStatementsLimit int64 `toml:"perf_events_statements_limit"` PerfEventsStatementsLimit int64 `toml:"perf_events_statements_limit"`
PerfEventsStatementsTimeLimit int64 `toml:"perf_events_statements_time_limit"` PerfEventsStatementsTimeLimit int64 `toml:"perf_events_statements_time_limit"`
TableSchemaDatabases []string `toml:"table_schema_databases"` TableSchemaDatabases []string `toml:"table_schema_databases"`
GatherProcessList bool `toml:"gather_process_list"` GatherProcessList bool `toml:"gather_process_list"`
GatherUserStatistics bool `toml:"gather_user_statistics"` GatherUserStatistics bool `toml:"gather_user_statistics"`
GatherInfoSchemaAutoInc bool `toml:"gather_info_schema_auto_inc"` GatherInfoSchemaAutoInc bool `toml:"gather_info_schema_auto_inc"`
GatherInnoDBMetrics bool `toml:"gather_innodb_metrics"` GatherInnoDBMetrics bool `toml:"gather_innodb_metrics"`
GatherSlaveStatus bool `toml:"gather_slave_status"` GatherSlaveStatus bool `toml:"gather_slave_status"`
GatherAllSlaveChannels bool `toml:"gather_all_slave_channels"` GatherAllSlaveChannels bool `toml:"gather_all_slave_channels"`
MariadbDialect bool `toml:"mariadb_dialect"` MariadbDialect bool `toml:"mariadb_dialect"`
GatherBinaryLogs bool `toml:"gather_binary_logs"` GatherBinaryLogs bool `toml:"gather_binary_logs"`
GatherTableIOWaits bool `toml:"gather_table_io_waits"` GatherTableIOWaits bool `toml:"gather_table_io_waits"`
GatherTableLockWaits bool `toml:"gather_table_lock_waits"` GatherTableLockWaits bool `toml:"gather_table_lock_waits"`
GatherIndexIOWaits bool `toml:"gather_index_io_waits"` GatherIndexIOWaits bool `toml:"gather_index_io_waits"`
GatherEventWaits bool `toml:"gather_event_waits"` GatherEventWaits bool `toml:"gather_event_waits"`
GatherTableSchema bool `toml:"gather_table_schema"` GatherTableSchema bool `toml:"gather_table_schema"`
GatherFileEventsStats bool `toml:"gather_file_events_stats"` GatherFileEventsStats bool `toml:"gather_file_events_stats"`
GatherPerfEventsStatements bool `toml:"gather_perf_events_statements"` GatherPerfEventsStatements bool `toml:"gather_perf_events_statements"`
GatherGlobalVars bool `toml:"gather_global_variables"` GatherGlobalVars bool `toml:"gather_global_variables"`
GatherPerfSummaryPerAccountPerEvent bool `toml:"gather_perf_sum_per_acc_per_event"` GatherPerfSummaryPerAccountPerEvent bool `toml:"gather_perf_sum_per_acc_per_event"`
PerfSummaryEvents []string `toml:"perf_summary_events"` PerfSummaryEvents []string `toml:"perf_summary_events"`
IntervalSlow string `toml:"interval_slow"` IntervalSlow config.Duration `toml:"interval_slow"`
MetricVersion int `toml:"metric_version"` MetricVersion int `toml:"metric_version"`
Log telegraf.Logger `toml:"-"` Log telegraf.Logger `toml:"-"`
tls.ClientConfig tls.ClientConfig
lastT time.Time lastT time.Time
initDone bool getStatusQuery string
scanIntervalSlow uint32
getStatusQuery string
} }
const ( const (
@ -70,42 +70,61 @@ func (*Mysql) SampleConfig() string {
return sampleConfig return sampleConfig
} }
func (m *Mysql) InitMysql() { func (m *Mysql) Init() error {
if len(m.IntervalSlow) > 0 {
interval, err := time.ParseDuration(m.IntervalSlow)
if err == nil && interval.Seconds() >= 1.0 {
m.scanIntervalSlow = uint32(interval.Seconds())
}
}
if m.MariadbDialect { if m.MariadbDialect {
m.getStatusQuery = slaveStatusQueryMariadb m.getStatusQuery = slaveStatusQueryMariadb
} else { } else {
m.getStatusQuery = slaveStatusQuery m.getStatusQuery = slaveStatusQuery
} }
m.initDone = true
}
func (m *Mysql) Gather(acc telegraf.Accumulator) error { // Default to localhost if nothing specified.
if len(m.Servers) == 0 { if len(m.Servers) == 0 {
// default to localhost if nothing specified. m.Servers = append(m.Servers, localhost)
return m.gatherServer(localhost, acc)
}
// Initialise additional query intervals
if !m.initDone {
m.InitMysql()
} }
// Register the TLS configuration. Due to the registry being a global
// one for the mysql package, we need to define unique IDs to avoid
// side effects and races between different plugin instances. Therefore,
// we decorate the "custom" naming of the "tls" parameter with an UUID.
tlsuuid, err := uuid.NewV7()
if err != nil {
return fmt.Errorf("cannot create UUID: %w", err)
}
tlsid := "custom-" + tlsuuid.String()
tlsConfig, err := m.ClientConfig.TLSConfig() tlsConfig, err := m.ClientConfig.TLSConfig()
if err != nil { if err != nil {
return fmt.Errorf("registering TLS config: %s", err) return fmt.Errorf("registering TLS config: %s", err)
} }
if tlsConfig != nil { if tlsConfig != nil {
if err := mysql.RegisterTLSConfig("custom", tlsConfig); err != nil { if err := mysql.RegisterTLSConfig(tlsid, tlsConfig); err != nil {
return err return err
} }
} }
// Adapt the DSN string
for i, dsn := range m.Servers {
conf, err := mysql.ParseDSN(dsn)
if err != nil {
return fmt.Errorf("parsing %q failed: %w", dsn, err)
}
// Set the default timeout if none specified
if conf.Timeout == 0 {
conf.Timeout = time.Second * 5
}
// Reference the custom TLS config of _THIS_ plugin instance
if conf.TLSConfig == "custom" {
conf.TLSConfig = tlsid
}
m.Servers[i] = conf.FormatDSN()
}
return nil
}
func (m *Mysql) Gather(acc telegraf.Accumulator) error {
var wg sync.WaitGroup var wg sync.WaitGroup
// Loop through each server and collect metrics // Loop through each server and collect metrics
@ -385,16 +404,10 @@ const (
) )
func (m *Mysql) gatherServer(serv string, acc telegraf.Accumulator) error { func (m *Mysql) gatherServer(serv string, acc telegraf.Accumulator) error {
serv, err := dsnAddTimeout(serv)
if err != nil {
return err
}
db, err := sql.Open("mysql", serv) db, err := sql.Open("mysql", serv)
if err != nil { if err != nil {
return err return err
} }
defer db.Close() defer db.Close()
err = m.gatherGlobalStatuses(db, serv, acc) err = m.gatherGlobalStatuses(db, serv, acc)
@ -404,14 +417,12 @@ func (m *Mysql) gatherServer(serv string, acc telegraf.Accumulator) error {
if m.GatherGlobalVars { if m.GatherGlobalVars {
// Global Variables may be gathered less often // Global Variables may be gathered less often
if len(m.IntervalSlow) > 0 { interval := time.Duration(m.IntervalSlow)
if uint32(time.Since(m.lastT).Seconds()) >= m.scanIntervalSlow { if interval >= time.Second && time.Since(m.lastT) >= interval {
err = m.gatherGlobalVariables(db, serv, acc) if err := m.gatherGlobalVariables(db, serv, acc); err != nil {
if err != nil { return err
return err
}
m.lastT = time.Now()
} }
m.lastT = time.Now()
} }
} }
@ -1930,19 +1941,6 @@ func copyTags(in map[string]string) map[string]string {
return out return out
} }
func dsnAddTimeout(dsn string) (string, error) {
conf, err := mysql.ParseDSN(dsn)
if err != nil {
return "", err
}
if conf.Timeout == 0 {
conf.Timeout = time.Second * 5
}
return conf.FormatDSN(), nil
}
func getDSNTag(dsn string) string { func getDSNTag(dsn string) string {
conf, err := mysql.ParseDSN(dsn) conf, err := mysql.ParseDSN(dsn)
if err != nil { if err != nil {

View File

@ -3,12 +3,14 @@ package mysql
import ( import (
"fmt" "fmt"
"testing" "testing"
"time"
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
"github.com/docker/go-connections/nat" "github.com/docker/go-connections/nat"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go/wait" "github.com/testcontainers/testcontainers-go/wait"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/testutil" "github.com/influxdata/telegraf/testutil"
) )
@ -31,17 +33,16 @@ func TestMysqlDefaultsToLocalIntegration(t *testing.T) {
), ),
} }
err := container.Start() require.NoError(t, container.Start(), "failed to start container")
require.NoError(t, err, "failed to start container")
defer container.Terminate() defer container.Terminate()
m := &Mysql{ m := &Mysql{
Servers: []string{fmt.Sprintf("root@tcp(%s:%s)/", container.Address, container.Ports[servicePort])}, Servers: []string{fmt.Sprintf("root@tcp(%s:%s)/", container.Address, container.Ports[servicePort])},
} }
require.NoError(t, m.Init())
var acc testutil.Accumulator var acc testutil.Accumulator
err = m.Gather(&acc) require.NoError(t, m.Gather(&acc))
require.NoError(t, err)
require.Empty(t, acc.Errors) require.Empty(t, acc.Errors)
require.True(t, acc.HasMeasurement("mysql")) require.True(t, acc.HasMeasurement("mysql"))
@ -66,21 +67,20 @@ func TestMysqlMultipleInstancesIntegration(t *testing.T) {
), ),
} }
err := container.Start() require.NoError(t, container.Start(), "failed to start container")
require.NoError(t, err, "failed to start container")
defer container.Terminate() defer container.Terminate()
testServer := fmt.Sprintf("root@tcp(%s:%s)/?tls=false", container.Address, container.Ports[servicePort]) testServer := fmt.Sprintf("root@tcp(%s:%s)/?tls=false", container.Address, container.Ports[servicePort])
m := &Mysql{ m := &Mysql{
Servers: []string{testServer}, Servers: []string{testServer},
IntervalSlow: "30s", IntervalSlow: config.Duration(30 * time.Second),
GatherGlobalVars: true, GatherGlobalVars: true,
MetricVersion: 2, MetricVersion: 2,
} }
require.NoError(t, m.Init())
var acc, acc2 testutil.Accumulator var acc testutil.Accumulator
err = m.Gather(&acc) require.NoError(t, m.Gather(&acc))
require.NoError(t, err)
require.Empty(t, acc.Errors) require.Empty(t, acc.Errors)
require.True(t, acc.HasMeasurement("mysql")) require.True(t, acc.HasMeasurement("mysql"))
// acc should have global variables // acc should have global variables
@ -90,33 +90,16 @@ func TestMysqlMultipleInstancesIntegration(t *testing.T) {
Servers: []string{testServer}, Servers: []string{testServer},
MetricVersion: 2, MetricVersion: 2,
} }
err = m2.Gather(&acc2) require.NoError(t, m2.Init())
require.NoError(t, err)
var acc2 testutil.Accumulator
require.NoError(t, m2.Gather(&acc2))
require.Empty(t, acc.Errors) require.Empty(t, acc.Errors)
require.True(t, acc2.HasMeasurement("mysql")) require.True(t, acc2.HasMeasurement("mysql"))
// acc2 should not have global variables // acc2 should not have global variables
require.False(t, acc2.HasMeasurement("mysql_variables")) require.False(t, acc2.HasMeasurement("mysql_variables"))
} }
func TestMysqlMultipleInits(t *testing.T) {
m := &Mysql{
IntervalSlow: "30s",
}
m2 := &Mysql{}
m.InitMysql()
require.True(t, m.initDone)
require.False(t, m2.initDone)
require.Equal(t, m.scanIntervalSlow, uint32(30))
require.Equal(t, m2.scanIntervalSlow, uint32(0))
m2.InitMysql()
require.True(t, m.initDone)
require.True(t, m2.initDone)
require.Equal(t, m.scanIntervalSlow, uint32(30))
require.Equal(t, m2.scanIntervalSlow, uint32(0))
}
func TestMysqlGetDSNTag(t *testing.T) { func TestMysqlGetDSNTag(t *testing.T) {
tests := []struct { tests := []struct {
input string input string
@ -178,44 +161,53 @@ func TestMysqlGetDSNTag(t *testing.T) {
func TestMysqlDNSAddTimeout(t *testing.T) { func TestMysqlDNSAddTimeout(t *testing.T) {
tests := []struct { tests := []struct {
name string
input string input string
output string output string
}{ }{
{ {
"empty",
"", "",
"tcp(127.0.0.1:3306)/?timeout=5s", "tcp(127.0.0.1:3306)/?timeout=5s",
}, },
{ {
"no timeout",
"tcp(192.168.1.1:3306)/", "tcp(192.168.1.1:3306)/",
"tcp(192.168.1.1:3306)/?timeout=5s", "tcp(192.168.1.1:3306)/?timeout=5s",
}, },
{ {
"no timeout with credentials",
"root:passwd@tcp(192.168.1.1:3306)/?tls=false", "root:passwd@tcp(192.168.1.1:3306)/?tls=false",
"root:passwd@tcp(192.168.1.1:3306)/?timeout=5s&tls=false", "root:passwd@tcp(192.168.1.1:3306)/?timeout=5s&tls=false",
}, },
{ {
"with timeout and credentials",
"root:passwd@tcp(192.168.1.1:3306)/?tls=false&timeout=10s", "root:passwd@tcp(192.168.1.1:3306)/?tls=false&timeout=10s",
"root:passwd@tcp(192.168.1.1:3306)/?timeout=10s&tls=false", "root:passwd@tcp(192.168.1.1:3306)/?timeout=10s&tls=false",
}, },
{ {
"no timeout different IP",
"tcp(10.150.1.123:3306)/", "tcp(10.150.1.123:3306)/",
"tcp(10.150.1.123:3306)/?timeout=5s", "tcp(10.150.1.123:3306)/?timeout=5s",
}, },
{ {
"no timeout with bracket credentials",
"root:@!~(*&$#%(&@#(@&#Password@tcp(10.150.1.123:3306)/", "root:@!~(*&$#%(&@#(@&#Password@tcp(10.150.1.123:3306)/",
"root:@!~(*&$#%(&@#(@&#Password@tcp(10.150.1.123:3306)/?timeout=5s", "root:@!~(*&$#%(&@#(@&#Password@tcp(10.150.1.123:3306)/?timeout=5s",
}, },
{ {
"no timeout with strange credentials",
"root:Test3a#@!@tcp(10.150.1.123:3306)/", "root:Test3a#@!@tcp(10.150.1.123:3306)/",
"root:Test3a#@!@tcp(10.150.1.123:3306)/?timeout=5s", "root:Test3a#@!@tcp(10.150.1.123:3306)/?timeout=5s",
}, },
} }
for _, test := range tests { for _, tt := range tests {
output, _ := dsnAddTimeout(test.input) t.Run(tt.name, func(t *testing.T) {
if output != test.output { m := &Mysql{Servers: []string{tt.input}}
t.Errorf("Expected %s, got %s\n", test.output, output) require.NoError(t, m.Init())
} require.Equal(t, tt.output, m.Servers[0])
})
} }
} }
@ -228,7 +220,7 @@ func TestGatherGlobalVariables(t *testing.T) {
Log: testutil.Logger{}, Log: testutil.Logger{},
MetricVersion: 2, MetricVersion: 2,
} }
m.InitMysql() require.NoError(t, m.Init())
columns := []string{"Variable_name", "Value"} columns := []string{"Variable_name", "Value"}
measurement := "mysql_variables" measurement := "mysql_variables"