441 lines
11 KiB
Go
441 lines
11 KiB
Go
// Package task provides asynchronous task processing with worker pools
|
|
package task
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"modelRT/config"
|
|
"modelRT/logger"
|
|
"modelRT/mq"
|
|
|
|
"github.com/gofrs/uuid"
|
|
"github.com/panjf2000/ants/v2"
|
|
amqp "github.com/rabbitmq/amqp091-go"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// WorkerConfig holds configuration for the task worker
|
|
type WorkerConfig struct {
|
|
// PoolSize is the number of worker goroutines in the pool
|
|
PoolSize int
|
|
// PreAlloc indicates whether to pre-allocate memory for the pool
|
|
PreAlloc bool
|
|
// MaxBlockingTasks is the maximum number of tasks waiting in queue
|
|
MaxBlockingTasks int
|
|
// QueueConsumerCount is the number of concurrent RabbitMQ consumers
|
|
QueueConsumerCount int
|
|
// PollingInterval is the interval between health checks
|
|
PollingInterval time.Duration
|
|
}
|
|
|
|
// DefaultWorkerConfig returns the default worker configuration
|
|
func DefaultWorkerConfig() WorkerConfig {
|
|
return WorkerConfig{
|
|
PoolSize: 10,
|
|
PreAlloc: true,
|
|
MaxBlockingTasks: 100,
|
|
QueueConsumerCount: 2,
|
|
PollingInterval: 30 * time.Second,
|
|
}
|
|
}
|
|
|
|
// TaskWorker manages a pool of workers for processing asynchronous tasks
|
|
type TaskWorker struct {
|
|
cfg WorkerConfig
|
|
db *gorm.DB
|
|
pool *ants.Pool
|
|
conn *amqp.Connection
|
|
ch *amqp.Channel
|
|
handler TaskHandler
|
|
stopChan chan struct{}
|
|
wg sync.WaitGroup
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
metrics *WorkerMetrics
|
|
}
|
|
|
|
// WorkerMetrics holds metrics for the worker pool
|
|
type WorkerMetrics struct {
|
|
TasksProcessed int64
|
|
TasksFailed int64
|
|
TasksInProgress int32
|
|
QueueDepth int
|
|
WorkersActive int
|
|
WorkersIdle int
|
|
LastHealthCheck time.Time
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewTaskWorker creates a new TaskWorker instance
|
|
func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg config.RabbitMQConfig, handler TaskHandler) (*TaskWorker, error) {
|
|
// Initialize RabbitMQ connection
|
|
mq.InitRabbitProxy(ctx, rabbitCfg)
|
|
conn := mq.GetConn()
|
|
if conn == nil {
|
|
return nil, fmt.Errorf("failed to get RabbitMQ connection")
|
|
}
|
|
|
|
// Create channel
|
|
ch, err := conn.Channel()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open channel: %w", err)
|
|
}
|
|
|
|
// Declare queue (ensure it exists with proper arguments)
|
|
_, err = ch.QueueDeclare(
|
|
TaskQueueName, // name
|
|
true, // durable
|
|
false, // delete when unused
|
|
false, // exclusive
|
|
false, // no-wait
|
|
amqp.Table{
|
|
"x-max-priority": MaxPriority,
|
|
"x-message-ttl": DefaultMessageTTL.Milliseconds(),
|
|
},
|
|
)
|
|
if err != nil {
|
|
ch.Close()
|
|
return nil, fmt.Errorf("failed to declare queue: %w", err)
|
|
}
|
|
|
|
// Set QoS (quality of service) for fair dispatch
|
|
err = ch.Qos(
|
|
1, // prefetch count
|
|
0, // prefetch size
|
|
false, // global
|
|
)
|
|
if err != nil {
|
|
ch.Close()
|
|
return nil, fmt.Errorf("failed to set QoS: %w", err)
|
|
}
|
|
|
|
// Create ants pool
|
|
pool, err := ants.NewPool(cfg.PoolSize, ants.WithPreAlloc(cfg.PreAlloc))
|
|
if err != nil {
|
|
ch.Close()
|
|
return nil, fmt.Errorf("failed to create worker pool: %w", err)
|
|
}
|
|
|
|
ctxWithCancel, cancel := context.WithCancel(ctx)
|
|
|
|
worker := &TaskWorker{
|
|
cfg: cfg,
|
|
db: db,
|
|
pool: pool,
|
|
conn: conn,
|
|
ch: ch,
|
|
handler: handler,
|
|
stopChan: make(chan struct{}),
|
|
ctx: ctxWithCancel,
|
|
cancel: cancel,
|
|
metrics: &WorkerMetrics{
|
|
LastHealthCheck: time.Now(),
|
|
},
|
|
}
|
|
|
|
return worker, nil
|
|
}
|
|
|
|
// Start begins consuming tasks from the queue
|
|
func (w *TaskWorker) Start() error {
|
|
logger.Info(w.ctx, "Starting task worker",
|
|
"pool_size", w.cfg.PoolSize,
|
|
"queue_consumers", w.cfg.QueueConsumerCount,
|
|
)
|
|
|
|
// Start multiple consumers for better throughput
|
|
for i := 0; i < w.cfg.QueueConsumerCount; i++ {
|
|
w.wg.Add(1)
|
|
go w.consumerLoop(i)
|
|
}
|
|
|
|
// Start health check goroutine
|
|
w.wg.Add(1)
|
|
go w.healthCheckLoop()
|
|
|
|
logger.Info(w.ctx, "Task worker started successfully")
|
|
return nil
|
|
}
|
|
|
|
// consumerLoop runs a single RabbitMQ consumer
|
|
func (w *TaskWorker) consumerLoop(consumerID int) {
|
|
defer w.wg.Done()
|
|
|
|
logger.Info(w.ctx, "Starting consumer", "consumer_id", consumerID)
|
|
|
|
// Consume messages from the queue
|
|
msgs, err := w.ch.Consume(
|
|
TaskQueueName, // queue
|
|
fmt.Sprintf("worker-%d", consumerID), // consumer tag
|
|
false, // auto-ack
|
|
false, // exclusive
|
|
false, // no-local
|
|
false, // no-wait
|
|
nil, // args
|
|
)
|
|
if err != nil {
|
|
logger.Error(w.ctx, "Failed to start consumer",
|
|
"consumer_id", consumerID,
|
|
"error", err,
|
|
)
|
|
return
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-w.stopChan:
|
|
logger.Info(w.ctx, "Consumer stopping", "consumer_id", consumerID)
|
|
return
|
|
case msg, ok := <-msgs:
|
|
if !ok {
|
|
logger.Warn(w.ctx, "Consumer channel closed", "consumer_id", consumerID)
|
|
return
|
|
}
|
|
|
|
// Process message in worker pool
|
|
err := w.pool.Submit(func() {
|
|
w.handleMessage(msg)
|
|
})
|
|
if err != nil {
|
|
logger.Error(w.ctx, "Failed to submit task to pool",
|
|
"consumer_id", consumerID,
|
|
"error", err,
|
|
)
|
|
// Reject message and requeue
|
|
msg.Nack(false, true)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleMessage processes a single RabbitMQ message
|
|
func (w *TaskWorker) handleMessage(msg amqp.Delivery) {
|
|
w.metrics.mu.Lock()
|
|
w.metrics.TasksInProgress++
|
|
w.metrics.mu.Unlock()
|
|
|
|
defer func() {
|
|
w.metrics.mu.Lock()
|
|
w.metrics.TasksInProgress--
|
|
w.metrics.mu.Unlock()
|
|
}()
|
|
|
|
ctx := w.ctx
|
|
|
|
// Parse task message
|
|
var taskMsg TaskQueueMessage
|
|
if err := json.Unmarshal(msg.Body, &taskMsg); err != nil {
|
|
logger.Error(ctx, "Failed to unmarshal task message", "error", err)
|
|
msg.Nack(false, false) // Reject without requeue
|
|
w.metrics.mu.Lock()
|
|
w.metrics.TasksFailed++
|
|
w.metrics.mu.Unlock()
|
|
return
|
|
}
|
|
|
|
// Validate message
|
|
if !taskMsg.Validate() {
|
|
logger.Error(ctx, "Invalid task message",
|
|
"task_id", taskMsg.TaskID,
|
|
"task_type", taskMsg.TaskType,
|
|
)
|
|
msg.Nack(false, false) // Reject without requeue
|
|
w.metrics.mu.Lock()
|
|
w.metrics.TasksFailed++
|
|
w.metrics.mu.Unlock()
|
|
return
|
|
}
|
|
|
|
logger.Info(ctx, "Processing task",
|
|
"task_id", taskMsg.TaskID,
|
|
"task_type", taskMsg.TaskType,
|
|
"priority", taskMsg.Priority,
|
|
)
|
|
|
|
// Update task status to RUNNING in database
|
|
if err := w.updateTaskStatus(ctx, taskMsg.TaskID, StatusRunning); err != nil {
|
|
logger.Error(ctx, "Failed to update task status", "error", err)
|
|
msg.Nack(false, true) // Reject with requeue
|
|
w.metrics.mu.Lock()
|
|
w.metrics.TasksFailed++
|
|
w.metrics.mu.Unlock()
|
|
return
|
|
}
|
|
|
|
// Execute task using handler
|
|
startTime := time.Now()
|
|
err := w.handler.Execute(ctx, taskMsg.TaskID, taskMsg.TaskType, w.db)
|
|
processingTime := time.Since(startTime)
|
|
|
|
if err != nil {
|
|
logger.Error(ctx, "Task execution failed",
|
|
"task_id", taskMsg.TaskID,
|
|
"task_type", taskMsg.TaskType,
|
|
"processing_time", processingTime,
|
|
"error", err,
|
|
)
|
|
|
|
// Update task status to FAILED
|
|
if updateErr := w.updateTaskWithError(ctx, taskMsg.TaskID, err); updateErr != nil {
|
|
logger.Error(ctx, "Failed to update task with error", "error", updateErr)
|
|
}
|
|
|
|
// Ack message even if task failed (we don't want to retry indefinitely)
|
|
msg.Ack(false)
|
|
w.metrics.mu.Lock()
|
|
w.metrics.TasksFailed++
|
|
w.metrics.mu.Unlock()
|
|
return
|
|
}
|
|
|
|
// Update task status to COMPLETED
|
|
if err := w.updateTaskStatus(ctx, taskMsg.TaskID, StatusCompleted); err != nil {
|
|
logger.Error(ctx, "Failed to update task status to completed", "error", err)
|
|
// Still ack the message since task was processed successfully
|
|
}
|
|
|
|
// Acknowledge message
|
|
msg.Ack(false)
|
|
|
|
logger.Info(ctx, "Task completed successfully",
|
|
"task_id", taskMsg.TaskID,
|
|
"task_type", taskMsg.TaskType,
|
|
"processing_time", processingTime,
|
|
)
|
|
|
|
w.metrics.mu.Lock()
|
|
w.metrics.TasksProcessed++
|
|
w.metrics.mu.Unlock()
|
|
}
|
|
|
|
// updateTaskStatus updates the status of a task in the database
|
|
func (w *TaskWorker) updateTaskStatus(ctx context.Context, taskID uuid.UUID, status TaskStatus) error {
|
|
// This is a simplified version. In a real implementation, you would:
|
|
// 1. Have a proper task table/model
|
|
// 2. Update the task record with status and timestamps
|
|
|
|
// For now, we'll log the update
|
|
logger.Debug(ctx, "Updating task status",
|
|
"task_id", taskID,
|
|
"status", status,
|
|
)
|
|
return nil
|
|
}
|
|
|
|
// updateTaskWithError updates a task with error information
|
|
func (w *TaskWorker) updateTaskWithError(ctx context.Context, taskID uuid.UUID, err error) error {
|
|
logger.Debug(ctx, "Updating task with error",
|
|
"task_id", taskID,
|
|
"error", err.Error(),
|
|
)
|
|
return nil
|
|
}
|
|
|
|
// healthCheckLoop periodically checks worker health and metrics
|
|
func (w *TaskWorker) healthCheckLoop() {
|
|
defer w.wg.Done()
|
|
|
|
ticker := time.NewTicker(w.cfg.PollingInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-w.stopChan:
|
|
return
|
|
case <-ticker.C:
|
|
w.checkHealth()
|
|
}
|
|
}
|
|
}
|
|
|
|
// checkHealth performs health checks and logs metrics
|
|
func (w *TaskWorker) checkHealth() {
|
|
w.metrics.mu.Lock()
|
|
defer w.metrics.mu.Unlock()
|
|
|
|
// Update queue depth
|
|
queue, err := w.ch.QueueDeclarePassive(
|
|
TaskQueueName, // name
|
|
true, // durable
|
|
false, // delete when unused
|
|
false, // exclusive
|
|
false, // no-wait
|
|
amqp.Table{
|
|
"x-max-priority": MaxPriority,
|
|
"x-message-ttl": DefaultMessageTTL.Milliseconds(),
|
|
},
|
|
)
|
|
if err == nil {
|
|
w.metrics.QueueDepth = queue.Messages
|
|
}
|
|
|
|
// Update worker pool stats
|
|
w.metrics.WorkersActive = w.pool.Running()
|
|
w.metrics.WorkersIdle = w.pool.Free()
|
|
w.metrics.LastHealthCheck = time.Now()
|
|
|
|
logger.Info(w.ctx, "Worker health check",
|
|
"tasks_processed", w.metrics.TasksProcessed,
|
|
"tasks_failed", w.metrics.TasksFailed,
|
|
"tasks_in_progress", w.metrics.TasksInProgress,
|
|
"queue_depth", w.metrics.QueueDepth,
|
|
"workers_active", w.metrics.WorkersActive,
|
|
"workers_idle", w.metrics.WorkersIdle,
|
|
"pool_capacity", w.pool.Cap(),
|
|
)
|
|
}
|
|
|
|
// Stop gracefully stops the task worker
|
|
func (w *TaskWorker) Stop() error {
|
|
logger.Info(w.ctx, "Stopping task worker")
|
|
|
|
// Signal all goroutines to stop
|
|
close(w.stopChan)
|
|
w.cancel()
|
|
|
|
// Wait for all goroutines to finish
|
|
w.wg.Wait()
|
|
|
|
// Release worker pool
|
|
w.pool.Release()
|
|
|
|
// Close channel
|
|
if w.ch != nil {
|
|
if err := w.ch.Close(); err != nil {
|
|
logger.Error(w.ctx, "Failed to close channel", "error", err)
|
|
}
|
|
}
|
|
|
|
logger.Info(w.ctx, "Task worker stopped")
|
|
return nil
|
|
}
|
|
|
|
// GetMetrics returns current worker metrics
|
|
func (w *TaskWorker) GetMetrics() *WorkerMetrics {
|
|
w.metrics.mu.RLock()
|
|
defer w.metrics.mu.RUnlock()
|
|
// Create a copy without the mutex to avoid copylocks warning
|
|
return &WorkerMetrics{
|
|
TasksProcessed: w.metrics.TasksProcessed,
|
|
TasksFailed: w.metrics.TasksFailed,
|
|
TasksInProgress: w.metrics.TasksInProgress,
|
|
QueueDepth: w.metrics.QueueDepth,
|
|
WorkersActive: w.metrics.WorkersActive,
|
|
WorkersIdle: w.metrics.WorkersIdle,
|
|
LastHealthCheck: w.metrics.LastHealthCheck,
|
|
// Mutex is intentionally omitted
|
|
}
|
|
}
|
|
|
|
// IsHealthy returns true if the worker is healthy
|
|
func (w *TaskWorker) IsHealthy() bool {
|
|
w.metrics.mu.RLock()
|
|
defer w.metrics.mu.RUnlock()
|
|
|
|
// Consider unhealthy if last health check was too long ago
|
|
return time.Since(w.metrics.LastHealthCheck) < 2*w.cfg.PollingInterval
|
|
} |