Refactor: extract task constants to dedicated constants package

- Add constants/task.go with centralized task-related constants
    - Task priority levels (default, high, low)
    - Task queue configuration (exchange, queue, routing key)
    - Task message settings (max priority, TTL)
    - Task retry settings (max retries, delays)
    - Test task settings (sleep duration, max limit)

  - Update task-related files to use constants from constants package:
    - handler/async_task_create_handler.go
    - task/queue_message.go
    - task/queue_producer.go
    - task/retry_manager.go
    - task/test_task.go
    - task/types.go (add TypeTest)
    - task/worker.go
This commit is contained in:
douxu 2026-04-22 17:20:26 +08:00
parent 4a3f7a65bc
commit 809e1cd87d
9 changed files with 210 additions and 201 deletions

54
constants/task.go Normal file
View File

@ -0,0 +1,54 @@
// Package constants defines task-related constants for the async task system
package constants
import "time"
// Task priority levels
const (
// TaskPriorityDefault is the default priority level for tasks
TaskPriorityDefault = 5
// TaskPriorityHigh represents high priority tasks
TaskPriorityHigh = 10
// TaskPriorityLow represents low priority tasks
TaskPriorityLow = 1
)
// Task queue configuration
const (
// TaskExchangeName is the name of the exchange for task routing
TaskExchangeName = "modelrt.tasks.exchange"
// TaskQueueName is the name of the main task queue
TaskQueueName = "modelrt.tasks.queue"
// TaskRoutingKey is the routing key for task messages
TaskRoutingKey = "modelrt.task"
)
// Task message settings
const (
// TaskMaxPriority is the maximum priority level for tasks (0-10)
TaskMaxPriority = 10
// TaskDefaultMessageTTL is the default time-to-live for task messages (24 hours)
TaskDefaultMessageTTL = 24 * time.Hour
)
// Task retry settings
const (
// TaskRetryMaxDefault is the default maximum number of retry attempts
TaskRetryMaxDefault = 3
// TaskRetryInitialDelayDefault is the default initial delay for exponential backoff
TaskRetryInitialDelayDefault = 1 * time.Second
// TaskRetryMaxDelayDefault is the default maximum delay for exponential backoff
TaskRetryMaxDelayDefault = 5 * time.Minute
// TaskRetryRandomFactorDefault is the default random factor for jitter (10%)
TaskRetryRandomFactorDefault = 0.1
// TaskRetryFixedDelayDefault is the default delay for fixed retry strategy
TaskRetryFixedDelayDefault = 5 * time.Second
)
// Test task settings
const (
// TestTaskSleepDurationDefault is the default sleep duration for test tasks (60 seconds)
TestTaskSleepDurationDefault = 60
// TestTaskSleepDurationMax is the maximum allowed sleep duration for test tasks (1 hour)
TestTaskSleepDurationMax = 3600
)

View File

@ -2,10 +2,8 @@
package handler
import (
"net/http"
"strings"
"modelRT/config"
"modelRT/constants"
"modelRT/database"
"modelRT/logger"
"modelRT/network"
@ -13,7 +11,6 @@ import (
"modelRT/task"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// AsyncTaskCreateHandler handles creation of asynchronous tasks
@ -32,59 +29,44 @@ func AsyncTaskCreateHandler(c *gin.Context) {
var request network.AsyncTaskCreateRequest
if err := c.ShouldBindJSON(&request); err != nil {
logger.Error(ctx, "failed to unmarshal async task create request", "error", err)
c.JSON(http.StatusOK, network.FailureResponse{
Code: http.StatusBadRequest,
Msg: "invalid request parameters",
})
logger.Error(ctx, "unmarshal async task create request failed", "error", err)
renderRespFailure(c, constants.RespCodeInvalidParams, "invalid request parameters", nil)
return
}
// Validate task type
// validate task type
if !orm.IsValidAsyncTaskType(request.TaskType) {
logger.Error(ctx, "invalid task type", "task_type", request.TaskType)
c.JSON(http.StatusOK, network.FailureResponse{
Code: http.StatusBadRequest,
Msg: "invalid task type",
})
logger.Error(ctx, "check task type invalid", "task_type", request.TaskType)
renderRespFailure(c, constants.RespCodeInvalidParams, "invalid task type", nil)
return
}
// Validate task parameters based on task type
// validate task parameters based on task type
if !validateTaskParams(request.TaskType, request.Params) {
logger.Error(ctx, "invalid task parameters", "task_type", request.TaskType, "params", request.Params)
c.JSON(http.StatusOK, network.FailureResponse{
Code: http.StatusBadRequest,
Msg: "invalid task parameters",
})
logger.Error(ctx, "check task parameters invalid", "task_type", request.TaskType, "params", request.Params)
renderRespFailure(c, constants.RespCodeInvalidParams, "invalid task parameters", nil)
return
}
pgClient := database.GetPostgresDBClient()
if pgClient == nil {
logger.Error(ctx, "database connection not found in context")
c.JSON(http.StatusOK, network.FailureResponse{
Code: http.StatusInternalServerError,
Msg: "database connection error",
})
renderRespFailure(c, constants.RespCodeServerError, "database connection error", nil)
return
}
// Create task in database
// create task in database
taskType := orm.AsyncTaskType(request.TaskType)
params := orm.JSONMap(request.Params)
asyncTask, err := database.CreateAsyncTask(ctx, pgClient, taskType, params)
if err != nil {
logger.Error(ctx, "failed to create async task in database", "error", err)
c.JSON(http.StatusOK, network.FailureResponse{
Code: http.StatusInternalServerError,
Msg: "failed to create task",
})
logger.Error(ctx, "create async task in database failed", "error", err)
renderRespFailure(c, constants.RespCodeServerError, "failed to create task", nil)
return
}
// Send task to message queue
// send task to message queue
cfg, exists := c.Get("config")
if !exists {
logger.Warn(ctx, "Configuration not found in context, skipping queue publishing")
@ -92,41 +74,36 @@ func AsyncTaskCreateHandler(c *gin.Context) {
modelRTConfig := cfg.(config.ModelRTConfig)
ctx := c.Request.Context()
// Create queue producer
// create queue producer
// TODO 像实时计算一样使用 channel 代替
producer, err := task.NewQueueProducer(ctx, modelRTConfig.RabbitMQConfig)
if err != nil {
logger.Error(ctx, "Failed to create queue producer", "error", err)
// Continue without queue publishing
} else {
defer producer.Close()
// Publish task to queue
taskType := task.TaskType(request.TaskType)
priority := 5 // Default priority
if err := producer.PublishTaskWithRetry(ctx, asyncTask.TaskID, taskType, priority, 3); err != nil {
logger.Error(ctx, "Failed to publish task to queue",
"task_id", asyncTask.TaskID,
"error", err)
// Log error but don't affect task creation response
} else {
logger.Info(ctx, "Task published to queue successfully",
"task_id", asyncTask.TaskID,
"queue", task.TaskQueueName)
}
logger.Error(ctx, "create rabbitMQ queue producer failed", "error", err)
renderRespFailure(c, constants.RespCodeServerError, "create rabbitMQ queue producer failed", nil)
return
}
defer producer.Close()
// publish task to queue
taskType := task.TaskType(request.TaskType)
priority := 5 // Default priority
if err := producer.PublishTaskWithRetry(ctx, asyncTask.TaskID, taskType, priority, 3); err != nil {
logger.Error(ctx, "publish task to rabbitMQ queue failed",
"task_id", asyncTask.TaskID, "error", err)
renderRespFailure(c, constants.RespCodeServerError, "publish task to rabbitMQ queue failed", nil)
return
}
logger.Info(ctx, "published task to rabbitMQ queue successfully",
"task_id", asyncTask.TaskID, "queue", constants.TaskQueueName)
}
logger.Info(ctx, "async task created successfully", "task_id", asyncTask.TaskID, "task_type", request.TaskType)
logger.Info(ctx, "async task created success", "task_id", asyncTask.TaskID, "task_type", request.TaskType)
// Return success response
c.JSON(http.StatusOK, network.SuccessResponse{
Code: 2000,
Msg: "task created successfully",
Payload: network.AsyncTaskCreateResponse{
TaskID: asyncTask.TaskID,
},
})
// return success response
payload := genAsyncTaskCreatePayload(asyncTask.TaskID.String())
renderRespSuccess(c, constants.RespCodeSuccess, "task created successfully", payload)
}
func validateTaskParams(taskType string, params map[string]any) bool {
@ -189,54 +166,9 @@ func validateTestTaskParams(params map[string]any) bool {
return true
}
func splitCommaSeparated(s string) []string {
var result []string
var current strings.Builder
inQuotes := false
escape := false
for _, ch := range s {
if escape {
current.WriteRune(ch)
escape = false
continue
}
switch ch {
case '\\':
escape = true
case '"':
inQuotes = !inQuotes
case ',':
if !inQuotes {
result = append(result, strings.TrimSpace(current.String()))
current.Reset()
} else {
current.WriteRune(ch)
}
default:
current.WriteRune(ch)
}
func genAsyncTaskCreatePayload(taskID string) map[string]any {
payload := map[string]any{
"task_id": taskID,
}
if current.Len() > 0 {
result = append(result, strings.TrimSpace(current.String()))
}
return result
}
func getDBFromContext(c *gin.Context) *gorm.DB {
// Try to get database connection from context
// This should be set by middleware
if db, exists := c.Get("db"); exists {
if gormDB, ok := db.(*gorm.DB); ok {
return gormDB
}
}
// Fallback to global database connection
// This should be implemented based on your application's database setup
// For now, return nil - actual implementation should retrieve from application context
return nil
return payload
}

View File

@ -3,6 +3,7 @@ package handler
import (
"net/http"
"strings"
"modelRT/database"
"modelRT/logger"
@ -142,3 +143,40 @@ func AsyncTaskResultQueryHandler(c *gin.Context) {
},
})
}
func splitCommaSeparated(s string) []string {
var result []string
var current strings.Builder
inQuotes := false
escape := false
for _, ch := range s {
if escape {
current.WriteRune(ch)
escape = false
continue
}
switch ch {
case '\\':
escape = true
case '"':
inQuotes = !inQuotes
case ',':
if !inQuotes {
result = append(result, strings.TrimSpace(current.String()))
current.Reset()
} else {
current.WriteRune(ch)
}
default:
current.WriteRune(ch)
}
}
if current.Len() > 0 {
result = append(result, strings.TrimSpace(current.String()))
}
return result
}

View File

@ -3,24 +3,17 @@ package task
import (
"encoding/json"
"modelRT/constants"
"github.com/gofrs/uuid"
)
// DefaultPriority is the default task priority
const DefaultPriority = 5
// HighPriority represents high priority tasks
const HighPriority = 10
// LowPriority represents low priority tasks
const LowPriority = 1
// TaskQueueMessage defines minimal message structure for RabbitMQ/Redis queue dispatch
// This struct is designed to be lightweight for efficient message transport
type TaskQueueMessage struct {
TaskID uuid.UUID `json:"task_id"`
TaskType TaskType `json:"task_type"`
Priority int `json:"priority,omitempty"` // Optional, defaults to DefaultPriority
Priority int `json:"priority,omitempty"` // Optional, defaults to constants.TaskPriorityDefault
}
// NewTaskQueueMessage creates a new TaskQueueMessage with default priority
@ -28,7 +21,7 @@ func NewTaskQueueMessage(taskID uuid.UUID, taskType TaskType) *TaskQueueMessage
return &TaskQueueMessage{
TaskID: taskID,
TaskType: taskType,
Priority: DefaultPriority,
Priority: constants.TaskPriorityDefault,
}
}
@ -41,12 +34,12 @@ func NewTaskQueueMessageWithPriority(taskID uuid.UUID, taskType TaskType, priori
}
}
// ToJSON converts the TaskQueueMessage to JSON bytes
// ToJSON converts TaskQueueMessage to JSON bytes
func (m *TaskQueueMessage) ToJSON() ([]byte, error) {
return json.Marshal(m)
}
// Validate checks if the TaskQueueMessage is valid
// Validate checks if TaskQueueMessage is valid
func (m *TaskQueueMessage) Validate() bool {
// Check if TaskID is valid (not nil UUID)
if m.TaskID == uuid.Nil {
@ -55,25 +48,25 @@ func (m *TaskQueueMessage) Validate() bool {
// Check if TaskType is valid
switch m.TaskType {
case TypeTopologyAnalysis, TypeEventAnalysis, TypeBatchImport:
case TypeTopologyAnalysis, TypeEventAnalysis, TypeBatchImport, TypeTest:
return true
default:
return false
}
}
// SetPriority sets the priority of the task queue message with validation
// SetPriority sets priority of task queue message with validation
func (m *TaskQueueMessage) SetPriority(priority int) {
if priority < LowPriority {
priority = LowPriority
if priority < constants.TaskPriorityLow {
priority = constants.TaskPriorityLow
}
if priority > HighPriority {
priority = HighPriority
if priority > constants.TaskPriorityHigh {
priority = constants.TaskPriorityHigh
}
m.Priority = priority
}
// GetPriority returns the priority of the task queue message
// GetPriority returns priority of task queue message
func (m *TaskQueueMessage) GetPriority() int {
return m.Priority
}

View File

@ -8,6 +8,7 @@ import (
"time"
"modelRT/config"
"modelRT/constants"
"modelRT/logger"
"modelRT/mq"
@ -15,19 +16,6 @@ import (
amqp "github.com/rabbitmq/amqp091-go"
)
const (
// TaskExchangeName is the name of the exchange for task routing
TaskExchangeName = "modelrt.tasks.exchange"
// TaskQueueName is the name of the main task queue
TaskQueueName = "modelrt.tasks.queue"
// TaskRoutingKey is the routing key for task messages
TaskRoutingKey = "modelrt.task"
// MaxPriority is the maximum priority level for tasks (0-10)
MaxPriority = 10
// DefaultMessageTTL is the default time-to-live for task messages (24 hours)
DefaultMessageTTL = 24 * time.Hour
)
// QueueProducer handles publishing tasks to RabbitMQ
type QueueProducer struct {
conn *amqp.Connection
@ -67,13 +55,13 @@ func NewQueueProducer(ctx context.Context, cfg config.RabbitMQConfig) (*QueuePro
func (p *QueueProducer) declareInfrastructure() error {
// Declare durable direct exchange
err := p.ch.ExchangeDeclare(
TaskExchangeName, // name
"direct", // type
true, // durable
false, // auto-deleted
false, // internal
false, // no-wait
nil, // arguments
constants.TaskExchangeName, // name
"direct", // type
true, // durable
false, // auto-deleted
false, // internal
false, // no-wait
nil, // arguments
)
if err != nil {
return fmt.Errorf("failed to declare exchange: %w", err)
@ -81,14 +69,14 @@ func (p *QueueProducer) declareInfrastructure() error {
// Declare durable queue with priority support and message TTL
_, err = p.ch.QueueDeclare(
TaskQueueName, // name
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
constants.TaskQueueName, // name
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
amqp.Table{
"x-max-priority": MaxPriority, // support priority levels 0-10
"x-message-ttl": DefaultMessageTTL.Milliseconds(), // message TTL
"x-max-priority": constants.TaskMaxPriority, // support priority levels 0-10
"x-message-ttl": constants.TaskDefaultMessageTTL.Milliseconds(), // message TTL
},
)
if err != nil {
@ -97,11 +85,11 @@ func (p *QueueProducer) declareInfrastructure() error {
// Bind queue to exchange
err = p.ch.QueueBind(
TaskQueueName, // queue name
TaskRoutingKey, // routing key
TaskExchangeName, // exchange name
false, // no-wait
nil, // arguments
constants.TaskQueueName, // queue name
constants.TaskRoutingKey, // routing key
constants.TaskExchangeName, // exchange name
false, // no-wait
nil, // arguments
)
if err != nil {
return fmt.Errorf("failed to bind queue: %w", err)
@ -141,10 +129,10 @@ func (p *QueueProducer) PublishTask(ctx context.Context, taskID uuid.UUID, taskT
// Publish to exchange
err = p.ch.PublishWithContext(
ctx,
TaskExchangeName, // exchange
TaskRoutingKey, // routing key
false, // mandatory
false, // immediate
constants.TaskExchangeName, // exchange
constants.TaskRoutingKey, // routing key
false, // mandatory
false, // immediate
publishing,
)
if err != nil {
@ -155,7 +143,7 @@ func (p *QueueProducer) PublishTask(ctx context.Context, taskID uuid.UUID, taskT
"task_id", taskID.String(),
"task_type", taskType,
"priority", priority,
"queue", TaskQueueName,
"queue", constants.TaskQueueName,
)
return nil
@ -205,14 +193,14 @@ func (p *QueueProducer) Close() error {
// GetQueueInfo returns information about the task queue
func (p *QueueProducer) GetQueueInfo() (*amqp.Queue, error) {
queue, err := p.ch.QueueDeclarePassive(
TaskQueueName, // name
constants.TaskQueueName, // name
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
amqp.Table{
"x-max-priority": MaxPriority,
"x-message-ttl": DefaultMessageTTL.Milliseconds(),
"x-max-priority": constants.TaskMaxPriority,
"x-message-ttl": constants.TaskDefaultMessageTTL.Milliseconds(),
},
)
if err != nil {
@ -223,5 +211,5 @@ func (p *QueueProducer) GetQueueInfo() (*amqp.Queue, error) {
// PurgeQueue removes all messages from the task queue
func (p *QueueProducer) PurgeQueue() (int, error) {
return p.ch.QueuePurge(TaskQueueName, false)
return p.ch.QueuePurge(constants.TaskQueueName, false)
}

View File

@ -8,12 +8,13 @@ import (
"strings"
"time"
"modelRT/constants"
"modelRT/logger"
)
// RetryStrategy defines the interface for task retry strategies
type RetryStrategy interface {
// ShouldRetry determines if a task should be retried and returns the delay before next retry
// ShouldRetry determines if a task should be retried and returns delay before next retry
ShouldRetry(ctx context.Context, taskID string, retryCount int, lastError error) (bool, time.Duration)
// GetMaxRetries returns the maximum number of retry attempts
GetMaxRetries() int
@ -98,7 +99,7 @@ func (s *ExponentialBackoffRetry) ShouldRetry(ctx context.Context, taskID string
return true, delay
}
// GetMaxRetries returns the maximum number of retry attempts
// GetMaxRetries returns maximum number of retry attempts
func (s *ExponentialBackoffRetry) GetMaxRetries() int {
return s.MaxRetries
}
@ -151,7 +152,7 @@ func (s *FixedDelayRetry) ShouldRetry(ctx context.Context, taskID string, retryC
return true, delay
}
// GetMaxRetries returns the maximum number of retry attempts
// GetMaxRetries returns maximum number of retry attempts
func (s *FixedDelayRetry) GetMaxRetries() int {
return s.MaxRetries
}
@ -177,10 +178,10 @@ func (s *NoRetryStrategy) GetMaxRetries() int {
// DefaultRetryStrategy returns the default retry strategy (exponential backoff)
func DefaultRetryStrategy() RetryStrategy {
return NewExponentialBackoffRetry(
3, // max retries
1*time.Second, // initial delay
5*time.Minute, // max delay
0.1, // random factor (10% jitter)
constants.TaskRetryMaxDefault, // max retries
constants.TaskRetryInitialDelayDefault, // initial delay
constants.TaskRetryMaxDelayDefault, // max delay
constants.TaskRetryRandomFactorDefault, // random factor (10% jitter)
)
}
@ -216,4 +217,4 @@ func IsRetryableError(err error) bool {
}
return false
}
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
"time"
"modelRT/constants"
"modelRT/database"
"modelRT/logger"
"modelRT/orm"
@ -17,7 +18,7 @@ import (
// TestTaskParams defines parameters for test task
type TestTaskParams struct {
// SleepDuration specifies how long the task should sleep (in seconds)
// Default is 60 seconds as per requirement
// Default is constants.TestTaskSleepDurationDefault seconds as per requirement
SleepDuration int `json:"sleep_duration"`
// Message is a custom message to include in the result
Message string `json:"message,omitempty"`
@ -25,14 +26,14 @@ type TestTaskParams struct {
// Validate checks if test task parameters are valid
func (p *TestTaskParams) Validate() error {
// Default to 60 seconds if not specified
// Default to constants.TestTaskSleepDurationDefault seconds if not specified
if p.SleepDuration <= 0 {
p.SleepDuration = 60
p.SleepDuration = constants.TestTaskSleepDurationDefault
}
// Validate max duration (max 1 hour)
if p.SleepDuration > 3600 {
return fmt.Errorf("sleep duration cannot exceed 3600 seconds (1 hour)")
if p.SleepDuration > constants.TestTaskSleepDurationMax {
return fmt.Errorf("sleep duration cannot exceed %d seconds (1 hour)", constants.TestTaskSleepDurationMax)
}
return nil
@ -90,7 +91,7 @@ func (t *TestTask) Execute(ctx context.Context, taskID uuid.UUID, db *gorm.DB) e
return fmt.Errorf("invalid parameter type for TestTask")
}
logger.Info(ctx, "Starting test task execution",
logger.Info(ctx, "Starting test task executionser",
"task_id", taskID,
"sleep_duration_seconds", params.SleepDuration,
"message", params.Message,
@ -149,7 +150,7 @@ func (h *TestTaskHandler) Execute(ctx context.Context, taskID uuid.UUID, taskTyp
// Fetch task parameters from database
asyncTask, err := database.GetAsyncTaskByID(ctx, db, taskID)
if err != nil {
return fmt.Errorf("failed to fetch task: %w", err)
return fmt.Errorf("failed toser fetch task: %w", err)
}
// Convert params map to TestTaskParams

View File

@ -20,6 +20,7 @@ const (
TypeTopologyAnalysis TaskType = "TOPOLOGY_ANALYSIS"
TypeEventAnalysis TaskType = "EVENT_ANALYSIS"
TypeBatchImport TaskType = "BATCH_IMPORT"
TypeTest TaskType = "TEST"
)
type Task struct {

View File

@ -9,6 +9,7 @@ import (
"time"
"modelRT/config"
"modelRT/constants"
"modelRT/database"
"modelRT/logger"
"modelRT/mq"
@ -112,14 +113,14 @@ func NewTaskWorker(ctx context.Context, cfg WorkerConfig, db *gorm.DB, rabbitCfg
// Declare queue (ensure it exists with proper arguments)
_, err = ch.QueueDeclare(
TaskQueueName, // name
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
constants.TaskQueueName, // name
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
amqp.Table{
"x-max-priority": MaxPriority,
"x-message-ttl": DefaultMessageTTL.Milliseconds(),
"x-max-priority": constants.TaskMaxPriority,
"x-message-ttl": constants.TaskDefaultMessageTTL.Milliseconds(),
},
)
if err != nil {
@ -198,7 +199,7 @@ func (w *TaskWorker) consumerLoop(consumerID int) {
// Consume messages from the queue
msgs, err := w.ch.Consume(
TaskQueueName, // queue
constants.TaskQueueName, // queue
fmt.Sprintf("worker-%d", consumerID), // consumer tag
false, // auto-ack
false, // exclusive
@ -462,14 +463,14 @@ func (w *TaskWorker) checkHealth() {
// Update queue depth
queue, err := w.ch.QueueDeclarePassive(
TaskQueueName, // name
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
constants.TaskQueueName, // name
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
amqp.Table{
"x-max-priority": MaxPriority,
"x-message-ttl": DefaultMessageTTL.Milliseconds(),
"x-max-priority": constants.TaskMaxPriority,
"x-message-ttl": constants.TaskDefaultMessageTTL.Milliseconds(),
},
)
if err == nil {