678 lines
20 KiB
Go
678 lines
20 KiB
Go
// Package handler provides HTTP handlers for various endpoints.
|
||
package handler
|
||
|
||
import (
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"modelRT/database"
|
||
"modelRT/logger"
|
||
"modelRT/network"
|
||
"modelRT/orm"
|
||
_ "modelRT/task"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/gofrs/uuid"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// AsyncTaskCreateHandler handles creation of asynchronous tasks
|
||
// @Summary 创建异步任务
|
||
// @Description 创建新的异步任务并返回任务ID,任务将被提交到队列等待处理
|
||
// @Tags AsyncTask
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param request body network.AsyncTaskCreateRequest true "任务创建请求"
|
||
// @Success 200 {object} network.SuccessResponse{payload=network.AsyncTaskCreateResponse} "任务创建成功"
|
||
// @Failure 400 {object} network.FailureResponse "请求参数错误"
|
||
// @Failure 500 {object} network.FailureResponse "服务器内部错误"
|
||
// @Router /task/async [post]
|
||
func AsyncTaskCreateHandler(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
var request network.AsyncTaskCreateRequest
|
||
|
||
if err := c.ShouldBindJSON(&request); err != nil {
|
||
logger.Error(ctx, "failed to unmarshal async task create request", "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusBadRequest,
|
||
Msg: "invalid request parameters",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Validate task type
|
||
if !orm.IsValidAsyncTaskType(request.TaskType) {
|
||
logger.Error(ctx, "invalid task type", "task_type", request.TaskType)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusBadRequest,
|
||
Msg: "invalid task type",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Validate task parameters based on task type
|
||
if !validateTaskParams(request.TaskType, request.Params) {
|
||
logger.Error(ctx, "invalid task parameters", "task_type", request.TaskType, "params", request.Params)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusBadRequest,
|
||
Msg: "invalid task parameters",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Get database connection from context or use default
|
||
db := getDBFromContext(c)
|
||
if db == nil {
|
||
logger.Error(ctx, "database connection not found in context")
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "database connection error",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Create task in database
|
||
taskType := orm.AsyncTaskType(request.TaskType)
|
||
params := orm.JSONMap(request.Params)
|
||
|
||
asyncTask, err := database.CreateAsyncTask(ctx, db, taskType, params)
|
||
if err != nil {
|
||
logger.Error(ctx, "failed to create async task in database", "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "failed to create task",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Create task queue message
|
||
// taskQueueMsg := task.NewTaskQueueMessage(asyncTask.TaskID, task.TaskType(request.TaskType))
|
||
|
||
// TODO: Send task to message queue (RabbitMQ/Redis)
|
||
// This should be implemented when message queue integration is ready
|
||
// For now, we'll just log the task creation
|
||
logger.Info(ctx, "async task created successfully", "task_id", asyncTask.TaskID, "task_type", request.TaskType)
|
||
|
||
// Return success response
|
||
c.JSON(http.StatusOK, network.SuccessResponse{
|
||
Code: 2000,
|
||
Msg: "task created successfully",
|
||
Payload: network.AsyncTaskCreateResponse{
|
||
TaskID: asyncTask.TaskID,
|
||
},
|
||
})
|
||
}
|
||
|
||
// AsyncTaskResultQueryHandler handles querying of asynchronous task results
|
||
// @Summary 查询异步任务结果
|
||
// @Description 根据任务ID列表查询异步任务的状态和结果
|
||
// @Tags AsyncTask
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param task_ids query string true "任务ID列表,用逗号分隔"
|
||
// @Success 200 {object} network.SuccessResponse{payload=network.AsyncTaskResultQueryResponse} "查询成功"
|
||
// @Failure 400 {object} network.FailureResponse "请求参数错误"
|
||
// @Failure 500 {object} network.FailureResponse "服务器内部错误"
|
||
// @Router /task/async/results [get]
|
||
func AsyncTaskResultQueryHandler(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
// Parse task IDs from query parameter
|
||
taskIDsParam := c.Query("task_ids")
|
||
if taskIDsParam == "" {
|
||
logger.Error(ctx, "task_ids parameter is required")
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusBadRequest,
|
||
Msg: "task_ids parameter is required",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Parse comma-separated task IDs
|
||
var taskIDs []uuid.UUID
|
||
taskIDStrs := splitCommaSeparated(taskIDsParam)
|
||
for _, taskIDStr := range taskIDStrs {
|
||
taskID, err := uuid.FromString(taskIDStr)
|
||
if err != nil {
|
||
logger.Error(ctx, "invalid task ID format", "task_id", taskIDStr, "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusBadRequest,
|
||
Msg: "invalid task ID format",
|
||
})
|
||
return
|
||
}
|
||
taskIDs = append(taskIDs, taskID)
|
||
}
|
||
|
||
if len(taskIDs) == 0 {
|
||
logger.Error(ctx, "no valid task IDs provided")
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusBadRequest,
|
||
Msg: "no valid task IDs provided",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Get database connection from context or use default
|
||
db := getDBFromContext(c)
|
||
if db == nil {
|
||
logger.Error(ctx, "database connection not found in context")
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "database connection error",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Query tasks from database
|
||
asyncTasks, err := database.GetAsyncTasksByIDs(ctx, db, taskIDs)
|
||
if err != nil {
|
||
logger.Error(ctx, "failed to query async tasks from database", "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "failed to query tasks",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Query task results from database
|
||
taskResults, err := database.GetAsyncTaskResults(ctx, db, taskIDs)
|
||
if err != nil {
|
||
logger.Error(ctx, "failed to query async task results from database", "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "failed to query task results",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Create a map of task results for easy lookup
|
||
taskResultMap := make(map[uuid.UUID]orm.AsyncTaskResult)
|
||
for _, result := range taskResults {
|
||
taskResultMap[result.TaskID] = result
|
||
}
|
||
|
||
// Convert to response format
|
||
var responseTasks []network.AsyncTaskResult
|
||
for _, asyncTask := range asyncTasks {
|
||
taskResult := network.AsyncTaskResult{
|
||
TaskID: asyncTask.TaskID,
|
||
TaskType: string(asyncTask.TaskType),
|
||
Status: string(asyncTask.Status),
|
||
CreatedAt: asyncTask.CreatedAt,
|
||
FinishedAt: asyncTask.FinishedAt,
|
||
Progress: asyncTask.Progress,
|
||
}
|
||
|
||
// Add result or error information if available
|
||
if result, exists := taskResultMap[asyncTask.TaskID]; exists {
|
||
if result.Result != nil {
|
||
taskResult.Result = map[string]any(result.Result)
|
||
}
|
||
if result.ErrorCode != nil {
|
||
taskResult.ErrorCode = result.ErrorCode
|
||
}
|
||
if result.ErrorMessage != nil {
|
||
taskResult.ErrorMessage = result.ErrorMessage
|
||
}
|
||
if result.ErrorDetail != nil {
|
||
taskResult.ErrorDetail = map[string]any(result.ErrorDetail)
|
||
}
|
||
}
|
||
|
||
responseTasks = append(responseTasks, taskResult)
|
||
}
|
||
|
||
// Return success response
|
||
c.JSON(http.StatusOK, network.SuccessResponse{
|
||
Code: 2000,
|
||
Msg: "query completed",
|
||
Payload: network.AsyncTaskResultQueryResponse{
|
||
Total: len(responseTasks),
|
||
Tasks: responseTasks,
|
||
},
|
||
})
|
||
}
|
||
|
||
// AsyncTaskProgressUpdateHandler handles updating task progress (internal use, not exposed via API)
|
||
func AsyncTaskProgressUpdateHandler(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
var request network.AsyncTaskProgressUpdate
|
||
|
||
if err := c.ShouldBindJSON(&request); err != nil {
|
||
logger.Error(ctx, "failed to unmarshal async task progress update request", "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusBadRequest,
|
||
Msg: "invalid request parameters",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Get database connection from context or use default
|
||
db := getDBFromContext(c)
|
||
if db == nil {
|
||
logger.Error(ctx, "database connection not found in context")
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "database connection error",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Update task progress
|
||
err := database.UpdateAsyncTaskProgress(ctx, db, request.TaskID, request.Progress)
|
||
if err != nil {
|
||
logger.Error(ctx, "failed to update async task progress", "task_id", request.TaskID, "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "failed to update task progress",
|
||
})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, network.SuccessResponse{
|
||
Code: 2000,
|
||
Msg: "task progress updated successfully",
|
||
Payload: nil,
|
||
})
|
||
}
|
||
|
||
// AsyncTaskStatusUpdateHandler handles updating task status (internal use, not exposed via API)
|
||
func AsyncTaskStatusUpdateHandler(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
var request network.AsyncTaskStatusUpdate
|
||
|
||
if err := c.ShouldBindJSON(&request); err != nil {
|
||
logger.Error(ctx, "failed to unmarshal async task status update request", "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusBadRequest,
|
||
Msg: "invalid request parameters",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Validate status
|
||
validStatus := map[string]bool{
|
||
string(orm.AsyncTaskStatusSubmitted): true,
|
||
string(orm.AsyncTaskStatusRunning): true,
|
||
string(orm.AsyncTaskStatusCompleted): true,
|
||
string(orm.AsyncTaskStatusFailed): true,
|
||
}
|
||
|
||
if !validStatus[request.Status] {
|
||
logger.Error(ctx, "invalid task status", "status", request.Status)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusBadRequest,
|
||
Msg: "invalid task status",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Get database connection from context or use default
|
||
db := getDBFromContext(c)
|
||
if db == nil {
|
||
logger.Error(ctx, "database connection not found in context")
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "database connection error",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Update task status
|
||
status := orm.AsyncTaskStatus(request.Status)
|
||
err := database.UpdateAsyncTaskStatus(ctx, db, request.TaskID, status)
|
||
if err != nil {
|
||
logger.Error(ctx, "failed to update async task status", "task_id", request.TaskID, "status", request.Status, "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "failed to update task status",
|
||
})
|
||
return
|
||
}
|
||
|
||
// If task is completed or failed, update finished_at timestamp
|
||
if request.Status == string(orm.AsyncTaskStatusCompleted) {
|
||
err = database.CompleteAsyncTask(ctx, db, request.TaskID, request.Timestamp)
|
||
} else if request.Status == string(orm.AsyncTaskStatusFailed) {
|
||
err = database.FailAsyncTask(ctx, db, request.TaskID, request.Timestamp)
|
||
}
|
||
|
||
if err != nil {
|
||
logger.Error(ctx, "failed to update async task completion timestamp", "task_id", request.TaskID, "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "failed to update task completion timestamp",
|
||
})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, network.SuccessResponse{
|
||
Code: 2000,
|
||
Msg: "task status updated successfully",
|
||
Payload: nil,
|
||
})
|
||
}
|
||
|
||
// Helper functions
|
||
|
||
func validateTaskParams(taskType string, params map[string]any) bool {
|
||
switch taskType {
|
||
case string(orm.AsyncTaskTypeTopologyAnalysis):
|
||
return validateTopologyAnalysisParams(params)
|
||
case string(orm.AsyncTaskTypePerformanceAnalysis):
|
||
return validatePerformanceAnalysisParams(params)
|
||
case string(orm.AsyncTaskTypeEventAnalysis):
|
||
return validateEventAnalysisParams(params)
|
||
case string(orm.AsyncTaskTypeBatchImport):
|
||
return validateBatchImportParams(params)
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func validateTopologyAnalysisParams(params map[string]any) bool {
|
||
// Check required parameters for topology analysis
|
||
if startUUID, ok := params["start_uuid"]; !ok || startUUID == "" {
|
||
return false
|
||
}
|
||
if endUUID, ok := params["end_uuid"]; !ok || endUUID == "" {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
func validatePerformanceAnalysisParams(params map[string]any) bool {
|
||
// Check required parameters for performance analysis
|
||
if componentIDs, ok := params["component_ids"]; !ok {
|
||
return false
|
||
} else if ids, isSlice := componentIDs.([]interface{}); !isSlice || len(ids) == 0 {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
func validateEventAnalysisParams(params map[string]any) bool {
|
||
// Check required parameters for event analysis
|
||
if eventType, ok := params["event_type"]; !ok || eventType == "" {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
func validateBatchImportParams(params map[string]any) bool {
|
||
// Check required parameters for batch import
|
||
if filePath, ok := params["file_path"]; !ok || filePath == "" {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
func splitCommaSeparated(s string) []string {
|
||
var result []string
|
||
var current strings.Builder
|
||
inQuotes := false
|
||
escape := false
|
||
|
||
for _, ch := range s {
|
||
if escape {
|
||
current.WriteRune(ch)
|
||
escape = false
|
||
continue
|
||
}
|
||
|
||
switch ch {
|
||
case '\\':
|
||
escape = true
|
||
case '"':
|
||
inQuotes = !inQuotes
|
||
case ',':
|
||
if !inQuotes {
|
||
result = append(result, strings.TrimSpace(current.String()))
|
||
current.Reset()
|
||
} else {
|
||
current.WriteRune(ch)
|
||
}
|
||
default:
|
||
current.WriteRune(ch)
|
||
}
|
||
}
|
||
|
||
if current.Len() > 0 {
|
||
result = append(result, strings.TrimSpace(current.String()))
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
func getDBFromContext(c *gin.Context) *gorm.DB {
|
||
// Try to get database connection from context
|
||
// This should be set by middleware
|
||
if db, exists := c.Get("db"); exists {
|
||
if gormDB, ok := db.(*gorm.DB); ok {
|
||
return gormDB
|
||
}
|
||
}
|
||
|
||
// Fallback to global database connection
|
||
// This should be implemented based on your application's database setup
|
||
// For now, return nil - actual implementation should retrieve from application context
|
||
return nil
|
||
}
|
||
|
||
// AsyncTaskResultDetailHandler handles detailed query of a single async task result
|
||
// @Summary 查询异步任务详情
|
||
// @Description 根据任务ID查询异步任务的详细状态和结果
|
||
// @Tags AsyncTask
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param task_id path string true "任务ID"
|
||
// @Success 200 {object} network.SuccessResponse{payload=network.AsyncTaskResult} "查询成功"
|
||
// @Failure 400 {object} network.FailureResponse "请求参数错误"
|
||
// @Failure 404 {object} network.FailureResponse "任务不存在"
|
||
// @Failure 500 {object} network.FailureResponse "服务器内部错误"
|
||
// @Router /task/async/{task_id} [get]
|
||
func AsyncTaskResultDetailHandler(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
// Parse task ID from path parameter
|
||
taskIDStr := c.Param("task_id")
|
||
if taskIDStr == "" {
|
||
logger.Error(ctx, "task_id parameter is required")
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusBadRequest,
|
||
Msg: "task_id parameter is required",
|
||
})
|
||
return
|
||
}
|
||
|
||
taskID, err := uuid.FromString(taskIDStr)
|
||
if err != nil {
|
||
logger.Error(ctx, "invalid task ID format", "task_id", taskIDStr, "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusBadRequest,
|
||
Msg: "invalid task ID format",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Get database connection from context or use default
|
||
db := getDBFromContext(c)
|
||
if db == nil {
|
||
logger.Error(ctx, "database connection not found in context")
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "database connection error",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Query task from database
|
||
asyncTask, err := database.GetAsyncTaskByID(ctx, db, taskID)
|
||
if err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
logger.Error(ctx, "async task not found", "task_id", taskID)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusNotFound,
|
||
Msg: "task not found",
|
||
})
|
||
return
|
||
}
|
||
logger.Error(ctx, "failed to query async task from database", "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "failed to query task",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Query task result from database
|
||
taskResult, err := database.GetAsyncTaskResult(ctx, db, taskID)
|
||
if err != nil {
|
||
logger.Error(ctx, "failed to query async task result from database", "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "failed to query task result",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Convert to response format
|
||
responseTask := network.AsyncTaskResult{
|
||
TaskID: asyncTask.TaskID,
|
||
TaskType: string(asyncTask.TaskType),
|
||
Status: string(asyncTask.Status),
|
||
CreatedAt: asyncTask.CreatedAt,
|
||
FinishedAt: asyncTask.FinishedAt,
|
||
Progress: asyncTask.Progress,
|
||
}
|
||
|
||
// Add result or error information if available
|
||
if taskResult != nil {
|
||
if taskResult.Result != nil {
|
||
responseTask.Result = map[string]any(taskResult.Result)
|
||
}
|
||
if taskResult.ErrorCode != nil {
|
||
responseTask.ErrorCode = taskResult.ErrorCode
|
||
}
|
||
if taskResult.ErrorMessage != nil {
|
||
responseTask.ErrorMessage = taskResult.ErrorMessage
|
||
}
|
||
if taskResult.ErrorDetail != nil {
|
||
responseTask.ErrorDetail = map[string]any(taskResult.ErrorDetail)
|
||
}
|
||
}
|
||
|
||
// Return success response
|
||
c.JSON(http.StatusOK, network.SuccessResponse{
|
||
Code: 2000,
|
||
Msg: "query completed",
|
||
Payload: responseTask,
|
||
})
|
||
}
|
||
|
||
// AsyncTaskCancelHandler handles cancellation of an async task
|
||
// @Summary 取消异步任务
|
||
// @Description 取消指定ID的异步任务(如果任务尚未开始执行)
|
||
// @Tags AsyncTask
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param task_id path string true "任务ID"
|
||
// @Success 200 {object} network.SuccessResponse "任务取消成功"
|
||
// @Failure 400 {object} network.FailureResponse "请求参数错误或任务无法取消"
|
||
// @Failure 404 {object} network.FailureResponse "任务不存在"
|
||
// @Failure 500 {object} network.FailureResponse "服务器内部错误"
|
||
// @Router /task/async/{task_id}/cancel [post]
|
||
func AsyncTaskCancelHandler(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
// Parse task ID from path parameter
|
||
taskIDStr := c.Param("task_id")
|
||
if taskIDStr == "" {
|
||
logger.Error(ctx, "task_id parameter is required")
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusBadRequest,
|
||
Msg: "task_id parameter is required",
|
||
})
|
||
return
|
||
}
|
||
|
||
taskID, err := uuid.FromString(taskIDStr)
|
||
if err != nil {
|
||
logger.Error(ctx, "invalid task ID format", "task_id", taskIDStr, "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusBadRequest,
|
||
Msg: "invalid task ID format",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Get database connection from context or use default
|
||
db := getDBFromContext(c)
|
||
if db == nil {
|
||
logger.Error(ctx, "database connection not found in context")
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "database connection error",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Query task from database
|
||
asyncTask, err := database.GetAsyncTaskByID(ctx, db, taskID)
|
||
if err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
logger.Error(ctx, "async task not found", "task_id", taskID)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusNotFound,
|
||
Msg: "task not found",
|
||
})
|
||
return
|
||
}
|
||
logger.Error(ctx, "failed to query async task from database", "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "failed to query task",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Check if task can be cancelled (only SUBMITTED tasks can be cancelled)
|
||
if asyncTask.Status != orm.AsyncTaskStatusSubmitted {
|
||
logger.Error(ctx, "task cannot be cancelled", "task_id", taskID, "status", asyncTask.Status)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusBadRequest,
|
||
Msg: "task cannot be cancelled (already running or completed)",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Update task status to failed with cancellation reason
|
||
timestamp := time.Now().Unix()
|
||
err = database.FailAsyncTask(ctx, db, taskID, timestamp)
|
||
if err != nil {
|
||
logger.Error(ctx, "failed to cancel async task", "task_id", taskID, "error", err)
|
||
c.JSON(http.StatusOK, network.FailureResponse{
|
||
Code: http.StatusInternalServerError,
|
||
Msg: "failed to cancel task",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Update task result with cancellation error
|
||
err = database.UpdateAsyncTaskResultWithError(ctx, db, taskID, 40003, "task cancelled by user", orm.JSONMap{
|
||
"cancelled_at": timestamp,
|
||
"cancelled_by": "user",
|
||
})
|
||
if err != nil {
|
||
logger.Error(ctx, "failed to update task result with cancellation error", "task_id", taskID, "error", err)
|
||
// Continue anyway since task is already marked as failed
|
||
}
|
||
|
||
c.JSON(http.StatusOK, network.SuccessResponse{
|
||
Code: 2000,
|
||
Msg: "task cancelled successfully",
|
||
})
|
||
}
|