175 lines
5.7 KiB
Go
175 lines
5.7 KiB
Go
// Package handler provides HTTP handlers for various endpoints.
|
||
package handler
|
||
|
||
import (
|
||
"modelRT/config"
|
||
"modelRT/constants"
|
||
"modelRT/database"
|
||
"modelRT/logger"
|
||
"modelRT/network"
|
||
"modelRT/orm"
|
||
"modelRT/task"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// AsyncTaskCreateHandler handles creation of asynchronous tasks
|
||
// @Summary 创建异步任务
|
||
// @Description 创建新的异步任务并返回任务ID,任务将被提交到队列等待处理
|
||
// @Tags AsyncTask
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param request body network.AsyncTaskCreateRequest true "任务创建请求"
|
||
// @Success 200 {object} network.SuccessResponse{payload=network.AsyncTaskCreateResponse} "任务创建成功"
|
||
// @Failure 400 {object} network.FailureResponse "请求参数错误"
|
||
// @Failure 500 {object} network.FailureResponse "服务器内部错误"
|
||
// @Router /task/async [post]
|
||
func AsyncTaskCreateHandler(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
var request network.AsyncTaskCreateRequest
|
||
|
||
if err := c.ShouldBindJSON(&request); err != nil {
|
||
logger.Error(ctx, "unmarshal async task create request failed", "error", err)
|
||
renderRespFailure(c, constants.RespCodeInvalidParams, "invalid request parameters", nil)
|
||
return
|
||
}
|
||
|
||
// validate task type
|
||
if !orm.IsValidAsyncTaskType(request.TaskType) {
|
||
logger.Error(ctx, "check task type invalid", "task_type", request.TaskType)
|
||
renderRespFailure(c, constants.RespCodeInvalidParams, "invalid task type", nil)
|
||
return
|
||
}
|
||
|
||
// validate task parameters based on task type
|
||
if !validateTaskParams(request.TaskType, request.Params) {
|
||
logger.Error(ctx, "check task parameters invalid", "task_type", request.TaskType, "params", request.Params)
|
||
renderRespFailure(c, constants.RespCodeInvalidParams, "invalid task parameters", nil)
|
||
return
|
||
}
|
||
|
||
pgClient := database.GetPostgresDBClient()
|
||
if pgClient == nil {
|
||
logger.Error(ctx, "database connection not found in context")
|
||
renderRespFailure(c, constants.RespCodeServerError, "database connection error", nil)
|
||
return
|
||
}
|
||
|
||
// create task in database
|
||
taskType := orm.AsyncTaskType(request.TaskType)
|
||
params := orm.JSONMap(request.Params)
|
||
|
||
asyncTask, err := database.CreateAsyncTask(ctx, pgClient, taskType, params)
|
||
if err != nil {
|
||
logger.Error(ctx, "create async task in database failed", "error", err)
|
||
renderRespFailure(c, constants.RespCodeServerError, "failed to create task", nil)
|
||
return
|
||
}
|
||
|
||
// send task to message queue
|
||
cfg, exists := c.Get("config")
|
||
if !exists {
|
||
logger.Warn(ctx, "Configuration not found in context, skipping queue publishing")
|
||
} else {
|
||
modelRTConfig := cfg.(config.ModelRTConfig)
|
||
ctx := c.Request.Context()
|
||
|
||
// create queue producer
|
||
// TODO 像实时计算一样使用 channel 代替
|
||
producer, err := task.NewQueueProducer(ctx, modelRTConfig.RabbitMQConfig)
|
||
if err != nil {
|
||
logger.Error(ctx, "create rabbitMQ queue producer failed", "error", err)
|
||
renderRespFailure(c, constants.RespCodeServerError, "create rabbitMQ queue producer failed", nil)
|
||
return
|
||
}
|
||
defer producer.Close()
|
||
|
||
// publish task to queue
|
||
taskType := task.TaskType(request.TaskType)
|
||
priority := 5 // Default priority
|
||
|
||
if err := producer.PublishTaskWithRetry(ctx, asyncTask.TaskID, taskType, priority, 3); err != nil {
|
||
logger.Error(ctx, "publish task to rabbitMQ queue failed",
|
||
"task_id", asyncTask.TaskID, "error", err)
|
||
renderRespFailure(c, constants.RespCodeServerError, "publish task to rabbitMQ queue failed", nil)
|
||
return
|
||
}
|
||
logger.Info(ctx, "published task to rabbitMQ queue successfully",
|
||
"task_id", asyncTask.TaskID, "queue", constants.TaskQueueName)
|
||
|
||
}
|
||
|
||
logger.Info(ctx, "async task created success", "task_id", asyncTask.TaskID, "task_type", request.TaskType)
|
||
|
||
// return success response
|
||
payload := genAsyncTaskCreatePayload(asyncTask.TaskID.String())
|
||
renderRespSuccess(c, constants.RespCodeSuccess, "task created successfully", payload)
|
||
}
|
||
|
||
func validateTaskParams(taskType string, params map[string]any) bool {
|
||
switch taskType {
|
||
case string(orm.AsyncTaskTypeTopologyAnalysis):
|
||
return validateTopologyAnalysisParams(params)
|
||
case string(orm.AsyncTaskTypePerformanceAnalysis):
|
||
return validatePerformanceAnalysisParams(params)
|
||
case string(orm.AsyncTaskTypeEventAnalysis):
|
||
return validateEventAnalysisParams(params)
|
||
case string(orm.AsyncTaskTypeBatchImport):
|
||
return validateBatchImportParams(params)
|
||
case string(orm.AsyncTaskTypeTest):
|
||
return validateTestTaskParams(params)
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func validateTopologyAnalysisParams(params map[string]any) bool {
|
||
// Check required parameters for topology analysis
|
||
if startUUID, ok := params["start_uuid"]; !ok || startUUID == "" {
|
||
return false
|
||
}
|
||
if endUUID, ok := params["end_uuid"]; !ok || endUUID == "" {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
func validatePerformanceAnalysisParams(params map[string]any) bool {
|
||
// Check required parameters for performance analysis
|
||
if componentIDs, ok := params["component_ids"]; !ok {
|
||
return false
|
||
} else if ids, isSlice := componentIDs.([]interface{}); !isSlice || len(ids) == 0 {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
func validateEventAnalysisParams(params map[string]any) bool {
|
||
// Check required parameters for event analysis
|
||
if eventType, ok := params["event_type"]; !ok || eventType == "" {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
func validateBatchImportParams(params map[string]any) bool {
|
||
// Check required parameters for batch import
|
||
if filePath, ok := params["file_path"]; !ok || filePath == "" {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
func validateTestTaskParams(params map[string]any) bool {
|
||
// Test task has optional parameters, all are valid
|
||
// sleep_duration defaults to 60 seconds if not provided
|
||
return true
|
||
}
|
||
|
||
func genAsyncTaskCreatePayload(taskID string) map[string]any {
|
||
payload := map[string]any{
|
||
"task_id": taskID,
|
||
}
|
||
return payload
|
||
}
|