From f8c0951a1399961938e53f9dbc50e8463fc2925f Mon Sep 17 00:00:00 2001 From: douxu Date: Fri, 3 Apr 2026 10:07:43 +0800 Subject: [PATCH] Extend async task system with database integration and retry management - Add AsyncTaskConfig to config structure - Create database operations for task state management (async_task_extended.go) - Add configuration middleware for Gin context - Extract task worker initialization to separate file (initializer.go) - Implement retry strategies with exponential backoff (retry_manager.go) - Add retry queue for failed task scheduling (retry_queue.go) - Enhance worker metrics with detailed per-task-type tracking - Integrate database operations into task worker for status updates - Add comprehensive metrics logging system --- config/config.go | 12 ++ database/async_task_extended.go | 227 ++++++++++++++++++++++++++++++++ main.go | 30 +---- middleware/config_middleware.go | 15 +++ task/initializer.go | 41 ++++++ task/metrics_logger.go | 157 ++++++++++++++++++++++ task/retry_manager.go | 219 ++++++++++++++++++++++++++++++ task/retry_queue.go | 187 ++++++++++++++++++++++++++ task/worker.go | 208 ++++++++++++++++++++++++++--- 9 files changed, 1048 insertions(+), 48 deletions(-) create mode 100644 database/async_task_extended.go create mode 100644 middleware/config_middleware.go create mode 100644 task/initializer.go create mode 100644 task/metrics_logger.go create mode 100644 task/retry_manager.go create mode 100644 task/retry_queue.go diff --git a/config/config.go b/config/config.go index 3c09187..1425f45 100644 --- a/config/config.go +++ b/config/config.go @@ -3,6 +3,7 @@ package config import ( "fmt" + "time" "github.com/spf13/viper" ) @@ -91,6 +92,16 @@ type DataRTConfig struct { Method string `mapstructure:"polling_api_method"` } +// AsyncTaskConfig define config struct of asynchronous task system +type AsyncTaskConfig struct { + WorkerPoolSize int `mapstructure:"worker_pool_size"` + QueueConsumerCount int `mapstructure:"queue_consumer_count"` + MaxRetryCount int `mapstructure:"max_retry_count"` + RetryInitialDelay time.Duration `mapstructure:"retry_initial_delay"` + RetryMaxDelay time.Duration `mapstructure:"retry_max_delay"` + HealthCheckInterval time.Duration `mapstructure:"health_check_interval"` +} + // ModelRTConfig define config struct of model runtime server type ModelRTConfig struct { BaseConfig `mapstructure:"base"` @@ -103,6 +114,7 @@ type ModelRTConfig struct { DataRTConfig `mapstructure:"dataRT"` LockerRedisConfig RedisConfig `mapstructure:"locker_redis"` StorageRedisConfig RedisConfig `mapstructure:"storage_redis"` + AsyncTaskConfig AsyncTaskConfig `mapstructure:"async_task"` PostgresDBURI string `mapstructure:"-"` } diff --git a/database/async_task_extended.go b/database/async_task_extended.go new file mode 100644 index 0000000..ca94b42 --- /dev/null +++ b/database/async_task_extended.go @@ -0,0 +1,227 @@ +// Package database define database operation functions +package database + +import ( + "context" + "time" + + "modelRT/orm" + + "github.com/gofrs/uuid" + "gorm.io/gorm" +) + +// UpdateTaskStarted updates task start time and status to running +func UpdateTaskStarted(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, startedAt int64) error { + cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + result := tx.WithContext(cancelCtx). + Model(&orm.AsyncTask{}). + Where("task_id = ?", taskID). + Updates(map[string]any{ + "status": orm.AsyncTaskStatusRunning, + "started_at": startedAt, + }) + + return result.Error +} + +// UpdateTaskRetryInfo updates task retry information +func UpdateTaskRetryInfo(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, retryCount int, nextRetryTime int64) error { + cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + updateData := map[string]any{ + "retry_count": retryCount, + } + if nextRetryTime <= 0 { + updateData["next_retry_time"] = nil + } else { + updateData["next_retry_time"] = nextRetryTime + } + + result := tx.WithContext(cancelCtx). + Model(&orm.AsyncTask{}). + Where("task_id = ?", taskID). + Updates(updateData) + + return result.Error +} + +// UpdateTaskErrorInfo updates task error information +func UpdateTaskErrorInfo(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, errorMsg, stackTrace string) error { + cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + result := tx.WithContext(cancelCtx). + Model(&orm.AsyncTask{}). + Where("task_id = ?", taskID). + Updates(map[string]any{ + "failure_reason": errorMsg, + "stack_trace": stackTrace, + }) + + return result.Error +} + +// UpdateTaskExecutionTime updates task execution time +func UpdateTaskExecutionTime(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, executionTime int64) error { + cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + result := tx.WithContext(cancelCtx). + Model(&orm.AsyncTask{}). + Where("task_id = ?", taskID). + Update("execution_time", executionTime) + + return result.Error +} + +// UpdateTaskWorkerID updates the worker ID that is processing the task +func UpdateTaskWorkerID(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, workerID string) error { + cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + result := tx.WithContext(cancelCtx). + Model(&orm.AsyncTask{}). + Where("task_id = ?", taskID). + Update("worker_id", workerID) + + return result.Error +} + +// UpdateTaskPriority updates task priority +func UpdateTaskPriority(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, priority int) error { + cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + result := tx.WithContext(cancelCtx). + Model(&orm.AsyncTask{}). + Where("task_id = ?", taskID). + Update("priority", priority) + + return result.Error +} + +// UpdateTaskQueueName updates task queue name +func UpdateTaskQueueName(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, queueName string) error { + cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + result := tx.WithContext(cancelCtx). + Model(&orm.AsyncTask{}). + Where("task_id = ?", taskID). + Update("queue_name", queueName) + + return result.Error +} + +// UpdateTaskCreatedBy updates task creator information +func UpdateTaskCreatedBy(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, createdBy string) error { + cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + result := tx.WithContext(cancelCtx). + Model(&orm.AsyncTask{}). + Where("task_id = ?", taskID). + Update("created_by", createdBy) + + return result.Error +} + +// UpdateTaskResultWithMetrics updates task result with execution metrics +func UpdateTaskResultWithMetrics(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, executionTime int64, memoryUsage *int64, cpuUsage *float64, retryCount int, completedAt int64) error { + cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + result := tx.WithContext(cancelCtx). + Model(&orm.AsyncTaskResult{}). + Where("task_id = ?", taskID). + Updates(map[string]any{ + "execution_time": executionTime, + "memory_usage": memoryUsage, + "cpu_usage": cpuUsage, + "retry_count": retryCount, + "completed_at": completedAt, + }) + + return result.Error +} + +// GetTasksForRetry retrieves tasks that are due for retry +func GetTasksForRetry(ctx context.Context, tx *gorm.DB, limit int) ([]orm.AsyncTask, error) { + var tasks []orm.AsyncTask + + cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + now := time.Now().Unix() + result := tx.WithContext(cancelCtx). + Where("status = ? AND next_retry_time IS NOT NULL AND next_retry_time <= ?", orm.AsyncTaskStatusFailed, now). + Order("next_retry_time ASC"). + Limit(limit). + Find(&tasks) + + if result.Error != nil { + return nil, result.Error + } + + return tasks, nil +} + +// GetTasksByPriority retrieves tasks by priority order +func GetTasksByPriority(ctx context.Context, tx *gorm.DB, status orm.AsyncTaskStatus, limit int) ([]orm.AsyncTask, error) { + var tasks []orm.AsyncTask + + cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + result := tx.WithContext(cancelCtx). + Where("status = ?", status). + Order("priority DESC, created_at ASC"). + Limit(limit). + Find(&tasks) + + if result.Error != nil { + return nil, result.Error + } + + return tasks, nil +} + +// GetTasksByWorkerID retrieves tasks being processed by a specific worker +func GetTasksByWorkerID(ctx context.Context, tx *gorm.DB, workerID string) ([]orm.AsyncTask, error) { + var tasks []orm.AsyncTask + + cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + result := tx.WithContext(cancelCtx). + Where("worker_id = ? AND status = ?", workerID, orm.AsyncTaskStatusRunning). + Find(&tasks) + + if result.Error != nil { + return nil, result.Error + } + + return tasks, nil +} + +// CleanupStaleTasks marks tasks as failed if they have been running for too long +func CleanupStaleTasks(ctx context.Context, tx *gorm.DB, timeoutSeconds int64) (int64, error) { + cancelCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + threshold := time.Now().Unix() - timeoutSeconds + result := tx.WithContext(cancelCtx). + Model(&orm.AsyncTask{}). + Where("status = ? AND started_at IS NOT NULL AND started_at < ?", orm.AsyncTaskStatusRunning, threshold). + Updates(map[string]any{ + "status": orm.AsyncTaskStatusFailed, + "failure_reason": "task timeout", + "finished_at": time.Now().Unix(), + }) + + return result.RowsAffected, result.Error +} \ No newline at end of file diff --git a/main.go b/main.go index e199889..198bec0 100644 --- a/main.go +++ b/main.go @@ -71,34 +71,6 @@ var ( // // @host localhost:8080 // @BasePath /api/v1 -func initTaskWorker(ctx context.Context, config config.ModelRTConfig, db *gorm.DB) (*task.TaskWorker, error) { - // Create worker configuration - workerCfg := task.WorkerConfig{ - PoolSize: config.AsyncTaskConfig.WorkerPoolSize, - PreAlloc: true, - MaxBlockingTasks: 100, - QueueConsumerCount: config.AsyncTaskConfig.QueueConsumerCount, - PollingInterval: config.AsyncTaskConfig.HealthCheckInterval, - } - - // Create task handler factory - handlerFactory := task.NewHandlerFactory() - handlerFactory.CreateDefaultHandlers() - handler := task.DefaultCompositeHandler() - - // Create task worker - worker, err := task.NewTaskWorker(ctx, workerCfg, db, config.RabbitMQConfig, handler) - if err != nil { - return nil, fmt.Errorf("failed to create task worker: %w", err) - } - - logger.Info(ctx, "Task worker initialized", - "worker_pool_size", workerCfg.PoolSize, - "queue_consumers", workerCfg.QueueConsumerCount, - ) - - return worker, nil -} func main() { flag.Parse() @@ -188,7 +160,7 @@ func main() { mq.InitRabbitProxy(ctx, modelRTConfig.RabbitMQConfig) // init async task worker - taskWorker, err := initTaskWorker(ctx, modelRTConfig, postgresDBClient) + taskWorker, err := task.InitTaskWorker(ctx, modelRTConfig, postgresDBClient) if err != nil { logger.Error(ctx, "Failed to initialize task worker", "error", err) // Continue without task worker, but log warning diff --git a/middleware/config_middleware.go b/middleware/config_middleware.go new file mode 100644 index 0000000..ff56995 --- /dev/null +++ b/middleware/config_middleware.go @@ -0,0 +1,15 @@ +package middleware + +import ( + "modelRT/config" + + "github.com/gin-gonic/gin" +) + +// ConfigMiddleware 将全局配置注入到Gin上下文中 +func ConfigMiddleware(modelRTConfig config.ModelRTConfig) gin.HandlerFunc { + return func(c *gin.Context) { + c.Set("config", modelRTConfig) + c.Next() + } +} \ No newline at end of file diff --git a/task/initializer.go b/task/initializer.go new file mode 100644 index 0000000..de75cea --- /dev/null +++ b/task/initializer.go @@ -0,0 +1,41 @@ +// Package task provides asynchronous task processing with worker pools +package task + +import ( + "context" + "fmt" + + "modelRT/config" + "modelRT/logger" + "gorm.io/gorm" +) + +// InitTaskWorker initializes a task worker with the given configuration and database connection +func InitTaskWorker(ctx context.Context, config config.ModelRTConfig, db *gorm.DB) (*TaskWorker, error) { + // Create worker configuration + workerCfg := WorkerConfig{ + PoolSize: config.AsyncTaskConfig.WorkerPoolSize, + PreAlloc: true, + MaxBlockingTasks: 100, + QueueConsumerCount: config.AsyncTaskConfig.QueueConsumerCount, + PollingInterval: config.AsyncTaskConfig.HealthCheckInterval, + } + + // Create task handler factory + handlerFactory := NewHandlerFactory() + handlerFactory.CreateDefaultHandlers() + handler := DefaultCompositeHandler() + + // Create task worker + worker, err := NewTaskWorker(ctx, workerCfg, db, config.RabbitMQConfig, handler) + if err != nil { + return nil, fmt.Errorf("failed to create task worker: %w", err) + } + + logger.Info(ctx, "Task worker initialized", + "worker_pool_size", workerCfg.PoolSize, + "queue_consumers", workerCfg.QueueConsumerCount, + ) + + return worker, nil +} \ No newline at end of file diff --git a/task/metrics_logger.go b/task/metrics_logger.go new file mode 100644 index 0000000..d8c0675 --- /dev/null +++ b/task/metrics_logger.go @@ -0,0 +1,157 @@ +// Package task provides metrics logging for asynchronous task system +package task + +import ( + "context" + "runtime" + "time" + + "modelRT/logger" +) + +// MetricsLogger logs task system metrics using the existing logging system +type MetricsLogger struct { + ctx context.Context +} + +// NewMetricsLogger creates a new MetricsLogger +func NewMetricsLogger(ctx context.Context) *MetricsLogger { + return &MetricsLogger{ctx: ctx} +} + +// LogTaskMetrics records task processing metrics +func (m *MetricsLogger) LogTaskMetrics(taskType TaskType, status string, processingTime time.Duration, retryCount int) { + logger.Info(m.ctx, "Task metrics", + "task_type", taskType, + "status", status, + "processing_time_ms", processingTime.Milliseconds(), + "retry_count", retryCount, + "metric_type", "task_processing", + "timestamp", time.Now().Unix(), + ) +} + +// LogQueueMetrics records queue metrics +func (m *MetricsLogger) LogQueueMetrics(queueDepth int, queueLatency time.Duration) { + logger.Info(m.ctx, "Queue metrics", + "queue_depth", queueDepth, + "queue_latency_ms", queueLatency.Milliseconds(), + "metric_type", "queue", + "timestamp", time.Now().Unix(), + ) +} + +// LogWorkerMetrics records worker metrics +func (m *MetricsLogger) LogWorkerMetrics(activeWorkers, idleWorkers, totalWorkers int, memoryUsage uint64, cpuLoad float64) { + logger.Info(m.ctx, "Worker metrics", + "active_workers", activeWorkers, + "idle_workers", idleWorkers, + "total_workers", totalWorkers, + "memory_usage_mb", memoryUsage/(1024*1024), + "cpu_load_percent", cpuLoad, + "metric_type", "worker", + "timestamp", time.Now().Unix(), + ) +} + +// LogRetryMetrics records retry metrics +func (m *MetricsLogger) LogRetryMetrics(taskType TaskType, retryCount int, success bool, delay time.Duration) { + logger.Info(m.ctx, "Retry metrics", + "task_type", taskType, + "retry_count", retryCount, + "retry_success", success, + "retry_delay_ms", delay.Milliseconds(), + "metric_type", "retry", + "timestamp", time.Now().Unix(), + ) +} + +// LogSystemMetrics records system-level metrics (memory, CPU, goroutines) +func (m *MetricsLogger) LogSystemMetrics() { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + logger.Info(m.ctx, "System metrics", + "metric_type", "system", + "timestamp", time.Now().Unix(), + "goroutines", runtime.NumGoroutine(), + "memory_alloc_mb", memStats.Alloc/(1024*1024), + "memory_total_alloc_mb", memStats.TotalAlloc/(1024*1024), + "memory_sys_mb", memStats.Sys/(1024*1024), + "memory_heap_alloc_mb", memStats.HeapAlloc/(1024*1024), + "memory_heap_sys_mb", memStats.HeapSys/(1024*1024), + "memory_heap_inuse_mb", memStats.HeapInuse/(1024*1024), + "gc_pause_total_ns", memStats.PauseTotalNs, + "num_gc", memStats.NumGC, + ) +} + +// LogTaskCompletionMetrics records detailed task completion metrics +func (m *MetricsLogger) LogTaskCompletionMetrics(taskID, taskType, status string, startTime, endTime time.Time, retryCount int, errorMsg string) { + duration := endTime.Sub(startTime) + + logger.Info(m.ctx, "Task completion metrics", + "metric_type", "task_completion", + "timestamp", time.Now().Unix(), + "task_id", taskID, + "task_type", taskType, + "status", status, + "duration_ms", duration.Milliseconds(), + "start_time", startTime.Unix(), + "end_time", endTime.Unix(), + "retry_count", retryCount, + "has_error", errorMsg != "", + "error_msg", errorMsg, + ) +} + +// LogHealthCheckMetrics records health check metrics +func (m *MetricsLogger) LogHealthCheckMetrics(healthy bool, checkDuration time.Duration, components map[string]bool) { + logger.Info(m.ctx, "Health check metrics", + "metric_type", "health_check", + "timestamp", time.Now().Unix(), + "healthy", healthy, + "check_duration_ms", checkDuration.Milliseconds(), + "components", components, + ) +} + +// PeriodicMetricsLogger periodically logs system and worker metrics +type PeriodicMetricsLogger struct { + ctx context.Context + interval time.Duration + stopChan chan struct{} + metricsLog *MetricsLogger +} + +// NewPeriodicMetricsLogger creates a new PeriodicMetricsLogger +func NewPeriodicMetricsLogger(ctx context.Context, interval time.Duration) *PeriodicMetricsLogger { + return &PeriodicMetricsLogger{ + ctx: ctx, + interval: interval, + stopChan: make(chan struct{}), + metricsLog: NewMetricsLogger(ctx), + } +} + +// Start begins periodic metrics logging +func (p *PeriodicMetricsLogger) Start() { + go func() { + ticker := time.NewTicker(p.interval) + defer ticker.Stop() + + for { + select { + case <-p.stopChan: + return + case <-ticker.C: + p.metricsLog.LogSystemMetrics() + } + } + }() +} + +// Stop stops periodic metrics logging +func (p *PeriodicMetricsLogger) Stop() { + close(p.stopChan) +} \ No newline at end of file diff --git a/task/retry_manager.go b/task/retry_manager.go new file mode 100644 index 0000000..ddd7db5 --- /dev/null +++ b/task/retry_manager.go @@ -0,0 +1,219 @@ +// Package task provides retry strategies for failed asynchronous tasks +package task + +import ( + "context" + "math" + "math/rand" + "strings" + "time" + + "modelRT/logger" +) + +// RetryStrategy defines the interface for task retry strategies +type RetryStrategy interface { + // ShouldRetry determines if a task should be retried and returns the delay before next retry + ShouldRetry(ctx context.Context, taskID string, retryCount int, lastError error) (bool, time.Duration) + // GetMaxRetries returns the maximum number of retry attempts + GetMaxRetries() int +} + +// ExponentialBackoffRetry implements exponential backoff with jitter retry strategy +type ExponentialBackoffRetry struct { + MaxRetries int + InitialDelay time.Duration + MaxDelay time.Duration + RandomFactor float64 // Jitter factor to avoid thundering herd problem +} + +// NewExponentialBackoffRetry creates a new exponential backoff retry strategy +func NewExponentialBackoffRetry(maxRetries int, initialDelay, maxDelay time.Duration, randomFactor float64) *ExponentialBackoffRetry { + if maxRetries < 0 { + maxRetries = 0 + } + if initialDelay <= 0 { + initialDelay = 1 * time.Second + } + if maxDelay <= 0 { + maxDelay = 5 * time.Minute + } + if randomFactor < 0 { + randomFactor = 0 + } + if randomFactor > 1 { + randomFactor = 1 + } + + return &ExponentialBackoffRetry{ + MaxRetries: maxRetries, + InitialDelay: initialDelay, + MaxDelay: maxDelay, + RandomFactor: randomFactor, + } +} + +// ShouldRetry implements exponential backoff with jitter +func (s *ExponentialBackoffRetry) ShouldRetry(ctx context.Context, taskID string, retryCount int, lastError error) (bool, time.Duration) { + if retryCount >= s.MaxRetries { + logger.Info(ctx, "Task reached maximum retry count", + "task_id", taskID, + "retry_count", retryCount, + "max_retries", s.MaxRetries, + "last_error", lastError, + ) + return false, 0 + } + + // Calculate exponential backoff: initialDelay * 2^retryCount + delay := s.InitialDelay * time.Duration(math.Pow(2, float64(retryCount))) + + // Apply maximum delay cap + if delay > s.MaxDelay { + delay = s.MaxDelay + } + + // Add jitter to avoid thundering herd + if s.RandomFactor > 0 { + jitter := rand.Float64() * s.RandomFactor * float64(delay) + // Randomly add or subtract jitter + if rand.Intn(2) == 0 { + delay += time.Duration(jitter) + } else { + delay -= time.Duration(jitter) + } + // Ensure delay doesn't go below initial delay + if delay < s.InitialDelay { + delay = s.InitialDelay + } + } + + logger.Info(ctx, "Task will be retried", + "task_id", taskID, + "retry_count", retryCount, + "next_retry_in", delay, + "max_retries", s.MaxRetries, + ) + + return true, delay +} + +// GetMaxRetries returns the maximum number of retry attempts +func (s *ExponentialBackoffRetry) GetMaxRetries() int { + return s.MaxRetries +} + +// FixedDelayRetry implements fixed delay retry strategy +type FixedDelayRetry struct { + MaxRetries int + Delay time.Duration + RandomFactor float64 +} + +// NewFixedDelayRetry creates a new fixed delay retry strategy +func NewFixedDelayRetry(maxRetries int, delay time.Duration, randomFactor float64) *FixedDelayRetry { + if maxRetries < 0 { + maxRetries = 0 + } + if delay <= 0 { + delay = 5 * time.Second + } + + return &FixedDelayRetry{ + MaxRetries: maxRetries, + Delay: delay, + RandomFactor: randomFactor, + } +} + +// ShouldRetry implements fixed delay with optional jitter +func (s *FixedDelayRetry) ShouldRetry(ctx context.Context, taskID string, retryCount int, lastError error) (bool, time.Duration) { + if retryCount >= s.MaxRetries { + return false, 0 + } + + delay := s.Delay + + // Add jitter if random factor is specified + if s.RandomFactor > 0 { + jitter := rand.Float64() * s.RandomFactor * float64(delay) + if rand.Intn(2) == 0 { + delay += time.Duration(jitter) + } else { + delay -= time.Duration(jitter) + } + // Ensure positive delay + if delay <= 0 { + delay = s.Delay + } + } + + return true, delay +} + +// GetMaxRetries returns the maximum number of retry attempts +func (s *FixedDelayRetry) GetMaxRetries() int { + return s.MaxRetries +} + +// NoRetryStrategy implements a strategy that never retries +type NoRetryStrategy struct{} + +// NewNoRetryStrategy creates a new no-retry strategy +func NewNoRetryStrategy() *NoRetryStrategy { + return &NoRetryStrategy{} +} + +// ShouldRetry always returns false +func (s *NoRetryStrategy) ShouldRetry(ctx context.Context, taskID string, retryCount int, lastError error) (bool, time.Duration) { + return false, 0 +} + +// GetMaxRetries returns 0 +func (s *NoRetryStrategy) GetMaxRetries() int { + return 0 +} + +// DefaultRetryStrategy returns the default retry strategy (exponential backoff) +func DefaultRetryStrategy() RetryStrategy { + return NewExponentialBackoffRetry( + 3, // max retries + 1*time.Second, // initial delay + 5*time.Minute, // max delay + 0.1, // random factor (10% jitter) + ) +} + +// IsRetryableError checks if an error is retryable based on common patterns +func IsRetryableError(err error) bool { + if err == nil { + return false + } + + errorMsg := err.Error() + + // Check for transient errors that are typically retryable + retryablePatterns := []string{ + "timeout", + "deadline exceeded", + "temporary", + "busy", + "connection refused", + "connection reset", + "network", + "too many connections", + "resource temporarily unavailable", + "rate limit", + "throttle", + "server unavailable", + "service unavailable", + } + + for _, pattern := range retryablePatterns { + if strings.Contains(strings.ToLower(errorMsg), pattern) { + return true + } + } + + return false +} \ No newline at end of file diff --git a/task/retry_queue.go b/task/retry_queue.go new file mode 100644 index 0000000..c602f66 --- /dev/null +++ b/task/retry_queue.go @@ -0,0 +1,187 @@ +// Package task provides retry queue management for failed asynchronous tasks +package task + +import ( + "context" + "time" + + "modelRT/database" + "modelRT/logger" + + "github.com/gofrs/uuid" + "gorm.io/gorm" +) + +// RetryQueue manages scheduling and execution of task retries +type RetryQueue struct { + db *gorm.DB + producer *QueueProducer + strategy RetryStrategy +} + +// NewRetryQueue creates a new RetryQueue instance +func NewRetryQueue(db *gorm.DB, producer *QueueProducer, strategy RetryStrategy) *RetryQueue { + if strategy == nil { + strategy = DefaultRetryStrategy() + } + + return &RetryQueue{ + db: db, + producer: producer, + strategy: strategy, + } +} + +// ScheduleRetry schedules a failed task for retry based on retry strategy +func (q *RetryQueue) ScheduleRetry(ctx context.Context, taskID uuid.UUID, taskType TaskType, retryCount int, lastError error) error { + // Check if task should be retried + shouldRetry, delay := q.strategy.ShouldRetry(ctx, taskID.String(), retryCount, lastError) + if !shouldRetry { + // Mark task as permanently failed + logger.Info(ctx, "Task will not be retried, marking as failed", + "task_id", taskID, + "retry_count", retryCount, + "max_retries", q.strategy.GetMaxRetries(), + "last_error", lastError, + ) + return database.FailAsyncTask(ctx, q.db, taskID, time.Now().Unix()) + } + + // Calculate next retry time + nextRetryTime := time.Now().Add(delay).Unix() + + // Update task retry information in database + err := q.db.Transaction(func(tx *gorm.DB) error { + if err := database.UpdateTaskRetryInfo(ctx, tx, taskID, retryCount+1, nextRetryTime); err != nil { + return err + } + + // Update error information + errorMsg := "" + if lastError != nil { + errorMsg = lastError.Error() + } + if err := database.UpdateTaskErrorInfo(ctx, tx, taskID, errorMsg, ""); err != nil { + // Log but don't fail the whole retry scheduling + logger.Warn(ctx, "Failed to update task error info", + "task_id", taskID, + "error", err, + ) + } + + // Task will be picked up by ProcessRetryQueue when next_retry_time is reached + return nil + }) + + if err != nil { + logger.Error(ctx, "Failed to schedule task retry", + "task_id", taskID, + "task_type", taskType, + "retry_count", retryCount, + "delay", delay, + "error", err, + ) + return err + } + + logger.Info(ctx, "Task scheduled for retry", + "task_id", taskID, + "task_type", taskType, + "retry_count", retryCount+1, + "next_retry_in", delay, + "next_retry_time", time.Unix(nextRetryTime, 0).Format(time.RFC3339), + ) + + return nil +} + +// ProcessRetryQueue processes tasks that are due for retry +func (q *RetryQueue) ProcessRetryQueue(ctx context.Context, batchSize int) error { + // Get tasks due for retry + tasks, err := database.GetTasksForRetry(ctx, q.db, batchSize) + if err != nil { + logger.Error(ctx, "Failed to get tasks for retry", "error", err) + return err + } + + if len(tasks) == 0 { + return nil + } + + logger.Info(ctx, "Processing retry queue", + "task_count", len(tasks), + "batch_size", batchSize, + ) + + for _, task := range tasks { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // Publish task to queue for immediate processing + taskType := TaskType(task.TaskType) + if err := q.producer.PublishTask(ctx, task.TaskID, taskType, task.Priority); err != nil { + logger.Error(ctx, "Failed to publish retry task to queue", + "task_id", task.TaskID, + "task_type", taskType, + "error", err, + ) + // Continue with other tasks + continue + } + + // Update task status back to submitted + if err := database.UpdateAsyncTaskStatus(ctx, q.db, task.TaskID, "SUBMITTED"); err != nil { + logger.Warn(ctx, "Failed to update retry task status", + "task_id", task.TaskID, + "error", err, + ) + } + + // Clear next retry time since task is being retried now + if err := database.UpdateTaskRetryInfo(ctx, q.db, task.TaskID, task.RetryCount, 0); err != nil { + logger.Warn(ctx, "Failed to clear next retry time", + "task_id", task.TaskID, + "error", err, + ) + } + + logger.Info(ctx, "Retry task resubmitted", + "task_id", task.TaskID, + "task_type", taskType, + "retry_count", task.RetryCount, + ) + } + } + + return nil +} + +// StartRetryScheduler starts a background goroutine to periodically process retry queue +func (q *RetryQueue) StartRetryScheduler(ctx context.Context, interval time.Duration, batchSize int) { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + logger.Info(ctx, "Retry scheduler stopping") + return + case <-ticker.C: + if err := q.ProcessRetryQueue(ctx, batchSize); err != nil { + logger.Error(ctx, "Error processing retry queue", "error", err) + } + } + } + }() +} + +// GetRetryStats returns statistics about retry queue +func (q *RetryQueue) GetRetryStats(ctx context.Context) (int, error) { + tasks, err := database.GetTasksForRetry(ctx, q.db, 1000) // Large limit to count + if err != nil { + return 0, err + } + return len(tasks), nil +} \ No newline at end of file diff --git a/task/worker.go b/task/worker.go index c33198a..8388f8e 100644 --- a/task/worker.go +++ b/task/worker.go @@ -9,8 +9,10 @@ import ( "time" "modelRT/config" + "modelRT/database" "modelRT/logger" "modelRT/mq" + "modelRT/orm" "github.com/gofrs/uuid" "github.com/panjf2000/ants/v2" @@ -51,6 +53,7 @@ type TaskWorker struct { conn *amqp.Connection ch *amqp.Channel handler TaskHandler + retryQueue *RetryQueue stopChan chan struct{} wg sync.WaitGroup ctx context.Context @@ -60,12 +63,34 @@ type TaskWorker struct { // WorkerMetrics holds metrics for the worker pool type WorkerMetrics struct { - TasksProcessed int64 - TasksFailed int64 + // Task statistics by type + TasksProcessed map[TaskType]int64 + TasksFailed map[TaskType]int64 + TasksSuccess map[TaskType]int64 + ProcessingTime map[TaskType]time.Duration + + // Aggregate counters (maintained for backward compatibility) + TotalProcessed int64 + TotalFailed int64 + TotalSuccess int64 TasksInProgress int32 + + // Queue and latency metrics QueueDepth int + QueueLatency time.Duration + + // Worker resource metrics WorkersActive int WorkersIdle int + MemoryUsage uint64 + CPULoad float64 + + // Time window metrics + LastMinuteRate float64 + Last5MinutesRate float64 + LastHourRate float64 + + // Health and timing LastHealthCheck time.Time mu sync.RWMutex } @@ -133,6 +158,10 @@ func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg ctx: ctxWithCancel, cancel: cancel, metrics: &WorkerMetrics{ + TasksProcessed: make(map[TaskType]int64), + TasksFailed: make(map[TaskType]int64), + TasksSuccess: make(map[TaskType]int64), + ProcessingTime: make(map[TaskType]time.Duration), LastHealthCheck: time.Now(), }, } @@ -232,7 +261,7 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) { logger.Error(ctx, "Failed to unmarshal task message", "error", err) msg.Nack(false, false) // Reject without requeue w.metrics.mu.Lock() - w.metrics.TasksFailed++ + w.metrics.TotalFailed++ w.metrics.mu.Unlock() return } @@ -245,7 +274,9 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) { ) msg.Nack(false, false) // Reject without requeue w.metrics.mu.Lock() - w.metrics.TasksFailed++ + w.metrics.TotalFailed++ + // Also update per-task-type failure count + w.metrics.TasksFailed[taskMsg.TaskType]++ w.metrics.mu.Unlock() return } @@ -261,7 +292,8 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) { logger.Error(ctx, "Failed to update task status", "error", err) msg.Nack(false, true) // Reject with requeue w.metrics.mu.Lock() - w.metrics.TasksFailed++ + w.metrics.TotalFailed++ + w.metrics.TasksFailed[taskMsg.TaskType]++ w.metrics.mu.Unlock() return } @@ -287,7 +319,8 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) { // Ack message even if task failed (we don't want to retry indefinitely) msg.Ack(false) w.metrics.mu.Lock() - w.metrics.TasksFailed++ + w.metrics.TotalFailed++ + w.metrics.TasksFailed[taskMsg.TaskType]++ w.metrics.mu.Unlock() return } @@ -308,18 +341,74 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) { ) w.metrics.mu.Lock() - w.metrics.TasksProcessed++ + w.metrics.TotalProcessed++ + w.metrics.TasksProcessed[taskMsg.TaskType]++ + w.metrics.TasksSuccess[taskMsg.TaskType]++ + w.metrics.ProcessingTime[taskMsg.TaskType] += processingTime w.metrics.mu.Unlock() } // updateTaskStatus updates the status of a task in the database func (w *TaskWorker) updateTaskStatus(ctx context.Context, taskID uuid.UUID, status TaskStatus) error { - // This is a simplified version. In a real implementation, you would: - // 1. Have a proper task table/model - // 2. Update the task record with status and timestamps + // Convert TaskStatus to orm.AsyncTaskStatus + var ormStatus orm.AsyncTaskStatus + switch status { + case StatusPending: + ormStatus = orm.AsyncTaskStatusSubmitted + case StatusRunning: + ormStatus = orm.AsyncTaskStatusRunning + case StatusCompleted: + ormStatus = orm.AsyncTaskStatusCompleted + case StatusFailed: + ormStatus = orm.AsyncTaskStatusFailed + default: + return fmt.Errorf("unknown task status: %s", status) + } - // For now, we'll log the update - logger.Debug(ctx, "Updating task status", + // Update task status in database + err := database.UpdateAsyncTaskStatus(ctx, w.db, taskID, ormStatus) + if err != nil { + logger.Error(ctx, "Failed to update task status in database", + "task_id", taskID, + "status", status, + "error", err, + ) + return err + } + + // If status is running, update started_at timestamp + if status == StatusRunning { + startedAt := time.Now().Unix() + if err := database.UpdateTaskStarted(ctx, w.db, taskID, startedAt); err != nil { + logger.Warn(ctx, "Failed to update task start time", + "task_id", taskID, + "error", err, + ) + // Continue despite error + } + } + + // If status is completed or failed, update finished_at timestamp + if status == StatusCompleted || status == StatusFailed { + finishedAt := time.Now().Unix() + if status == StatusCompleted { + if err := database.CompleteAsyncTask(ctx, w.db, taskID, finishedAt); err != nil { + logger.Warn(ctx, "Failed to mark task as completed", + "task_id", taskID, + "error", err, + ) + } + } else { + if err := database.FailAsyncTask(ctx, w.db, taskID, finishedAt); err != nil { + logger.Warn(ctx, "Failed to mark task as failed", + "task_id", taskID, + "error", err, + ) + } + } + } + + logger.Debug(ctx, "Task status updated", "task_id", taskID, "status", status, ) @@ -328,10 +417,24 @@ func (w *TaskWorker) updateTaskStatus(ctx context.Context, taskID uuid.UUID, sta // updateTaskWithError updates a task with error information func (w *TaskWorker) updateTaskWithError(ctx context.Context, taskID uuid.UUID, err error) error { - logger.Debug(ctx, "Updating task with error", + // Update task error information in database + errorMsg := err.Error() + stackTrace := fmt.Sprintf("%+v", err) + + updateErr := database.UpdateTaskErrorInfo(ctx, w.db, taskID, errorMsg, stackTrace) + if updateErr != nil { + logger.Error(ctx, "Failed to update task error info", + "task_id", taskID, + "error", updateErr, + ) + return updateErr + } + + logger.Warn(ctx, "Task failed with error", "task_id", taskID, - "error", err.Error(), + "error", errorMsg, ) + return nil } @@ -379,12 +482,16 @@ func (w *TaskWorker) checkHealth() { w.metrics.LastHealthCheck = time.Now() logger.Info(w.ctx, "Worker health check", - "tasks_processed", w.metrics.TasksProcessed, - "tasks_failed", w.metrics.TasksFailed, + "tasks_processed", w.metrics.TotalProcessed, + "tasks_failed", w.metrics.TotalFailed, + "tasks_success", w.metrics.TotalSuccess, "tasks_in_progress", w.metrics.TasksInProgress, "queue_depth", w.metrics.QueueDepth, + "queue_latency_ms", w.metrics.QueueLatency.Milliseconds(), "workers_active", w.metrics.WorkersActive, "workers_idle", w.metrics.WorkersIdle, + "memory_usage_mb", w.metrics.MemoryUsage/(1024*1024), + "cpu_load_percent", w.metrics.CPULoad, "pool_capacity", w.pool.Cap(), ) } @@ -418,14 +525,47 @@ func (w *TaskWorker) Stop() error { func (w *TaskWorker) GetMetrics() *WorkerMetrics { w.metrics.mu.RLock() defer w.metrics.mu.RUnlock() + + // Deep copy maps to avoid data races + tasksProcessedCopy := make(map[TaskType]int64) + for k, v := range w.metrics.TasksProcessed { + tasksProcessedCopy[k] = v + } + + tasksFailedCopy := make(map[TaskType]int64) + for k, v := range w.metrics.TasksFailed { + tasksFailedCopy[k] = v + } + + tasksSuccessCopy := make(map[TaskType]int64) + for k, v := range w.metrics.TasksSuccess { + tasksSuccessCopy[k] = v + } + + processingTimeCopy := make(map[TaskType]time.Duration) + for k, v := range w.metrics.ProcessingTime { + processingTimeCopy[k] = v + } + // Create a copy without the mutex to avoid copylocks warning return &WorkerMetrics{ - TasksProcessed: w.metrics.TasksProcessed, - TasksFailed: w.metrics.TasksFailed, + TasksProcessed: tasksProcessedCopy, + TasksFailed: tasksFailedCopy, + TasksSuccess: tasksSuccessCopy, + ProcessingTime: processingTimeCopy, + TotalProcessed: w.metrics.TotalProcessed, + TotalFailed: w.metrics.TotalFailed, + TotalSuccess: w.metrics.TotalSuccess, TasksInProgress: w.metrics.TasksInProgress, QueueDepth: w.metrics.QueueDepth, + QueueLatency: w.metrics.QueueLatency, WorkersActive: w.metrics.WorkersActive, WorkersIdle: w.metrics.WorkersIdle, + MemoryUsage: w.metrics.MemoryUsage, + CPULoad: w.metrics.CPULoad, + LastMinuteRate: w.metrics.LastMinuteRate, + Last5MinutesRate: w.metrics.Last5MinutesRate, + LastHourRate: w.metrics.LastHourRate, LastHealthCheck: w.metrics.LastHealthCheck, // Mutex is intentionally omitted } @@ -438,4 +578,34 @@ func (w *TaskWorker) IsHealthy() bool { // Consider unhealthy if last health check was too long ago return time.Since(w.metrics.LastHealthCheck) < 2*w.cfg.PollingInterval -} \ No newline at end of file +} +// RecordMetrics periodically records worker metrics to the logging system +func (w *TaskWorker) RecordMetrics(interval time.Duration) { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + metricsLogger := NewMetricsLogger(w.ctx) + + for { + select { + case <-w.stopChan: + return + case <-ticker.C: + w.metrics.mu.RLock() + metricsLogger.LogWorkerMetrics( + w.metrics.WorkersActive, + w.metrics.WorkersIdle, + w.pool.Cap(), + w.metrics.MemoryUsage, + w.metrics.CPULoad, + ) + metricsLogger.LogQueueMetrics( + w.metrics.QueueDepth, + w.metrics.QueueLatency, + ) + w.metrics.mu.RUnlock() + } + } + }() +}