fix(inputs.mysql): avoid side-effects for TLS between plugin instances (#12576)

This commit is contained in:
Sven Rebhan 2023-01-31 19:29:45 +01:00 committed by GitHub
parent f82f2fdb16
commit e6655d534e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 100 additions and 110 deletions

View File

@ -11,8 +11,10 @@ import (
"time"
"github.com/go-sql-driver/mysql"
"github.com/gofrs/uuid"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/plugins/common/tls"
"github.com/influxdata/telegraf/plugins/inputs"
v1 "github.com/influxdata/telegraf/plugins/inputs/mysql/v1"
@ -23,38 +25,36 @@ import (
var sampleConfig string
type Mysql struct {
Servers []string `toml:"servers"`
PerfEventsStatementsDigestTextLimit int64 `toml:"perf_events_statements_digest_text_limit"`
PerfEventsStatementsLimit int64 `toml:"perf_events_statements_limit"`
PerfEventsStatementsTimeLimit int64 `toml:"perf_events_statements_time_limit"`
TableSchemaDatabases []string `toml:"table_schema_databases"`
GatherProcessList bool `toml:"gather_process_list"`
GatherUserStatistics bool `toml:"gather_user_statistics"`
GatherInfoSchemaAutoInc bool `toml:"gather_info_schema_auto_inc"`
GatherInnoDBMetrics bool `toml:"gather_innodb_metrics"`
GatherSlaveStatus bool `toml:"gather_slave_status"`
GatherAllSlaveChannels bool `toml:"gather_all_slave_channels"`
MariadbDialect bool `toml:"mariadb_dialect"`
GatherBinaryLogs bool `toml:"gather_binary_logs"`
GatherTableIOWaits bool `toml:"gather_table_io_waits"`
GatherTableLockWaits bool `toml:"gather_table_lock_waits"`
GatherIndexIOWaits bool `toml:"gather_index_io_waits"`
GatherEventWaits bool `toml:"gather_event_waits"`
GatherTableSchema bool `toml:"gather_table_schema"`
GatherFileEventsStats bool `toml:"gather_file_events_stats"`
GatherPerfEventsStatements bool `toml:"gather_perf_events_statements"`
GatherGlobalVars bool `toml:"gather_global_variables"`
GatherPerfSummaryPerAccountPerEvent bool `toml:"gather_perf_sum_per_acc_per_event"`
PerfSummaryEvents []string `toml:"perf_summary_events"`
IntervalSlow string `toml:"interval_slow"`
MetricVersion int `toml:"metric_version"`
Servers []string `toml:"servers"`
PerfEventsStatementsDigestTextLimit int64 `toml:"perf_events_statements_digest_text_limit"`
PerfEventsStatementsLimit int64 `toml:"perf_events_statements_limit"`
PerfEventsStatementsTimeLimit int64 `toml:"perf_events_statements_time_limit"`
TableSchemaDatabases []string `toml:"table_schema_databases"`
GatherProcessList bool `toml:"gather_process_list"`
GatherUserStatistics bool `toml:"gather_user_statistics"`
GatherInfoSchemaAutoInc bool `toml:"gather_info_schema_auto_inc"`
GatherInnoDBMetrics bool `toml:"gather_innodb_metrics"`
GatherSlaveStatus bool `toml:"gather_slave_status"`
GatherAllSlaveChannels bool `toml:"gather_all_slave_channels"`
MariadbDialect bool `toml:"mariadb_dialect"`
GatherBinaryLogs bool `toml:"gather_binary_logs"`
GatherTableIOWaits bool `toml:"gather_table_io_waits"`
GatherTableLockWaits bool `toml:"gather_table_lock_waits"`
GatherIndexIOWaits bool `toml:"gather_index_io_waits"`
GatherEventWaits bool `toml:"gather_event_waits"`
GatherTableSchema bool `toml:"gather_table_schema"`
GatherFileEventsStats bool `toml:"gather_file_events_stats"`
GatherPerfEventsStatements bool `toml:"gather_perf_events_statements"`
GatherGlobalVars bool `toml:"gather_global_variables"`
GatherPerfSummaryPerAccountPerEvent bool `toml:"gather_perf_sum_per_acc_per_event"`
PerfSummaryEvents []string `toml:"perf_summary_events"`
IntervalSlow config.Duration `toml:"interval_slow"`
MetricVersion int `toml:"metric_version"`
Log telegraf.Logger `toml:"-"`
tls.ClientConfig
lastT time.Time
initDone bool
scanIntervalSlow uint32
getStatusQuery string
lastT time.Time
getStatusQuery string
}
const (
@ -70,42 +70,61 @@ func (*Mysql) SampleConfig() string {
return sampleConfig
}
func (m *Mysql) InitMysql() {
if len(m.IntervalSlow) > 0 {
interval, err := time.ParseDuration(m.IntervalSlow)
if err == nil && interval.Seconds() >= 1.0 {
m.scanIntervalSlow = uint32(interval.Seconds())
}
}
func (m *Mysql) Init() error {
if m.MariadbDialect {
m.getStatusQuery = slaveStatusQueryMariadb
} else {
m.getStatusQuery = slaveStatusQuery
}
m.initDone = true
}
func (m *Mysql) Gather(acc telegraf.Accumulator) error {
// Default to localhost if nothing specified.
if len(m.Servers) == 0 {
// default to localhost if nothing specified.
return m.gatherServer(localhost, acc)
}
// Initialise additional query intervals
if !m.initDone {
m.InitMysql()
m.Servers = append(m.Servers, localhost)
}
// Register the TLS configuration. Due to the registry being a global
// one for the mysql package, we need to define unique IDs to avoid
// side effects and races between different plugin instances. Therefore,
// we decorate the "custom" naming of the "tls" parameter with an UUID.
tlsuuid, err := uuid.NewV7()
if err != nil {
return fmt.Errorf("cannot create UUID: %w", err)
}
tlsid := "custom-" + tlsuuid.String()
tlsConfig, err := m.ClientConfig.TLSConfig()
if err != nil {
return fmt.Errorf("registering TLS config: %s", err)
}
if tlsConfig != nil {
if err := mysql.RegisterTLSConfig("custom", tlsConfig); err != nil {
if err := mysql.RegisterTLSConfig(tlsid, tlsConfig); err != nil {
return err
}
}
// Adapt the DSN string
for i, dsn := range m.Servers {
conf, err := mysql.ParseDSN(dsn)
if err != nil {
return fmt.Errorf("parsing %q failed: %w", dsn, err)
}
// Set the default timeout if none specified
if conf.Timeout == 0 {
conf.Timeout = time.Second * 5
}
// Reference the custom TLS config of _THIS_ plugin instance
if conf.TLSConfig == "custom" {
conf.TLSConfig = tlsid
}
m.Servers[i] = conf.FormatDSN()
}
return nil
}
func (m *Mysql) Gather(acc telegraf.Accumulator) error {
var wg sync.WaitGroup
// Loop through each server and collect metrics
@ -385,16 +404,10 @@ const (
)
func (m *Mysql) gatherServer(serv string, acc telegraf.Accumulator) error {
serv, err := dsnAddTimeout(serv)
if err != nil {
return err
}
db, err := sql.Open("mysql", serv)
if err != nil {
return err
}
defer db.Close()
err = m.gatherGlobalStatuses(db, serv, acc)
@ -404,14 +417,12 @@ func (m *Mysql) gatherServer(serv string, acc telegraf.Accumulator) error {
if m.GatherGlobalVars {
// Global Variables may be gathered less often
if len(m.IntervalSlow) > 0 {
if uint32(time.Since(m.lastT).Seconds()) >= m.scanIntervalSlow {
err = m.gatherGlobalVariables(db, serv, acc)
if err != nil {
return err
}
m.lastT = time.Now()
interval := time.Duration(m.IntervalSlow)
if interval >= time.Second && time.Since(m.lastT) >= interval {
if err := m.gatherGlobalVariables(db, serv, acc); err != nil {
return err
}
m.lastT = time.Now()
}
}
@ -1930,19 +1941,6 @@ func copyTags(in map[string]string) map[string]string {
return out
}
func dsnAddTimeout(dsn string) (string, error) {
conf, err := mysql.ParseDSN(dsn)
if err != nil {
return "", err
}
if conf.Timeout == 0 {
conf.Timeout = time.Second * 5
}
return conf.FormatDSN(), nil
}
func getDSNTag(dsn string) string {
conf, err := mysql.ParseDSN(dsn)
if err != nil {

View File

@ -3,12 +3,14 @@ package mysql
import (
"fmt"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/docker/go-connections/nat"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go/wait"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/testutil"
)
@ -31,17 +33,16 @@ func TestMysqlDefaultsToLocalIntegration(t *testing.T) {
),
}
err := container.Start()
require.NoError(t, err, "failed to start container")
require.NoError(t, container.Start(), "failed to start container")
defer container.Terminate()
m := &Mysql{
Servers: []string{fmt.Sprintf("root@tcp(%s:%s)/", container.Address, container.Ports[servicePort])},
}
require.NoError(t, m.Init())
var acc testutil.Accumulator
err = m.Gather(&acc)
require.NoError(t, err)
require.NoError(t, m.Gather(&acc))
require.Empty(t, acc.Errors)
require.True(t, acc.HasMeasurement("mysql"))
@ -66,21 +67,20 @@ func TestMysqlMultipleInstancesIntegration(t *testing.T) {
),
}
err := container.Start()
require.NoError(t, err, "failed to start container")
require.NoError(t, container.Start(), "failed to start container")
defer container.Terminate()
testServer := fmt.Sprintf("root@tcp(%s:%s)/?tls=false", container.Address, container.Ports[servicePort])
m := &Mysql{
Servers: []string{testServer},
IntervalSlow: "30s",
IntervalSlow: config.Duration(30 * time.Second),
GatherGlobalVars: true,
MetricVersion: 2,
}
require.NoError(t, m.Init())
var acc, acc2 testutil.Accumulator
err = m.Gather(&acc)
require.NoError(t, err)
var acc testutil.Accumulator
require.NoError(t, m.Gather(&acc))
require.Empty(t, acc.Errors)
require.True(t, acc.HasMeasurement("mysql"))
// acc should have global variables
@ -90,33 +90,16 @@ func TestMysqlMultipleInstancesIntegration(t *testing.T) {
Servers: []string{testServer},
MetricVersion: 2,
}
err = m2.Gather(&acc2)
require.NoError(t, err)
require.NoError(t, m2.Init())
var acc2 testutil.Accumulator
require.NoError(t, m2.Gather(&acc2))
require.Empty(t, acc.Errors)
require.True(t, acc2.HasMeasurement("mysql"))
// acc2 should not have global variables
require.False(t, acc2.HasMeasurement("mysql_variables"))
}
func TestMysqlMultipleInits(t *testing.T) {
m := &Mysql{
IntervalSlow: "30s",
}
m2 := &Mysql{}
m.InitMysql()
require.True(t, m.initDone)
require.False(t, m2.initDone)
require.Equal(t, m.scanIntervalSlow, uint32(30))
require.Equal(t, m2.scanIntervalSlow, uint32(0))
m2.InitMysql()
require.True(t, m.initDone)
require.True(t, m2.initDone)
require.Equal(t, m.scanIntervalSlow, uint32(30))
require.Equal(t, m2.scanIntervalSlow, uint32(0))
}
func TestMysqlGetDSNTag(t *testing.T) {
tests := []struct {
input string
@ -178,44 +161,53 @@ func TestMysqlGetDSNTag(t *testing.T) {
func TestMysqlDNSAddTimeout(t *testing.T) {
tests := []struct {
name string
input string
output string
}{
{
"empty",
"",
"tcp(127.0.0.1:3306)/?timeout=5s",
},
{
"no timeout",
"tcp(192.168.1.1:3306)/",
"tcp(192.168.1.1:3306)/?timeout=5s",
},
{
"no timeout with credentials",
"root:passwd@tcp(192.168.1.1:3306)/?tls=false",
"root:passwd@tcp(192.168.1.1:3306)/?timeout=5s&tls=false",
},
{
"with timeout and credentials",
"root:passwd@tcp(192.168.1.1:3306)/?tls=false&timeout=10s",
"root:passwd@tcp(192.168.1.1:3306)/?timeout=10s&tls=false",
},
{
"no timeout different IP",
"tcp(10.150.1.123:3306)/",
"tcp(10.150.1.123:3306)/?timeout=5s",
},
{
"no timeout with bracket credentials",
"root:@!~(*&$#%(&@#(@&#Password@tcp(10.150.1.123:3306)/",
"root:@!~(*&$#%(&@#(@&#Password@tcp(10.150.1.123:3306)/?timeout=5s",
},
{
"no timeout with strange credentials",
"root:Test3a#@!@tcp(10.150.1.123:3306)/",
"root:Test3a#@!@tcp(10.150.1.123:3306)/?timeout=5s",
},
}
for _, test := range tests {
output, _ := dsnAddTimeout(test.input)
if output != test.output {
t.Errorf("Expected %s, got %s\n", test.output, output)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := &Mysql{Servers: []string{tt.input}}
require.NoError(t, m.Init())
require.Equal(t, tt.output, m.Servers[0])
})
}
}
@ -228,7 +220,7 @@ func TestGatherGlobalVariables(t *testing.T) {
Log: testutil.Logger{},
MetricVersion: 2,
}
m.InitMysql()
require.NoError(t, m.Init())
columns := []string{"Variable_name", "Value"}
measurement := "mysql_variables"