fix: race condition in cookie test (#9659)

This commit is contained in:
Sebastian Spaink 2021-09-01 22:21:53 -07:00 committed by GitHub
parent b8ff3e9c56
commit 167b6e0075
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 92 deletions

View File

@ -1,12 +1,14 @@
package cookie
import (
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/cookiejar"
"strings"
"sync"
"time"
clockutil "github.com/benbjohnson/clock"
@ -26,9 +28,25 @@ type CookieAuthConfig struct {
Renewal config.Duration `toml:"cookie_auth_renewal"`
client *http.Client
wg sync.WaitGroup
}
func (c *CookieAuthConfig) Start(client *http.Client, log telegraf.Logger, clock clockutil.Clock) (err error) {
if err = c.initializeClient(client); err != nil {
return err
}
// continual auth renewal if set
if c.Renewal > 0 {
ticker := clock.Ticker(time.Duration(c.Renewal))
// this context is used in the tests only, it is to cancel the goroutine
go c.authRenewal(context.Background(), ticker, log)
}
return nil
}
func (c *CookieAuthConfig) initializeClient(client *http.Client) (err error) {
c.client = client
if c.Method == "" {
@ -40,23 +58,21 @@ func (c *CookieAuthConfig) Start(client *http.Client, log telegraf.Logger, clock
return err
}
if err = c.auth(); err != nil {
return err
}
return c.auth()
}
// continual auth renewal if set
if c.Renewal > 0 {
ticker := clock.Ticker(time.Duration(c.Renewal))
go func() {
for range ticker.C {
if err := c.auth(); err != nil && log != nil {
log.Errorf("renewal failed for %q: %v", c.URL, err)
}
func (c *CookieAuthConfig) authRenewal(ctx context.Context, ticker *clockutil.Ticker, log telegraf.Logger) {
for {
select {
case <-ctx.Done():
c.wg.Done()
return
case <-ticker.C:
if err := c.auth(); err != nil && log != nil {
log.Errorf("renewal failed for %q: %v", c.URL, err)
}
}()
}
}
return nil
}
func (c *CookieAuthConfig) auth() error {

View File

@ -1,6 +1,7 @@
package cookie_test
package cookie
import (
"context"
"fmt"
"io/ioutil"
"net/http"
@ -12,7 +13,6 @@ import (
clockutil "github.com/benbjohnson/clock"
"github.com/google/go-cmp/cmp"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/plugins/common/cookie"
"github.com/influxdata/telegraf/testutil"
"github.com/stretchr/testify/require"
)
@ -118,44 +118,25 @@ func TestAuthConfig_Start(t *testing.T) {
endpoint string
}
tests := []struct {
name string
fields fields
args args
wantErr error
assert func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock)
name string
fields fields
args args
wantErr error
firstAuthCount int32
lastAuthCount int32
firstHTTPResponse int
lastHTTPResponse int
}{
{
name: "zero renewal does not renew",
args: args{
renewal: 0,
endpoint: authEndpointNoCreds,
},
assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) {
// should have Cookie Authed once
srv.checkAuthCount(t, 1)
srv.checkResp(t, http.StatusOK)
mock.Add(renewalCheck)
srv.checkAuthCount(t, 1)
srv.checkResp(t, http.StatusOK)
},
},
{
name: "success no creds, no body, default method",
args: args{
renewal: renewal,
endpoint: authEndpointNoCreds,
},
assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) {
// should have Cookie Authed once
srv.checkAuthCount(t, 1)
// default method set
require.Equal(t, http.MethodPost, c.Method)
srv.checkResp(t, http.StatusOK)
mock.Add(renewalCheck)
// should have Cookie Authed at least twice more
srv.checkAuthCount(t, 3)
srv.checkResp(t, http.StatusOK)
},
firstAuthCount: 1,
lastAuthCount: 3,
firstHTTPResponse: http.StatusOK,
lastHTTPResponse: http.StatusOK,
},
{
name: "success with creds, no body",
@ -168,15 +149,10 @@ func TestAuthConfig_Start(t *testing.T) {
renewal: renewal,
endpoint: authEndpointWithBasicAuth,
},
assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) {
// should have Cookie Authed once
srv.checkAuthCount(t, 1)
srv.checkResp(t, http.StatusOK)
mock.Add(renewalCheck)
// should have Cookie Authed at least twice more
srv.checkAuthCount(t, 3)
srv.checkResp(t, http.StatusOK)
},
firstAuthCount: 1,
lastAuthCount: 3,
firstHTTPResponse: http.StatusOK,
lastHTTPResponse: http.StatusOK,
},
{
name: "failure with bad creds",
@ -189,16 +165,11 @@ func TestAuthConfig_Start(t *testing.T) {
renewal: renewal,
endpoint: authEndpointWithBasicAuth,
},
wantErr: fmt.Errorf("cookie auth renewal received status code: 401 (Unauthorized)"),
assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) {
// should have never Cookie Authed
srv.checkAuthCount(t, 0)
srv.checkResp(t, http.StatusForbidden)
mock.Add(renewalCheck)
// should have still never Cookie Authed
srv.checkAuthCount(t, 0)
srv.checkResp(t, http.StatusForbidden)
},
wantErr: fmt.Errorf("cookie auth renewal received status code: 401 (Unauthorized)"),
firstAuthCount: 0,
lastAuthCount: 0,
firstHTTPResponse: http.StatusForbidden,
lastHTTPResponse: http.StatusForbidden,
},
{
name: "success with no creds, with good body",
@ -210,15 +181,10 @@ func TestAuthConfig_Start(t *testing.T) {
renewal: renewal,
endpoint: authEndpointWithBody,
},
assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) {
// should have Cookie Authed once
srv.checkAuthCount(t, 1)
srv.checkResp(t, http.StatusOK)
mock.Add(renewalCheck)
// should have Cookie Authed at least twice more
srv.checkAuthCount(t, 3)
srv.checkResp(t, http.StatusOK)
},
firstAuthCount: 1,
lastAuthCount: 3,
firstHTTPResponse: http.StatusOK,
lastHTTPResponse: http.StatusOK,
},
{
name: "failure with bad body",
@ -230,23 +196,18 @@ func TestAuthConfig_Start(t *testing.T) {
renewal: renewal,
endpoint: authEndpointWithBody,
},
wantErr: fmt.Errorf("cookie auth renewal received status code: 401 (Unauthorized)"),
assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) {
// should have never Cookie Authed
srv.checkAuthCount(t, 0)
srv.checkResp(t, http.StatusForbidden)
mock.Add(renewalCheck)
// should have still never Cookie Authed
srv.checkAuthCount(t, 0)
srv.checkResp(t, http.StatusForbidden)
},
wantErr: fmt.Errorf("cookie auth renewal received status code: 401 (Unauthorized)"),
firstAuthCount: 0,
lastAuthCount: 0,
firstHTTPResponse: http.StatusForbidden,
lastHTTPResponse: http.StatusForbidden,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
srv := newFakeServer(t)
c := &cookie.CookieAuthConfig{
c := &CookieAuthConfig{
URL: srv.URL + tt.args.endpoint,
Method: tt.fields.Method,
Username: tt.fields.Username,
@ -254,17 +215,28 @@ func TestAuthConfig_Start(t *testing.T) {
Body: tt.fields.Body,
Renewal: config.Duration(tt.args.renewal),
}
mock := clockutil.NewMock()
if err := c.Start(srv.Client(), testutil.Logger{Name: "cookie_auth"}, mock); tt.wantErr != nil {
if err := c.initializeClient(srv.Client()); tt.wantErr != nil {
require.EqualError(t, err, tt.wantErr.Error())
} else {
require.NoError(t, err)
}
mock := clockutil.NewMock()
ticker := mock.Ticker(time.Duration(c.Renewal))
defer ticker.Stop()
c.wg.Add(1)
ctx, cancel := context.WithCancel(context.Background())
go c.authRenewal(ctx, ticker, testutil.Logger{Name: "cookie_auth"})
srv.checkAuthCount(t, tt.firstAuthCount)
srv.checkResp(t, tt.firstHTTPResponse)
mock.Add(renewalCheck)
// Ensure that the auth renewal goroutine has completed
cancel()
c.wg.Wait()
srv.checkAuthCount(t, tt.lastAuthCount)
srv.checkResp(t, tt.lastHTTPResponse)
if tt.assert != nil {
tt.assert(t, c, srv, mock)
}
srv.Close()
})
}