From 167b6e0075b5da04fd3c33a86e68c95b3727e485 Mon Sep 17 00:00:00 2001 From: Sebastian Spaink <3441183+sspaink@users.noreply.github.com> Date: Wed, 1 Sep 2021 22:21:53 -0700 Subject: [PATCH] fix: race condition in cookie test (#9659) --- plugins/common/cookie/cookie.go | 44 ++++++--- plugins/common/cookie/cookie_test.go | 128 +++++++++++---------------- 2 files changed, 80 insertions(+), 92 deletions(-) diff --git a/plugins/common/cookie/cookie.go b/plugins/common/cookie/cookie.go index 10213f78d..e452a50a4 100644 --- a/plugins/common/cookie/cookie.go +++ b/plugins/common/cookie/cookie.go @@ -1,12 +1,14 @@ package cookie import ( + "context" "fmt" "io" "io/ioutil" "net/http" "net/http/cookiejar" "strings" + "sync" "time" clockutil "github.com/benbjohnson/clock" @@ -26,9 +28,25 @@ type CookieAuthConfig struct { Renewal config.Duration `toml:"cookie_auth_renewal"` client *http.Client + wg sync.WaitGroup } func (c *CookieAuthConfig) Start(client *http.Client, log telegraf.Logger, clock clockutil.Clock) (err error) { + if err = c.initializeClient(client); err != nil { + return err + } + + // continual auth renewal if set + if c.Renewal > 0 { + ticker := clock.Ticker(time.Duration(c.Renewal)) + // this context is used in the tests only, it is to cancel the goroutine + go c.authRenewal(context.Background(), ticker, log) + } + + return nil +} + +func (c *CookieAuthConfig) initializeClient(client *http.Client) (err error) { c.client = client if c.Method == "" { @@ -40,23 +58,21 @@ func (c *CookieAuthConfig) Start(client *http.Client, log telegraf.Logger, clock return err } - if err = c.auth(); err != nil { - return err - } + return c.auth() +} - // continual auth renewal if set - if c.Renewal > 0 { - ticker := clock.Ticker(time.Duration(c.Renewal)) - go func() { - for range ticker.C { - if err := c.auth(); err != nil && log != nil { - log.Errorf("renewal failed for %q: %v", c.URL, err) - } +func (c *CookieAuthConfig) authRenewal(ctx context.Context, ticker *clockutil.Ticker, log telegraf.Logger) { + for { + select { + case <-ctx.Done(): + c.wg.Done() + return + case <-ticker.C: + if err := c.auth(); err != nil && log != nil { + log.Errorf("renewal failed for %q: %v", c.URL, err) } - }() + } } - - return nil } func (c *CookieAuthConfig) auth() error { diff --git a/plugins/common/cookie/cookie_test.go b/plugins/common/cookie/cookie_test.go index 036ca2b5b..99269c27c 100644 --- a/plugins/common/cookie/cookie_test.go +++ b/plugins/common/cookie/cookie_test.go @@ -1,6 +1,7 @@ -package cookie_test +package cookie import ( + "context" "fmt" "io/ioutil" "net/http" @@ -12,7 +13,6 @@ import ( clockutil "github.com/benbjohnson/clock" "github.com/google/go-cmp/cmp" "github.com/influxdata/telegraf/config" - "github.com/influxdata/telegraf/plugins/common/cookie" "github.com/influxdata/telegraf/testutil" "github.com/stretchr/testify/require" ) @@ -118,44 +118,25 @@ func TestAuthConfig_Start(t *testing.T) { endpoint string } tests := []struct { - name string - fields fields - args args - wantErr error - assert func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) + name string + fields fields + args args + wantErr error + firstAuthCount int32 + lastAuthCount int32 + firstHTTPResponse int + lastHTTPResponse int }{ - { - name: "zero renewal does not renew", - args: args{ - renewal: 0, - endpoint: authEndpointNoCreds, - }, - assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) { - // should have Cookie Authed once - srv.checkAuthCount(t, 1) - srv.checkResp(t, http.StatusOK) - mock.Add(renewalCheck) - srv.checkAuthCount(t, 1) - srv.checkResp(t, http.StatusOK) - }, - }, { name: "success no creds, no body, default method", args: args{ renewal: renewal, endpoint: authEndpointNoCreds, }, - assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) { - // should have Cookie Authed once - srv.checkAuthCount(t, 1) - // default method set - require.Equal(t, http.MethodPost, c.Method) - srv.checkResp(t, http.StatusOK) - mock.Add(renewalCheck) - // should have Cookie Authed at least twice more - srv.checkAuthCount(t, 3) - srv.checkResp(t, http.StatusOK) - }, + firstAuthCount: 1, + lastAuthCount: 3, + firstHTTPResponse: http.StatusOK, + lastHTTPResponse: http.StatusOK, }, { name: "success with creds, no body", @@ -168,15 +149,10 @@ func TestAuthConfig_Start(t *testing.T) { renewal: renewal, endpoint: authEndpointWithBasicAuth, }, - assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) { - // should have Cookie Authed once - srv.checkAuthCount(t, 1) - srv.checkResp(t, http.StatusOK) - mock.Add(renewalCheck) - // should have Cookie Authed at least twice more - srv.checkAuthCount(t, 3) - srv.checkResp(t, http.StatusOK) - }, + firstAuthCount: 1, + lastAuthCount: 3, + firstHTTPResponse: http.StatusOK, + lastHTTPResponse: http.StatusOK, }, { name: "failure with bad creds", @@ -189,16 +165,11 @@ func TestAuthConfig_Start(t *testing.T) { renewal: renewal, endpoint: authEndpointWithBasicAuth, }, - wantErr: fmt.Errorf("cookie auth renewal received status code: 401 (Unauthorized)"), - assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) { - // should have never Cookie Authed - srv.checkAuthCount(t, 0) - srv.checkResp(t, http.StatusForbidden) - mock.Add(renewalCheck) - // should have still never Cookie Authed - srv.checkAuthCount(t, 0) - srv.checkResp(t, http.StatusForbidden) - }, + wantErr: fmt.Errorf("cookie auth renewal received status code: 401 (Unauthorized)"), + firstAuthCount: 0, + lastAuthCount: 0, + firstHTTPResponse: http.StatusForbidden, + lastHTTPResponse: http.StatusForbidden, }, { name: "success with no creds, with good body", @@ -210,15 +181,10 @@ func TestAuthConfig_Start(t *testing.T) { renewal: renewal, endpoint: authEndpointWithBody, }, - assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) { - // should have Cookie Authed once - srv.checkAuthCount(t, 1) - srv.checkResp(t, http.StatusOK) - mock.Add(renewalCheck) - // should have Cookie Authed at least twice more - srv.checkAuthCount(t, 3) - srv.checkResp(t, http.StatusOK) - }, + firstAuthCount: 1, + lastAuthCount: 3, + firstHTTPResponse: http.StatusOK, + lastHTTPResponse: http.StatusOK, }, { name: "failure with bad body", @@ -230,23 +196,18 @@ func TestAuthConfig_Start(t *testing.T) { renewal: renewal, endpoint: authEndpointWithBody, }, - wantErr: fmt.Errorf("cookie auth renewal received status code: 401 (Unauthorized)"), - assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) { - // should have never Cookie Authed - srv.checkAuthCount(t, 0) - srv.checkResp(t, http.StatusForbidden) - mock.Add(renewalCheck) - // should have still never Cookie Authed - srv.checkAuthCount(t, 0) - srv.checkResp(t, http.StatusForbidden) - }, + wantErr: fmt.Errorf("cookie auth renewal received status code: 401 (Unauthorized)"), + firstAuthCount: 0, + lastAuthCount: 0, + firstHTTPResponse: http.StatusForbidden, + lastHTTPResponse: http.StatusForbidden, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { srv := newFakeServer(t) - c := &cookie.CookieAuthConfig{ + c := &CookieAuthConfig{ URL: srv.URL + tt.args.endpoint, Method: tt.fields.Method, Username: tt.fields.Username, @@ -254,17 +215,28 @@ func TestAuthConfig_Start(t *testing.T) { Body: tt.fields.Body, Renewal: config.Duration(tt.args.renewal), } - - mock := clockutil.NewMock() - if err := c.Start(srv.Client(), testutil.Logger{Name: "cookie_auth"}, mock); tt.wantErr != nil { + if err := c.initializeClient(srv.Client()); tt.wantErr != nil { require.EqualError(t, err, tt.wantErr.Error()) } else { require.NoError(t, err) } + mock := clockutil.NewMock() + ticker := mock.Ticker(time.Duration(c.Renewal)) + defer ticker.Stop() + + c.wg.Add(1) + ctx, cancel := context.WithCancel(context.Background()) + go c.authRenewal(ctx, ticker, testutil.Logger{Name: "cookie_auth"}) + + srv.checkAuthCount(t, tt.firstAuthCount) + srv.checkResp(t, tt.firstHTTPResponse) + mock.Add(renewalCheck) + // Ensure that the auth renewal goroutine has completed + cancel() + c.wg.Wait() + srv.checkAuthCount(t, tt.lastAuthCount) + srv.checkResp(t, tt.lastHTTPResponse) - if tt.assert != nil { - tt.assert(t, c, srv, mock) - } srv.Close() }) }