refactor: overhaul async task handler routing and fix data consistency

- fix params lost in RabbitMQ transit by threading them through PublishTask/PublishTaskWithRetry
  - fix UpdateTaskErrorInfo not setting status=FAILED on async_task
  - fix UpdateAsyncTaskResultWithError silently skipping when no result row exists (UPDATE → upsert)
  - sync task failure to async_task_result in updateTaskWithError
  - remove taskType from AsyncTaskHandler.Execute interface; rename TaskHandler → AsyncTaskHandler
  - replace CompositeHandler with direct factory.GetHandler dispatch via worker.dispatch()
  - use constructors (NewXxxHandler) for handler registration instead of zero-value literals
  - consolidate TaskType/TaskStatus/UnifiedTaskType into task/types.go; delete types_v2.go
  - extract BaseTask/TaskParams into task/base_task.go
This commit is contained in:
douxu 2026-04-27 17:55:38 +08:00
parent 1b1f43db7f
commit 33f7d758e5
12 changed files with 226 additions and 357 deletions

View File

@ -60,6 +60,7 @@ func UpdateTaskErrorInfo(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, err
Updates(map[string]any{
"failure_reason": errorMsg,
"stack_trace": stackTrace,
"status": orm.AsyncTaskStatusFailed,
})
return result.Error

View File

@ -161,14 +161,18 @@ func CreateAsyncTaskResult(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, r
return resultOp.Error
}
// UpdateAsyncTaskResultWithError updates a task result with error information
// UpdateAsyncTaskResultWithError upserts a task result with error information.
func UpdateAsyncTaskResultWithError(ctx context.Context, tx *gorm.DB, taskID uuid.UUID, code int, message string, detail orm.JSONMap) error {
// ctx timeout judgment
cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// Update with error information
result := tx.WithContext(cancelCtx).
if err := tx.WithContext(cancelCtx).
Where("task_id = ?", taskID).
FirstOrCreate(&orm.AsyncTaskResult{TaskID: taskID}).Error; err != nil {
return err
}
return tx.WithContext(cancelCtx).
Model(&orm.AsyncTaskResult{}).
Where("task_id = ?", taskID).
Updates(map[string]any{
@ -176,9 +180,7 @@ func UpdateAsyncTaskResultWithError(ctx context.Context, tx *gorm.DB, taskID uui
"error_message": message,
"error_detail": detail,
"result": nil,
})
return result.Error
}).Error
}
// UpdateAsyncTaskResultWithSuccess updates a task result with success information

61
task/base_task.go Normal file
View File

@ -0,0 +1,61 @@
// Package task provides unified task type definitions and interfaces
package task
import (
"context"
"fmt"
"github.com/gofrs/uuid"
"gorm.io/gorm"
)
// TaskParams defines the interface for task-specific parameters
type TaskParams interface {
Validate() error
GetType() UnifiedTaskType
ToMap() map[string]interface{}
FromMap(params map[string]interface{}) error
}
// BaseTask provides common functionality for all task implementations
type BaseTask struct {
taskType UnifiedTaskType
params TaskParams
name string
}
// NewBaseTask creates a new BaseTask instance
func NewBaseTask(taskType UnifiedTaskType, params TaskParams, name string) *BaseTask {
return &BaseTask{
taskType: taskType,
params: params,
name: name,
}
}
func (t *BaseTask) GetType() UnifiedTaskType {
return t.taskType
}
func (t *BaseTask) GetParams() TaskParams {
return t.params
}
func (t *BaseTask) GetName() string {
return t.name
}
func (t *BaseTask) Validate() error {
if t.params == nil {
return fmt.Errorf("task parameters cannot be nil")
}
if t.taskType != t.params.GetType() {
return fmt.Errorf("task type mismatch: expected %s, got %s", t.taskType, t.params.GetType())
}
return t.params.Validate()
}
// Execute is a placeholder; concrete task types override this via embedding.
func (t *BaseTask) Execute(_ context.Context, _ uuid.UUID, _ *gorm.DB) error {
return fmt.Errorf("Execute not implemented for task type %s", t.taskType)
}

View File

@ -15,10 +15,10 @@ import (
"gorm.io/gorm"
)
// TaskHandler defines the interface for task processors
type TaskHandler interface {
// Execute processes a task with the given ID, type, and params from the MQ message
Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, params map[string]any, db *gorm.DB) error
// AsyncTaskHandler defines the interface for task processors
type AsyncTaskHandler interface {
// Execute processes a task with the given ID and params from the MQ message
Execute(ctx context.Context, taskID uuid.UUID, params map[string]any, db *gorm.DB) error
// CanHandle returns true if this handler can process the given task type
CanHandle(taskType TaskType) bool
// Name returns the name of the handler for logging and metrics
@ -27,19 +27,19 @@ type TaskHandler interface {
// HandlerFactory creates task handlers based on task type
type HandlerFactory struct {
handlers map[TaskType]TaskHandler
handlers map[TaskType]AsyncTaskHandler
mu sync.RWMutex
}
// NewHandlerFactory creates a new HandlerFactory
func NewHandlerFactory() *HandlerFactory {
return &HandlerFactory{
handlers: make(map[TaskType]TaskHandler),
handlers: make(map[TaskType]AsyncTaskHandler),
}
}
// RegisterHandler registers a handler for a specific task type
func (f *HandlerFactory) RegisterHandler(ctx context.Context, taskType TaskType, handler TaskHandler) {
func (f *HandlerFactory) RegisterHandler(ctx context.Context, taskType TaskType, handler AsyncTaskHandler) {
f.mu.Lock()
defer f.mu.Unlock()
@ -51,7 +51,7 @@ func (f *HandlerFactory) RegisterHandler(ctx context.Context, taskType TaskType,
}
// GetHandler returns a handler for the given task type
func (f *HandlerFactory) GetHandler(taskType TaskType) (TaskHandler, error) {
func (f *HandlerFactory) GetHandler(taskType TaskType) (AsyncTaskHandler, error) {
f.mu.RLock()
handler, exists := f.handlers[taskType]
f.mu.RUnlock()
@ -65,10 +65,10 @@ func (f *HandlerFactory) GetHandler(taskType TaskType) (TaskHandler, error) {
// CreateDefaultHandlers registers all default task handlers
func (f *HandlerFactory) CreateDefaultHandlers(ctx context.Context) {
f.RegisterHandler(ctx, TypeTopologyAnalysis, &TopologyAnalysisHandler{})
f.RegisterHandler(ctx, TypeEventAnalysis, &EventAnalysisHandler{})
f.RegisterHandler(ctx, TypeBatchImport, &BatchImportHandler{})
f.RegisterHandler(ctx, TaskType(TaskTypeTest), NewTestTaskHandler())
f.RegisterHandler(ctx, TypeTopologyAnalysis, NewTopologyAnalysisHandler())
f.RegisterHandler(ctx, TypeEventAnalysis, NewEventAnalysisHandler())
f.RegisterHandler(ctx, TypeBatchImport, NewBatchImportHandler())
f.RegisterHandler(ctx, TypeTest, NewTestTaskHandler())
}
// BaseHandler provides common functionality for all task handlers
@ -103,7 +103,7 @@ func NewTopologyAnalysisHandler() *TopologyAnalysisHandler {
// - start_component_uuid (string, required): BFS origin
// - end_component_uuid (string, required): reachability target
// - check_in_service (bool, optional, default true): skip out-of-service components
func (h *TopologyAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, params map[string]any, db *gorm.DB) error {
func (h *TopologyAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID, params map[string]any, db *gorm.DB) error {
logger.Info(ctx, "topology analysis started", "task_id", taskID)
// Phase 1: parse params from MQ message
@ -158,6 +158,7 @@ func (h *TopologyAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID,
// check the start node itself before BFS
if !inServiceMap[startComponentUUID] {
fmt.Println(11111)
return persistTopologyResult(ctx, db, taskID, startComponentUUID, endComponentUUID,
checkInService, false, nil, &startComponentUUID)
}
@ -220,6 +221,7 @@ func (h *TopologyAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID,
// parseTopologyAnalysisParams extracts and validates the three required fields.
// check_in_service defaults to true when absent.
func parseTopologyAnalysisParams(params map[string]any) (startID, endID uuid.UUID, checkInService bool, err error) {
fmt.Printf("params:%+v\n", params)
startStr, ok := params["start_component_uuid"].(string)
if !ok || startStr == "" {
err = fmt.Errorf("missing or invalid start_component_uuid")
@ -318,10 +320,10 @@ func NewEventAnalysisHandler() *EventAnalysisHandler {
}
// Execute processes an event analysis task
func (h *EventAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, params map[string]any, db *gorm.DB) error {
func (h *EventAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID, params map[string]any, db *gorm.DB) error {
logger.Info(ctx, "Starting event analysis",
"task_id", taskID,
"task_type", taskType,
"task_params", params,
)
// TODO: Implement actual event analysis logic
@ -334,7 +336,8 @@ func (h *EventAnalysisHandler) Execute(ctx context.Context, taskID uuid.UUID, ta
// Simulate work
logger.Info(ctx, "Event analysis completed",
"task_id", taskID,
"task_type", taskType,
"task_params", params,
"db", db,
)
return nil
@ -358,10 +361,11 @@ func NewBatchImportHandler() *BatchImportHandler {
}
// Execute processes a batch import task
func (h *BatchImportHandler) Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, params map[string]any, db *gorm.DB) error {
func (h *BatchImportHandler) Execute(ctx context.Context, taskID uuid.UUID, params map[string]any, db *gorm.DB) error {
logger.Info(ctx, "Starting batch import",
"task_id", taskID,
"task_type", taskType,
"task_params", params,
"db", db,
)
// TODO: Implement actual batch import logic
@ -374,7 +378,8 @@ func (h *BatchImportHandler) Execute(ctx context.Context, taskID uuid.UUID, task
// Simulate work
logger.Info(ctx, "Batch import completed",
"task_id", taskID,
"task_type", taskType,
"task_params", params,
"db", db,
)
return nil
@ -385,46 +390,9 @@ func (h *BatchImportHandler) CanHandle(taskType TaskType) bool {
return taskType == TypeBatchImport
}
// CompositeHandler can handle multiple task types by delegating to appropriate handlers
type CompositeHandler struct {
factory *HandlerFactory
}
// NewCompositeHandler creates a new CompositeHandler
func NewCompositeHandler(factory *HandlerFactory) *CompositeHandler {
return &CompositeHandler{factory: factory}
}
// Execute delegates task execution to the appropriate handler
func (h *CompositeHandler) Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, params map[string]any, db *gorm.DB) error {
handler, err := h.factory.GetHandler(taskType)
if err != nil {
return fmt.Errorf("failed to get handler for task type %s: %w", taskType, err)
}
return handler.Execute(ctx, taskID, taskType, params, db)
}
// CanHandle returns true if any registered handler can handle the task type
func (h *CompositeHandler) CanHandle(taskType TaskType) bool {
_, err := h.factory.GetHandler(taskType)
return err == nil
}
// Name returns the composite handler name
func (h *CompositeHandler) Name() string {
return "composite_handler"
}
// DefaultHandlerFactory returns a HandlerFactory with all default handlers registered
func DefaultHandlerFactory(ctx context.Context) *HandlerFactory {
factory := NewHandlerFactory()
factory.CreateDefaultHandlers(ctx)
return factory
}
// DefaultCompositeHandler returns a CompositeHandler with all default handlers
func DefaultCompositeHandler(ctx context.Context) TaskHandler {
factory := DefaultHandlerFactory(ctx)
return NewCompositeHandler(factory)
}

View File

@ -22,12 +22,10 @@ func InitTaskWorker(ctx context.Context, config config.ModelRTConfig, db *gorm.D
}
// Create task handler factory
handlerFactory := NewHandlerFactory()
handlerFactory.CreateDefaultHandlers(ctx)
handler := DefaultCompositeHandler(ctx)
handlerFactory := DefaultHandlerFactory(ctx)
// Create task worker
worker, err := NewTaskWorker(ctx, workerCfg, db, config.RabbitMQConfig, handler)
worker, err := NewTaskWorker(ctx, workerCfg, db, config.RabbitMQConfig, handlerFactory)
if err != nil {
return nil, fmt.Errorf("failed to create task worker: %w", err)
}

View File

@ -9,16 +9,14 @@ import (
)
// TaskQueueMessage defines minimal message structure for RabbitMQ/Redis queue dispatch
// This struct is designed to be lightweight for efficient message transport
type TaskQueueMessage struct {
TaskID uuid.UUID `json:"task_id"`
TaskType TaskType `json:"task_type"`
Priority int `json:"priority,omitempty"` // Optional, defaults to constants.TaskPriorityDefault
TraceCarrier map[string]string `json:"trace_carrier,omitempty"` // OTel propagation carrier (B3 headers)
Params map[string]any `json:"params,omitempty"` // Task-specific parameters, set by the HTTP handler
TaskID uuid.UUID `json:"task_id"`
TaskType TaskType `json:"task_type"`
Priority int `json:"priority,omitempty"`
TraceCarrier map[string]string `json:"trace_carrier,omitempty"`
Params map[string]any `json:"params,omitempty"`
}
// NewTaskQueueMessage creates a new TaskQueueMessage with default priority
func NewTaskQueueMessage(taskID uuid.UUID, taskType TaskType) *TaskQueueMessage {
return &TaskQueueMessage{
TaskID: taskID,
@ -27,7 +25,6 @@ func NewTaskQueueMessage(taskID uuid.UUID, taskType TaskType) *TaskQueueMessage
}
}
// NewTaskQueueMessageWithPriority creates a new TaskQueueMessage with specified priority
func NewTaskQueueMessageWithPriority(taskID uuid.UUID, taskType TaskType, priority int) *TaskQueueMessage {
return &TaskQueueMessage{
TaskID: taskID,
@ -36,19 +33,14 @@ func NewTaskQueueMessageWithPriority(taskID uuid.UUID, taskType TaskType, priori
}
}
// ToJSON converts TaskQueueMessage to JSON bytes
func (m *TaskQueueMessage) ToJSON() ([]byte, error) {
return json.Marshal(m)
}
// Validate checks if TaskQueueMessage is valid
func (m *TaskQueueMessage) Validate() bool {
// Check if TaskID is valid (not nil UUID)
if m.TaskID == uuid.Nil {
return false
}
// Check if TaskType is valid
switch m.TaskType {
case TypeTopologyAnalysis, TypeEventAnalysis, TypeBatchImport, TypeTest:
return true
@ -57,7 +49,6 @@ func (m *TaskQueueMessage) Validate() bool {
}
}
// SetPriority sets priority of task queue message with validation
func (m *TaskQueueMessage) SetPriority(priority int) {
if priority < constants.TaskPriorityLow {
priority = constants.TaskPriorityLow
@ -68,7 +59,6 @@ func (m *TaskQueueMessage) SetPriority(priority int) {
m.Priority = priority
}
// GetPriority returns priority of task queue message
func (m *TaskQueueMessage) GetPriority() int {
return m.Priority
}

View File

@ -110,8 +110,9 @@ func (p *QueueProducer) declareInfrastructure() error {
}
// PublishTask publishes a task message to RabbitMQ
func (p *QueueProducer) PublishTask(ctx context.Context, taskID uuid.UUID, taskType TaskType, priority int) error {
func (p *QueueProducer) PublishTask(ctx context.Context, taskID uuid.UUID, taskType TaskType, priority int, params map[string]any) error {
message := NewTaskQueueMessageWithPriority(taskID, taskType, priority)
message.Params = params
// Validate message
if !message.Validate() {
@ -166,10 +167,10 @@ func (p *QueueProducer) PublishTask(ctx context.Context, taskID uuid.UUID, taskT
}
// PublishTaskWithRetry publishes a task with retry logic
func (p *QueueProducer) PublishTaskWithRetry(ctx context.Context, taskID uuid.UUID, taskType TaskType, priority int, maxRetries int) error {
func (p *QueueProducer) PublishTaskWithRetry(ctx context.Context, taskID uuid.UUID, taskType TaskType, priority int, params map[string]any, maxRetries int) error {
var lastErr error
for i := range maxRetries {
err := p.PublishTask(ctx, taskID, taskType, priority)
err := p.PublishTask(ctx, taskID, taskType, priority, params)
if err == nil {
return nil
}
@ -255,7 +256,7 @@ func PushTaskToRabbitMQ(ctx context.Context, cfg config.RabbitMQConfig, taskChan
taskCtx, pubSpan := otel.Tracer("modelRT/task").Start(taskCtx, "task.publish",
oteltrace.WithAttributes(attribute.String("task_id", msg.TaskID.String())),
)
if err := producer.PublishTaskWithRetry(taskCtx, msg.TaskID, msg.TaskType, msg.Priority, 3); err != nil {
if err := producer.PublishTaskWithRetry(taskCtx, msg.TaskID, msg.TaskType, msg.Priority, msg.Params, 3); err != nil {
pubSpan.RecordError(err)
logger.Error(taskCtx, "publish task to RabbitMQ failed",
"task_id", msg.TaskID, "error", err)

View File

@ -120,7 +120,7 @@ func (q *RetryQueue) ProcessRetryQueue(ctx context.Context, batchSize int) error
default:
// Publish task to queue for immediate processing
taskType := TaskType(task.TaskType)
if err := q.producer.PublishTask(ctx, task.TaskID, taskType, task.Priority); err != nil {
if err := q.producer.PublishTask(ctx, task.TaskID, taskType, task.Priority, map[string]any(task.Params)); err != nil {
logger.Error(ctx, "Failed to publish retry task to queue",
"task_id", task.TaskID,
"task_type", taskType,

View File

@ -104,11 +104,11 @@ func (t *TestTask) Execute(ctx context.Context, taskID uuid.UUID, db *gorm.DB) e
// Build result
result := map[string]interface{}{
"status": "completed",
"sleep_duration": params.SleepDuration,
"message": params.Message,
"executed_at": time.Now().Unix(),
"task_id": taskID.String(),
"status": "completed",
"sleep_duration": params.SleepDuration,
"message": params.Message,
"executed_at": time.Now().Unix(),
"task_id": taskID.String(),
}
// Save result to database
@ -130,21 +130,22 @@ func (t *TestTask) Execute(ctx context.Context, taskID uuid.UUID, db *gorm.DB) e
// TestTaskHandler handles test task execution
type TestTaskHandler struct {
*BaseHandler
BaseHandler
}
// NewTestTaskHandler creates a new TestTaskHandler
func NewTestTaskHandler() *TestTaskHandler {
return &TestTaskHandler{
BaseHandler: NewBaseHandler("test_task_handler"),
BaseHandler: *NewBaseHandler("test_task_handler"),
}
}
// Execute processes a test task using the unified task interface
func (h *TestTaskHandler) Execute(ctx context.Context, taskID uuid.UUID, taskType TaskType, params map[string]any, db *gorm.DB) error {
func (h *TestTaskHandler) Execute(ctx context.Context, taskID uuid.UUID, params map[string]any, db *gorm.DB) error {
logger.Info(ctx, "Executing test task",
"task_id", taskID,
"task_type", taskType,
"task_params", params,
"db", db,
)
// Convert params from MQ message to TestTaskParams

View File

@ -1,19 +1,6 @@
package task
import (
"time"
)
type TaskStatus string
const (
StatusPending TaskStatus = "PENDING"
StatusRunning TaskStatus = "RUNNING"
StatusCompleted TaskStatus = "COMPLETED"
StatusFailed TaskStatus = "FAILED"
)
// TaskType 定义异步任务的具体业务类型
// TaskType defines the business type of an async task
type TaskType string
const (
@ -23,34 +10,23 @@ const (
TypeTest TaskType = "TEST"
)
type Task struct {
ID string `bson:"_id" json:"id"`
Type TaskType `bson:"type" json:"type"`
Status TaskStatus `bson:"status" json:"status"`
Priority int `bson:"priority" json:"priority"`
// TaskStatus defines the lifecycle status of an async task
type TaskStatus string
Params map[string]interface{} `bson:"params" json:"params"`
Result map[string]interface{} `bson:"result,omitempty" json:"result"`
ErrorMsg string `bson:"error_msg,omitempty" json:"error_msg"`
const (
StatusPending TaskStatus = "PENDING"
StatusRunning TaskStatus = "RUNNING"
StatusCompleted TaskStatus = "COMPLETED"
StatusFailed TaskStatus = "FAILED"
)
CreatedAt time.Time `bson:"created_at" json:"created_at"`
StartedAt time.Time `bson:"started_at,omitempty" json:"started_at"`
CompletedAt time.Time `bson:"completed_at,omitempty" json:"completed_at"`
}
// UnifiedTaskType defines all async task types in a single location
type UnifiedTaskType string
type TopologyParams struct {
CheckIsland bool `json:"check_island"`
CheckShort bool `json:"check_short"`
BaseModelIDs []string `json:"base_model_ids"`
}
type EventAnalysisParams struct {
MotorID string `json:"motor_id"`
TriggerID string `json:"trigger_id"`
DurationMS int `json:"duration_ms"`
}
type BatchImportParams struct {
FileName string `json:"file_name"`
FilePath string `json:"file_path"`
}
const (
TaskTypeTopologyAnalysis UnifiedTaskType = "TOPOLOGY_ANALYSIS"
TaskTypePerformanceAnalysis UnifiedTaskType = "PERFORMANCE_ANALYSIS"
TaskTypeEventAnalysis UnifiedTaskType = "EVENT_ANALYSIS"
TaskTypeBatchImport UnifiedTaskType = "BATCH_IMPORT"
TaskTypeTest UnifiedTaskType = "TEST"
)

View File

@ -1,138 +0,0 @@
// Package task provides unified task type definitions and interfaces
package task
import (
"context"
"fmt"
"github.com/gofrs/uuid"
"gorm.io/gorm"
)
// UnifiedTaskType defines all async task types in a single location
type UnifiedTaskType string
const (
// TaskTypeTopologyAnalysis represents topology analysis task
TaskTypeTopologyAnalysis UnifiedTaskType = "TOPOLOGY_ANALYSIS"
// TaskTypePerformanceAnalysis represents performance analysis task
TaskTypePerformanceAnalysis UnifiedTaskType = "PERFORMANCE_ANALYSIS"
// TaskTypeEventAnalysis represents event analysis task
TaskTypeEventAnalysis UnifiedTaskType = "EVENT_ANALYSIS"
// TaskTypeBatchImport represents batch import task
TaskTypeBatchImport UnifiedTaskType = "BATCH_IMPORT"
// TaskTypeTest represents test task for system verification
TaskTypeTest UnifiedTaskType = "TEST"
)
// UnifiedTaskStatus defines task status constants
type UnifiedTaskStatus string
const (
// TaskStatusPending represents task waiting to be processed
TaskStatusPending UnifiedTaskStatus = "PENDING"
// TaskStatusRunning represents task currently executing
TaskStatusRunning UnifiedTaskStatus = "RUNNING"
// TaskStatusCompleted represents task finished successfully
TaskStatusCompleted UnifiedTaskStatus = "COMPLETED"
// TaskStatusFailed represents task failed with error
TaskStatusFailed UnifiedTaskStatus = "FAILED"
)
// TaskParams defines the interface for task-specific parameters
// All task types must implement this interface to provide their parameter structure
type TaskParams interface {
// Validate checks if the parameters are valid for this task type
Validate() error
// GetType returns the task type these parameters are for
GetType() UnifiedTaskType
// ToMap converts parameters to map for database storage
ToMap() map[string]interface{}
// FromMap populates parameters from map (for database retrieval)
FromMap(params map[string]interface{}) error
}
// UnifiedTask defines the base interface that all tasks must implement
// This provides a clean abstraction for task execution and management
type UnifiedTask interface {
// GetType returns the task type
GetType() UnifiedTaskType
// GetParams returns the task parameters
GetParams() TaskParams
// Execute performs the actual task logic
Execute(ctx context.Context, taskID uuid.UUID, db *gorm.DB) error
// GetName returns a human-readable task name for logging
GetName() string
// Validate checks if the task is valid before execution
Validate() error
}
// BaseTask provides common functionality for all task implementations
type BaseTask struct {
taskType UnifiedTaskType
params TaskParams
name string
}
// NewBaseTask creates a new BaseTask instance
func NewBaseTask(taskType UnifiedTaskType, params TaskParams, name string) *BaseTask {
return &BaseTask{
taskType: taskType,
params: params,
name: name,
}
}
// GetType returns the task type
func (t *BaseTask) GetType() UnifiedTaskType {
return t.taskType
}
// GetParams returns the task parameters
func (t *BaseTask) GetParams() TaskParams {
return t.params
}
// GetName returns the task name
func (t *BaseTask) GetName() string {
return t.name
}
// Validate checks if the task is valid
func (t *BaseTask) Validate() error {
if t.params == nil {
return fmt.Errorf("task parameters cannot be nil")
}
if t.taskType != t.params.GetType() {
return fmt.Errorf("task type mismatch: expected %s, got %s", t.taskType, t.params.GetType())
}
return t.params.Validate()
}
// IsTaskType checks if a task type string is valid
func IsTaskType(taskType string) bool {
switch UnifiedTaskType(taskType) {
case TaskTypeTopologyAnalysis, TaskTypePerformanceAnalysis,
TaskTypeEventAnalysis, TaskTypeBatchImport, TaskTypeTest:
return true
default:
return false
}
}
// GetTaskTypes returns all registered task types
func GetTaskTypes() []UnifiedTaskType {
return []UnifiedTaskType{
TaskTypeTopologyAnalysis,
TaskTypePerformanceAnalysis,
TaskTypeEventAnalysis,
TaskTypeBatchImport,
TaskTypeTest,
}
}

View File

@ -52,56 +52,56 @@ func DefaultWorkerConfig() WorkerConfig {
// TaskWorker manages a pool of workers for processing asynchronous tasks
type TaskWorker struct {
cfg WorkerConfig
db *gorm.DB
pool *ants.Pool
conn *amqp.Connection
ch *amqp.Channel
handler TaskHandler
retryQueue *RetryQueue
stopChan chan struct{}
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
metrics *WorkerMetrics
cfg WorkerConfig
db *gorm.DB
pool *ants.Pool
conn *amqp.Connection
ch *amqp.Channel
factory *HandlerFactory
retryQueue *RetryQueue
stopChan chan struct{}
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
metrics *WorkerMetrics
}
// WorkerMetrics holds metrics for the worker pool
type WorkerMetrics struct {
// Task statistics by type
TasksProcessed map[TaskType]int64
TasksFailed map[TaskType]int64
TasksSuccess map[TaskType]int64
ProcessingTime map[TaskType]time.Duration
TasksProcessed map[TaskType]int64
TasksFailed map[TaskType]int64
TasksSuccess map[TaskType]int64
ProcessingTime map[TaskType]time.Duration
// Aggregate counters (maintained for backward compatibility)
TotalProcessed int64
TotalFailed int64
TotalSuccess int64
TasksInProgress int32
TotalProcessed int64
TotalFailed int64
TotalSuccess int64
TasksInProgress int32
// Queue and latency metrics
QueueDepth int
QueueLatency time.Duration
QueueDepth int
QueueLatency time.Duration
// Worker resource metrics
WorkersActive int
WorkersIdle int
MemoryUsage uint64
CPULoad float64
WorkersActive int
WorkersIdle int
MemoryUsage uint64
CPULoad float64
// Time window metrics
LastMinuteRate float64
Last5MinutesRate float64
LastHourRate float64
LastMinuteRate float64
Last5MinutesRate float64
LastHourRate float64
// Health and timing
LastHealthCheck time.Time
mu sync.RWMutex
LastHealthCheck time.Time
mu sync.RWMutex
}
// NewTaskWorker creates a new TaskWorker instance
func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg config.RabbitMQConfig, handler TaskHandler) (*TaskWorker, error) {
func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg config.RabbitMQConfig, factory *HandlerFactory) (*TaskWorker, error) {
// Initialize RabbitMQ connection
mq.InitRabbitProxy(ctx, rabbitCfg)
conn := mq.GetConn()
@ -118,10 +118,10 @@ func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg
// Declare queue (ensure it exists with proper arguments)
_, err = ch.QueueDeclare(
constants.TaskQueueName, // name
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
amqp.Table{
"x-max-priority": constants.TaskMaxPriority,
"x-message-ttl": constants.TaskDefaultMessageTTL.Milliseconds(),
@ -158,15 +158,15 @@ func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg
pool: pool,
conn: conn,
ch: ch,
handler: handler,
factory: factory,
stopChan: make(chan struct{}),
ctx: ctxWithCancel,
cancel: cancel,
metrics: &WorkerMetrics{
TasksProcessed: make(map[TaskType]int64),
TasksFailed: make(map[TaskType]int64),
TasksSuccess: make(map[TaskType]int64),
ProcessingTime: make(map[TaskType]time.Duration),
TasksProcessed: make(map[TaskType]int64),
TasksFailed: make(map[TaskType]int64),
TasksSuccess: make(map[TaskType]int64),
ProcessingTime: make(map[TaskType]time.Duration),
LastHealthCheck: time.Now(),
},
}
@ -203,13 +203,13 @@ func (w *TaskWorker) consumerLoop(consumerID int) {
// Consume messages from the queue
msgs, err := w.ch.Consume(
constants.TaskQueueName, // queue
constants.TaskQueueName, // queue
fmt.Sprintf("worker-%d", consumerID), // consumer tag
false, // auto-ack
false, // exclusive
false, // no-local
false, // no-wait
nil, // args
false, // auto-ack
false, // exclusive
false, // no-local
false, // no-wait
nil, // args
)
if err != nil {
logger.Error(w.ctx, "Failed to start consumer",
@ -316,7 +316,7 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) {
// Execute task using handler
startTime := time.Now()
err := w.handler.Execute(ctx, taskMsg.TaskID, taskMsg.TaskType, taskMsg.Params, w.db)
err := w.dispatch(ctx, taskMsg.TaskType, taskMsg.TaskID, taskMsg.Params, &msg)
processingTime := time.Since(startTime)
if err != nil {
@ -431,29 +431,37 @@ func (w *TaskWorker) updateTaskStatus(ctx context.Context, taskID uuid.UUID, sta
return nil
}
// updateTaskWithError updates a task with error information
// updateTaskWithError updates a task with error information in both async_task and async_task_result.
func (w *TaskWorker) updateTaskWithError(ctx context.Context, taskID uuid.UUID, err error) error {
// Update task error information in database
errorMsg := err.Error()
stackTrace := fmt.Sprintf("%+v", err)
updateErr := database.UpdateTaskErrorInfo(ctx, w.db, taskID, errorMsg, stackTrace)
if updateErr != nil {
logger.Error(ctx, "Failed to update task error info",
"task_id", taskID,
"error", updateErr,
)
if updateErr := database.UpdateTaskErrorInfo(ctx, w.db, taskID, errorMsg, stackTrace); updateErr != nil {
logger.Error(ctx, "Failed to update task error info", "task_id", taskID, "error", updateErr)
return updateErr
}
logger.Warn(ctx, "Task failed with error",
"task_id", taskID,
"error", errorMsg,
)
if updateErr := database.UpdateAsyncTaskResultWithError(ctx, w.db, taskID, 500, errorMsg, nil); updateErr != nil {
logger.Error(ctx, "Failed to update task result with error", "task_id", taskID, "error", updateErr)
return updateErr
}
logger.Warn(ctx, "Task failed with error", "task_id", taskID, "error", errorMsg)
return nil
}
// dispatch routes a task message to the appropriate handler and executes it.
// Nacks the message and returns an error if no handler is registered for the task type.
func (w *TaskWorker) dispatch(ctx context.Context, taskType TaskType, taskID uuid.UUID, params map[string]any, msg *amqp.Delivery) error {
handler, err := w.factory.GetHandler(taskType)
if err != nil {
logger.Error(ctx, "No handler for task type", "task_type", taskType)
msg.Nack(false, false)
return err
}
return handler.Execute(ctx, taskID, params, w.db)
}
// healthCheckLoop periodically checks worker health and metrics
func (w *TaskWorker) healthCheckLoop() {
defer w.wg.Done()
@ -479,10 +487,10 @@ func (w *TaskWorker) checkHealth() {
// Update queue depth
queue, err := w.ch.QueueDeclarePassive(
constants.TaskQueueName, // name
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
amqp.Table{
"x-max-priority": constants.TaskMaxPriority,
"x-message-ttl": constants.TaskDefaultMessageTTL.Milliseconds(),
@ -565,24 +573,24 @@ func (w *TaskWorker) GetMetrics() *WorkerMetrics {
// Create a copy without the mutex to avoid copylocks warning
return &WorkerMetrics{
TasksProcessed: tasksProcessedCopy,
TasksFailed: tasksFailedCopy,
TasksSuccess: tasksSuccessCopy,
ProcessingTime: processingTimeCopy,
TotalProcessed: w.metrics.TotalProcessed,
TotalFailed: w.metrics.TotalFailed,
TotalSuccess: w.metrics.TotalSuccess,
TasksInProgress: w.metrics.TasksInProgress,
QueueDepth: w.metrics.QueueDepth,
QueueLatency: w.metrics.QueueLatency,
WorkersActive: w.metrics.WorkersActive,
WorkersIdle: w.metrics.WorkersIdle,
MemoryUsage: w.metrics.MemoryUsage,
CPULoad: w.metrics.CPULoad,
LastMinuteRate: w.metrics.LastMinuteRate,
TasksProcessed: tasksProcessedCopy,
TasksFailed: tasksFailedCopy,
TasksSuccess: tasksSuccessCopy,
ProcessingTime: processingTimeCopy,
TotalProcessed: w.metrics.TotalProcessed,
TotalFailed: w.metrics.TotalFailed,
TotalSuccess: w.metrics.TotalSuccess,
TasksInProgress: w.metrics.TasksInProgress,
QueueDepth: w.metrics.QueueDepth,
QueueLatency: w.metrics.QueueLatency,
WorkersActive: w.metrics.WorkersActive,
WorkersIdle: w.metrics.WorkersIdle,
MemoryUsage: w.metrics.MemoryUsage,
CPULoad: w.metrics.CPULoad,
LastMinuteRate: w.metrics.LastMinuteRate,
Last5MinutesRate: w.metrics.Last5MinutesRate,
LastHourRate: w.metrics.LastHourRate,
LastHealthCheck: w.metrics.LastHealthCheck,
LastHourRate: w.metrics.LastHourRate,
LastHealthCheck: w.metrics.LastHealthCheck,
// Mutex is intentionally omitted
}
}
@ -595,6 +603,7 @@ func (w *TaskWorker) IsHealthy() bool {
// Consider unhealthy if last health check was too long ago
return time.Since(w.metrics.LastHealthCheck) < 2*w.cfg.PollingInterval
}
// RecordMetrics periodically records worker metrics to the logging system
func (w *TaskWorker) RecordMetrics(interval time.Duration) {
go func() {