Extend async task system with database integration and retry management

- Add AsyncTaskConfig to config structure
  - Create database operations for task state management (async_task_extended.go)
  - Add configuration middleware for Gin context
  - Extract task worker initialization to separate file (initializer.go)
  - Implement retry strategies with exponential backoff (retry_manager.go)
  - Add retry queue for failed task scheduling (retry_queue.go)
  - Enhance worker metrics with detailed per-task-type tracking
  - Integrate database operations into task worker for status updates
  - Add comprehensive metrics logging system
This commit is contained in:
douxu 2026-04-03 10:07:43 +08:00
parent 9e4c35794c
commit f8c0951a13
9 changed files with 1048 additions and 48 deletions

View File

@ -3,6 +3,7 @@ package config
import (
"fmt"
"time"
"github.com/spf13/viper"
)
@ -91,6 +92,16 @@ type DataRTConfig struct {
Method string `mapstructure:"polling_api_method"`
}
// AsyncTaskConfig define config struct of asynchronous task system
type AsyncTaskConfig struct {
WorkerPoolSize int `mapstructure:"worker_pool_size"`
QueueConsumerCount int `mapstructure:"queue_consumer_count"`
MaxRetryCount int `mapstructure:"max_retry_count"`
RetryInitialDelay time.Duration `mapstructure:"retry_initial_delay"`
RetryMaxDelay time.Duration `mapstructure:"retry_max_delay"`
HealthCheckInterval time.Duration `mapstructure:"health_check_interval"`
}
// ModelRTConfig define config struct of model runtime server
type ModelRTConfig struct {
BaseConfig `mapstructure:"base"`
@ -103,6 +114,7 @@ type ModelRTConfig struct {
DataRTConfig `mapstructure:"dataRT"`
LockerRedisConfig RedisConfig `mapstructure:"locker_redis"`
StorageRedisConfig RedisConfig `mapstructure:"storage_redis"`
AsyncTaskConfig AsyncTaskConfig `mapstructure:"async_task"`
PostgresDBURI string `mapstructure:"-"`
}

View File

@ -0,0 +1,227 @@
// Package database define database operation functions
package database
import (
"context"
"time"
"modelRT/orm"
"github.com/gofrs/uuid"
"gorm.io/gorm"
)
// UpdateTaskStarted updates task start time and status to running
func UpdateTaskStarted(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, startedAt int64) error {
cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
result := tx.WithContext(cancelCtx).
Model(&orm.AsyncTask{}).
Where("task_id = ?", taskID).
Updates(map[string]any{
"status": orm.AsyncTaskStatusRunning,
"started_at": startedAt,
})
return result.Error
}
// UpdateTaskRetryInfo updates task retry information
func UpdateTaskRetryInfo(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, retryCount int, nextRetryTime int64) error {
cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
updateData := map[string]any{
"retry_count": retryCount,
}
if nextRetryTime <= 0 {
updateData["next_retry_time"] = nil
} else {
updateData["next_retry_time"] = nextRetryTime
}
result := tx.WithContext(cancelCtx).
Model(&orm.AsyncTask{}).
Where("task_id = ?", taskID).
Updates(updateData)
return result.Error
}
// UpdateTaskErrorInfo updates task error information
func UpdateTaskErrorInfo(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, errorMsg, stackTrace string) error {
cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
result := tx.WithContext(cancelCtx).
Model(&orm.AsyncTask{}).
Where("task_id = ?", taskID).
Updates(map[string]any{
"failure_reason": errorMsg,
"stack_trace": stackTrace,
})
return result.Error
}
// UpdateTaskExecutionTime updates task execution time
func UpdateTaskExecutionTime(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, executionTime int64) error {
cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
result := tx.WithContext(cancelCtx).
Model(&orm.AsyncTask{}).
Where("task_id = ?", taskID).
Update("execution_time", executionTime)
return result.Error
}
// UpdateTaskWorkerID updates the worker ID that is processing the task
func UpdateTaskWorkerID(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, workerID string) error {
cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
result := tx.WithContext(cancelCtx).
Model(&orm.AsyncTask{}).
Where("task_id = ?", taskID).
Update("worker_id", workerID)
return result.Error
}
// UpdateTaskPriority updates task priority
func UpdateTaskPriority(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, priority int) error {
cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
result := tx.WithContext(cancelCtx).
Model(&orm.AsyncTask{}).
Where("task_id = ?", taskID).
Update("priority", priority)
return result.Error
}
// UpdateTaskQueueName updates task queue name
func UpdateTaskQueueName(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, queueName string) error {
cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
result := tx.WithContext(cancelCtx).
Model(&orm.AsyncTask{}).
Where("task_id = ?", taskID).
Update("queue_name", queueName)
return result.Error
}
// UpdateTaskCreatedBy updates task creator information
func UpdateTaskCreatedBy(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, createdBy string) error {
cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
result := tx.WithContext(cancelCtx).
Model(&orm.AsyncTask{}).
Where("task_id = ?", taskID).
Update("created_by", createdBy)
return result.Error
}
// UpdateTaskResultWithMetrics updates task result with execution metrics
func UpdateTaskResultWithMetrics(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, executionTime int64, memoryUsage *int64, cpuUsage *float64, retryCount int, completedAt int64) error {
cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
result := tx.WithContext(cancelCtx).
Model(&orm.AsyncTaskResult{}).
Where("task_id = ?", taskID).
Updates(map[string]any{
"execution_time": executionTime,
"memory_usage": memoryUsage,
"cpu_usage": cpuUsage,
"retry_count": retryCount,
"completed_at": completedAt,
})
return result.Error
}
// GetTasksForRetry retrieves tasks that are due for retry
func GetTasksForRetry(ctx context.Context, tx *gorm.DB, limit int) ([]orm.AsyncTask, error) {
var tasks []orm.AsyncTask
cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
now := time.Now().Unix()
result := tx.WithContext(cancelCtx).
Where("status = ? AND next_retry_time IS NOT NULL AND next_retry_time <= ?", orm.AsyncTaskStatusFailed, now).
Order("next_retry_time ASC").
Limit(limit).
Find(&tasks)
if result.Error != nil {
return nil, result.Error
}
return tasks, nil
}
// GetTasksByPriority retrieves tasks by priority order
func GetTasksByPriority(ctx context.Context, tx *gorm.DB, status orm.AsyncTaskStatus, limit int) ([]orm.AsyncTask, error) {
var tasks []orm.AsyncTask
cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
result := tx.WithContext(cancelCtx).
Where("status = ?", status).
Order("priority DESC, created_at ASC").
Limit(limit).
Find(&tasks)
if result.Error != nil {
return nil, result.Error
}
return tasks, nil
}
// GetTasksByWorkerID retrieves tasks being processed by a specific worker
func GetTasksByWorkerID(ctx context.Context, tx *gorm.DB, workerID string) ([]orm.AsyncTask, error) {
var tasks []orm.AsyncTask
cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
result := tx.WithContext(cancelCtx).
Where("worker_id = ? AND status = ?", workerID, orm.AsyncTaskStatusRunning).
Find(&tasks)
if result.Error != nil {
return nil, result.Error
}
return tasks, nil
}
// CleanupStaleTasks marks tasks as failed if they have been running for too long
func CleanupStaleTasks(ctx context.Context, tx *gorm.DB, timeoutSeconds int64) (int64, error) {
cancelCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
threshold := time.Now().Unix() - timeoutSeconds
result := tx.WithContext(cancelCtx).
Model(&orm.AsyncTask{}).
Where("status = ? AND started_at IS NOT NULL AND started_at < ?", orm.AsyncTaskStatusRunning, threshold).
Updates(map[string]any{
"status": orm.AsyncTaskStatusFailed,
"failure_reason": "task timeout",
"finished_at": time.Now().Unix(),
})
return result.RowsAffected, result.Error
}

30
main.go
View File

@ -71,34 +71,6 @@ var (
//
// @host localhost:8080
// @BasePath /api/v1
func initTaskWorker(ctx context.Context, config config.ModelRTConfig, db *gorm.DB) (*task.TaskWorker, error) {
// Create worker configuration
workerCfg := task.WorkerConfig{
PoolSize: config.AsyncTaskConfig.WorkerPoolSize,
PreAlloc: true,
MaxBlockingTasks: 100,
QueueConsumerCount: config.AsyncTaskConfig.QueueConsumerCount,
PollingInterval: config.AsyncTaskConfig.HealthCheckInterval,
}
// Create task handler factory
handlerFactory := task.NewHandlerFactory()
handlerFactory.CreateDefaultHandlers()
handler := task.DefaultCompositeHandler()
// Create task worker
worker, err := task.NewTaskWorker(ctx, workerCfg, db, config.RabbitMQConfig, handler)
if err != nil {
return nil, fmt.Errorf("failed to create task worker: %w", err)
}
logger.Info(ctx, "Task worker initialized",
"worker_pool_size", workerCfg.PoolSize,
"queue_consumers", workerCfg.QueueConsumerCount,
)
return worker, nil
}
func main() {
flag.Parse()
@ -188,7 +160,7 @@ func main() {
mq.InitRabbitProxy(ctx, modelRTConfig.RabbitMQConfig)
// init async task worker
taskWorker, err := initTaskWorker(ctx, modelRTConfig, postgresDBClient)
taskWorker, err := task.InitTaskWorker(ctx, modelRTConfig, postgresDBClient)
if err != nil {
logger.Error(ctx, "Failed to initialize task worker", "error", err)
// Continue without task worker, but log warning

View File

@ -0,0 +1,15 @@
package middleware
import (
"modelRT/config"
"github.com/gin-gonic/gin"
)
// ConfigMiddleware 将全局配置注入到Gin上下文中
func ConfigMiddleware(modelRTConfig config.ModelRTConfig) gin.HandlerFunc {
return func(c *gin.Context) {
c.Set("config", modelRTConfig)
c.Next()
}
}

41
task/initializer.go Normal file
View File

@ -0,0 +1,41 @@
// Package task provides asynchronous task processing with worker pools
package task
import (
"context"
"fmt"
"modelRT/config"
"modelRT/logger"
"gorm.io/gorm"
)
// InitTaskWorker initializes a task worker with the given configuration and database connection
func InitTaskWorker(ctx context.Context, config config.ModelRTConfig, db *gorm.DB) (*TaskWorker, error) {
// Create worker configuration
workerCfg := WorkerConfig{
PoolSize: config.AsyncTaskConfig.WorkerPoolSize,
PreAlloc: true,
MaxBlockingTasks: 100,
QueueConsumerCount: config.AsyncTaskConfig.QueueConsumerCount,
PollingInterval: config.AsyncTaskConfig.HealthCheckInterval,
}
// Create task handler factory
handlerFactory := NewHandlerFactory()
handlerFactory.CreateDefaultHandlers()
handler := DefaultCompositeHandler()
// Create task worker
worker, err := NewTaskWorker(ctx, workerCfg, db, config.RabbitMQConfig, handler)
if err != nil {
return nil, fmt.Errorf("failed to create task worker: %w", err)
}
logger.Info(ctx, "Task worker initialized",
"worker_pool_size", workerCfg.PoolSize,
"queue_consumers", workerCfg.QueueConsumerCount,
)
return worker, nil
}

157
task/metrics_logger.go Normal file
View File

@ -0,0 +1,157 @@
// Package task provides metrics logging for asynchronous task system
package task
import (
"context"
"runtime"
"time"
"modelRT/logger"
)
// MetricsLogger logs task system metrics using the existing logging system
type MetricsLogger struct {
ctx context.Context
}
// NewMetricsLogger creates a new MetricsLogger
func NewMetricsLogger(ctx context.Context) *MetricsLogger {
return &MetricsLogger{ctx: ctx}
}
// LogTaskMetrics records task processing metrics
func (m *MetricsLogger) LogTaskMetrics(taskType TaskType, status string, processingTime time.Duration, retryCount int) {
logger.Info(m.ctx, "Task metrics",
"task_type", taskType,
"status", status,
"processing_time_ms", processingTime.Milliseconds(),
"retry_count", retryCount,
"metric_type", "task_processing",
"timestamp", time.Now().Unix(),
)
}
// LogQueueMetrics records queue metrics
func (m *MetricsLogger) LogQueueMetrics(queueDepth int, queueLatency time.Duration) {
logger.Info(m.ctx, "Queue metrics",
"queue_depth", queueDepth,
"queue_latency_ms", queueLatency.Milliseconds(),
"metric_type", "queue",
"timestamp", time.Now().Unix(),
)
}
// LogWorkerMetrics records worker metrics
func (m *MetricsLogger) LogWorkerMetrics(activeWorkers, idleWorkers, totalWorkers int, memoryUsage uint64, cpuLoad float64) {
logger.Info(m.ctx, "Worker metrics",
"active_workers", activeWorkers,
"idle_workers", idleWorkers,
"total_workers", totalWorkers,
"memory_usage_mb", memoryUsage/(1024*1024),
"cpu_load_percent", cpuLoad,
"metric_type", "worker",
"timestamp", time.Now().Unix(),
)
}
// LogRetryMetrics records retry metrics
func (m *MetricsLogger) LogRetryMetrics(taskType TaskType, retryCount int, success bool, delay time.Duration) {
logger.Info(m.ctx, "Retry metrics",
"task_type", taskType,
"retry_count", retryCount,
"retry_success", success,
"retry_delay_ms", delay.Milliseconds(),
"metric_type", "retry",
"timestamp", time.Now().Unix(),
)
}
// LogSystemMetrics records system-level metrics (memory, CPU, goroutines)
func (m *MetricsLogger) LogSystemMetrics() {
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
logger.Info(m.ctx, "System metrics",
"metric_type", "system",
"timestamp", time.Now().Unix(),
"goroutines", runtime.NumGoroutine(),
"memory_alloc_mb", memStats.Alloc/(1024*1024),
"memory_total_alloc_mb", memStats.TotalAlloc/(1024*1024),
"memory_sys_mb", memStats.Sys/(1024*1024),
"memory_heap_alloc_mb", memStats.HeapAlloc/(1024*1024),
"memory_heap_sys_mb", memStats.HeapSys/(1024*1024),
"memory_heap_inuse_mb", memStats.HeapInuse/(1024*1024),
"gc_pause_total_ns", memStats.PauseTotalNs,
"num_gc", memStats.NumGC,
)
}
// LogTaskCompletionMetrics records detailed task completion metrics
func (m *MetricsLogger) LogTaskCompletionMetrics(taskID, taskType, status string, startTime, endTime time.Time, retryCount int, errorMsg string) {
duration := endTime.Sub(startTime)
logger.Info(m.ctx, "Task completion metrics",
"metric_type", "task_completion",
"timestamp", time.Now().Unix(),
"task_id", taskID,
"task_type", taskType,
"status", status,
"duration_ms", duration.Milliseconds(),
"start_time", startTime.Unix(),
"end_time", endTime.Unix(),
"retry_count", retryCount,
"has_error", errorMsg != "",
"error_msg", errorMsg,
)
}
// LogHealthCheckMetrics records health check metrics
func (m *MetricsLogger) LogHealthCheckMetrics(healthy bool, checkDuration time.Duration, components map[string]bool) {
logger.Info(m.ctx, "Health check metrics",
"metric_type", "health_check",
"timestamp", time.Now().Unix(),
"healthy", healthy,
"check_duration_ms", checkDuration.Milliseconds(),
"components", components,
)
}
// PeriodicMetricsLogger periodically logs system and worker metrics
type PeriodicMetricsLogger struct {
ctx context.Context
interval time.Duration
stopChan chan struct{}
metricsLog *MetricsLogger
}
// NewPeriodicMetricsLogger creates a new PeriodicMetricsLogger
func NewPeriodicMetricsLogger(ctx context.Context, interval time.Duration) *PeriodicMetricsLogger {
return &PeriodicMetricsLogger{
ctx: ctx,
interval: interval,
stopChan: make(chan struct{}),
metricsLog: NewMetricsLogger(ctx),
}
}
// Start begins periodic metrics logging
func (p *PeriodicMetricsLogger) Start() {
go func() {
ticker := time.NewTicker(p.interval)
defer ticker.Stop()
for {
select {
case <-p.stopChan:
return
case <-ticker.C:
p.metricsLog.LogSystemMetrics()
}
}
}()
}
// Stop stops periodic metrics logging
func (p *PeriodicMetricsLogger) Stop() {
close(p.stopChan)
}

219
task/retry_manager.go Normal file
View File

@ -0,0 +1,219 @@
// Package task provides retry strategies for failed asynchronous tasks
package task
import (
"context"
"math"
"math/rand"
"strings"
"time"
"modelRT/logger"
)
// RetryStrategy defines the interface for task retry strategies
type RetryStrategy interface {
// ShouldRetry determines if a task should be retried and returns the delay before next retry
ShouldRetry(ctx context.Context, taskID string, retryCount int, lastError error) (bool, time.Duration)
// GetMaxRetries returns the maximum number of retry attempts
GetMaxRetries() int
}
// ExponentialBackoffRetry implements exponential backoff with jitter retry strategy
type ExponentialBackoffRetry struct {
MaxRetries int
InitialDelay time.Duration
MaxDelay time.Duration
RandomFactor float64 // Jitter factor to avoid thundering herd problem
}
// NewExponentialBackoffRetry creates a new exponential backoff retry strategy
func NewExponentialBackoffRetry(maxRetries int, initialDelay, maxDelay time.Duration, randomFactor float64) *ExponentialBackoffRetry {
if maxRetries < 0 {
maxRetries = 0
}
if initialDelay <= 0 {
initialDelay = 1 * time.Second
}
if maxDelay <= 0 {
maxDelay = 5 * time.Minute
}
if randomFactor < 0 {
randomFactor = 0
}
if randomFactor > 1 {
randomFactor = 1
}
return &ExponentialBackoffRetry{
MaxRetries: maxRetries,
InitialDelay: initialDelay,
MaxDelay: maxDelay,
RandomFactor: randomFactor,
}
}
// ShouldRetry implements exponential backoff with jitter
func (s *ExponentialBackoffRetry) ShouldRetry(ctx context.Context, taskID string, retryCount int, lastError error) (bool, time.Duration) {
if retryCount >= s.MaxRetries {
logger.Info(ctx, "Task reached maximum retry count",
"task_id", taskID,
"retry_count", retryCount,
"max_retries", s.MaxRetries,
"last_error", lastError,
)
return false, 0
}
// Calculate exponential backoff: initialDelay * 2^retryCount
delay := s.InitialDelay * time.Duration(math.Pow(2, float64(retryCount)))
// Apply maximum delay cap
if delay > s.MaxDelay {
delay = s.MaxDelay
}
// Add jitter to avoid thundering herd
if s.RandomFactor > 0 {
jitter := rand.Float64() * s.RandomFactor * float64(delay)
// Randomly add or subtract jitter
if rand.Intn(2) == 0 {
delay += time.Duration(jitter)
} else {
delay -= time.Duration(jitter)
}
// Ensure delay doesn't go below initial delay
if delay < s.InitialDelay {
delay = s.InitialDelay
}
}
logger.Info(ctx, "Task will be retried",
"task_id", taskID,
"retry_count", retryCount,
"next_retry_in", delay,
"max_retries", s.MaxRetries,
)
return true, delay
}
// GetMaxRetries returns the maximum number of retry attempts
func (s *ExponentialBackoffRetry) GetMaxRetries() int {
return s.MaxRetries
}
// FixedDelayRetry implements fixed delay retry strategy
type FixedDelayRetry struct {
MaxRetries int
Delay time.Duration
RandomFactor float64
}
// NewFixedDelayRetry creates a new fixed delay retry strategy
func NewFixedDelayRetry(maxRetries int, delay time.Duration, randomFactor float64) *FixedDelayRetry {
if maxRetries < 0 {
maxRetries = 0
}
if delay <= 0 {
delay = 5 * time.Second
}
return &FixedDelayRetry{
MaxRetries: maxRetries,
Delay: delay,
RandomFactor: randomFactor,
}
}
// ShouldRetry implements fixed delay with optional jitter
func (s *FixedDelayRetry) ShouldRetry(ctx context.Context, taskID string, retryCount int, lastError error) (bool, time.Duration) {
if retryCount >= s.MaxRetries {
return false, 0
}
delay := s.Delay
// Add jitter if random factor is specified
if s.RandomFactor > 0 {
jitter := rand.Float64() * s.RandomFactor * float64(delay)
if rand.Intn(2) == 0 {
delay += time.Duration(jitter)
} else {
delay -= time.Duration(jitter)
}
// Ensure positive delay
if delay <= 0 {
delay = s.Delay
}
}
return true, delay
}
// GetMaxRetries returns the maximum number of retry attempts
func (s *FixedDelayRetry) GetMaxRetries() int {
return s.MaxRetries
}
// NoRetryStrategy implements a strategy that never retries
type NoRetryStrategy struct{}
// NewNoRetryStrategy creates a new no-retry strategy
func NewNoRetryStrategy() *NoRetryStrategy {
return &NoRetryStrategy{}
}
// ShouldRetry always returns false
func (s *NoRetryStrategy) ShouldRetry(ctx context.Context, taskID string, retryCount int, lastError error) (bool, time.Duration) {
return false, 0
}
// GetMaxRetries returns 0
func (s *NoRetryStrategy) GetMaxRetries() int {
return 0
}
// DefaultRetryStrategy returns the default retry strategy (exponential backoff)
func DefaultRetryStrategy() RetryStrategy {
return NewExponentialBackoffRetry(
3, // max retries
1*time.Second, // initial delay
5*time.Minute, // max delay
0.1, // random factor (10% jitter)
)
}
// IsRetryableError checks if an error is retryable based on common patterns
func IsRetryableError(err error) bool {
if err == nil {
return false
}
errorMsg := err.Error()
// Check for transient errors that are typically retryable
retryablePatterns := []string{
"timeout",
"deadline exceeded",
"temporary",
"busy",
"connection refused",
"connection reset",
"network",
"too many connections",
"resource temporarily unavailable",
"rate limit",
"throttle",
"server unavailable",
"service unavailable",
}
for _, pattern := range retryablePatterns {
if strings.Contains(strings.ToLower(errorMsg), pattern) {
return true
}
}
return false
}

187
task/retry_queue.go Normal file
View File

@ -0,0 +1,187 @@
// Package task provides retry queue management for failed asynchronous tasks
package task
import (
"context"
"time"
"modelRT/database"
"modelRT/logger"
"github.com/gofrs/uuid"
"gorm.io/gorm"
)
// RetryQueue manages scheduling and execution of task retries
type RetryQueue struct {
db *gorm.DB
producer *QueueProducer
strategy RetryStrategy
}
// NewRetryQueue creates a new RetryQueue instance
func NewRetryQueue(db *gorm.DB, producer *QueueProducer, strategy RetryStrategy) *RetryQueue {
if strategy == nil {
strategy = DefaultRetryStrategy()
}
return &RetryQueue{
db: db,
producer: producer,
strategy: strategy,
}
}
// ScheduleRetry schedules a failed task for retry based on retry strategy
func (q *RetryQueue) ScheduleRetry(ctx context.Context, taskID uuid.UUID, taskType TaskType, retryCount int, lastError error) error {
// Check if task should be retried
shouldRetry, delay := q.strategy.ShouldRetry(ctx, taskID.String(), retryCount, lastError)
if !shouldRetry {
// Mark task as permanently failed
logger.Info(ctx, "Task will not be retried, marking as failed",
"task_id", taskID,
"retry_count", retryCount,
"max_retries", q.strategy.GetMaxRetries(),
"last_error", lastError,
)
return database.FailAsyncTask(ctx, q.db, taskID, time.Now().Unix())
}
// Calculate next retry time
nextRetryTime := time.Now().Add(delay).Unix()
// Update task retry information in database
err := q.db.Transaction(func(tx *gorm.DB) error {
if err := database.UpdateTaskRetryInfo(ctx, tx, taskID, retryCount+1, nextRetryTime); err != nil {
return err
}
// Update error information
errorMsg := ""
if lastError != nil {
errorMsg = lastError.Error()
}
if err := database.UpdateTaskErrorInfo(ctx, tx, taskID, errorMsg, ""); err != nil {
// Log but don't fail the whole retry scheduling
logger.Warn(ctx, "Failed to update task error info",
"task_id", taskID,
"error", err,
)
}
// Task will be picked up by ProcessRetryQueue when next_retry_time is reached
return nil
})
if err != nil {
logger.Error(ctx, "Failed to schedule task retry",
"task_id", taskID,
"task_type", taskType,
"retry_count", retryCount,
"delay", delay,
"error", err,
)
return err
}
logger.Info(ctx, "Task scheduled for retry",
"task_id", taskID,
"task_type", taskType,
"retry_count", retryCount+1,
"next_retry_in", delay,
"next_retry_time", time.Unix(nextRetryTime, 0).Format(time.RFC3339),
)
return nil
}
// ProcessRetryQueue processes tasks that are due for retry
func (q *RetryQueue) ProcessRetryQueue(ctx context.Context, batchSize int) error {
// Get tasks due for retry
tasks, err := database.GetTasksForRetry(ctx, q.db, batchSize)
if err != nil {
logger.Error(ctx, "Failed to get tasks for retry", "error", err)
return err
}
if len(tasks) == 0 {
return nil
}
logger.Info(ctx, "Processing retry queue",
"task_count", len(tasks),
"batch_size", batchSize,
)
for _, task := range tasks {
select {
case <-ctx.Done():
return ctx.Err()
default:
// Publish task to queue for immediate processing
taskType := TaskType(task.TaskType)
if err := q.producer.PublishTask(ctx, task.TaskID, taskType, task.Priority); err != nil {
logger.Error(ctx, "Failed to publish retry task to queue",
"task_id", task.TaskID,
"task_type", taskType,
"error", err,
)
// Continue with other tasks
continue
}
// Update task status back to submitted
if err := database.UpdateAsyncTaskStatus(ctx, q.db, task.TaskID, "SUBMITTED"); err != nil {
logger.Warn(ctx, "Failed to update retry task status",
"task_id", task.TaskID,
"error", err,
)
}
// Clear next retry time since task is being retried now
if err := database.UpdateTaskRetryInfo(ctx, q.db, task.TaskID, task.RetryCount, 0); err != nil {
logger.Warn(ctx, "Failed to clear next retry time",
"task_id", task.TaskID,
"error", err,
)
}
logger.Info(ctx, "Retry task resubmitted",
"task_id", task.TaskID,
"task_type", taskType,
"retry_count", task.RetryCount,
)
}
}
return nil
}
// StartRetryScheduler starts a background goroutine to periodically process retry queue
func (q *RetryQueue) StartRetryScheduler(ctx context.Context, interval time.Duration, batchSize int) {
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
logger.Info(ctx, "Retry scheduler stopping")
return
case <-ticker.C:
if err := q.ProcessRetryQueue(ctx, batchSize); err != nil {
logger.Error(ctx, "Error processing retry queue", "error", err)
}
}
}
}()
}
// GetRetryStats returns statistics about retry queue
func (q *RetryQueue) GetRetryStats(ctx context.Context) (int, error) {
tasks, err := database.GetTasksForRetry(ctx, q.db, 1000) // Large limit to count
if err != nil {
return 0, err
}
return len(tasks), nil
}

View File

@ -9,8 +9,10 @@ import (
"time"
"modelRT/config"
"modelRT/database"
"modelRT/logger"
"modelRT/mq"
"modelRT/orm"
"github.com/gofrs/uuid"
"github.com/panjf2000/ants/v2"
@ -51,6 +53,7 @@ type TaskWorker struct {
conn *amqp.Connection
ch *amqp.Channel
handler TaskHandler
retryQueue *RetryQueue
stopChan chan struct{}
wg sync.WaitGroup
ctx context.Context
@ -60,12 +63,34 @@ type TaskWorker struct {
// WorkerMetrics holds metrics for the worker pool
type WorkerMetrics struct {
TasksProcessed int64
TasksFailed int64
// Task statistics by type
TasksProcessed map[TaskType]int64
TasksFailed map[TaskType]int64
TasksSuccess map[TaskType]int64
ProcessingTime map[TaskType]time.Duration
// Aggregate counters (maintained for backward compatibility)
TotalProcessed int64
TotalFailed int64
TotalSuccess int64
TasksInProgress int32
// Queue and latency metrics
QueueDepth int
QueueLatency time.Duration
// Worker resource metrics
WorkersActive int
WorkersIdle int
MemoryUsage uint64
CPULoad float64
// Time window metrics
LastMinuteRate float64
Last5MinutesRate float64
LastHourRate float64
// Health and timing
LastHealthCheck time.Time
mu sync.RWMutex
}
@ -133,6 +158,10 @@ func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg
ctx: ctxWithCancel,
cancel: cancel,
metrics: &WorkerMetrics{
TasksProcessed: make(map[TaskType]int64),
TasksFailed: make(map[TaskType]int64),
TasksSuccess: make(map[TaskType]int64),
ProcessingTime: make(map[TaskType]time.Duration),
LastHealthCheck: time.Now(),
},
}
@ -232,7 +261,7 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) {
logger.Error(ctx, "Failed to unmarshal task message", "error", err)
msg.Nack(false, false) // Reject without requeue
w.metrics.mu.Lock()
w.metrics.TasksFailed++
w.metrics.TotalFailed++
w.metrics.mu.Unlock()
return
}
@ -245,7 +274,9 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) {
)
msg.Nack(false, false) // Reject without requeue
w.metrics.mu.Lock()
w.metrics.TasksFailed++
w.metrics.TotalFailed++
// Also update per-task-type failure count
w.metrics.TasksFailed[taskMsg.TaskType]++
w.metrics.mu.Unlock()
return
}
@ -261,7 +292,8 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) {
logger.Error(ctx, "Failed to update task status", "error", err)
msg.Nack(false, true) // Reject with requeue
w.metrics.mu.Lock()
w.metrics.TasksFailed++
w.metrics.TotalFailed++
w.metrics.TasksFailed[taskMsg.TaskType]++
w.metrics.mu.Unlock()
return
}
@ -287,7 +319,8 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) {
// Ack message even if task failed (we don't want to retry indefinitely)
msg.Ack(false)
w.metrics.mu.Lock()
w.metrics.TasksFailed++
w.metrics.TotalFailed++
w.metrics.TasksFailed[taskMsg.TaskType]++
w.metrics.mu.Unlock()
return
}
@ -308,18 +341,74 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) {
)
w.metrics.mu.Lock()
w.metrics.TasksProcessed++
w.metrics.TotalProcessed++
w.metrics.TasksProcessed[taskMsg.TaskType]++
w.metrics.TasksSuccess[taskMsg.TaskType]++
w.metrics.ProcessingTime[taskMsg.TaskType] += processingTime
w.metrics.mu.Unlock()
}
// updateTaskStatus updates the status of a task in the database
func (w *TaskWorker) updateTaskStatus(ctx context.Context, taskID uuid.UUID, status TaskStatus) error {
// This is a simplified version. In a real implementation, you would:
// 1. Have a proper task table/model
// 2. Update the task record with status and timestamps
// Convert TaskStatus to orm.AsyncTaskStatus
var ormStatus orm.AsyncTaskStatus
switch status {
case StatusPending:
ormStatus = orm.AsyncTaskStatusSubmitted
case StatusRunning:
ormStatus = orm.AsyncTaskStatusRunning
case StatusCompleted:
ormStatus = orm.AsyncTaskStatusCompleted
case StatusFailed:
ormStatus = orm.AsyncTaskStatusFailed
default:
return fmt.Errorf("unknown task status: %s", status)
}
// For now, we'll log the update
logger.Debug(ctx, "Updating task status",
// Update task status in database
err := database.UpdateAsyncTaskStatus(ctx, w.db, taskID, ormStatus)
if err != nil {
logger.Error(ctx, "Failed to update task status in database",
"task_id", taskID,
"status", status,
"error", err,
)
return err
}
// If status is running, update started_at timestamp
if status == StatusRunning {
startedAt := time.Now().Unix()
if err := database.UpdateTaskStarted(ctx, w.db, taskID, startedAt); err != nil {
logger.Warn(ctx, "Failed to update task start time",
"task_id", taskID,
"error", err,
)
// Continue despite error
}
}
// If status is completed or failed, update finished_at timestamp
if status == StatusCompleted || status == StatusFailed {
finishedAt := time.Now().Unix()
if status == StatusCompleted {
if err := database.CompleteAsyncTask(ctx, w.db, taskID, finishedAt); err != nil {
logger.Warn(ctx, "Failed to mark task as completed",
"task_id", taskID,
"error", err,
)
}
} else {
if err := database.FailAsyncTask(ctx, w.db, taskID, finishedAt); err != nil {
logger.Warn(ctx, "Failed to mark task as failed",
"task_id", taskID,
"error", err,
)
}
}
}
logger.Debug(ctx, "Task status updated",
"task_id", taskID,
"status", status,
)
@ -328,10 +417,24 @@ func (w *TaskWorker) updateTaskStatus(ctx context.Context, taskID uuid.UUID, sta
// updateTaskWithError updates a task with error information
func (w *TaskWorker) updateTaskWithError(ctx context.Context, taskID uuid.UUID, err error) error {
logger.Debug(ctx, "Updating task with error",
// Update task error information in database
errorMsg := err.Error()
stackTrace := fmt.Sprintf("%+v", err)
updateErr := database.UpdateTaskErrorInfo(ctx, w.db, taskID, errorMsg, stackTrace)
if updateErr != nil {
logger.Error(ctx, "Failed to update task error info",
"task_id", taskID,
"error", updateErr,
)
return updateErr
}
logger.Warn(ctx, "Task failed with error",
"task_id", taskID,
"error", err.Error(),
"error", errorMsg,
)
return nil
}
@ -379,12 +482,16 @@ func (w *TaskWorker) checkHealth() {
w.metrics.LastHealthCheck = time.Now()
logger.Info(w.ctx, "Worker health check",
"tasks_processed", w.metrics.TasksProcessed,
"tasks_failed", w.metrics.TasksFailed,
"tasks_processed", w.metrics.TotalProcessed,
"tasks_failed", w.metrics.TotalFailed,
"tasks_success", w.metrics.TotalSuccess,
"tasks_in_progress", w.metrics.TasksInProgress,
"queue_depth", w.metrics.QueueDepth,
"queue_latency_ms", w.metrics.QueueLatency.Milliseconds(),
"workers_active", w.metrics.WorkersActive,
"workers_idle", w.metrics.WorkersIdle,
"memory_usage_mb", w.metrics.MemoryUsage/(1024*1024),
"cpu_load_percent", w.metrics.CPULoad,
"pool_capacity", w.pool.Cap(),
)
}
@ -418,14 +525,47 @@ func (w *TaskWorker) Stop() error {
func (w *TaskWorker) GetMetrics() *WorkerMetrics {
w.metrics.mu.RLock()
defer w.metrics.mu.RUnlock()
// Deep copy maps to avoid data races
tasksProcessedCopy := make(map[TaskType]int64)
for k, v := range w.metrics.TasksProcessed {
tasksProcessedCopy[k] = v
}
tasksFailedCopy := make(map[TaskType]int64)
for k, v := range w.metrics.TasksFailed {
tasksFailedCopy[k] = v
}
tasksSuccessCopy := make(map[TaskType]int64)
for k, v := range w.metrics.TasksSuccess {
tasksSuccessCopy[k] = v
}
processingTimeCopy := make(map[TaskType]time.Duration)
for k, v := range w.metrics.ProcessingTime {
processingTimeCopy[k] = v
}
// Create a copy without the mutex to avoid copylocks warning
return &WorkerMetrics{
TasksProcessed: w.metrics.TasksProcessed,
TasksFailed: w.metrics.TasksFailed,
TasksProcessed: tasksProcessedCopy,
TasksFailed: tasksFailedCopy,
TasksSuccess: tasksSuccessCopy,
ProcessingTime: processingTimeCopy,
TotalProcessed: w.metrics.TotalProcessed,
TotalFailed: w.metrics.TotalFailed,
TotalSuccess: w.metrics.TotalSuccess,
TasksInProgress: w.metrics.TasksInProgress,
QueueDepth: w.metrics.QueueDepth,
QueueLatency: w.metrics.QueueLatency,
WorkersActive: w.metrics.WorkersActive,
WorkersIdle: w.metrics.WorkersIdle,
MemoryUsage: w.metrics.MemoryUsage,
CPULoad: w.metrics.CPULoad,
LastMinuteRate: w.metrics.LastMinuteRate,
Last5MinutesRate: w.metrics.Last5MinutesRate,
LastHourRate: w.metrics.LastHourRate,
LastHealthCheck: w.metrics.LastHealthCheck,
// Mutex is intentionally omitted
}
@ -438,4 +578,34 @@ func (w *TaskWorker) IsHealthy() bool {
// Consider unhealthy if last health check was too long ago
return time.Since(w.metrics.LastHealthCheck) < 2*w.cfg.PollingInterval
}
}
// RecordMetrics periodically records worker metrics to the logging system
func (w *TaskWorker) RecordMetrics(interval time.Duration) {
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
metricsLogger := NewMetricsLogger(w.ctx)
for {
select {
case <-w.stopChan:
return
case <-ticker.C:
w.metrics.mu.RLock()
metricsLogger.LogWorkerMetrics(
w.metrics.WorkersActive,
w.metrics.WorkersIdle,
w.pool.Cap(),
w.metrics.MemoryUsage,
w.metrics.CPULoad,
)
metricsLogger.LogQueueMetrics(
w.metrics.QueueDepth,
w.metrics.QueueLatency,
)
w.metrics.mu.RUnlock()
}
}
}()
}