feat: implement end-to-end distributed tracing for HTTP and async tasks

- introduce typed traceCtxKey to prevent context key collisions (staticcheck fix)
  - inject B3 trace values into c.Request.Context() in StartTrace middleware
    so handlers using c.Request.Context() carry trace info
  - create startup trace context in main.go, replacing context.TODO()
  - propagate HTTP traceID/spanID through TaskQueueMessage into RabbitMQ
    worker, linking HTTP request → publish → execution on the same traceID
  - fix GORM logger null traceID by binding ctx to AutoMigrate and queries
    via db.WithContext(ctx)
  - thread ctx through handler factory to fix null traceID in startup logs
  - replace per-request RabbitMQ producer with channel-based
    PushTaskToRabbitMQ goroutine; restrict Swagger to non-production
This commit is contained in:
douxu 2026-04-23 16:48:32 +08:00
parent 809e1cd87d
commit 03bd058558
16 changed files with 142 additions and 79 deletions

View File

@ -7,3 +7,13 @@ const (
HeaderSpanID = "X-B3-SpanId"
HeaderParentSpanID = "X-B3-ParentSpanId"
)
// traceCtxKey is an unexported type for context keys to avoid collisions with other packages.
type traceCtxKey string
// Typed context keys for trace values — use these with context.WithValue / ctx.Value.
var (
CtxKeyTraceID = traceCtxKey(HeaderTraceID)
CtxKeySpanID = traceCtxKey(HeaderSpanID)
CtxKeyParentSpanID = traceCtxKey(HeaderParentSpanID)
)

View File

@ -2,6 +2,7 @@
package database
import (
"context"
"sync"
"modelRT/logger"
@ -22,22 +23,22 @@ func GetPostgresDBClient() *gorm.DB {
}
// InitPostgresDBInstance return instance of PostgresDB client
func InitPostgresDBInstance(PostgresDBURI string) *gorm.DB {
func InitPostgresDBInstance(ctx context.Context, PostgresDBURI string) *gorm.DB {
postgresOnce.Do(func() {
_globalPostgresClient = initPostgresDBClient(PostgresDBURI)
_globalPostgresClient = initPostgresDBClient(ctx, PostgresDBURI)
})
return _globalPostgresClient
}
// initPostgresDBClient return successfully initialized PostgresDB client
func initPostgresDBClient(PostgresDBURI string) *gorm.DB {
func initPostgresDBClient(ctx context.Context, PostgresDBURI string) *gorm.DB {
db, err := gorm.Open(postgres.Open(PostgresDBURI), &gorm.Config{Logger: logger.NewGormLogger()})
if err != nil {
panic(err)
}
// Auto migrate async task tables
err = db.AutoMigrate(
err = db.WithContext(ctx).AutoMigrate(
&orm.AsyncTask{},
&orm.AsyncTaskResult{},
)

View File

@ -2,7 +2,6 @@
package handler
import (
"modelRT/config"
"modelRT/constants"
"modelRT/database"
"modelRT/logger"
@ -66,38 +65,17 @@ func AsyncTaskCreateHandler(c *gin.Context) {
return
}
// send task to message queue
cfg, exists := c.Get("config")
if !exists {
logger.Warn(ctx, "Configuration not found in context, skipping queue publishing")
} else {
modelRTConfig := cfg.(config.ModelRTConfig)
ctx := c.Request.Context()
// create queue producer
// TODO 像实时计算一样使用 channel 代替
producer, err := task.NewQueueProducer(ctx, modelRTConfig.RabbitMQConfig)
if err != nil {
logger.Error(ctx, "create rabbitMQ queue producer failed", "error", err)
renderRespFailure(c, constants.RespCodeServerError, "create rabbitMQ queue producer failed", nil)
return
}
defer producer.Close()
// publish task to queue
taskType := task.TaskType(request.TaskType)
priority := 5 // Default priority
if err := producer.PublishTaskWithRetry(ctx, asyncTask.TaskID, taskType, priority, 3); err != nil {
logger.Error(ctx, "publish task to rabbitMQ queue failed",
"task_id", asyncTask.TaskID, "error", err)
renderRespFailure(c, constants.RespCodeServerError, "publish task to rabbitMQ queue failed", nil)
return
}
logger.Info(ctx, "published task to rabbitMQ queue successfully",
"task_id", asyncTask.TaskID, "queue", constants.TaskQueueName)
// enqueue task to channel for async publishing to RabbitMQ
msg := task.NewTaskQueueMessageWithPriority(asyncTask.TaskID, task.TaskType(request.TaskType), 5)
// propagate HTTP request trace so the async chain stays on the same traceID
if v, _ := ctx.Value(constants.CtxKeyTraceID).(string); v != "" {
msg.TraceID = v
}
if v, _ := ctx.Value(constants.CtxKeySpanID).(string); v != "" {
msg.SpanID = v
}
task.TaskMsgChan <- msg
logger.Info(ctx, "task enqueued to channel", "task_id", asyncTask.TaskID, "queue", constants.TaskQueueName)
logger.Info(ctx, "async task created success", "task_id", asyncTask.TaskID, "task_type", request.TaskType)

View File

@ -12,6 +12,14 @@ import (
"go.uber.org/zap/zapcore"
)
// Logger is the interface returned by New for structured, trace-aware logging.
type Logger interface {
Debug(msg string, kv ...any)
Info(msg string, kv ...any)
Warn(msg string, kv ...any)
Error(msg string, kv ...any)
}
type logger struct {
ctx context.Context
traceID string
@ -48,7 +56,10 @@ func makeLogFields(ctx context.Context, kv ...any) []zap.Field {
kv = append(kv, "unknown")
}
kv = append(kv, "traceID", ctx.Value(constants.HeaderTraceID), "spanID", ctx.Value(constants.HeaderSpanID), "parentSpanID", ctx.Value(constants.HeaderParentSpanID))
traceID, _ := ctx.Value(constants.CtxKeyTraceID).(string)
spanID, _ := ctx.Value(constants.CtxKeySpanID).(string)
parentSpanID, _ := ctx.Value(constants.CtxKeyParentSpanID).(string)
kv = append(kv, "traceID", traceID, "spanID", spanID, "parentSpanID", parentSpanID)
funcName, file, line := getLoggerCallerInfo()
kv = append(kv, "func", funcName, "file", file, "line", line)
@ -89,16 +100,18 @@ func getLoggerCallerInfo() (funcName, file string, line int) {
return
}
func New(ctx context.Context) *logger {
// New returns a logger bound to ctx. Trace fields (traceID, spanID, parentSpanID)
// are extracted from ctx using typed keys, and are included in every log entry.
func New(ctx context.Context) Logger {
var traceID, spanID, pSpanID string
if ctx.Value("traceID") != nil {
traceID = ctx.Value("traceID").(string)
if v, _ := ctx.Value(constants.CtxKeyTraceID).(string); v != "" {
traceID = v
}
if ctx.Value("spanID") != nil {
spanID = ctx.Value("spanID").(string)
if v, _ := ctx.Value(constants.CtxKeySpanID).(string); v != "" {
spanID = v
}
if ctx.Value("psapnID") != nil {
pSpanID = ctx.Value("pspanID").(string)
if v, _ := ctx.Value(constants.CtxKeyParentSpanID).(string); v != "" {
pSpanID = v
}
return &logger{

28
main.go
View File

@ -19,7 +19,6 @@ import (
"modelRT/database"
"modelRT/diagram"
"modelRT/logger"
"modelRT/middleware"
"modelRT/model"
"modelRT/mq"
"modelRT/pool"
@ -74,7 +73,9 @@ var (
func main() {
flag.Parse()
ctx := context.TODO()
startupSpanID := util.GenerateSpanID("startup")
ctx := context.WithValue(context.Background(), constants.CtxKeyTraceID, startupSpanID)
ctx = context.WithValue(ctx, constants.CtxKeySpanID, startupSpanID)
configPath := filepath.Join(*modelRTConfigDir, *modelRTConfigName+"."+*modelRTConfigType)
if _, err := os.Stat(configPath); os.IsNotExist(err) {
@ -113,7 +114,7 @@ func main() {
}
// init postgresDBClient
postgresDBClient = database.InitPostgresDBInstance(modelRTConfig.PostgresDBURI)
postgresDBClient = database.InitPostgresDBInstance(ctx, modelRTConfig.PostgresDBURI)
defer func() {
sqlDB, err := postgresDBClient.DB()
@ -171,8 +172,10 @@ func main() {
// async push event to rabbitMQ
go mq.PushUpDownLimitEventToRabbitMQ(ctx, mq.MsgChan)
// async push task message to rabbitMQ
go task.PushTaskToRabbitMQ(ctx, modelRTConfig.RabbitMQConfig, task.TaskMsgChan)
postgresDBClient.Transaction(func(tx *gorm.DB) error {
postgresDBClient.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// load circuit diagram from postgres
// componentTypeMap, err := database.QueryCircuitDiagramComponentFromDB(cancelCtx, tx, parsePool)
// if err != nil {
@ -246,22 +249,11 @@ func main() {
AllowCredentials: true,
MaxAge: 12 * time.Hour,
}))
// Register configuration middleware
engine.Use(middleware.ConfigMiddleware(modelRTConfig))
router.RegisterRoutes(engine, serviceToken)
// Swagger UI
engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
// 注册 Swagger UI 路由
// docs.SwaggerInfo.BasePath = "/model"
// v1 := engine.Group("/api/v1")
// {
// eg := v1.Group("/example")
// {
// eg.GET("/helloworld", Helloworld)
// }
// }
if modelRTConfig.DeployEnv != constants.ProductionDeployMode {
engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
}
server := http.Server{
Addr: modelRTConfig.ServiceAddr,

View File

@ -1,3 +1,4 @@
// Package middleware define gin framework middlewares
package middleware
import (

View File

@ -1,3 +1,4 @@
// Package middleware define gin framework middlewares
package middleware
import (

View File

@ -1,3 +1,4 @@
// Package middleware define gin framework middlewares
package middleware
import (

View File

@ -1,3 +1,4 @@
// Package middleware define gin framework middlewares
package middleware
import "github.com/gin-gonic/gin"

View File

@ -1,7 +1,9 @@
// Package middleware define gin framework middlewares
package middleware
import (
"bytes"
"context"
"io"
"strings"
"time"
@ -27,6 +29,14 @@ func StartTrace() gin.HandlerFunc {
c.Set(constants.HeaderTraceID, traceID)
c.Set(constants.HeaderSpanID, spanID)
c.Set(constants.HeaderParentSpanID, parentSpanID)
// also inject into request context so c.Request.Context() carries trace values
reqCtx := c.Request.Context()
reqCtx = context.WithValue(reqCtx, constants.CtxKeyTraceID, traceID)
reqCtx = context.WithValue(reqCtx, constants.CtxKeySpanID, spanID)
reqCtx = context.WithValue(reqCtx, constants.CtxKeyParentSpanID, parentSpanID)
c.Request = c.Request.WithContext(reqCtx)
c.Next()
}
}
@ -78,7 +88,6 @@ func LogAccess() gin.HandlerFunc {
accessLog(c, "access_end", time.Since(start), reqBody, responseLogging)
}()
c.Next()
return
}
}

View File

@ -22,7 +22,7 @@ func GetNSpathToIsLocalMap(ctx context.Context, db *gorm.DB) (map[string]bool, e
var results []ComponentStationRelation
nspathMap := make(map[string]bool)
err := db.Table("component").
err := db.WithContext(ctx).Table("component").
Select("component.nspath, station.is_local").
Joins("join station on component.station_id = station.id").
Scan(&results).Error

View File

@ -36,12 +36,12 @@ func NewHandlerFactory() *HandlerFactory {
}
// RegisterHandler registers a handler for a specific task type
func (f *HandlerFactory) RegisterHandler(taskType TaskType, handler TaskHandler) {
func (f *HandlerFactory) RegisterHandler(ctx context.Context, taskType TaskType, handler TaskHandler) {
f.mu.Lock()
defer f.mu.Unlock()
f.handlers[taskType] = handler
logger.Info(context.Background(), "Handler registered",
logger.Info(ctx, "Handler registered",
"task_type", taskType,
"handler_name", handler.Name(),
)
@ -61,11 +61,11 @@ func (f *HandlerFactory) GetHandler(taskType TaskType) (TaskHandler, error) {
}
// CreateDefaultHandlers registers all default task handlers
func (f *HandlerFactory) CreateDefaultHandlers() {
f.RegisterHandler(TypeTopologyAnalysis, &TopologyAnalysisHandler{})
f.RegisterHandler(TypeEventAnalysis, &EventAnalysisHandler{})
f.RegisterHandler(TypeBatchImport, &BatchImportHandler{})
f.RegisterHandler(TaskType(TaskTypeTest), NewTestTaskHandler())
func (f *HandlerFactory) CreateDefaultHandlers(ctx context.Context) {
f.RegisterHandler(ctx, TypeTopologyAnalysis, &TopologyAnalysisHandler{})
f.RegisterHandler(ctx, TypeEventAnalysis, &EventAnalysisHandler{})
f.RegisterHandler(ctx, TypeBatchImport, &BatchImportHandler{})
f.RegisterHandler(ctx, TaskType(TaskTypeTest), NewTestTaskHandler())
}
// BaseHandler provides common functionality for all task handlers
@ -235,14 +235,14 @@ func (h *CompositeHandler) Name() string {
}
// DefaultHandlerFactory returns a HandlerFactory with all default handlers registered
func DefaultHandlerFactory() *HandlerFactory {
func DefaultHandlerFactory(ctx context.Context) *HandlerFactory {
factory := NewHandlerFactory()
factory.CreateDefaultHandlers()
factory.CreateDefaultHandlers(ctx)
return factory
}
// DefaultCompositeHandler returns a CompositeHandler with all default handlers
func DefaultCompositeHandler() TaskHandler {
factory := DefaultHandlerFactory()
func DefaultCompositeHandler(ctx context.Context) TaskHandler {
factory := DefaultHandlerFactory(ctx)
return NewCompositeHandler(factory)
}

View File

@ -23,8 +23,8 @@ func InitTaskWorker(ctx context.Context, config config.ModelRTConfig, db *gorm.D
// Create task handler factory
handlerFactory := NewHandlerFactory()
handlerFactory.CreateDefaultHandlers()
handler := DefaultCompositeHandler()
handlerFactory.CreateDefaultHandlers(ctx)
handler := DefaultCompositeHandler(ctx)
// Create task worker
worker, err := NewTaskWorker(ctx, workerCfg, db, config.RabbitMQConfig, handler)

View File

@ -14,6 +14,8 @@ type TaskQueueMessage struct {
TaskID uuid.UUID `json:"task_id"`
TaskType TaskType `json:"task_type"`
Priority int `json:"priority,omitempty"` // Optional, defaults to constants.TaskPriorityDefault
TraceID string `json:"trace_id,omitempty"` // propagated from the originating HTTP request
SpanID string `json:"span_id,omitempty"` // spanID of the step that enqueued this message
}
// NewTaskQueueMessage creates a new TaskQueueMessage with default priority

View File

@ -11,11 +11,19 @@ import (
"modelRT/constants"
"modelRT/logger"
"modelRT/mq"
"modelRT/util"
"github.com/gofrs/uuid"
amqp "github.com/rabbitmq/amqp091-go"
)
// TaskMsgChan buffers task messages to be published to RabbitMQ asynchronously
var TaskMsgChan chan *TaskQueueMessage
func init() {
TaskMsgChan = make(chan *TaskQueueMessage, 10000)
}
// QueueProducer handles publishing tasks to RabbitMQ
type QueueProducer struct {
conn *amqp.Connection
@ -213,3 +221,38 @@ func (p *QueueProducer) GetQueueInfo() (*amqp.Queue, error) {
func (p *QueueProducer) PurgeQueue() (int, error) {
return p.ch.QueuePurge(constants.TaskQueueName, false)
}
// PushTaskToRabbitMQ reads from taskChan and publishes to RabbitMQ.
// Must be run as a goroutine; blocks until ctx is cancelled or taskChan is closed.
func PushTaskToRabbitMQ(ctx context.Context, cfg config.RabbitMQConfig, taskChan chan *TaskQueueMessage) {
producer, err := NewQueueProducer(ctx, cfg)
if err != nil {
logger.Error(ctx, "init task queue producer failed", "error", err)
return
}
defer producer.Close()
for {
select {
case <-ctx.Done():
logger.Info(ctx, "push task to RabbitMQ stopped by context cancel")
return
case msg, ok := <-taskChan:
if !ok {
logger.Info(ctx, "task channel closed, exiting push loop")
return
}
traceID := msg.TraceID
if traceID == "" {
traceID = msg.TaskID.String() // fallback when no HTTP trace was propagated
}
taskCtx := context.WithValue(ctx, constants.CtxKeyTraceID, traceID)
taskCtx = context.WithValue(taskCtx, constants.CtxKeySpanID, util.GenerateSpanID("task-publish"))
taskCtx = context.WithValue(taskCtx, constants.CtxKeyParentSpanID, msg.SpanID)
if err := producer.PublishTaskWithRetry(taskCtx, msg.TaskID, msg.TaskType, msg.Priority, 3); err != nil {
logger.Error(taskCtx, "publish task to RabbitMQ failed",
"task_id", msg.TaskID, "error", err)
}
}
}
}

View File

@ -14,6 +14,7 @@ import (
"modelRT/logger"
"modelRT/mq"
"modelRT/orm"
"modelRT/util"
"github.com/gofrs/uuid"
"github.com/panjf2000/ants/v2"
@ -282,6 +283,16 @@ func (w *TaskWorker) handleMessage(msg amqp.Delivery) {
return
}
// derive a per-task context carrying the trace propagated from the originating HTTP request
traceID := taskMsg.TraceID
if traceID == "" {
traceID = taskMsg.TaskID.String() // fallback when message carries no trace
}
taskCtx := context.WithValue(ctx, constants.CtxKeyTraceID, traceID)
taskCtx = context.WithValue(taskCtx, constants.CtxKeySpanID, util.GenerateSpanID("task-worker"))
taskCtx = context.WithValue(taskCtx, constants.CtxKeyParentSpanID, taskMsg.SpanID)
ctx = taskCtx
logger.Info(ctx, "Processing task",
"task_id", taskMsg.TaskID,
"task_type", taskMsg.TaskType,