fix: race condition in cookie test (#9659)
This commit is contained in:
parent
b8ff3e9c56
commit
167b6e0075
|
|
@ -1,12 +1,14 @@
|
||||||
package cookie
|
package cookie
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/cookiejar"
|
"net/http/cookiejar"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
clockutil "github.com/benbjohnson/clock"
|
clockutil "github.com/benbjohnson/clock"
|
||||||
|
|
@ -26,9 +28,25 @@ type CookieAuthConfig struct {
|
||||||
Renewal config.Duration `toml:"cookie_auth_renewal"`
|
Renewal config.Duration `toml:"cookie_auth_renewal"`
|
||||||
|
|
||||||
client *http.Client
|
client *http.Client
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CookieAuthConfig) Start(client *http.Client, log telegraf.Logger, clock clockutil.Clock) (err error) {
|
func (c *CookieAuthConfig) Start(client *http.Client, log telegraf.Logger, clock clockutil.Clock) (err error) {
|
||||||
|
if err = c.initializeClient(client); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// continual auth renewal if set
|
||||||
|
if c.Renewal > 0 {
|
||||||
|
ticker := clock.Ticker(time.Duration(c.Renewal))
|
||||||
|
// this context is used in the tests only, it is to cancel the goroutine
|
||||||
|
go c.authRenewal(context.Background(), ticker, log)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CookieAuthConfig) initializeClient(client *http.Client) (err error) {
|
||||||
c.client = client
|
c.client = client
|
||||||
|
|
||||||
if c.Method == "" {
|
if c.Method == "" {
|
||||||
|
|
@ -40,23 +58,21 @@ func (c *CookieAuthConfig) Start(client *http.Client, log telegraf.Logger, clock
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = c.auth(); err != nil {
|
return c.auth()
|
||||||
return err
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// continual auth renewal if set
|
func (c *CookieAuthConfig) authRenewal(ctx context.Context, ticker *clockutil.Ticker, log telegraf.Logger) {
|
||||||
if c.Renewal > 0 {
|
for {
|
||||||
ticker := clock.Ticker(time.Duration(c.Renewal))
|
select {
|
||||||
go func() {
|
case <-ctx.Done():
|
||||||
for range ticker.C {
|
c.wg.Done()
|
||||||
if err := c.auth(); err != nil && log != nil {
|
return
|
||||||
log.Errorf("renewal failed for %q: %v", c.URL, err)
|
case <-ticker.C:
|
||||||
}
|
if err := c.auth(); err != nil && log != nil {
|
||||||
|
log.Errorf("renewal failed for %q: %v", c.URL, err)
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CookieAuthConfig) auth() error {
|
func (c *CookieAuthConfig) auth() error {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package cookie_test
|
package cookie
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
@ -12,7 +13,6 @@ import (
|
||||||
clockutil "github.com/benbjohnson/clock"
|
clockutil "github.com/benbjohnson/clock"
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/influxdata/telegraf/config"
|
"github.com/influxdata/telegraf/config"
|
||||||
"github.com/influxdata/telegraf/plugins/common/cookie"
|
|
||||||
"github.com/influxdata/telegraf/testutil"
|
"github.com/influxdata/telegraf/testutil"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
@ -118,44 +118,25 @@ func TestAuthConfig_Start(t *testing.T) {
|
||||||
endpoint string
|
endpoint string
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
fields fields
|
fields fields
|
||||||
args args
|
args args
|
||||||
wantErr error
|
wantErr error
|
||||||
assert func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock)
|
firstAuthCount int32
|
||||||
|
lastAuthCount int32
|
||||||
|
firstHTTPResponse int
|
||||||
|
lastHTTPResponse int
|
||||||
}{
|
}{
|
||||||
{
|
|
||||||
name: "zero renewal does not renew",
|
|
||||||
args: args{
|
|
||||||
renewal: 0,
|
|
||||||
endpoint: authEndpointNoCreds,
|
|
||||||
},
|
|
||||||
assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) {
|
|
||||||
// should have Cookie Authed once
|
|
||||||
srv.checkAuthCount(t, 1)
|
|
||||||
srv.checkResp(t, http.StatusOK)
|
|
||||||
mock.Add(renewalCheck)
|
|
||||||
srv.checkAuthCount(t, 1)
|
|
||||||
srv.checkResp(t, http.StatusOK)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "success no creds, no body, default method",
|
name: "success no creds, no body, default method",
|
||||||
args: args{
|
args: args{
|
||||||
renewal: renewal,
|
renewal: renewal,
|
||||||
endpoint: authEndpointNoCreds,
|
endpoint: authEndpointNoCreds,
|
||||||
},
|
},
|
||||||
assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) {
|
firstAuthCount: 1,
|
||||||
// should have Cookie Authed once
|
lastAuthCount: 3,
|
||||||
srv.checkAuthCount(t, 1)
|
firstHTTPResponse: http.StatusOK,
|
||||||
// default method set
|
lastHTTPResponse: http.StatusOK,
|
||||||
require.Equal(t, http.MethodPost, c.Method)
|
|
||||||
srv.checkResp(t, http.StatusOK)
|
|
||||||
mock.Add(renewalCheck)
|
|
||||||
// should have Cookie Authed at least twice more
|
|
||||||
srv.checkAuthCount(t, 3)
|
|
||||||
srv.checkResp(t, http.StatusOK)
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "success with creds, no body",
|
name: "success with creds, no body",
|
||||||
|
|
@ -168,15 +149,10 @@ func TestAuthConfig_Start(t *testing.T) {
|
||||||
renewal: renewal,
|
renewal: renewal,
|
||||||
endpoint: authEndpointWithBasicAuth,
|
endpoint: authEndpointWithBasicAuth,
|
||||||
},
|
},
|
||||||
assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) {
|
firstAuthCount: 1,
|
||||||
// should have Cookie Authed once
|
lastAuthCount: 3,
|
||||||
srv.checkAuthCount(t, 1)
|
firstHTTPResponse: http.StatusOK,
|
||||||
srv.checkResp(t, http.StatusOK)
|
lastHTTPResponse: http.StatusOK,
|
||||||
mock.Add(renewalCheck)
|
|
||||||
// should have Cookie Authed at least twice more
|
|
||||||
srv.checkAuthCount(t, 3)
|
|
||||||
srv.checkResp(t, http.StatusOK)
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "failure with bad creds",
|
name: "failure with bad creds",
|
||||||
|
|
@ -189,16 +165,11 @@ func TestAuthConfig_Start(t *testing.T) {
|
||||||
renewal: renewal,
|
renewal: renewal,
|
||||||
endpoint: authEndpointWithBasicAuth,
|
endpoint: authEndpointWithBasicAuth,
|
||||||
},
|
},
|
||||||
wantErr: fmt.Errorf("cookie auth renewal received status code: 401 (Unauthorized)"),
|
wantErr: fmt.Errorf("cookie auth renewal received status code: 401 (Unauthorized)"),
|
||||||
assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) {
|
firstAuthCount: 0,
|
||||||
// should have never Cookie Authed
|
lastAuthCount: 0,
|
||||||
srv.checkAuthCount(t, 0)
|
firstHTTPResponse: http.StatusForbidden,
|
||||||
srv.checkResp(t, http.StatusForbidden)
|
lastHTTPResponse: http.StatusForbidden,
|
||||||
mock.Add(renewalCheck)
|
|
||||||
// should have still never Cookie Authed
|
|
||||||
srv.checkAuthCount(t, 0)
|
|
||||||
srv.checkResp(t, http.StatusForbidden)
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "success with no creds, with good body",
|
name: "success with no creds, with good body",
|
||||||
|
|
@ -210,15 +181,10 @@ func TestAuthConfig_Start(t *testing.T) {
|
||||||
renewal: renewal,
|
renewal: renewal,
|
||||||
endpoint: authEndpointWithBody,
|
endpoint: authEndpointWithBody,
|
||||||
},
|
},
|
||||||
assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) {
|
firstAuthCount: 1,
|
||||||
// should have Cookie Authed once
|
lastAuthCount: 3,
|
||||||
srv.checkAuthCount(t, 1)
|
firstHTTPResponse: http.StatusOK,
|
||||||
srv.checkResp(t, http.StatusOK)
|
lastHTTPResponse: http.StatusOK,
|
||||||
mock.Add(renewalCheck)
|
|
||||||
// should have Cookie Authed at least twice more
|
|
||||||
srv.checkAuthCount(t, 3)
|
|
||||||
srv.checkResp(t, http.StatusOK)
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "failure with bad body",
|
name: "failure with bad body",
|
||||||
|
|
@ -230,23 +196,18 @@ func TestAuthConfig_Start(t *testing.T) {
|
||||||
renewal: renewal,
|
renewal: renewal,
|
||||||
endpoint: authEndpointWithBody,
|
endpoint: authEndpointWithBody,
|
||||||
},
|
},
|
||||||
wantErr: fmt.Errorf("cookie auth renewal received status code: 401 (Unauthorized)"),
|
wantErr: fmt.Errorf("cookie auth renewal received status code: 401 (Unauthorized)"),
|
||||||
assert: func(t *testing.T, c *cookie.CookieAuthConfig, srv fakeServer, mock *clockutil.Mock) {
|
firstAuthCount: 0,
|
||||||
// should have never Cookie Authed
|
lastAuthCount: 0,
|
||||||
srv.checkAuthCount(t, 0)
|
firstHTTPResponse: http.StatusForbidden,
|
||||||
srv.checkResp(t, http.StatusForbidden)
|
lastHTTPResponse: http.StatusForbidden,
|
||||||
mock.Add(renewalCheck)
|
|
||||||
// should have still never Cookie Authed
|
|
||||||
srv.checkAuthCount(t, 0)
|
|
||||||
srv.checkResp(t, http.StatusForbidden)
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
tt := tt
|
tt := tt
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
srv := newFakeServer(t)
|
srv := newFakeServer(t)
|
||||||
c := &cookie.CookieAuthConfig{
|
c := &CookieAuthConfig{
|
||||||
URL: srv.URL + tt.args.endpoint,
|
URL: srv.URL + tt.args.endpoint,
|
||||||
Method: tt.fields.Method,
|
Method: tt.fields.Method,
|
||||||
Username: tt.fields.Username,
|
Username: tt.fields.Username,
|
||||||
|
|
@ -254,17 +215,28 @@ func TestAuthConfig_Start(t *testing.T) {
|
||||||
Body: tt.fields.Body,
|
Body: tt.fields.Body,
|
||||||
Renewal: config.Duration(tt.args.renewal),
|
Renewal: config.Duration(tt.args.renewal),
|
||||||
}
|
}
|
||||||
|
if err := c.initializeClient(srv.Client()); tt.wantErr != nil {
|
||||||
mock := clockutil.NewMock()
|
|
||||||
if err := c.Start(srv.Client(), testutil.Logger{Name: "cookie_auth"}, mock); tt.wantErr != nil {
|
|
||||||
require.EqualError(t, err, tt.wantErr.Error())
|
require.EqualError(t, err, tt.wantErr.Error())
|
||||||
} else {
|
} else {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
mock := clockutil.NewMock()
|
||||||
|
ticker := mock.Ticker(time.Duration(c.Renewal))
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
c.wg.Add(1)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go c.authRenewal(ctx, ticker, testutil.Logger{Name: "cookie_auth"})
|
||||||
|
|
||||||
|
srv.checkAuthCount(t, tt.firstAuthCount)
|
||||||
|
srv.checkResp(t, tt.firstHTTPResponse)
|
||||||
|
mock.Add(renewalCheck)
|
||||||
|
// Ensure that the auth renewal goroutine has completed
|
||||||
|
cancel()
|
||||||
|
c.wg.Wait()
|
||||||
|
srv.checkAuthCount(t, tt.lastAuthCount)
|
||||||
|
srv.checkResp(t, tt.lastHTTPResponse)
|
||||||
|
|
||||||
if tt.assert != nil {
|
|
||||||
tt.assert(t, c, srv, mock)
|
|
||||||
}
|
|
||||||
srv.Close()
|
srv.Close()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue