fix: migrate aws/credentials.go to use NewSession, same functionality but now supports error (#9878)

This commit is contained in:
Sebastian Spaink 2021-10-07 15:47:56 -05:00 committed by GitHub
parent da5727e34c
commit fde637464a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 83 additions and 51 deletions

View File

@ -21,7 +21,7 @@ type CredentialConfig struct {
WebIdentityTokenFile string `toml:"web_identity_token_file"` WebIdentityTokenFile string `toml:"web_identity_token_file"`
} }
func (c *CredentialConfig) Credentials() client.ConfigProvider { func (c *CredentialConfig) Credentials() (client.ConfigProvider, error) {
if c.RoleARN != "" { if c.RoleARN != "" {
return c.assumeCredentials() return c.assumeCredentials()
} }
@ -29,7 +29,7 @@ func (c *CredentialConfig) Credentials() client.ConfigProvider {
return c.rootCredentials() return c.rootCredentials()
} }
func (c *CredentialConfig) rootCredentials() client.ConfigProvider { func (c *CredentialConfig) rootCredentials() (client.ConfigProvider, error) {
config := &aws.Config{ config := &aws.Config{
Region: aws.String(c.Region), Region: aws.String(c.Region),
} }
@ -42,11 +42,14 @@ func (c *CredentialConfig) rootCredentials() client.ConfigProvider {
config.Credentials = credentials.NewSharedCredentials(c.Filename, c.Profile) config.Credentials = credentials.NewSharedCredentials(c.Filename, c.Profile)
} }
return session.New(config) return session.NewSession(config)
} }
func (c *CredentialConfig) assumeCredentials() client.ConfigProvider { func (c *CredentialConfig) assumeCredentials() (client.ConfigProvider, error) {
rootCredentials := c.rootCredentials() rootCredentials, err := c.rootCredentials()
if err != nil {
return nil, err
}
config := &aws.Config{ config := &aws.Config{
Region: aws.String(c.Region), Region: aws.String(c.Region),
Endpoint: &c.EndpointURL, Endpoint: &c.EndpointURL,
@ -58,5 +61,5 @@ func (c *CredentialConfig) assumeCredentials() client.ConfigProvider {
config.Credentials = stscreds.NewCredentials(rootCredentials, c.RoleARN) config.Credentials = stscreds.NewCredentials(rootCredentials, c.RoleARN)
} }
return session.New(config) return session.NewSession(config)
} }

View File

@ -288,7 +288,11 @@ func (c *CloudWatch) initializeCloudWatch() error {
} }
loglevel := aws.LogOff loglevel := aws.LogOff
c.client = cwClient.New(c.CredentialConfig.Credentials(), cfg.WithLogLevel(loglevel)) p, err := c.CredentialConfig.Credentials()
if err != nil {
return err
}
c.client = cwClient.New(p, cfg.WithLogLevel(loglevel))
// Initialize regex matchers for each Dimension value. // Initialize regex matchers for each Dimension value.
for _, m := range c.Metrics { for _, m := range c.Metrics {

View File

@ -153,15 +153,15 @@ func (k *KinesisConsumer) SetParser(parser parsers.Parser) {
} }
func (k *KinesisConsumer) connect(ac telegraf.Accumulator) error { func (k *KinesisConsumer) connect(ac telegraf.Accumulator) error {
client := kinesis.New(k.CredentialConfig.Credentials()) p, err := k.CredentialConfig.Credentials()
if err != nil {
return err
}
client := kinesis.New(p)
k.checkpoint = &noopCheckpoint{} k.checkpoint = &noopCheckpoint{}
if k.DynamoDB != nil { if k.DynamoDB != nil {
var err error p, err := (&internalaws.CredentialConfig{
k.checkpoint, err = ddb.New(
k.DynamoDB.AppName,
k.DynamoDB.TableName,
ddb.WithDynamoClient(dynamodb.New((&internalaws.CredentialConfig{
Region: k.Region, Region: k.Region,
AccessKey: k.AccessKey, AccessKey: k.AccessKey,
SecretKey: k.SecretKey, SecretKey: k.SecretKey,
@ -170,7 +170,14 @@ func (k *KinesisConsumer) connect(ac telegraf.Accumulator) error {
Filename: k.Filename, Filename: k.Filename,
Token: k.Token, Token: k.Token,
EndpointURL: k.EndpointURL, EndpointURL: k.EndpointURL,
}).Credentials())), }).Credentials()
if err != nil {
return err
}
k.checkpoint, err = ddb.New(
k.DynamoDB.AppName,
k.DynamoDB.TableName,
ddb.WithDynamoClient(dynamodb.New(p)),
ddb.WithMaxInterval(time.Second*10), ddb.WithMaxInterval(time.Second*10),
) )
if err != nil { if err != nil {

View File

@ -198,7 +198,11 @@ func (c *CloudWatch) Description() string {
} }
func (c *CloudWatch) Connect() error { func (c *CloudWatch) Connect() error {
c.svc = cloudwatch.New(c.CredentialConfig.Credentials()) p, err := c.CredentialConfig.Credentials()
if err != nil {
return err
}
c.svc = cloudwatch.New(p)
return nil return nil
} }

View File

@ -187,7 +187,11 @@ func (c *CloudWatchLogs) Connect() error {
var logGroupsOutput = &cloudwatchlogs.DescribeLogGroupsOutput{NextToken: &dummyToken} var logGroupsOutput = &cloudwatchlogs.DescribeLogGroupsOutput{NextToken: &dummyToken}
var err error var err error
c.svc = cloudwatchlogs.New(c.CredentialConfig.Credentials()) p, err := c.CredentialConfig.Credentials()
if err != nil {
return err
}
c.svc = cloudwatchlogs.New(p)
if c.svc == nil { if c.svc == nil {
return fmt.Errorf("can't create cloudwatch logs service endpoint") return fmt.Errorf("can't create cloudwatch logs service endpoint")
} }

View File

@ -126,9 +126,13 @@ func (k *KinesisOutput) Connect() error {
k.Log.Infof("Establishing a connection to Kinesis in %s", k.Region) k.Log.Infof("Establishing a connection to Kinesis in %s", k.Region)
} }
svc := kinesis.New(k.CredentialConfig.Credentials()) p, err := k.CredentialConfig.Credentials()
if err != nil {
return err
}
svc := kinesis.New(p)
_, err := svc.DescribeStreamSummary(&kinesis.DescribeStreamSummaryInput{ _, err = svc.DescribeStreamSummary(&kinesis.DescribeStreamSummaryInput{
StreamName: aws.String(k.StreamName), StreamName: aws.String(k.StreamName),
}) })
k.svc = svc k.svc = svc

View File

@ -169,9 +169,12 @@ var sampleConfig = `
` `
// WriteFactory function provides a way to mock the client instantiation for testing purposes. // WriteFactory function provides a way to mock the client instantiation for testing purposes.
var WriteFactory = func(credentialConfig *internalaws.CredentialConfig) WriteClient { var WriteFactory = func(credentialConfig *internalaws.CredentialConfig) (WriteClient, error) {
configProvider := credentialConfig.Credentials() configProvider, err := credentialConfig.Credentials()
return timestreamwrite.New(configProvider) if err != nil {
return nil, err
}
return timestreamwrite.New(configProvider), nil
} }
func (t *Timestream) Connect() error { func (t *Timestream) Connect() error {
@ -221,7 +224,10 @@ func (t *Timestream) Connect() error {
t.Log.Infof("Constructing Timestream client for '%s' mode", t.MappingMode) t.Log.Infof("Constructing Timestream client for '%s' mode", t.MappingMode)
svc := WriteFactory(&t.CredentialConfig) svc, err := WriteFactory(&t.CredentialConfig)
if err != nil {
return err
}
if t.DescribeDatabaseOnStart { if t.DescribeDatabaseOnStart {
t.Log.Infof("Describing database '%s' in region '%s'", t.DatabaseName, t.Region) t.Log.Infof("Describing database '%s' in region '%s'", t.DatabaseName, t.Region)

View File

@ -2,7 +2,6 @@ package timestream_test
import ( import (
"fmt" "fmt"
"github.com/aws/aws-sdk-go/aws/awserr"
"reflect" "reflect"
"sort" "sort"
"strconv" "strconv"
@ -10,6 +9,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/timestreamwrite" "github.com/aws/aws-sdk-go/service/timestreamwrite"
"github.com/influxdata/telegraf" "github.com/influxdata/telegraf"
@ -53,10 +54,9 @@ func (m *mockTimestreamClient) DescribeDatabase(*timestreamwrite.DescribeDatabas
func TestConnectValidatesConfigParameters(t *testing.T) { func TestConnectValidatesConfigParameters(t *testing.T) {
assertions := assert.New(t) assertions := assert.New(t)
ts.WriteFactory = func(credentialConfig *internalaws.CredentialConfig) ts.WriteClient { ts.WriteFactory = func(credentialConfig *internalaws.CredentialConfig) (ts.WriteClient, error) {
return &mockTimestreamClient{} return &mockTimestreamClient{}, nil
} }
// checking base arguments // checking base arguments
noDatabaseName := ts.Timestream{Log: testutil.Logger{}} noDatabaseName := ts.Timestream{Log: testutil.Logger{}}
assertions.Contains(noDatabaseName.Connect().Error(), "DatabaseName") assertions.Contains(noDatabaseName.Connect().Error(), "DatabaseName")
@ -182,11 +182,11 @@ func (m *mockTimestreamErrorClient) DescribeDatabase(*timestreamwrite.DescribeDa
func TestThrottlingErrorIsReturnedToTelegraf(t *testing.T) { func TestThrottlingErrorIsReturnedToTelegraf(t *testing.T) {
assertions := assert.New(t) assertions := assert.New(t)
ts.WriteFactory = func(credentialConfig *internalaws.CredentialConfig) ts.WriteClient { ts.WriteFactory = func(credentialConfig *internalaws.CredentialConfig) (ts.WriteClient, error) {
return &mockTimestreamErrorClient{ return &mockTimestreamErrorClient{
awserr.New(timestreamwrite.ErrCodeThrottlingException, awserr.New(timestreamwrite.ErrCodeThrottlingException,
"Throttling Test", nil), "Throttling Test", nil),
} }, nil
} }
plugin := ts.Timestream{ plugin := ts.Timestream{
MappingMode: ts.MappingModeMultiTable, MappingMode: ts.MappingModeMultiTable,
@ -210,11 +210,11 @@ func TestThrottlingErrorIsReturnedToTelegraf(t *testing.T) {
func TestRejectedRecordsErrorResultsInMetricsBeingSkipped(t *testing.T) { func TestRejectedRecordsErrorResultsInMetricsBeingSkipped(t *testing.T) {
assertions := assert.New(t) assertions := assert.New(t)
ts.WriteFactory = func(credentialConfig *internalaws.CredentialConfig) ts.WriteClient { ts.WriteFactory = func(credentialConfig *internalaws.CredentialConfig) (ts.WriteClient, error) {
return &mockTimestreamErrorClient{ return &mockTimestreamErrorClient{
awserr.New(timestreamwrite.ErrCodeRejectedRecordsException, awserr.New(timestreamwrite.ErrCodeRejectedRecordsException,
"RejectedRecords Test", nil), "RejectedRecords Test", nil),
} }, nil
} }
plugin := ts.Timestream{ plugin := ts.Timestream{
MappingMode: ts.MappingModeMultiTable, MappingMode: ts.MappingModeMultiTable,