From de5f976c317411d53c523c52c01d0f2c4c2b33c7 Mon Sep 17 00:00:00 2001 From: douxu Date: Tue, 17 Mar 2026 16:08:46 +0800 Subject: [PATCH] add route of async task system --- .gitignore | 12 + handler/async_task_handler.go | 677 ++++++++++++++++++++++++++++++++++ router/async_task.go | 32 ++ router/router.go | 1 + 4 files changed, 722 insertions(+) create mode 100644 handler/async_task_handler.go create mode 100644 router/async_task.go diff --git a/.gitignore b/.gitignore index 673582e..b338b88 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,15 @@ go.work # Shield config files in the configs folder /configs/**/*.yaml /configs/**/*.pem + +# ai config +.cursor/ +.claude/ +.cursorrules +.copilot/ +.chatgpt/ +.ai_history/ +.vector_cache/ +ai-debug.log +*.patch +*.diff \ No newline at end of file diff --git a/handler/async_task_handler.go b/handler/async_task_handler.go new file mode 100644 index 0000000..9808e33 --- /dev/null +++ b/handler/async_task_handler.go @@ -0,0 +1,677 @@ +// Package handler provides HTTP handlers for various endpoints. +package handler + +import ( + "net/http" + "strings" + "time" + + "modelRT/database" + "modelRT/logger" + "modelRT/network" + "modelRT/orm" + _ "modelRT/task" + + "github.com/gin-gonic/gin" + "github.com/gofrs/uuid" + "gorm.io/gorm" +) + +// AsyncTaskCreateHandler handles creation of asynchronous tasks +// @Summary 创建异步任务 +// @Description 创建新的异步任务并返回任务ID,任务将被提交到队列等待处理 +// @Tags AsyncTask +// @Accept json +// @Produce json +// @Param request body network.AsyncTaskCreateRequest true "任务创建请求" +// @Success 200 {object} network.SuccessResponse{payload=network.AsyncTaskCreateResponse} "任务创建成功" +// @Failure 400 {object} network.FailureResponse "请求参数错误" +// @Failure 500 {object} network.FailureResponse "服务器内部错误" +// @Router /task/async [post] +func AsyncTaskCreateHandler(c *gin.Context) { + ctx := c.Request.Context() + var request network.AsyncTaskCreateRequest + + if err := c.ShouldBindJSON(&request); err != nil { + logger.Error(ctx, "failed to unmarshal async task create request", "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusBadRequest, + Msg: "invalid request parameters", + }) + return + } + + // Validate task type + if !orm.IsValidAsyncTaskType(request.TaskType) { + logger.Error(ctx, "invalid task type", "task_type", request.TaskType) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusBadRequest, + Msg: "invalid task type", + }) + return + } + + // Validate task parameters based on task type + if !validateTaskParams(request.TaskType, request.Params) { + logger.Error(ctx, "invalid task parameters", "task_type", request.TaskType, "params", request.Params) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusBadRequest, + Msg: "invalid task parameters", + }) + return + } + + // Get database connection from context or use default + db := getDBFromContext(c) + if db == nil { + logger.Error(ctx, "database connection not found in context") + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "database connection error", + }) + return + } + + // Create task in database + taskType := orm.AsyncTaskType(request.TaskType) + params := orm.JSONMap(request.Params) + + asyncTask, err := database.CreateAsyncTask(ctx, db, taskType, params) + if err != nil { + logger.Error(ctx, "failed to create async task in database", "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "failed to create task", + }) + return + } + + // Create task queue message + // taskQueueMsg := task.NewTaskQueueMessage(asyncTask.TaskID, task.TaskType(request.TaskType)) + + // TODO: Send task to message queue (RabbitMQ/Redis) + // This should be implemented when message queue integration is ready + // For now, we'll just log the task creation + logger.Info(ctx, "async task created successfully", "task_id", asyncTask.TaskID, "task_type", request.TaskType) + + // Return success response + c.JSON(http.StatusOK, network.SuccessResponse{ + Code: 2000, + Msg: "task created successfully", + Payload: network.AsyncTaskCreateResponse{ + TaskID: asyncTask.TaskID, + }, + }) +} + +// AsyncTaskResultQueryHandler handles querying of asynchronous task results +// @Summary 查询异步任务结果 +// @Description 根据任务ID列表查询异步任务的状态和结果 +// @Tags AsyncTask +// @Accept json +// @Produce json +// @Param task_ids query string true "任务ID列表,用逗号分隔" +// @Success 200 {object} network.SuccessResponse{payload=network.AsyncTaskResultQueryResponse} "查询成功" +// @Failure 400 {object} network.FailureResponse "请求参数错误" +// @Failure 500 {object} network.FailureResponse "服务器内部错误" +// @Router /task/async/results [get] +func AsyncTaskResultQueryHandler(c *gin.Context) { + ctx := c.Request.Context() + + // Parse task IDs from query parameter + taskIDsParam := c.Query("task_ids") + if taskIDsParam == "" { + logger.Error(ctx, "task_ids parameter is required") + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusBadRequest, + Msg: "task_ids parameter is required", + }) + return + } + + // Parse comma-separated task IDs + var taskIDs []uuid.UUID + taskIDStrs := splitCommaSeparated(taskIDsParam) + for _, taskIDStr := range taskIDStrs { + taskID, err := uuid.FromString(taskIDStr) + if err != nil { + logger.Error(ctx, "invalid task ID format", "task_id", taskIDStr, "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusBadRequest, + Msg: "invalid task ID format", + }) + return + } + taskIDs = append(taskIDs, taskID) + } + + if len(taskIDs) == 0 { + logger.Error(ctx, "no valid task IDs provided") + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusBadRequest, + Msg: "no valid task IDs provided", + }) + return + } + + // Get database connection from context or use default + db := getDBFromContext(c) + if db == nil { + logger.Error(ctx, "database connection not found in context") + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "database connection error", + }) + return + } + + // Query tasks from database + asyncTasks, err := database.GetAsyncTasksByIDs(ctx, db, taskIDs) + if err != nil { + logger.Error(ctx, "failed to query async tasks from database", "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "failed to query tasks", + }) + return + } + + // Query task results from database + taskResults, err := database.GetAsyncTaskResults(ctx, db, taskIDs) + if err != nil { + logger.Error(ctx, "failed to query async task results from database", "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "failed to query task results", + }) + return + } + + // Create a map of task results for easy lookup + taskResultMap := make(map[uuid.UUID]orm.AsyncTaskResult) + for _, result := range taskResults { + taskResultMap[result.TaskID] = result + } + + // Convert to response format + var responseTasks []network.AsyncTaskResult + for _, asyncTask := range asyncTasks { + taskResult := network.AsyncTaskResult{ + TaskID: asyncTask.TaskID, + TaskType: string(asyncTask.TaskType), + Status: string(asyncTask.Status), + CreatedAt: asyncTask.CreatedAt, + FinishedAt: asyncTask.FinishedAt, + Progress: asyncTask.Progress, + } + + // Add result or error information if available + if result, exists := taskResultMap[asyncTask.TaskID]; exists { + if result.Result != nil { + taskResult.Result = map[string]any(result.Result) + } + if result.ErrorCode != nil { + taskResult.ErrorCode = result.ErrorCode + } + if result.ErrorMessage != nil { + taskResult.ErrorMessage = result.ErrorMessage + } + if result.ErrorDetail != nil { + taskResult.ErrorDetail = map[string]any(result.ErrorDetail) + } + } + + responseTasks = append(responseTasks, taskResult) + } + + // Return success response + c.JSON(http.StatusOK, network.SuccessResponse{ + Code: 2000, + Msg: "query completed", + Payload: network.AsyncTaskResultQueryResponse{ + Total: len(responseTasks), + Tasks: responseTasks, + }, + }) +} + +// AsyncTaskProgressUpdateHandler handles updating task progress (internal use, not exposed via API) +func AsyncTaskProgressUpdateHandler(c *gin.Context) { + ctx := c.Request.Context() + var request network.AsyncTaskProgressUpdate + + if err := c.ShouldBindJSON(&request); err != nil { + logger.Error(ctx, "failed to unmarshal async task progress update request", "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusBadRequest, + Msg: "invalid request parameters", + }) + return + } + + // Get database connection from context or use default + db := getDBFromContext(c) + if db == nil { + logger.Error(ctx, "database connection not found in context") + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "database connection error", + }) + return + } + + // Update task progress + err := database.UpdateAsyncTaskProgress(ctx, db, request.TaskID, request.Progress) + if err != nil { + logger.Error(ctx, "failed to update async task progress", "task_id", request.TaskID, "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "failed to update task progress", + }) + return + } + + c.JSON(http.StatusOK, network.SuccessResponse{ + Code: 2000, + Msg: "task progress updated successfully", + Payload: nil, + }) +} + +// AsyncTaskStatusUpdateHandler handles updating task status (internal use, not exposed via API) +func AsyncTaskStatusUpdateHandler(c *gin.Context) { + ctx := c.Request.Context() + var request network.AsyncTaskStatusUpdate + + if err := c.ShouldBindJSON(&request); err != nil { + logger.Error(ctx, "failed to unmarshal async task status update request", "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusBadRequest, + Msg: "invalid request parameters", + }) + return + } + + // Validate status + validStatus := map[string]bool{ + string(orm.AsyncTaskStatusSubmitted): true, + string(orm.AsyncTaskStatusRunning): true, + string(orm.AsyncTaskStatusCompleted): true, + string(orm.AsyncTaskStatusFailed): true, + } + + if !validStatus[request.Status] { + logger.Error(ctx, "invalid task status", "status", request.Status) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusBadRequest, + Msg: "invalid task status", + }) + return + } + + // Get database connection from context or use default + db := getDBFromContext(c) + if db == nil { + logger.Error(ctx, "database connection not found in context") + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "database connection error", + }) + return + } + + // Update task status + status := orm.AsyncTaskStatus(request.Status) + err := database.UpdateAsyncTaskStatus(ctx, db, request.TaskID, status) + if err != nil { + logger.Error(ctx, "failed to update async task status", "task_id", request.TaskID, "status", request.Status, "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "failed to update task status", + }) + return + } + + // If task is completed or failed, update finished_at timestamp + if request.Status == string(orm.AsyncTaskStatusCompleted) { + err = database.CompleteAsyncTask(ctx, db, request.TaskID, request.Timestamp) + } else if request.Status == string(orm.AsyncTaskStatusFailed) { + err = database.FailAsyncTask(ctx, db, request.TaskID, request.Timestamp) + } + + if err != nil { + logger.Error(ctx, "failed to update async task completion timestamp", "task_id", request.TaskID, "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "failed to update task completion timestamp", + }) + return + } + + c.JSON(http.StatusOK, network.SuccessResponse{ + Code: 2000, + Msg: "task status updated successfully", + Payload: nil, + }) +} + +// Helper functions + +func validateTaskParams(taskType string, params map[string]any) bool { + switch taskType { + case string(orm.AsyncTaskTypeTopologyAnalysis): + return validateTopologyAnalysisParams(params) + case string(orm.AsyncTaskTypePerformanceAnalysis): + return validatePerformanceAnalysisParams(params) + case string(orm.AsyncTaskTypeEventAnalysis): + return validateEventAnalysisParams(params) + case string(orm.AsyncTaskTypeBatchImport): + return validateBatchImportParams(params) + default: + return false + } +} + +func validateTopologyAnalysisParams(params map[string]any) bool { + // Check required parameters for topology analysis + if startUUID, ok := params["start_uuid"]; !ok || startUUID == "" { + return false + } + if endUUID, ok := params["end_uuid"]; !ok || endUUID == "" { + return false + } + return true +} + +func validatePerformanceAnalysisParams(params map[string]any) bool { + // Check required parameters for performance analysis + if componentIDs, ok := params["component_ids"]; !ok { + return false + } else if ids, isSlice := componentIDs.([]interface{}); !isSlice || len(ids) == 0 { + return false + } + return true +} + +func validateEventAnalysisParams(params map[string]any) bool { + // Check required parameters for event analysis + if eventType, ok := params["event_type"]; !ok || eventType == "" { + return false + } + return true +} + +func validateBatchImportParams(params map[string]any) bool { + // Check required parameters for batch import + if filePath, ok := params["file_path"]; !ok || filePath == "" { + return false + } + return true +} + +func splitCommaSeparated(s string) []string { + var result []string + var current strings.Builder + inQuotes := false + escape := false + + for _, ch := range s { + if escape { + current.WriteRune(ch) + escape = false + continue + } + + switch ch { + case '\\': + escape = true + case '"': + inQuotes = !inQuotes + case ',': + if !inQuotes { + result = append(result, strings.TrimSpace(current.String())) + current.Reset() + } else { + current.WriteRune(ch) + } + default: + current.WriteRune(ch) + } + } + + if current.Len() > 0 { + result = append(result, strings.TrimSpace(current.String())) + } + + return result +} + +func getDBFromContext(c *gin.Context) *gorm.DB { + // Try to get database connection from context + // This should be set by middleware + if db, exists := c.Get("db"); exists { + if gormDB, ok := db.(*gorm.DB); ok { + return gormDB + } + } + + // Fallback to global database connection + // This should be implemented based on your application's database setup + // For now, return nil - actual implementation should retrieve from application context + return nil +} + +// AsyncTaskResultDetailHandler handles detailed query of a single async task result +// @Summary 查询异步任务详情 +// @Description 根据任务ID查询异步任务的详细状态和结果 +// @Tags AsyncTask +// @Accept json +// @Produce json +// @Param task_id path string true "任务ID" +// @Success 200 {object} network.SuccessResponse{payload=network.AsyncTaskResult} "查询成功" +// @Failure 400 {object} network.FailureResponse "请求参数错误" +// @Failure 404 {object} network.FailureResponse "任务不存在" +// @Failure 500 {object} network.FailureResponse "服务器内部错误" +// @Router /task/async/{task_id} [get] +func AsyncTaskResultDetailHandler(c *gin.Context) { + ctx := c.Request.Context() + + // Parse task ID from path parameter + taskIDStr := c.Param("task_id") + if taskIDStr == "" { + logger.Error(ctx, "task_id parameter is required") + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusBadRequest, + Msg: "task_id parameter is required", + }) + return + } + + taskID, err := uuid.FromString(taskIDStr) + if err != nil { + logger.Error(ctx, "invalid task ID format", "task_id", taskIDStr, "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusBadRequest, + Msg: "invalid task ID format", + }) + return + } + + // Get database connection from context or use default + db := getDBFromContext(c) + if db == nil { + logger.Error(ctx, "database connection not found in context") + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "database connection error", + }) + return + } + + // Query task from database + asyncTask, err := database.GetAsyncTaskByID(ctx, db, taskID) + if err != nil { + if err == gorm.ErrRecordNotFound { + logger.Error(ctx, "async task not found", "task_id", taskID) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusNotFound, + Msg: "task not found", + }) + return + } + logger.Error(ctx, "failed to query async task from database", "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "failed to query task", + }) + return + } + + // Query task result from database + taskResult, err := database.GetAsyncTaskResult(ctx, db, taskID) + if err != nil { + logger.Error(ctx, "failed to query async task result from database", "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "failed to query task result", + }) + return + } + + // Convert to response format + responseTask := network.AsyncTaskResult{ + TaskID: asyncTask.TaskID, + TaskType: string(asyncTask.TaskType), + Status: string(asyncTask.Status), + CreatedAt: asyncTask.CreatedAt, + FinishedAt: asyncTask.FinishedAt, + Progress: asyncTask.Progress, + } + + // Add result or error information if available + if taskResult != nil { + if taskResult.Result != nil { + responseTask.Result = map[string]any(taskResult.Result) + } + if taskResult.ErrorCode != nil { + responseTask.ErrorCode = taskResult.ErrorCode + } + if taskResult.ErrorMessage != nil { + responseTask.ErrorMessage = taskResult.ErrorMessage + } + if taskResult.ErrorDetail != nil { + responseTask.ErrorDetail = map[string]any(taskResult.ErrorDetail) + } + } + + // Return success response + c.JSON(http.StatusOK, network.SuccessResponse{ + Code: 2000, + Msg: "query completed", + Payload: responseTask, + }) +} + +// AsyncTaskCancelHandler handles cancellation of an async task +// @Summary 取消异步任务 +// @Description 取消指定ID的异步任务(如果任务尚未开始执行) +// @Tags AsyncTask +// @Accept json +// @Produce json +// @Param task_id path string true "任务ID" +// @Success 200 {object} network.SuccessResponse "任务取消成功" +// @Failure 400 {object} network.FailureResponse "请求参数错误或任务无法取消" +// @Failure 404 {object} network.FailureResponse "任务不存在" +// @Failure 500 {object} network.FailureResponse "服务器内部错误" +// @Router /task/async/{task_id}/cancel [post] +func AsyncTaskCancelHandler(c *gin.Context) { + ctx := c.Request.Context() + + // Parse task ID from path parameter + taskIDStr := c.Param("task_id") + if taskIDStr == "" { + logger.Error(ctx, "task_id parameter is required") + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusBadRequest, + Msg: "task_id parameter is required", + }) + return + } + + taskID, err := uuid.FromString(taskIDStr) + if err != nil { + logger.Error(ctx, "invalid task ID format", "task_id", taskIDStr, "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusBadRequest, + Msg: "invalid task ID format", + }) + return + } + + // Get database connection from context or use default + db := getDBFromContext(c) + if db == nil { + logger.Error(ctx, "database connection not found in context") + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "database connection error", + }) + return + } + + // Query task from database + asyncTask, err := database.GetAsyncTaskByID(ctx, db, taskID) + if err != nil { + if err == gorm.ErrRecordNotFound { + logger.Error(ctx, "async task not found", "task_id", taskID) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusNotFound, + Msg: "task not found", + }) + return + } + logger.Error(ctx, "failed to query async task from database", "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "failed to query task", + }) + return + } + + // Check if task can be cancelled (only SUBMITTED tasks can be cancelled) + if asyncTask.Status != orm.AsyncTaskStatusSubmitted { + logger.Error(ctx, "task cannot be cancelled", "task_id", taskID, "status", asyncTask.Status) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusBadRequest, + Msg: "task cannot be cancelled (already running or completed)", + }) + return + } + + // Update task status to failed with cancellation reason + timestamp := time.Now().Unix() + err = database.FailAsyncTask(ctx, db, taskID, timestamp) + if err != nil { + logger.Error(ctx, "failed to cancel async task", "task_id", taskID, "error", err) + c.JSON(http.StatusOK, network.FailureResponse{ + Code: http.StatusInternalServerError, + Msg: "failed to cancel task", + }) + return + } + + // Update task result with cancellation error + err = database.UpdateAsyncTaskResultWithError(ctx, db, taskID, 40003, "task cancelled by user", orm.JSONMap{ + "cancelled_at": timestamp, + "cancelled_by": "user", + }) + if err != nil { + logger.Error(ctx, "failed to update task result with cancellation error", "task_id", taskID, "error", err) + // Continue anyway since task is already marked as failed + } + + c.JSON(http.StatusOK, network.SuccessResponse{ + Code: 2000, + Msg: "task cancelled successfully", + }) +} diff --git a/router/async_task.go b/router/async_task.go new file mode 100644 index 0000000..4b0861c --- /dev/null +++ b/router/async_task.go @@ -0,0 +1,32 @@ +// Package router provides router config +package router + +import ( + "modelRT/handler" + + "github.com/gin-gonic/gin" +) + +// registerAsyncTaskRoutes define func of register async task routes +func registerAsyncTaskRoutes(rg *gin.RouterGroup, middlewares ...gin.HandlerFunc) { + g := rg.Group("/task/") + g.Use(middlewares...) + + // Async task creation + g.POST("async", handler.AsyncTaskCreateHandler) + + // Async task result query + g.GET("async/results", handler.AsyncTaskResultQueryHandler) + + // Async task detail query + g.GET("async/:task_id", handler.AsyncTaskResultDetailHandler) + + // Async task cancellation + g.POST("async/:task_id/cancel", handler.AsyncTaskCancelHandler) + + // Internal APIs for worker updates (not exposed to external users) + internal := g.Group("internal/") + internal.Use(middlewares...) + internal.POST("async/progress", handler.AsyncTaskProgressUpdateHandler) + internal.POST("async/status", handler.AsyncTaskStatusUpdateHandler) +} \ No newline at end of file diff --git a/router/router.go b/router/router.go index 6fbc113..f242cb9 100644 --- a/router/router.go +++ b/router/router.go @@ -27,4 +27,5 @@ func RegisterRoutes(engine *gin.Engine, clientToken string) { registerDataRoutes(routeGroup) registerMonitorRoutes(routeGroup) registerComponentRoutes(routeGroup, middleware.SetTokenMiddleware(clientToken)) + registerAsyncTaskRoutes(routeGroup, middleware.SetTokenMiddleware(clientToken)) }