diff --git a/CHANGELOG.md b/CHANGELOG.md index bb16492d7..e2a66533d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ consecutive messages exceeds the timeout. [#14837](https://github.com/influxdata/telegraf/pull/14828) sets the timeout to infinite (i.e zero) as this is the expected behavior. +- With correctly sanitizing PostgreSQL addresses ([PR #14829](https://github.com/influxdata/telegraf/pull/14829)) + the `server` tag value for a URI-format address might change in case it + contains spaces, backslashes or single-quotes in non-redacted parameters. ## v1.29.5 [2024-02-20] diff --git a/plugins/inputs/postgresql/postgresql_test.go b/plugins/inputs/postgresql/postgresql_test.go index fa54096ed..dfb5da0d2 100644 --- a/plugins/inputs/postgresql/postgresql_test.go +++ b/plugins/inputs/postgresql/postgresql_test.go @@ -2,6 +2,7 @@ package postgresql import ( "fmt" + "strings" "testing" "github.com/docker/go-connections/nat" @@ -316,3 +317,239 @@ func TestPostgresqlDatabaseBlacklistTestIntegration(t *testing.T) { require.False(t, foundTemplate0) require.True(t, foundTemplate1) } + +func TestURIParsing(t *testing.T) { + tests := []struct { + name string + uri string + expected string + }{ + { + name: "short", + uri: `postgres://localhost`, + expected: "host=localhost", + }, + { + name: "with port", + uri: `postgres://localhost:5432`, + expected: "host=localhost port=5432", + }, + { + name: "with database", + uri: `postgres://localhost/mydb`, + expected: "dbname=mydb host=localhost", + }, + { + name: "with additional parameters", + uri: `postgres://localhost/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5`, + expected: "application_name=pgxtest connect_timeout=5 dbname=mydb host=localhost search_path=myschema", + }, + { + name: "with database setting in params", + uri: `postgres://localhost:5432/?database=mydb`, + expected: "database=mydb host=localhost port=5432", + }, + { + name: "with authentication", + uri: `postgres://jack:secret@localhost:5432/mydb?sslmode=prefer`, + expected: "dbname=mydb host=localhost password=secret port=5432 sslmode=prefer user=jack", + }, + { + name: "with spaces", + uri: `postgres://jack%20hunter:secret@localhost/mydb?application_name=pgx%20test`, + expected: "application_name='pgx test' dbname=mydb host=localhost password=secret user='jack hunter'", + }, + { + name: "with equal signs", + uri: `postgres://jack%20hunter:secret@localhost/mydb?application_name=pgx%3Dtest`, + expected: "application_name='pgx=test' dbname=mydb host=localhost password=secret user='jack hunter'", + }, + { + name: "multiple hosts", + uri: `postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable`, + expected: "dbname=mydb host=foo,bar,baz password=secret port=1,2,3 sslmode=disable user=jack", + }, + { + name: "multiple hosts without ports", + uri: `postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable`, + expected: "dbname=mydb host=foo,bar,baz password=secret sslmode=disable user=jack", + }, + } + + for _, tt := range tests { + // Key value without spaces around equal sign + t.Run(tt.name, func(t *testing.T) { + actual, err := toKeyValue(tt.uri) + require.NoError(t, err) + require.Equalf(t, tt.expected, actual, "initial: %s", tt.uri) + }) + } +} + +func TestSanitizeAddressKeyValue(t *testing.T) { + keys := []string{"password", "sslcert", "sslkey", "sslmode", "sslrootcert"} + tests := []struct { + name string + value string + }{ + { + name: "simple text", + value: `foo`, + }, + { + name: "empty values", + value: `''`, + }, + { + name: "space in value", + value: `'foo bar'`, + }, + { + name: "equal sign in value", + value: `'foo=bar'`, + }, + { + name: "escaped quote", + value: `'foo\'s bar'`, + }, + { + name: "escaped quote no space", + value: `\'foobar\'s\'`, + }, + { + name: "escaped backslash", + value: `'foo bar\\'`, + }, + { + name: "escaped quote and backslash", + value: `'foo\\\'s bar'`, + }, + { + name: "two escaped backslashes", + value: `'foo bar\\\\'`, + }, + { + name: "multiple inline spaces", + value: "'foo \t bar'", + }, + { + name: "leading space", + value: `' foo bar'`, + }, + { + name: "trailing space", + value: `'foo bar '`, + }, + { + name: "multiple equal signs", + value: `'foo===bar'`, + }, + { + name: "leading equal sign", + value: `'=foo bar'`, + }, + { + name: "trailing equal sign", + value: `'foo bar='`, + }, + { + name: "mix of equal signs and spaces", + value: "'foo = a\t===\tbar'", + }, + } + + for _, tt := range tests { + // Key value without spaces around equal sign + t.Run(tt.name, func(t *testing.T) { + // Generate the DSN from the given keys and value + parts := make([]string, 0, len(keys)) + for _, k := range keys { + parts = append(parts, k+"="+tt.value) + } + dsn := strings.Join(parts, " canary=ok ") + + plugin := &Postgresql{ + Service: Service{ + Address: config.NewSecret([]byte(dsn)), + }, + } + + expected := strings.Join(make([]string, len(keys)), "canary=ok ") + expected = strings.TrimSpace(expected) + actual, err := plugin.SanitizedAddress() + require.NoError(t, err) + require.Equalf(t, expected, actual, "initial: %s", dsn) + }) + + // Key value with spaces around equal sign + t.Run("spaced "+tt.name, func(t *testing.T) { + // Generate the DSN from the given keys and value + parts := make([]string, 0, len(keys)) + for _, k := range keys { + parts = append(parts, k+" = "+tt.value) + } + dsn := strings.Join(parts, " canary=ok ") + + plugin := &Postgresql{ + Service: Service{ + Address: config.NewSecret([]byte(dsn)), + }, + } + + expected := strings.Join(make([]string, len(keys)), "canary=ok ") + expected = strings.TrimSpace(expected) + actual, err := plugin.SanitizedAddress() + require.NoError(t, err) + require.Equalf(t, expected, actual, "initial: %s", dsn) + }) + } +} + +func TestSanitizeAddressURI(t *testing.T) { + keys := []string{"password", "sslcert", "sslkey", "sslmode", "sslrootcert"} + tests := []struct { + name string + value string + }{ + { + name: "simple text", + value: `foo`, + }, + { + name: "empty values", + value: ``, + }, + { + name: "space in value", + value: `foo bar`, + }, + { + name: "equal sign in value", + value: `foo=bar`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Generate the DSN from the given keys and value + value := strings.ReplaceAll(tt.value, "=", "%3D") + value = strings.ReplaceAll(value, " ", "%20") + parts := make([]string, 0, len(keys)) + for _, k := range keys { + parts = append(parts, k+"="+value) + } + dsn := "postgresql://user:passwd@localhost:5432/db?" + strings.Join(parts, "&") + + plugin := &Postgresql{ + Service: Service{ + Address: config.NewSecret([]byte(dsn)), + }, + } + + expected := "dbname=db host=localhost port=5432 user=user" + actual, err := plugin.SanitizedAddress() + require.NoError(t, err) + require.Equalf(t, expected, actual, "initial: %s", dsn) + }) + } +} diff --git a/plugins/inputs/postgresql/service.go b/plugins/inputs/postgresql/service.go index 6c76d76f3..d9447b4d3 100644 --- a/plugins/inputs/postgresql/service.go +++ b/plugins/inputs/postgresql/service.go @@ -17,72 +17,76 @@ import ( "github.com/influxdata/telegraf/config" ) -// pulled from lib/pq -// ParseURL no longer needs to be used by clients of this library since supplying a URL as a -// connection string to sql.Open() is now supported: -// -// sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full") -// -// It remains exported here for backwards-compatibility. -// -// ParseURL converts a url to a connection string for driver.Open. -// Example: -// -// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full" -// -// converts to: -// -// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full" -// -// A minimal example: -// -// "postgres://" -// -// This will be blank, causing driver.Open to use all of the defaults -func parseURL(uri string) (string, error) { +// Based on parseURLSettings() at https://github.com/jackc/pgx/blob/master/pgconn/config.go +func toKeyValue(uri string) (string, error) { u, err := url.Parse(uri) if err != nil { - return "", err + return "", fmt.Errorf("parsing URI failed: %w", err) } + // Check the protocol if u.Scheme != "postgres" && u.Scheme != "postgresql" { return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) } - var kvs []string - escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) - accrue := func(k, v string) { - if v != "" { - kvs = append(kvs, k+"="+escaper.Replace(v)) + quoteIfNecessary := func(v string) string { + if !strings.ContainsAny(v, ` ='\`) { + return v + } + r := strings.ReplaceAll(v, `\`, `\\`) + r = strings.ReplaceAll(r, `'`, `\'`) + return "'" + r + "'" + } + + // Extract the parameters + parts := make([]string, 0, len(u.Query())+5) + if u.User != nil { + parts = append(parts, "user="+quoteIfNecessary(u.User.Username())) + if password, found := u.User.Password(); found { + parts = append(parts, "password="+quoteIfNecessary(password)) } } - if u.User != nil { - v := u.User.Username() - accrue("user", v) + // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. + hostParts := strings.Split(u.Host, ",") + hosts := make([]string, 0, len(hostParts)) + ports := make([]string, 0, len(hostParts)) + var anyPortSet bool + for _, host := range hostParts { + if host == "" { + continue + } - v, _ = u.User.Password() - accrue("password", v) + h, p, err := net.SplitHostPort(host) + if err != nil { + if !strings.Contains(err.Error(), "missing port") { + return "", fmt.Errorf("failed to process host %q: %w", host, err) + } + h = host + } + anyPortSet = anyPortSet || err == nil + hosts = append(hosts, h) + ports = append(ports, p) + } + if len(hosts) > 0 { + parts = append(parts, "host="+strings.Join(hosts, ",")) + } + if anyPortSet { + parts = append(parts, "port="+strings.Join(ports, ",")) } - if host, port, err := net.SplitHostPort(u.Host); err != nil { - accrue("host", u.Host) - } else { - accrue("host", host) - accrue("port", port) + database := strings.TrimLeft(u.Path, "/") + if database != "" { + parts = append(parts, "dbname="+quoteIfNecessary(database)) } - if u.Path != "" { - accrue("dbname", u.Path[1:]) + for k, v := range u.Query() { + parts = append(parts, k+"="+quoteIfNecessary(strings.Join(v, ","))) } - q := u.Query() - for k := range q { - accrue(k, q.Get(k)) - } - - sort.Strings(kvs) // Makes testing easier (not a performance concern) - return strings.Join(kvs, " "), nil + // Required to produce a repeatable output e.g. for tags or testing + sort.Strings(parts) + return strings.Join(parts, " "), nil } // Service common functionality shared between the postgresql and postgresql_extensible @@ -148,30 +152,32 @@ func (p *Service) Stop() { p.DB.Close() } -var kvMatcher, _ = regexp.Compile(`(password|sslcert|sslkey|sslmode|sslrootcert)=\S+ ?`) +var sanitizer = regexp.MustCompile(`(\s|^)((?:password|sslcert|sslkey|sslmode|sslrootcert)\s?=\s?(?:(?:'(?:[^'\\]|\\.)*')|(?:\S+)))`) // SanitizedAddress utility function to strip sensitive information from the connection string. -func (p *Service) SanitizedAddress() (sanitizedAddress string, err error) { +func (p *Service) SanitizedAddress() (string, error) { if p.OutputAddress != "" { return p.OutputAddress, nil } - addr, err := p.Address.Get() + // Get the address + addrSecret, err := p.Address.Get() if err != nil { - return sanitizedAddress, fmt.Errorf("getting address for sanitization failed: %w", err) + return "", fmt.Errorf("getting address for sanitization failed: %w", err) } - defer addr.Destroy() + defer addrSecret.Destroy() - var canonicalizedAddress string - if strings.HasPrefix(addr.TemporaryString(), "postgres://") || strings.HasPrefix(addr.TemporaryString(), "postgresql://") { - if canonicalizedAddress, err = parseURL(addr.String()); err != nil { - return sanitizedAddress, err + // Make sure we convert URI-formatted strings into key-values + addr := addrSecret.TemporaryString() + if strings.HasPrefix(addr, "postgres://") || strings.HasPrefix(addr, "postgresql://") { + if addr, err = toKeyValue(addr); err != nil { + return "", err } - } else { - canonicalizedAddress = addr.String() } - return kvMatcher.ReplaceAllString(canonicalizedAddress, ""), nil + // Sanitize the string using a regular expression + sanitized := sanitizer.ReplaceAllString(addr, "") + return strings.TrimSpace(sanitized), nil } // GetConnectDatabase utility function for getting the database to which the connection was made