modelRT/task/worker.go

441 lines
11 KiB
Go

// Package task provides asynchronous task processing with worker pools
package task
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"modelRT/config"
"modelRT/logger"
"modelRT/mq"
"github.com/gofrs/uuid"
"github.com/panjf2000/ants/v2"
amqp "github.com/rabbitmq/amqp091-go"
"gorm.io/gorm"
)
// WorkerConfig holds configuration for the task worker
type WorkerConfig struct {
// PoolSize is the number of worker goroutines in the pool
PoolSize int
// PreAlloc indicates whether to pre-allocate memory for the pool
PreAlloc bool
// MaxBlockingTasks is the maximum number of tasks waiting in queue
MaxBlockingTasks int
// QueueConsumerCount is the number of concurrent RabbitMQ consumers
QueueConsumerCount int
// PollingInterval is the interval between health checks
PollingInterval time.Duration
}
// DefaultWorkerConfig returns the default worker configuration
func DefaultWorkerConfig() WorkerConfig {
return WorkerConfig{
PoolSize: 10,
PreAlloc: true,
MaxBlockingTasks: 100,
QueueConsumerCount: 2,
PollingInterval: 30 * time.Second,
}
}
// TaskWorker manages a pool of workers for processing asynchronous tasks
type TaskWorker struct {
cfg WorkerConfig
db *gorm.DB
pool *ants.Pool
conn *amqp.Connection
ch *amqp.Channel
handler TaskHandler
stopChan chan struct{}
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
metrics *WorkerMetrics
}
// WorkerMetrics holds metrics for the worker pool
type WorkerMetrics struct {
TasksProcessed int64
TasksFailed int64
TasksInProgress int32
QueueDepth int
WorkersActive int
WorkersIdle int
LastHealthCheck time.Time
mu sync.RWMutex
}
// NewTaskWorker creates a new TaskWorker instance
func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg config.RabbitMQConfig, handler TaskHandler) (*TaskWorker, error) {
// Initialize RabbitMQ connection
mq.InitRabbitProxy(ctx, rabbitCfg)
conn := mq.GetConn()
if conn == nil {
return nil, fmt.Errorf("failed to get RabbitMQ connection")
}
// Create channel
ch, err := conn.Channel()
if err != nil {
return nil, fmt.Errorf("failed to open channel: %w", err)
}
// Declare queue (ensure it exists with proper arguments)
_, err = ch.QueueDeclare(
TaskQueueName, // name
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
amqp.Table{
"x-max-priority": MaxPriority,
"x-message-ttl": DefaultMessageTTL.Milliseconds(),
},
)
if err != nil {
ch.Close()
return nil, fmt.Errorf("failed to declare queue: %w", err)
}
// Set QoS (quality of service) for fair dispatch
err = ch.Qos(
1, // prefetch count
0, // prefetch size
false, // global
)
if err != nil {
ch.Close()
return nil, fmt.Errorf("failed to set QoS: %w", err)
}
// Create ants pool
pool, err := ants.NewPool(cfg.PoolSize, ants.WithPreAlloc(cfg.PreAlloc))
if err != nil {
ch.Close()
return nil, fmt.Errorf("failed to create worker pool: %w", err)
}
ctxWithCancel, cancel := context.WithCancel(ctx)
worker := &TaskWorker{
cfg: cfg,
db: db,
pool: pool,
conn: conn,
ch: ch,
handler: handler,
stopChan: make(chan struct{}),
ctx: ctxWithCancel,
cancel: cancel,
metrics: &WorkerMetrics{
LastHealthCheck: time.Now(),
},
}
return worker, nil
}
// Start begins consuming tasks from the queue
func (w *TaskWorker) Start() error {
logger.Info(w.ctx, "Starting task worker",
"pool_size", w.cfg.PoolSize,
"queue_consumers", w.cfg.QueueConsumerCount,
)
// Start multiple consumers for better throughput
for i := 0; i < w.cfg.QueueConsumerCount; i++ {
w.wg.Add(1)
go w.consumerLoop(i)
}
// Start health check goroutine
w.wg.Add(1)
go w.healthCheckLoop()
logger.Info(w.ctx, "Task worker started successfully")
return nil
}
// consumerLoop runs a single RabbitMQ consumer
func (w *TaskWorker) consumerLoop(consumerID int) {
defer w.wg.Done()
logger.Info(w.ctx, "Starting consumer", "consumer_id", consumerID)
// Consume messages from the queue
msgs, err := w.ch.Consume(
TaskQueueName, // queue
fmt.Sprintf("worker-%d", consumerID), // consumer tag
false, // auto-ack
false, // exclusive
false, // no-local
false, // no-wait
nil, // args
)
if err != nil {
logger.Error(w.ctx, "Failed to start consumer",
"consumer_id", consumerID,
"error", err,
)
return
}
for {
select {
case <-w.stopChan:
logger.Info(w.ctx, "Consumer stopping", "consumer_id", consumerID)
return
case msg, ok := <-msgs:
if !ok {
logger.Warn(w.ctx, "Consumer channel closed", "consumer_id", consumerID)
return
}
// Process message in worker pool
err := w.pool.Submit(func() {
w.handleMessage(msg)
})
if err != nil {
logger.Error(w.ctx, "Failed to submit task to pool",
"consumer_id", consumerID,
"error", err,
)
// Reject message and requeue
msg.Nack(false, true)
}
}
}
}
// handleMessage processes a single RabbitMQ message
func (w *TaskWorker) handleMessage(msg amqp.Delivery) {
w.metrics.mu.Lock()
w.metrics.TasksInProgress++
w.metrics.mu.Unlock()
defer func() {
w.metrics.mu.Lock()
w.metrics.TasksInProgress--
w.metrics.mu.Unlock()
}()
ctx := w.ctx
// Parse task message
var taskMsg TaskQueueMessage
if err := json.Unmarshal(msg.Body, &taskMsg); err != nil {
logger.Error(ctx, "Failed to unmarshal task message", "error", err)
msg.Nack(false, false) // Reject without requeue
w.metrics.mu.Lock()
w.metrics.TasksFailed++
w.metrics.mu.Unlock()
return
}
// Validate message
if !taskMsg.Validate() {
logger.Error(ctx, "Invalid task message",
"task_id", taskMsg.TaskID,
"task_type", taskMsg.TaskType,
)
msg.Nack(false, false) // Reject without requeue
w.metrics.mu.Lock()
w.metrics.TasksFailed++
w.metrics.mu.Unlock()
return
}
logger.Info(ctx, "Processing task",
"task_id", taskMsg.TaskID,
"task_type", taskMsg.TaskType,
"priority", taskMsg.Priority,
)
// Update task status to RUNNING in database
if err := w.updateTaskStatus(ctx, taskMsg.TaskID, StatusRunning); err != nil {
logger.Error(ctx, "Failed to update task status", "error", err)
msg.Nack(false, true) // Reject with requeue
w.metrics.mu.Lock()
w.metrics.TasksFailed++
w.metrics.mu.Unlock()
return
}
// Execute task using handler
startTime := time.Now()
err := w.handler.Execute(ctx, taskMsg.TaskID, taskMsg.TaskType, w.db)
processingTime := time.Since(startTime)
if err != nil {
logger.Error(ctx, "Task execution failed",
"task_id", taskMsg.TaskID,
"task_type", taskMsg.TaskType,
"processing_time", processingTime,
"error", err,
)
// Update task status to FAILED
if updateErr := w.updateTaskWithError(ctx, taskMsg.TaskID, err); updateErr != nil {
logger.Error(ctx, "Failed to update task with error", "error", updateErr)
}
// Ack message even if task failed (we don't want to retry indefinitely)
msg.Ack(false)
w.metrics.mu.Lock()
w.metrics.TasksFailed++
w.metrics.mu.Unlock()
return
}
// Update task status to COMPLETED
if err := w.updateTaskStatus(ctx, taskMsg.TaskID, StatusCompleted); err != nil {
logger.Error(ctx, "Failed to update task status to completed", "error", err)
// Still ack the message since task was processed successfully
}
// Acknowledge message
msg.Ack(false)
logger.Info(ctx, "Task completed successfully",
"task_id", taskMsg.TaskID,
"task_type", taskMsg.TaskType,
"processing_time", processingTime,
)
w.metrics.mu.Lock()
w.metrics.TasksProcessed++
w.metrics.mu.Unlock()
}
// updateTaskStatus updates the status of a task in the database
func (w *TaskWorker) updateTaskStatus(ctx context.Context, taskID uuid.UUID, status TaskStatus) error {
// This is a simplified version. In a real implementation, you would:
// 1. Have a proper task table/model
// 2. Update the task record with status and timestamps
// For now, we'll log the update
logger.Debug(ctx, "Updating task status",
"task_id", taskID,
"status", status,
)
return nil
}
// updateTaskWithError updates a task with error information
func (w *TaskWorker) updateTaskWithError(ctx context.Context, taskID uuid.UUID, err error) error {
logger.Debug(ctx, "Updating task with error",
"task_id", taskID,
"error", err.Error(),
)
return nil
}
// healthCheckLoop periodically checks worker health and metrics
func (w *TaskWorker) healthCheckLoop() {
defer w.wg.Done()
ticker := time.NewTicker(w.cfg.PollingInterval)
defer ticker.Stop()
for {
select {
case <-w.stopChan:
return
case <-ticker.C:
w.checkHealth()
}
}
}
// checkHealth performs health checks and logs metrics
func (w *TaskWorker) checkHealth() {
w.metrics.mu.Lock()
defer w.metrics.mu.Unlock()
// Update queue depth
queue, err := w.ch.QueueDeclarePassive(
TaskQueueName, // name
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
amqp.Table{
"x-max-priority": MaxPriority,
"x-message-ttl": DefaultMessageTTL.Milliseconds(),
},
)
if err == nil {
w.metrics.QueueDepth = queue.Messages
}
// Update worker pool stats
w.metrics.WorkersActive = w.pool.Running()
w.metrics.WorkersIdle = w.pool.Free()
w.metrics.LastHealthCheck = time.Now()
logger.Info(w.ctx, "Worker health check",
"tasks_processed", w.metrics.TasksProcessed,
"tasks_failed", w.metrics.TasksFailed,
"tasks_in_progress", w.metrics.TasksInProgress,
"queue_depth", w.metrics.QueueDepth,
"workers_active", w.metrics.WorkersActive,
"workers_idle", w.metrics.WorkersIdle,
"pool_capacity", w.pool.Cap(),
)
}
// Stop gracefully stops the task worker
func (w *TaskWorker) Stop() error {
logger.Info(w.ctx, "Stopping task worker")
// Signal all goroutines to stop
close(w.stopChan)
w.cancel()
// Wait for all goroutines to finish
w.wg.Wait()
// Release worker pool
w.pool.Release()
// Close channel
if w.ch != nil {
if err := w.ch.Close(); err != nil {
logger.Error(w.ctx, "Failed to close channel", "error", err)
}
}
logger.Info(w.ctx, "Task worker stopped")
return nil
}
// GetMetrics returns current worker metrics
func (w *TaskWorker) GetMetrics() *WorkerMetrics {
w.metrics.mu.RLock()
defer w.metrics.mu.RUnlock()
// Create a copy without the mutex to avoid copylocks warning
return &WorkerMetrics{
TasksProcessed: w.metrics.TasksProcessed,
TasksFailed: w.metrics.TasksFailed,
TasksInProgress: w.metrics.TasksInProgress,
QueueDepth: w.metrics.QueueDepth,
WorkersActive: w.metrics.WorkersActive,
WorkersIdle: w.metrics.WorkersIdle,
LastHealthCheck: w.metrics.LastHealthCheck,
// Mutex is intentionally omitted
}
}
// IsHealthy returns true if the worker is healthy
func (w *TaskWorker) IsHealthy() bool {
w.metrics.mu.RLock()
defer w.metrics.mu.RUnlock()
// Consider unhealthy if last health check was too long ago
return time.Since(w.metrics.LastHealthCheck) < 2*w.cfg.PollingInterval
}