Auto recovery connection publishers and consumers (#22)

* Closes: #4
*. Closes: #5
* Add auto-reconnection for connection, producers and consumers 

---------

Signed-off-by: Gabriele Santomaggio <G.santomaggio@gmail.com>
This commit is contained in:
Gabriele Santomaggio 2025-02-07 11:00:14 +01:00 committed by GitHub
parent 89c4dd74a4
commit 707fe72c3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 1258 additions and 248 deletions

View File

@ -1,13 +1,13 @@
all: format vet test
all: test
format:
go fmt ./...
vet:
go vet ./rabbitmq_amqp
go vet ./pkg/rabbitmq_amqp
test:
cd rabbitmq_amqp && go run -mod=mod github.com/onsi/ginkgo/v2/ginkgo \
test: format vet
cd ./pkg/rabbitmq_amqp && go run -mod=mod github.com/onsi/ginkgo/v2/ginkgo \
--randomize-all --randomize-suites \
--cover --coverprofile=coverage.txt --covermode=atomic \
--race

View File

@ -5,7 +5,7 @@ import (
"errors"
"fmt"
"github.com/Azure/go-amqp"
"github.com/rabbitmq/rabbitmq-amqp-go-client/rabbitmq_amqp"
"github.com/rabbitmq/rabbitmq-amqp-go-client/pkg/rabbitmq_amqp"
"time"
)
@ -20,7 +20,7 @@ func main() {
stateChanged := make(chan *rabbitmq_amqp.StateChanged, 1)
go func(ch chan *rabbitmq_amqp.StateChanged) {
for statusChanged := range ch {
rabbitmq_amqp.Info("[Connection]", "Status changed", statusChanged)
rabbitmq_amqp.Info("[connection]", "Status changed", statusChanged)
}
}(stateChanged)
@ -33,7 +33,7 @@ func main() {
// Register the channel to receive status change notifications
amqpConnection.NotifyStatusChange(stateChanged)
fmt.Printf("AMQP Connection opened.\n")
fmt.Printf("AMQP connection opened.\n")
// Create the management interface for the connection
// so we can declare exchanges, queues, and bindings
management := amqpConnection.Management()
@ -86,16 +86,16 @@ func main() {
deliveryContext, err := consumer.Receive(ctx)
if errors.Is(err, context.Canceled) {
// The consumer was closed correctly
rabbitmq_amqp.Info("[Consumer]", "consumer closed. Context", err)
rabbitmq_amqp.Info("[NewConsumer]", "consumer closed. Context", err)
return
}
if err != nil {
// An error occurred receiving the message
rabbitmq_amqp.Error("[Consumer]", "Error receiving message", err)
rabbitmq_amqp.Error("[NewConsumer]", "Error receiving message", err)
return
}
rabbitmq_amqp.Info("[Consumer]", "Received message",
rabbitmq_amqp.Info("[NewConsumer]", "Received message",
fmt.Sprintf("%s", deliveryContext.Message().Data))
err = deliveryContext.Accept(context.Background())
@ -115,26 +115,26 @@ func main() {
return
}
for i := 0; i < 10; i++ {
for i := 0; i < 1_000; i++ {
// Publish a message to the exchange
publishResult, err := publisher.Publish(context.Background(), amqp.NewMessage([]byte("Hello, World!"+fmt.Sprintf("%d", i))))
if err != nil {
rabbitmq_amqp.Error("Error publishing message", err)
return
rabbitmq_amqp.Error("Error publishing message", "error", err)
time.Sleep(1 * time.Second)
continue
}
switch publishResult.Outcome.(type) {
case *amqp.StateAccepted:
rabbitmq_amqp.Info("[Publisher]", "Message accepted", publishResult.Message.Data[0])
rabbitmq_amqp.Info("[NewPublisher]", "Message accepted", publishResult.Message.Data[0])
break
case *amqp.StateReleased:
rabbitmq_amqp.Warn("[Publisher]", "Message was not routed", publishResult.Message.Data[0])
rabbitmq_amqp.Warn("[NewPublisher]", "Message was not routed", publishResult.Message.Data[0])
break
case *amqp.StateRejected:
rabbitmq_amqp.Warn("[Publisher]", "Message rejected", publishResult.Message.Data[0])
rabbitmq_amqp.Warn("[NewPublisher]", "Message rejected", publishResult.Message.Data[0])
stateType := publishResult.Outcome.(*amqp.StateRejected)
if stateType.Error != nil {
rabbitmq_amqp.Warn("[Publisher]", "Message rejected with error: %v", stateType.Error)
rabbitmq_amqp.Warn("[NewPublisher]", "Message rejected with error: %v", stateType.Error)
}
break
default:
@ -153,13 +153,13 @@ func main() {
//Close the consumer
err = consumer.Close(context.Background())
if err != nil {
rabbitmq_amqp.Error("[Consumer]", err)
rabbitmq_amqp.Error("[NewConsumer]", err)
return
}
// Close the publisher
err = publisher.Close(context.Background())
if err != nil {
rabbitmq_amqp.Error("[Publisher]", err)
rabbitmq_amqp.Error("[NewPublisher]", err)
return
}
@ -197,7 +197,7 @@ func main() {
return
}
fmt.Printf("AMQP Connection closed.\n")
fmt.Printf("AMQP connection closed.\n")
// not necessary. It waits for the status change to be printed
time.Sleep(100 * time.Millisecond)
close(stateChanged)

View File

@ -0,0 +1,213 @@
package main
import (
"context"
"errors"
"fmt"
"github.com/Azure/go-amqp"
"github.com/rabbitmq/rabbitmq-amqp-go-client/pkg/rabbitmq_amqp"
"sync"
"sync/atomic"
"time"
)
func main() {
queueName := "reliable-amqp10-go-queue"
var stateAccepted int32
var stateReleased int32
var stateRejected int32
var received int32
var failed int32
startTime := time.Now()
go func() {
for {
time.Sleep(5 * time.Second)
total := stateAccepted + stateReleased + stateRejected
messagesPerSecond := float64(total) / time.Since(startTime).Seconds()
rabbitmq_amqp.Info("[Stats]", "sent", total, "received", received, "failed", failed, "messagesPerSecond", messagesPerSecond)
}
}()
rabbitmq_amqp.Info("How to deal with network disconnections")
signalBlock := sync.Cond{L: &sync.Mutex{}}
/// Create a channel to receive state change notifications
stateChanged := make(chan *rabbitmq_amqp.StateChanged, 1)
go func(ch chan *rabbitmq_amqp.StateChanged) {
for statusChanged := range ch {
rabbitmq_amqp.Info("[connection]", "Status changed", statusChanged)
switch statusChanged.To.(type) {
case *rabbitmq_amqp.StateOpen:
signalBlock.Broadcast()
}
}
}(stateChanged)
// Open a connection to the AMQP 1.0 server
amqpConnection, err := rabbitmq_amqp.Dial(context.Background(), []string{"amqp://"}, &rabbitmq_amqp.AmqpConnOptions{
SASLType: amqp.SASLTypeAnonymous(),
ContainerID: "reliable-amqp10-go",
RecoveryConfiguration: &rabbitmq_amqp.RecoveryConfiguration{
ActiveRecovery: true,
BackOffReconnectInterval: 2 * time.Second, // we reduce the reconnect interval to speed up the test. The default is 5 seconds
// In production, you should avoid BackOffReconnectInterval with low values since it can cause a high number of reconnection attempts
MaxReconnectAttempts: 5,
},
})
if err != nil {
rabbitmq_amqp.Error("Error opening connection", err)
return
}
// Register the channel to receive status change notifications
amqpConnection.NotifyStatusChange(stateChanged)
fmt.Printf("AMQP connection opened.\n")
// Create the management interface for the connection
// so we can declare exchanges, queues, and bindings
management := amqpConnection.Management()
// Declare a Quorum queue
queueInfo, err := management.DeclareQueue(context.TODO(), &rabbitmq_amqp.QuorumQueueSpecification{
Name: queueName,
})
if err != nil {
rabbitmq_amqp.Error("Error declaring queue", err)
return
}
consumer, err := amqpConnection.NewConsumer(context.Background(), &rabbitmq_amqp.QueueAddress{
Queue: queueName,
}, "reliable-consumer")
if err != nil {
rabbitmq_amqp.Error("Error creating consumer", err)
return
}
consumerContext, cancel := context.WithCancel(context.Background())
// Consume messages from the queue
go func(ctx context.Context) {
for {
deliveryContext, err := consumer.Receive(ctx)
if errors.Is(err, context.Canceled) {
// The consumer was closed correctly
return
}
if err != nil {
// An error occurred receiving the message
// here the consumer could be disconnected from the server due to a network error
signalBlock.L.Lock()
rabbitmq_amqp.Info("[Consumer]", "Consumer is blocked, queue", queueName, "error", err)
signalBlock.Wait()
rabbitmq_amqp.Info("[Consumer]", "Consumer is unblocked, queue", queueName)
signalBlock.L.Unlock()
continue
}
atomic.AddInt32(&received, 1)
err = deliveryContext.Accept(context.Background())
if err != nil {
// same here the delivery could not be accepted due to a network error
// we wait for 2_500 ms and try again
time.Sleep(2500 * time.Millisecond)
continue
}
}
}(consumerContext)
publisher, err := amqpConnection.NewPublisher(context.Background(), &rabbitmq_amqp.QueueAddress{
Queue: queueName,
}, "reliable-publisher")
if err != nil {
rabbitmq_amqp.Error("Error creating publisher", err)
return
}
wg := &sync.WaitGroup{}
for i := 0; i < 1; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 500_000; i++ {
publishResult, err := publisher.Publish(context.Background(), amqp.NewMessage([]byte("Hello, World!"+fmt.Sprintf("%d", i))))
if err != nil {
// here you need to deal with the error. You can store the message in a local in memory/persistent storage
// then retry to send the message as soon as the connection is reestablished
atomic.AddInt32(&failed, 1)
// block signalBlock until the connection is reestablished
signalBlock.L.Lock()
rabbitmq_amqp.Info("[Publisher]", "Publisher is blocked, queue", queueName, "error", err)
signalBlock.Wait()
rabbitmq_amqp.Info("[Publisher]", "Publisher is unblocked, queue", queueName)
signalBlock.L.Unlock()
} else {
switch publishResult.Outcome.(type) {
case *amqp.StateAccepted:
atomic.AddInt32(&stateAccepted, 1)
break
case *amqp.StateReleased:
atomic.AddInt32(&stateReleased, 1)
break
case *amqp.StateRejected:
atomic.AddInt32(&stateRejected, 1)
break
default:
// these status are not supported. Leave it for AMQP 1.0 compatibility
// see: https://www.rabbitmq.com/docs/next/amqp#outcomes
rabbitmq_amqp.Warn("Message state: %v", publishResult.Outcome)
}
}
}
}()
}
wg.Wait()
println("press any key to close the connection")
var input string
_, _ = fmt.Scanln(&input)
cancel()
//Close the consumer
err = consumer.Close(context.Background())
if err != nil {
rabbitmq_amqp.Error("[NewConsumer]", err)
return
}
// Close the publisher
err = publisher.Close(context.Background())
if err != nil {
rabbitmq_amqp.Error("[NewPublisher]", err)
return
}
// Purge the queue
purged, err := management.PurgeQueue(context.TODO(), queueInfo.Name())
if err != nil {
fmt.Printf("Error purging queue: %v\n", err)
return
}
fmt.Printf("Purged %d messages from the queue.\n", purged)
err = management.DeleteQueue(context.TODO(), queueInfo.Name())
if err != nil {
fmt.Printf("Error deleting queue: %v\n", err)
return
}
err = amqpConnection.Close(context.Background())
if err != nil {
fmt.Printf("Error closing connection: %v\n", err)
return
}
fmt.Printf("AMQP connection closed.\n")
// not necessary. It waits for the status change to be printed
time.Sleep(100 * time.Millisecond)
close(stateChanged)
}

View File

@ -0,0 +1,349 @@
package rabbitmq_amqp
import (
"context"
"crypto/tls"
"fmt"
"github.com/Azure/go-amqp"
"github.com/google/uuid"
"math/rand"
"sync"
"sync/atomic"
"time"
)
//func (c *ConnUrlHelper) UseSsl(value bool) {
// c.UseSsl = value
// if value {
// c.Scheme = "amqps"
// } else {
// c.Scheme = "amqp"
// }
//}
type AmqpConnOptions struct {
// wrapper for amqp.ConnOptions
ContainerID string
// wrapper for amqp.ConnOptions
HostName string
// wrapper for amqp.ConnOptions
IdleTimeout time.Duration
// wrapper for amqp.ConnOptions
MaxFrameSize uint32
// wrapper for amqp.ConnOptions
MaxSessions uint16
// wrapper for amqp.ConnOptions
Properties map[string]any
// wrapper for amqp.ConnOptions
SASLType amqp.SASLType
// wrapper for amqp.ConnOptions
TLSConfig *tls.Config
// wrapper for amqp.ConnOptions
WriteTimeout time.Duration
// RecoveryConfiguration is used to configure the recovery behavior of the connection.
// when the connection is closed unexpectedly.
RecoveryConfiguration *RecoveryConfiguration
// copy the addresses for reconnection
addresses []string
}
type AmqpConnection struct {
azureConnection *amqp.Conn
id string
management *AmqpManagement
lifeCycle *LifeCycle
amqpConnOptions *AmqpConnOptions
session *amqp.Session
refMap *sync.Map
entitiesTracker *entitiesTracker
}
// NewPublisher creates a new Publisher that sends messages to the provided destination.
// The destination is a TargetAddress that can be a Queue or an Exchange with a routing key.
// See QueueAddress and ExchangeAddress for more information.
func (a *AmqpConnection) NewPublisher(ctx context.Context, destination TargetAddress, linkName string) (*Publisher, error) {
destinationAdd := ""
err := error(nil)
if destination != nil {
destinationAdd, err = destination.toAddress()
if err != nil {
return nil, err
}
err = validateAddress(destinationAdd)
if err != nil {
return nil, err
}
}
return newPublisher(ctx, a, destinationAdd, linkName)
}
// NewConsumer creates a new Consumer that listens to the provided destination. Destination is a QueueAddress.
func (a *AmqpConnection) NewConsumer(ctx context.Context, destination *QueueAddress, linkName string) (*Consumer, error) {
destinationAdd, err := destination.toAddress()
if err != nil {
return nil, err
}
err = validateAddress(destinationAdd)
return newConsumer(ctx, a, destinationAdd, linkName)
}
// Dial connect to the AMQP 1.0 server using the provided connectionSettings
// Returns a pointer to the new AmqpConnection if successful else an error.
// addresses is a list of addresses to connect to. It picks one randomly.
// It is enough that one of the addresses is reachable.
func Dial(ctx context.Context, addresses []string, connOptions *AmqpConnOptions, args ...string) (*AmqpConnection, error) {
if connOptions == nil {
connOptions = &AmqpConnOptions{
// RabbitMQ requires SASL security layer
// to be enabled for AMQP 1.0 connections.
// So this is mandatory and default in case not defined.
SASLType: amqp.SASLTypeAnonymous(),
}
}
if connOptions.RecoveryConfiguration == nil {
connOptions.RecoveryConfiguration = NewRecoveryConfiguration()
}
// validate the RecoveryConfiguration options
if connOptions.RecoveryConfiguration.MaxReconnectAttempts <= 0 && connOptions.RecoveryConfiguration.ActiveRecovery {
return nil, fmt.Errorf("MaxReconnectAttempts should be greater than 0")
}
if connOptions.RecoveryConfiguration.BackOffReconnectInterval <= 1*time.Second && connOptions.RecoveryConfiguration.ActiveRecovery {
return nil, fmt.Errorf("BackOffReconnectInterval should be greater than 1 second")
}
// create the connection
conn := &AmqpConnection{
management: NewAmqpManagement(),
lifeCycle: NewLifeCycle(),
amqpConnOptions: connOptions,
entitiesTracker: newEntitiesTracker(),
}
tmp := make([]string, len(addresses))
copy(tmp, addresses)
err := conn.open(ctx, addresses, connOptions, args...)
if err != nil {
return nil, err
}
conn.amqpConnOptions = connOptions
conn.amqpConnOptions.addresses = addresses
conn.lifeCycle.SetState(&StateOpen{})
return conn, nil
}
// Open opens a connection to the AMQP 1.0 server.
// using the provided connectionSettings and the AMQPLite library.
// Setups the connection and the management interface.
func (a *AmqpConnection) open(ctx context.Context, addresses []string, connOptions *AmqpConnOptions, args ...string) error {
amqpLiteConnOptions := &amqp.ConnOptions{
ContainerID: connOptions.ContainerID,
HostName: connOptions.HostName,
IdleTimeout: connOptions.IdleTimeout,
MaxFrameSize: connOptions.MaxFrameSize,
MaxSessions: connOptions.MaxSessions,
Properties: connOptions.Properties,
SASLType: connOptions.SASLType,
TLSConfig: connOptions.TLSConfig,
WriteTimeout: connOptions.WriteTimeout,
}
tmp := make([]string, len(addresses))
copy(tmp, addresses)
// random pick and extract one address to use for connection
var azureConnection *amqp.Conn
for len(tmp) > 0 {
idx := random(len(tmp))
addr := tmp[idx]
//connOptions.HostName is the way to set the virtual host
// so we need to pre-parse the URI to get the virtual host
// the PARSE is copied from go-amqp091 library
// the URI will be parsed is parsed again in the amqp lite library
uri, err := ParseURI(addr)
if err != nil {
return err
}
connOptions.HostName = fmt.Sprintf("vhost:%s", uri.Vhost)
// remove the index from the tmp list
tmp = append(tmp[:idx], tmp[idx+1:]...)
azureConnection, err = amqp.Dial(ctx, addr, amqpLiteConnOptions)
if err != nil {
Error("Failed to open connection", ExtractWithoutPassword(addr), err)
continue
}
Debug("Connected to", ExtractWithoutPassword(addr))
break
}
if azureConnection == nil {
return fmt.Errorf("failed to connect to any of the provided addresses")
}
if len(args) > 0 {
a.id = args[0]
} else {
a.id = uuid.New().String()
}
a.azureConnection = azureConnection
var err error
a.session, err = a.azureConnection.NewSession(ctx, nil)
go func() {
select {
case <-azureConnection.Done():
{
a.lifeCycle.SetState(&StateClosed{error: azureConnection.Err()})
if azureConnection.Err() != nil {
Error("connection closed unexpectedly", "error", azureConnection.Err())
a.maybeReconnect()
return
}
Debug("connection closed successfully")
}
}
}()
if err != nil {
return err
}
err = a.management.Open(ctx, a)
if err != nil {
// TODO close connection?
return err
}
return nil
}
func (a *AmqpConnection) maybeReconnect() {
if !a.amqpConnOptions.RecoveryConfiguration.ActiveRecovery {
Info("Recovery is disabled, closing connection")
return
}
a.lifeCycle.SetState(&StateReconnecting{})
numberOfAttempts := 1
waitTime := a.amqpConnOptions.RecoveryConfiguration.BackOffReconnectInterval
reconnected := false
for numberOfAttempts <= a.amqpConnOptions.RecoveryConfiguration.MaxReconnectAttempts {
///wait for before reconnecting
// add some random milliseconds to the wait time to avoid thundering herd
// the random time is between 0 and 500 milliseconds
waitTime = waitTime + time.Duration(rand.Intn(500))*time.Millisecond
Info("Waiting before reconnecting", "in", waitTime, "attempt", numberOfAttempts)
time.Sleep(waitTime)
// context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
// try to createSender
err := a.open(ctx, a.amqpConnOptions.addresses, a.amqpConnOptions)
cancel()
if err != nil {
numberOfAttempts++
waitTime = waitTime * 2
Error("Failed to connection. ", "id", a.Id(), "error", err)
} else {
reconnected = true
break
}
}
if reconnected {
var fails int32
Info("Reconnected successfully, restarting publishers and consumers")
a.entitiesTracker.publishers.Range(func(key, value any) bool {
publisher := value.(*Publisher)
// try to createSender
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
err := publisher.createSender(ctx)
if err != nil {
atomic.AddInt32(&fails, 1)
Error("Failed to createSender publisher", "ID", publisher.Id(), "error", err)
}
cancel()
return true
})
Info("Restarted publishers", "number of fails", fails)
fails = 0
a.entitiesTracker.consumers.Range(func(key, value any) bool {
consumer := value.(*Consumer)
// try to createSender
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
err := consumer.createReceiver(ctx)
if err != nil {
atomic.AddInt32(&fails, 1)
Error("Failed to createReceiver consumer", "ID", consumer.Id(), "error", err)
}
cancel()
return true
})
Info("Restarted consumers", "number of fails", fails)
a.lifeCycle.SetState(&StateOpen{})
}
}
func (a *AmqpConnection) close() {
if a.refMap != nil {
a.refMap.Delete(a.Id())
}
a.entitiesTracker.CleanUp()
}
/*
Close closes the connection to the AMQP 1.0 server and the management interface.
All the publishers and consumers are closed as well.
*/
func (a *AmqpConnection) Close(ctx context.Context) error {
// the status closed (lifeCycle.SetState(&StateClosed{error: nil})) is not set here
// it is set in the connection.Done() channel
// the channel is called anyway
// see the open(...) function with a.lifeCycle.SetState(&StateClosed{error: connection.Err()})
err := a.management.Close(ctx)
if err != nil {
Error("Failed to close management", "error:", err)
}
err = a.azureConnection.Close()
a.close()
return err
}
// NotifyStatusChange registers a channel to receive getState change notifications
// from the connection.
func (a *AmqpConnection) NotifyStatusChange(channel chan *StateChanged) {
a.lifeCycle.chStatusChanged = channel
}
func (a *AmqpConnection) State() LifeCycleState {
return a.lifeCycle.State()
}
func (a *AmqpConnection) Id() string {
return a.id
}
// *** management section ***
// Management returns the management interface for the connection.
// The management interface is used to declare and delete exchanges, queues, and bindings.
func (a *AmqpConnection) Management() *AmqpManagement {
return a.management
}
//*** end management section ***

View File

@ -0,0 +1,93 @@
package rabbitmq_amqp
import (
"sync"
"time"
)
type RecoveryConfiguration struct {
/*
ActiveRecovery Define if the recovery is activated.
If is not activated the connection will not try to createSender.
*/
ActiveRecovery bool
/*
BackOffReconnectInterval The time to wait before trying to createSender after a connection is closed.
time will be increased exponentially with each attempt.
Default is 5 seconds, each attempt will double the time.
The minimum value is 1 second. Avoid setting a value low values since it can cause a high
number of reconnection attempts.
*/
BackOffReconnectInterval time.Duration
/*
MaxReconnectAttempts The maximum number of reconnection attempts.
Default is 5.
The minimum value is 1.
*/
MaxReconnectAttempts int
}
func NewRecoveryConfiguration() *RecoveryConfiguration {
return &RecoveryConfiguration{
ActiveRecovery: true,
BackOffReconnectInterval: 5 * time.Second,
MaxReconnectAttempts: 5,
}
}
type entitiesTracker struct {
publishers sync.Map
consumers sync.Map
}
func newEntitiesTracker() *entitiesTracker {
return &entitiesTracker{
publishers: sync.Map{},
consumers: sync.Map{},
}
}
func (e *entitiesTracker) storeOrReplaceProducer(entity entityIdentifier) {
e.publishers.Store(entity.Id(), entity)
}
func (e *entitiesTracker) getProducer(id string) (*Publisher, bool) {
producer, ok := e.publishers.Load(id)
if !ok {
return nil, false
}
return producer.(*Publisher), true
}
func (e *entitiesTracker) removeProducer(entity entityIdentifier) {
e.publishers.Delete(entity.Id())
}
func (e *entitiesTracker) storeOrReplaceConsumer(entity entityIdentifier) {
e.consumers.Store(entity.Id(), entity)
}
func (e *entitiesTracker) getConsumer(id string) (*Consumer, bool) {
consumer, ok := e.consumers.Load(id)
if !ok {
return nil, false
}
return consumer.(*Consumer), true
}
func (e *entitiesTracker) removeConsumer(entity entityIdentifier) {
e.consumers.Delete(entity.Id())
}
func (e *entitiesTracker) CleanUp() {
e.publishers.Range(func(key, value interface{}) bool {
e.publishers.Delete(key)
return true
})
e.consumers.Range(func(key, value interface{}) bool {
e.consumers.Delete(key)
return true
})
}

View File

@ -0,0 +1,183 @@
package rabbitmq_amqp
import (
"context"
"github.com/Azure/go-amqp"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
testhelper "github.com/rabbitmq/rabbitmq-amqp-go-client/pkg/test-helper"
"time"
)
var _ = Describe("Recovery connection test", func() {
It("connection should reconnect producers and consumers if dropped by via REST API", func() {
/*
The test is a bit complex since it requires to drop the connection by REST API
Then wait for the connection to be reconnected.
The scope of the test is to verify that the connection is reconnected and the
producers and consumers are able to send and receive messages.
It is more like an integration test.
This kind of the tests requires time in terms of execution it has to wait for the
connection to be reconnected, so to speed up the test I aggregated the tests in one.
*/
name := "connection should reconnect producers and consumers if dropped by via REST API"
connection, err := Dial(context.Background(), []string{"amqp://"}, &AmqpConnOptions{
SASLType: amqp.SASLTypeAnonymous(),
ContainerID: name,
// reduced the reconnect interval to speed up the test
RecoveryConfiguration: &RecoveryConfiguration{
ActiveRecovery: true,
BackOffReconnectInterval: 2 * time.Second,
MaxReconnectAttempts: 5,
},
})
Expect(err).To(BeNil())
ch := make(chan *StateChanged, 1)
connection.NotifyStatusChange(ch)
qName := generateName(name)
queueInfo, err := connection.Management().DeclareQueue(context.Background(), &QuorumQueueSpecification{
Name: qName,
})
Expect(err).To(BeNil())
Expect(queueInfo).NotTo(BeNil())
consumer, err := connection.NewConsumer(context.Background(), &QueueAddress{
Queue: qName,
}, "test")
publisher, err := connection.NewPublisher(context.Background(), &QueueAddress{
Queue: qName,
}, "test")
Expect(err).To(BeNil())
Expect(publisher).NotTo(BeNil())
for i := 0; i < 5; i++ {
publishResult, err := publisher.Publish(context.Background(), amqp.NewMessage([]byte("Hello")))
Expect(err).To(BeNil())
Expect(publishResult).NotTo(BeNil())
Expect(publishResult.Outcome).To(Equal(&amqp.StateAccepted{}))
}
Eventually(func() bool {
err := testhelper.DropConnectionContainerID(name)
return err == nil
}).WithTimeout(5 * time.Second).WithPolling(400 * time.Millisecond).Should(BeTrue())
st1 := <-ch
Expect(st1.From).To(Equal(&StateOpen{}))
Expect(st1.To).To(BeAssignableToTypeOf(&StateClosed{}))
/// Closed state should have an error
// Since it is forced closed by the REST API
err = st1.To.(*StateClosed).GetError()
Expect(err).NotTo(BeNil())
Expect(err.Error()).To(ContainSubstring("Connection forced"))
time.Sleep(1 * time.Second)
Eventually(func() bool {
conn, err := testhelper.GetConnectionByContainerID(name)
return err == nil && conn != nil
}).WithTimeout(5 * time.Second).WithPolling(400 * time.Millisecond).Should(BeTrue())
st2 := <-ch
Expect(st2.From).To(BeAssignableToTypeOf(&StateClosed{}))
Expect(st2.To).To(Equal(&StateReconnecting{}))
st3 := <-ch
Expect(st3.From).To(BeAssignableToTypeOf(&StateReconnecting{}))
Expect(st3.To).To(Equal(&StateOpen{}))
for i := 0; i < 5; i++ {
publishResult, err := publisher.Publish(context.Background(), amqp.NewMessage([]byte("Hello")))
Expect(err).To(BeNil())
Expect(publishResult).NotTo(BeNil())
Expect(publishResult.Outcome).To(Equal(&amqp.StateAccepted{}))
}
/// after the connection is reconnected the consumer should be able to receive the messages
for i := 0; i < 10; i++ {
deliveryContext, err := consumer.Receive(context.Background())
Expect(err).To(BeNil())
Expect(deliveryContext).NotTo(BeNil())
}
Expect(connection.Management().DeleteQueue(context.Background(), qName)).To(BeNil())
err = connection.Close(context.Background())
Expect(err).To(BeNil())
st4 := <-ch
Expect(st4.From).To(Equal(&StateOpen{}))
Expect(st4.To).To(BeAssignableToTypeOf(&StateClosed{}))
err = st4.To.(*StateClosed).GetError()
// the flow status should be:
// from open to closed (with error)
// from closed to reconnecting
// from reconnecting to open
// from open to closed (without error)
Expect(err).To(BeNil())
})
It("connection should not reconnect producers and consumers if the auto-recovery is disabled", func() {
name := "connection should reconnect producers and consumers if dropped by via REST API"
connection, err := Dial(context.Background(), []string{"amqp://"}, &AmqpConnOptions{
SASLType: amqp.SASLTypeAnonymous(),
ContainerID: name,
// reduced the reconnect interval to speed up the test
RecoveryConfiguration: &RecoveryConfiguration{
ActiveRecovery: false, // disabled
},
})
Expect(err).To(BeNil())
ch := make(chan *StateChanged, 1)
connection.NotifyStatusChange(ch)
Eventually(func() bool {
err := testhelper.DropConnectionContainerID(name)
return err == nil
}).WithTimeout(5 * time.Second).WithPolling(400 * time.Millisecond).Should(BeTrue())
st1 := <-ch
Expect(st1.From).To(Equal(&StateOpen{}))
Expect(st1.To).To(BeAssignableToTypeOf(&StateClosed{}))
err = st1.To.(*StateClosed).GetError()
Expect(err).NotTo(BeNil())
Expect(err.Error()).To(ContainSubstring("Connection forced"))
time.Sleep(1 * time.Second)
// the connection should not be reconnected
Consistently(func() bool {
conn, err := testhelper.GetConnectionByContainerID(name)
return err == nil && conn != nil
}).WithTimeout(5 * time.Second).WithPolling(400 * time.Millisecond).Should(BeFalse())
err = connection.Close(context.Background())
Expect(err).NotTo(BeNil())
})
It("validate the Recovery connection parameters", func() {
_, err := Dial(context.Background(), []string{"amqp://"}, &AmqpConnOptions{
SASLType: amqp.SASLTypeAnonymous(),
// reduced the reconnect interval to speed up the test
RecoveryConfiguration: &RecoveryConfiguration{
ActiveRecovery: true,
BackOffReconnectInterval: 500 * time.Millisecond,
MaxReconnectAttempts: 5,
},
})
Expect(err).NotTo(BeNil())
Expect(err.Error()).To(ContainSubstring("BackOffReconnectInterval should be greater than"))
_, err = Dial(context.Background(), []string{"amqp://"}, &AmqpConnOptions{
SASLType: amqp.SASLTypeAnonymous(),
RecoveryConfiguration: &RecoveryConfiguration{
ActiveRecovery: true,
MaxReconnectAttempts: 0,
},
})
Expect(err).NotTo(BeNil())
Expect(err.Error()).To(ContainSubstring("MaxReconnectAttempts should be greater than"))
})
})

View File

@ -8,19 +8,19 @@ import (
"time"
)
var _ = Describe("AMQP Connection Test", func() {
It("AMQP SASLTypeAnonymous Connection should succeed", func() {
var _ = Describe("AMQP connection Test", func() {
It("AMQP SASLTypeAnonymous connection should succeed", func() {
connection, err := Dial(context.Background(), []string{"amqp://"}, &amqp.ConnOptions{
connection, err := Dial(context.Background(), []string{"amqp://"}, &AmqpConnOptions{
SASLType: amqp.SASLTypeAnonymous()})
Expect(err).To(BeNil())
err = connection.Close(context.Background())
Expect(err).To(BeNil())
})
It("AMQP SASLTypePlain Connection should succeed", func() {
It("AMQP SASLTypePlain connection should succeed", func() {
connection, err := Dial(context.Background(), []string{"amqp://"}, &amqp.ConnOptions{
connection, err := Dial(context.Background(), []string{"amqp://"}, &AmqpConnOptions{
SASLType: amqp.SASLTypePlain("guest", "guest")})
Expect(err).To(BeNil())
@ -28,35 +28,35 @@ var _ = Describe("AMQP Connection Test", func() {
Expect(err).To(BeNil())
})
It("AMQP Connection connect to the one correct uri and fails the others", func() {
It("AMQP connection connect to the one correct uri and fails the others", func() {
conn, err := Dial(context.Background(), []string{"amqp://localhost:1234", "amqp://nohost:555", "amqp://"}, nil)
Expect(err).To(BeNil())
Expect(conn.Close(context.Background()))
})
It("AMQP Connection should fail due of wrong Port", func() {
It("AMQP connection should fail due of wrong Port", func() {
_, err := Dial(context.Background(), []string{"amqp://localhost:1234"}, nil)
Expect(err).NotTo(BeNil())
})
It("AMQP Connection should fail due of wrong Host", func() {
It("AMQP connection should fail due of wrong Host", func() {
_, err := Dial(context.Background(), []string{"amqp://wrong_host:5672"}, nil)
Expect(err).NotTo(BeNil())
})
It("AMQP Connection should fails with all the wrong uris", func() {
It("AMQP connection should fails with all the wrong uris", func() {
_, err := Dial(context.Background(), []string{"amqp://localhost:1234", "amqp://nohost:555", "amqp://nono"}, nil)
Expect(err).NotTo(BeNil())
})
It("AMQP Connection should fail due to context cancellation", func() {
It("AMQP connection should fail due to context cancellation", func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
cancel()
_, err := Dial(ctx, []string{"amqp://"}, nil)
Expect(err).NotTo(BeNil())
})
It("AMQP Connection should receive events", func() {
It("AMQP connection should receive events", func() {
ch := make(chan *StateChanged, 1)
connection, err := Dial(context.Background(), []string{"amqp://"}, nil)
Expect(err).To(BeNil())
@ -70,7 +70,7 @@ var _ = Describe("AMQP Connection Test", func() {
Expect(recv.To).To(Equal(&StateClosed{}))
})
//It("AMQP TLS Connection should success with SASLTypeAnonymous ", func() {
//It("AMQP TLS connection should success with SASLTypeAnonymous ", func() {
// amqpConnection := NewAmqpConnection()
// Expect(amqpConnection).NotTo(BeNil())
// Expect(amqpConnection).To(BeAssignableToTypeOf(&AmqpConnection{}))

View File

@ -2,7 +2,10 @@ package rabbitmq_amqp
import (
"context"
"fmt"
"github.com/Azure/go-amqp"
"github.com/google/uuid"
"sync/atomic"
)
type DeliveryContext struct {
@ -28,8 +31,8 @@ func (dc *DeliveryContext) DiscardWithAnnotations(ctx context.Context, annotatio
}
// copy the rabbitmq annotations to amqp annotations
destination := make(amqp.Annotations)
for key, value := range annotations {
destination[key] = value
for keyA, value := range annotations {
destination[keyA] = value
}
@ -62,21 +65,49 @@ func (dc *DeliveryContext) RequeueWithAnnotations(ctx context.Context, annotatio
}
type Consumer struct {
receiver *amqp.Receiver
receiver atomic.Pointer[amqp.Receiver]
connection *AmqpConnection
linkName string
destinationAdd string
id string
}
func newConsumer(receiver *amqp.Receiver) *Consumer {
return &Consumer{receiver: receiver}
func (c *Consumer) Id() string {
return c.id
}
func (c *Consumer) Receive(ctx context.Context) (*DeliveryContext, error) {
msg, err := c.receiver.Receive(ctx, nil)
func newConsumer(ctx context.Context, connection *AmqpConnection, destinationAdd string, linkName string, args ...string) (*Consumer, error) {
id := fmt.Sprintf("consumer-%s", uuid.New().String())
if len(args) > 0 {
id = args[0]
}
r := &Consumer{connection: connection, linkName: linkName, destinationAdd: destinationAdd, id: id}
connection.entitiesTracker.storeOrReplaceConsumer(r)
err := r.createReceiver(ctx)
if err != nil {
return nil, err
}
return &DeliveryContext{receiver: c.receiver, message: msg}, nil
return r, nil
}
func (c *Consumer) createReceiver(ctx context.Context) error {
receiver, err := c.connection.session.NewReceiver(ctx, c.destinationAdd, createReceiverLinkOptions(c.destinationAdd, c.linkName, AtLeastOnce))
if err != nil {
return err
}
c.receiver.Swap(receiver)
return nil
}
func (c *Consumer) Receive(ctx context.Context) (*DeliveryContext, error) {
msg, err := c.receiver.Load().Receive(ctx, nil)
if err != nil {
return nil, err
}
return &DeliveryContext{receiver: c.receiver.Load(), message: msg}, nil
}
func (c *Consumer) Close(ctx context.Context) error {
return c.receiver.Close(ctx)
return c.receiver.Load().Close(ctx)
}

View File

@ -0,0 +1,67 @@
package rabbitmq_amqp
import (
"context"
"fmt"
"sync"
)
type Environment struct {
connections sync.Map
addresses []string
connOptions *AmqpConnOptions
}
func NewEnvironment(addresses []string, connOptions *AmqpConnOptions) *Environment {
return &Environment{
connections: sync.Map{},
addresses: addresses,
connOptions: connOptions,
}
}
// NewConnection get a new connection from the environment.
// If the connection id is provided, it will be used as the connection id.
// If the connection id is not provided, a new connection id will be generated.
// The connection id is unique in the environment.
// The Environment will keep track of the connection and close it when the environment is closed.
func (e *Environment) NewConnection(ctx context.Context, args ...string) (*AmqpConnection, error) {
if len(args) > 0 && len(args[0]) > 0 {
// check if connection already exists
if _, ok := e.connections.Load(args[0]); ok {
return nil, fmt.Errorf("connection with id %s already exists", args[0])
}
}
connection, err := Dial(ctx, e.addresses, e.connOptions, args...)
if err != nil {
return nil, err
}
e.connections.Store(connection.Id(), connection)
connection.refMap = &e.connections
return connection, nil
}
// Connections gets the active connections in the environment
func (e *Environment) Connections() []*AmqpConnection {
connections := make([]*AmqpConnection, 0)
e.connections.Range(func(key, value interface{}) bool {
connections = append(connections, value.(*AmqpConnection))
return true
})
return connections
}
// CloseConnections closes all the connections in the environment with all the publishers and consumers.
func (e *Environment) CloseConnections(ctx context.Context) error {
var err error
e.connections.Range(func(key, value any) bool {
connection := value.(*AmqpConnection)
if cerr := connection.Close(ctx); cerr != nil {
err = cerr
}
return true
})
return err
}

View File

@ -0,0 +1,57 @@
package rabbitmq_amqp
import (
"context"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("AMQP Environment Test", func() {
It("AMQP Environment connection should succeed", func() {
env := NewEnvironment([]string{"amqp://"}, nil)
Expect(env).NotTo(BeNil())
Expect(env.Connections()).NotTo(BeNil())
Expect(len(env.Connections())).To(Equal(0))
connection, err := env.NewConnection(context.Background())
Expect(err).To(BeNil())
Expect(connection).NotTo(BeNil())
Expect(len(env.Connections())).To(Equal(1))
Expect(connection.Close(context.Background())).To(BeNil())
Expect(len(env.Connections())).To(Equal(0))
})
It("AMQP Environment CloseConnections should remove all the elements form the list", func() {
env := NewEnvironment([]string{"amqp://"}, nil)
Expect(env).NotTo(BeNil())
Expect(env.Connections()).NotTo(BeNil())
Expect(len(env.Connections())).To(Equal(0))
connection, err := env.NewConnection(context.Background())
Expect(err).To(BeNil())
Expect(connection).NotTo(BeNil())
Expect(len(env.Connections())).To(Equal(1))
Expect(env.CloseConnections(context.Background())).To(BeNil())
Expect(len(env.Connections())).To(Equal(0))
})
It("AMQP Environment connection ID should be unique", func() {
env := NewEnvironment([]string{"amqp://"}, nil)
Expect(env).NotTo(BeNil())
Expect(env.Connections()).NotTo(BeNil())
Expect(len(env.Connections())).To(Equal(0))
connection, err := env.NewConnection(context.Background(), "myConnectionId")
Expect(err).To(BeNil())
Expect(connection).NotTo(BeNil())
Expect(len(env.Connections())).To(Equal(1))
connectionShouldBeNil, err := env.NewConnection(context.Background(), "myConnectionId")
Expect(err).NotTo(BeNil())
Expect(err.Error()).To(ContainSubstring("connection with id myConnectionId already exists"))
Expect(connectionShouldBeNil).To(BeNil())
Expect(len(env.Connections())).To(Equal(1))
Expect(connection.Close(context.Background())).To(BeNil())
Expect(len(env.Connections())).To(Equal(0))
})
})

View File

@ -13,6 +13,10 @@ import (
var ErrPreconditionFailed = errors.New("precondition Failed")
var ErrDoesNotExist = errors.New("does not exist")
/*
AmqpManagement is the interface to the RabbitMQ /management endpoint
The management interface is used to declare/delete exchanges, queues, and bindings
*/
type AmqpManagement struct {
session *amqp.Session
sender *amqp.Sender
@ -28,34 +32,28 @@ func NewAmqpManagement() *AmqpManagement {
}
func (a *AmqpManagement) ensureReceiverLink(ctx context.Context) error {
if a.receiver == nil {
opts := createReceiverLinkOptions(managementNodeAddress, linkPairName, AtMostOnce)
receiver, err := a.session.NewReceiver(ctx, managementNodeAddress, opts)
if err != nil {
return err
}
a.receiver = receiver
return nil
opts := createReceiverLinkOptions(managementNodeAddress, linkPairName, AtMostOnce)
receiver, err := a.session.NewReceiver(ctx, managementNodeAddress, opts)
if err != nil {
return err
}
a.receiver = receiver
return nil
}
func (a *AmqpManagement) ensureSenderLink(ctx context.Context) error {
if a.sender == nil {
sender, err := a.session.NewSender(ctx, managementNodeAddress,
createSenderLinkOptions(managementNodeAddress, linkPairName, AtMostOnce))
if err != nil {
return err
}
a.sender = sender
return nil
sender, err := a.session.NewSender(ctx, managementNodeAddress,
createSenderLinkOptions(managementNodeAddress, linkPairName, AtMostOnce))
if err != nil {
return err
}
a.sender = sender
return nil
}
func (a *AmqpManagement) Open(ctx context.Context, connection *AmqpConnection) error {
session, err := connection.Connection.NewSession(ctx, nil)
session, err := connection.azureConnection.NewSession(ctx, nil)
if err != nil {
return err
}
@ -89,6 +87,11 @@ func (a *AmqpManagement) Close(ctx context.Context) error {
return err
}
/*
Request sends a request to the /management endpoint.
It is a generic method that can be used to send any request to the management endpoint.
In most of the cases you don't need to use this method directly, instead use the standard methods
*/
func (a *AmqpManagement) Request(ctx context.Context, body any, path string, method string,
expectedResponseCodes []int) (map[string]any, error) {
return a.request(ctx, uuid.New().String(), body, path, method, expectedResponseCodes)

View File

@ -22,8 +22,13 @@ var _ = Describe("Management tests", func() {
})
It("AMQP Management should receive events", func() {
ch := make(chan *StateChanged, 1)
connection, err := Dial(context.Background(), []string{"amqp://"}, nil)
ch := make(chan *StateChanged, 2)
connection, err := Dial(context.Background(), []string{"amqp://"}, &AmqpConnOptions{
SASLType: amqp.SASLTypeAnonymous(),
RecoveryConfiguration: &RecoveryConfiguration{
ActiveRecovery: false,
},
})
Expect(err).To(BeNil())
connection.NotifyStatusChange(ch)
err = connection.Close(context.Background())

View File

@ -4,6 +4,8 @@ import (
"context"
"fmt"
"github.com/Azure/go-amqp"
"github.com/google/uuid"
"sync/atomic"
)
type PublishResult struct {
@ -13,12 +15,39 @@ type PublishResult struct {
// Publisher is a publisher that sends messages to a specific destination address.
type Publisher struct {
sender *amqp.Sender
staticTargetAddress bool
sender atomic.Pointer[amqp.Sender]
connection *AmqpConnection
linkName string
destinationAdd string
id string
}
func newPublisher(sender *amqp.Sender, staticTargetAddress bool) *Publisher {
return &Publisher{sender: sender, staticTargetAddress: staticTargetAddress}
func (m *Publisher) Id() string {
return m.id
}
func newPublisher(ctx context.Context, connection *AmqpConnection, destinationAdd string, linkName string, args ...string) (*Publisher, error) {
id := fmt.Sprintf("publisher-%s", uuid.New().String())
if len(args) > 0 {
id = args[0]
}
r := &Publisher{connection: connection, linkName: linkName, destinationAdd: destinationAdd, id: id}
connection.entitiesTracker.storeOrReplaceProducer(r)
err := r.createSender(ctx)
if err != nil {
return nil, err
}
return r, nil
}
func (m *Publisher) createSender(ctx context.Context) error {
sender, err := m.connection.session.NewSender(ctx, m.destinationAdd, createSenderLinkOptions(m.destinationAdd, m.linkName, AtLeastOnce))
if err != nil {
return err
}
m.sender.Swap(sender)
return nil
}
/*
@ -58,7 +87,7 @@ Create a new publisher that sends messages based on message destination address:
</code>
*/
func (m *Publisher) Publish(ctx context.Context, message *amqp.Message) (*PublishResult, error) {
if !m.staticTargetAddress {
if m.destinationAdd == "" {
if message.Properties == nil || message.Properties.To == nil {
return nil, fmt.Errorf("message properties TO is required to send a message to a dynamic target address")
}
@ -68,7 +97,7 @@ func (m *Publisher) Publish(ctx context.Context, message *amqp.Message) (*Publis
return nil, err
}
}
r, err := m.sender.SendWithReceipt(ctx, message, nil)
r, err := m.sender.Load().SendWithReceipt(ctx, message, nil)
if err != nil {
return nil, err
}
@ -76,14 +105,14 @@ func (m *Publisher) Publish(ctx context.Context, message *amqp.Message) (*Publis
if err != nil {
return nil, err
}
publishResult := &PublishResult{
return &PublishResult{
Message: message,
Outcome: state,
}
return publishResult, err
}, err
}
// Close closes the publisher.
func (m *Publisher) Close(ctx context.Context) error {
return m.sender.Close(ctx)
m.connection.entitiesTracker.removeProducer(m)
return m.sender.Load().Close(ctx)
}

View File

@ -95,14 +95,14 @@ var _ = Describe("AMQP publisher ", func() {
Expect(connection.Close(context.Background()))
})
It("Multi Targets Publisher should fail with StateReleased when the destination does not exist", func() {
It("Multi Targets NewPublisher should fail with StateReleased when the destination does not exist", func() {
connection, err := Dial(context.Background(), []string{"amqp://"}, nil)
Expect(err).To(BeNil())
Expect(connection).NotTo(BeNil())
publisher, err := connection.NewPublisher(context.Background(), nil, "test")
Expect(err).To(BeNil())
Expect(publisher).NotTo(BeNil())
qName := generateNameWithDateTime("Targets Publisher should fail when the destination does not exist")
qName := generateNameWithDateTime("Targets NewPublisher should fail when the destination does not exist")
msg := amqp.NewMessage([]byte("hello"))
Expect(MessageToAddressHelper(msg, &QueueAddress{Queue: qName})).To(BeNil())
@ -113,7 +113,7 @@ var _ = Describe("AMQP publisher ", func() {
Expect(connection.Close(context.Background())).To(BeNil())
})
It("Multi Targets Publisher should success with StateReceived when the destination exists", func() {
It("Multi Targets NewPublisher should success with StateReceived when the destination exists", func() {
connection, err := Dial(context.Background(), []string{"amqp://"}, nil)
Expect(err).To(BeNil())
Expect(connection).NotTo(BeNil())
@ -121,7 +121,7 @@ var _ = Describe("AMQP publisher ", func() {
publisher, err := connection.NewPublisher(context.Background(), nil, "test")
Expect(err).To(BeNil())
Expect(publisher).NotTo(BeNil())
name := generateNameWithDateTime("Targets Publisher should success with StateReceived when the destination exists")
name := generateNameWithDateTime("Targets NewPublisher should success with StateReceived when the destination exists")
_, err = connection.Management().DeclareQueue(context.Background(), &QuorumQueueSpecification{
Name: name,
})
@ -167,7 +167,7 @@ var _ = Describe("AMQP publisher ", func() {
Expect(connection.Close(context.Background())).To(BeNil())
})
It("Multi Targets Publisher should fail it TO is not set or not valid", func() {
It("Multi Targets NewPublisher should fail it TO is not set or not valid", func() {
connection, err := Dial(context.Background(), []string{"amqp://"}, nil)
Expect(err).To(BeNil())
Expect(connection).NotTo(BeNil())

View File

@ -1,5 +1,9 @@
package rabbitmq_amqp
type entityIdentifier interface {
Id() string
}
type TQueueType string
const (
@ -16,6 +20,9 @@ func (e QueueType) String() string {
return string(e.Type)
}
/*
QueueSpecification represents the specification of a queue
*/
type QueueSpecification interface {
name() string
isAutoDelete() bool
@ -24,8 +31,6 @@ type QueueSpecification interface {
buildArguments() map[string]any
}
// QuorumQueueSpecification represents the specification of the quorum queue
type OverflowStrategy interface {
overflowStrategy() string
}
@ -69,6 +74,10 @@ func (r *ClientLocalLeaderLocator) leaderLocator() string {
return "client-local"
}
/*
QuorumQueueSpecification represents the specification of the quorum queue
*/
type QuorumQueueSpecification struct {
Name string
AutoExpire int64
@ -150,7 +159,9 @@ func (q *QuorumQueueSpecification) buildArguments() map[string]any {
return result
}
// ClassicQueueSpecification represents the specification of the classic queue
/*
ClassicQueueSpecification represents the specification of the classic queue
*/
type ClassicQueueSpecification struct {
Name string
IsAutoDelete bool
@ -231,6 +242,11 @@ func (q *ClassicQueueSpecification) buildArguments() map[string]any {
return result
}
/*
AutoGeneratedQueueSpecification represents the specification of the auto-generated queue.
It is a classic queue with auto-generated name.
It is useful in context like RPC or when you need a temporary queue.
*/
type AutoGeneratedQueueSpecification struct {
IsAutoDelete bool
IsExclusive bool

View File

@ -31,6 +31,11 @@ func (c *StateClosing) getState() int {
}
type StateClosed struct {
error error
}
func (c *StateClosed) GetError() error {
return c.error
}
func (c *StateClosed) getState() int {
@ -65,7 +70,18 @@ type StateChanged struct {
}
func (s StateChanged) String() string {
switch s.From.(type) {
case *StateClosed:
}
switch s.To.(type) {
case *StateClosed:
return fmt.Sprintf("From: %s, To: %s, Error: %s", statusToString(s.From), statusToString(s.To), s.To.(*StateClosed).error)
}
return fmt.Sprintf("From: %s, To: %s", statusToString(s.From), statusToString(s.To))
}
type LifeCycle struct {
@ -100,6 +116,7 @@ func (l *LifeCycle) SetState(value LifeCycleState) {
if l.chStatusChanged == nil {
return
}
l.chStatusChanged <- &StateChanged{
From: oldState,
To: value,

View File

@ -0,0 +1,115 @@
package test_helper
import (
"encoding/json"
"errors"
"io"
"net/http"
"strconv"
)
type Connection struct {
Name string `json:"name"`
ContainerId string `json:"container_id"`
}
func Connections() ([]Connection, error) {
bodyString, err := httpGet("http://localhost:15672/api/connections/", "guest", "guest")
if err != nil {
return nil, err
}
var data []Connection
err = json.Unmarshal([]byte(bodyString), &data)
if err != nil {
return nil, err
}
return data, nil
}
func GetConnectionByContainerID(Id string) (*Connection, error) {
connections, err := Connections()
if err != nil {
return nil, err
}
for _, conn := range connections {
if conn.ContainerId == Id {
return &conn, nil
}
}
return nil, errors.New("connection not found")
}
func DropConnectionContainerID(Id string) error {
connections, err := Connections()
if err != nil {
return err
}
connectionToDrop := ""
for _, conn := range connections {
if conn.ContainerId == Id {
connectionToDrop = conn.Name
break
}
}
if connectionToDrop == "" {
return errors.New("connection not found")
}
err = DropConnection(connectionToDrop, "15672")
if err != nil {
return err
}
return nil
}
func DropConnection(name string, port string) error {
_, err := httpDelete("http://localhost:"+port+"/api/connections/"+name, "guest", "guest")
if err != nil {
return err
}
return nil
}
func httpGet(url, username, password string) (string, error) {
return baseCall(url, username, password, "GET")
}
func httpDelete(url, username, password string) (string, error) {
return baseCall(url, username, password, "DELETE")
}
func baseCall(url, username, password string, method string) (string, error) {
var client http.Client
req, err := http.NewRequest(method, url, nil)
if err != nil {
return "", err
}
req.SetBasicAuth(username, password)
resp, err3 := client.Do(req)
if err3 != nil {
return "", err3
}
defer resp.Body.Close()
if resp.StatusCode == 200 { // OK
bodyBytes, err2 := io.ReadAll(resp.Body)
if err2 != nil {
return "", err2
}
return string(bodyBytes), nil
}
if resp.StatusCode == 204 { // No Content
return "", nil
}
return "", errors.New(strconv.Itoa(resp.StatusCode))
}

View File

@ -1,168 +0,0 @@
package rabbitmq_amqp
import (
"context"
"fmt"
"github.com/Azure/go-amqp"
)
//func (c *ConnUrlHelper) UseSsl(value bool) {
// c.UseSsl = value
// if value {
// c.Scheme = "amqps"
// } else {
// c.Scheme = "amqp"
// }
//}
type AmqpConnection struct {
Connection *amqp.Conn
management *AmqpManagement
lifeCycle *LifeCycle
session *amqp.Session
}
// NewPublisher creates a new Publisher that sends messages to the provided destination.
// The destination is a TargetAddress that can be a Queue or an Exchange with a routing key.
// See QueueAddress and ExchangeAddress for more information.
func (a *AmqpConnection) NewPublisher(ctx context.Context, destination TargetAddress, linkName string) (*Publisher, error) {
destinationAdd := ""
err := error(nil)
if destination != nil {
destinationAdd, err = destination.toAddress()
if err != nil {
return nil, err
}
err = validateAddress(destinationAdd)
if err != nil {
return nil, err
}
}
sender, err := a.session.NewSender(ctx, destinationAdd, createSenderLinkOptions(destinationAdd, linkName, AtLeastOnce))
if err != nil {
return nil, err
}
return newPublisher(sender, destinationAdd != ""), nil
}
// NewConsumer creates a new Consumer that listens to the provided destination. Destination is a QueueAddress.
func (a *AmqpConnection) NewConsumer(ctx context.Context, destination *QueueAddress, linkName string) (*Consumer, error) {
destinationAdd, err := destination.toAddress()
if err != nil {
return nil, err
}
err = validateAddress(destinationAdd)
if err != nil {
return nil, err
}
receiver, err := a.session.NewReceiver(ctx, destinationAdd, createReceiverLinkOptions(destinationAdd, linkName, AtLeastOnce))
if err != nil {
return nil, err
}
return newConsumer(receiver), nil
}
// Dial connect to the AMQP 1.0 server using the provided connectionSettings
// Returns a pointer to the new AmqpConnection if successful else an error.
// addresses is a list of addresses to connect to. It picks one randomly.
// It is enough that one of the addresses is reachable.
func Dial(ctx context.Context, addresses []string, connOptions *amqp.ConnOptions) (*AmqpConnection, error) {
conn := &AmqpConnection{
management: NewAmqpManagement(),
lifeCycle: NewLifeCycle(),
}
tmp := make([]string, len(addresses))
copy(tmp, addresses)
// random pick and extract one address to use for connection
for len(tmp) > 0 {
idx := random(len(tmp))
addr := tmp[idx]
// remove the index from the tmp list
tmp = append(tmp[:idx], tmp[idx+1:]...)
err := conn.open(ctx, addr, connOptions)
if err != nil {
Error("Failed to open connection", ExtractWithoutPassword(addr), err)
continue
}
Debug("Connected to", ExtractWithoutPassword(addr))
return conn, nil
}
return nil, fmt.Errorf("no address to connect to")
}
// Open opens a connection to the AMQP 1.0 server.
// using the provided connectionSettings and the AMQPLite library.
// Setups the connection and the management interface.
func (a *AmqpConnection) open(ctx context.Context, addr string, connOptions *amqp.ConnOptions) error {
if connOptions == nil {
connOptions = &amqp.ConnOptions{
// RabbitMQ requires SASL security layer
// to be enabled for AMQP 1.0 connections.
// So this is mandatory and default in case not defined.
SASLType: amqp.SASLTypeAnonymous(),
}
}
//connOptions.HostName is the way to set the virtual host
// so we need to pre-parse the URI to get the virtual host
// the PARSE is copied from go-amqp091 library
// the URI will be parsed is parsed again in the amqp lite library
uri, err := ParseURI(addr)
if err != nil {
return err
}
connOptions.HostName = fmt.Sprintf("vhost:%s", uri.Vhost)
conn, err := amqp.Dial(ctx, addr, connOptions)
if err != nil {
return err
}
a.Connection = conn
a.session, err = a.Connection.NewSession(ctx, nil)
if err != nil {
return err
}
err = a.management.Open(ctx, a)
if err != nil {
// TODO close connection?
return err
}
a.lifeCycle.SetState(&StateOpen{})
return nil
}
func (a *AmqpConnection) Close(ctx context.Context) error {
err := a.management.Close(ctx)
if err != nil {
return err
}
err = a.Connection.Close()
a.lifeCycle.SetState(&StateClosed{})
return err
}
// NotifyStatusChange registers a channel to receive getState change notifications
// from the connection.
func (a *AmqpConnection) NotifyStatusChange(channel chan *StateChanged) {
a.lifeCycle.chStatusChanged = channel
}
func (a *AmqpConnection) State() LifeCycleState {
return a.lifeCycle.State()
}
// *** management section ***
// Management returns the management interface for the connection.
// The management interface is used to declare and delete exchanges, queues, and bindings.
func (a *AmqpConnection) Management() *AmqpManagement {
return a.management
}
//*** end management section ***