// Package task provides asynchronous task processing with worker pools package task import ( "context" "encoding/json" "fmt" "sync" "time" "modelRT/config" "modelRT/logger" "modelRT/mq" "github.com/gofrs/uuid" "github.com/panjf2000/ants/v2" amqp "github.com/rabbitmq/amqp091-go" "gorm.io/gorm" ) // WorkerConfig holds configuration for the task worker type WorkerConfig struct { // PoolSize is the number of worker goroutines in the pool PoolSize int // PreAlloc indicates whether to pre-allocate memory for the pool PreAlloc bool // MaxBlockingTasks is the maximum number of tasks waiting in queue MaxBlockingTasks int // QueueConsumerCount is the number of concurrent RabbitMQ consumers QueueConsumerCount int // PollingInterval is the interval between health checks PollingInterval time.Duration } // DefaultWorkerConfig returns the default worker configuration func DefaultWorkerConfig() WorkerConfig { return WorkerConfig{ PoolSize: 10, PreAlloc: true, MaxBlockingTasks: 100, QueueConsumerCount: 2, PollingInterval: 30 * time.Second, } } // TaskWorker manages a pool of workers for processing asynchronous tasks type TaskWorker struct { cfg WorkerConfig db *gorm.DB pool *ants.Pool conn *amqp.Connection ch *amqp.Channel handler TaskHandler stopChan chan struct{} wg sync.WaitGroup ctx context.Context cancel context.CancelFunc metrics *WorkerMetrics } // WorkerMetrics holds metrics for the worker pool type WorkerMetrics struct { TasksProcessed int64 TasksFailed int64 TasksInProgress int32 QueueDepth int WorkersActive int WorkersIdle int LastHealthCheck time.Time mu sync.RWMutex } // NewTaskWorker creates a new TaskWorker instance func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg config.RabbitMQConfig, handler TaskHandler) (*TaskWorker, error) { // Initialize RabbitMQ connection mq.InitRabbitProxy(ctx, rabbitCfg) conn := mq.GetConn() if conn == nil { return nil, fmt.Errorf("failed to get RabbitMQ connection") } // Create channel ch, err := conn.Channel() if err != nil { return nil, fmt.Errorf("failed to open channel: %w", err) } // Declare queue (ensure it exists with proper arguments) _, err = ch.QueueDeclare( TaskQueueName, // name true, // durable false, // delete when unused false, // exclusive false, // no-wait amqp.Table{ "x-max-priority": MaxPriority, "x-message-ttl": DefaultMessageTTL.Milliseconds(), }, ) if err != nil { ch.Close() return nil, fmt.Errorf("failed to declare queue: %w", err) } // Set QoS (quality of service) for fair dispatch err = ch.Qos( 1, // prefetch count 0, // prefetch size false, // global ) if err != nil { ch.Close() return nil, fmt.Errorf("failed to set QoS: %w", err) } // Create ants pool pool, err := ants.NewPool(cfg.PoolSize, ants.WithPreAlloc(cfg.PreAlloc)) if err != nil { ch.Close() return nil, fmt.Errorf("failed to create worker pool: %w", err) } ctxWithCancel, cancel := context.WithCancel(ctx) worker := &TaskWorker{ cfg: cfg, db: db, pool: pool, conn: conn, ch: ch, handler: handler, stopChan: make(chan struct{}), ctx: ctxWithCancel, cancel: cancel, metrics: &WorkerMetrics{ LastHealthCheck: time.Now(), }, } return worker, nil } // Start begins consuming tasks from the queue func (w *TaskWorker) Start() error { logger.Info(w.ctx, "Starting task worker", "pool_size", w.cfg.PoolSize, "queue_consumers", w.cfg.QueueConsumerCount, ) // Start multiple consumers for better throughput for i := 0; i < w.cfg.QueueConsumerCount; i++ { w.wg.Add(1) go w.consumerLoop(i) } // Start health check goroutine w.wg.Add(1) go w.healthCheckLoop() logger.Info(w.ctx, "Task worker started successfully") return nil } // consumerLoop runs a single RabbitMQ consumer func (w *TaskWorker) consumerLoop(consumerID int) { defer w.wg.Done() logger.Info(w.ctx, "Starting consumer", "consumer_id", consumerID) // Consume messages from the queue msgs, err := w.ch.Consume( TaskQueueName, // queue fmt.Sprintf("worker-%d", consumerID), // consumer tag false, // auto-ack false, // exclusive false, // no-local false, // no-wait nil, // args ) if err != nil { logger.Error(w.ctx, "Failed to start consumer", "consumer_id", consumerID, "error", err, ) return } for { select { case <-w.stopChan: logger.Info(w.ctx, "Consumer stopping", "consumer_id", consumerID) return case msg, ok := <-msgs: if !ok { logger.Warn(w.ctx, "Consumer channel closed", "consumer_id", consumerID) return } // Process message in worker pool err := w.pool.Submit(func() { w.handleMessage(msg) }) if err != nil { logger.Error(w.ctx, "Failed to submit task to pool", "consumer_id", consumerID, "error", err, ) // Reject message and requeue msg.Nack(false, true) } } } } // handleMessage processes a single RabbitMQ message func (w *TaskWorker) handleMessage(msg amqp.Delivery) { w.metrics.mu.Lock() w.metrics.TasksInProgress++ w.metrics.mu.Unlock() defer func() { w.metrics.mu.Lock() w.metrics.TasksInProgress-- w.metrics.mu.Unlock() }() ctx := w.ctx // Parse task message var taskMsg TaskQueueMessage if err := json.Unmarshal(msg.Body, &taskMsg); err != nil { logger.Error(ctx, "Failed to unmarshal task message", "error", err) msg.Nack(false, false) // Reject without requeue w.metrics.mu.Lock() w.metrics.TasksFailed++ w.metrics.mu.Unlock() return } // Validate message if !taskMsg.Validate() { logger.Error(ctx, "Invalid task message", "task_id", taskMsg.TaskID, "task_type", taskMsg.TaskType, ) msg.Nack(false, false) // Reject without requeue w.metrics.mu.Lock() w.metrics.TasksFailed++ w.metrics.mu.Unlock() return } logger.Info(ctx, "Processing task", "task_id", taskMsg.TaskID, "task_type", taskMsg.TaskType, "priority", taskMsg.Priority, ) // Update task status to RUNNING in database if err := w.updateTaskStatus(ctx, taskMsg.TaskID, StatusRunning); err != nil { logger.Error(ctx, "Failed to update task status", "error", err) msg.Nack(false, true) // Reject with requeue w.metrics.mu.Lock() w.metrics.TasksFailed++ w.metrics.mu.Unlock() return } // Execute task using handler startTime := time.Now() err := w.handler.Execute(ctx, taskMsg.TaskID, taskMsg.TaskType, w.db) processingTime := time.Since(startTime) if err != nil { logger.Error(ctx, "Task execution failed", "task_id", taskMsg.TaskID, "task_type", taskMsg.TaskType, "processing_time", processingTime, "error", err, ) // Update task status to FAILED if updateErr := w.updateTaskWithError(ctx, taskMsg.TaskID, err); updateErr != nil { logger.Error(ctx, "Failed to update task with error", "error", updateErr) } // Ack message even if task failed (we don't want to retry indefinitely) msg.Ack(false) w.metrics.mu.Lock() w.metrics.TasksFailed++ w.metrics.mu.Unlock() return } // Update task status to COMPLETED if err := w.updateTaskStatus(ctx, taskMsg.TaskID, StatusCompleted); err != nil { logger.Error(ctx, "Failed to update task status to completed", "error", err) // Still ack the message since task was processed successfully } // Acknowledge message msg.Ack(false) logger.Info(ctx, "Task completed successfully", "task_id", taskMsg.TaskID, "task_type", taskMsg.TaskType, "processing_time", processingTime, ) w.metrics.mu.Lock() w.metrics.TasksProcessed++ w.metrics.mu.Unlock() } // updateTaskStatus updates the status of a task in the database func (w *TaskWorker) updateTaskStatus(ctx context.Context, taskID uuid.UUID, status TaskStatus) error { // This is a simplified version. In a real implementation, you would: // 1. Have a proper task table/model // 2. Update the task record with status and timestamps // For now, we'll log the update logger.Debug(ctx, "Updating task status", "task_id", taskID, "status", status, ) return nil } // updateTaskWithError updates a task with error information func (w *TaskWorker) updateTaskWithError(ctx context.Context, taskID uuid.UUID, err error) error { logger.Debug(ctx, "Updating task with error", "task_id", taskID, "error", err.Error(), ) return nil } // healthCheckLoop periodically checks worker health and metrics func (w *TaskWorker) healthCheckLoop() { defer w.wg.Done() ticker := time.NewTicker(w.cfg.PollingInterval) defer ticker.Stop() for { select { case <-w.stopChan: return case <-ticker.C: w.checkHealth() } } } // checkHealth performs health checks and logs metrics func (w *TaskWorker) checkHealth() { w.metrics.mu.Lock() defer w.metrics.mu.Unlock() // Update queue depth queue, err := w.ch.QueueDeclarePassive( TaskQueueName, // name true, // durable false, // delete when unused false, // exclusive false, // no-wait amqp.Table{ "x-max-priority": MaxPriority, "x-message-ttl": DefaultMessageTTL.Milliseconds(), }, ) if err == nil { w.metrics.QueueDepth = queue.Messages } // Update worker pool stats w.metrics.WorkersActive = w.pool.Running() w.metrics.WorkersIdle = w.pool.Free() w.metrics.LastHealthCheck = time.Now() logger.Info(w.ctx, "Worker health check", "tasks_processed", w.metrics.TasksProcessed, "tasks_failed", w.metrics.TasksFailed, "tasks_in_progress", w.metrics.TasksInProgress, "queue_depth", w.metrics.QueueDepth, "workers_active", w.metrics.WorkersActive, "workers_idle", w.metrics.WorkersIdle, "pool_capacity", w.pool.Cap(), ) } // Stop gracefully stops the task worker func (w *TaskWorker) Stop() error { logger.Info(w.ctx, "Stopping task worker") // Signal all goroutines to stop close(w.stopChan) w.cancel() // Wait for all goroutines to finish w.wg.Wait() // Release worker pool w.pool.Release() // Close channel if w.ch != nil { if err := w.ch.Close(); err != nil { logger.Error(w.ctx, "Failed to close channel", "error", err) } } logger.Info(w.ctx, "Task worker stopped") return nil } // GetMetrics returns current worker metrics func (w *TaskWorker) GetMetrics() *WorkerMetrics { w.metrics.mu.RLock() defer w.metrics.mu.RUnlock() // Create a copy without the mutex to avoid copylocks warning return &WorkerMetrics{ TasksProcessed: w.metrics.TasksProcessed, TasksFailed: w.metrics.TasksFailed, TasksInProgress: w.metrics.TasksInProgress, QueueDepth: w.metrics.QueueDepth, WorkersActive: w.metrics.WorkersActive, WorkersIdle: w.metrics.WorkersIdle, LastHealthCheck: w.metrics.LastHealthCheck, // Mutex is intentionally omitted } } // IsHealthy returns true if the worker is healthy func (w *TaskWorker) IsHealthy() bool { w.metrics.mu.RLock() defer w.metrics.mu.RUnlock() // Consider unhealthy if last health check was too long ago return time.Since(w.metrics.LastHealthCheck) < 2*w.cfg.PollingInterval }