// Package task provides asynchronous task processing with worker pools package task import ( "context" "encoding/json" "fmt" "sync" "time" "modelRT/config" "modelRT/database" "modelRT/logger" "modelRT/mq" "modelRT/orm" "github.com/gofrs/uuid" "github.com/panjf2000/ants/v2" amqp "github.com/rabbitmq/amqp091-go" "gorm.io/gorm" ) // WorkerConfig holds configuration for the task worker type WorkerConfig struct { // PoolSize is the number of worker goroutines in the pool PoolSize int // PreAlloc indicates whether to pre-allocate memory for the pool PreAlloc bool // MaxBlockingTasks is the maximum number of tasks waiting in queue MaxBlockingTasks int // QueueConsumerCount is the number of concurrent RabbitMQ consumers QueueConsumerCount int // PollingInterval is the interval between health checks PollingInterval time.Duration } // DefaultWorkerConfig returns the default worker configuration func DefaultWorkerConfig() WorkerConfig { return WorkerConfig{ PoolSize: 10, PreAlloc: true, MaxBlockingTasks: 100, QueueConsumerCount: 2, PollingInterval: 30 * time.Second, } } // TaskWorker manages a pool of workers for processing asynchronous tasks type TaskWorker struct { cfg WorkerConfig db *gorm.DB pool *ants.Pool conn *amqp.Connection ch *amqp.Channel handler TaskHandler retryQueue *RetryQueue stopChan chan struct{} wg sync.WaitGroup ctx context.Context cancel context.CancelFunc metrics *WorkerMetrics } // WorkerMetrics holds metrics for the worker pool type WorkerMetrics struct { // Task statistics by type TasksProcessed map[TaskType]int64 TasksFailed map[TaskType]int64 TasksSuccess map[TaskType]int64 ProcessingTime map[TaskType]time.Duration // Aggregate counters (maintained for backward compatibility) TotalProcessed int64 TotalFailed int64 TotalSuccess int64 TasksInProgress int32 // Queue and latency metrics QueueDepth int QueueLatency time.Duration // Worker resource metrics WorkersActive int WorkersIdle int MemoryUsage uint64 CPULoad float64 // Time window metrics LastMinuteRate float64 Last5MinutesRate float64 LastHourRate float64 // Health and timing LastHealthCheck time.Time mu sync.RWMutex } // NewTaskWorker creates a new TaskWorker instance func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg config.RabbitMQConfig, handler TaskHandler) (*TaskWorker, error) { // Initialize RabbitMQ connection mq.InitRabbitProxy(ctx, rabbitCfg) conn := mq.GetConn() if conn == nil { return nil, fmt.Errorf("failed to get RabbitMQ connection") } // Create channel ch, err := conn.Channel() if err != nil { return nil, fmt.Errorf("failed to open channel: %w", err) } // Declare queue (ensure it exists with proper arguments) _, err = ch.QueueDeclare( TaskQueueName, // name true, // durable false, // delete when unused false, // exclusive false, // no-wait amqp.Table{ "x-max-priority": MaxPriority, "x-message-ttl": DefaultMessageTTL.Milliseconds(), }, ) if err != nil { ch.Close() return nil, fmt.Errorf("failed to declare queue: %w", err) } // Set QoS (quality of service) for fair dispatch err = ch.Qos( 1, // prefetch count 0, // prefetch size false, // global ) if err != nil { ch.Close() return nil, fmt.Errorf("failed to set QoS: %w", err) } // Create ants pool pool, err := ants.NewPool(cfg.PoolSize, ants.WithPreAlloc(cfg.PreAlloc)) if err != nil { ch.Close() return nil, fmt.Errorf("failed to create worker pool: %w", err) } ctxWithCancel, cancel := context.WithCancel(ctx) worker := &TaskWorker{ cfg: cfg, db: db, pool: pool, conn: conn, ch: ch, handler: handler, stopChan: make(chan struct{}), ctx: ctxWithCancel, cancel: cancel, metrics: &WorkerMetrics{ TasksProcessed: make(map[TaskType]int64), TasksFailed: make(map[TaskType]int64), TasksSuccess: make(map[TaskType]int64), ProcessingTime: make(map[TaskType]time.Duration), LastHealthCheck: time.Now(), }, } return worker, nil } // Start begins consuming tasks from the queue func (w *TaskWorker) Start() error { logger.Info(w.ctx, "Starting task worker", "pool_size", w.cfg.PoolSize, "queue_consumers", w.cfg.QueueConsumerCount, ) // Start multiple consumers for better throughput for i := 0; i < w.cfg.QueueConsumerCount; i++ { w.wg.Add(1) go w.consumerLoop(i) } // Start health check goroutine w.wg.Add(1) go w.healthCheckLoop() logger.Info(w.ctx, "Task worker started successfully") return nil } // consumerLoop runs a single RabbitMQ consumer func (w *TaskWorker) consumerLoop(consumerID int) { defer w.wg.Done() logger.Info(w.ctx, "Starting consumer", "consumer_id", consumerID) // Consume messages from the queue msgs, err := w.ch.Consume( TaskQueueName, // queue fmt.Sprintf("worker-%d", consumerID), // consumer tag false, // auto-ack false, // exclusive false, // no-local false, // no-wait nil, // args ) if err != nil { logger.Error(w.ctx, "Failed to start consumer", "consumer_id", consumerID, "error", err, ) return } for { select { case <-w.stopChan: logger.Info(w.ctx, "Consumer stopping", "consumer_id", consumerID) return case msg, ok := <-msgs: if !ok { logger.Warn(w.ctx, "Consumer channel closed", "consumer_id", consumerID) return } // Process message in worker pool err := w.pool.Submit(func() { w.handleMessage(msg) }) if err != nil { logger.Error(w.ctx, "Failed to submit task to pool", "consumer_id", consumerID, "error", err, ) // Reject message and requeue msg.Nack(false, true) } } } } // handleMessage processes a single RabbitMQ message func (w *TaskWorker) handleMessage(msg amqp.Delivery) { w.metrics.mu.Lock() w.metrics.TasksInProgress++ w.metrics.mu.Unlock() defer func() { w.metrics.mu.Lock() w.metrics.TasksInProgress-- w.metrics.mu.Unlock() }() ctx := w.ctx // Parse task message var taskMsg TaskQueueMessage if err := json.Unmarshal(msg.Body, &taskMsg); err != nil { logger.Error(ctx, "Failed to unmarshal task message", "error", err) msg.Nack(false, false) // Reject without requeue w.metrics.mu.Lock() w.metrics.TotalFailed++ w.metrics.mu.Unlock() return } // Validate message if !taskMsg.Validate() { logger.Error(ctx, "Invalid task message", "task_id", taskMsg.TaskID, "task_type", taskMsg.TaskType, ) msg.Nack(false, false) // Reject without requeue w.metrics.mu.Lock() w.metrics.TotalFailed++ // Also update per-task-type failure count w.metrics.TasksFailed[taskMsg.TaskType]++ w.metrics.mu.Unlock() return } logger.Info(ctx, "Processing task", "task_id", taskMsg.TaskID, "task_type", taskMsg.TaskType, "priority", taskMsg.Priority, ) // Update task status to RUNNING in database if err := w.updateTaskStatus(ctx, taskMsg.TaskID, StatusRunning); err != nil { logger.Error(ctx, "Failed to update task status", "error", err) msg.Nack(false, true) // Reject with requeue w.metrics.mu.Lock() w.metrics.TotalFailed++ w.metrics.TasksFailed[taskMsg.TaskType]++ w.metrics.mu.Unlock() return } // Execute task using handler startTime := time.Now() err := w.handler.Execute(ctx, taskMsg.TaskID, taskMsg.TaskType, w.db) processingTime := time.Since(startTime) if err != nil { logger.Error(ctx, "Task execution failed", "task_id", taskMsg.TaskID, "task_type", taskMsg.TaskType, "processing_time", processingTime, "error", err, ) // Update task status to FAILED if updateErr := w.updateTaskWithError(ctx, taskMsg.TaskID, err); updateErr != nil { logger.Error(ctx, "Failed to update task with error", "error", updateErr) } // Ack message even if task failed (we don't want to retry indefinitely) msg.Ack(false) w.metrics.mu.Lock() w.metrics.TotalFailed++ w.metrics.TasksFailed[taskMsg.TaskType]++ w.metrics.mu.Unlock() return } // Update task status to COMPLETED if err := w.updateTaskStatus(ctx, taskMsg.TaskID, StatusCompleted); err != nil { logger.Error(ctx, "Failed to update task status to completed", "error", err) // Still ack the message since task was processed successfully } // Acknowledge message msg.Ack(false) logger.Info(ctx, "Task completed successfully", "task_id", taskMsg.TaskID, "task_type", taskMsg.TaskType, "processing_time", processingTime, ) w.metrics.mu.Lock() w.metrics.TotalProcessed++ w.metrics.TasksProcessed[taskMsg.TaskType]++ w.metrics.TasksSuccess[taskMsg.TaskType]++ w.metrics.ProcessingTime[taskMsg.TaskType] += processingTime w.metrics.mu.Unlock() } // updateTaskStatus updates the status of a task in the database func (w *TaskWorker) updateTaskStatus(ctx context.Context, taskID uuid.UUID, status TaskStatus) error { // Convert TaskStatus to orm.AsyncTaskStatus var ormStatus orm.AsyncTaskStatus switch status { case StatusPending: ormStatus = orm.AsyncTaskStatusSubmitted case StatusRunning: ormStatus = orm.AsyncTaskStatusRunning case StatusCompleted: ormStatus = orm.AsyncTaskStatusCompleted case StatusFailed: ormStatus = orm.AsyncTaskStatusFailed default: return fmt.Errorf("unknown task status: %s", status) } // Update task status in database err := database.UpdateAsyncTaskStatus(ctx, w.db, taskID, ormStatus) if err != nil { logger.Error(ctx, "Failed to update task status in database", "task_id", taskID, "status", status, "error", err, ) return err } // If status is running, update started_at timestamp if status == StatusRunning { startedAt := time.Now().Unix() if err := database.UpdateTaskStarted(ctx, w.db, taskID, startedAt); err != nil { logger.Warn(ctx, "Failed to update task start time", "task_id", taskID, "error", err, ) // Continue despite error } } // If status is completed or failed, update finished_at timestamp if status == StatusCompleted || status == StatusFailed { finishedAt := time.Now().Unix() if status == StatusCompleted { if err := database.CompleteAsyncTask(ctx, w.db, taskID, finishedAt); err != nil { logger.Warn(ctx, "Failed to mark task as completed", "task_id", taskID, "error", err, ) } } else { if err := database.FailAsyncTask(ctx, w.db, taskID, finishedAt); err != nil { logger.Warn(ctx, "Failed to mark task as failed", "task_id", taskID, "error", err, ) } } } logger.Debug(ctx, "Task status updated", "task_id", taskID, "status", status, ) return nil } // updateTaskWithError updates a task with error information func (w *TaskWorker) updateTaskWithError(ctx context.Context, taskID uuid.UUID, err error) error { // Update task error information in database errorMsg := err.Error() stackTrace := fmt.Sprintf("%+v", err) updateErr := database.UpdateTaskErrorInfo(ctx, w.db, taskID, errorMsg, stackTrace) if updateErr != nil { logger.Error(ctx, "Failed to update task error info", "task_id", taskID, "error", updateErr, ) return updateErr } logger.Warn(ctx, "Task failed with error", "task_id", taskID, "error", errorMsg, ) return nil } // healthCheckLoop periodically checks worker health and metrics func (w *TaskWorker) healthCheckLoop() { defer w.wg.Done() ticker := time.NewTicker(w.cfg.PollingInterval) defer ticker.Stop() for { select { case <-w.stopChan: return case <-ticker.C: w.checkHealth() } } } // checkHealth performs health checks and logs metrics func (w *TaskWorker) checkHealth() { w.metrics.mu.Lock() defer w.metrics.mu.Unlock() // Update queue depth queue, err := w.ch.QueueDeclarePassive( TaskQueueName, // name true, // durable false, // delete when unused false, // exclusive false, // no-wait amqp.Table{ "x-max-priority": MaxPriority, "x-message-ttl": DefaultMessageTTL.Milliseconds(), }, ) if err == nil { w.metrics.QueueDepth = queue.Messages } // Update worker pool stats w.metrics.WorkersActive = w.pool.Running() w.metrics.WorkersIdle = w.pool.Free() w.metrics.LastHealthCheck = time.Now() logger.Info(w.ctx, "Worker health check", "tasks_processed", w.metrics.TotalProcessed, "tasks_failed", w.metrics.TotalFailed, "tasks_success", w.metrics.TotalSuccess, "tasks_in_progress", w.metrics.TasksInProgress, "queue_depth", w.metrics.QueueDepth, "queue_latency_ms", w.metrics.QueueLatency.Milliseconds(), "workers_active", w.metrics.WorkersActive, "workers_idle", w.metrics.WorkersIdle, "memory_usage_mb", w.metrics.MemoryUsage/(1024*1024), "cpu_load_percent", w.metrics.CPULoad, "pool_capacity", w.pool.Cap(), ) } // Stop gracefully stops the task worker func (w *TaskWorker) Stop() error { logger.Info(w.ctx, "Stopping task worker") // Signal all goroutines to stop close(w.stopChan) w.cancel() // Wait for all goroutines to finish w.wg.Wait() // Release worker pool w.pool.Release() // Close channel if w.ch != nil { if err := w.ch.Close(); err != nil { logger.Error(w.ctx, "Failed to close channel", "error", err) } } logger.Info(w.ctx, "Task worker stopped") return nil } // GetMetrics returns current worker metrics func (w *TaskWorker) GetMetrics() *WorkerMetrics { w.metrics.mu.RLock() defer w.metrics.mu.RUnlock() // Deep copy maps to avoid data races tasksProcessedCopy := make(map[TaskType]int64) for k, v := range w.metrics.TasksProcessed { tasksProcessedCopy[k] = v } tasksFailedCopy := make(map[TaskType]int64) for k, v := range w.metrics.TasksFailed { tasksFailedCopy[k] = v } tasksSuccessCopy := make(map[TaskType]int64) for k, v := range w.metrics.TasksSuccess { tasksSuccessCopy[k] = v } processingTimeCopy := make(map[TaskType]time.Duration) for k, v := range w.metrics.ProcessingTime { processingTimeCopy[k] = v } // Create a copy without the mutex to avoid copylocks warning return &WorkerMetrics{ TasksProcessed: tasksProcessedCopy, TasksFailed: tasksFailedCopy, TasksSuccess: tasksSuccessCopy, ProcessingTime: processingTimeCopy, TotalProcessed: w.metrics.TotalProcessed, TotalFailed: w.metrics.TotalFailed, TotalSuccess: w.metrics.TotalSuccess, TasksInProgress: w.metrics.TasksInProgress, QueueDepth: w.metrics.QueueDepth, QueueLatency: w.metrics.QueueLatency, WorkersActive: w.metrics.WorkersActive, WorkersIdle: w.metrics.WorkersIdle, MemoryUsage: w.metrics.MemoryUsage, CPULoad: w.metrics.CPULoad, LastMinuteRate: w.metrics.LastMinuteRate, Last5MinutesRate: w.metrics.Last5MinutesRate, LastHourRate: w.metrics.LastHourRate, LastHealthCheck: w.metrics.LastHealthCheck, // Mutex is intentionally omitted } } // IsHealthy returns true if the worker is healthy func (w *TaskWorker) IsHealthy() bool { w.metrics.mu.RLock() defer w.metrics.mu.RUnlock() // Consider unhealthy if last health check was too long ago return time.Since(w.metrics.LastHealthCheck) < 2*w.cfg.PollingInterval } // RecordMetrics periodically records worker metrics to the logging system func (w *TaskWorker) RecordMetrics(interval time.Duration) { go func() { ticker := time.NewTicker(interval) defer ticker.Stop() metricsLogger := NewMetricsLogger(w.ctx) for { select { case <-w.stopChan: return case <-ticker.C: w.metrics.mu.RLock() metricsLogger.LogWorkerMetrics( w.metrics.WorkersActive, w.metrics.WorkersIdle, w.pool.Cap(), w.metrics.MemoryUsage, w.metrics.CPULoad, ) metricsLogger.LogQueueMetrics( w.metrics.QueueDepth, w.metrics.QueueLatency, ) w.metrics.mu.RUnlock() } } }() }