diff --git a/constants/trace.go b/constants/trace.go index 14f52ed..ab5c5df 100644 --- a/constants/trace.go +++ b/constants/trace.go @@ -7,3 +7,13 @@ const ( HeaderSpanID = "X-B3-SpanId" HeaderParentSpanID = "X-B3-ParentSpanId" ) + +// traceCtxKey is an unexported type for context keys to avoid collisions with other packages. +type traceCtxKey string + +// Typed context keys for trace values — use these with context.WithValue / ctx.Value. +var ( + CtxKeyTraceID = traceCtxKey(HeaderTraceID) + CtxKeySpanID = traceCtxKey(HeaderSpanID) + CtxKeyParentSpanID = traceCtxKey(HeaderParentSpanID) +) diff --git a/database/postgres_init.go b/database/postgres_init.go index d2babd4..c2e0ed0 100644 --- a/database/postgres_init.go +++ b/database/postgres_init.go @@ -2,6 +2,7 @@ package database import ( + "context" "sync" "modelRT/logger" @@ -22,22 +23,22 @@ func GetPostgresDBClient() *gorm.DB { } // InitPostgresDBInstance return instance of PostgresDB client -func InitPostgresDBInstance(PostgresDBURI string) *gorm.DB { +func InitPostgresDBInstance(ctx context.Context, PostgresDBURI string) *gorm.DB { postgresOnce.Do(func() { - _globalPostgresClient = initPostgresDBClient(PostgresDBURI) + _globalPostgresClient = initPostgresDBClient(ctx, PostgresDBURI) }) return _globalPostgresClient } // initPostgresDBClient return successfully initialized PostgresDB client -func initPostgresDBClient(PostgresDBURI string) *gorm.DB { +func initPostgresDBClient(ctx context.Context, PostgresDBURI string) *gorm.DB { db, err := gorm.Open(postgres.Open(PostgresDBURI), &gorm.Config{Logger: logger.NewGormLogger()}) if err != nil { panic(err) } // Auto migrate async task tables - err = db.AutoMigrate( + err = db.WithContext(ctx).AutoMigrate( &orm.AsyncTask{}, &orm.AsyncTaskResult{}, ) diff --git a/handler/async_task_create_handler.go b/handler/async_task_create_handler.go index 2726ce5..6bf8a72 100644 --- a/handler/async_task_create_handler.go +++ b/handler/async_task_create_handler.go @@ -2,7 +2,6 @@ package handler import ( - "modelRT/config" "modelRT/constants" "modelRT/database" "modelRT/logger" @@ -66,38 +65,17 @@ func AsyncTaskCreateHandler(c *gin.Context) { return } - // send task to message queue - cfg, exists := c.Get("config") - if !exists { - logger.Warn(ctx, "Configuration not found in context, skipping queue publishing") - } else { - modelRTConfig := cfg.(config.ModelRTConfig) - ctx := c.Request.Context() - - // create queue producer - // TODO 像实时计算一样使用 channel 代替 - producer, err := task.NewQueueProducer(ctx, modelRTConfig.RabbitMQConfig) - if err != nil { - logger.Error(ctx, "create rabbitMQ queue producer failed", "error", err) - renderRespFailure(c, constants.RespCodeServerError, "create rabbitMQ queue producer failed", nil) - return - } - defer producer.Close() - - // publish task to queue - taskType := task.TaskType(request.TaskType) - priority := 5 // Default priority - - if err := producer.PublishTaskWithRetry(ctx, asyncTask.TaskID, taskType, priority, 3); err != nil { - logger.Error(ctx, "publish task to rabbitMQ queue failed", - "task_id", asyncTask.TaskID, "error", err) - renderRespFailure(c, constants.RespCodeServerError, "publish task to rabbitMQ queue failed", nil) - return - } - logger.Info(ctx, "published task to rabbitMQ queue successfully", - "task_id", asyncTask.TaskID, "queue", constants.TaskQueueName) - + // enqueue task to channel for async publishing to RabbitMQ + msg := task.NewTaskQueueMessageWithPriority(asyncTask.TaskID, task.TaskType(request.TaskType), 5) + // propagate HTTP request trace so the async chain stays on the same traceID + if v, _ := ctx.Value(constants.CtxKeyTraceID).(string); v != "" { + msg.TraceID = v } + if v, _ := ctx.Value(constants.CtxKeySpanID).(string); v != "" { + msg.SpanID = v + } + task.TaskMsgChan <- msg + logger.Info(ctx, "task enqueued to channel", "task_id", asyncTask.TaskID, "queue", constants.TaskQueueName) logger.Info(ctx, "async task created success", "task_id", asyncTask.TaskID, "task_type", request.TaskType) diff --git a/logger/logger.go b/logger/logger.go index 116e3b1..0d4317a 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -12,6 +12,14 @@ import ( "go.uber.org/zap/zapcore" ) +// Logger is the interface returned by New for structured, trace-aware logging. +type Logger interface { + Debug(msg string, kv ...any) + Info(msg string, kv ...any) + Warn(msg string, kv ...any) + Error(msg string, kv ...any) +} + type logger struct { ctx context.Context traceID string @@ -48,7 +56,10 @@ func makeLogFields(ctx context.Context, kv ...any) []zap.Field { kv = append(kv, "unknown") } - kv = append(kv, "traceID", ctx.Value(constants.HeaderTraceID), "spanID", ctx.Value(constants.HeaderSpanID), "parentSpanID", ctx.Value(constants.HeaderParentSpanID)) + traceID, _ := ctx.Value(constants.CtxKeyTraceID).(string) + spanID, _ := ctx.Value(constants.CtxKeySpanID).(string) + parentSpanID, _ := ctx.Value(constants.CtxKeyParentSpanID).(string) + kv = append(kv, "traceID", traceID, "spanID", spanID, "parentSpanID", parentSpanID) funcName, file, line := getLoggerCallerInfo() kv = append(kv, "func", funcName, "file", file, "line", line) @@ -89,16 +100,18 @@ func getLoggerCallerInfo() (funcName, file string, line int) { return } -func New(ctx context.Context) *logger { +// New returns a logger bound to ctx. Trace fields (traceID, spanID, parentSpanID) +// are extracted from ctx using typed keys, and are included in every log entry. +func New(ctx context.Context) Logger { var traceID, spanID, pSpanID string - if ctx.Value("traceID") != nil { - traceID = ctx.Value("traceID").(string) + if v, _ := ctx.Value(constants.CtxKeyTraceID).(string); v != "" { + traceID = v } - if ctx.Value("spanID") != nil { - spanID = ctx.Value("spanID").(string) + if v, _ := ctx.Value(constants.CtxKeySpanID).(string); v != "" { + spanID = v } - if ctx.Value("psapnID") != nil { - pSpanID = ctx.Value("pspanID").(string) + if v, _ := ctx.Value(constants.CtxKeyParentSpanID).(string); v != "" { + pSpanID = v } return &logger{ diff --git a/main.go b/main.go index 198bec0..103eca9 100644 --- a/main.go +++ b/main.go @@ -19,7 +19,6 @@ import ( "modelRT/database" "modelRT/diagram" "modelRT/logger" - "modelRT/middleware" "modelRT/model" "modelRT/mq" "modelRT/pool" @@ -74,7 +73,9 @@ var ( func main() { flag.Parse() - ctx := context.TODO() + startupSpanID := util.GenerateSpanID("startup") + ctx := context.WithValue(context.Background(), constants.CtxKeyTraceID, startupSpanID) + ctx = context.WithValue(ctx, constants.CtxKeySpanID, startupSpanID) configPath := filepath.Join(*modelRTConfigDir, *modelRTConfigName+"."+*modelRTConfigType) if _, err := os.Stat(configPath); os.IsNotExist(err) { @@ -113,7 +114,7 @@ func main() { } // init postgresDBClient - postgresDBClient = database.InitPostgresDBInstance(modelRTConfig.PostgresDBURI) + postgresDBClient = database.InitPostgresDBInstance(ctx, modelRTConfig.PostgresDBURI) defer func() { sqlDB, err := postgresDBClient.DB() @@ -171,8 +172,10 @@ func main() { // async push event to rabbitMQ go mq.PushUpDownLimitEventToRabbitMQ(ctx, mq.MsgChan) + // async push task message to rabbitMQ + go task.PushTaskToRabbitMQ(ctx, modelRTConfig.RabbitMQConfig, task.TaskMsgChan) - postgresDBClient.Transaction(func(tx *gorm.DB) error { + postgresDBClient.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // load circuit diagram from postgres // componentTypeMap, err := database.QueryCircuitDiagramComponentFromDB(cancelCtx, tx, parsePool) // if err != nil { @@ -246,22 +249,11 @@ func main() { AllowCredentials: true, MaxAge: 12 * time.Hour, })) - // Register configuration middleware - engine.Use(middleware.ConfigMiddleware(modelRTConfig)) router.RegisterRoutes(engine, serviceToken) - // Swagger UI - engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) - - // 注册 Swagger UI 路由 - // docs.SwaggerInfo.BasePath = "/model" - // v1 := engine.Group("/api/v1") - // { - // eg := v1.Group("/example") - // { - // eg.GET("/helloworld", Helloworld) - // } - // } + if modelRTConfig.DeployEnv != constants.ProductionDeployMode { + engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) + } server := http.Server{ Addr: modelRTConfig.ServiceAddr, diff --git a/middleware/config_middleware.go b/middleware/config_middleware.go index ff56995..9ff5436 100644 --- a/middleware/config_middleware.go +++ b/middleware/config_middleware.go @@ -1,3 +1,4 @@ +// Package middleware define gin framework middlewares package middleware import ( @@ -12,4 +13,4 @@ func ConfigMiddleware(modelRTConfig config.ModelRTConfig) gin.HandlerFunc { c.Set("config", modelRTConfig) c.Next() } -} \ No newline at end of file +} diff --git a/middleware/limiter.go b/middleware/limiter.go index ca8dc53..d1debe1 100644 --- a/middleware/limiter.go +++ b/middleware/limiter.go @@ -1,3 +1,4 @@ +// Package middleware define gin framework middlewares package middleware import ( diff --git a/middleware/panic_recover.go b/middleware/panic_recover.go index e7aa3ac..b8c101f 100644 --- a/middleware/panic_recover.go +++ b/middleware/panic_recover.go @@ -1,3 +1,4 @@ +// Package middleware define gin framework middlewares package middleware import ( diff --git a/middleware/token.go b/middleware/token.go index 06ffa79..6759f40 100644 --- a/middleware/token.go +++ b/middleware/token.go @@ -1,3 +1,4 @@ +// Package middleware define gin framework middlewares package middleware import "github.com/gin-gonic/gin" diff --git a/middleware/trace.go b/middleware/trace.go index 1bb27ca..604a08c 100644 --- a/middleware/trace.go +++ b/middleware/trace.go @@ -1,7 +1,9 @@ +// Package middleware define gin framework middlewares package middleware import ( "bytes" + "context" "io" "strings" "time" @@ -27,6 +29,14 @@ func StartTrace() gin.HandlerFunc { c.Set(constants.HeaderTraceID, traceID) c.Set(constants.HeaderSpanID, spanID) c.Set(constants.HeaderParentSpanID, parentSpanID) + + // also inject into request context so c.Request.Context() carries trace values + reqCtx := c.Request.Context() + reqCtx = context.WithValue(reqCtx, constants.CtxKeyTraceID, traceID) + reqCtx = context.WithValue(reqCtx, constants.CtxKeySpanID, spanID) + reqCtx = context.WithValue(reqCtx, constants.CtxKeyParentSpanID, parentSpanID) + c.Request = c.Request.WithContext(reqCtx) + c.Next() } } @@ -78,7 +88,6 @@ func LogAccess() gin.HandlerFunc { accessLog(c, "access_end", time.Since(start), reqBody, responseLogging) }() c.Next() - return } } diff --git a/model/recommend_islocal_cache.go b/model/recommend_islocal_cache.go index 8bb7cda..a9f54b5 100644 --- a/model/recommend_islocal_cache.go +++ b/model/recommend_islocal_cache.go @@ -22,7 +22,7 @@ func GetNSpathToIsLocalMap(ctx context.Context, db *gorm.DB) (map[string]bool, e var results []ComponentStationRelation nspathMap := make(map[string]bool) - err := db.Table("component"). + err := db.WithContext(ctx).Table("component"). Select("component.nspath, station.is_local"). Joins("join station on component.station_id = station.id"). Scan(&results).Error diff --git a/task/handler_factory.go b/task/handler_factory.go index c472fef..1efdf65 100644 --- a/task/handler_factory.go +++ b/task/handler_factory.go @@ -36,12 +36,12 @@ func NewHandlerFactory() *HandlerFactory { } // RegisterHandler registers a handler for a specific task type -func (f *HandlerFactory) RegisterHandler(taskType TaskType, handler TaskHandler) { +func (f *HandlerFactory) RegisterHandler(ctx context.Context, taskType TaskType, handler TaskHandler) { f.mu.Lock() defer f.mu.Unlock() f.handlers[taskType] = handler - logger.Info(context.Background(), "Handler registered", + logger.Info(ctx, "Handler registered", "task_type", taskType, "handler_name", handler.Name(), ) @@ -61,11 +61,11 @@ func (f *HandlerFactory) GetHandler(taskType TaskType) (TaskHandler, error) { } // CreateDefaultHandlers registers all default task handlers -func (f *HandlerFactory) CreateDefaultHandlers() { - f.RegisterHandler(TypeTopologyAnalysis, &TopologyAnalysisHandler{}) - f.RegisterHandler(TypeEventAnalysis, &EventAnalysisHandler{}) - f.RegisterHandler(TypeBatchImport, &BatchImportHandler{}) - f.RegisterHandler(TaskType(TaskTypeTest), NewTestTaskHandler()) +func (f *HandlerFactory) CreateDefaultHandlers(ctx context.Context) { + f.RegisterHandler(ctx, TypeTopologyAnalysis, &TopologyAnalysisHandler{}) + f.RegisterHandler(ctx, TypeEventAnalysis, &EventAnalysisHandler{}) + f.RegisterHandler(ctx, TypeBatchImport, &BatchImportHandler{}) + f.RegisterHandler(ctx, TaskType(TaskTypeTest), NewTestTaskHandler()) } // BaseHandler provides common functionality for all task handlers @@ -235,14 +235,14 @@ func (h *CompositeHandler) Name() string { } // DefaultHandlerFactory returns a HandlerFactory with all default handlers registered -func DefaultHandlerFactory() *HandlerFactory { +func DefaultHandlerFactory(ctx context.Context) *HandlerFactory { factory := NewHandlerFactory() - factory.CreateDefaultHandlers() + factory.CreateDefaultHandlers(ctx) return factory } // DefaultCompositeHandler returns a CompositeHandler with all default handlers -func DefaultCompositeHandler() TaskHandler { - factory := DefaultHandlerFactory() +func DefaultCompositeHandler(ctx context.Context) TaskHandler { + factory := DefaultHandlerFactory(ctx) return NewCompositeHandler(factory) } \ No newline at end of file diff --git a/task/initializer.go b/task/initializer.go index de75cea..1122a05 100644 --- a/task/initializer.go +++ b/task/initializer.go @@ -23,8 +23,8 @@ func InitTaskWorker(ctx context.Context, config config.ModelRTConfig, db *gorm.D // Create task handler factory handlerFactory := NewHandlerFactory() - handlerFactory.CreateDefaultHandlers() - handler := DefaultCompositeHandler() + handlerFactory.CreateDefaultHandlers(ctx) + handler := DefaultCompositeHandler(ctx) // Create task worker worker, err := NewTaskWorker(ctx, workerCfg, db, config.RabbitMQConfig, handler) @@ -38,4 +38,4 @@ func InitTaskWorker(ctx context.Context, config config.ModelRTConfig, db *gorm.D ) return worker, nil -} \ No newline at end of file +} diff --git a/task/queue_message.go b/task/queue_message.go index 6ce8028..95b717a 100644 --- a/task/queue_message.go +++ b/task/queue_message.go @@ -14,6 +14,8 @@ type TaskQueueMessage struct { TaskID uuid.UUID `json:"task_id"` TaskType TaskType `json:"task_type"` Priority int `json:"priority,omitempty"` // Optional, defaults to constants.TaskPriorityDefault + TraceID string `json:"trace_id,omitempty"` // propagated from the originating HTTP request + SpanID string `json:"span_id,omitempty"` // spanID of the step that enqueued this message } // NewTaskQueueMessage creates a new TaskQueueMessage with default priority diff --git a/task/queue_producer.go b/task/queue_producer.go index 650a2bf..97e6da7 100644 --- a/task/queue_producer.go +++ b/task/queue_producer.go @@ -11,11 +11,19 @@ import ( "modelRT/constants" "modelRT/logger" "modelRT/mq" + "modelRT/util" "github.com/gofrs/uuid" amqp "github.com/rabbitmq/amqp091-go" ) +// TaskMsgChan buffers task messages to be published to RabbitMQ asynchronously +var TaskMsgChan chan *TaskQueueMessage + +func init() { + TaskMsgChan = make(chan *TaskQueueMessage, 10000) +} + // QueueProducer handles publishing tasks to RabbitMQ type QueueProducer struct { conn *amqp.Connection @@ -212,4 +220,39 @@ func (p *QueueProducer) GetQueueInfo() (*amqp.Queue, error) { // PurgeQueue removes all messages from the task queue func (p *QueueProducer) PurgeQueue() (int, error) { return p.ch.QueuePurge(constants.TaskQueueName, false) +} + +// PushTaskToRabbitMQ reads from taskChan and publishes to RabbitMQ. +// Must be run as a goroutine; blocks until ctx is cancelled or taskChan is closed. +func PushTaskToRabbitMQ(ctx context.Context, cfg config.RabbitMQConfig, taskChan chan *TaskQueueMessage) { + producer, err := NewQueueProducer(ctx, cfg) + if err != nil { + logger.Error(ctx, "init task queue producer failed", "error", err) + return + } + defer producer.Close() + + for { + select { + case <-ctx.Done(): + logger.Info(ctx, "push task to RabbitMQ stopped by context cancel") + return + case msg, ok := <-taskChan: + if !ok { + logger.Info(ctx, "task channel closed, exiting push loop") + return + } + traceID := msg.TraceID + if traceID == "" { + traceID = msg.TaskID.String() // fallback when no HTTP trace was propagated + } + taskCtx := context.WithValue(ctx, constants.CtxKeyTraceID, traceID) + taskCtx = context.WithValue(taskCtx, constants.CtxKeySpanID, util.GenerateSpanID("task-publish")) + taskCtx = context.WithValue(taskCtx, constants.CtxKeyParentSpanID, msg.SpanID) + if err := producer.PublishTaskWithRetry(taskCtx, msg.TaskID, msg.TaskType, msg.Priority, 3); err != nil { + logger.Error(taskCtx, "publish task to RabbitMQ failed", + "task_id", msg.TaskID, "error", err) + } + } + } } \ No newline at end of file diff --git a/task/worker.go b/task/worker.go index 03f3728..6a47b56 100644 --- a/task/worker.go +++ b/task/worker.go @@ -14,6 +14,7 @@ import ( "modelRT/logger" "modelRT/mq" "modelRT/orm" + "modelRT/util" "github.com/gofrs/uuid" "github.com/panjf2000/ants/v2" @@ -282,6 +283,16 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) { return } + // derive a per-task context carrying the trace propagated from the originating HTTP request + traceID := taskMsg.TraceID + if traceID == "" { + traceID = taskMsg.TaskID.String() // fallback when message carries no trace + } + taskCtx := context.WithValue(ctx, constants.CtxKeyTraceID, traceID) + taskCtx = context.WithValue(taskCtx, constants.CtxKeySpanID, util.GenerateSpanID("task-worker")) + taskCtx = context.WithValue(taskCtx, constants.CtxKeyParentSpanID, taskMsg.SpanID) + ctx = taskCtx + logger.Info(ctx, "Processing task", "task_id", taskMsg.TaskID, "task_type", taskMsg.TaskType,