diff --git a/plugins/inputs/postgresql/service.go b/plugins/common/postgresql/config.go similarity index 72% rename from plugins/inputs/postgresql/service.go rename to plugins/common/postgresql/config.go index d9447b4d3..d3a0b387f 100644 --- a/plugins/inputs/postgresql/service.go +++ b/plugins/common/postgresql/config.go @@ -1,7 +1,6 @@ package postgresql import ( - "database/sql" "fmt" "net" "net/url" @@ -13,10 +12,105 @@ import ( "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/stdlib" - "github.com/influxdata/telegraf" "github.com/influxdata/telegraf/config" ) +var socketRegexp = regexp.MustCompile(`/\.s\.PGSQL\.\d+$`) +var sanitizer = regexp.MustCompile(`(\s|^)((?:password|sslcert|sslkey|sslmode|sslrootcert)\s?=\s?(?:(?:'(?:[^'\\]|\\.)*')|(?:\S+)))`) + +type Config struct { + Address config.Secret `toml:"address"` + OutputAddress string `toml:"outputaddress"` + MaxIdle int `toml:"max_idle"` + MaxOpen int `toml:"max_open"` + MaxLifetime config.Duration `toml:"max_lifetime"` + IsPgBouncer bool `toml:"-"` +} + +func (c *Config) CreateService() (*Service, error) { + addrSecret, err := c.Address.Get() + if err != nil { + return nil, fmt.Errorf("getting address failed: %w", err) + } + addr := addrSecret.String() + defer addrSecret.Destroy() + + if c.Address.Empty() || addr == "localhost" { + addr = "host=localhost sslmode=disable" + if err := c.Address.Set([]byte(addr)); err != nil { + return nil, err + } + } + + connConfig, err := pgx.ParseConfig(addr) + if err != nil { + return nil, err + } + // Remove the socket name from the path + connConfig.Host = socketRegexp.ReplaceAllLiteralString(connConfig.Host, "") + + // Specific support to make it work with PgBouncer too + // See https://github.com/influxdata/telegraf/issues/3253#issuecomment-357505343 + if c.IsPgBouncer { + // Remove DriveConfig and revert it by the ParseConfig method + // See https://github.com/influxdata/telegraf/issues/9134 + connConfig.PreferSimpleProtocol = true + } + + // Provide the connection string without sensitive information for use as + // tag or other output properties + sanitizedAddr, err := c.sanitizedAddress() + if err != nil { + return nil, err + } + + return &Service{ + SanitizedAddress: sanitizedAddr, + ConnectionDatabase: connectionDatabase(sanitizedAddr), + maxIdle: c.MaxIdle, + maxOpen: c.MaxOpen, + maxLifetime: time.Duration(c.MaxLifetime), + dsn: stdlib.RegisterConnConfig(connConfig), + }, nil +} + +// connectionDatabase determines the database to which the connection was made +func connectionDatabase(sanitizedAddr string) string { + connConfig, err := pgx.ParseConfig(sanitizedAddr) + if err != nil || connConfig.Database == "" { + return "postgres" + } + + return connConfig.Database +} + +// sanitizedAddress strips sensitive information from the connection string. +// If the user set the output address use that before parsing anything else. +func (c *Config) sanitizedAddress() (string, error) { + if c.OutputAddress != "" { + return c.OutputAddress, nil + } + + // Get the address + addrSecret, err := c.Address.Get() + if err != nil { + return "", fmt.Errorf("getting address for sanitization failed: %w", err) + } + defer addrSecret.Destroy() + + // Make sure we convert URI-formatted strings into key-values + addr := addrSecret.TemporaryString() + if strings.HasPrefix(addr, "postgres://") || strings.HasPrefix(addr, "postgresql://") { + if addr, err = toKeyValue(addr); err != nil { + return "", err + } + } + + // Sanitize the string using a regular expression + sanitized := sanitizer.ReplaceAllString(addr, "") + return strings.TrimSpace(sanitized), nil +} + // Based on parseURLSettings() at https://github.com/jackc/pgx/blob/master/pgconn/config.go func toKeyValue(uri string) (string, error) { u, err := url.Parse(uri) @@ -88,105 +182,3 @@ func toKeyValue(uri string) (string, error) { sort.Strings(parts) return strings.Join(parts, " "), nil } - -// Service common functionality shared between the postgresql and postgresql_extensible -// packages. -type Service struct { - Address config.Secret `toml:"address"` - OutputAddress string `toml:"outputaddress"` - MaxIdle int `toml:"max_idle"` - MaxOpen int `toml:"max_open"` - MaxLifetime config.Duration `toml:"max_lifetime"` - IsPgBouncer bool `toml:"-"` - DB *sql.DB -} - -var socketRegexp = regexp.MustCompile(`/\.s\.PGSQL\.\d+$`) - -// Start starts the ServiceInput's service, whatever that may be -func (p *Service) Start(telegraf.Accumulator) (err error) { - addrSecret, err := p.Address.Get() - if err != nil { - return fmt.Errorf("getting address failed: %w", err) - } - addr := addrSecret.String() - defer addrSecret.Destroy() - - if p.Address.Empty() || addr == "localhost" { - addr = "host=localhost sslmode=disable" - if err := p.Address.Set([]byte(addr)); err != nil { - return err - } - } - - connConfig, err := pgx.ParseConfig(addr) - if err != nil { - return err - } - - // Remove the socket name from the path - connConfig.Host = socketRegexp.ReplaceAllLiteralString(connConfig.Host, "") - - // Specific support to make it work with PgBouncer too - // See https://github.com/influxdata/telegraf/issues/3253#issuecomment-357505343 - if p.IsPgBouncer { - // Remove DriveConfig and revert it by the ParseConfig method - // See https://github.com/influxdata/telegraf/issues/9134 - connConfig.PreferSimpleProtocol = true - } - - connectionString := stdlib.RegisterConnConfig(connConfig) - if p.DB, err = sql.Open("pgx", connectionString); err != nil { - return err - } - - p.DB.SetMaxOpenConns(p.MaxOpen) - p.DB.SetMaxIdleConns(p.MaxIdle) - p.DB.SetConnMaxLifetime(time.Duration(p.MaxLifetime)) - - return nil -} - -// Stop stops the services and closes any necessary channels and connections -func (p *Service) Stop() { - p.DB.Close() -} - -var sanitizer = regexp.MustCompile(`(\s|^)((?:password|sslcert|sslkey|sslmode|sslrootcert)\s?=\s?(?:(?:'(?:[^'\\]|\\.)*')|(?:\S+)))`) - -// SanitizedAddress utility function to strip sensitive information from the connection string. -func (p *Service) SanitizedAddress() (string, error) { - if p.OutputAddress != "" { - return p.OutputAddress, nil - } - - // Get the address - addrSecret, err := p.Address.Get() - if err != nil { - return "", fmt.Errorf("getting address for sanitization failed: %w", err) - } - defer addrSecret.Destroy() - - // Make sure we convert URI-formatted strings into key-values - addr := addrSecret.TemporaryString() - if strings.HasPrefix(addr, "postgres://") || strings.HasPrefix(addr, "postgresql://") { - if addr, err = toKeyValue(addr); err != nil { - return "", err - } - } - - // Sanitize the string using a regular expression - sanitized := sanitizer.ReplaceAllString(addr, "") - return strings.TrimSpace(sanitized), nil -} - -// GetConnectDatabase utility function for getting the database to which the connection was made -// If the user set the output address use that before parsing anything else. -func (p *Service) GetConnectDatabase(connectionString string) (string, error) { - connConfig, err := pgx.ParseConfig(connectionString) - if err == nil && len(connConfig.Database) != 0 { - return connConfig.Database, nil - } - - return "postgres", nil -} diff --git a/plugins/common/postgresql/config_test.go b/plugins/common/postgresql/config_test.go new file mode 100644 index 000000000..696b5361d --- /dev/null +++ b/plugins/common/postgresql/config_test.go @@ -0,0 +1,239 @@ +package postgresql + +import ( + "strings" + "testing" + + "github.com/influxdata/telegraf/config" + "github.com/stretchr/testify/require" +) + +func TestURIParsing(t *testing.T) { + tests := []struct { + name string + uri string + expected string + }{ + { + name: "short", + uri: `postgres://localhost`, + expected: "host=localhost", + }, + { + name: "with port", + uri: `postgres://localhost:5432`, + expected: "host=localhost port=5432", + }, + { + name: "with database", + uri: `postgres://localhost/mydb`, + expected: "dbname=mydb host=localhost", + }, + { + name: "with additional parameters", + uri: `postgres://localhost/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5`, + expected: "application_name=pgxtest connect_timeout=5 dbname=mydb host=localhost search_path=myschema", + }, + { + name: "with database setting in params", + uri: `postgres://localhost:5432/?database=mydb`, + expected: "database=mydb host=localhost port=5432", + }, + { + name: "with authentication", + uri: `postgres://jack:secret@localhost:5432/mydb?sslmode=prefer`, + expected: "dbname=mydb host=localhost password=secret port=5432 sslmode=prefer user=jack", + }, + { + name: "with spaces", + uri: `postgres://jack%20hunter:secret@localhost/mydb?application_name=pgx%20test`, + expected: "application_name='pgx test' dbname=mydb host=localhost password=secret user='jack hunter'", + }, + { + name: "with equal signs", + uri: `postgres://jack%20hunter:secret@localhost/mydb?application_name=pgx%3Dtest`, + expected: "application_name='pgx=test' dbname=mydb host=localhost password=secret user='jack hunter'", + }, + { + name: "multiple hosts", + uri: `postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable`, + expected: "dbname=mydb host=foo,bar,baz password=secret port=1,2,3 sslmode=disable user=jack", + }, + { + name: "multiple hosts without ports", + uri: `postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable`, + expected: "dbname=mydb host=foo,bar,baz password=secret sslmode=disable user=jack", + }, + } + + for _, tt := range tests { + // Key value without spaces around equal sign + t.Run(tt.name, func(t *testing.T) { + actual, err := toKeyValue(tt.uri) + require.NoError(t, err) + require.Equalf(t, tt.expected, actual, "initial: %s", tt.uri) + }) + } +} + +func TestSanitizeAddressKeyValue(t *testing.T) { + keys := []string{"password", "sslcert", "sslkey", "sslmode", "sslrootcert"} + tests := []struct { + name string + value string + }{ + { + name: "simple text", + value: `foo`, + }, + { + name: "empty values", + value: `''`, + }, + { + name: "space in value", + value: `'foo bar'`, + }, + { + name: "equal sign in value", + value: `'foo=bar'`, + }, + { + name: "escaped quote", + value: `'foo\'s bar'`, + }, + { + name: "escaped quote no space", + value: `\'foobar\'s\'`, + }, + { + name: "escaped backslash", + value: `'foo bar\\'`, + }, + { + name: "escaped quote and backslash", + value: `'foo\\\'s bar'`, + }, + { + name: "two escaped backslashes", + value: `'foo bar\\\\'`, + }, + { + name: "multiple inline spaces", + value: "'foo \t bar'", + }, + { + name: "leading space", + value: `' foo bar'`, + }, + { + name: "trailing space", + value: `'foo bar '`, + }, + { + name: "multiple equal signs", + value: `'foo===bar'`, + }, + { + name: "leading equal sign", + value: `'=foo bar'`, + }, + { + name: "trailing equal sign", + value: `'foo bar='`, + }, + { + name: "mix of equal signs and spaces", + value: "'foo = a\t===\tbar'", + }, + } + + for _, tt := range tests { + // Key value without spaces around equal sign + t.Run(tt.name, func(t *testing.T) { + // Generate the DSN from the given keys and value + parts := make([]string, 0, len(keys)) + for _, k := range keys { + parts = append(parts, k+"="+tt.value) + } + dsn := strings.Join(parts, " canary=ok ") + + cfg := &Config{ + Address: config.NewSecret([]byte(dsn)), + } + + expected := strings.Join(make([]string, len(keys)), "canary=ok ") + expected = strings.TrimSpace(expected) + actual, err := cfg.sanitizedAddress() + require.NoError(t, err) + require.Equalf(t, expected, actual, "initial: %s", dsn) + }) + + // Key value with spaces around equal sign + t.Run("spaced "+tt.name, func(t *testing.T) { + // Generate the DSN from the given keys and value + parts := make([]string, 0, len(keys)) + for _, k := range keys { + parts = append(parts, k+" = "+tt.value) + } + dsn := strings.Join(parts, " canary=ok ") + + cfg := &Config{ + Address: config.NewSecret([]byte(dsn)), + } + + expected := strings.Join(make([]string, len(keys)), "canary=ok ") + expected = strings.TrimSpace(expected) + actual, err := cfg.sanitizedAddress() + require.NoError(t, err) + require.Equalf(t, expected, actual, "initial: %s", dsn) + }) + } +} + +func TestSanitizeAddressURI(t *testing.T) { + keys := []string{"password", "sslcert", "sslkey", "sslmode", "sslrootcert"} + tests := []struct { + name string + value string + }{ + { + name: "simple text", + value: `foo`, + }, + { + name: "empty values", + value: ``, + }, + { + name: "space in value", + value: `foo bar`, + }, + { + name: "equal sign in value", + value: `foo=bar`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Generate the DSN from the given keys and value + value := strings.ReplaceAll(tt.value, "=", "%3D") + value = strings.ReplaceAll(value, " ", "%20") + parts := make([]string, 0, len(keys)) + for _, k := range keys { + parts = append(parts, k+"="+value) + } + dsn := "postgresql://user:passwd@localhost:5432/db?" + strings.Join(parts, "&") + + cfg := &Config{ + Address: config.NewSecret([]byte(dsn)), + } + + expected := "dbname=db host=localhost port=5432 user=user" + actual, err := cfg.sanitizedAddress() + require.NoError(t, err) + require.Equalf(t, expected, actual, "initial: %s", dsn) + }) + } +} diff --git a/plugins/common/postgresql/service.go b/plugins/common/postgresql/service.go new file mode 100644 index 000000000..abba6da05 --- /dev/null +++ b/plugins/common/postgresql/service.go @@ -0,0 +1,42 @@ +package postgresql + +import ( + "database/sql" + "time" + + // Blank import required to register driver + _ "github.com/jackc/pgx/v4/stdlib" +) + +// Service common functionality shared between the postgresql and postgresql_extensible +// packages. +type Service struct { + DB *sql.DB + SanitizedAddress string + ConnectionDatabase string + + dsn string + maxIdle int + maxOpen int + maxLifetime time.Duration +} + +func (p *Service) Start() error { + db, err := sql.Open("pgx", p.dsn) + if err != nil { + return err + } + p.DB = db + + p.DB.SetMaxOpenConns(p.maxOpen) + p.DB.SetMaxIdleConns(p.maxIdle) + p.DB.SetConnMaxLifetime(p.maxLifetime) + + return nil +} + +func (p *Service) Stop() { + if p.DB != nil { + p.DB.Close() + } +} diff --git a/plugins/inputs/pgbouncer/pgbouncer.go b/plugins/inputs/pgbouncer/pgbouncer.go index 2ece50020..4d079e173 100644 --- a/plugins/inputs/pgbouncer/pgbouncer.go +++ b/plugins/inputs/pgbouncer/pgbouncer.go @@ -3,25 +3,24 @@ package pgbouncer import ( "bytes" + "database/sql" _ "embed" "fmt" "strconv" - // Required for SQL framework driver - _ "github.com/jackc/pgx/v4/stdlib" - "github.com/influxdata/telegraf" - "github.com/influxdata/telegraf/config" + "github.com/influxdata/telegraf/plugins/common/postgresql" "github.com/influxdata/telegraf/plugins/inputs" - "github.com/influxdata/telegraf/plugins/inputs/postgresql" ) //go:embed sample.conf var sampleConfig string type PgBouncer struct { - postgresql.Service ShowCommands []string `toml:"show_commands"` + postgresql.Config + + service *postgresql.Service } var ignoredColumns = map[string]bool{"user": true, "database": true, "pool_mode": true, @@ -33,34 +32,54 @@ func (*PgBouncer) SampleConfig() string { return sampleConfig } -func (p *PgBouncer) Gather(acc telegraf.Accumulator) error { +func (p *PgBouncer) Init() error { + // Set defaults and check settings if len(p.ShowCommands) == 0 { - if err := p.showStats(acc); err != nil { - return err + p.ShowCommands = []string{"stats", "pools"} + } + for _, cmd := range p.ShowCommands { + switch cmd { + case "stats", "pools", "lists", "databases": + default: + return fmt.Errorf("invalid setting %q for 'show_command'", cmd) } + } - if err := p.showPools(acc); err != nil { - return err - } - } else { - for _, cmd := range p.ShowCommands { - switch { - case cmd == "stats": - if err := p.showStats(acc); err != nil { - return err - } - case cmd == "pools": - if err := p.showPools(acc); err != nil { - return err - } - case cmd == "lists": - if err := p.showLists(acc); err != nil { - return err - } - case cmd == "databases": - if err := p.showDatabase(acc); err != nil { - return err - } + // Create a postgres service for the queries + service, err := p.Config.CreateService() + if err != nil { + return err + } + p.service = service + return nil +} + +func (p *PgBouncer) Start(_ telegraf.Accumulator) error { + return p.service.Start() +} + +func (p *PgBouncer) Stop() { + p.service.Stop() +} + +func (p *PgBouncer) Gather(acc telegraf.Accumulator) error { + for _, cmd := range p.ShowCommands { + switch cmd { + case "stats": + if err := p.showStats(acc); err != nil { + return err + } + case "pools": + if err := p.showPools(acc); err != nil { + return err + } + case "lists": + if err := p.showLists(acc); err != nil { + return err + } + case "databases": + if err := p.showDatabase(acc); err != nil { + return err } } } @@ -68,11 +87,7 @@ func (p *PgBouncer) Gather(acc telegraf.Accumulator) error { return nil } -type scanner interface { - Scan(dest ...interface{}) error -} - -func (p *PgBouncer) accRow(row scanner, columns []string) (map[string]string, map[string]*interface{}, error) { +func (p *PgBouncer) accRow(row *sql.Rows, columns []string) (map[string]string, map[string]*interface{}, error) { var dbname bytes.Buffer // this is where we'll store the column name with its *interface{} @@ -103,19 +118,13 @@ func (p *PgBouncer) accRow(row scanner, columns []string) (map[string]string, ma dbname.WriteString("pgbouncer") } - var tagAddress string - tagAddress, err = p.SanitizedAddress() - if err != nil { - return nil, nil, fmt.Errorf("couldn't get connection data: %w", err) - } - // Return basic tags and the mapped columns - return map[string]string{"server": tagAddress, "db": dbname.String()}, columnMap, nil + return map[string]string{"server": p.service.SanitizedAddress, "db": dbname.String()}, columnMap, nil } func (p *PgBouncer) showStats(acc telegraf.Accumulator) error { // STATS - rows, err := p.DB.Query(`SHOW STATS`) + rows, err := p.service.DB.Query(`SHOW STATS`) if err != nil { return fmt.Errorf("execution error 'show stats': %w", err) } @@ -163,7 +172,7 @@ func (p *PgBouncer) showStats(acc telegraf.Accumulator) error { func (p *PgBouncer) showPools(acc telegraf.Accumulator) error { // POOLS - poolRows, err := p.DB.Query(`SHOW POOLS`) + poolRows, err := p.service.DB.Query(`SHOW POOLS`) if err != nil { return fmt.Errorf("execution error 'show pools': %w", err) } @@ -209,7 +218,7 @@ func (p *PgBouncer) showPools(acc telegraf.Accumulator) error { func (p *PgBouncer) showLists(acc telegraf.Accumulator) error { // LISTS - rows, err := p.DB.Query(`SHOW LISTS`) + rows, err := p.service.DB.Query(`SHOW LISTS`) if err != nil { return fmt.Errorf("execution error 'show lists': %w", err) } @@ -250,7 +259,7 @@ func (p *PgBouncer) showLists(acc telegraf.Accumulator) error { func (p *PgBouncer) showDatabase(acc telegraf.Accumulator) error { // DATABASES - rows, err := p.DB.Query(`SHOW DATABASES`) + rows, err := p.service.DB.Query(`SHOW DATABASES`) if err != nil { return fmt.Errorf("execution error 'show database': %w", err) } @@ -298,10 +307,9 @@ func (p *PgBouncer) showDatabase(acc telegraf.Accumulator) error { func init() { inputs.Add("pgbouncer", func() telegraf.Input { return &PgBouncer{ - Service: postgresql.Service{ + Config: postgresql.Config{ MaxIdle: 1, MaxOpen: 1, - MaxLifetime: config.Duration(0), IsPgBouncer: true, }, } diff --git a/plugins/inputs/pgbouncer/pgbouncer_test.go b/plugins/inputs/pgbouncer/pgbouncer_test.go index a681d25e8..4a39a1e3e 100644 --- a/plugins/inputs/pgbouncer/pgbouncer_test.go +++ b/plugins/inputs/pgbouncer/pgbouncer_test.go @@ -9,7 +9,7 @@ import ( "github.com/testcontainers/testcontainers-go/wait" "github.com/influxdata/telegraf/config" - "github.com/influxdata/telegraf/plugins/inputs/postgresql" + "github.com/influxdata/telegraf/plugins/common/postgresql" "github.com/influxdata/telegraf/testutil" ) @@ -29,8 +29,7 @@ func TestPgBouncerGeneratesMetricsIntegration(t *testing.T) { }, WaitingFor: wait.ForLog("database system is ready to accept connections").WithOccurrence(2), } - err := backend.Start() - require.NoError(t, err, "failed to start container") + require.NoError(t, backend.Start(), "failed to start container") defer backend.Terminate() container := testutil.Container{ @@ -45,8 +44,7 @@ func TestPgBouncerGeneratesMetricsIntegration(t *testing.T) { wait.ForLog("LOG process up"), ), } - err = container.Start() - require.NoError(t, err, "failed to start container") + require.NoError(t, container.Start(), "failed to start container") defer container.Terminate() addr := fmt.Sprintf( @@ -56,14 +54,16 @@ func TestPgBouncerGeneratesMetricsIntegration(t *testing.T) { ) p := &PgBouncer{ - Service: postgresql.Service{ + Config: postgresql.Config{ Address: config.NewSecret([]byte(addr)), IsPgBouncer: true, }, } + require.NoError(t, p.Init()) var acc testutil.Accumulator require.NoError(t, p.Start(&acc)) + defer p.Stop() require.NoError(t, p.Gather(&acc)) intMetricsPgBouncer := []string{ @@ -145,15 +145,17 @@ func TestPgBouncerGeneratesMetricsIntegrationShowCommands(t *testing.T) { ) p := &PgBouncer{ - Service: postgresql.Service{ + Config: postgresql.Config{ Address: config.NewSecret([]byte(addr)), IsPgBouncer: true, }, ShowCommands: []string{"pools", "lists", "databases"}, } + require.NoError(t, p.Init()) var acc testutil.Accumulator require.NoError(t, p.Start(&acc)) + defer p.Stop() require.NoError(t, p.Gather(&acc)) intMetricsPgBouncerPools := []string{ diff --git a/plugins/inputs/postgresql/postgresql.go b/plugins/inputs/postgresql/postgresql.go index ac13adce4..dc8f37ca8 100644 --- a/plugins/inputs/postgresql/postgresql.go +++ b/plugins/inputs/postgresql/postgresql.go @@ -3,15 +3,13 @@ package postgresql import ( "bytes" + "database/sql" _ "embed" "fmt" "strings" - // Blank import required to register driver - _ "github.com/jackc/pgx/v4/stdlib" - "github.com/influxdata/telegraf" - "github.com/influxdata/telegraf/config" + "github.com/influxdata/telegraf/plugins/common/postgresql" "github.com/influxdata/telegraf/plugins/inputs" ) @@ -19,10 +17,12 @@ import ( var sampleConfig string type Postgresql struct { - Service Databases []string `toml:"databases"` IgnoredDatabases []string `toml:"ignored_databases"` PreparedStatements bool `toml:"prepared_statements"` + postgresql.Config + + service *postgresql.Service } var ignoredColumns = map[string]bool{"stats_reset": true} @@ -31,22 +31,28 @@ func (*Postgresql) SampleConfig() string { return sampleConfig } -func (p *Postgresql) IgnoredColumns() map[string]bool { - return ignoredColumns -} - func (p *Postgresql) Init() error { - p.Service.IsPgBouncer = !p.PreparedStatements + p.IsPgBouncer = !p.PreparedStatements + + service, err := p.Config.CreateService() + if err != nil { + return err + } + p.service = service + return nil } -func (p *Postgresql) Gather(acc telegraf.Accumulator) error { - var ( - err error - query string - columns []string - ) +func (p *Postgresql) Start(_ telegraf.Accumulator) error { + return p.service.Start() +} +func (p *Postgresql) Stop() { + p.service.Stop() +} + +func (p *Postgresql) Gather(acc telegraf.Accumulator) error { + var query string if len(p.Databases) == 0 && len(p.IgnoredDatabases) == 0 { query = `SELECT * FROM pg_stat_database` } else if len(p.IgnoredDatabases) != 0 { @@ -57,7 +63,7 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error { strings.Join(p.Databases, "','")) } - rows, err := p.DB.Query(query) + rows, err := p.service.DB.Query(query) if err != nil { return err } @@ -65,7 +71,8 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error { defer rows.Close() // grab the column information from the result - if columns, err = rows.Columns(); err != nil { + columns, err := rows.Columns() + if err != nil { return err } @@ -78,7 +85,7 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error { query = `SELECT * FROM pg_stat_bgwriter` - bgWriterRow, err := p.DB.Query(query) + bgWriterRow, err := p.service.DB.Query(query) if err != nil { return err } @@ -91,8 +98,7 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error { } for bgWriterRow.Next() { - err = p.accRow(bgWriterRow, acc, columns) - if err != nil { + if err := p.accRow(bgWriterRow, acc, columns); err != nil { return err } } @@ -100,11 +106,7 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error { return bgWriterRow.Err() } -type scanner interface { - Scan(dest ...interface{}) error -} - -func (p *Postgresql) accRow(row scanner, acc telegraf.Accumulator, columns []string) error { +func (p *Postgresql) accRow(row *sql.Rows, acc telegraf.Accumulator, columns []string) error { var dbname bytes.Buffer // this is where we'll store the column name with its *interface{} @@ -120,15 +122,8 @@ func (p *Postgresql) accRow(row scanner, acc telegraf.Accumulator, columns []str columnVars = append(columnVars, columnMap[columns[i]]) } - tagAddress, err := p.SanitizedAddress() - if err != nil { - return err - } - // deconstruct array of variables and send to Scan - err = row.Scan(columnVars...) - - if err != nil { + if err := row.Scan(columnVars...); err != nil { return err } if columnMap["datname"] != nil { @@ -140,13 +135,10 @@ func (p *Postgresql) accRow(row scanner, acc telegraf.Accumulator, columns []str dbname.WriteString("postgres_global") } } else { - database, err := p.GetConnectDatabase(tagAddress) - if err != nil { - return err - } - dbname.WriteString(database) + dbname.WriteString(p.service.ConnectionDatabase) } + tagAddress := p.service.SanitizedAddress tags := map[string]string{"server": tagAddress, "db": dbname.String()} fields := make(map[string]interface{}) @@ -164,10 +156,9 @@ func (p *Postgresql) accRow(row scanner, acc telegraf.Accumulator, columns []str func init() { inputs.Add("postgresql", func() telegraf.Input { return &Postgresql{ - Service: Service{ - MaxIdle: 1, - MaxOpen: 1, - MaxLifetime: config.Duration(0), + Config: postgresql.Config{ + MaxIdle: 1, + MaxOpen: 1, }, PreparedStatements: true, } diff --git a/plugins/inputs/postgresql/postgresql_test.go b/plugins/inputs/postgresql/postgresql_test.go index dfb5da0d2..60ec08bde 100644 --- a/plugins/inputs/postgresql/postgresql_test.go +++ b/plugins/inputs/postgresql/postgresql_test.go @@ -2,7 +2,6 @@ package postgresql import ( "fmt" - "strings" "testing" "github.com/docker/go-connections/nat" @@ -10,6 +9,7 @@ import ( "github.com/testcontainers/testcontainers-go/wait" "github.com/influxdata/telegraf/config" + "github.com/influxdata/telegraf/plugins/common/postgresql" "github.com/influxdata/telegraf/testutil" ) @@ -51,15 +51,17 @@ func TestPostgresqlGeneratesMetricsIntegration(t *testing.T) { ) p := &Postgresql{ - Service: Service{ + Config: postgresql.Config{ Address: config.NewSecret([]byte(addr)), IsPgBouncer: false, }, Databases: []string{"postgres"}, } + require.NoError(t, p.Init()) var acc testutil.Accumulator require.NoError(t, p.Start(&acc)) + defer p.Stop() require.NoError(t, p.Gather(&acc)) intMetrics := []string{ @@ -142,15 +144,16 @@ func TestPostgresqlTagsMetricsWithDatabaseNameIntegration(t *testing.T) { ) p := &Postgresql{ - Service: Service{ + Config: postgresql.Config{ Address: config.NewSecret([]byte(addr)), }, Databases: []string{"postgres"}, } + require.NoError(t, p.Init()) var acc testutil.Accumulator - require.NoError(t, p.Start(&acc)) + defer p.Stop() require.NoError(t, p.Gather(&acc)) point, ok := acc.Get("postgresql") @@ -174,14 +177,15 @@ func TestPostgresqlDefaultsToAllDatabasesIntegration(t *testing.T) { ) p := &Postgresql{ - Service: Service{ + Config: postgresql.Config{ Address: config.NewSecret([]byte(addr)), }, } + require.NoError(t, p.Init()) var acc testutil.Accumulator - require.NoError(t, p.Start(&acc)) + defer p.Stop() require.NoError(t, p.Gather(&acc)) var found bool @@ -213,16 +217,18 @@ func TestPostgresqlIgnoresUnwantedColumnsIntegration(t *testing.T) { ) p := &Postgresql{ - Service: Service{ + Config: postgresql.Config{ Address: config.NewSecret([]byte(addr)), }, } + require.NoError(t, p.Init()) var acc testutil.Accumulator require.NoError(t, p.Start(&acc)) + defer p.Stop() require.NoError(t, p.Gather(&acc)) - for col := range p.IgnoredColumns() { + for col := range ignoredColumns { require.False(t, acc.HasMeasurement(col)) } } @@ -242,15 +248,16 @@ func TestPostgresqlDatabaseWhitelistTestIntegration(t *testing.T) { ) p := &Postgresql{ - Service: Service{ + Config: postgresql.Config{ Address: config.NewSecret([]byte(addr)), }, Databases: []string{"template0"}, } + require.NoError(t, p.Init()) var acc testutil.Accumulator - require.NoError(t, p.Start(&acc)) + defer p.Stop() require.NoError(t, p.Gather(&acc)) var foundTemplate0 = false @@ -288,14 +295,16 @@ func TestPostgresqlDatabaseBlacklistTestIntegration(t *testing.T) { ) p := &Postgresql{ - Service: Service{ + Config: postgresql.Config{ Address: config.NewSecret([]byte(addr)), }, IgnoredDatabases: []string{"template0"}, } + require.NoError(t, p.Init()) var acc testutil.Accumulator require.NoError(t, p.Start(&acc)) + defer p.Stop() require.NoError(t, p.Gather(&acc)) var foundTemplate0 = false @@ -317,239 +326,3 @@ func TestPostgresqlDatabaseBlacklistTestIntegration(t *testing.T) { require.False(t, foundTemplate0) require.True(t, foundTemplate1) } - -func TestURIParsing(t *testing.T) { - tests := []struct { - name string - uri string - expected string - }{ - { - name: "short", - uri: `postgres://localhost`, - expected: "host=localhost", - }, - { - name: "with port", - uri: `postgres://localhost:5432`, - expected: "host=localhost port=5432", - }, - { - name: "with database", - uri: `postgres://localhost/mydb`, - expected: "dbname=mydb host=localhost", - }, - { - name: "with additional parameters", - uri: `postgres://localhost/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5`, - expected: "application_name=pgxtest connect_timeout=5 dbname=mydb host=localhost search_path=myschema", - }, - { - name: "with database setting in params", - uri: `postgres://localhost:5432/?database=mydb`, - expected: "database=mydb host=localhost port=5432", - }, - { - name: "with authentication", - uri: `postgres://jack:secret@localhost:5432/mydb?sslmode=prefer`, - expected: "dbname=mydb host=localhost password=secret port=5432 sslmode=prefer user=jack", - }, - { - name: "with spaces", - uri: `postgres://jack%20hunter:secret@localhost/mydb?application_name=pgx%20test`, - expected: "application_name='pgx test' dbname=mydb host=localhost password=secret user='jack hunter'", - }, - { - name: "with equal signs", - uri: `postgres://jack%20hunter:secret@localhost/mydb?application_name=pgx%3Dtest`, - expected: "application_name='pgx=test' dbname=mydb host=localhost password=secret user='jack hunter'", - }, - { - name: "multiple hosts", - uri: `postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable`, - expected: "dbname=mydb host=foo,bar,baz password=secret port=1,2,3 sslmode=disable user=jack", - }, - { - name: "multiple hosts without ports", - uri: `postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable`, - expected: "dbname=mydb host=foo,bar,baz password=secret sslmode=disable user=jack", - }, - } - - for _, tt := range tests { - // Key value without spaces around equal sign - t.Run(tt.name, func(t *testing.T) { - actual, err := toKeyValue(tt.uri) - require.NoError(t, err) - require.Equalf(t, tt.expected, actual, "initial: %s", tt.uri) - }) - } -} - -func TestSanitizeAddressKeyValue(t *testing.T) { - keys := []string{"password", "sslcert", "sslkey", "sslmode", "sslrootcert"} - tests := []struct { - name string - value string - }{ - { - name: "simple text", - value: `foo`, - }, - { - name: "empty values", - value: `''`, - }, - { - name: "space in value", - value: `'foo bar'`, - }, - { - name: "equal sign in value", - value: `'foo=bar'`, - }, - { - name: "escaped quote", - value: `'foo\'s bar'`, - }, - { - name: "escaped quote no space", - value: `\'foobar\'s\'`, - }, - { - name: "escaped backslash", - value: `'foo bar\\'`, - }, - { - name: "escaped quote and backslash", - value: `'foo\\\'s bar'`, - }, - { - name: "two escaped backslashes", - value: `'foo bar\\\\'`, - }, - { - name: "multiple inline spaces", - value: "'foo \t bar'", - }, - { - name: "leading space", - value: `' foo bar'`, - }, - { - name: "trailing space", - value: `'foo bar '`, - }, - { - name: "multiple equal signs", - value: `'foo===bar'`, - }, - { - name: "leading equal sign", - value: `'=foo bar'`, - }, - { - name: "trailing equal sign", - value: `'foo bar='`, - }, - { - name: "mix of equal signs and spaces", - value: "'foo = a\t===\tbar'", - }, - } - - for _, tt := range tests { - // Key value without spaces around equal sign - t.Run(tt.name, func(t *testing.T) { - // Generate the DSN from the given keys and value - parts := make([]string, 0, len(keys)) - for _, k := range keys { - parts = append(parts, k+"="+tt.value) - } - dsn := strings.Join(parts, " canary=ok ") - - plugin := &Postgresql{ - Service: Service{ - Address: config.NewSecret([]byte(dsn)), - }, - } - - expected := strings.Join(make([]string, len(keys)), "canary=ok ") - expected = strings.TrimSpace(expected) - actual, err := plugin.SanitizedAddress() - require.NoError(t, err) - require.Equalf(t, expected, actual, "initial: %s", dsn) - }) - - // Key value with spaces around equal sign - t.Run("spaced "+tt.name, func(t *testing.T) { - // Generate the DSN from the given keys and value - parts := make([]string, 0, len(keys)) - for _, k := range keys { - parts = append(parts, k+" = "+tt.value) - } - dsn := strings.Join(parts, " canary=ok ") - - plugin := &Postgresql{ - Service: Service{ - Address: config.NewSecret([]byte(dsn)), - }, - } - - expected := strings.Join(make([]string, len(keys)), "canary=ok ") - expected = strings.TrimSpace(expected) - actual, err := plugin.SanitizedAddress() - require.NoError(t, err) - require.Equalf(t, expected, actual, "initial: %s", dsn) - }) - } -} - -func TestSanitizeAddressURI(t *testing.T) { - keys := []string{"password", "sslcert", "sslkey", "sslmode", "sslrootcert"} - tests := []struct { - name string - value string - }{ - { - name: "simple text", - value: `foo`, - }, - { - name: "empty values", - value: ``, - }, - { - name: "space in value", - value: `foo bar`, - }, - { - name: "equal sign in value", - value: `foo=bar`, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Generate the DSN from the given keys and value - value := strings.ReplaceAll(tt.value, "=", "%3D") - value = strings.ReplaceAll(value, " ", "%20") - parts := make([]string, 0, len(keys)) - for _, k := range keys { - parts = append(parts, k+"="+value) - } - dsn := "postgresql://user:passwd@localhost:5432/db?" + strings.Join(parts, "&") - - plugin := &Postgresql{ - Service: Service{ - Address: config.NewSecret([]byte(dsn)), - }, - } - - expected := "dbname=db host=localhost port=5432 user=user" - actual, err := plugin.SanitizedAddress() - require.NoError(t, err) - require.Equalf(t, expected, actual, "initial: %s", dsn) - }) - } -} diff --git a/plugins/inputs/postgresql_extensible/postgresql_extensible.go b/plugins/inputs/postgresql_extensible/postgresql_extensible.go index 39dfea904..64eff6936 100644 --- a/plugins/inputs/postgresql_extensible/postgresql_extensible.go +++ b/plugins/inputs/postgresql_extensible/postgresql_extensible.go @@ -5,9 +5,7 @@ import ( "bytes" _ "embed" "fmt" - "io" "os" - "strconv" "strings" "time" @@ -15,36 +13,36 @@ import ( _ "github.com/jackc/pgx/v4/stdlib" "github.com/influxdata/telegraf" - "github.com/influxdata/telegraf/config" + "github.com/influxdata/telegraf/internal" + "github.com/influxdata/telegraf/plugins/common/postgresql" "github.com/influxdata/telegraf/plugins/inputs" - "github.com/influxdata/telegraf/plugins/inputs/postgresql" ) //go:embed sample.conf var sampleConfig string type Postgresql struct { - postgresql.Service - Databases []string `deprecated:"1.22.4;use the sqlquery option to specify database to use"` - AdditionalTags []string - Timestamp string - Query query - Debug bool - PreparedStatements bool `toml:"prepared_statements"` + Databases []string `deprecated:"1.22.4;use the sqlquery option to specify database to use"` + Query []query `toml:"query"` + PreparedStatements bool `toml:"prepared_statements"` + Log telegraf.Logger `toml:"-"` + postgresql.Config - Log telegraf.Logger + service *postgresql.Service } -type query []struct { - Sqlquery string - Script string - Version int `deprecated:"1.28.0;use minVersion to specify minimal DB version this query supports"` - MinVersion int `toml:"min_version"` - MaxVersion int `toml:"max_version"` - Withdbname bool `deprecated:"1.22.4;use the sqlquery option to specify database to use"` - Tagvalue string - Measurement string - Timestamp string +type query struct { + Sqlquery string `toml:"sqlquery"` + Script string `toml:"script"` + Version int `deprecated:"1.28.0;use minVersion to specify minimal DB version this query supports"` + MinVersion int `toml:"min_version"` + MaxVersion int `toml:"max_version"` + Withdbname bool `deprecated:"1.22.4;use the sqlquery option to specify database to use"` + Tagvalue string `toml:"tagvalue"` + Measurement string `toml:"measurement"` + Timestamp string `toml:"timestamp"` + + additionalTags map[string]bool } var ignoredColumns = map[string]bool{"stats_reset": true} @@ -54,133 +52,105 @@ func (*Postgresql) SampleConfig() string { } func (p *Postgresql) Init() error { - var err error - for i := range p.Query { - if p.Query[i].Sqlquery == "" { - p.Query[i].Sqlquery, err = ReadQueryFromFile(p.Query[i].Script) + // Set defaults for the queries + for i, q := range p.Query { + if q.Sqlquery == "" { + query, err := os.ReadFile(q.Script) if err != nil { return err } + q.Sqlquery = string(query) } - if p.Query[i].MinVersion == 0 { - p.Query[i].MinVersion = p.Query[i].Version + if q.MinVersion == 0 { + q.MinVersion = q.Version } - } - p.Service.IsPgBouncer = !p.PreparedStatements - return nil -} - -func (p *Postgresql) IgnoredColumns() map[string]bool { - return ignoredColumns -} - -func ReadQueryFromFile(filePath string) (string, error) { - file, err := os.Open(filePath) - if err != nil { - return "", err - } - defer file.Close() - - query, err := io.ReadAll(file) - if err != nil { - return "", err - } - return string(query), err -} - -func (p *Postgresql) Gather(acc telegraf.Accumulator) error { - var ( - err error - sqlQuery string - queryAddon string - dbVersion int - query string - measName string - ) - - // Retrieving the database version - query = `SELECT setting::integer / 100 AS version FROM pg_settings WHERE name = 'server_version_num'` - if err = p.DB.QueryRow(query).Scan(&dbVersion); err != nil { - dbVersion = 0 - } - - // We loop in order to process each query - // Query is not run if Database version does not match the query version. - for i := range p.Query { - sqlQuery = p.Query[i].Sqlquery - - if p.Query[i].Measurement != "" { - measName = p.Query[i].Measurement - } else { - measName = "postgresql" + if q.Measurement == "" { + q.Measurement = "postgresql" } - if p.Query[i].Withdbname { + var queryAddon string + if q.Withdbname { if len(p.Databases) != 0 { queryAddon = fmt.Sprintf(` IN ('%s')`, strings.Join(p.Databases, "','")) } else { queryAddon = " is not null" } - } else { - queryAddon = "" } - sqlQuery += queryAddon + q.Sqlquery += queryAddon - maxVer := p.Query[i].MaxVersion + q.additionalTags = make(map[string]bool) + if q.Tagvalue != "" { + for _, tag := range strings.Split(q.Tagvalue, ",") { + q.additionalTags[tag] = true + } + } + p.Query[i] = q + } + p.Config.IsPgBouncer = !p.PreparedStatements - if p.Query[i].MinVersion <= dbVersion && (maxVer == 0 || maxVer > dbVersion) { - p.gatherMetricsFromQuery(acc, sqlQuery, p.Query[i].Tagvalue, p.Query[i].Timestamp, measName) + // Create a service to access the PostgreSQL server + service, err := p.Config.CreateService() + if err != nil { + return err + } + p.service = service + + return nil +} + +func (p *Postgresql) Start(_ telegraf.Accumulator) error { + return p.service.Start() +} + +func (p *Postgresql) Stop() { + p.service.Stop() +} + +func (p *Postgresql) Gather(acc telegraf.Accumulator) error { + // Retrieving the database version + query := `SELECT setting::integer / 100 AS version FROM pg_settings WHERE name = 'server_version_num'` + var dbVersion int + if err := p.service.DB.QueryRow(query).Scan(&dbVersion); err != nil { + dbVersion = 0 + } + + // We loop in order to process each query + // Query is not run if Database version does not match the query version. + for _, q := range p.Query { + if q.MinVersion <= dbVersion && (q.MaxVersion == 0 || q.MaxVersion > dbVersion) { + acc.AddError(p.gatherMetricsFromQuery(acc, q)) } } return nil } -func (p *Postgresql) gatherMetricsFromQuery(acc telegraf.Accumulator, sqlQuery string, tagValue string, timestamp string, measName string) { - var columns []string - - rows, err := p.DB.Query(sqlQuery) +func (p *Postgresql) gatherMetricsFromQuery(acc telegraf.Accumulator, q query) error { + rows, err := p.service.DB.Query(q.Sqlquery) if err != nil { - acc.AddError(err) - return + return err } defer rows.Close() // grab the column information from the result - if columns, err = rows.Columns(); err != nil { - acc.AddError(err) - return + columns, err := rows.Columns() + if err != nil { + return err } - p.AdditionalTags = nil - if tagValue != "" { - tagList := strings.Split(tagValue, ",") - p.AdditionalTags = append(p.AdditionalTags, tagList...) - } - - p.Timestamp = timestamp - for rows.Next() { - err = p.accRow(measName, rows, acc, columns) - if err != nil { - acc.AddError(err) - break + if err := p.accRow(acc, rows, columns, q); err != nil { + return err } } + return nil } type scanner interface { Scan(dest ...interface{}) error } -func (p *Postgresql) accRow(measName string, row scanner, acc telegraf.Accumulator, columns []string) error { - var ( - err error - dbname bytes.Buffer - tagAddress string - timestamp time.Time - ) - +func (p *Postgresql) accRow(acc telegraf.Accumulator, row scanner, columns []string, q query) error { // this is where we'll store the column name with its *interface{} columnMap := make(map[string]*interface{}) @@ -194,46 +164,34 @@ func (p *Postgresql) accRow(measName string, row scanner, acc telegraf.Accumulat columnVars = append(columnVars, columnMap[columns[i]]) } - if tagAddress, err = p.SanitizedAddress(); err != nil { - return err - } - // deconstruct array of variables and send to Scan if err := row.Scan(columnVars...); err != nil { return err } + var dbname bytes.Buffer if c, ok := columnMap["datname"]; ok && *c != nil { // extract the database name from the column map switch datname := (*c).(type) { case string: dbname.WriteString(datname) default: - database, err := p.GetConnectDatabase(tagAddress) - if err != nil { - return err - } - dbname.WriteString(database) + dbname.WriteString(p.service.ConnectionDatabase) } } else { - database, err := p.GetConnectDatabase(tagAddress) - if err != nil { - return err - } - dbname.WriteString(database) + dbname.WriteString(p.service.ConnectionDatabase) } // Process the additional tags tags := map[string]string{ - "server": tagAddress, + "server": p.service.SanitizedAddress, "db": dbname.String(), } // set default timestamp to Now - timestamp = time.Now() + timestamp := time.Now() fields := make(map[string]interface{}) -COLUMN: for col, val := range columnMap { p.Log.Debugf("Column: %s = %T: %v\n", col, *val, *val) _, ignore := ignoredColumns[col] @@ -241,30 +199,21 @@ COLUMN: continue } - if col == p.Timestamp { + if col == q.Timestamp { if v, ok := (*val).(time.Time); ok { timestamp = v } continue } - for _, tag := range p.AdditionalTags { - if col != tag { - continue - } - switch v := (*val).(type) { - case string: + if q.additionalTags[col] { + v, err := internal.ToString(*val) + if err != nil { + p.Log.Debugf("Failed to add %q as additional tag: %v", col, err) + } else { tags[col] = v - case []byte: - tags[col] = string(v) - case int64, int32, int: - tags[col] = fmt.Sprintf("%d", v) - case bool: - tags[col] = strconv.FormatBool(v) - default: - p.Log.Debugf("Failed to add %q as additional tag", col) } - continue COLUMN + continue } if v, ok := (*val).([]byte); ok { @@ -273,18 +222,16 @@ COLUMN: fields[col] = *val } } - acc.AddFields(measName, fields, tags, timestamp) + acc.AddFields(q.Measurement, fields, tags, timestamp) return nil } func init() { inputs.Add("postgresql_extensible", func() telegraf.Input { return &Postgresql{ - Service: postgresql.Service{ - MaxIdle: 1, - MaxOpen: 1, - MaxLifetime: config.Duration(0), - IsPgBouncer: false, + Config: postgresql.Config{ + MaxIdle: 1, + MaxOpen: 1, }, PreparedStatements: true, } diff --git a/plugins/inputs/postgresql_extensible/postgresql_extensible_test.go b/plugins/inputs/postgresql_extensible/postgresql_extensible_test.go index a26d177b8..3b590da2e 100644 --- a/plugins/inputs/postgresql_extensible/postgresql_extensible_test.go +++ b/plugins/inputs/postgresql_extensible/postgresql_extensible_test.go @@ -11,11 +11,11 @@ import ( "github.com/testcontainers/testcontainers-go/wait" "github.com/influxdata/telegraf/config" - "github.com/influxdata/telegraf/plugins/inputs/postgresql" + "github.com/influxdata/telegraf/plugins/common/postgresql" "github.com/influxdata/telegraf/testutil" ) -func queryRunner(t *testing.T, q query) *testutil.Accumulator { +func queryRunner(t *testing.T, q []query) *testutil.Accumulator { servicePort := "5432" container := testutil.Container{ Image: "postgres:alpine", @@ -29,8 +29,7 @@ func queryRunner(t *testing.T, q query) *testutil.Accumulator { ), } - err := container.Start() - require.NoError(t, err, "failed to start container") + require.NoError(t, container.Start(), "failed to start container") defer container.Terminate() addr := fmt.Sprintf( @@ -41,18 +40,20 @@ func queryRunner(t *testing.T, q query) *testutil.Accumulator { p := &Postgresql{ Log: testutil.Logger{}, - Service: postgresql.Service{ + Config: postgresql.Config{ Address: config.NewSecret([]byte(addr)), IsPgBouncer: false, }, Databases: []string{"postgres"}, Query: q, } + require.NoError(t, p.Init()) var acc testutil.Accumulator - require.NoError(t, p.Init()) require.NoError(t, p.Start(&acc)) + defer p.Stop() require.NoError(t, acc.GatherError(p.Gather)) + return &acc } @@ -61,12 +62,13 @@ func TestPostgresqlGeneratesMetricsIntegration(t *testing.T) { t.Skip("Skipping integration test in short mode") } - acc := queryRunner(t, query{{ + acc := queryRunner(t, []query{{ Sqlquery: "select * from pg_stat_database", MinVersion: 901, Withdbname: false, Tagvalue: "", }}) + testutil.PrintMetrics(acc.GetTelegrafMetrics()) intMetrics := []string{ "xact_commit", @@ -161,7 +163,7 @@ func TestPostgresqlQueryOutputTestsIntegration(t *testing.T) { } for q, assertions := range examples { - acc := queryRunner(t, query{{ + acc := queryRunner(t, []query{{ Sqlquery: q, MinVersion: 901, Withdbname: false, @@ -178,7 +180,7 @@ func TestPostgresqlFieldOutputIntegration(t *testing.T) { t.Skip("Skipping integration test in short mode") } - acc := queryRunner(t, query{{ + acc := queryRunner(t, []query{{ Sqlquery: "select * from pg_stat_database", MinVersion: 901, Withdbname: false, @@ -236,7 +238,7 @@ func TestPostgresqlFieldOutputIntegration(t *testing.T) { } func TestPostgresqlSqlScript(t *testing.T) { - q := query{{ + q := []query{{ Script: "testdata/test.sql", MinVersion: 901, Withdbname: false, @@ -250,17 +252,18 @@ func TestPostgresqlSqlScript(t *testing.T) { p := &Postgresql{ Log: testutil.Logger{}, - Service: postgresql.Service{ + Config: postgresql.Config{ Address: config.NewSecret([]byte(addr)), IsPgBouncer: false, }, Databases: []string{"postgres"}, Query: q, } - var acc testutil.Accumulator require.NoError(t, p.Init()) - require.NoError(t, p.Start(&acc)) + var acc testutil.Accumulator + require.NoError(t, p.Start(&acc)) + defer p.Stop() require.NoError(t, acc.GatherError(p.Gather)) } @@ -276,17 +279,19 @@ func TestPostgresqlIgnoresUnwantedColumnsIntegration(t *testing.T) { p := &Postgresql{ Log: testutil.Logger{}, - Service: postgresql.Service{ + Config: postgresql.Config{ Address: config.NewSecret([]byte(addr)), }, } + require.NoError(t, p.Init()) var acc testutil.Accumulator - require.NoError(t, p.Start(&acc)) + defer p.Stop() require.NoError(t, acc.GatherError(p.Gather)) - require.NotEmpty(t, p.IgnoredColumns()) - for col := range p.IgnoredColumns() { + + require.NotEmpty(t, ignoredColumns) + for col := range ignoredColumns { require.False(t, acc.HasMeasurement(col)) } } @@ -294,10 +299,12 @@ func TestPostgresqlIgnoresUnwantedColumnsIntegration(t *testing.T) { func TestAccRow(t *testing.T) { p := Postgresql{ Log: testutil.Logger{}, - Service: postgresql.Service{ + Config: postgresql.Config{ + Address: config.NewSecret(nil), OutputAddress: "server", }, } + require.NoError(t, p.Init()) var acc testutil.Accumulator columns := []string{"datname", "cat"} @@ -330,7 +337,8 @@ func TestAccRow(t *testing.T) { }, } for _, tt := range tests { - require.NoError(t, p.accRow("pgTEST", tt.fields, &acc, columns)) + q := query{Measurement: "pgTEST", additionalTags: make(map[string]bool)} + require.NoError(t, p.accRow(&acc, tt.fields, columns, q)) require.Len(t, acc.Metrics, 1) metric := acc.Metrics[0] require.Equal(t, tt.dbName, metric.Tags["db"])