feat: implement end-to-end distributed tracing for HTTP and async tasks

- introduce typed traceCtxKey to prevent context key collisions (staticcheck fix)
  - inject B3 trace values into c.Request.Context() in StartTrace middleware
    so handlers using c.Request.Context() carry trace info
  - create startup trace context in main.go, replacing context.TODO()
  - propagate HTTP traceID/spanID through TaskQueueMessage into RabbitMQ
    worker, linking HTTP request → publish → execution on the same traceID
  - fix GORM logger null traceID by binding ctx to AutoMigrate and queries
    via db.WithContext(ctx)
  - thread ctx through handler factory to fix null traceID in startup logs
  - replace per-request RabbitMQ producer with channel-based
    PushTaskToRabbitMQ goroutine; restrict Swagger to non-production
This commit is contained in:
douxu 2026-04-23 16:48:32 +08:00
parent 809e1cd87d
commit 03bd058558
16 changed files with 142 additions and 79 deletions

View File

@ -7,3 +7,13 @@ const (
HeaderSpanID = "X-B3-SpanId" HeaderSpanID = "X-B3-SpanId"
HeaderParentSpanID = "X-B3-ParentSpanId" HeaderParentSpanID = "X-B3-ParentSpanId"
) )
// traceCtxKey is an unexported type for context keys to avoid collisions with other packages.
type traceCtxKey string
// Typed context keys for trace values — use these with context.WithValue / ctx.Value.
var (
CtxKeyTraceID = traceCtxKey(HeaderTraceID)
CtxKeySpanID = traceCtxKey(HeaderSpanID)
CtxKeyParentSpanID = traceCtxKey(HeaderParentSpanID)
)

View File

@ -2,6 +2,7 @@
package database package database
import ( import (
"context"
"sync" "sync"
"modelRT/logger" "modelRT/logger"
@ -22,22 +23,22 @@ func GetPostgresDBClient() *gorm.DB {
} }
// InitPostgresDBInstance return instance of PostgresDB client // InitPostgresDBInstance return instance of PostgresDB client
func InitPostgresDBInstance(PostgresDBURI string) *gorm.DB { func InitPostgresDBInstance(ctx context.Context, PostgresDBURI string) *gorm.DB {
postgresOnce.Do(func() { postgresOnce.Do(func() {
_globalPostgresClient = initPostgresDBClient(PostgresDBURI) _globalPostgresClient = initPostgresDBClient(ctx, PostgresDBURI)
}) })
return _globalPostgresClient return _globalPostgresClient
} }
// initPostgresDBClient return successfully initialized PostgresDB client // initPostgresDBClient return successfully initialized PostgresDB client
func initPostgresDBClient(PostgresDBURI string) *gorm.DB { func initPostgresDBClient(ctx context.Context, PostgresDBURI string) *gorm.DB {
db, err := gorm.Open(postgres.Open(PostgresDBURI), &gorm.Config{Logger: logger.NewGormLogger()}) db, err := gorm.Open(postgres.Open(PostgresDBURI), &gorm.Config{Logger: logger.NewGormLogger()})
if err != nil { if err != nil {
panic(err) panic(err)
} }
// Auto migrate async task tables // Auto migrate async task tables
err = db.AutoMigrate( err = db.WithContext(ctx).AutoMigrate(
&orm.AsyncTask{}, &orm.AsyncTask{},
&orm.AsyncTaskResult{}, &orm.AsyncTaskResult{},
) )

View File

@ -2,7 +2,6 @@
package handler package handler
import ( import (
"modelRT/config"
"modelRT/constants" "modelRT/constants"
"modelRT/database" "modelRT/database"
"modelRT/logger" "modelRT/logger"
@ -66,38 +65,17 @@ func AsyncTaskCreateHandler(c *gin.Context) {
return return
} }
// send task to message queue // enqueue task to channel for async publishing to RabbitMQ
cfg, exists := c.Get("config") msg := task.NewTaskQueueMessageWithPriority(asyncTask.TaskID, task.TaskType(request.TaskType), 5)
if !exists { // propagate HTTP request trace so the async chain stays on the same traceID
logger.Warn(ctx, "Configuration not found in context, skipping queue publishing") if v, _ := ctx.Value(constants.CtxKeyTraceID).(string); v != "" {
} else { msg.TraceID = v
modelRTConfig := cfg.(config.ModelRTConfig)
ctx := c.Request.Context()
// create queue producer
// TODO 像实时计算一样使用 channel 代替
producer, err := task.NewQueueProducer(ctx, modelRTConfig.RabbitMQConfig)
if err != nil {
logger.Error(ctx, "create rabbitMQ queue producer failed", "error", err)
renderRespFailure(c, constants.RespCodeServerError, "create rabbitMQ queue producer failed", nil)
return
}
defer producer.Close()
// publish task to queue
taskType := task.TaskType(request.TaskType)
priority := 5 // Default priority
if err := producer.PublishTaskWithRetry(ctx, asyncTask.TaskID, taskType, priority, 3); err != nil {
logger.Error(ctx, "publish task to rabbitMQ queue failed",
"task_id", asyncTask.TaskID, "error", err)
renderRespFailure(c, constants.RespCodeServerError, "publish task to rabbitMQ queue failed", nil)
return
}
logger.Info(ctx, "published task to rabbitMQ queue successfully",
"task_id", asyncTask.TaskID, "queue", constants.TaskQueueName)
} }
if v, _ := ctx.Value(constants.CtxKeySpanID).(string); v != "" {
msg.SpanID = v
}
task.TaskMsgChan <- msg
logger.Info(ctx, "task enqueued to channel", "task_id", asyncTask.TaskID, "queue", constants.TaskQueueName)
logger.Info(ctx, "async task created success", "task_id", asyncTask.TaskID, "task_type", request.TaskType) logger.Info(ctx, "async task created success", "task_id", asyncTask.TaskID, "task_type", request.TaskType)

View File

@ -12,6 +12,14 @@ import (
"go.uber.org/zap/zapcore" "go.uber.org/zap/zapcore"
) )
// Logger is the interface returned by New for structured, trace-aware logging.
type Logger interface {
Debug(msg string, kv ...any)
Info(msg string, kv ...any)
Warn(msg string, kv ...any)
Error(msg string, kv ...any)
}
type logger struct { type logger struct {
ctx context.Context ctx context.Context
traceID string traceID string
@ -48,7 +56,10 @@ func makeLogFields(ctx context.Context, kv ...any) []zap.Field {
kv = append(kv, "unknown") kv = append(kv, "unknown")
} }
kv = append(kv, "traceID", ctx.Value(constants.HeaderTraceID), "spanID", ctx.Value(constants.HeaderSpanID), "parentSpanID", ctx.Value(constants.HeaderParentSpanID)) traceID, _ := ctx.Value(constants.CtxKeyTraceID).(string)
spanID, _ := ctx.Value(constants.CtxKeySpanID).(string)
parentSpanID, _ := ctx.Value(constants.CtxKeyParentSpanID).(string)
kv = append(kv, "traceID", traceID, "spanID", spanID, "parentSpanID", parentSpanID)
funcName, file, line := getLoggerCallerInfo() funcName, file, line := getLoggerCallerInfo()
kv = append(kv, "func", funcName, "file", file, "line", line) kv = append(kv, "func", funcName, "file", file, "line", line)
@ -89,16 +100,18 @@ func getLoggerCallerInfo() (funcName, file string, line int) {
return return
} }
func New(ctx context.Context) *logger { // New returns a logger bound to ctx. Trace fields (traceID, spanID, parentSpanID)
// are extracted from ctx using typed keys, and are included in every log entry.
func New(ctx context.Context) Logger {
var traceID, spanID, pSpanID string var traceID, spanID, pSpanID string
if ctx.Value("traceID") != nil { if v, _ := ctx.Value(constants.CtxKeyTraceID).(string); v != "" {
traceID = ctx.Value("traceID").(string) traceID = v
} }
if ctx.Value("spanID") != nil { if v, _ := ctx.Value(constants.CtxKeySpanID).(string); v != "" {
spanID = ctx.Value("spanID").(string) spanID = v
} }
if ctx.Value("psapnID") != nil { if v, _ := ctx.Value(constants.CtxKeyParentSpanID).(string); v != "" {
pSpanID = ctx.Value("pspanID").(string) pSpanID = v
} }
return &logger{ return &logger{

28
main.go
View File

@ -19,7 +19,6 @@ import (
"modelRT/database" "modelRT/database"
"modelRT/diagram" "modelRT/diagram"
"modelRT/logger" "modelRT/logger"
"modelRT/middleware"
"modelRT/model" "modelRT/model"
"modelRT/mq" "modelRT/mq"
"modelRT/pool" "modelRT/pool"
@ -74,7 +73,9 @@ var (
func main() { func main() {
flag.Parse() flag.Parse()
ctx := context.TODO() startupSpanID := util.GenerateSpanID("startup")
ctx := context.WithValue(context.Background(), constants.CtxKeyTraceID, startupSpanID)
ctx = context.WithValue(ctx, constants.CtxKeySpanID, startupSpanID)
configPath := filepath.Join(*modelRTConfigDir, *modelRTConfigName+"."+*modelRTConfigType) configPath := filepath.Join(*modelRTConfigDir, *modelRTConfigName+"."+*modelRTConfigType)
if _, err := os.Stat(configPath); os.IsNotExist(err) { if _, err := os.Stat(configPath); os.IsNotExist(err) {
@ -113,7 +114,7 @@ func main() {
} }
// init postgresDBClient // init postgresDBClient
postgresDBClient = database.InitPostgresDBInstance(modelRTConfig.PostgresDBURI) postgresDBClient = database.InitPostgresDBInstance(ctx, modelRTConfig.PostgresDBURI)
defer func() { defer func() {
sqlDB, err := postgresDBClient.DB() sqlDB, err := postgresDBClient.DB()
@ -171,8 +172,10 @@ func main() {
// async push event to rabbitMQ // async push event to rabbitMQ
go mq.PushUpDownLimitEventToRabbitMQ(ctx, mq.MsgChan) go mq.PushUpDownLimitEventToRabbitMQ(ctx, mq.MsgChan)
// async push task message to rabbitMQ
go task.PushTaskToRabbitMQ(ctx, modelRTConfig.RabbitMQConfig, task.TaskMsgChan)
postgresDBClient.Transaction(func(tx *gorm.DB) error { postgresDBClient.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// load circuit diagram from postgres // load circuit diagram from postgres
// componentTypeMap, err := database.QueryCircuitDiagramComponentFromDB(cancelCtx, tx, parsePool) // componentTypeMap, err := database.QueryCircuitDiagramComponentFromDB(cancelCtx, tx, parsePool)
// if err != nil { // if err != nil {
@ -246,22 +249,11 @@ func main() {
AllowCredentials: true, AllowCredentials: true,
MaxAge: 12 * time.Hour, MaxAge: 12 * time.Hour,
})) }))
// Register configuration middleware
engine.Use(middleware.ConfigMiddleware(modelRTConfig))
router.RegisterRoutes(engine, serviceToken) router.RegisterRoutes(engine, serviceToken)
// Swagger UI if modelRTConfig.DeployEnv != constants.ProductionDeployMode {
engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
}
// 注册 Swagger UI 路由
// docs.SwaggerInfo.BasePath = "/model"
// v1 := engine.Group("/api/v1")
// {
// eg := v1.Group("/example")
// {
// eg.GET("/helloworld", Helloworld)
// }
// }
server := http.Server{ server := http.Server{
Addr: modelRTConfig.ServiceAddr, Addr: modelRTConfig.ServiceAddr,

View File

@ -1,3 +1,4 @@
// Package middleware define gin framework middlewares
package middleware package middleware
import ( import (
@ -12,4 +13,4 @@ func ConfigMiddleware(modelRTConfig config.ModelRTConfig) gin.HandlerFunc {
c.Set("config", modelRTConfig) c.Set("config", modelRTConfig)
c.Next() c.Next()
} }
} }

View File

@ -1,3 +1,4 @@
// Package middleware define gin framework middlewares
package middleware package middleware
import ( import (

View File

@ -1,3 +1,4 @@
// Package middleware define gin framework middlewares
package middleware package middleware
import ( import (

View File

@ -1,3 +1,4 @@
// Package middleware define gin framework middlewares
package middleware package middleware
import "github.com/gin-gonic/gin" import "github.com/gin-gonic/gin"

View File

@ -1,7 +1,9 @@
// Package middleware define gin framework middlewares
package middleware package middleware
import ( import (
"bytes" "bytes"
"context"
"io" "io"
"strings" "strings"
"time" "time"
@ -27,6 +29,14 @@ func StartTrace() gin.HandlerFunc {
c.Set(constants.HeaderTraceID, traceID) c.Set(constants.HeaderTraceID, traceID)
c.Set(constants.HeaderSpanID, spanID) c.Set(constants.HeaderSpanID, spanID)
c.Set(constants.HeaderParentSpanID, parentSpanID) c.Set(constants.HeaderParentSpanID, parentSpanID)
// also inject into request context so c.Request.Context() carries trace values
reqCtx := c.Request.Context()
reqCtx = context.WithValue(reqCtx, constants.CtxKeyTraceID, traceID)
reqCtx = context.WithValue(reqCtx, constants.CtxKeySpanID, spanID)
reqCtx = context.WithValue(reqCtx, constants.CtxKeyParentSpanID, parentSpanID)
c.Request = c.Request.WithContext(reqCtx)
c.Next() c.Next()
} }
} }
@ -78,7 +88,6 @@ func LogAccess() gin.HandlerFunc {
accessLog(c, "access_end", time.Since(start), reqBody, responseLogging) accessLog(c, "access_end", time.Since(start), reqBody, responseLogging)
}() }()
c.Next() c.Next()
return
} }
} }

View File

@ -22,7 +22,7 @@ func GetNSpathToIsLocalMap(ctx context.Context, db *gorm.DB) (map[string]bool, e
var results []ComponentStationRelation var results []ComponentStationRelation
nspathMap := make(map[string]bool) nspathMap := make(map[string]bool)
err := db.Table("component"). err := db.WithContext(ctx).Table("component").
Select("component.nspath, station.is_local"). Select("component.nspath, station.is_local").
Joins("join station on component.station_id = station.id"). Joins("join station on component.station_id = station.id").
Scan(&results).Error Scan(&results).Error

View File

@ -36,12 +36,12 @@ func NewHandlerFactory() *HandlerFactory {
} }
// RegisterHandler registers a handler for a specific task type // RegisterHandler registers a handler for a specific task type
func (f *HandlerFactory) RegisterHandler(taskType TaskType, handler TaskHandler) { func (f *HandlerFactory) RegisterHandler(ctx context.Context, taskType TaskType, handler TaskHandler) {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
f.handlers[taskType] = handler f.handlers[taskType] = handler
logger.Info(context.Background(), "Handler registered", logger.Info(ctx, "Handler registered",
"task_type", taskType, "task_type", taskType,
"handler_name", handler.Name(), "handler_name", handler.Name(),
) )
@ -61,11 +61,11 @@ func (f *HandlerFactory) GetHandler(taskType TaskType) (TaskHandler, error) {
} }
// CreateDefaultHandlers registers all default task handlers // CreateDefaultHandlers registers all default task handlers
func (f *HandlerFactory) CreateDefaultHandlers() { func (f *HandlerFactory) CreateDefaultHandlers(ctx context.Context) {
f.RegisterHandler(TypeTopologyAnalysis, &TopologyAnalysisHandler{}) f.RegisterHandler(ctx, TypeTopologyAnalysis, &TopologyAnalysisHandler{})
f.RegisterHandler(TypeEventAnalysis, &EventAnalysisHandler{}) f.RegisterHandler(ctx, TypeEventAnalysis, &EventAnalysisHandler{})
f.RegisterHandler(TypeBatchImport, &BatchImportHandler{}) f.RegisterHandler(ctx, TypeBatchImport, &BatchImportHandler{})
f.RegisterHandler(TaskType(TaskTypeTest), NewTestTaskHandler()) f.RegisterHandler(ctx, TaskType(TaskTypeTest), NewTestTaskHandler())
} }
// BaseHandler provides common functionality for all task handlers // BaseHandler provides common functionality for all task handlers
@ -235,14 +235,14 @@ func (h *CompositeHandler) Name() string {
} }
// DefaultHandlerFactory returns a HandlerFactory with all default handlers registered // DefaultHandlerFactory returns a HandlerFactory with all default handlers registered
func DefaultHandlerFactory() *HandlerFactory { func DefaultHandlerFactory(ctx context.Context) *HandlerFactory {
factory := NewHandlerFactory() factory := NewHandlerFactory()
factory.CreateDefaultHandlers() factory.CreateDefaultHandlers(ctx)
return factory return factory
} }
// DefaultCompositeHandler returns a CompositeHandler with all default handlers // DefaultCompositeHandler returns a CompositeHandler with all default handlers
func DefaultCompositeHandler() TaskHandler { func DefaultCompositeHandler(ctx context.Context) TaskHandler {
factory := DefaultHandlerFactory() factory := DefaultHandlerFactory(ctx)
return NewCompositeHandler(factory) return NewCompositeHandler(factory)
} }

View File

@ -23,8 +23,8 @@ func InitTaskWorker(ctx context.Context, config config.ModelRTConfig, db *gorm.D
// Create task handler factory // Create task handler factory
handlerFactory := NewHandlerFactory() handlerFactory := NewHandlerFactory()
handlerFactory.CreateDefaultHandlers() handlerFactory.CreateDefaultHandlers(ctx)
handler := DefaultCompositeHandler() handler := DefaultCompositeHandler(ctx)
// Create task worker // Create task worker
worker, err := NewTaskWorker(ctx, workerCfg, db, config.RabbitMQConfig, handler) worker, err := NewTaskWorker(ctx, workerCfg, db, config.RabbitMQConfig, handler)
@ -38,4 +38,4 @@ func InitTaskWorker(ctx context.Context, config config.ModelRTConfig, db *gorm.D
) )
return worker, nil return worker, nil
} }

View File

@ -14,6 +14,8 @@ type TaskQueueMessage struct {
TaskID uuid.UUID `json:"task_id"` TaskID uuid.UUID `json:"task_id"`
TaskType TaskType `json:"task_type"` TaskType TaskType `json:"task_type"`
Priority int `json:"priority,omitempty"` // Optional, defaults to constants.TaskPriorityDefault Priority int `json:"priority,omitempty"` // Optional, defaults to constants.TaskPriorityDefault
TraceID string `json:"trace_id,omitempty"` // propagated from the originating HTTP request
SpanID string `json:"span_id,omitempty"` // spanID of the step that enqueued this message
} }
// NewTaskQueueMessage creates a new TaskQueueMessage with default priority // NewTaskQueueMessage creates a new TaskQueueMessage with default priority

View File

@ -11,11 +11,19 @@ import (
"modelRT/constants" "modelRT/constants"
"modelRT/logger" "modelRT/logger"
"modelRT/mq" "modelRT/mq"
"modelRT/util"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
amqp "github.com/rabbitmq/amqp091-go" amqp "github.com/rabbitmq/amqp091-go"
) )
// TaskMsgChan buffers task messages to be published to RabbitMQ asynchronously
var TaskMsgChan chan *TaskQueueMessage
func init() {
TaskMsgChan = make(chan *TaskQueueMessage, 10000)
}
// QueueProducer handles publishing tasks to RabbitMQ // QueueProducer handles publishing tasks to RabbitMQ
type QueueProducer struct { type QueueProducer struct {
conn *amqp.Connection conn *amqp.Connection
@ -212,4 +220,39 @@ func (p *QueueProducer) GetQueueInfo() (*amqp.Queue, error) {
// PurgeQueue removes all messages from the task queue // PurgeQueue removes all messages from the task queue
func (p *QueueProducer) PurgeQueue() (int, error) { func (p *QueueProducer) PurgeQueue() (int, error) {
return p.ch.QueuePurge(constants.TaskQueueName, false) return p.ch.QueuePurge(constants.TaskQueueName, false)
}
// PushTaskToRabbitMQ reads from taskChan and publishes to RabbitMQ.
// Must be run as a goroutine; blocks until ctx is cancelled or taskChan is closed.
func PushTaskToRabbitMQ(ctx context.Context, cfg config.RabbitMQConfig, taskChan chan *TaskQueueMessage) {
producer, err := NewQueueProducer(ctx, cfg)
if err != nil {
logger.Error(ctx, "init task queue producer failed", "error", err)
return
}
defer producer.Close()
for {
select {
case <-ctx.Done():
logger.Info(ctx, "push task to RabbitMQ stopped by context cancel")
return
case msg, ok := <-taskChan:
if !ok {
logger.Info(ctx, "task channel closed, exiting push loop")
return
}
traceID := msg.TraceID
if traceID == "" {
traceID = msg.TaskID.String() // fallback when no HTTP trace was propagated
}
taskCtx := context.WithValue(ctx, constants.CtxKeyTraceID, traceID)
taskCtx = context.WithValue(taskCtx, constants.CtxKeySpanID, util.GenerateSpanID("task-publish"))
taskCtx = context.WithValue(taskCtx, constants.CtxKeyParentSpanID, msg.SpanID)
if err := producer.PublishTaskWithRetry(taskCtx, msg.TaskID, msg.TaskType, msg.Priority, 3); err != nil {
logger.Error(taskCtx, "publish task to RabbitMQ failed",
"task_id", msg.TaskID, "error", err)
}
}
}
} }

View File

@ -14,6 +14,7 @@ import (
"modelRT/logger" "modelRT/logger"
"modelRT/mq" "modelRT/mq"
"modelRT/orm" "modelRT/orm"
"modelRT/util"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
"github.com/panjf2000/ants/v2" "github.com/panjf2000/ants/v2"
@ -282,6 +283,16 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) {
return return
} }
// derive a per-task context carrying the trace propagated from the originating HTTP request
traceID := taskMsg.TraceID
if traceID == "" {
traceID = taskMsg.TaskID.String() // fallback when message carries no trace
}
taskCtx := context.WithValue(ctx, constants.CtxKeyTraceID, traceID)
taskCtx = context.WithValue(taskCtx, constants.CtxKeySpanID, util.GenerateSpanID("task-worker"))
taskCtx = context.WithValue(taskCtx, constants.CtxKeyParentSpanID, taskMsg.SpanID)
ctx = taskCtx
logger.Info(ctx, "Processing task", logger.Info(ctx, "Processing task",
"task_id", taskMsg.TaskID, "task_id", taskMsg.TaskID,
"task_type", taskMsg.TaskType, "task_type", taskMsg.TaskType,