chore(inputs.postgresql): Factor out common code and cleanup (#15103)

This commit is contained in:
Sven Rebhan 2024-04-04 17:40:43 -04:00 committed by GitHub
parent c5e915e32b
commit aa1091aba8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 622 additions and 620 deletions

View File

@ -1,7 +1,6 @@
package postgresql package postgresql
import ( import (
"database/sql"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
@ -13,10 +12,105 @@ import (
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib" "github.com/jackc/pgx/v4/stdlib"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config" "github.com/influxdata/telegraf/config"
) )
var socketRegexp = regexp.MustCompile(`/\.s\.PGSQL\.\d+$`)
var sanitizer = regexp.MustCompile(`(\s|^)((?:password|sslcert|sslkey|sslmode|sslrootcert)\s?=\s?(?:(?:'(?:[^'\\]|\\.)*')|(?:\S+)))`)
type Config struct {
Address config.Secret `toml:"address"`
OutputAddress string `toml:"outputaddress"`
MaxIdle int `toml:"max_idle"`
MaxOpen int `toml:"max_open"`
MaxLifetime config.Duration `toml:"max_lifetime"`
IsPgBouncer bool `toml:"-"`
}
func (c *Config) CreateService() (*Service, error) {
addrSecret, err := c.Address.Get()
if err != nil {
return nil, fmt.Errorf("getting address failed: %w", err)
}
addr := addrSecret.String()
defer addrSecret.Destroy()
if c.Address.Empty() || addr == "localhost" {
addr = "host=localhost sslmode=disable"
if err := c.Address.Set([]byte(addr)); err != nil {
return nil, err
}
}
connConfig, err := pgx.ParseConfig(addr)
if err != nil {
return nil, err
}
// Remove the socket name from the path
connConfig.Host = socketRegexp.ReplaceAllLiteralString(connConfig.Host, "")
// Specific support to make it work with PgBouncer too
// See https://github.com/influxdata/telegraf/issues/3253#issuecomment-357505343
if c.IsPgBouncer {
// Remove DriveConfig and revert it by the ParseConfig method
// See https://github.com/influxdata/telegraf/issues/9134
connConfig.PreferSimpleProtocol = true
}
// Provide the connection string without sensitive information for use as
// tag or other output properties
sanitizedAddr, err := c.sanitizedAddress()
if err != nil {
return nil, err
}
return &Service{
SanitizedAddress: sanitizedAddr,
ConnectionDatabase: connectionDatabase(sanitizedAddr),
maxIdle: c.MaxIdle,
maxOpen: c.MaxOpen,
maxLifetime: time.Duration(c.MaxLifetime),
dsn: stdlib.RegisterConnConfig(connConfig),
}, nil
}
// connectionDatabase determines the database to which the connection was made
func connectionDatabase(sanitizedAddr string) string {
connConfig, err := pgx.ParseConfig(sanitizedAddr)
if err != nil || connConfig.Database == "" {
return "postgres"
}
return connConfig.Database
}
// sanitizedAddress strips sensitive information from the connection string.
// If the user set the output address use that before parsing anything else.
func (c *Config) sanitizedAddress() (string, error) {
if c.OutputAddress != "" {
return c.OutputAddress, nil
}
// Get the address
addrSecret, err := c.Address.Get()
if err != nil {
return "", fmt.Errorf("getting address for sanitization failed: %w", err)
}
defer addrSecret.Destroy()
// Make sure we convert URI-formatted strings into key-values
addr := addrSecret.TemporaryString()
if strings.HasPrefix(addr, "postgres://") || strings.HasPrefix(addr, "postgresql://") {
if addr, err = toKeyValue(addr); err != nil {
return "", err
}
}
// Sanitize the string using a regular expression
sanitized := sanitizer.ReplaceAllString(addr, "")
return strings.TrimSpace(sanitized), nil
}
// Based on parseURLSettings() at https://github.com/jackc/pgx/blob/master/pgconn/config.go // Based on parseURLSettings() at https://github.com/jackc/pgx/blob/master/pgconn/config.go
func toKeyValue(uri string) (string, error) { func toKeyValue(uri string) (string, error) {
u, err := url.Parse(uri) u, err := url.Parse(uri)
@ -88,105 +182,3 @@ func toKeyValue(uri string) (string, error) {
sort.Strings(parts) sort.Strings(parts)
return strings.Join(parts, " "), nil return strings.Join(parts, " "), nil
} }
// Service common functionality shared between the postgresql and postgresql_extensible
// packages.
type Service struct {
Address config.Secret `toml:"address"`
OutputAddress string `toml:"outputaddress"`
MaxIdle int `toml:"max_idle"`
MaxOpen int `toml:"max_open"`
MaxLifetime config.Duration `toml:"max_lifetime"`
IsPgBouncer bool `toml:"-"`
DB *sql.DB
}
var socketRegexp = regexp.MustCompile(`/\.s\.PGSQL\.\d+$`)
// Start starts the ServiceInput's service, whatever that may be
func (p *Service) Start(telegraf.Accumulator) (err error) {
addrSecret, err := p.Address.Get()
if err != nil {
return fmt.Errorf("getting address failed: %w", err)
}
addr := addrSecret.String()
defer addrSecret.Destroy()
if p.Address.Empty() || addr == "localhost" {
addr = "host=localhost sslmode=disable"
if err := p.Address.Set([]byte(addr)); err != nil {
return err
}
}
connConfig, err := pgx.ParseConfig(addr)
if err != nil {
return err
}
// Remove the socket name from the path
connConfig.Host = socketRegexp.ReplaceAllLiteralString(connConfig.Host, "")
// Specific support to make it work with PgBouncer too
// See https://github.com/influxdata/telegraf/issues/3253#issuecomment-357505343
if p.IsPgBouncer {
// Remove DriveConfig and revert it by the ParseConfig method
// See https://github.com/influxdata/telegraf/issues/9134
connConfig.PreferSimpleProtocol = true
}
connectionString := stdlib.RegisterConnConfig(connConfig)
if p.DB, err = sql.Open("pgx", connectionString); err != nil {
return err
}
p.DB.SetMaxOpenConns(p.MaxOpen)
p.DB.SetMaxIdleConns(p.MaxIdle)
p.DB.SetConnMaxLifetime(time.Duration(p.MaxLifetime))
return nil
}
// Stop stops the services and closes any necessary channels and connections
func (p *Service) Stop() {
p.DB.Close()
}
var sanitizer = regexp.MustCompile(`(\s|^)((?:password|sslcert|sslkey|sslmode|sslrootcert)\s?=\s?(?:(?:'(?:[^'\\]|\\.)*')|(?:\S+)))`)
// SanitizedAddress utility function to strip sensitive information from the connection string.
func (p *Service) SanitizedAddress() (string, error) {
if p.OutputAddress != "" {
return p.OutputAddress, nil
}
// Get the address
addrSecret, err := p.Address.Get()
if err != nil {
return "", fmt.Errorf("getting address for sanitization failed: %w", err)
}
defer addrSecret.Destroy()
// Make sure we convert URI-formatted strings into key-values
addr := addrSecret.TemporaryString()
if strings.HasPrefix(addr, "postgres://") || strings.HasPrefix(addr, "postgresql://") {
if addr, err = toKeyValue(addr); err != nil {
return "", err
}
}
// Sanitize the string using a regular expression
sanitized := sanitizer.ReplaceAllString(addr, "")
return strings.TrimSpace(sanitized), nil
}
// GetConnectDatabase utility function for getting the database to which the connection was made
// If the user set the output address use that before parsing anything else.
func (p *Service) GetConnectDatabase(connectionString string) (string, error) {
connConfig, err := pgx.ParseConfig(connectionString)
if err == nil && len(connConfig.Database) != 0 {
return connConfig.Database, nil
}
return "postgres", nil
}

View File

@ -0,0 +1,239 @@
package postgresql
import (
"strings"
"testing"
"github.com/influxdata/telegraf/config"
"github.com/stretchr/testify/require"
)
func TestURIParsing(t *testing.T) {
tests := []struct {
name string
uri string
expected string
}{
{
name: "short",
uri: `postgres://localhost`,
expected: "host=localhost",
},
{
name: "with port",
uri: `postgres://localhost:5432`,
expected: "host=localhost port=5432",
},
{
name: "with database",
uri: `postgres://localhost/mydb`,
expected: "dbname=mydb host=localhost",
},
{
name: "with additional parameters",
uri: `postgres://localhost/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5`,
expected: "application_name=pgxtest connect_timeout=5 dbname=mydb host=localhost search_path=myschema",
},
{
name: "with database setting in params",
uri: `postgres://localhost:5432/?database=mydb`,
expected: "database=mydb host=localhost port=5432",
},
{
name: "with authentication",
uri: `postgres://jack:secret@localhost:5432/mydb?sslmode=prefer`,
expected: "dbname=mydb host=localhost password=secret port=5432 sslmode=prefer user=jack",
},
{
name: "with spaces",
uri: `postgres://jack%20hunter:secret@localhost/mydb?application_name=pgx%20test`,
expected: "application_name='pgx test' dbname=mydb host=localhost password=secret user='jack hunter'",
},
{
name: "with equal signs",
uri: `postgres://jack%20hunter:secret@localhost/mydb?application_name=pgx%3Dtest`,
expected: "application_name='pgx=test' dbname=mydb host=localhost password=secret user='jack hunter'",
},
{
name: "multiple hosts",
uri: `postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable`,
expected: "dbname=mydb host=foo,bar,baz password=secret port=1,2,3 sslmode=disable user=jack",
},
{
name: "multiple hosts without ports",
uri: `postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable`,
expected: "dbname=mydb host=foo,bar,baz password=secret sslmode=disable user=jack",
},
}
for _, tt := range tests {
// Key value without spaces around equal sign
t.Run(tt.name, func(t *testing.T) {
actual, err := toKeyValue(tt.uri)
require.NoError(t, err)
require.Equalf(t, tt.expected, actual, "initial: %s", tt.uri)
})
}
}
func TestSanitizeAddressKeyValue(t *testing.T) {
keys := []string{"password", "sslcert", "sslkey", "sslmode", "sslrootcert"}
tests := []struct {
name string
value string
}{
{
name: "simple text",
value: `foo`,
},
{
name: "empty values",
value: `''`,
},
{
name: "space in value",
value: `'foo bar'`,
},
{
name: "equal sign in value",
value: `'foo=bar'`,
},
{
name: "escaped quote",
value: `'foo\'s bar'`,
},
{
name: "escaped quote no space",
value: `\'foobar\'s\'`,
},
{
name: "escaped backslash",
value: `'foo bar\\'`,
},
{
name: "escaped quote and backslash",
value: `'foo\\\'s bar'`,
},
{
name: "two escaped backslashes",
value: `'foo bar\\\\'`,
},
{
name: "multiple inline spaces",
value: "'foo \t bar'",
},
{
name: "leading space",
value: `' foo bar'`,
},
{
name: "trailing space",
value: `'foo bar '`,
},
{
name: "multiple equal signs",
value: `'foo===bar'`,
},
{
name: "leading equal sign",
value: `'=foo bar'`,
},
{
name: "trailing equal sign",
value: `'foo bar='`,
},
{
name: "mix of equal signs and spaces",
value: "'foo = a\t===\tbar'",
},
}
for _, tt := range tests {
// Key value without spaces around equal sign
t.Run(tt.name, func(t *testing.T) {
// Generate the DSN from the given keys and value
parts := make([]string, 0, len(keys))
for _, k := range keys {
parts = append(parts, k+"="+tt.value)
}
dsn := strings.Join(parts, " canary=ok ")
cfg := &Config{
Address: config.NewSecret([]byte(dsn)),
}
expected := strings.Join(make([]string, len(keys)), "canary=ok ")
expected = strings.TrimSpace(expected)
actual, err := cfg.sanitizedAddress()
require.NoError(t, err)
require.Equalf(t, expected, actual, "initial: %s", dsn)
})
// Key value with spaces around equal sign
t.Run("spaced "+tt.name, func(t *testing.T) {
// Generate the DSN from the given keys and value
parts := make([]string, 0, len(keys))
for _, k := range keys {
parts = append(parts, k+" = "+tt.value)
}
dsn := strings.Join(parts, " canary=ok ")
cfg := &Config{
Address: config.NewSecret([]byte(dsn)),
}
expected := strings.Join(make([]string, len(keys)), "canary=ok ")
expected = strings.TrimSpace(expected)
actual, err := cfg.sanitizedAddress()
require.NoError(t, err)
require.Equalf(t, expected, actual, "initial: %s", dsn)
})
}
}
func TestSanitizeAddressURI(t *testing.T) {
keys := []string{"password", "sslcert", "sslkey", "sslmode", "sslrootcert"}
tests := []struct {
name string
value string
}{
{
name: "simple text",
value: `foo`,
},
{
name: "empty values",
value: ``,
},
{
name: "space in value",
value: `foo bar`,
},
{
name: "equal sign in value",
value: `foo=bar`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Generate the DSN from the given keys and value
value := strings.ReplaceAll(tt.value, "=", "%3D")
value = strings.ReplaceAll(value, " ", "%20")
parts := make([]string, 0, len(keys))
for _, k := range keys {
parts = append(parts, k+"="+value)
}
dsn := "postgresql://user:passwd@localhost:5432/db?" + strings.Join(parts, "&")
cfg := &Config{
Address: config.NewSecret([]byte(dsn)),
}
expected := "dbname=db host=localhost port=5432 user=user"
actual, err := cfg.sanitizedAddress()
require.NoError(t, err)
require.Equalf(t, expected, actual, "initial: %s", dsn)
})
}
}

View File

@ -0,0 +1,42 @@
package postgresql
import (
"database/sql"
"time"
// Blank import required to register driver
_ "github.com/jackc/pgx/v4/stdlib"
)
// Service common functionality shared between the postgresql and postgresql_extensible
// packages.
type Service struct {
DB *sql.DB
SanitizedAddress string
ConnectionDatabase string
dsn string
maxIdle int
maxOpen int
maxLifetime time.Duration
}
func (p *Service) Start() error {
db, err := sql.Open("pgx", p.dsn)
if err != nil {
return err
}
p.DB = db
p.DB.SetMaxOpenConns(p.maxOpen)
p.DB.SetMaxIdleConns(p.maxIdle)
p.DB.SetConnMaxLifetime(p.maxLifetime)
return nil
}
func (p *Service) Stop() {
if p.DB != nil {
p.DB.Close()
}
}

View File

@ -3,25 +3,24 @@ package pgbouncer
import ( import (
"bytes" "bytes"
"database/sql"
_ "embed" _ "embed"
"fmt" "fmt"
"strconv" "strconv"
// Required for SQL framework driver
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/influxdata/telegraf" "github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config" "github.com/influxdata/telegraf/plugins/common/postgresql"
"github.com/influxdata/telegraf/plugins/inputs" "github.com/influxdata/telegraf/plugins/inputs"
"github.com/influxdata/telegraf/plugins/inputs/postgresql"
) )
//go:embed sample.conf //go:embed sample.conf
var sampleConfig string var sampleConfig string
type PgBouncer struct { type PgBouncer struct {
postgresql.Service
ShowCommands []string `toml:"show_commands"` ShowCommands []string `toml:"show_commands"`
postgresql.Config
service *postgresql.Service
} }
var ignoredColumns = map[string]bool{"user": true, "database": true, "pool_mode": true, var ignoredColumns = map[string]bool{"user": true, "database": true, "pool_mode": true,
@ -33,34 +32,54 @@ func (*PgBouncer) SampleConfig() string {
return sampleConfig return sampleConfig
} }
func (p *PgBouncer) Gather(acc telegraf.Accumulator) error { func (p *PgBouncer) Init() error {
// Set defaults and check settings
if len(p.ShowCommands) == 0 { if len(p.ShowCommands) == 0 {
if err := p.showStats(acc); err != nil { p.ShowCommands = []string{"stats", "pools"}
return err }
for _, cmd := range p.ShowCommands {
switch cmd {
case "stats", "pools", "lists", "databases":
default:
return fmt.Errorf("invalid setting %q for 'show_command'", cmd)
} }
}
if err := p.showPools(acc); err != nil { // Create a postgres service for the queries
return err service, err := p.Config.CreateService()
} if err != nil {
} else { return err
for _, cmd := range p.ShowCommands { }
switch { p.service = service
case cmd == "stats": return nil
if err := p.showStats(acc); err != nil { }
return err
} func (p *PgBouncer) Start(_ telegraf.Accumulator) error {
case cmd == "pools": return p.service.Start()
if err := p.showPools(acc); err != nil { }
return err
} func (p *PgBouncer) Stop() {
case cmd == "lists": p.service.Stop()
if err := p.showLists(acc); err != nil { }
return err
} func (p *PgBouncer) Gather(acc telegraf.Accumulator) error {
case cmd == "databases": for _, cmd := range p.ShowCommands {
if err := p.showDatabase(acc); err != nil { switch cmd {
return err case "stats":
} if err := p.showStats(acc); err != nil {
return err
}
case "pools":
if err := p.showPools(acc); err != nil {
return err
}
case "lists":
if err := p.showLists(acc); err != nil {
return err
}
case "databases":
if err := p.showDatabase(acc); err != nil {
return err
} }
} }
} }
@ -68,11 +87,7 @@ func (p *PgBouncer) Gather(acc telegraf.Accumulator) error {
return nil return nil
} }
type scanner interface { func (p *PgBouncer) accRow(row *sql.Rows, columns []string) (map[string]string, map[string]*interface{}, error) {
Scan(dest ...interface{}) error
}
func (p *PgBouncer) accRow(row scanner, columns []string) (map[string]string, map[string]*interface{}, error) {
var dbname bytes.Buffer var dbname bytes.Buffer
// this is where we'll store the column name with its *interface{} // this is where we'll store the column name with its *interface{}
@ -103,19 +118,13 @@ func (p *PgBouncer) accRow(row scanner, columns []string) (map[string]string, ma
dbname.WriteString("pgbouncer") dbname.WriteString("pgbouncer")
} }
var tagAddress string
tagAddress, err = p.SanitizedAddress()
if err != nil {
return nil, nil, fmt.Errorf("couldn't get connection data: %w", err)
}
// Return basic tags and the mapped columns // Return basic tags and the mapped columns
return map[string]string{"server": tagAddress, "db": dbname.String()}, columnMap, nil return map[string]string{"server": p.service.SanitizedAddress, "db": dbname.String()}, columnMap, nil
} }
func (p *PgBouncer) showStats(acc telegraf.Accumulator) error { func (p *PgBouncer) showStats(acc telegraf.Accumulator) error {
// STATS // STATS
rows, err := p.DB.Query(`SHOW STATS`) rows, err := p.service.DB.Query(`SHOW STATS`)
if err != nil { if err != nil {
return fmt.Errorf("execution error 'show stats': %w", err) return fmt.Errorf("execution error 'show stats': %w", err)
} }
@ -163,7 +172,7 @@ func (p *PgBouncer) showStats(acc telegraf.Accumulator) error {
func (p *PgBouncer) showPools(acc telegraf.Accumulator) error { func (p *PgBouncer) showPools(acc telegraf.Accumulator) error {
// POOLS // POOLS
poolRows, err := p.DB.Query(`SHOW POOLS`) poolRows, err := p.service.DB.Query(`SHOW POOLS`)
if err != nil { if err != nil {
return fmt.Errorf("execution error 'show pools': %w", err) return fmt.Errorf("execution error 'show pools': %w", err)
} }
@ -209,7 +218,7 @@ func (p *PgBouncer) showPools(acc telegraf.Accumulator) error {
func (p *PgBouncer) showLists(acc telegraf.Accumulator) error { func (p *PgBouncer) showLists(acc telegraf.Accumulator) error {
// LISTS // LISTS
rows, err := p.DB.Query(`SHOW LISTS`) rows, err := p.service.DB.Query(`SHOW LISTS`)
if err != nil { if err != nil {
return fmt.Errorf("execution error 'show lists': %w", err) return fmt.Errorf("execution error 'show lists': %w", err)
} }
@ -250,7 +259,7 @@ func (p *PgBouncer) showLists(acc telegraf.Accumulator) error {
func (p *PgBouncer) showDatabase(acc telegraf.Accumulator) error { func (p *PgBouncer) showDatabase(acc telegraf.Accumulator) error {
// DATABASES // DATABASES
rows, err := p.DB.Query(`SHOW DATABASES`) rows, err := p.service.DB.Query(`SHOW DATABASES`)
if err != nil { if err != nil {
return fmt.Errorf("execution error 'show database': %w", err) return fmt.Errorf("execution error 'show database': %w", err)
} }
@ -298,10 +307,9 @@ func (p *PgBouncer) showDatabase(acc telegraf.Accumulator) error {
func init() { func init() {
inputs.Add("pgbouncer", func() telegraf.Input { inputs.Add("pgbouncer", func() telegraf.Input {
return &PgBouncer{ return &PgBouncer{
Service: postgresql.Service{ Config: postgresql.Config{
MaxIdle: 1, MaxIdle: 1,
MaxOpen: 1, MaxOpen: 1,
MaxLifetime: config.Duration(0),
IsPgBouncer: true, IsPgBouncer: true,
}, },
} }

View File

@ -9,7 +9,7 @@ import (
"github.com/testcontainers/testcontainers-go/wait" "github.com/testcontainers/testcontainers-go/wait"
"github.com/influxdata/telegraf/config" "github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/plugins/inputs/postgresql" "github.com/influxdata/telegraf/plugins/common/postgresql"
"github.com/influxdata/telegraf/testutil" "github.com/influxdata/telegraf/testutil"
) )
@ -29,8 +29,7 @@ func TestPgBouncerGeneratesMetricsIntegration(t *testing.T) {
}, },
WaitingFor: wait.ForLog("database system is ready to accept connections").WithOccurrence(2), WaitingFor: wait.ForLog("database system is ready to accept connections").WithOccurrence(2),
} }
err := backend.Start() require.NoError(t, backend.Start(), "failed to start container")
require.NoError(t, err, "failed to start container")
defer backend.Terminate() defer backend.Terminate()
container := testutil.Container{ container := testutil.Container{
@ -45,8 +44,7 @@ func TestPgBouncerGeneratesMetricsIntegration(t *testing.T) {
wait.ForLog("LOG process up"), wait.ForLog("LOG process up"),
), ),
} }
err = container.Start() require.NoError(t, container.Start(), "failed to start container")
require.NoError(t, err, "failed to start container")
defer container.Terminate() defer container.Terminate()
addr := fmt.Sprintf( addr := fmt.Sprintf(
@ -56,14 +54,16 @@ func TestPgBouncerGeneratesMetricsIntegration(t *testing.T) {
) )
p := &PgBouncer{ p := &PgBouncer{
Service: postgresql.Service{ Config: postgresql.Config{
Address: config.NewSecret([]byte(addr)), Address: config.NewSecret([]byte(addr)),
IsPgBouncer: true, IsPgBouncer: true,
}, },
} }
require.NoError(t, p.Init())
var acc testutil.Accumulator var acc testutil.Accumulator
require.NoError(t, p.Start(&acc)) require.NoError(t, p.Start(&acc))
defer p.Stop()
require.NoError(t, p.Gather(&acc)) require.NoError(t, p.Gather(&acc))
intMetricsPgBouncer := []string{ intMetricsPgBouncer := []string{
@ -145,15 +145,17 @@ func TestPgBouncerGeneratesMetricsIntegrationShowCommands(t *testing.T) {
) )
p := &PgBouncer{ p := &PgBouncer{
Service: postgresql.Service{ Config: postgresql.Config{
Address: config.NewSecret([]byte(addr)), Address: config.NewSecret([]byte(addr)),
IsPgBouncer: true, IsPgBouncer: true,
}, },
ShowCommands: []string{"pools", "lists", "databases"}, ShowCommands: []string{"pools", "lists", "databases"},
} }
require.NoError(t, p.Init())
var acc testutil.Accumulator var acc testutil.Accumulator
require.NoError(t, p.Start(&acc)) require.NoError(t, p.Start(&acc))
defer p.Stop()
require.NoError(t, p.Gather(&acc)) require.NoError(t, p.Gather(&acc))
intMetricsPgBouncerPools := []string{ intMetricsPgBouncerPools := []string{

View File

@ -3,15 +3,13 @@ package postgresql
import ( import (
"bytes" "bytes"
"database/sql"
_ "embed" _ "embed"
"fmt" "fmt"
"strings" "strings"
// Blank import required to register driver
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/influxdata/telegraf" "github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config" "github.com/influxdata/telegraf/plugins/common/postgresql"
"github.com/influxdata/telegraf/plugins/inputs" "github.com/influxdata/telegraf/plugins/inputs"
) )
@ -19,10 +17,12 @@ import (
var sampleConfig string var sampleConfig string
type Postgresql struct { type Postgresql struct {
Service
Databases []string `toml:"databases"` Databases []string `toml:"databases"`
IgnoredDatabases []string `toml:"ignored_databases"` IgnoredDatabases []string `toml:"ignored_databases"`
PreparedStatements bool `toml:"prepared_statements"` PreparedStatements bool `toml:"prepared_statements"`
postgresql.Config
service *postgresql.Service
} }
var ignoredColumns = map[string]bool{"stats_reset": true} var ignoredColumns = map[string]bool{"stats_reset": true}
@ -31,22 +31,28 @@ func (*Postgresql) SampleConfig() string {
return sampleConfig return sampleConfig
} }
func (p *Postgresql) IgnoredColumns() map[string]bool {
return ignoredColumns
}
func (p *Postgresql) Init() error { func (p *Postgresql) Init() error {
p.Service.IsPgBouncer = !p.PreparedStatements p.IsPgBouncer = !p.PreparedStatements
service, err := p.Config.CreateService()
if err != nil {
return err
}
p.service = service
return nil return nil
} }
func (p *Postgresql) Gather(acc telegraf.Accumulator) error { func (p *Postgresql) Start(_ telegraf.Accumulator) error {
var ( return p.service.Start()
err error }
query string
columns []string
)
func (p *Postgresql) Stop() {
p.service.Stop()
}
func (p *Postgresql) Gather(acc telegraf.Accumulator) error {
var query string
if len(p.Databases) == 0 && len(p.IgnoredDatabases) == 0 { if len(p.Databases) == 0 && len(p.IgnoredDatabases) == 0 {
query = `SELECT * FROM pg_stat_database` query = `SELECT * FROM pg_stat_database`
} else if len(p.IgnoredDatabases) != 0 { } else if len(p.IgnoredDatabases) != 0 {
@ -57,7 +63,7 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error {
strings.Join(p.Databases, "','")) strings.Join(p.Databases, "','"))
} }
rows, err := p.DB.Query(query) rows, err := p.service.DB.Query(query)
if err != nil { if err != nil {
return err return err
} }
@ -65,7 +71,8 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error {
defer rows.Close() defer rows.Close()
// grab the column information from the result // grab the column information from the result
if columns, err = rows.Columns(); err != nil { columns, err := rows.Columns()
if err != nil {
return err return err
} }
@ -78,7 +85,7 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error {
query = `SELECT * FROM pg_stat_bgwriter` query = `SELECT * FROM pg_stat_bgwriter`
bgWriterRow, err := p.DB.Query(query) bgWriterRow, err := p.service.DB.Query(query)
if err != nil { if err != nil {
return err return err
} }
@ -91,8 +98,7 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error {
} }
for bgWriterRow.Next() { for bgWriterRow.Next() {
err = p.accRow(bgWriterRow, acc, columns) if err := p.accRow(bgWriterRow, acc, columns); err != nil {
if err != nil {
return err return err
} }
} }
@ -100,11 +106,7 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error {
return bgWriterRow.Err() return bgWriterRow.Err()
} }
type scanner interface { func (p *Postgresql) accRow(row *sql.Rows, acc telegraf.Accumulator, columns []string) error {
Scan(dest ...interface{}) error
}
func (p *Postgresql) accRow(row scanner, acc telegraf.Accumulator, columns []string) error {
var dbname bytes.Buffer var dbname bytes.Buffer
// this is where we'll store the column name with its *interface{} // this is where we'll store the column name with its *interface{}
@ -120,15 +122,8 @@ func (p *Postgresql) accRow(row scanner, acc telegraf.Accumulator, columns []str
columnVars = append(columnVars, columnMap[columns[i]]) columnVars = append(columnVars, columnMap[columns[i]])
} }
tagAddress, err := p.SanitizedAddress()
if err != nil {
return err
}
// deconstruct array of variables and send to Scan // deconstruct array of variables and send to Scan
err = row.Scan(columnVars...) if err := row.Scan(columnVars...); err != nil {
if err != nil {
return err return err
} }
if columnMap["datname"] != nil { if columnMap["datname"] != nil {
@ -140,13 +135,10 @@ func (p *Postgresql) accRow(row scanner, acc telegraf.Accumulator, columns []str
dbname.WriteString("postgres_global") dbname.WriteString("postgres_global")
} }
} else { } else {
database, err := p.GetConnectDatabase(tagAddress) dbname.WriteString(p.service.ConnectionDatabase)
if err != nil {
return err
}
dbname.WriteString(database)
} }
tagAddress := p.service.SanitizedAddress
tags := map[string]string{"server": tagAddress, "db": dbname.String()} tags := map[string]string{"server": tagAddress, "db": dbname.String()}
fields := make(map[string]interface{}) fields := make(map[string]interface{})
@ -164,10 +156,9 @@ func (p *Postgresql) accRow(row scanner, acc telegraf.Accumulator, columns []str
func init() { func init() {
inputs.Add("postgresql", func() telegraf.Input { inputs.Add("postgresql", func() telegraf.Input {
return &Postgresql{ return &Postgresql{
Service: Service{ Config: postgresql.Config{
MaxIdle: 1, MaxIdle: 1,
MaxOpen: 1, MaxOpen: 1,
MaxLifetime: config.Duration(0),
}, },
PreparedStatements: true, PreparedStatements: true,
} }

View File

@ -2,7 +2,6 @@ package postgresql
import ( import (
"fmt" "fmt"
"strings"
"testing" "testing"
"github.com/docker/go-connections/nat" "github.com/docker/go-connections/nat"
@ -10,6 +9,7 @@ import (
"github.com/testcontainers/testcontainers-go/wait" "github.com/testcontainers/testcontainers-go/wait"
"github.com/influxdata/telegraf/config" "github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/plugins/common/postgresql"
"github.com/influxdata/telegraf/testutil" "github.com/influxdata/telegraf/testutil"
) )
@ -51,15 +51,17 @@ func TestPostgresqlGeneratesMetricsIntegration(t *testing.T) {
) )
p := &Postgresql{ p := &Postgresql{
Service: Service{ Config: postgresql.Config{
Address: config.NewSecret([]byte(addr)), Address: config.NewSecret([]byte(addr)),
IsPgBouncer: false, IsPgBouncer: false,
}, },
Databases: []string{"postgres"}, Databases: []string{"postgres"},
} }
require.NoError(t, p.Init())
var acc testutil.Accumulator var acc testutil.Accumulator
require.NoError(t, p.Start(&acc)) require.NoError(t, p.Start(&acc))
defer p.Stop()
require.NoError(t, p.Gather(&acc)) require.NoError(t, p.Gather(&acc))
intMetrics := []string{ intMetrics := []string{
@ -142,15 +144,16 @@ func TestPostgresqlTagsMetricsWithDatabaseNameIntegration(t *testing.T) {
) )
p := &Postgresql{ p := &Postgresql{
Service: Service{ Config: postgresql.Config{
Address: config.NewSecret([]byte(addr)), Address: config.NewSecret([]byte(addr)),
}, },
Databases: []string{"postgres"}, Databases: []string{"postgres"},
} }
require.NoError(t, p.Init())
var acc testutil.Accumulator var acc testutil.Accumulator
require.NoError(t, p.Start(&acc)) require.NoError(t, p.Start(&acc))
defer p.Stop()
require.NoError(t, p.Gather(&acc)) require.NoError(t, p.Gather(&acc))
point, ok := acc.Get("postgresql") point, ok := acc.Get("postgresql")
@ -174,14 +177,15 @@ func TestPostgresqlDefaultsToAllDatabasesIntegration(t *testing.T) {
) )
p := &Postgresql{ p := &Postgresql{
Service: Service{ Config: postgresql.Config{
Address: config.NewSecret([]byte(addr)), Address: config.NewSecret([]byte(addr)),
}, },
} }
require.NoError(t, p.Init())
var acc testutil.Accumulator var acc testutil.Accumulator
require.NoError(t, p.Start(&acc)) require.NoError(t, p.Start(&acc))
defer p.Stop()
require.NoError(t, p.Gather(&acc)) require.NoError(t, p.Gather(&acc))
var found bool var found bool
@ -213,16 +217,18 @@ func TestPostgresqlIgnoresUnwantedColumnsIntegration(t *testing.T) {
) )
p := &Postgresql{ p := &Postgresql{
Service: Service{ Config: postgresql.Config{
Address: config.NewSecret([]byte(addr)), Address: config.NewSecret([]byte(addr)),
}, },
} }
require.NoError(t, p.Init())
var acc testutil.Accumulator var acc testutil.Accumulator
require.NoError(t, p.Start(&acc)) require.NoError(t, p.Start(&acc))
defer p.Stop()
require.NoError(t, p.Gather(&acc)) require.NoError(t, p.Gather(&acc))
for col := range p.IgnoredColumns() { for col := range ignoredColumns {
require.False(t, acc.HasMeasurement(col)) require.False(t, acc.HasMeasurement(col))
} }
} }
@ -242,15 +248,16 @@ func TestPostgresqlDatabaseWhitelistTestIntegration(t *testing.T) {
) )
p := &Postgresql{ p := &Postgresql{
Service: Service{ Config: postgresql.Config{
Address: config.NewSecret([]byte(addr)), Address: config.NewSecret([]byte(addr)),
}, },
Databases: []string{"template0"}, Databases: []string{"template0"},
} }
require.NoError(t, p.Init())
var acc testutil.Accumulator var acc testutil.Accumulator
require.NoError(t, p.Start(&acc)) require.NoError(t, p.Start(&acc))
defer p.Stop()
require.NoError(t, p.Gather(&acc)) require.NoError(t, p.Gather(&acc))
var foundTemplate0 = false var foundTemplate0 = false
@ -288,14 +295,16 @@ func TestPostgresqlDatabaseBlacklistTestIntegration(t *testing.T) {
) )
p := &Postgresql{ p := &Postgresql{
Service: Service{ Config: postgresql.Config{
Address: config.NewSecret([]byte(addr)), Address: config.NewSecret([]byte(addr)),
}, },
IgnoredDatabases: []string{"template0"}, IgnoredDatabases: []string{"template0"},
} }
require.NoError(t, p.Init())
var acc testutil.Accumulator var acc testutil.Accumulator
require.NoError(t, p.Start(&acc)) require.NoError(t, p.Start(&acc))
defer p.Stop()
require.NoError(t, p.Gather(&acc)) require.NoError(t, p.Gather(&acc))
var foundTemplate0 = false var foundTemplate0 = false
@ -317,239 +326,3 @@ func TestPostgresqlDatabaseBlacklistTestIntegration(t *testing.T) {
require.False(t, foundTemplate0) require.False(t, foundTemplate0)
require.True(t, foundTemplate1) require.True(t, foundTemplate1)
} }
func TestURIParsing(t *testing.T) {
tests := []struct {
name string
uri string
expected string
}{
{
name: "short",
uri: `postgres://localhost`,
expected: "host=localhost",
},
{
name: "with port",
uri: `postgres://localhost:5432`,
expected: "host=localhost port=5432",
},
{
name: "with database",
uri: `postgres://localhost/mydb`,
expected: "dbname=mydb host=localhost",
},
{
name: "with additional parameters",
uri: `postgres://localhost/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5`,
expected: "application_name=pgxtest connect_timeout=5 dbname=mydb host=localhost search_path=myschema",
},
{
name: "with database setting in params",
uri: `postgres://localhost:5432/?database=mydb`,
expected: "database=mydb host=localhost port=5432",
},
{
name: "with authentication",
uri: `postgres://jack:secret@localhost:5432/mydb?sslmode=prefer`,
expected: "dbname=mydb host=localhost password=secret port=5432 sslmode=prefer user=jack",
},
{
name: "with spaces",
uri: `postgres://jack%20hunter:secret@localhost/mydb?application_name=pgx%20test`,
expected: "application_name='pgx test' dbname=mydb host=localhost password=secret user='jack hunter'",
},
{
name: "with equal signs",
uri: `postgres://jack%20hunter:secret@localhost/mydb?application_name=pgx%3Dtest`,
expected: "application_name='pgx=test' dbname=mydb host=localhost password=secret user='jack hunter'",
},
{
name: "multiple hosts",
uri: `postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable`,
expected: "dbname=mydb host=foo,bar,baz password=secret port=1,2,3 sslmode=disable user=jack",
},
{
name: "multiple hosts without ports",
uri: `postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable`,
expected: "dbname=mydb host=foo,bar,baz password=secret sslmode=disable user=jack",
},
}
for _, tt := range tests {
// Key value without spaces around equal sign
t.Run(tt.name, func(t *testing.T) {
actual, err := toKeyValue(tt.uri)
require.NoError(t, err)
require.Equalf(t, tt.expected, actual, "initial: %s", tt.uri)
})
}
}
func TestSanitizeAddressKeyValue(t *testing.T) {
keys := []string{"password", "sslcert", "sslkey", "sslmode", "sslrootcert"}
tests := []struct {
name string
value string
}{
{
name: "simple text",
value: `foo`,
},
{
name: "empty values",
value: `''`,
},
{
name: "space in value",
value: `'foo bar'`,
},
{
name: "equal sign in value",
value: `'foo=bar'`,
},
{
name: "escaped quote",
value: `'foo\'s bar'`,
},
{
name: "escaped quote no space",
value: `\'foobar\'s\'`,
},
{
name: "escaped backslash",
value: `'foo bar\\'`,
},
{
name: "escaped quote and backslash",
value: `'foo\\\'s bar'`,
},
{
name: "two escaped backslashes",
value: `'foo bar\\\\'`,
},
{
name: "multiple inline spaces",
value: "'foo \t bar'",
},
{
name: "leading space",
value: `' foo bar'`,
},
{
name: "trailing space",
value: `'foo bar '`,
},
{
name: "multiple equal signs",
value: `'foo===bar'`,
},
{
name: "leading equal sign",
value: `'=foo bar'`,
},
{
name: "trailing equal sign",
value: `'foo bar='`,
},
{
name: "mix of equal signs and spaces",
value: "'foo = a\t===\tbar'",
},
}
for _, tt := range tests {
// Key value without spaces around equal sign
t.Run(tt.name, func(t *testing.T) {
// Generate the DSN from the given keys and value
parts := make([]string, 0, len(keys))
for _, k := range keys {
parts = append(parts, k+"="+tt.value)
}
dsn := strings.Join(parts, " canary=ok ")
plugin := &Postgresql{
Service: Service{
Address: config.NewSecret([]byte(dsn)),
},
}
expected := strings.Join(make([]string, len(keys)), "canary=ok ")
expected = strings.TrimSpace(expected)
actual, err := plugin.SanitizedAddress()
require.NoError(t, err)
require.Equalf(t, expected, actual, "initial: %s", dsn)
})
// Key value with spaces around equal sign
t.Run("spaced "+tt.name, func(t *testing.T) {
// Generate the DSN from the given keys and value
parts := make([]string, 0, len(keys))
for _, k := range keys {
parts = append(parts, k+" = "+tt.value)
}
dsn := strings.Join(parts, " canary=ok ")
plugin := &Postgresql{
Service: Service{
Address: config.NewSecret([]byte(dsn)),
},
}
expected := strings.Join(make([]string, len(keys)), "canary=ok ")
expected = strings.TrimSpace(expected)
actual, err := plugin.SanitizedAddress()
require.NoError(t, err)
require.Equalf(t, expected, actual, "initial: %s", dsn)
})
}
}
func TestSanitizeAddressURI(t *testing.T) {
keys := []string{"password", "sslcert", "sslkey", "sslmode", "sslrootcert"}
tests := []struct {
name string
value string
}{
{
name: "simple text",
value: `foo`,
},
{
name: "empty values",
value: ``,
},
{
name: "space in value",
value: `foo bar`,
},
{
name: "equal sign in value",
value: `foo=bar`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Generate the DSN from the given keys and value
value := strings.ReplaceAll(tt.value, "=", "%3D")
value = strings.ReplaceAll(value, " ", "%20")
parts := make([]string, 0, len(keys))
for _, k := range keys {
parts = append(parts, k+"="+value)
}
dsn := "postgresql://user:passwd@localhost:5432/db?" + strings.Join(parts, "&")
plugin := &Postgresql{
Service: Service{
Address: config.NewSecret([]byte(dsn)),
},
}
expected := "dbname=db host=localhost port=5432 user=user"
actual, err := plugin.SanitizedAddress()
require.NoError(t, err)
require.Equalf(t, expected, actual, "initial: %s", dsn)
})
}
}

View File

@ -5,9 +5,7 @@ import (
"bytes" "bytes"
_ "embed" _ "embed"
"fmt" "fmt"
"io"
"os" "os"
"strconv"
"strings" "strings"
"time" "time"
@ -15,36 +13,36 @@ import (
_ "github.com/jackc/pgx/v4/stdlib" _ "github.com/jackc/pgx/v4/stdlib"
"github.com/influxdata/telegraf" "github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config" "github.com/influxdata/telegraf/internal"
"github.com/influxdata/telegraf/plugins/common/postgresql"
"github.com/influxdata/telegraf/plugins/inputs" "github.com/influxdata/telegraf/plugins/inputs"
"github.com/influxdata/telegraf/plugins/inputs/postgresql"
) )
//go:embed sample.conf //go:embed sample.conf
var sampleConfig string var sampleConfig string
type Postgresql struct { type Postgresql struct {
postgresql.Service Databases []string `deprecated:"1.22.4;use the sqlquery option to specify database to use"`
Databases []string `deprecated:"1.22.4;use the sqlquery option to specify database to use"` Query []query `toml:"query"`
AdditionalTags []string PreparedStatements bool `toml:"prepared_statements"`
Timestamp string Log telegraf.Logger `toml:"-"`
Query query postgresql.Config
Debug bool
PreparedStatements bool `toml:"prepared_statements"`
Log telegraf.Logger service *postgresql.Service
} }
type query []struct { type query struct {
Sqlquery string Sqlquery string `toml:"sqlquery"`
Script string Script string `toml:"script"`
Version int `deprecated:"1.28.0;use minVersion to specify minimal DB version this query supports"` Version int `deprecated:"1.28.0;use minVersion to specify minimal DB version this query supports"`
MinVersion int `toml:"min_version"` MinVersion int `toml:"min_version"`
MaxVersion int `toml:"max_version"` MaxVersion int `toml:"max_version"`
Withdbname bool `deprecated:"1.22.4;use the sqlquery option to specify database to use"` Withdbname bool `deprecated:"1.22.4;use the sqlquery option to specify database to use"`
Tagvalue string Tagvalue string `toml:"tagvalue"`
Measurement string Measurement string `toml:"measurement"`
Timestamp string Timestamp string `toml:"timestamp"`
additionalTags map[string]bool
} }
var ignoredColumns = map[string]bool{"stats_reset": true} var ignoredColumns = map[string]bool{"stats_reset": true}
@ -54,133 +52,105 @@ func (*Postgresql) SampleConfig() string {
} }
func (p *Postgresql) Init() error { func (p *Postgresql) Init() error {
var err error // Set defaults for the queries
for i := range p.Query { for i, q := range p.Query {
if p.Query[i].Sqlquery == "" { if q.Sqlquery == "" {
p.Query[i].Sqlquery, err = ReadQueryFromFile(p.Query[i].Script) query, err := os.ReadFile(q.Script)
if err != nil { if err != nil {
return err return err
} }
q.Sqlquery = string(query)
} }
if p.Query[i].MinVersion == 0 { if q.MinVersion == 0 {
p.Query[i].MinVersion = p.Query[i].Version q.MinVersion = q.Version
} }
} if q.Measurement == "" {
p.Service.IsPgBouncer = !p.PreparedStatements q.Measurement = "postgresql"
return nil
}
func (p *Postgresql) IgnoredColumns() map[string]bool {
return ignoredColumns
}
func ReadQueryFromFile(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
query, err := io.ReadAll(file)
if err != nil {
return "", err
}
return string(query), err
}
func (p *Postgresql) Gather(acc telegraf.Accumulator) error {
var (
err error
sqlQuery string
queryAddon string
dbVersion int
query string
measName string
)
// Retrieving the database version
query = `SELECT setting::integer / 100 AS version FROM pg_settings WHERE name = 'server_version_num'`
if err = p.DB.QueryRow(query).Scan(&dbVersion); err != nil {
dbVersion = 0
}
// We loop in order to process each query
// Query is not run if Database version does not match the query version.
for i := range p.Query {
sqlQuery = p.Query[i].Sqlquery
if p.Query[i].Measurement != "" {
measName = p.Query[i].Measurement
} else {
measName = "postgresql"
} }
if p.Query[i].Withdbname { var queryAddon string
if q.Withdbname {
if len(p.Databases) != 0 { if len(p.Databases) != 0 {
queryAddon = fmt.Sprintf(` IN ('%s')`, strings.Join(p.Databases, "','")) queryAddon = fmt.Sprintf(` IN ('%s')`, strings.Join(p.Databases, "','"))
} else { } else {
queryAddon = " is not null" queryAddon = " is not null"
} }
} else {
queryAddon = ""
} }
sqlQuery += queryAddon q.Sqlquery += queryAddon
maxVer := p.Query[i].MaxVersion q.additionalTags = make(map[string]bool)
if q.Tagvalue != "" {
for _, tag := range strings.Split(q.Tagvalue, ",") {
q.additionalTags[tag] = true
}
}
p.Query[i] = q
}
p.Config.IsPgBouncer = !p.PreparedStatements
if p.Query[i].MinVersion <= dbVersion && (maxVer == 0 || maxVer > dbVersion) { // Create a service to access the PostgreSQL server
p.gatherMetricsFromQuery(acc, sqlQuery, p.Query[i].Tagvalue, p.Query[i].Timestamp, measName) service, err := p.Config.CreateService()
if err != nil {
return err
}
p.service = service
return nil
}
func (p *Postgresql) Start(_ telegraf.Accumulator) error {
return p.service.Start()
}
func (p *Postgresql) Stop() {
p.service.Stop()
}
func (p *Postgresql) Gather(acc telegraf.Accumulator) error {
// Retrieving the database version
query := `SELECT setting::integer / 100 AS version FROM pg_settings WHERE name = 'server_version_num'`
var dbVersion int
if err := p.service.DB.QueryRow(query).Scan(&dbVersion); err != nil {
dbVersion = 0
}
// We loop in order to process each query
// Query is not run if Database version does not match the query version.
for _, q := range p.Query {
if q.MinVersion <= dbVersion && (q.MaxVersion == 0 || q.MaxVersion > dbVersion) {
acc.AddError(p.gatherMetricsFromQuery(acc, q))
} }
} }
return nil return nil
} }
func (p *Postgresql) gatherMetricsFromQuery(acc telegraf.Accumulator, sqlQuery string, tagValue string, timestamp string, measName string) { func (p *Postgresql) gatherMetricsFromQuery(acc telegraf.Accumulator, q query) error {
var columns []string rows, err := p.service.DB.Query(q.Sqlquery)
rows, err := p.DB.Query(sqlQuery)
if err != nil { if err != nil {
acc.AddError(err) return err
return
} }
defer rows.Close() defer rows.Close()
// grab the column information from the result // grab the column information from the result
if columns, err = rows.Columns(); err != nil { columns, err := rows.Columns()
acc.AddError(err) if err != nil {
return return err
} }
p.AdditionalTags = nil
if tagValue != "" {
tagList := strings.Split(tagValue, ",")
p.AdditionalTags = append(p.AdditionalTags, tagList...)
}
p.Timestamp = timestamp
for rows.Next() { for rows.Next() {
err = p.accRow(measName, rows, acc, columns) if err := p.accRow(acc, rows, columns, q); err != nil {
if err != nil { return err
acc.AddError(err)
break
} }
} }
return nil
} }
type scanner interface { type scanner interface {
Scan(dest ...interface{}) error Scan(dest ...interface{}) error
} }
func (p *Postgresql) accRow(measName string, row scanner, acc telegraf.Accumulator, columns []string) error { func (p *Postgresql) accRow(acc telegraf.Accumulator, row scanner, columns []string, q query) error {
var (
err error
dbname bytes.Buffer
tagAddress string
timestamp time.Time
)
// this is where we'll store the column name with its *interface{} // this is where we'll store the column name with its *interface{}
columnMap := make(map[string]*interface{}) columnMap := make(map[string]*interface{})
@ -194,46 +164,34 @@ func (p *Postgresql) accRow(measName string, row scanner, acc telegraf.Accumulat
columnVars = append(columnVars, columnMap[columns[i]]) columnVars = append(columnVars, columnMap[columns[i]])
} }
if tagAddress, err = p.SanitizedAddress(); err != nil {
return err
}
// deconstruct array of variables and send to Scan // deconstruct array of variables and send to Scan
if err := row.Scan(columnVars...); err != nil { if err := row.Scan(columnVars...); err != nil {
return err return err
} }
var dbname bytes.Buffer
if c, ok := columnMap["datname"]; ok && *c != nil { if c, ok := columnMap["datname"]; ok && *c != nil {
// extract the database name from the column map // extract the database name from the column map
switch datname := (*c).(type) { switch datname := (*c).(type) {
case string: case string:
dbname.WriteString(datname) dbname.WriteString(datname)
default: default:
database, err := p.GetConnectDatabase(tagAddress) dbname.WriteString(p.service.ConnectionDatabase)
if err != nil {
return err
}
dbname.WriteString(database)
} }
} else { } else {
database, err := p.GetConnectDatabase(tagAddress) dbname.WriteString(p.service.ConnectionDatabase)
if err != nil {
return err
}
dbname.WriteString(database)
} }
// Process the additional tags // Process the additional tags
tags := map[string]string{ tags := map[string]string{
"server": tagAddress, "server": p.service.SanitizedAddress,
"db": dbname.String(), "db": dbname.String(),
} }
// set default timestamp to Now // set default timestamp to Now
timestamp = time.Now() timestamp := time.Now()
fields := make(map[string]interface{}) fields := make(map[string]interface{})
COLUMN:
for col, val := range columnMap { for col, val := range columnMap {
p.Log.Debugf("Column: %s = %T: %v\n", col, *val, *val) p.Log.Debugf("Column: %s = %T: %v\n", col, *val, *val)
_, ignore := ignoredColumns[col] _, ignore := ignoredColumns[col]
@ -241,30 +199,21 @@ COLUMN:
continue continue
} }
if col == p.Timestamp { if col == q.Timestamp {
if v, ok := (*val).(time.Time); ok { if v, ok := (*val).(time.Time); ok {
timestamp = v timestamp = v
} }
continue continue
} }
for _, tag := range p.AdditionalTags { if q.additionalTags[col] {
if col != tag { v, err := internal.ToString(*val)
continue if err != nil {
} p.Log.Debugf("Failed to add %q as additional tag: %v", col, err)
switch v := (*val).(type) { } else {
case string:
tags[col] = v tags[col] = v
case []byte:
tags[col] = string(v)
case int64, int32, int:
tags[col] = fmt.Sprintf("%d", v)
case bool:
tags[col] = strconv.FormatBool(v)
default:
p.Log.Debugf("Failed to add %q as additional tag", col)
} }
continue COLUMN continue
} }
if v, ok := (*val).([]byte); ok { if v, ok := (*val).([]byte); ok {
@ -273,18 +222,16 @@ COLUMN:
fields[col] = *val fields[col] = *val
} }
} }
acc.AddFields(measName, fields, tags, timestamp) acc.AddFields(q.Measurement, fields, tags, timestamp)
return nil return nil
} }
func init() { func init() {
inputs.Add("postgresql_extensible", func() telegraf.Input { inputs.Add("postgresql_extensible", func() telegraf.Input {
return &Postgresql{ return &Postgresql{
Service: postgresql.Service{ Config: postgresql.Config{
MaxIdle: 1, MaxIdle: 1,
MaxOpen: 1, MaxOpen: 1,
MaxLifetime: config.Duration(0),
IsPgBouncer: false,
}, },
PreparedStatements: true, PreparedStatements: true,
} }

View File

@ -11,11 +11,11 @@ import (
"github.com/testcontainers/testcontainers-go/wait" "github.com/testcontainers/testcontainers-go/wait"
"github.com/influxdata/telegraf/config" "github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/plugins/inputs/postgresql" "github.com/influxdata/telegraf/plugins/common/postgresql"
"github.com/influxdata/telegraf/testutil" "github.com/influxdata/telegraf/testutil"
) )
func queryRunner(t *testing.T, q query) *testutil.Accumulator { func queryRunner(t *testing.T, q []query) *testutil.Accumulator {
servicePort := "5432" servicePort := "5432"
container := testutil.Container{ container := testutil.Container{
Image: "postgres:alpine", Image: "postgres:alpine",
@ -29,8 +29,7 @@ func queryRunner(t *testing.T, q query) *testutil.Accumulator {
), ),
} }
err := container.Start() require.NoError(t, container.Start(), "failed to start container")
require.NoError(t, err, "failed to start container")
defer container.Terminate() defer container.Terminate()
addr := fmt.Sprintf( addr := fmt.Sprintf(
@ -41,18 +40,20 @@ func queryRunner(t *testing.T, q query) *testutil.Accumulator {
p := &Postgresql{ p := &Postgresql{
Log: testutil.Logger{}, Log: testutil.Logger{},
Service: postgresql.Service{ Config: postgresql.Config{
Address: config.NewSecret([]byte(addr)), Address: config.NewSecret([]byte(addr)),
IsPgBouncer: false, IsPgBouncer: false,
}, },
Databases: []string{"postgres"}, Databases: []string{"postgres"},
Query: q, Query: q,
} }
require.NoError(t, p.Init())
var acc testutil.Accumulator var acc testutil.Accumulator
require.NoError(t, p.Init())
require.NoError(t, p.Start(&acc)) require.NoError(t, p.Start(&acc))
defer p.Stop()
require.NoError(t, acc.GatherError(p.Gather)) require.NoError(t, acc.GatherError(p.Gather))
return &acc return &acc
} }
@ -61,12 +62,13 @@ func TestPostgresqlGeneratesMetricsIntegration(t *testing.T) {
t.Skip("Skipping integration test in short mode") t.Skip("Skipping integration test in short mode")
} }
acc := queryRunner(t, query{{ acc := queryRunner(t, []query{{
Sqlquery: "select * from pg_stat_database", Sqlquery: "select * from pg_stat_database",
MinVersion: 901, MinVersion: 901,
Withdbname: false, Withdbname: false,
Tagvalue: "", Tagvalue: "",
}}) }})
testutil.PrintMetrics(acc.GetTelegrafMetrics())
intMetrics := []string{ intMetrics := []string{
"xact_commit", "xact_commit",
@ -161,7 +163,7 @@ func TestPostgresqlQueryOutputTestsIntegration(t *testing.T) {
} }
for q, assertions := range examples { for q, assertions := range examples {
acc := queryRunner(t, query{{ acc := queryRunner(t, []query{{
Sqlquery: q, Sqlquery: q,
MinVersion: 901, MinVersion: 901,
Withdbname: false, Withdbname: false,
@ -178,7 +180,7 @@ func TestPostgresqlFieldOutputIntegration(t *testing.T) {
t.Skip("Skipping integration test in short mode") t.Skip("Skipping integration test in short mode")
} }
acc := queryRunner(t, query{{ acc := queryRunner(t, []query{{
Sqlquery: "select * from pg_stat_database", Sqlquery: "select * from pg_stat_database",
MinVersion: 901, MinVersion: 901,
Withdbname: false, Withdbname: false,
@ -236,7 +238,7 @@ func TestPostgresqlFieldOutputIntegration(t *testing.T) {
} }
func TestPostgresqlSqlScript(t *testing.T) { func TestPostgresqlSqlScript(t *testing.T) {
q := query{{ q := []query{{
Script: "testdata/test.sql", Script: "testdata/test.sql",
MinVersion: 901, MinVersion: 901,
Withdbname: false, Withdbname: false,
@ -250,17 +252,18 @@ func TestPostgresqlSqlScript(t *testing.T) {
p := &Postgresql{ p := &Postgresql{
Log: testutil.Logger{}, Log: testutil.Logger{},
Service: postgresql.Service{ Config: postgresql.Config{
Address: config.NewSecret([]byte(addr)), Address: config.NewSecret([]byte(addr)),
IsPgBouncer: false, IsPgBouncer: false,
}, },
Databases: []string{"postgres"}, Databases: []string{"postgres"},
Query: q, Query: q,
} }
var acc testutil.Accumulator
require.NoError(t, p.Init()) require.NoError(t, p.Init())
require.NoError(t, p.Start(&acc))
var acc testutil.Accumulator
require.NoError(t, p.Start(&acc))
defer p.Stop()
require.NoError(t, acc.GatherError(p.Gather)) require.NoError(t, acc.GatherError(p.Gather))
} }
@ -276,17 +279,19 @@ func TestPostgresqlIgnoresUnwantedColumnsIntegration(t *testing.T) {
p := &Postgresql{ p := &Postgresql{
Log: testutil.Logger{}, Log: testutil.Logger{},
Service: postgresql.Service{ Config: postgresql.Config{
Address: config.NewSecret([]byte(addr)), Address: config.NewSecret([]byte(addr)),
}, },
} }
require.NoError(t, p.Init())
var acc testutil.Accumulator var acc testutil.Accumulator
require.NoError(t, p.Start(&acc)) require.NoError(t, p.Start(&acc))
defer p.Stop()
require.NoError(t, acc.GatherError(p.Gather)) require.NoError(t, acc.GatherError(p.Gather))
require.NotEmpty(t, p.IgnoredColumns())
for col := range p.IgnoredColumns() { require.NotEmpty(t, ignoredColumns)
for col := range ignoredColumns {
require.False(t, acc.HasMeasurement(col)) require.False(t, acc.HasMeasurement(col))
} }
} }
@ -294,10 +299,12 @@ func TestPostgresqlIgnoresUnwantedColumnsIntegration(t *testing.T) {
func TestAccRow(t *testing.T) { func TestAccRow(t *testing.T) {
p := Postgresql{ p := Postgresql{
Log: testutil.Logger{}, Log: testutil.Logger{},
Service: postgresql.Service{ Config: postgresql.Config{
Address: config.NewSecret(nil),
OutputAddress: "server", OutputAddress: "server",
}, },
} }
require.NoError(t, p.Init())
var acc testutil.Accumulator var acc testutil.Accumulator
columns := []string{"datname", "cat"} columns := []string{"datname", "cat"}
@ -330,7 +337,8 @@ func TestAccRow(t *testing.T) {
}, },
} }
for _, tt := range tests { for _, tt := range tests {
require.NoError(t, p.accRow("pgTEST", tt.fields, &acc, columns)) q := query{Measurement: "pgTEST", additionalTags: make(map[string]bool)}
require.NoError(t, p.accRow(&acc, tt.fields, columns, q))
require.Len(t, acc.Metrics, 1) require.Len(t, acc.Metrics, 1)
metric := acc.Metrics[0] metric := acc.Metrics[0]
require.Equal(t, tt.dbName, metric.Tags["db"]) require.Equal(t, tt.dbName, metric.Tags["db"])