add code of async task system
This commit is contained in:
parent
de5f976c31
commit
7ea66e48af
|
|
@ -0,0 +1,247 @@
|
||||||
|
// Package task provides asynchronous task processing with handler factory pattern
|
||||||
|
package task
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"modelRT/logger"
|
||||||
|
|
||||||
|
"github.com/gofrs/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TaskHandler defines the interface for task processors
|
||||||
|
type TaskHandler interface {
|
||||||
|
// Execute processes a task with the given ID and type
|
||||||
|
Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, db *gorm.DB) error
|
||||||
|
// CanHandle returns true if this handler can process the given task type
|
||||||
|
CanHandle(taskType TaskType) bool
|
||||||
|
// Name returns the name of the handler for logging and metrics
|
||||||
|
Name() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlerFactory creates task handlers based on task type
|
||||||
|
type HandlerFactory struct {
|
||||||
|
handlers map[TaskType]TaskHandler
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandlerFactory creates a new HandlerFactory
|
||||||
|
func NewHandlerFactory() *HandlerFactory {
|
||||||
|
return &HandlerFactory{
|
||||||
|
handlers: make(map[TaskType]TaskHandler),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterHandler registers a handler for a specific task type
|
||||||
|
func (f *HandlerFactory) RegisterHandler(taskType TaskType, handler TaskHandler) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
|
||||||
|
f.handlers[taskType] = handler
|
||||||
|
logger.Info(context.Background(), "Handler registered",
|
||||||
|
"task_type", taskType,
|
||||||
|
"handler_name", handler.Name(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHandler returns a handler for the given task type
|
||||||
|
func (f *HandlerFactory) GetHandler(taskType TaskType) (TaskHandler, error) {
|
||||||
|
f.mu.RLock()
|
||||||
|
handler, exists := f.handlers[taskType]
|
||||||
|
f.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("no handler registered for task type: %s", taskType)
|
||||||
|
}
|
||||||
|
|
||||||
|
return handler, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateDefaultHandlers registers all default task handlers
|
||||||
|
func (f *HandlerFactory) CreateDefaultHandlers() {
|
||||||
|
f.RegisterHandler(TypeTopologyAnalysis, &TopologyAnalysisHandler{})
|
||||||
|
f.RegisterHandler(TypeEventAnalysis, &EventAnalysisHandler{})
|
||||||
|
f.RegisterHandler(TypeBatchImport, &BatchImportHandler{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BaseHandler provides common functionality for all task handlers
|
||||||
|
type BaseHandler struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBaseHandler creates a new BaseHandler
|
||||||
|
func NewBaseHandler(name string) *BaseHandler {
|
||||||
|
return &BaseHandler{name: name}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the handler name
|
||||||
|
func (h *BaseHandler) Name() string {
|
||||||
|
return h.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// TopologyAnalysisHandler handles topology analysis tasks
|
||||||
|
type TopologyAnalysisHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTopologyAnalysisHandler creates a new TopologyAnalysisHandler
|
||||||
|
func NewTopologyAnalysisHandler() *TopologyAnalysisHandler {
|
||||||
|
return &TopologyAnalysisHandler{
|
||||||
|
BaseHandler: *NewBaseHandler("topology_analysis_handler"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute processes a topology analysis task
|
||||||
|
func (h *TopologyAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, db *gorm.DB) error {
|
||||||
|
logger.Info(ctx, "Starting topology analysis",
|
||||||
|
"task_id", taskID,
|
||||||
|
"task_type", taskType,
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: Implement actual topology analysis logic
|
||||||
|
// This would typically involve:
|
||||||
|
// 1. Fetching task parameters from database
|
||||||
|
// 2. Performing topology analysis (checking for islands, shorts, etc.)
|
||||||
|
// 3. Storing results in database
|
||||||
|
// 4. Updating task status
|
||||||
|
|
||||||
|
// Simulate work
|
||||||
|
logger.Info(ctx, "Topology analysis completed",
|
||||||
|
"task_id", taskID,
|
||||||
|
"task_type", taskType,
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanHandle returns true for topology analysis tasks
|
||||||
|
func (h *TopologyAnalysisHandler) CanHandle(taskType TaskType) bool {
|
||||||
|
return taskType == TypeTopologyAnalysis
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventAnalysisHandler handles event analysis tasks
|
||||||
|
type EventAnalysisHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEventAnalysisHandler creates a new EventAnalysisHandler
|
||||||
|
func NewEventAnalysisHandler() *EventAnalysisHandler {
|
||||||
|
return &EventAnalysisHandler{
|
||||||
|
BaseHandler: *NewBaseHandler("event_analysis_handler"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute processes an event analysis task
|
||||||
|
func (h *EventAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, db *gorm.DB) error {
|
||||||
|
logger.Info(ctx, "Starting event analysis",
|
||||||
|
"task_id", taskID,
|
||||||
|
"task_type", taskType,
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: Implement actual event analysis logic
|
||||||
|
// This would typically involve:
|
||||||
|
// 1. Fetching motor and trigger information
|
||||||
|
// 2. Analyzing events within the specified duration
|
||||||
|
// 3. Generating analysis report
|
||||||
|
// 4. Storing results in database
|
||||||
|
|
||||||
|
// Simulate work
|
||||||
|
logger.Info(ctx, "Event analysis completed",
|
||||||
|
"task_id", taskID,
|
||||||
|
"task_type", taskType,
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanHandle returns true for event analysis tasks
|
||||||
|
func (h *EventAnalysisHandler) CanHandle(taskType TaskType) bool {
|
||||||
|
return taskType == TypeEventAnalysis
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchImportHandler handles batch import tasks
|
||||||
|
type BatchImportHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBatchImportHandler creates a new BatchImportHandler
|
||||||
|
func NewBatchImportHandler() *BatchImportHandler {
|
||||||
|
return &BatchImportHandler{
|
||||||
|
BaseHandler: *NewBaseHandler("batch_import_handler"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute processes a batch import task
|
||||||
|
func (h *BatchImportHandler) Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, db *gorm.DB) error {
|
||||||
|
logger.Info(ctx, "Starting batch import",
|
||||||
|
"task_id", taskID,
|
||||||
|
"task_type", taskType,
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: Implement actual batch import logic
|
||||||
|
// This would typically involve:
|
||||||
|
// 1. Reading file from specified path
|
||||||
|
// 2. Parsing file content (CSV, Excel, etc.)
|
||||||
|
// 3. Validating and importing data into database
|
||||||
|
// 4. Generating import report
|
||||||
|
|
||||||
|
// Simulate work
|
||||||
|
logger.Info(ctx, "Batch import completed",
|
||||||
|
"task_id", taskID,
|
||||||
|
"task_type", taskType,
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanHandle returns true for batch import tasks
|
||||||
|
func (h *BatchImportHandler) CanHandle(taskType TaskType) bool {
|
||||||
|
return taskType == TypeBatchImport
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompositeHandler can handle multiple task types by delegating to appropriate handlers
|
||||||
|
type CompositeHandler struct {
|
||||||
|
factory *HandlerFactory
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCompositeHandler creates a new CompositeHandler
|
||||||
|
func NewCompositeHandler(factory *HandlerFactory) *CompositeHandler {
|
||||||
|
return &CompositeHandler{factory: factory}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute delegates task execution to the appropriate handler
|
||||||
|
func (h *CompositeHandler) Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, db *gorm.DB) error {
|
||||||
|
handler, err := h.factory.GetHandler(taskType)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get handler for task type %s: %w", taskType, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return handler.Execute(ctx, taskID, taskType, db)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanHandle returns true if any registered handler can handle the task type
|
||||||
|
func (h *CompositeHandler) CanHandle(taskType TaskType) bool {
|
||||||
|
_, err := h.factory.GetHandler(taskType)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the composite handler name
|
||||||
|
func (h *CompositeHandler) Name() string {
|
||||||
|
return "composite_handler"
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultHandlerFactory returns a HandlerFactory with all default handlers registered
|
||||||
|
func DefaultHandlerFactory() *HandlerFactory {
|
||||||
|
factory := NewHandlerFactory()
|
||||||
|
factory.CreateDefaultHandlers()
|
||||||
|
return factory
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultCompositeHandler returns a CompositeHandler with all default handlers
|
||||||
|
func DefaultCompositeHandler() TaskHandler {
|
||||||
|
factory := DefaultHandlerFactory()
|
||||||
|
return NewCompositeHandler(factory)
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
package task
|
package task
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/gofrs/uuid"
|
"github.com/gofrs/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -41,7 +43,7 @@ func NewTaskQueueMessageWithPriority(taskID uuid.UUID, taskType TaskType, priori
|
||||||
|
|
||||||
// ToJSON converts the TaskQueueMessage to JSON bytes
|
// ToJSON converts the TaskQueueMessage to JSON bytes
|
||||||
func (m *TaskQueueMessage) ToJSON() ([]byte, error) {
|
func (m *TaskQueueMessage) ToJSON() ([]byte, error) {
|
||||||
return []byte{}, nil // Placeholder - actual implementation would use json.Marshal
|
return json.Marshal(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate checks if the TaskQueueMessage is valid
|
// Validate checks if the TaskQueueMessage is valid
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,227 @@
|
||||||
|
// Package task provides asynchronous task processing with RabbitMQ integration
|
||||||
|
package task
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"modelRT/config"
|
||||||
|
"modelRT/logger"
|
||||||
|
"modelRT/mq"
|
||||||
|
|
||||||
|
"github.com/gofrs/uuid"
|
||||||
|
amqp "github.com/rabbitmq/amqp091-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// TaskExchangeName is the name of the exchange for task routing
|
||||||
|
TaskExchangeName = "modelrt.tasks.exchange"
|
||||||
|
// TaskQueueName is the name of the main task queue
|
||||||
|
TaskQueueName = "modelrt.tasks.queue"
|
||||||
|
// TaskRoutingKey is the routing key for task messages
|
||||||
|
TaskRoutingKey = "modelrt.task"
|
||||||
|
// MaxPriority is the maximum priority level for tasks (0-10)
|
||||||
|
MaxPriority = 10
|
||||||
|
// DefaultMessageTTL is the default time-to-live for task messages (24 hours)
|
||||||
|
DefaultMessageTTL = 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
// QueueProducer handles publishing tasks to RabbitMQ
|
||||||
|
type QueueProducer struct {
|
||||||
|
conn *amqp.Connection
|
||||||
|
ch *amqp.Channel
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewQueueProducer creates a new QueueProducer instance
|
||||||
|
func NewQueueProducer(ctx context.Context, cfg config.RabbitMQConfig) (*QueueProducer, error) {
|
||||||
|
// Initialize RabbitMQ connection if not already initialized
|
||||||
|
mq.InitRabbitProxy(ctx, cfg)
|
||||||
|
|
||||||
|
conn := mq.GetConn()
|
||||||
|
if conn == nil {
|
||||||
|
return nil, fmt.Errorf("failed to get RabbitMQ connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
ch, err := conn.Channel()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open channel: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
producer := &QueueProducer{
|
||||||
|
conn: conn,
|
||||||
|
ch: ch,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Declare exchange and queue
|
||||||
|
if err := producer.declareInfrastructure(); err != nil {
|
||||||
|
ch.Close()
|
||||||
|
return nil, fmt.Errorf("failed to declare infrastructure: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return producer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// declareInfrastructure declares the exchange, queue, and binds them
|
||||||
|
func (p *QueueProducer) declareInfrastructure() error {
|
||||||
|
// Declare durable direct exchange
|
||||||
|
err := p.ch.ExchangeDeclare(
|
||||||
|
TaskExchangeName, // name
|
||||||
|
"direct", // type
|
||||||
|
true, // durable
|
||||||
|
false, // auto-deleted
|
||||||
|
false, // internal
|
||||||
|
false, // no-wait
|
||||||
|
nil, // arguments
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to declare exchange: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Declare durable queue with priority support and message TTL
|
||||||
|
_, err = p.ch.QueueDeclare(
|
||||||
|
TaskQueueName, // name
|
||||||
|
true, // durable
|
||||||
|
false, // delete when unused
|
||||||
|
false, // exclusive
|
||||||
|
false, // no-wait
|
||||||
|
amqp.Table{
|
||||||
|
"x-max-priority": MaxPriority, // support priority levels 0-10
|
||||||
|
"x-message-ttl": DefaultMessageTTL.Milliseconds(), // message TTL
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to declare queue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind queue to exchange
|
||||||
|
err = p.ch.QueueBind(
|
||||||
|
TaskQueueName, // queue name
|
||||||
|
TaskRoutingKey, // routing key
|
||||||
|
TaskExchangeName, // exchange name
|
||||||
|
false, // no-wait
|
||||||
|
nil, // arguments
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to bind queue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishTask publishes a task message to RabbitMQ
|
||||||
|
func (p *QueueProducer) PublishTask(ctx context.Context, taskID uuid.UUID, taskType TaskType, priority int) error {
|
||||||
|
message := NewTaskQueueMessageWithPriority(taskID, taskType, priority)
|
||||||
|
|
||||||
|
// Validate message
|
||||||
|
if !message.Validate() {
|
||||||
|
return fmt.Errorf("invalid task message: taskID=%s, taskType=%s", taskID, taskType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert message to JSON
|
||||||
|
body, err := json.Marshal(message)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal task message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare publishing options
|
||||||
|
publishing := amqp.Publishing{
|
||||||
|
ContentType: "application/json",
|
||||||
|
Body: body,
|
||||||
|
DeliveryMode: amqp.Persistent, // Persistent messages survive broker restart
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Priority: uint8(priority),
|
||||||
|
Headers: amqp.Table{
|
||||||
|
"task_id": taskID.String(),
|
||||||
|
"task_type": string(taskType),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish to exchange
|
||||||
|
err = p.ch.PublishWithContext(
|
||||||
|
ctx,
|
||||||
|
TaskExchangeName, // exchange
|
||||||
|
TaskRoutingKey, // routing key
|
||||||
|
false, // mandatory
|
||||||
|
false, // immediate
|
||||||
|
publishing,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to publish task message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info(ctx, "Task published to queue",
|
||||||
|
"task_id", taskID.String(),
|
||||||
|
"task_type", taskType,
|
||||||
|
"priority", priority,
|
||||||
|
"queue", TaskQueueName,
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishTaskWithRetry publishes a task with retry logic
|
||||||
|
func (p *QueueProducer) PublishTaskWithRetry(ctx context.Context, taskID uuid.UUID, taskType TaskType, priority int, maxRetries int) error {
|
||||||
|
var lastErr error
|
||||||
|
for i := range maxRetries {
|
||||||
|
err := p.PublishTask(ctx, taskID, taskType, priority)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
|
||||||
|
// Exponential backoff
|
||||||
|
backoff := time.Duration(1<<uint(i)) * time.Second
|
||||||
|
backoff = min(backoff, 10*time.Second)
|
||||||
|
|
||||||
|
logger.Warn(ctx, "Failed to publish task, retrying",
|
||||||
|
"task_id", taskID.String(),
|
||||||
|
"attempt", i+1,
|
||||||
|
"max_retries", maxRetries,
|
||||||
|
"backoff", backoff,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(backoff):
|
||||||
|
continue
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("failed to publish task after %d retries: %w", maxRetries, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the producer's channel
|
||||||
|
func (p *QueueProducer) Close() error {
|
||||||
|
if p.ch != nil {
|
||||||
|
return p.ch.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQueueInfo returns information about the task queue
|
||||||
|
func (p *QueueProducer) GetQueueInfo() (*amqp.Queue, error) {
|
||||||
|
queue, err := p.ch.QueueDeclarePassive(
|
||||||
|
TaskQueueName, // name
|
||||||
|
true, // durable
|
||||||
|
false, // delete when unused
|
||||||
|
false, // exclusive
|
||||||
|
false, // no-wait
|
||||||
|
amqp.Table{
|
||||||
|
"x-max-priority": MaxPriority,
|
||||||
|
"x-message-ttl": DefaultMessageTTL.Milliseconds(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to inspect queue: %w", err)
|
||||||
|
}
|
||||||
|
return &queue, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PurgeQueue removes all messages from the task queue
|
||||||
|
func (p *QueueProducer) PurgeQueue() (int, error) {
|
||||||
|
return p.ch.QueuePurge(TaskQueueName, false)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,441 @@
|
||||||
|
// Package task provides asynchronous task processing with worker pools
|
||||||
|
package task
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"modelRT/config"
|
||||||
|
"modelRT/logger"
|
||||||
|
"modelRT/mq"
|
||||||
|
|
||||||
|
"github.com/gofrs/uuid"
|
||||||
|
"github.com/panjf2000/ants/v2"
|
||||||
|
amqp "github.com/rabbitmq/amqp091-go"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WorkerConfig holds configuration for the task worker
|
||||||
|
type WorkerConfig struct {
|
||||||
|
// PoolSize is the number of worker goroutines in the pool
|
||||||
|
PoolSize int
|
||||||
|
// PreAlloc indicates whether to pre-allocate memory for the pool
|
||||||
|
PreAlloc bool
|
||||||
|
// MaxBlockingTasks is the maximum number of tasks waiting in queue
|
||||||
|
MaxBlockingTasks int
|
||||||
|
// QueueConsumerCount is the number of concurrent RabbitMQ consumers
|
||||||
|
QueueConsumerCount int
|
||||||
|
// PollingInterval is the interval between health checks
|
||||||
|
PollingInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultWorkerConfig returns the default worker configuration
|
||||||
|
func DefaultWorkerConfig() WorkerConfig {
|
||||||
|
return WorkerConfig{
|
||||||
|
PoolSize: 10,
|
||||||
|
PreAlloc: true,
|
||||||
|
MaxBlockingTasks: 100,
|
||||||
|
QueueConsumerCount: 2,
|
||||||
|
PollingInterval: 30 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskWorker manages a pool of workers for processing asynchronous tasks
|
||||||
|
type TaskWorker struct {
|
||||||
|
cfg WorkerConfig
|
||||||
|
db *gorm.DB
|
||||||
|
pool *ants.Pool
|
||||||
|
conn *amqp.Connection
|
||||||
|
ch *amqp.Channel
|
||||||
|
handler TaskHandler
|
||||||
|
stopChan chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
metrics *WorkerMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// WorkerMetrics holds metrics for the worker pool
|
||||||
|
type WorkerMetrics struct {
|
||||||
|
TasksProcessed int64
|
||||||
|
TasksFailed int64
|
||||||
|
TasksInProgress int32
|
||||||
|
QueueDepth int
|
||||||
|
WorkersActive int
|
||||||
|
WorkersIdle int
|
||||||
|
LastHealthCheck time.Time
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTaskWorker creates a new TaskWorker instance
|
||||||
|
func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg config.RabbitMQConfig, handler TaskHandler) (*TaskWorker, error) {
|
||||||
|
// Initialize RabbitMQ connection
|
||||||
|
mq.InitRabbitProxy(ctx, rabbitCfg)
|
||||||
|
conn := mq.GetConn()
|
||||||
|
if conn == nil {
|
||||||
|
return nil, fmt.Errorf("failed to get RabbitMQ connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create channel
|
||||||
|
ch, err := conn.Channel()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open channel: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Declare queue (ensure it exists with proper arguments)
|
||||||
|
_, err = ch.QueueDeclare(
|
||||||
|
TaskQueueName, // name
|
||||||
|
true, // durable
|
||||||
|
false, // delete when unused
|
||||||
|
false, // exclusive
|
||||||
|
false, // no-wait
|
||||||
|
amqp.Table{
|
||||||
|
"x-max-priority": MaxPriority,
|
||||||
|
"x-message-ttl": DefaultMessageTTL.Milliseconds(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
ch.Close()
|
||||||
|
return nil, fmt.Errorf("failed to declare queue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set QoS (quality of service) for fair dispatch
|
||||||
|
err = ch.Qos(
|
||||||
|
1, // prefetch count
|
||||||
|
0, // prefetch size
|
||||||
|
false, // global
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
ch.Close()
|
||||||
|
return nil, fmt.Errorf("failed to set QoS: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create ants pool
|
||||||
|
pool, err := ants.NewPool(cfg.PoolSize, ants.WithPreAlloc(cfg.PreAlloc))
|
||||||
|
if err != nil {
|
||||||
|
ch.Close()
|
||||||
|
return nil, fmt.Errorf("failed to create worker pool: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxWithCancel, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
worker := &TaskWorker{
|
||||||
|
cfg: cfg,
|
||||||
|
db: db,
|
||||||
|
pool: pool,
|
||||||
|
conn: conn,
|
||||||
|
ch: ch,
|
||||||
|
handler: handler,
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
ctx: ctxWithCancel,
|
||||||
|
cancel: cancel,
|
||||||
|
metrics: &WorkerMetrics{
|
||||||
|
LastHealthCheck: time.Now(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return worker, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins consuming tasks from the queue
|
||||||
|
func (w *TaskWorker) Start() error {
|
||||||
|
logger.Info(w.ctx, "Starting task worker",
|
||||||
|
"pool_size", w.cfg.PoolSize,
|
||||||
|
"queue_consumers", w.cfg.QueueConsumerCount,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Start multiple consumers for better throughput
|
||||||
|
for i := 0; i < w.cfg.QueueConsumerCount; i++ {
|
||||||
|
w.wg.Add(1)
|
||||||
|
go w.consumerLoop(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start health check goroutine
|
||||||
|
w.wg.Add(1)
|
||||||
|
go w.healthCheckLoop()
|
||||||
|
|
||||||
|
logger.Info(w.ctx, "Task worker started successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// consumerLoop runs a single RabbitMQ consumer
|
||||||
|
func (w *TaskWorker) consumerLoop(consumerID int) {
|
||||||
|
defer w.wg.Done()
|
||||||
|
|
||||||
|
logger.Info(w.ctx, "Starting consumer", "consumer_id", consumerID)
|
||||||
|
|
||||||
|
// Consume messages from the queue
|
||||||
|
msgs, err := w.ch.Consume(
|
||||||
|
TaskQueueName, // queue
|
||||||
|
fmt.Sprintf("worker-%d", consumerID), // consumer tag
|
||||||
|
false, // auto-ack
|
||||||
|
false, // exclusive
|
||||||
|
false, // no-local
|
||||||
|
false, // no-wait
|
||||||
|
nil, // args
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(w.ctx, "Failed to start consumer",
|
||||||
|
"consumer_id", consumerID,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-w.stopChan:
|
||||||
|
logger.Info(w.ctx, "Consumer stopping", "consumer_id", consumerID)
|
||||||
|
return
|
||||||
|
case msg, ok := <-msgs:
|
||||||
|
if !ok {
|
||||||
|
logger.Warn(w.ctx, "Consumer channel closed", "consumer_id", consumerID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process message in worker pool
|
||||||
|
err := w.pool.Submit(func() {
|
||||||
|
w.handleMessage(msg)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(w.ctx, "Failed to submit task to pool",
|
||||||
|
"consumer_id", consumerID,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
// Reject message and requeue
|
||||||
|
msg.Nack(false, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMessage processes a single RabbitMQ message
|
||||||
|
func (w *TaskWorker) handleMessage(msg amqp.Delivery) {
|
||||||
|
w.metrics.mu.Lock()
|
||||||
|
w.metrics.TasksInProgress++
|
||||||
|
w.metrics.mu.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
w.metrics.mu.Lock()
|
||||||
|
w.metrics.TasksInProgress--
|
||||||
|
w.metrics.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx := w.ctx
|
||||||
|
|
||||||
|
// Parse task message
|
||||||
|
var taskMsg TaskQueueMessage
|
||||||
|
if err := json.Unmarshal(msg.Body, &taskMsg); err != nil {
|
||||||
|
logger.Error(ctx, "Failed to unmarshal task message", "error", err)
|
||||||
|
msg.Nack(false, false) // Reject without requeue
|
||||||
|
w.metrics.mu.Lock()
|
||||||
|
w.metrics.TasksFailed++
|
||||||
|
w.metrics.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate message
|
||||||
|
if !taskMsg.Validate() {
|
||||||
|
logger.Error(ctx, "Invalid task message",
|
||||||
|
"task_id", taskMsg.TaskID,
|
||||||
|
"task_type", taskMsg.TaskType,
|
||||||
|
)
|
||||||
|
msg.Nack(false, false) // Reject without requeue
|
||||||
|
w.metrics.mu.Lock()
|
||||||
|
w.metrics.TasksFailed++
|
||||||
|
w.metrics.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info(ctx, "Processing task",
|
||||||
|
"task_id", taskMsg.TaskID,
|
||||||
|
"task_type", taskMsg.TaskType,
|
||||||
|
"priority", taskMsg.Priority,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Update task status to RUNNING in database
|
||||||
|
if err := w.updateTaskStatus(ctx, taskMsg.TaskID, StatusRunning); err != nil {
|
||||||
|
logger.Error(ctx, "Failed to update task status", "error", err)
|
||||||
|
msg.Nack(false, true) // Reject with requeue
|
||||||
|
w.metrics.mu.Lock()
|
||||||
|
w.metrics.TasksFailed++
|
||||||
|
w.metrics.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute task using handler
|
||||||
|
startTime := time.Now()
|
||||||
|
err := w.handler.Execute(ctx, taskMsg.TaskID, taskMsg.TaskType, w.db)
|
||||||
|
processingTime := time.Since(startTime)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(ctx, "Task execution failed",
|
||||||
|
"task_id", taskMsg.TaskID,
|
||||||
|
"task_type", taskMsg.TaskType,
|
||||||
|
"processing_time", processingTime,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Update task status to FAILED
|
||||||
|
if updateErr := w.updateTaskWithError(ctx, taskMsg.TaskID, err); updateErr != nil {
|
||||||
|
logger.Error(ctx, "Failed to update task with error", "error", updateErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ack message even if task failed (we don't want to retry indefinitely)
|
||||||
|
msg.Ack(false)
|
||||||
|
w.metrics.mu.Lock()
|
||||||
|
w.metrics.TasksFailed++
|
||||||
|
w.metrics.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update task status to COMPLETED
|
||||||
|
if err := w.updateTaskStatus(ctx, taskMsg.TaskID, StatusCompleted); err != nil {
|
||||||
|
logger.Error(ctx, "Failed to update task status to completed", "error", err)
|
||||||
|
// Still ack the message since task was processed successfully
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acknowledge message
|
||||||
|
msg.Ack(false)
|
||||||
|
|
||||||
|
logger.Info(ctx, "Task completed successfully",
|
||||||
|
"task_id", taskMsg.TaskID,
|
||||||
|
"task_type", taskMsg.TaskType,
|
||||||
|
"processing_time", processingTime,
|
||||||
|
)
|
||||||
|
|
||||||
|
w.metrics.mu.Lock()
|
||||||
|
w.metrics.TasksProcessed++
|
||||||
|
w.metrics.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateTaskStatus updates the status of a task in the database
|
||||||
|
func (w *TaskWorker) updateTaskStatus(ctx context.Context, taskID uuid.UUID, status TaskStatus) error {
|
||||||
|
// This is a simplified version. In a real implementation, you would:
|
||||||
|
// 1. Have a proper task table/model
|
||||||
|
// 2. Update the task record with status and timestamps
|
||||||
|
|
||||||
|
// For now, we'll log the update
|
||||||
|
logger.Debug(ctx, "Updating task status",
|
||||||
|
"task_id", taskID,
|
||||||
|
"status", status,
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateTaskWithError updates a task with error information
|
||||||
|
func (w *TaskWorker) updateTaskWithError(ctx context.Context, taskID uuid.UUID, err error) error {
|
||||||
|
logger.Debug(ctx, "Updating task with error",
|
||||||
|
"task_id", taskID,
|
||||||
|
"error", err.Error(),
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// healthCheckLoop periodically checks worker health and metrics
|
||||||
|
func (w *TaskWorker) healthCheckLoop() {
|
||||||
|
defer w.wg.Done()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(w.cfg.PollingInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-w.stopChan:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
w.checkHealth()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkHealth performs health checks and logs metrics
|
||||||
|
func (w *TaskWorker) checkHealth() {
|
||||||
|
w.metrics.mu.Lock()
|
||||||
|
defer w.metrics.mu.Unlock()
|
||||||
|
|
||||||
|
// Update queue depth
|
||||||
|
queue, err := w.ch.QueueDeclarePassive(
|
||||||
|
TaskQueueName, // name
|
||||||
|
true, // durable
|
||||||
|
false, // delete when unused
|
||||||
|
false, // exclusive
|
||||||
|
false, // no-wait
|
||||||
|
amqp.Table{
|
||||||
|
"x-max-priority": MaxPriority,
|
||||||
|
"x-message-ttl": DefaultMessageTTL.Milliseconds(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
w.metrics.QueueDepth = queue.Messages
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update worker pool stats
|
||||||
|
w.metrics.WorkersActive = w.pool.Running()
|
||||||
|
w.metrics.WorkersIdle = w.pool.Free()
|
||||||
|
w.metrics.LastHealthCheck = time.Now()
|
||||||
|
|
||||||
|
logger.Info(w.ctx, "Worker health check",
|
||||||
|
"tasks_processed", w.metrics.TasksProcessed,
|
||||||
|
"tasks_failed", w.metrics.TasksFailed,
|
||||||
|
"tasks_in_progress", w.metrics.TasksInProgress,
|
||||||
|
"queue_depth", w.metrics.QueueDepth,
|
||||||
|
"workers_active", w.metrics.WorkersActive,
|
||||||
|
"workers_idle", w.metrics.WorkersIdle,
|
||||||
|
"pool_capacity", w.pool.Cap(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully stops the task worker
|
||||||
|
func (w *TaskWorker) Stop() error {
|
||||||
|
logger.Info(w.ctx, "Stopping task worker")
|
||||||
|
|
||||||
|
// Signal all goroutines to stop
|
||||||
|
close(w.stopChan)
|
||||||
|
w.cancel()
|
||||||
|
|
||||||
|
// Wait for all goroutines to finish
|
||||||
|
w.wg.Wait()
|
||||||
|
|
||||||
|
// Release worker pool
|
||||||
|
w.pool.Release()
|
||||||
|
|
||||||
|
// Close channel
|
||||||
|
if w.ch != nil {
|
||||||
|
if err := w.ch.Close(); err != nil {
|
||||||
|
logger.Error(w.ctx, "Failed to close channel", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info(w.ctx, "Task worker stopped")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetrics returns current worker metrics
|
||||||
|
func (w *TaskWorker) GetMetrics() *WorkerMetrics {
|
||||||
|
w.metrics.mu.RLock()
|
||||||
|
defer w.metrics.mu.RUnlock()
|
||||||
|
// Create a copy without the mutex to avoid copylocks warning
|
||||||
|
return &WorkerMetrics{
|
||||||
|
TasksProcessed: w.metrics.TasksProcessed,
|
||||||
|
TasksFailed: w.metrics.TasksFailed,
|
||||||
|
TasksInProgress: w.metrics.TasksInProgress,
|
||||||
|
QueueDepth: w.metrics.QueueDepth,
|
||||||
|
WorkersActive: w.metrics.WorkersActive,
|
||||||
|
WorkersIdle: w.metrics.WorkersIdle,
|
||||||
|
LastHealthCheck: w.metrics.LastHealthCheck,
|
||||||
|
// Mutex is intentionally omitted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsHealthy returns true if the worker is healthy
|
||||||
|
func (w *TaskWorker) IsHealthy() bool {
|
||||||
|
w.metrics.mu.RLock()
|
||||||
|
defer w.metrics.mu.RUnlock()
|
||||||
|
|
||||||
|
// Consider unhealthy if last health check was too long ago
|
||||||
|
return time.Since(w.metrics.LastHealthCheck) < 2*w.cfg.PollingInterval
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue