rabbitmq-amqp-go-client/pkg/rabbitmqamqp/amqp_connection.go

353 lines
10 KiB
Go

package rabbitmqamqp
import (
"context"
"crypto/tls"
"fmt"
"github.com/Azure/go-amqp"
"github.com/google/uuid"
"math/rand"
"sync"
"sync/atomic"
"time"
)
//func (c *ConnUrlHelper) UseSsl(value bool) {
// c.UseSsl = value
// if value {
// c.Scheme = "amqps"
// } else {
// c.Scheme = "amqp"
// }
//}
type AmqpConnOptions struct {
// wrapper for amqp.ConnOptions
ContainerID string
// wrapper for amqp.ConnOptions
HostName string
// wrapper for amqp.ConnOptions
IdleTimeout time.Duration
// wrapper for amqp.ConnOptions
MaxFrameSize uint32
// wrapper for amqp.ConnOptions
MaxSessions uint16
// wrapper for amqp.ConnOptions
Properties map[string]any
// wrapper for amqp.ConnOptions
SASLType amqp.SASLType
// wrapper for amqp.ConnOptions
TLSConfig *tls.Config
// wrapper for amqp.ConnOptions
WriteTimeout time.Duration
// RecoveryConfiguration is used to configure the recovery behavior of the connection.
// when the connection is closed unexpectedly.
RecoveryConfiguration *RecoveryConfiguration
// copy the addresses for reconnection
addresses []string
}
type AmqpConnection struct {
azureConnection *amqp.Conn
id string
management *AmqpManagement
lifeCycle *LifeCycle
amqpConnOptions *AmqpConnOptions
session *amqp.Session
refMap *sync.Map
entitiesTracker *entitiesTracker
}
// NewPublisher creates a new Publisher that sends messages to the provided destination.
// The destination is a TargetAddress that can be a Queue or an Exchange with a routing key.
// See QueueAddress and ExchangeAddress for more information.
func (a *AmqpConnection) NewPublisher(ctx context.Context, destination TargetAddress, linkName string) (*Publisher, error) {
destinationAdd := ""
err := error(nil)
if destination != nil {
destinationAdd, err = destination.toAddress()
if err != nil {
return nil, err
}
err = validateAddress(destinationAdd)
if err != nil {
return nil, err
}
}
return newPublisher(ctx, a, destinationAdd, linkName)
}
// NewConsumer creates a new Consumer that listens to the provided destination. Destination is a QueueAddress.
func (a *AmqpConnection) NewConsumer(ctx context.Context, queueName string, options ConsumerOptions) (*Consumer, error) {
destination := &QueueAddress{
Queue: queueName,
}
destinationAdd, err := destination.toAddress()
if err != nil {
return nil, err
}
return newConsumer(ctx, a, destinationAdd, options)
}
// Dial connect to the AMQP 1.0 server using the provided connectionSettings
// Returns a pointer to the new AmqpConnection if successful else an error.
// addresses is a list of addresses to connect to. It picks one randomly.
// It is enough that one of the addresses is reachable.
func Dial(ctx context.Context, addresses []string, connOptions *AmqpConnOptions, args ...string) (*AmqpConnection, error) {
if connOptions == nil {
connOptions = &AmqpConnOptions{
// RabbitMQ requires SASL security layer
// to be enabled for AMQP 1.0 connections.
// So this is mandatory and default in case not defined.
SASLType: amqp.SASLTypeAnonymous(),
}
}
if connOptions.RecoveryConfiguration == nil {
connOptions.RecoveryConfiguration = NewRecoveryConfiguration()
}
// validate the RecoveryConfiguration options
if connOptions.RecoveryConfiguration.MaxReconnectAttempts <= 0 && connOptions.RecoveryConfiguration.ActiveRecovery {
return nil, fmt.Errorf("MaxReconnectAttempts should be greater than 0")
}
if connOptions.RecoveryConfiguration.BackOffReconnectInterval <= 1*time.Second && connOptions.RecoveryConfiguration.ActiveRecovery {
return nil, fmt.Errorf("BackOffReconnectInterval should be greater than 1 second")
}
// create the connection
conn := &AmqpConnection{
management: NewAmqpManagement(),
lifeCycle: NewLifeCycle(),
amqpConnOptions: connOptions,
entitiesTracker: newEntitiesTracker(),
}
tmp := make([]string, len(addresses))
copy(tmp, addresses)
err := conn.open(ctx, addresses, connOptions, args...)
if err != nil {
return nil, err
}
conn.amqpConnOptions = connOptions
conn.amqpConnOptions.addresses = addresses
conn.lifeCycle.SetState(&StateOpen{})
return conn, nil
}
// Open opens a connection to the AMQP 1.0 server.
// using the provided connectionSettings and the AMQPLite library.
// Setups the connection and the management interface.
func (a *AmqpConnection) open(ctx context.Context, addresses []string, connOptions *AmqpConnOptions, args ...string) error {
amqpLiteConnOptions := &amqp.ConnOptions{
ContainerID: connOptions.ContainerID,
HostName: connOptions.HostName,
IdleTimeout: connOptions.IdleTimeout,
MaxFrameSize: connOptions.MaxFrameSize,
MaxSessions: connOptions.MaxSessions,
Properties: connOptions.Properties,
SASLType: connOptions.SASLType,
TLSConfig: connOptions.TLSConfig,
WriteTimeout: connOptions.WriteTimeout,
}
tmp := make([]string, len(addresses))
copy(tmp, addresses)
// random pick and extract one address to use for connection
var azureConnection *amqp.Conn
for len(tmp) > 0 {
idx := random(len(tmp))
addr := tmp[idx]
//connOptions.HostName is the way to set the virtual host
// so we need to pre-parse the URI to get the virtual host
// the PARSE is copied from go-amqp091 library
// the URI will be parsed is parsed again in the amqp lite library
uri, err := ParseURI(addr)
if err != nil {
return err
}
connOptions.HostName = fmt.Sprintf("vhost:%s", uri.Vhost)
// remove the index from the tmp list
tmp = append(tmp[:idx], tmp[idx+1:]...)
azureConnection, err = amqp.Dial(ctx, addr, amqpLiteConnOptions)
if err != nil {
Error("Failed to open connection", ExtractWithoutPassword(addr), err)
continue
}
Debug("Connected to", ExtractWithoutPassword(addr))
break
}
if azureConnection == nil {
return fmt.Errorf("failed to connect to any of the provided addresses")
}
if len(args) > 0 {
a.id = args[0]
} else {
a.id = uuid.New().String()
}
a.azureConnection = azureConnection
var err error
a.session, err = a.azureConnection.NewSession(ctx, nil)
go func() {
select {
case <-azureConnection.Done():
{
a.lifeCycle.SetState(&StateClosed{error: azureConnection.Err()})
if azureConnection.Err() != nil {
Error("connection closed unexpectedly", "error", azureConnection.Err())
a.maybeReconnect()
return
}
Debug("connection closed successfully")
}
}
}()
if err != nil {
return err
}
err = a.management.Open(ctx, a)
if err != nil {
// TODO close connection?
return err
}
return nil
}
func (a *AmqpConnection) maybeReconnect() {
if !a.amqpConnOptions.RecoveryConfiguration.ActiveRecovery {
Info("Recovery is disabled, closing connection")
return
}
a.lifeCycle.SetState(&StateReconnecting{})
numberOfAttempts := 1
waitTime := a.amqpConnOptions.RecoveryConfiguration.BackOffReconnectInterval
reconnected := false
for numberOfAttempts <= a.amqpConnOptions.RecoveryConfiguration.MaxReconnectAttempts {
///wait for before reconnecting
// add some random milliseconds to the wait time to avoid thundering herd
// the random time is between 0 and 500 milliseconds
waitTime = waitTime + time.Duration(rand.Intn(500))*time.Millisecond
Info("Waiting before reconnecting", "in", waitTime, "attempt", numberOfAttempts)
time.Sleep(waitTime)
// context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
// try to createSender
err := a.open(ctx, a.amqpConnOptions.addresses, a.amqpConnOptions)
cancel()
if err != nil {
numberOfAttempts++
waitTime = waitTime * 2
Error("Failed to connection. ", "id", a.Id(), "error", err)
} else {
reconnected = true
break
}
}
if reconnected {
var fails int32
Info("Reconnected successfully, restarting publishers and consumers")
a.entitiesTracker.publishers.Range(func(key, value any) bool {
publisher := value.(*Publisher)
// try to createSender
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
err := publisher.createSender(ctx)
if err != nil {
atomic.AddInt32(&fails, 1)
Error("Failed to createSender publisher", "ID", publisher.Id(), "error", err)
}
cancel()
return true
})
Info("Restarted publishers", "number of fails", fails)
fails = 0
a.entitiesTracker.consumers.Range(func(key, value any) bool {
consumer := value.(*Consumer)
// try to createSender
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
err := consumer.createReceiver(ctx)
if err != nil {
atomic.AddInt32(&fails, 1)
Error("Failed to createReceiver consumer", "ID", consumer.Id(), "error", err)
}
cancel()
return true
})
Info("Restarted consumers", "number of fails", fails)
a.lifeCycle.SetState(&StateOpen{})
}
}
func (a *AmqpConnection) close() {
if a.refMap != nil {
a.refMap.Delete(a.Id())
}
a.entitiesTracker.CleanUp()
}
/*
Close closes the connection to the AMQP 1.0 server and the management interface.
All the publishers and consumers are closed as well.
*/
func (a *AmqpConnection) Close(ctx context.Context) error {
// the status closed (lifeCycle.SetState(&StateClosed{error: nil})) is not set here
// it is set in the connection.Done() channel
// the channel is called anyway
// see the open(...) function with a.lifeCycle.SetState(&StateClosed{error: connection.Err()})
err := a.management.Close(ctx)
if err != nil {
Error("Failed to close management", "error:", err)
}
err = a.azureConnection.Close()
a.close()
return err
}
// NotifyStatusChange registers a channel to receive getState change notifications
// from the connection.
func (a *AmqpConnection) NotifyStatusChange(channel chan *StateChanged) {
a.lifeCycle.chStatusChanged = channel
}
func (a *AmqpConnection) State() LifeCycleState {
return a.lifeCycle.State()
}
func (a *AmqpConnection) Id() string {
return a.id
}
// *** management section ***
// Management returns the management interface for the connection.
// The management interface is used to declare and delete exchanges, queues, and bindings.
func (a *AmqpConnection) Management() *AmqpManagement {
return a.management
}
//*** end management section ***