rabbitmq-amqp-go-client/pkg/rabbitmqamqp/oauth2_test.go

208 lines
6.8 KiB
Go
Raw Permalink Normal View History

package rabbitmqamqp
// test the OAuth2 connection
import (
"context"
"encoding/base64"
"github.com/golang-jwt/jwt/v5"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
testhelper "github.com/rabbitmq/rabbitmq-amqp-go-client/pkg/test-helper"
"math/rand"
"time"
)
const Base64Key = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGH"
// const HmacKey KEY = new HmacKey(Base64.getDecoder().decode(Base64Key));
const AUDIENCE = "rabbitmq"
// Helper function to generate random string
func randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
result[i] = charset[rand.Intn(len(charset))]
}
return string(result)
}
var _ = Describe("OAuth2 Tests", func() {
It("OAuth2 Connection should success", func() {
tokenString := token(time.Now().Add(time.Duration(2500) * time.Millisecond))
Expect(tokenString).NotTo(BeEmpty())
conn, err := Dial(context.TODO(), "amqp://localhost:5672",
&AmqpConnOptions{
ContainerID: "oAuth2Test",
OAuth2Options: &OAuth2Options{
Token: tokenString,
},
})
Expect(err).To(BeNil())
Expect(conn).NotTo(BeNil())
qName := generateName("OAuth2 Connection should success")
_, err = conn.Management().DeclareQueue(context.Background(), &QuorumQueueSpecification{
Name: qName,
})
Expect(err).To(BeNil())
Expect(conn.Management().DeleteQueue(context.Background(), qName)).To(BeNil())
Expect(conn.Close(context.Background())).To(BeNil())
})
It("OAuth2 Connection should disconnect after the timeout", func() {
tokenString := token(time.Now().Add(time.Duration(1_000) * time.Millisecond))
Expect(tokenString).NotTo(BeEmpty())
conn, err := Dial(context.TODO(), "amqp://localhost:5672",
&AmqpConnOptions{
ContainerID: "oAuth2TestTimeout",
OAuth2Options: &OAuth2Options{
Token: tokenString,
},
RecoveryConfiguration: &RecoveryConfiguration{
ActiveRecovery: false,
},
})
Expect(err).To(BeNil())
Expect(conn).NotTo(BeNil())
ch := make(chan *StateChanged, 1)
go func() {
defer GinkgoRecover()
for statusChanged := range ch {
x := statusChanged.To.(*StateClosed)
Expect(x.GetError()).NotTo(BeNil())
Expect(x.GetError().Error()).To(ContainSubstring("credential expired"))
}
}()
conn.NotifyStatusChange(ch)
time.Sleep(1 * time.Second)
})
It("OAuth2 Connection should be alive after token refresh", func() {
tokenString := token(time.Now().Add(time.Duration(1) * time.Second))
Expect(tokenString).NotTo(BeEmpty())
conn, err := Dial(context.TODO(), "amqp://localhost:5672",
&AmqpConnOptions{
ContainerID: "oAuth2Test",
OAuth2Options: &OAuth2Options{
Token: tokenString,
},
RecoveryConfiguration: &RecoveryConfiguration{
ActiveRecovery: false,
},
})
Expect(err).To(BeNil())
Expect(conn).NotTo(BeNil())
time.Sleep(100 * time.Millisecond)
err = conn.RefreshToken(context.Background(), token(time.Now().Add(time.Duration(2500)*time.Millisecond)))
time.Sleep(1 * time.Second)
Expect(err).To(BeNil())
Expect(conn.Close(context.Background())).To(BeNil())
})
// this test is a bit flaky, it may fail if the connection is not closed in time
// that should mark as flakes
It("OAuth2 Connection should use the new token to reconnect", func() {
name := "oAuth2TestReconnect_" + time.Now().String()
startToken := token(time.Now().Add(time.Duration(1) * time.Second))
connection, err := Dial(context.Background(), "amqp://", &AmqpConnOptions{
OAuth2Options: &OAuth2Options{
Token: startToken,
},
ContainerID: name,
// reduced the reconnect interval to speed up the test
RecoveryConfiguration: &RecoveryConfiguration{
ActiveRecovery: true,
BackOffReconnectInterval: 1100 * time.Millisecond,
MaxReconnectAttempts: 5,
},
})
Expect(err).To(BeNil())
Expect(connection).NotTo(BeNil())
ch := make(chan *StateChanged, 1)
connection.NotifyStatusChange(ch)
newToken := token(time.Now().Add(time.Duration(10) * time.Second))
Expect(connection.RefreshToken(context.Background(), newToken)).To(BeNil())
time.Sleep(1 * time.Second)
// here the token used during the connection (startToken) is expired
// the new token should be used to reconnect.
// The test is to validate that the client uses the new token to reconnect
// The RefreshToken requests a new token and updates the connection with the new token
Eventually(func() bool {
err := testhelper.DropConnectionContainerID(name)
return err == nil
}).WithTimeout(5 * time.Second).WithPolling(400 * time.Millisecond).Should(BeTrue())
st1 := <-ch
Expect(st1.From).To(Equal(&StateOpen{}))
Expect(st1.To).To(BeAssignableToTypeOf(&StateClosed{}))
time.Sleep(1 * time.Second)
// the connection should not be reconnected
Eventually(func() bool {
conn, err := testhelper.GetConnectionByContainerID(name)
return err == nil && conn != nil
}).WithTimeout(5 * time.Second).WithPolling(400 * time.Millisecond).Should(BeTrue())
Expect(connection.Close(context.Background())).To(BeNil())
})
It("Setting OAuth2 on the Environment should work", func() {
env := NewClusterEnvironment([]Endpoint{
{Address: "amqp://", Options: &AmqpConnOptions{
OAuth2Options: &OAuth2Options{
Token: token(time.Now().Add(time.Duration(10) * time.Second)),
},
},
}})
Expect(env).NotTo(BeNil())
Expect(env.Connections()).NotTo(BeNil())
Expect(len(env.Connections())).To(Equal(0))
connection, err := env.NewConnection(context.Background())
Expect(err).To(BeNil())
Expect(connection).NotTo(BeNil())
Expect(len(env.Connections())).To(Equal(1))
Expect(connection.Close(context.Background())).To(BeNil())
Expect(len(env.Connections())).To(Equal(0))
})
It("Can't use refresh token if not OAuth2 is enabled ", func() {
connection, err := Dial(context.Background(), "amqp://", nil)
Expect(err).To(BeNil())
Expect(connection).NotTo(BeNil())
err = connection.RefreshToken(context.Background(), token(time.Now().Add(time.Duration(10)*time.Second)))
Expect(err).NotTo(BeNil())
Expect(err.Error()).To(ContainSubstring("is not configured to use OAuth2 token"))
Expect(connection.Close(context.Background())).To(BeNil())
})
})
func token(duration time.Time) string {
decodedKey, _ := base64.StdEncoding.DecodeString(Base64Key)
claims := jwt.MapClaims{
"iss": "unit_test",
"aud": AUDIENCE,
"exp": jwt.NewNumericDate(duration),
"scope": []string{"rabbitmq.configure:*/*", "rabbitmq.write:*/*", "rabbitmq.read:*/*"},
"random": randomString(6),
}
// Create a new token object, specifying signing method and the claims
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["kid"] = "token-key"
// Sign and get the complete encoded token as a string using the secret
tokenString, err := token.SignedString(decodedKey)
Expect(err).To(BeNil())
return tokenString
}