From 33f7d758e55eea18ca28cc68f24ea2aa5e4e209a Mon Sep 17 00:00:00 2001 From: douxu Date: Mon, 27 Apr 2026 17:55:38 +0800 Subject: [PATCH] refactor: overhaul async task handler routing and fix data consistency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - fix params lost in RabbitMQ transit by threading them through PublishTask/PublishTaskWithRetry - fix UpdateTaskErrorInfo not setting status=FAILED on async_task - fix UpdateAsyncTaskResultWithError silently skipping when no result row exists (UPDATE → upsert) - sync task failure to async_task_result in updateTaskWithError - remove taskType from AsyncTaskHandler.Execute interface; rename TaskHandler → AsyncTaskHandler - replace CompositeHandler with direct factory.GetHandler dispatch via worker.dispatch() - use constructors (NewXxxHandler) for handler registration instead of zero-value literals - consolidate TaskType/TaskStatus/UnifiedTaskType into task/types.go; delete types_v2.go - extract BaseTask/TaskParams into task/base_task.go --- database/async_task_extended.go | 1 + database/async_task_operations.go | 16 +-- task/base_task.go | 61 +++++++++++ task/handler_factory.go | 80 +++++--------- task/initializer.go | 6 +- task/queue_message.go | 20 +--- task/queue_producer.go | 9 +- task/retry_queue.go | 2 +- task/test_task.go | 19 ++-- task/types.go | 60 ++++------- task/types_v2.go | 138 ------------------------ task/worker.go | 171 ++++++++++++++++-------------- 12 files changed, 226 insertions(+), 357 deletions(-) create mode 100644 task/base_task.go delete mode 100644 task/types_v2.go diff --git a/database/async_task_extended.go b/database/async_task_extended.go index ca94b42..8d5849c 100644 --- a/database/async_task_extended.go +++ b/database/async_task_extended.go @@ -60,6 +60,7 @@ func UpdateTaskErrorInfo(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, err Updates(map[string]any{ "failure_reason": errorMsg, "stack_trace": stackTrace, + "status": orm.AsyncTaskStatusFailed, }) return result.Error diff --git a/database/async_task_operations.go b/database/async_task_operations.go index 991e77b..fce150b 100644 --- a/database/async_task_operations.go +++ b/database/async_task_operations.go @@ -161,14 +161,18 @@ func CreateAsyncTaskResult(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, r return resultOp.Error } -// UpdateAsyncTaskResultWithError updates a task result with error information +// UpdateAsyncTaskResultWithError upserts a task result with error information. func UpdateAsyncTaskResultWithError(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, code int, message string, detail orm.JSONMap) error { - // ctx timeout judgment cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - // Update with error information - result := tx.WithContext(cancelCtx). + if err := tx.WithContext(cancelCtx). + Where("task_id = ?", taskID). + FirstOrCreate(&orm.AsyncTaskResult{TaskID: taskID}).Error; err != nil { + return err + } + + return tx.WithContext(cancelCtx). Model(&orm.AsyncTaskResult{}). Where("task_id = ?", taskID). Updates(map[string]any{ @@ -176,9 +180,7 @@ func UpdateAsyncTaskResultWithError(ctx context.Context, tx *gorm.DB, taskID uui "error_message": message, "error_detail": detail, "result": nil, - }) - - return result.Error + }).Error } // UpdateAsyncTaskResultWithSuccess updates a task result with success information diff --git a/task/base_task.go b/task/base_task.go new file mode 100644 index 0000000..271d043 --- /dev/null +++ b/task/base_task.go @@ -0,0 +1,61 @@ +// Package task provides unified task type definitions and interfaces +package task + +import ( + "context" + "fmt" + + "github.com/gofrs/uuid" + "gorm.io/gorm" +) + +// TaskParams defines the interface for task-specific parameters +type TaskParams interface { + Validate() error + GetType() UnifiedTaskType + ToMap() map[string]interface{} + FromMap(params map[string]interface{}) error +} + +// BaseTask provides common functionality for all task implementations +type BaseTask struct { + taskType UnifiedTaskType + params TaskParams + name string +} + +// NewBaseTask creates a new BaseTask instance +func NewBaseTask(taskType UnifiedTaskType, params TaskParams, name string) *BaseTask { + return &BaseTask{ + taskType: taskType, + params: params, + name: name, + } +} + +func (t *BaseTask) GetType() UnifiedTaskType { + return t.taskType +} + +func (t *BaseTask) GetParams() TaskParams { + return t.params +} + +func (t *BaseTask) GetName() string { + return t.name +} + +func (t *BaseTask) Validate() error { + if t.params == nil { + return fmt.Errorf("task parameters cannot be nil") + } + if t.taskType != t.params.GetType() { + return fmt.Errorf("task type mismatch: expected %s, got %s", t.taskType, t.params.GetType()) + } + return t.params.Validate() +} + +// Execute is a placeholder; concrete task types override this via embedding. +func (t *BaseTask) Execute(_ context.Context, _ uuid.UUID, _ *gorm.DB) error { + return fmt.Errorf("Execute not implemented for task type %s", t.taskType) +} diff --git a/task/handler_factory.go b/task/handler_factory.go index ae1d951..7fcc174 100644 --- a/task/handler_factory.go +++ b/task/handler_factory.go @@ -15,10 +15,10 @@ import ( "gorm.io/gorm" ) -// TaskHandler defines the interface for task processors -type TaskHandler interface { - // Execute processes a task with the given ID, type, and params from the MQ message - Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, params map[string]any, db *gorm.DB) error +// AsyncTaskHandler defines the interface for task processors +type AsyncTaskHandler interface { + // Execute processes a task with the given ID and params from the MQ message + Execute(ctx context.Context, taskID uuid.UUID, params map[string]any, db *gorm.DB) error // CanHandle returns true if this handler can process the given task type CanHandle(taskType TaskType) bool // Name returns the name of the handler for logging and metrics @@ -27,19 +27,19 @@ type TaskHandler interface { // HandlerFactory creates task handlers based on task type type HandlerFactory struct { - handlers map[TaskType]TaskHandler + handlers map[TaskType]AsyncTaskHandler mu sync.RWMutex } // NewHandlerFactory creates a new HandlerFactory func NewHandlerFactory() *HandlerFactory { return &HandlerFactory{ - handlers: make(map[TaskType]TaskHandler), + handlers: make(map[TaskType]AsyncTaskHandler), } } // RegisterHandler registers a handler for a specific task type -func (f *HandlerFactory) RegisterHandler(ctx context.Context, taskType TaskType, handler TaskHandler) { +func (f *HandlerFactory) RegisterHandler(ctx context.Context, taskType TaskType, handler AsyncTaskHandler) { f.mu.Lock() defer f.mu.Unlock() @@ -51,7 +51,7 @@ func (f *HandlerFactory) RegisterHandler(ctx context.Context, taskType TaskType, } // GetHandler returns a handler for the given task type -func (f *HandlerFactory) GetHandler(taskType TaskType) (TaskHandler, error) { +func (f *HandlerFactory) GetHandler(taskType TaskType) (AsyncTaskHandler, error) { f.mu.RLock() handler, exists := f.handlers[taskType] f.mu.RUnlock() @@ -65,10 +65,10 @@ func (f *HandlerFactory) GetHandler(taskType TaskType) (TaskHandler, error) { // CreateDefaultHandlers registers all default task handlers func (f *HandlerFactory) CreateDefaultHandlers(ctx context.Context) { - f.RegisterHandler(ctx, TypeTopologyAnalysis, &TopologyAnalysisHandler{}) - f.RegisterHandler(ctx, TypeEventAnalysis, &EventAnalysisHandler{}) - f.RegisterHandler(ctx, TypeBatchImport, &BatchImportHandler{}) - f.RegisterHandler(ctx, TaskType(TaskTypeTest), NewTestTaskHandler()) + f.RegisterHandler(ctx, TypeTopologyAnalysis, NewTopologyAnalysisHandler()) + f.RegisterHandler(ctx, TypeEventAnalysis, NewEventAnalysisHandler()) + f.RegisterHandler(ctx, TypeBatchImport, NewBatchImportHandler()) + f.RegisterHandler(ctx, TypeTest, NewTestTaskHandler()) } // BaseHandler provides common functionality for all task handlers @@ -103,7 +103,7 @@ func NewTopologyAnalysisHandler() *TopologyAnalysisHandler { // - start_component_uuid (string, required): BFS origin // - end_component_uuid (string, required): reachability target // - check_in_service (bool, optional, default true): skip out-of-service components -func (h *TopologyAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, params map[string]any, db *gorm.DB) error { +func (h *TopologyAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID, params map[string]any, db *gorm.DB) error { logger.Info(ctx, "topology analysis started", "task_id", taskID) // Phase 1: parse params from MQ message @@ -158,6 +158,7 @@ func (h *TopologyAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID, // check the start node itself before BFS if !inServiceMap[startComponentUUID] { + fmt.Println(11111) return persistTopologyResult(ctx, db, taskID, startComponentUUID, endComponentUUID, checkInService, false, nil, &startComponentUUID) } @@ -220,6 +221,7 @@ func (h *TopologyAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID, // parseTopologyAnalysisParams extracts and validates the three required fields. // check_in_service defaults to true when absent. func parseTopologyAnalysisParams(params map[string]any) (startID, endID uuid.UUID, checkInService bool, err error) { + fmt.Printf("params:%+v\n", params) startStr, ok := params["start_component_uuid"].(string) if !ok || startStr == "" { err = fmt.Errorf("missing or invalid start_component_uuid") @@ -318,10 +320,10 @@ func NewEventAnalysisHandler() *EventAnalysisHandler { } // Execute processes an event analysis task -func (h *EventAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, params map[string]any, db *gorm.DB) error { +func (h *EventAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID, params map[string]any, db *gorm.DB) error { logger.Info(ctx, "Starting event analysis", "task_id", taskID, - "task_type", taskType, + "task_params", params, ) // TODO: Implement actual event analysis logic @@ -334,7 +336,8 @@ func (h *EventAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID, ta // Simulate work logger.Info(ctx, "Event analysis completed", "task_id", taskID, - "task_type", taskType, + "task_params", params, + "db", db, ) return nil @@ -358,10 +361,11 @@ func NewBatchImportHandler() *BatchImportHandler { } // Execute processes a batch import task -func (h *BatchImportHandler) Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, params map[string]any, db *gorm.DB) error { +func (h *BatchImportHandler) Execute(ctx context.Context, taskID uuid.UUID, params map[string]any, db *gorm.DB) error { logger.Info(ctx, "Starting batch import", "task_id", taskID, - "task_type", taskType, + "task_params", params, + "db", db, ) // TODO: Implement actual batch import logic @@ -374,7 +378,8 @@ func (h *BatchImportHandler) Execute(ctx context.Context, taskID uuid.UUID, task // Simulate work logger.Info(ctx, "Batch import completed", "task_id", taskID, - "task_type", taskType, + "task_params", params, + "db", db, ) return nil @@ -385,46 +390,9 @@ func (h *BatchImportHandler) CanHandle(taskType TaskType) bool { return taskType == TypeBatchImport } -// CompositeHandler can handle multiple task types by delegating to appropriate handlers -type CompositeHandler struct { - factory *HandlerFactory -} - -// NewCompositeHandler creates a new CompositeHandler -func NewCompositeHandler(factory *HandlerFactory) *CompositeHandler { - return &CompositeHandler{factory: factory} -} - -// Execute delegates task execution to the appropriate handler -func (h *CompositeHandler) Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, params map[string]any, db *gorm.DB) error { - handler, err := h.factory.GetHandler(taskType) - if err != nil { - return fmt.Errorf("failed to get handler for task type %s: %w", taskType, err) - } - - return handler.Execute(ctx, taskID, taskType, params, db) -} - -// CanHandle returns true if any registered handler can handle the task type -func (h *CompositeHandler) CanHandle(taskType TaskType) bool { - _, err := h.factory.GetHandler(taskType) - return err == nil -} - -// Name returns the composite handler name -func (h *CompositeHandler) Name() string { - return "composite_handler" -} - // DefaultHandlerFactory returns a HandlerFactory with all default handlers registered func DefaultHandlerFactory(ctx context.Context) *HandlerFactory { factory := NewHandlerFactory() factory.CreateDefaultHandlers(ctx) return factory } - -// DefaultCompositeHandler returns a CompositeHandler with all default handlers -func DefaultCompositeHandler(ctx context.Context) TaskHandler { - factory := DefaultHandlerFactory(ctx) - return NewCompositeHandler(factory) -} \ No newline at end of file diff --git a/task/initializer.go b/task/initializer.go index 1122a05..008d721 100644 --- a/task/initializer.go +++ b/task/initializer.go @@ -22,12 +22,10 @@ func InitTaskWorker(ctx context.Context, config config.ModelRTConfig, db *gorm.D } // Create task handler factory - handlerFactory := NewHandlerFactory() - handlerFactory.CreateDefaultHandlers(ctx) - handler := DefaultCompositeHandler(ctx) + handlerFactory := DefaultHandlerFactory(ctx) // Create task worker - worker, err := NewTaskWorker(ctx, workerCfg, db, config.RabbitMQConfig, handler) + worker, err := NewTaskWorker(ctx, workerCfg, db, config.RabbitMQConfig, handlerFactory) if err != nil { return nil, fmt.Errorf("failed to create task worker: %w", err) } diff --git a/task/queue_message.go b/task/queue_message.go index 9cc5ebf..c87c02c 100644 --- a/task/queue_message.go +++ b/task/queue_message.go @@ -9,16 +9,14 @@ import ( ) // TaskQueueMessage defines minimal message structure for RabbitMQ/Redis queue dispatch -// This struct is designed to be lightweight for efficient message transport type TaskQueueMessage struct { - TaskID uuid.UUID `json:"task_id"` - TaskType TaskType `json:"task_type"` - Priority int `json:"priority,omitempty"` // Optional, defaults to constants.TaskPriorityDefault - TraceCarrier map[string]string `json:"trace_carrier,omitempty"` // OTel propagation carrier (B3 headers) - Params map[string]any `json:"params,omitempty"` // Task-specific parameters, set by the HTTP handler + TaskID uuid.UUID `json:"task_id"` + TaskType TaskType `json:"task_type"` + Priority int `json:"priority,omitempty"` + TraceCarrier map[string]string `json:"trace_carrier,omitempty"` + Params map[string]any `json:"params,omitempty"` } -// NewTaskQueueMessage creates a new TaskQueueMessage with default priority func NewTaskQueueMessage(taskID uuid.UUID, taskType TaskType) *TaskQueueMessage { return &TaskQueueMessage{ TaskID: taskID, @@ -27,7 +25,6 @@ func NewTaskQueueMessage(taskID uuid.UUID, taskType TaskType) *TaskQueueMessage } } -// NewTaskQueueMessageWithPriority creates a new TaskQueueMessage with specified priority func NewTaskQueueMessageWithPriority(taskID uuid.UUID, taskType TaskType, priority int) *TaskQueueMessage { return &TaskQueueMessage{ TaskID: taskID, @@ -36,19 +33,14 @@ func NewTaskQueueMessageWithPriority(taskID uuid.UUID, taskType TaskType, priori } } -// ToJSON converts TaskQueueMessage to JSON bytes func (m *TaskQueueMessage) ToJSON() ([]byte, error) { return json.Marshal(m) } -// Validate checks if TaskQueueMessage is valid func (m *TaskQueueMessage) Validate() bool { - // Check if TaskID is valid (not nil UUID) if m.TaskID == uuid.Nil { return false } - - // Check if TaskType is valid switch m.TaskType { case TypeTopologyAnalysis, TypeEventAnalysis, TypeBatchImport, TypeTest: return true @@ -57,7 +49,6 @@ func (m *TaskQueueMessage) Validate() bool { } } -// SetPriority sets priority of task queue message with validation func (m *TaskQueueMessage) SetPriority(priority int) { if priority < constants.TaskPriorityLow { priority = constants.TaskPriorityLow @@ -68,7 +59,6 @@ func (m *TaskQueueMessage) SetPriority(priority int) { m.Priority = priority } -// GetPriority returns priority of task queue message func (m *TaskQueueMessage) GetPriority() int { return m.Priority } diff --git a/task/queue_producer.go b/task/queue_producer.go index bcaaffc..3ae6185 100644 --- a/task/queue_producer.go +++ b/task/queue_producer.go @@ -110,8 +110,9 @@ func (p *QueueProducer) declareInfrastructure() error { } // PublishTask publishes a task message to RabbitMQ -func (p *QueueProducer) PublishTask(ctx context.Context, taskID uuid.UUID, taskType TaskType, priority int) error { +func (p *QueueProducer) PublishTask(ctx context.Context, taskID uuid.UUID, taskType TaskType, priority int, params map[string]any) error { message := NewTaskQueueMessageWithPriority(taskID, taskType, priority) + message.Params = params // Validate message if !message.Validate() { @@ -166,10 +167,10 @@ func (p *QueueProducer) PublishTask(ctx context.Context, taskID uuid.UUID, taskT } // PublishTaskWithRetry publishes a task with retry logic -func (p *QueueProducer) PublishTaskWithRetry(ctx context.Context, taskID uuid.UUID, taskType TaskType, priority int, maxRetries int) error { +func (p *QueueProducer) PublishTaskWithRetry(ctx context.Context, taskID uuid.UUID, taskType TaskType, priority int, params map[string]any, maxRetries int) error { var lastErr error for i := range maxRetries { - err := p.PublishTask(ctx, taskID, taskType, priority) + err := p.PublishTask(ctx, taskID, taskType, priority, params) if err == nil { return nil } @@ -255,7 +256,7 @@ func PushTaskToRabbitMQ(ctx context.Context, cfg config.RabbitMQConfig, taskChan taskCtx, pubSpan := otel.Tracer("modelRT/task").Start(taskCtx, "task.publish", oteltrace.WithAttributes(attribute.String("task_id", msg.TaskID.String())), ) - if err := producer.PublishTaskWithRetry(taskCtx, msg.TaskID, msg.TaskType, msg.Priority, 3); err != nil { + if err := producer.PublishTaskWithRetry(taskCtx, msg.TaskID, msg.TaskType, msg.Priority, msg.Params, 3); err != nil { pubSpan.RecordError(err) logger.Error(taskCtx, "publish task to RabbitMQ failed", "task_id", msg.TaskID, "error", err) diff --git a/task/retry_queue.go b/task/retry_queue.go index c602f66..34499bd 100644 --- a/task/retry_queue.go +++ b/task/retry_queue.go @@ -120,7 +120,7 @@ func (q *RetryQueue) ProcessRetryQueue(ctx context.Context, batchSize int) error default: // Publish task to queue for immediate processing taskType := TaskType(task.TaskType) - if err := q.producer.PublishTask(ctx, task.TaskID, taskType, task.Priority); err != nil { + if err := q.producer.PublishTask(ctx, task.TaskID, taskType, task.Priority, map[string]any(task.Params)); err != nil { logger.Error(ctx, "Failed to publish retry task to queue", "task_id", task.TaskID, "task_type", taskType, diff --git a/task/test_task.go b/task/test_task.go index d85b682..580d1a2 100644 --- a/task/test_task.go +++ b/task/test_task.go @@ -104,11 +104,11 @@ func (t *TestTask) Execute(ctx context.Context, taskID uuid.UUID, db *gorm.DB) e // Build result result := map[string]interface{}{ - "status": "completed", - "sleep_duration": params.SleepDuration, - "message": params.Message, - "executed_at": time.Now().Unix(), - "task_id": taskID.String(), + "status": "completed", + "sleep_duration": params.SleepDuration, + "message": params.Message, + "executed_at": time.Now().Unix(), + "task_id": taskID.String(), } // Save result to database @@ -130,21 +130,22 @@ func (t *TestTask) Execute(ctx context.Context, taskID uuid.UUID, db *gorm.DB) e // TestTaskHandler handles test task execution type TestTaskHandler struct { - *BaseHandler + BaseHandler } // NewTestTaskHandler creates a new TestTaskHandler func NewTestTaskHandler() *TestTaskHandler { return &TestTaskHandler{ - BaseHandler: NewBaseHandler("test_task_handler"), + BaseHandler: *NewBaseHandler("test_task_handler"), } } // Execute processes a test task using the unified task interface -func (h *TestTaskHandler) Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, params map[string]any, db *gorm.DB) error { +func (h *TestTaskHandler) Execute(ctx context.Context, taskID uuid.UUID, params map[string]any, db *gorm.DB) error { logger.Info(ctx, "Executing test task", "task_id", taskID, - "task_type", taskType, + "task_params", params, + "db", db, ) // Convert params from MQ message to TestTaskParams diff --git a/task/types.go b/task/types.go index be1f4f5..ef59f6c 100644 --- a/task/types.go +++ b/task/types.go @@ -1,19 +1,6 @@ package task -import ( - "time" -) - -type TaskStatus string - -const ( - StatusPending TaskStatus = "PENDING" - StatusRunning TaskStatus = "RUNNING" - StatusCompleted TaskStatus = "COMPLETED" - StatusFailed TaskStatus = "FAILED" -) - -// TaskType 定义异步任务的具体业务类型 +// TaskType defines the business type of an async task type TaskType string const ( @@ -23,34 +10,23 @@ const ( TypeTest TaskType = "TEST" ) -type Task struct { - ID string `bson:"_id" json:"id"` - Type TaskType `bson:"type" json:"type"` - Status TaskStatus `bson:"status" json:"status"` - Priority int `bson:"priority" json:"priority"` +// TaskStatus defines the lifecycle status of an async task +type TaskStatus string - Params map[string]interface{} `bson:"params" json:"params"` - Result map[string]interface{} `bson:"result,omitempty" json:"result"` - ErrorMsg string `bson:"error_msg,omitempty" json:"error_msg"` +const ( + StatusPending TaskStatus = "PENDING" + StatusRunning TaskStatus = "RUNNING" + StatusCompleted TaskStatus = "COMPLETED" + StatusFailed TaskStatus = "FAILED" +) - CreatedAt time.Time `bson:"created_at" json:"created_at"` - StartedAt time.Time `bson:"started_at,omitempty" json:"started_at"` - CompletedAt time.Time `bson:"completed_at,omitempty" json:"completed_at"` -} +// UnifiedTaskType defines all async task types in a single location +type UnifiedTaskType string -type TopologyParams struct { - CheckIsland bool `json:"check_island"` - CheckShort bool `json:"check_short"` - BaseModelIDs []string `json:"base_model_ids"` -} - -type EventAnalysisParams struct { - MotorID string `json:"motor_id"` - TriggerID string `json:"trigger_id"` - DurationMS int `json:"duration_ms"` -} - -type BatchImportParams struct { - FileName string `json:"file_name"` - FilePath string `json:"file_path"` -} +const ( + TaskTypeTopologyAnalysis UnifiedTaskType = "TOPOLOGY_ANALYSIS" + TaskTypePerformanceAnalysis UnifiedTaskType = "PERFORMANCE_ANALYSIS" + TaskTypeEventAnalysis UnifiedTaskType = "EVENT_ANALYSIS" + TaskTypeBatchImport UnifiedTaskType = "BATCH_IMPORT" + TaskTypeTest UnifiedTaskType = "TEST" +) diff --git a/task/types_v2.go b/task/types_v2.go deleted file mode 100644 index aed3558..0000000 --- a/task/types_v2.go +++ /dev/null @@ -1,138 +0,0 @@ -// Package task provides unified task type definitions and interfaces -package task - -import ( - "context" - "fmt" - - "github.com/gofrs/uuid" - "gorm.io/gorm" -) - -// UnifiedTaskType defines all async task types in a single location -type UnifiedTaskType string - -const ( - // TaskTypeTopologyAnalysis represents topology analysis task - TaskTypeTopologyAnalysis UnifiedTaskType = "TOPOLOGY_ANALYSIS" - // TaskTypePerformanceAnalysis represents performance analysis task - TaskTypePerformanceAnalysis UnifiedTaskType = "PERFORMANCE_ANALYSIS" - // TaskTypeEventAnalysis represents event analysis task - TaskTypeEventAnalysis UnifiedTaskType = "EVENT_ANALYSIS" - // TaskTypeBatchImport represents batch import task - TaskTypeBatchImport UnifiedTaskType = "BATCH_IMPORT" - // TaskTypeTest represents test task for system verification - TaskTypeTest UnifiedTaskType = "TEST" -) - -// UnifiedTaskStatus defines task status constants -type UnifiedTaskStatus string - -const ( - // TaskStatusPending represents task waiting to be processed - TaskStatusPending UnifiedTaskStatus = "PENDING" - // TaskStatusRunning represents task currently executing - TaskStatusRunning UnifiedTaskStatus = "RUNNING" - // TaskStatusCompleted represents task finished successfully - TaskStatusCompleted UnifiedTaskStatus = "COMPLETED" - // TaskStatusFailed represents task failed with error - TaskStatusFailed UnifiedTaskStatus = "FAILED" -) - -// TaskParams defines the interface for task-specific parameters -// All task types must implement this interface to provide their parameter structure -type TaskParams interface { - // Validate checks if the parameters are valid for this task type - Validate() error - // GetType returns the task type these parameters are for - GetType() UnifiedTaskType - // ToMap converts parameters to map for database storage - ToMap() map[string]interface{} - // FromMap populates parameters from map (for database retrieval) - FromMap(params map[string]interface{}) error -} - -// UnifiedTask defines the base interface that all tasks must implement -// This provides a clean abstraction for task execution and management -type UnifiedTask interface { - // GetType returns the task type - GetType() UnifiedTaskType - - // GetParams returns the task parameters - GetParams() TaskParams - - // Execute performs the actual task logic - Execute(ctx context.Context, taskID uuid.UUID, db *gorm.DB) error - - // GetName returns a human-readable task name for logging - GetName() string - - // Validate checks if the task is valid before execution - Validate() error -} - -// BaseTask provides common functionality for all task implementations -type BaseTask struct { - taskType UnifiedTaskType - params TaskParams - name string -} - -// NewBaseTask creates a new BaseTask instance -func NewBaseTask(taskType UnifiedTaskType, params TaskParams, name string) *BaseTask { - return &BaseTask{ - taskType: taskType, - params: params, - name: name, - } -} - -// GetType returns the task type -func (t *BaseTask) GetType() UnifiedTaskType { - return t.taskType -} - -// GetParams returns the task parameters -func (t *BaseTask) GetParams() TaskParams { - return t.params -} - -// GetName returns the task name -func (t *BaseTask) GetName() string { - return t.name -} - -// Validate checks if the task is valid -func (t *BaseTask) Validate() error { - if t.params == nil { - return fmt.Errorf("task parameters cannot be nil") - } - - if t.taskType != t.params.GetType() { - return fmt.Errorf("task type mismatch: expected %s, got %s", t.taskType, t.params.GetType()) - } - - return t.params.Validate() -} - -// IsTaskType checks if a task type string is valid -func IsTaskType(taskType string) bool { - switch UnifiedTaskType(taskType) { - case TaskTypeTopologyAnalysis, TaskTypePerformanceAnalysis, - TaskTypeEventAnalysis, TaskTypeBatchImport, TaskTypeTest: - return true - default: - return false - } -} - -// GetTaskTypes returns all registered task types -func GetTaskTypes() []UnifiedTaskType { - return []UnifiedTaskType{ - TaskTypeTopologyAnalysis, - TaskTypePerformanceAnalysis, - TaskTypeEventAnalysis, - TaskTypeBatchImport, - TaskTypeTest, - } -} diff --git a/task/worker.go b/task/worker.go index 280b217..20795cf 100644 --- a/task/worker.go +++ b/task/worker.go @@ -52,56 +52,56 @@ func DefaultWorkerConfig() WorkerConfig { // TaskWorker manages a pool of workers for processing asynchronous tasks type TaskWorker struct { - cfg WorkerConfig - db *gorm.DB - pool *ants.Pool - conn *amqp.Connection - ch *amqp.Channel - handler TaskHandler - retryQueue *RetryQueue - stopChan chan struct{} - wg sync.WaitGroup - ctx context.Context - cancel context.CancelFunc - metrics *WorkerMetrics + cfg WorkerConfig + db *gorm.DB + pool *ants.Pool + conn *amqp.Connection + ch *amqp.Channel + factory *HandlerFactory + retryQueue *RetryQueue + stopChan chan struct{} + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + metrics *WorkerMetrics } // WorkerMetrics holds metrics for the worker pool type WorkerMetrics struct { // Task statistics by type - TasksProcessed map[TaskType]int64 - TasksFailed map[TaskType]int64 - TasksSuccess map[TaskType]int64 - ProcessingTime map[TaskType]time.Duration + TasksProcessed map[TaskType]int64 + TasksFailed map[TaskType]int64 + TasksSuccess map[TaskType]int64 + ProcessingTime map[TaskType]time.Duration // Aggregate counters (maintained for backward compatibility) - TotalProcessed int64 - TotalFailed int64 - TotalSuccess int64 - TasksInProgress int32 + TotalProcessed int64 + TotalFailed int64 + TotalSuccess int64 + TasksInProgress int32 // Queue and latency metrics - QueueDepth int - QueueLatency time.Duration + QueueDepth int + QueueLatency time.Duration // Worker resource metrics - WorkersActive int - WorkersIdle int - MemoryUsage uint64 - CPULoad float64 + WorkersActive int + WorkersIdle int + MemoryUsage uint64 + CPULoad float64 // Time window metrics - LastMinuteRate float64 - Last5MinutesRate float64 - LastHourRate float64 + LastMinuteRate float64 + Last5MinutesRate float64 + LastHourRate float64 // Health and timing - LastHealthCheck time.Time - mu sync.RWMutex + LastHealthCheck time.Time + mu sync.RWMutex } // NewTaskWorker creates a new TaskWorker instance -func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg config.RabbitMQConfig, handler TaskHandler) (*TaskWorker, error) { +func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg config.RabbitMQConfig, factory *HandlerFactory) (*TaskWorker, error) { // Initialize RabbitMQ connection mq.InitRabbitProxy(ctx, rabbitCfg) conn := mq.GetConn() @@ -118,10 +118,10 @@ func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg // Declare queue (ensure it exists with proper arguments) _, err = ch.QueueDeclare( constants.TaskQueueName, // name - true, // durable - false, // delete when unused - false, // exclusive - false, // no-wait + true, // durable + false, // delete when unused + false, // exclusive + false, // no-wait amqp.Table{ "x-max-priority": constants.TaskMaxPriority, "x-message-ttl": constants.TaskDefaultMessageTTL.Milliseconds(), @@ -158,15 +158,15 @@ func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg pool: pool, conn: conn, ch: ch, - handler: handler, + factory: factory, stopChan: make(chan struct{}), ctx: ctxWithCancel, cancel: cancel, metrics: &WorkerMetrics{ - TasksProcessed: make(map[TaskType]int64), - TasksFailed: make(map[TaskType]int64), - TasksSuccess: make(map[TaskType]int64), - ProcessingTime: make(map[TaskType]time.Duration), + TasksProcessed: make(map[TaskType]int64), + TasksFailed: make(map[TaskType]int64), + TasksSuccess: make(map[TaskType]int64), + ProcessingTime: make(map[TaskType]time.Duration), LastHealthCheck: time.Now(), }, } @@ -203,13 +203,13 @@ func (w *TaskWorker) consumerLoop(consumerID int) { // Consume messages from the queue msgs, err := w.ch.Consume( - constants.TaskQueueName, // queue + constants.TaskQueueName, // queue fmt.Sprintf("worker-%d", consumerID), // consumer tag - false, // auto-ack - false, // exclusive - false, // no-local - false, // no-wait - nil, // args + false, // auto-ack + false, // exclusive + false, // no-local + false, // no-wait + nil, // args ) if err != nil { logger.Error(w.ctx, "Failed to start consumer", @@ -316,7 +316,7 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) { // Execute task using handler startTime := time.Now() - err := w.handler.Execute(ctx, taskMsg.TaskID, taskMsg.TaskType, taskMsg.Params, w.db) + err := w.dispatch(ctx, taskMsg.TaskType, taskMsg.TaskID, taskMsg.Params, &msg) processingTime := time.Since(startTime) if err != nil { @@ -431,29 +431,37 @@ func (w *TaskWorker) updateTaskStatus(ctx context.Context, taskID uuid.UUID, sta return nil } -// updateTaskWithError updates a task with error information +// updateTaskWithError updates a task with error information in both async_task and async_task_result. func (w *TaskWorker) updateTaskWithError(ctx context.Context, taskID uuid.UUID, err error) error { - // Update task error information in database errorMsg := err.Error() stackTrace := fmt.Sprintf("%+v", err) - updateErr := database.UpdateTaskErrorInfo(ctx, w.db, taskID, errorMsg, stackTrace) - if updateErr != nil { - logger.Error(ctx, "Failed to update task error info", - "task_id", taskID, - "error", updateErr, - ) + if updateErr := database.UpdateTaskErrorInfo(ctx, w.db, taskID, errorMsg, stackTrace); updateErr != nil { + logger.Error(ctx, "Failed to update task error info", "task_id", taskID, "error", updateErr) return updateErr } - logger.Warn(ctx, "Task failed with error", - "task_id", taskID, - "error", errorMsg, - ) + if updateErr := database.UpdateAsyncTaskResultWithError(ctx, w.db, taskID, 500, errorMsg, nil); updateErr != nil { + logger.Error(ctx, "Failed to update task result with error", "task_id", taskID, "error", updateErr) + return updateErr + } + logger.Warn(ctx, "Task failed with error", "task_id", taskID, "error", errorMsg) return nil } +// dispatch routes a task message to the appropriate handler and executes it. +// Nacks the message and returns an error if no handler is registered for the task type. +func (w *TaskWorker) dispatch(ctx context.Context, taskType TaskType, taskID uuid.UUID, params map[string]any, msg *amqp.Delivery) error { + handler, err := w.factory.GetHandler(taskType) + if err != nil { + logger.Error(ctx, "No handler for task type", "task_type", taskType) + msg.Nack(false, false) + return err + } + return handler.Execute(ctx, taskID, params, w.db) +} + // healthCheckLoop periodically checks worker health and metrics func (w *TaskWorker) healthCheckLoop() { defer w.wg.Done() @@ -479,10 +487,10 @@ func (w *TaskWorker) checkHealth() { // Update queue depth queue, err := w.ch.QueueDeclarePassive( constants.TaskQueueName, // name - true, // durable - false, // delete when unused - false, // exclusive - false, // no-wait + true, // durable + false, // delete when unused + false, // exclusive + false, // no-wait amqp.Table{ "x-max-priority": constants.TaskMaxPriority, "x-message-ttl": constants.TaskDefaultMessageTTL.Milliseconds(), @@ -565,24 +573,24 @@ func (w *TaskWorker) GetMetrics() *WorkerMetrics { // Create a copy without the mutex to avoid copylocks warning return &WorkerMetrics{ - TasksProcessed: tasksProcessedCopy, - TasksFailed: tasksFailedCopy, - TasksSuccess: tasksSuccessCopy, - ProcessingTime: processingTimeCopy, - TotalProcessed: w.metrics.TotalProcessed, - TotalFailed: w.metrics.TotalFailed, - TotalSuccess: w.metrics.TotalSuccess, - TasksInProgress: w.metrics.TasksInProgress, - QueueDepth: w.metrics.QueueDepth, - QueueLatency: w.metrics.QueueLatency, - WorkersActive: w.metrics.WorkersActive, - WorkersIdle: w.metrics.WorkersIdle, - MemoryUsage: w.metrics.MemoryUsage, - CPULoad: w.metrics.CPULoad, - LastMinuteRate: w.metrics.LastMinuteRate, + TasksProcessed: tasksProcessedCopy, + TasksFailed: tasksFailedCopy, + TasksSuccess: tasksSuccessCopy, + ProcessingTime: processingTimeCopy, + TotalProcessed: w.metrics.TotalProcessed, + TotalFailed: w.metrics.TotalFailed, + TotalSuccess: w.metrics.TotalSuccess, + TasksInProgress: w.metrics.TasksInProgress, + QueueDepth: w.metrics.QueueDepth, + QueueLatency: w.metrics.QueueLatency, + WorkersActive: w.metrics.WorkersActive, + WorkersIdle: w.metrics.WorkersIdle, + MemoryUsage: w.metrics.MemoryUsage, + CPULoad: w.metrics.CPULoad, + LastMinuteRate: w.metrics.LastMinuteRate, Last5MinutesRate: w.metrics.Last5MinutesRate, - LastHourRate: w.metrics.LastHourRate, - LastHealthCheck: w.metrics.LastHealthCheck, + LastHourRate: w.metrics.LastHourRate, + LastHealthCheck: w.metrics.LastHealthCheck, // Mutex is intentionally omitted } } @@ -595,6 +603,7 @@ func (w *TaskWorker) IsHealthy() bool { // Consider unhealthy if last health check was too long ago return time.Since(w.metrics.LastHealthCheck) < 2*w.cfg.PollingInterval } + // RecordMetrics periodically records worker metrics to the logging system func (w *TaskWorker) RecordMetrics(interval time.Duration) { go func() {