add route of async task system
This commit is contained in:
parent
adcc8c6c91
commit
de5f976c31
|
|
@ -28,3 +28,15 @@ go.work
|
|||
# Shield config files in the configs folder
|
||||
/configs/**/*.yaml
|
||||
/configs/**/*.pem
|
||||
|
||||
# ai config
|
||||
.cursor/
|
||||
.claude/
|
||||
.cursorrules
|
||||
.copilot/
|
||||
.chatgpt/
|
||||
.ai_history/
|
||||
.vector_cache/
|
||||
ai-debug.log
|
||||
*.patch
|
||||
*.diff
|
||||
|
|
@ -0,0 +1,677 @@
|
|||
// Package handler provides HTTP handlers for various endpoints.
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"modelRT/database"
|
||||
"modelRT/logger"
|
||||
"modelRT/network"
|
||||
"modelRT/orm"
|
||||
_ "modelRT/task"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AsyncTaskCreateHandler handles creation of asynchronous tasks
|
||||
// @Summary 创建异步任务
|
||||
// @Description 创建新的异步任务并返回任务ID,任务将被提交到队列等待处理
|
||||
// @Tags AsyncTask
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body network.AsyncTaskCreateRequest true "任务创建请求"
|
||||
// @Success 200 {object} network.SuccessResponse{payload=network.AsyncTaskCreateResponse} "任务创建成功"
|
||||
// @Failure 400 {object} network.FailureResponse "请求参数错误"
|
||||
// @Failure 500 {object} network.FailureResponse "服务器内部错误"
|
||||
// @Router /task/async [post]
|
||||
func AsyncTaskCreateHandler(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
var request network.AsyncTaskCreateRequest
|
||||
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
logger.Error(ctx, "failed to unmarshal async task create request", "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "invalid request parameters",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate task type
|
||||
if !orm.IsValidAsyncTaskType(request.TaskType) {
|
||||
logger.Error(ctx, "invalid task type", "task_type", request.TaskType)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "invalid task type",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate task parameters based on task type
|
||||
if !validateTaskParams(request.TaskType, request.Params) {
|
||||
logger.Error(ctx, "invalid task parameters", "task_type", request.TaskType, "params", request.Params)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "invalid task parameters",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get database connection from context or use default
|
||||
db := getDBFromContext(c)
|
||||
if db == nil {
|
||||
logger.Error(ctx, "database connection not found in context")
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "database connection error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create task in database
|
||||
taskType := orm.AsyncTaskType(request.TaskType)
|
||||
params := orm.JSONMap(request.Params)
|
||||
|
||||
asyncTask, err := database.CreateAsyncTask(ctx, db, taskType, params)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to create async task in database", "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "failed to create task",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create task queue message
|
||||
// taskQueueMsg := task.NewTaskQueueMessage(asyncTask.TaskID, task.TaskType(request.TaskType))
|
||||
|
||||
// TODO: Send task to message queue (RabbitMQ/Redis)
|
||||
// This should be implemented when message queue integration is ready
|
||||
// For now, we'll just log the task creation
|
||||
logger.Info(ctx, "async task created successfully", "task_id", asyncTask.TaskID, "task_type", request.TaskType)
|
||||
|
||||
// Return success response
|
||||
c.JSON(http.StatusOK, network.SuccessResponse{
|
||||
Code: 2000,
|
||||
Msg: "task created successfully",
|
||||
Payload: network.AsyncTaskCreateResponse{
|
||||
TaskID: asyncTask.TaskID,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// AsyncTaskResultQueryHandler handles querying of asynchronous task results
|
||||
// @Summary 查询异步任务结果
|
||||
// @Description 根据任务ID列表查询异步任务的状态和结果
|
||||
// @Tags AsyncTask
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param task_ids query string true "任务ID列表,用逗号分隔"
|
||||
// @Success 200 {object} network.SuccessResponse{payload=network.AsyncTaskResultQueryResponse} "查询成功"
|
||||
// @Failure 400 {object} network.FailureResponse "请求参数错误"
|
||||
// @Failure 500 {object} network.FailureResponse "服务器内部错误"
|
||||
// @Router /task/async/results [get]
|
||||
func AsyncTaskResultQueryHandler(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Parse task IDs from query parameter
|
||||
taskIDsParam := c.Query("task_ids")
|
||||
if taskIDsParam == "" {
|
||||
logger.Error(ctx, "task_ids parameter is required")
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "task_ids parameter is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse comma-separated task IDs
|
||||
var taskIDs []uuid.UUID
|
||||
taskIDStrs := splitCommaSeparated(taskIDsParam)
|
||||
for _, taskIDStr := range taskIDStrs {
|
||||
taskID, err := uuid.FromString(taskIDStr)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "invalid task ID format", "task_id", taskIDStr, "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "invalid task ID format",
|
||||
})
|
||||
return
|
||||
}
|
||||
taskIDs = append(taskIDs, taskID)
|
||||
}
|
||||
|
||||
if len(taskIDs) == 0 {
|
||||
logger.Error(ctx, "no valid task IDs provided")
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "no valid task IDs provided",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get database connection from context or use default
|
||||
db := getDBFromContext(c)
|
||||
if db == nil {
|
||||
logger.Error(ctx, "database connection not found in context")
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "database connection error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Query tasks from database
|
||||
asyncTasks, err := database.GetAsyncTasksByIDs(ctx, db, taskIDs)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to query async tasks from database", "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "failed to query tasks",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Query task results from database
|
||||
taskResults, err := database.GetAsyncTaskResults(ctx, db, taskIDs)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to query async task results from database", "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "failed to query task results",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create a map of task results for easy lookup
|
||||
taskResultMap := make(map[uuid.UUID]orm.AsyncTaskResult)
|
||||
for _, result := range taskResults {
|
||||
taskResultMap[result.TaskID] = result
|
||||
}
|
||||
|
||||
// Convert to response format
|
||||
var responseTasks []network.AsyncTaskResult
|
||||
for _, asyncTask := range asyncTasks {
|
||||
taskResult := network.AsyncTaskResult{
|
||||
TaskID: asyncTask.TaskID,
|
||||
TaskType: string(asyncTask.TaskType),
|
||||
Status: string(asyncTask.Status),
|
||||
CreatedAt: asyncTask.CreatedAt,
|
||||
FinishedAt: asyncTask.FinishedAt,
|
||||
Progress: asyncTask.Progress,
|
||||
}
|
||||
|
||||
// Add result or error information if available
|
||||
if result, exists := taskResultMap[asyncTask.TaskID]; exists {
|
||||
if result.Result != nil {
|
||||
taskResult.Result = map[string]any(result.Result)
|
||||
}
|
||||
if result.ErrorCode != nil {
|
||||
taskResult.ErrorCode = result.ErrorCode
|
||||
}
|
||||
if result.ErrorMessage != nil {
|
||||
taskResult.ErrorMessage = result.ErrorMessage
|
||||
}
|
||||
if result.ErrorDetail != nil {
|
||||
taskResult.ErrorDetail = map[string]any(result.ErrorDetail)
|
||||
}
|
||||
}
|
||||
|
||||
responseTasks = append(responseTasks, taskResult)
|
||||
}
|
||||
|
||||
// Return success response
|
||||
c.JSON(http.StatusOK, network.SuccessResponse{
|
||||
Code: 2000,
|
||||
Msg: "query completed",
|
||||
Payload: network.AsyncTaskResultQueryResponse{
|
||||
Total: len(responseTasks),
|
||||
Tasks: responseTasks,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// AsyncTaskProgressUpdateHandler handles updating task progress (internal use, not exposed via API)
|
||||
func AsyncTaskProgressUpdateHandler(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
var request network.AsyncTaskProgressUpdate
|
||||
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
logger.Error(ctx, "failed to unmarshal async task progress update request", "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "invalid request parameters",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get database connection from context or use default
|
||||
db := getDBFromContext(c)
|
||||
if db == nil {
|
||||
logger.Error(ctx, "database connection not found in context")
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "database connection error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Update task progress
|
||||
err := database.UpdateAsyncTaskProgress(ctx, db, request.TaskID, request.Progress)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to update async task progress", "task_id", request.TaskID, "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "failed to update task progress",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, network.SuccessResponse{
|
||||
Code: 2000,
|
||||
Msg: "task progress updated successfully",
|
||||
Payload: nil,
|
||||
})
|
||||
}
|
||||
|
||||
// AsyncTaskStatusUpdateHandler handles updating task status (internal use, not exposed via API)
|
||||
func AsyncTaskStatusUpdateHandler(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
var request network.AsyncTaskStatusUpdate
|
||||
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
logger.Error(ctx, "failed to unmarshal async task status update request", "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "invalid request parameters",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate status
|
||||
validStatus := map[string]bool{
|
||||
string(orm.AsyncTaskStatusSubmitted): true,
|
||||
string(orm.AsyncTaskStatusRunning): true,
|
||||
string(orm.AsyncTaskStatusCompleted): true,
|
||||
string(orm.AsyncTaskStatusFailed): true,
|
||||
}
|
||||
|
||||
if !validStatus[request.Status] {
|
||||
logger.Error(ctx, "invalid task status", "status", request.Status)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "invalid task status",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get database connection from context or use default
|
||||
db := getDBFromContext(c)
|
||||
if db == nil {
|
||||
logger.Error(ctx, "database connection not found in context")
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "database connection error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Update task status
|
||||
status := orm.AsyncTaskStatus(request.Status)
|
||||
err := database.UpdateAsyncTaskStatus(ctx, db, request.TaskID, status)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to update async task status", "task_id", request.TaskID, "status", request.Status, "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "failed to update task status",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// If task is completed or failed, update finished_at timestamp
|
||||
if request.Status == string(orm.AsyncTaskStatusCompleted) {
|
||||
err = database.CompleteAsyncTask(ctx, db, request.TaskID, request.Timestamp)
|
||||
} else if request.Status == string(orm.AsyncTaskStatusFailed) {
|
||||
err = database.FailAsyncTask(ctx, db, request.TaskID, request.Timestamp)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to update async task completion timestamp", "task_id", request.TaskID, "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "failed to update task completion timestamp",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, network.SuccessResponse{
|
||||
Code: 2000,
|
||||
Msg: "task status updated successfully",
|
||||
Payload: nil,
|
||||
})
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func validateTaskParams(taskType string, params map[string]any) bool {
|
||||
switch taskType {
|
||||
case string(orm.AsyncTaskTypeTopologyAnalysis):
|
||||
return validateTopologyAnalysisParams(params)
|
||||
case string(orm.AsyncTaskTypePerformanceAnalysis):
|
||||
return validatePerformanceAnalysisParams(params)
|
||||
case string(orm.AsyncTaskTypeEventAnalysis):
|
||||
return validateEventAnalysisParams(params)
|
||||
case string(orm.AsyncTaskTypeBatchImport):
|
||||
return validateBatchImportParams(params)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func validateTopologyAnalysisParams(params map[string]any) bool {
|
||||
// Check required parameters for topology analysis
|
||||
if startUUID, ok := params["start_uuid"]; !ok || startUUID == "" {
|
||||
return false
|
||||
}
|
||||
if endUUID, ok := params["end_uuid"]; !ok || endUUID == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func validatePerformanceAnalysisParams(params map[string]any) bool {
|
||||
// Check required parameters for performance analysis
|
||||
if componentIDs, ok := params["component_ids"]; !ok {
|
||||
return false
|
||||
} else if ids, isSlice := componentIDs.([]interface{}); !isSlice || len(ids) == 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func validateEventAnalysisParams(params map[string]any) bool {
|
||||
// Check required parameters for event analysis
|
||||
if eventType, ok := params["event_type"]; !ok || eventType == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func validateBatchImportParams(params map[string]any) bool {
|
||||
// Check required parameters for batch import
|
||||
if filePath, ok := params["file_path"]; !ok || filePath == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func splitCommaSeparated(s string) []string {
|
||||
var result []string
|
||||
var current strings.Builder
|
||||
inQuotes := false
|
||||
escape := false
|
||||
|
||||
for _, ch := range s {
|
||||
if escape {
|
||||
current.WriteRune(ch)
|
||||
escape = false
|
||||
continue
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case '\\':
|
||||
escape = true
|
||||
case '"':
|
||||
inQuotes = !inQuotes
|
||||
case ',':
|
||||
if !inQuotes {
|
||||
result = append(result, strings.TrimSpace(current.String()))
|
||||
current.Reset()
|
||||
} else {
|
||||
current.WriteRune(ch)
|
||||
}
|
||||
default:
|
||||
current.WriteRune(ch)
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
result = append(result, strings.TrimSpace(current.String()))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func getDBFromContext(c *gin.Context) *gorm.DB {
|
||||
// Try to get database connection from context
|
||||
// This should be set by middleware
|
||||
if db, exists := c.Get("db"); exists {
|
||||
if gormDB, ok := db.(*gorm.DB); ok {
|
||||
return gormDB
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to global database connection
|
||||
// This should be implemented based on your application's database setup
|
||||
// For now, return nil - actual implementation should retrieve from application context
|
||||
return nil
|
||||
}
|
||||
|
||||
// AsyncTaskResultDetailHandler handles detailed query of a single async task result
|
||||
// @Summary 查询异步任务详情
|
||||
// @Description 根据任务ID查询异步任务的详细状态和结果
|
||||
// @Tags AsyncTask
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param task_id path string true "任务ID"
|
||||
// @Success 200 {object} network.SuccessResponse{payload=network.AsyncTaskResult} "查询成功"
|
||||
// @Failure 400 {object} network.FailureResponse "请求参数错误"
|
||||
// @Failure 404 {object} network.FailureResponse "任务不存在"
|
||||
// @Failure 500 {object} network.FailureResponse "服务器内部错误"
|
||||
// @Router /task/async/{task_id} [get]
|
||||
func AsyncTaskResultDetailHandler(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Parse task ID from path parameter
|
||||
taskIDStr := c.Param("task_id")
|
||||
if taskIDStr == "" {
|
||||
logger.Error(ctx, "task_id parameter is required")
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "task_id parameter is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
taskID, err := uuid.FromString(taskIDStr)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "invalid task ID format", "task_id", taskIDStr, "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "invalid task ID format",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get database connection from context or use default
|
||||
db := getDBFromContext(c)
|
||||
if db == nil {
|
||||
logger.Error(ctx, "database connection not found in context")
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "database connection error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Query task from database
|
||||
asyncTask, err := database.GetAsyncTaskByID(ctx, db, taskID)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
logger.Error(ctx, "async task not found", "task_id", taskID)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusNotFound,
|
||||
Msg: "task not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
logger.Error(ctx, "failed to query async task from database", "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "failed to query task",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Query task result from database
|
||||
taskResult, err := database.GetAsyncTaskResult(ctx, db, taskID)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to query async task result from database", "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "failed to query task result",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to response format
|
||||
responseTask := network.AsyncTaskResult{
|
||||
TaskID: asyncTask.TaskID,
|
||||
TaskType: string(asyncTask.TaskType),
|
||||
Status: string(asyncTask.Status),
|
||||
CreatedAt: asyncTask.CreatedAt,
|
||||
FinishedAt: asyncTask.FinishedAt,
|
||||
Progress: asyncTask.Progress,
|
||||
}
|
||||
|
||||
// Add result or error information if available
|
||||
if taskResult != nil {
|
||||
if taskResult.Result != nil {
|
||||
responseTask.Result = map[string]any(taskResult.Result)
|
||||
}
|
||||
if taskResult.ErrorCode != nil {
|
||||
responseTask.ErrorCode = taskResult.ErrorCode
|
||||
}
|
||||
if taskResult.ErrorMessage != nil {
|
||||
responseTask.ErrorMessage = taskResult.ErrorMessage
|
||||
}
|
||||
if taskResult.ErrorDetail != nil {
|
||||
responseTask.ErrorDetail = map[string]any(taskResult.ErrorDetail)
|
||||
}
|
||||
}
|
||||
|
||||
// Return success response
|
||||
c.JSON(http.StatusOK, network.SuccessResponse{
|
||||
Code: 2000,
|
||||
Msg: "query completed",
|
||||
Payload: responseTask,
|
||||
})
|
||||
}
|
||||
|
||||
// AsyncTaskCancelHandler handles cancellation of an async task
|
||||
// @Summary 取消异步任务
|
||||
// @Description 取消指定ID的异步任务(如果任务尚未开始执行)
|
||||
// @Tags AsyncTask
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param task_id path string true "任务ID"
|
||||
// @Success 200 {object} network.SuccessResponse "任务取消成功"
|
||||
// @Failure 400 {object} network.FailureResponse "请求参数错误或任务无法取消"
|
||||
// @Failure 404 {object} network.FailureResponse "任务不存在"
|
||||
// @Failure 500 {object} network.FailureResponse "服务器内部错误"
|
||||
// @Router /task/async/{task_id}/cancel [post]
|
||||
func AsyncTaskCancelHandler(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Parse task ID from path parameter
|
||||
taskIDStr := c.Param("task_id")
|
||||
if taskIDStr == "" {
|
||||
logger.Error(ctx, "task_id parameter is required")
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "task_id parameter is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
taskID, err := uuid.FromString(taskIDStr)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "invalid task ID format", "task_id", taskIDStr, "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "invalid task ID format",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get database connection from context or use default
|
||||
db := getDBFromContext(c)
|
||||
if db == nil {
|
||||
logger.Error(ctx, "database connection not found in context")
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "database connection error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Query task from database
|
||||
asyncTask, err := database.GetAsyncTaskByID(ctx, db, taskID)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
logger.Error(ctx, "async task not found", "task_id", taskID)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusNotFound,
|
||||
Msg: "task not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
logger.Error(ctx, "failed to query async task from database", "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "failed to query task",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if task can be cancelled (only SUBMITTED tasks can be cancelled)
|
||||
if asyncTask.Status != orm.AsyncTaskStatusSubmitted {
|
||||
logger.Error(ctx, "task cannot be cancelled", "task_id", taskID, "status", asyncTask.Status)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "task cannot be cancelled (already running or completed)",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Update task status to failed with cancellation reason
|
||||
timestamp := time.Now().Unix()
|
||||
err = database.FailAsyncTask(ctx, db, taskID, timestamp)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to cancel async task", "task_id", taskID, "error", err)
|
||||
c.JSON(http.StatusOK, network.FailureResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
Msg: "failed to cancel task",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Update task result with cancellation error
|
||||
err = database.UpdateAsyncTaskResultWithError(ctx, db, taskID, 40003, "task cancelled by user", orm.JSONMap{
|
||||
"cancelled_at": timestamp,
|
||||
"cancelled_by": "user",
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to update task result with cancellation error", "task_id", taskID, "error", err)
|
||||
// Continue anyway since task is already marked as failed
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, network.SuccessResponse{
|
||||
Code: 2000,
|
||||
Msg: "task cancelled successfully",
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
// Package router provides router config
|
||||
package router
|
||||
|
||||
import (
|
||||
"modelRT/handler"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// registerAsyncTaskRoutes define func of register async task routes
|
||||
func registerAsyncTaskRoutes(rg *gin.RouterGroup, middlewares ...gin.HandlerFunc) {
|
||||
g := rg.Group("/task/")
|
||||
g.Use(middlewares...)
|
||||
|
||||
// Async task creation
|
||||
g.POST("async", handler.AsyncTaskCreateHandler)
|
||||
|
||||
// Async task result query
|
||||
g.GET("async/results", handler.AsyncTaskResultQueryHandler)
|
||||
|
||||
// Async task detail query
|
||||
g.GET("async/:task_id", handler.AsyncTaskResultDetailHandler)
|
||||
|
||||
// Async task cancellation
|
||||
g.POST("async/:task_id/cancel", handler.AsyncTaskCancelHandler)
|
||||
|
||||
// Internal APIs for worker updates (not exposed to external users)
|
||||
internal := g.Group("internal/")
|
||||
internal.Use(middlewares...)
|
||||
internal.POST("async/progress", handler.AsyncTaskProgressUpdateHandler)
|
||||
internal.POST("async/status", handler.AsyncTaskStatusUpdateHandler)
|
||||
}
|
||||
|
|
@ -27,4 +27,5 @@ func RegisterRoutes(engine *gin.Engine, clientToken string) {
|
|||
registerDataRoutes(routeGroup)
|
||||
registerMonitorRoutes(routeGroup)
|
||||
registerComponentRoutes(routeGroup, middleware.SetTokenMiddleware(clientToken))
|
||||
registerAsyncTaskRoutes(routeGroup, middleware.SetTokenMiddleware(clientToken))
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue