diff --git a/.gitignore b/.gitignore index adf8f72..6487bf7 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ # Go workspace file go.work +.vscode \ No newline at end of file diff --git a/README.md b/README.md index 097d95d..5c7a8c6 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,3 @@ # ModelRT -[![Build Status](http://192.168.46.100:4080/api/badges/CL-Softwares/modelRT/status.svg)](http://192.168.46.100:4080/CL-Softwares/modelRT) \ No newline at end of file +[![Build Status](http://192.168.46.100:4080/api/badges/CL-Softwares/modelRT/status.svg)](http://192.168.46.100:4080/CL-Softwares/modelRT) diff --git a/alert/init.go b/alert/init.go index f56de9a..e6e7ba3 100644 --- a/alert/init.go +++ b/alert/init.go @@ -5,7 +5,7 @@ import ( "sort" "sync" - "modelRT/constant" + constants "modelRT/constant" ) var ( @@ -18,7 +18,7 @@ var ( type Event struct { ComponentID int64 AnchorName string - Level constant.AlertLevel + Level constants.AlertLevel Message string StartTime int64 } @@ -26,7 +26,7 @@ type Event struct { // EventManager define store and manager alert event struct type EventManager struct { mu sync.RWMutex - events map[constant.AlertLevel][]Event + events map[constants.AlertLevel][]Event } // EventSet define alert event set implement sort.Interface @@ -53,7 +53,7 @@ func (am *EventManager) AddEvent(event Event) { } // GetEventsByLevel define get alert event by alert level -func (am *EventManager) GetEventsByLevel(level constant.AlertLevel) []Event { +func (am *EventManager) GetEventsByLevel(level constants.AlertLevel) []Event { am.mu.Lock() defer am.mu.Unlock() @@ -61,7 +61,7 @@ func (am *EventManager) GetEventsByLevel(level constant.AlertLevel) []Event { } // GetRangeEventsByLevel define get range alert event by alert level -func (am *EventManager) GetRangeEventsByLevel(targetLevel constant.AlertLevel) []Event { +func (am *EventManager) GetRangeEventsByLevel(targetLevel constants.AlertLevel) []Event { var targetEvents []Event am.mu.Lock() @@ -79,7 +79,7 @@ func (am *EventManager) GetRangeEventsByLevel(targetLevel constant.AlertLevel) [ // InitAlertEventManager define new alert event manager func InitAlertEventManager() *EventManager { return &EventManager{ - events: make(map[constant.AlertLevel][]Event), + events: make(map[constants.AlertLevel][]Event), } } diff --git a/common/errcode/code.go b/common/errcode/code.go new file mode 100644 index 0000000..abb4d09 --- /dev/null +++ b/common/errcode/code.go @@ -0,0 +1,72 @@ +package errcode + +import ( + "net/http" +) + +// 此处为公共的错误码, 预留 10000000 ~ 10000099 间的 100 个错误码 +var ( + Success = newError(0, "success") + ErrServer = newError(10000000, "服务器内部错误") + ErrParams = newError(10000001, "参数错误, 请检查") + ErrNotFound = newError(10000002, "资源未找到") + ErrPanic = newError(10000003, "(*^__^*)系统开小差了,请稍后重试") // 无预期的panic错误 + ErrToken = newError(10000004, "Token无效") + ErrForbidden = newError(10000005, "未授权") // 访问一些未授权的资源时的错误 + ErrTooManyRequests = newError(10000006, "请求过多") + ErrCoverData = newError(10000007, "ConvertDataError") // 数据转换错误 +) + +// 各个业务模块自定义的错误码, 从 10000100 开始, 可以按照不同的业务模块划分不同的号段 +// Example: +//var ( +// ErrOrderClosed = NewError(10000100, "订单已关闭") +//) + +// 用户模块相关错误码 10000100 ~ 1000199 +var ( + ErrUserInvalid = newError(10000101, "用户异常") + ErrUserNameOccupied = newError(10000102, "用户名已被占用") + ErrUserNotRight = newError(10000103, "用户名或密码不正确") +) + +// 商品模块相关错误码 10000200 ~ 1000299 +var ( + ErrCommodityNotExists = newError(10000200, "商品不存在") + ErrCommodityStockOut = newError(10000201, "库存不足") +) + +// 购物车模块相关错误码 10000300 ~ 1000399 +var ( + ErrCartItemParam = newError(10000300, "购物项参数异常") + ErrCartWrongUser = newError(10000301, "用户购物信息不匹配") +) + +// 订单模块相关错误码 10000500 ~ 10000599 +var ( + ErrOrderParams = newError(10000500, "订单参数异常") + ErrOrderCanNotBeChanged = newError(10000501, "订单不可修改") + ErrOrderUnsupportedPayScene = newError(10000502, "支付场景暂不支持") +) + +func (e *AppError) HttpStatusCode() int { + switch e.Code() { + case Success.Code(): + return http.StatusOK + case ErrServer.Code(), ErrPanic.Code(): + return http.StatusInternalServerError + case ErrParams.Code(), ErrUserInvalid.Code(), ErrUserNameOccupied.Code(), ErrUserNotRight.Code(), + ErrCommodityNotExists.Code(), ErrCommodityStockOut.Code(), ErrCartItemParam.Code(), ErrOrderParams.Code(): + return http.StatusBadRequest + case ErrNotFound.Code(): + return http.StatusNotFound + case ErrTooManyRequests.Code(): + return http.StatusTooManyRequests + case ErrToken.Code(): + return http.StatusUnauthorized + case ErrForbidden.Code(), ErrCartWrongUser.Code(), ErrOrderCanNotBeChanged.Code(): + return http.StatusForbidden + default: + return http.StatusInternalServerError + } +} diff --git a/common/errcode/dao_error.go b/common/errcode/dao_error.go new file mode 100644 index 0000000..df30ff8 --- /dev/null +++ b/common/errcode/dao_error.go @@ -0,0 +1,21 @@ +package errcode + +import "errors" + +// Database layer error +var ( + // ErrUUIDChangeType define error of check uuid from value failed in uuid from change type + ErrUUIDChangeType = errors.New("undefined uuid change type") + + // ErrUpdateRowZero define error of update affected row zero + ErrUpdateRowZero = errors.New("update affected rows is zero") + + // ErrDeleteRowZero define error of delete affected row zero + ErrDeleteRowZero = errors.New("delete affected rows is zero") + + // ErrQueryRowZero define error of query affected row zero + ErrQueryRowZero = errors.New("query affected rows is zero") + + // ErrInsertRowUnexpected define error of insert affected row not reach expected number + ErrInsertRowUnexpected = errors.New("the number of inserted data rows don't reach the expected value") +) diff --git a/common/errcode/error.go b/common/errcode/error.go new file mode 100644 index 0000000..7b5dd71 --- /dev/null +++ b/common/errcode/error.go @@ -0,0 +1,156 @@ +package errcode + +import ( + "encoding/json" + "fmt" + "path" + "runtime" +) + +var codes = map[int]struct{}{} + +// AppError define struct of internal error +type AppError struct { + code int + msg string + cause error + occurred string // 保存由底层错误导致AppErr发生时的位置 +} + +func (e *AppError) Error() string { + if e == nil { + return "" + } + errBytes, err := json.Marshal(e.toStructuredError()) + if err != nil { + return fmt.Sprintf("Error() is error: json marshal error: %v", err) + } + return string(errBytes) +} + +func (e *AppError) String() string { + return e.Error() +} + +// Code define func return error code +func (e *AppError) Code() int { + return e.code +} + +// Msg define func return error msg +func (e *AppError) Msg() string { + return e.msg +} + +// Cause define func return base error +func (e *AppError) Cause() error { + return e.cause +} + +// WithCause define func return top level predefined errors,where the cause field contains the underlying base error +// 在逻辑执行中出现错误, 比如dao层返回的数据库查询错误 +// 可以在领域层返回预定义的错误前附加上导致错误的基础错误。 +// 如果业务模块预定义的错误码比较详细, 可以使用这个方法, 反之错误码定义的比较笼统建议使用Wrap方法包装底层错误生成项目自定义Error +// 并将其记录到日志后再使用预定义错误码返回接口响应 +func (e *AppError) WithCause(err error) *AppError { + newErr := e.Clone() + newErr.cause = err + newErr.occurred = getAppErrOccurredInfo() + return newErr +} + +// Wrap define func packaging information and errors returned by the underlying logic +// 用于逻辑中包装底层函数返回的error 和 WithCause 一样都是为了记录错误链条 +// 该方法生成的error 用于日志记录, 返回响应请使用预定义好的error +func Wrap(msg string, err error) *AppError { + if err == nil { + return nil + } + appErr := &AppError{code: -1, msg: msg, cause: err} + appErr.occurred = getAppErrOccurredInfo() + return appErr +} + +// UnWrap define func return the error wrapped in structure +func (e *AppError) UnWrap() error { + return e.cause +} + +// Is define func return result of whether any error in err's tree matches target. implemented to support errors.Is(err, target) +func (e *AppError) Is(target error) bool { + targetErr, ok := target.(*AppError) + if !ok { + return false + } + return targetErr.Code() == e.Code() +} + +// Clone define func return a new AppError with source AppError's code, msg, cause, occurred +func (e *AppError) Clone() *AppError { + return &AppError{ + code: e.code, + msg: e.msg, + cause: e.cause, + occurred: e.occurred, + } +} + +func newError(code int, msg string) *AppError { + if code > -1 { + if _, duplicated := codes[code]; duplicated { + panic(fmt.Sprintf("预定义错误码 %d 不能重复, 请检查后更换", code)) + } + codes[code] = struct{}{} + } + + return &AppError{code: code, msg: msg} +} + +// getAppErrOccurredInfo 获取项目中调用Wrap或者WithCause方法时的程序位置, 方便排查问题 +func getAppErrOccurredInfo() string { + pc, file, line, ok := runtime.Caller(2) + if !ok { + return "" + } + file = path.Base(file) + funcName := runtime.FuncForPC(pc).Name() + triggerInfo := fmt.Sprintf("func: %s, file: %s, line: %d", funcName, file, line) + return triggerInfo +} + +// AppendMsg define func append a message to the existing error message +func (e *AppError) AppendMsg(msg string) *AppError { + n := e.Clone() + n.msg = fmt.Sprintf("%s, %s", e.msg, msg) + return n +} + +// SetMsg define func set error message into specify field +func (e *AppError) SetMsg(msg string) *AppError { + n := e.Clone() + n.msg = msg + return n +} + +type formattedErr struct { + Code int `json:"code"` + Msg string `json:"msg"` + Cause interface{} `json:"cause"` + Occurred string `json:"occurred"` +} + +// toStructuredError 在JSON Encode 前把Error进行格式化 +func (e *AppError) toStructuredError() *formattedErr { + fe := new(formattedErr) + fe.Code = e.Code() + fe.Msg = e.Msg() + fe.Occurred = e.occurred + if e.cause != nil { + if appErr, ok := e.cause.(*AppError); ok { + fe.Cause = appErr.toStructuredError() + } else { + fe.Cause = e.cause.Error() + } + } + return fe +} diff --git a/config/anchor_param_config.go b/config/anchor_param_config.go index 989b870..e0a9c15 100644 --- a/config/anchor_param_config.go +++ b/config/anchor_param_config.go @@ -2,7 +2,7 @@ package config import ( - "modelRT/constant" + constants "modelRT/constant" ) // AnchorParamListConfig define anchor params list config struct @@ -43,7 +43,7 @@ var baseCurrentFunc = func(archorValue float64, args ...float64) float64 { // SelectAnchorCalculateFuncAndParams define select anchor func and anchor calculate value by component type 、 anchor name and component data func SelectAnchorCalculateFuncAndParams(componentType int, anchorName string, componentData map[string]interface{}) (func(archorValue float64, args ...float64) float64, []float64) { - if componentType == constant.DemoType { + if componentType == constants.DemoType { if anchorName == "voltage" { resistance := componentData["resistance"].(float64) return baseVoltageFunc, []float64{resistance} diff --git a/config/config.go b/config/config.go index 57e917b..9a3bfb8 100644 --- a/config/config.go +++ b/config/config.go @@ -41,6 +41,16 @@ type LoggerConfig struct { MaxSize int `mapstructure:"maxsize"` MaxBackups int `mapstructure:"maxbackups"` MaxAge int `mapstructure:"maxage"` + Compress bool `mapstructure:"compress"` +} + +// RedisConfig define config stuct of redis config +type RedisConfig struct { + Addr string `mapstructure:"addr"` + Password string `mapstructure:"password"` + DB int `mapstructure:"db"` + PoolSize int `mapstructure:"poolsize"` + Timeout int `mapstructure:"timeout"` } // AntsConfig define config stuct of ants pool config @@ -59,13 +69,15 @@ type DataRTConfig struct { // ModelRTConfig define config stuct of model runtime server type ModelRTConfig struct { - BaseConfig `mapstructure:"base"` - PostgresConfig `mapstructure:"postgres"` - KafkaConfig `mapstructure:"kafka"` - LoggerConfig `mapstructure:"logger"` - AntsConfig `mapstructure:"ants"` - DataRTConfig `mapstructure:"dataRT"` - PostgresDBURI string `mapstructure:"-"` + BaseConfig `mapstructure:"base"` + PostgresConfig `mapstructure:"postgres"` + KafkaConfig `mapstructure:"kafka"` + LoggerConfig `mapstructure:"logger"` + AntsConfig `mapstructure:"ants"` + DataRTConfig `mapstructure:"dataRT"` + LockerRedisConfig RedisConfig `mapstructure:"locker_redis"` + StorageRedisConfig RedisConfig `mapstructure:"storage_redis"` + PostgresDBURI string `mapstructure:"-"` } // ReadAndInitConfig return modelRT project config struct diff --git a/config/config.yaml b/config/config.yaml index 9df949b..57ac79c 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -29,12 +29,28 @@ logger: maxsize: 1 maxbackups: 5 maxage: 30 + compress: false # ants config ants: parse_concurrent_quantity: 10 rtd_receive_concurrent_quantity: 10 +# redis config +locker_redis: + addr: "192.168.2.104:6379" + password: "" + db: 1 + poolsize: 50 + timeout: 10 + +storage_redis: + addr: "192.168.2.104:6379" + password: "" + db: 0 + poolsize: 50 + timeout: 10 + # modelRT base config base: grid_id: 1 diff --git a/config/model_config.go b/config/model_config.go index dedc93c..e626070 100644 --- a/config/model_config.go +++ b/config/model_config.go @@ -9,6 +9,6 @@ import ( type ModelParseConfig struct { ComponentInfo orm.Component - Context context.Context + Ctx context.Context AnchorChan chan AnchorParamConfig } diff --git a/constant/alert.go b/constant/alert.go index 5b6ed6f..6f6d793 100644 --- a/constant/alert.go +++ b/constant/alert.go @@ -1,5 +1,5 @@ -// Package constant define alert level constant -package constant +// Package constants define constant variable +package constants // AlertLevel define alert level type type AlertLevel int diff --git a/constant/busbar_section.go b/constant/busbar_section.go index aa66cb3..7a6f86a 100644 --- a/constant/busbar_section.go +++ b/constant/busbar_section.go @@ -1,4 +1,5 @@ -package constant +// Package constants define constant variable +package constants const ( // 母线服役属性 diff --git a/constant/electrical_components.go b/constant/electrical_components.go index 2f7da65..077eea3 100644 --- a/constant/electrical_components.go +++ b/constant/electrical_components.go @@ -1,5 +1,5 @@ -// Package constant define constant value -package constant +// Package constants define constant variable +package constants const ( // NullableType 空类型类型 diff --git a/constant/error.go b/constant/error.go index cb48f93..d4c6242 100644 --- a/constant/error.go +++ b/constant/error.go @@ -1,22 +1,8 @@ -package constant +// Package constants define constant variable +package constants import "errors" -// ErrUUIDChangeType define error of check uuid from value failed in uuid from change type -var ErrUUIDChangeType = errors.New("undefined uuid change type") - -// ErrUpdateRowZero define error of update affected row zero -var ErrUpdateRowZero = errors.New("update affected rows is zero") - -// ErrDeleteRowZero define error of delete affected row zero -var ErrDeleteRowZero = errors.New("delete affected rows is zero") - -// ErrQueryRowZero define error of query affected row zero -var ErrQueryRowZero = errors.New("query affected rows is zero") - -// ErrInsertRowUnexpected define error of insert affected row not reach expected number -var ErrInsertRowUnexpected = errors.New("the number of inserted data rows don't reach the expected value") - var ( // ErrUUIDFromCheckT1 define error of check uuid from value failed in uuid from change type ErrUUIDFromCheckT1 = errors.New("in uuid from change type, value of new uuid_from is equal value of old uuid_from") diff --git a/constant/log_mode.go b/constant/log_mode.go index bb5ace7..faecb47 100644 --- a/constant/log_mode.go +++ b/constant/log_mode.go @@ -1,5 +1,5 @@ -// Package constant define constant value -package constant +// Package constants define constant variable +package constants const ( // DevelopmentLogMode define development operator environment for modelRT project diff --git a/constant/time.go b/constant/time.go index e10be69..a7b4d84 100644 --- a/constant/time.go +++ b/constant/time.go @@ -1,5 +1,5 @@ -// Package constant define constant value -package constant +// Package constants define constant variable +package constants const ( // LogTimeFormate define time format for log file name diff --git a/constant/togologic.go b/constant/togologic.go index 871ffc5..a5bc57c 100644 --- a/constant/togologic.go +++ b/constant/togologic.go @@ -1,4 +1,7 @@ -package constant +// Package constants define constant variable +package constants + +import "github.com/gofrs/uuid" const ( // UUIDErrChangeType 拓扑信息错误改变类型 @@ -10,3 +13,11 @@ const ( // UUIDAddChangeType 拓扑信息新增类型 UUIDAddChangeType ) + +const ( + // UUIDNilStr 拓扑信息中开始节点与结束节点字符串形式 + UUIDNilStr = "00000000-0000-0000-0000-000000000000" +) + +// UUIDNil 拓扑信息中开始节点与结束节点 UUID 格式 +var UUIDNil = uuid.FromStringOrNil(UUIDNilStr) diff --git a/database/create_component.go b/database/create_component.go index cdfd99e..e4c1644 100644 --- a/database/create_component.go +++ b/database/create_component.go @@ -7,7 +7,7 @@ import ( "strconv" "time" - "modelRT/constant" + "modelRT/common/errcode" "modelRT/network" "modelRT/orm" @@ -43,7 +43,7 @@ func CreateComponentIntoDB(ctx context.Context, tx *gorm.DB, componentInfo netwo if result.Error != nil || result.RowsAffected == 0 { err := result.Error if result.RowsAffected == 0 { - err = fmt.Errorf("%w:please check insert component slice", constant.ErrInsertRowUnexpected) + err = fmt.Errorf("%w:please check insert component slice", errcode.ErrInsertRowUnexpected) } return -1, fmt.Errorf("insert component info failed:%w", err) } diff --git a/database/create_model_info.go b/database/create_model_info.go index 4a4e41d..0f41172 100644 --- a/database/create_model_info.go +++ b/database/create_model_info.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "modelRT/constant" + "modelRT/common/errcode" "modelRT/model" jsoniter "github.com/json-iterator/go" @@ -28,7 +28,7 @@ func CreateModelIntoDB(ctx context.Context, tx *gorm.DB, componentID int64, comp if result.Error != nil || result.RowsAffected == 0 { err := result.Error if result.RowsAffected == 0 { - err = fmt.Errorf("%w:please check insert model params", constant.ErrInsertRowUnexpected) + err = fmt.Errorf("%w:please check insert model params", errcode.ErrInsertRowUnexpected) } return fmt.Errorf("insert component model params into table %s failed:%w", modelStruct.ReturnTableName(), err) } diff --git a/database/create_topologic.go b/database/create_topologic.go index 5af336c..98d4d20 100644 --- a/database/create_topologic.go +++ b/database/create_topologic.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "modelRT/constant" + "modelRT/common/errcode" "modelRT/network" "modelRT/orm" @@ -21,7 +21,6 @@ func CreateTopologicIntoDB(ctx context.Context, tx *gorm.DB, pageID int64, topol var topologicSlice []orm.Topologic for _, info := range topologicInfos { topologicInfo := orm.Topologic{ - PageID: pageID, UUIDFrom: info.UUIDFrom, UUIDTo: info.UUIDTo, Flag: info.Flag, @@ -35,7 +34,7 @@ func CreateTopologicIntoDB(ctx context.Context, tx *gorm.DB, pageID int64, topol if result.Error != nil || result.RowsAffected != int64(len(topologicSlice)) { err := result.Error if result.RowsAffected != int64(len(topologicSlice)) { - err = fmt.Errorf("%w:please check insert topologic slice", constant.ErrInsertRowUnexpected) + err = fmt.Errorf("%w:please check insert topologic slice", errcode.ErrInsertRowUnexpected) } return fmt.Errorf("insert topologic link failed:%w", err) } diff --git a/database/delete_topologic.go b/database/delete_topologic.go index 21e0264..52f4d97 100644 --- a/database/delete_topologic.go +++ b/database/delete_topologic.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "modelRT/constant" + "modelRT/common/errcode" "modelRT/network" "modelRT/orm" @@ -23,7 +23,7 @@ func DeleteTopologicIntoDB(ctx context.Context, tx *gorm.DB, pageID int64, delIn if result.Error != nil || result.RowsAffected == 0 { err := result.Error if result.RowsAffected == 0 { - err = fmt.Errorf("%w:please check delete topologic where conditions", constant.ErrDeleteRowZero) + err = fmt.Errorf("%w:please check delete topologic where conditions", errcode.ErrDeleteRowZero) } return fmt.Errorf("delete topologic link failed:%w", err) } diff --git a/database/postgres_init.go b/database/postgres_init.go index d4edcc8..a80d9fc 100644 --- a/database/postgres_init.go +++ b/database/postgres_init.go @@ -6,6 +6,8 @@ import ( "sync" "time" + "modelRT/logger" + "gorm.io/driver/postgres" "gorm.io/gorm" ) @@ -36,7 +38,7 @@ func InitPostgresDBInstance(ctx context.Context, PostgresDBURI string) *gorm.DB func initPostgresDBClient(ctx context.Context, PostgresDBURI string) *gorm.DB { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - db, err := gorm.Open(postgres.Open(PostgresDBURI), &gorm.Config{}) + db, err := gorm.Open(postgres.Open(PostgresDBURI), &gorm.Config{Logger: logger.NewGormLogger()}) if err != nil { panic(err) } diff --git a/database/query_component.go b/database/query_component.go index 18471a9..4c8245f 100644 --- a/database/query_component.go +++ b/database/query_component.go @@ -6,17 +6,17 @@ import ( "time" "modelRT/config" + "modelRT/logger" "modelRT/orm" "github.com/gofrs/uuid" "github.com/panjf2000/ants/v2" - "go.uber.org/zap" "gorm.io/gorm" "gorm.io/gorm/clause" ) // QueryCircuitDiagramComponentFromDB return the result of query circuit diagram component info order by page id from postgresDB -func QueryCircuitDiagramComponentFromDB(ctx context.Context, tx *gorm.DB, pool *ants.PoolWithFunc, logger *zap.Logger) error { +func QueryCircuitDiagramComponentFromDB(ctx context.Context, tx *gorm.DB, pool *ants.PoolWithFunc) (map[uuid.UUID]int, error) { var components []orm.Component // ctx超时判断 cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) @@ -24,17 +24,22 @@ func QueryCircuitDiagramComponentFromDB(ctx context.Context, tx *gorm.DB, pool * result := tx.WithContext(cancelCtx).Clauses(clause.Locking{Strength: "UPDATE"}).Find(&components) if result.Error != nil { - logger.Error("query circuit diagram component info failed", zap.Error(result.Error)) - return result.Error + logger.Error(ctx, "query circuit diagram component info failed", "error", result.Error) + return nil, result.Error } + // TODO 优化componentTypeMap输出 + componentTypeMap := make(map[uuid.UUID]int, len(components)) + for _, component := range components { pool.Invoke(config.ModelParseConfig{ ComponentInfo: component, - Context: ctx, + Ctx: ctx, }) + + componentTypeMap[component.GlobalUUID] = component.ComponentType } - return nil + return componentTypeMap, nil } // QueryComponentByUUID return the result of query circuit diagram component info by uuid from postgresDB @@ -50,3 +55,17 @@ func QueryComponentByUUID(ctx context.Context, tx *gorm.DB, uuid uuid.UUID) (orm } return component, nil } + +// QueryComponentByPageID return the result of query circuit diagram component info by page id from postgresDB +func QueryComponentByPageID(ctx context.Context, tx *gorm.DB, uuid uuid.UUID) (orm.Component, error) { + var component orm.Component + // ctx超时判断 + cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + result := tx.WithContext(cancelCtx).Where("page_id = ? ", uuid).Clauses(clause.Locking{Strength: "UPDATE"}).Find(&component) + if result.Error != nil { + return orm.Component{}, result.Error + } + return component, nil +} diff --git a/database/query_page.go b/database/query_page.go index a974020..6b069b5 100644 --- a/database/query_page.go +++ b/database/query_page.go @@ -5,26 +5,25 @@ import ( "context" "time" + "modelRT/logger" "modelRT/orm" - "go.uber.org/zap" "gorm.io/gorm" "gorm.io/gorm/clause" ) // QueryAllPages return the all page info of the circuit diagram query by grid_id and zone_id and station_id -func QueryAllPages(ctx context.Context, tx *gorm.DB, logger *zap.Logger, gridID, zoneID, stationID int64) ([]orm.Page, error) { +func QueryAllPages(ctx context.Context, tx *gorm.DB, gridID, zoneID, stationID int64) ([]orm.Page, error) { var pages []orm.Page - // ctx超时判断 + // ctx timeout judgment cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() result := tx.Model(&orm.Page{}).WithContext(cancelCtx).Clauses(clause.Locking{Strength: "UPDATE"}).Select(`"page".id, "page".Name, "page".status,"page".context`).Joins(`inner join "station" on "station".id = "page".station_id`).Joins(`inner join "zone" on "zone".id = "station".zone_id`).Joins(`inner join "grid" on "grid".id = "zone".grid_id`).Where(`"grid".id = ? and "zone".id = ? and "station".id = ?`, gridID, zoneID, stationID).Scan(&pages) if result.Error != nil { - logger.Error("query circuit diagram pages by gridID and zoneID and stationID failed", zap.Int64("grid_id", gridID), zap.Int64("zone_id", zoneID), zap.Int64("station_id", stationID), zap.Error(result.Error)) + logger.Error(ctx, "query circuit diagram pages by gridID and zoneID and stationID failed", "grid_id", gridID, "zone_id", zoneID, "station_id", stationID, "error", result.Error) return nil, result.Error } - return pages, nil } diff --git a/database/query_topologic.go b/database/query_topologic.go index c21ba86..9a72426 100644 --- a/database/query_topologic.go +++ b/database/query_topologic.go @@ -3,77 +3,177 @@ package database import ( "context" + "fmt" "time" + constants "modelRT/constant" "modelRT/diagram" + "modelRT/logger" "modelRT/orm" "modelRT/sql" "github.com/gofrs/uuid" - "go.uber.org/zap" "gorm.io/gorm" "gorm.io/gorm/clause" ) -// QueryTopologicByPageID return the topologic info of the circuit diagram query by pageID -func QueryTopologicByPageID(ctx context.Context, tx *gorm.DB, logger *zap.Logger, pageID int64) ([]orm.Topologic, error) { +// QueryTopologic return the topologic info of the circuit diagram +func QueryTopologic(ctx context.Context, tx *gorm.DB) ([]orm.Topologic, error) { var topologics []orm.Topologic // ctx超时判断 cancelCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - result := tx.WithContext(cancelCtx).Clauses(clause.Locking{Strength: "UPDATE"}).Raw(sql.RecursiveSQL, pageID).Scan(&topologics) + result := tx.WithContext(cancelCtx).Clauses(clause.Locking{Strength: "UPDATE"}).Raw(sql.RecursiveSQL, constants.UUIDNilStr).Scan(&topologics) if result.Error != nil { - logger.Error("query circuit diagram topologic info by pageID failed", zap.Int64("pageID", pageID), zap.Error(result.Error)) + logger.Error(ctx, "query circuit diagram topologic info by start node uuid failed", "start_node_uuid", constants.UUIDNilStr, "error", result.Error) return nil, result.Error } return topologics, nil } -// QueryTopologicFromDB return the result of query topologic info from postgresDB -func QueryTopologicFromDB(ctx context.Context, tx *gorm.DB, logger *zap.Logger, gridID, zoneID, stationID int64) error { - allPages, err := QueryAllPages(ctx, tx, logger, gridID, zoneID, stationID) +// QueryTopologicFromDB return the result of query topologic info from DB +func QueryTopologicFromDB(ctx context.Context, tx *gorm.DB, componentTypeMap map[uuid.UUID]int) (*diagram.MultiBranchTreeNode, error) { + topologicInfos, err := QueryTopologic(ctx, tx) if err != nil { - logger.Error("query all pages info failed", zap.Int64("gridID", gridID), zap.Int64("zoneID", zoneID), zap.Int64("stationID", stationID), zap.Error(err)) - return err + logger.Error(ctx, "query topologic info failed", "error", err) + return nil, err } - for _, page := range allPages { - topologicInfos, err := QueryTopologicByPageID(ctx, tx, logger, page.ID) - if err != nil { - logger.Error("query topologic info by pageID failed", zap.Int64("pageID", page.ID), zap.Error(err)) - return err - } - - err = InitCircuitDiagramTopologic(page.ID, topologicInfos) - if err != nil { - logger.Error("init topologic failed", zap.Error(err)) - return err - } + tree, err := BuildMultiBranchTree(topologicInfos, componentTypeMap) + if err != nil { + logger.Error(ctx, "init topologic failed", "error", err) + return nil, err } - return nil + return tree, nil } // InitCircuitDiagramTopologic return circuit diagram topologic info from postgres -func InitCircuitDiagramTopologic(pageID int64, topologicNodes []orm.Topologic) error { - var rootVertex uuid.UUID - +func InitCircuitDiagramTopologic(topologicNodes []orm.Topologic, componentTypeMap map[uuid.UUID]int) error { + var rootVertex *diagram.MultiBranchTreeNode for _, node := range topologicNodes { - if node.UUIDFrom.IsNil() { - rootVertex = node.UUIDTo + if node.UUIDFrom == constants.UUIDNil { + // rootVertex = node.UUIDTo + var componentType int + componentType, ok := componentTypeMap[node.UUIDFrom] + if !ok { + return fmt.Errorf("can not get component type by uuid: %s", node.UUIDFrom) + } + rootVertex = diagram.NewMultiBranchTree(node.UUIDFrom, componentType) break } } - topologicSet := diagram.NewGraph(rootVertex) + if rootVertex == nil { + return fmt.Errorf("root vertex is nil") + } for _, node := range topologicNodes { - if node.UUIDFrom.IsNil() { - continue + if node.UUIDFrom == constants.UUIDNil { + var componentType int + componentType, ok := componentTypeMap[node.UUIDTo] + if !ok { + return fmt.Errorf("can not get component type by uuid: %s", node.UUIDTo) + } + nodeVertex := diagram.NewMultiBranchTree(node.UUIDTo, componentType) + + rootVertex.AddChild(nodeVertex) } - // TODO 增加对 node.flag值的判断 - topologicSet.AddEdge(node.UUIDFrom, node.UUIDTo) } - diagram.StoreGraphMap(pageID, topologicSet) + + node := rootVertex + for _, nodeVertex := range node.Children { + nextVertexs := make([]*diagram.MultiBranchTreeNode, 0) + nextVertexs = append(nextVertexs, nodeVertex) + } return nil } + +// TODO 电流互感器不单独划分间隔,以母线、浇筑母线、变压器为间隔原件 +func IntervalBoundaryDetermine(uuid uuid.UUID) bool { + fmt.Println(uuid) + var componentID int64 + diagram.GetComponentMap(componentID) + // TODO 判断 component 的类型是否为间隔 + // TODO 0xA1B2C3D4,高四位表示可以成为间隔的compoent类型的值为FFFF,普通 component 类型的值为 0000。低四位中前二位表示component的一级类型,例如母线 PT、母联/母分、进线等,低四位中后二位表示一级类型中包含的具体类型,例如母线 PT中包含的电压互感器、隔离开关、接地开关、避雷器、带电显示器等。 + num := uint32(0xA1B2C3D4) // 八位16进制数 + high16 := uint16(num >> 16) + fmt.Printf("原始值: 0x%X\n", num) // 输出: 0xA1B2C3D4 + fmt.Printf("高十六位: 0x%X\n", high16) // 输出: 0xA1B2 + return true +} + +// BuildMultiBranchTree return the multi branch tree by topologic info and component type map +func BuildMultiBranchTree(topologics []orm.Topologic, componentTypeMap map[uuid.UUID]int) (*diagram.MultiBranchTreeNode, error) { + nodeMap := make(map[uuid.UUID]*diagram.MultiBranchTreeNode, len(topologics)*2) + + for _, topo := range topologics { + if _, exists := nodeMap[topo.UUIDFrom]; !exists { + // skip special uuid + if topo.UUIDTo != constants.UUIDNil { + var ok bool + componentType, ok := componentTypeMap[topo.UUIDFrom] + if !ok { + return nil, fmt.Errorf("can not get component type by uuid: %s", topo.UUIDFrom) + } + + nodeMap[topo.UUIDFrom] = &diagram.MultiBranchTreeNode{ + ID: topo.UUIDFrom, + NodeComponentType: componentType, + Children: make([]*diagram.MultiBranchTreeNode, 0), + } + } + } + + if _, exists := nodeMap[topo.UUIDTo]; !exists { + // skip special uuid + if topo.UUIDTo != constants.UUIDNil { + var ok bool + componentType, ok := componentTypeMap[topo.UUIDTo] + if !ok { + return nil, fmt.Errorf("can not get component type by uuid: %s", topo.UUIDTo) + } + + nodeMap[topo.UUIDTo] = &diagram.MultiBranchTreeNode{ + ID: topo.UUIDTo, + NodeComponentType: componentType, + Children: make([]*diagram.MultiBranchTreeNode, 0), + } + } + } + } + + for _, topo := range topologics { + var parent *diagram.MultiBranchTreeNode + if topo.UUIDFrom == constants.UUIDNil { + var componentType int + parent = &diagram.MultiBranchTreeNode{ + ID: constants.UUIDNil, + NodeComponentType: componentType, + } + nodeMap[constants.UUIDNil] = parent + } else { + parent = nodeMap[topo.UUIDFrom] + } + + var child *diagram.MultiBranchTreeNode + if topo.UUIDTo == constants.UUIDNil { + var componentType int + child = &diagram.MultiBranchTreeNode{ + ID: topo.UUIDTo, + NodeComponentType: componentType, + } + } else { + child = nodeMap[topo.UUIDTo] + } + child.Parent = parent + parent.Children = append(parent.Children, child) + } + + // return root vertex + root, exists := nodeMap[constants.UUIDNil] + if !exists { + return nil, fmt.Errorf("root node not found") + } + return root, nil +} diff --git a/database/update_component.go b/database/update_component.go index 22db5b9..2032f9d 100644 --- a/database/update_component.go +++ b/database/update_component.go @@ -7,7 +7,7 @@ import ( "strconv" "time" - "modelRT/constant" + "modelRT/common/errcode" "modelRT/network" "modelRT/orm" @@ -30,7 +30,7 @@ func UpdateComponentIntoDB(ctx context.Context, tx *gorm.DB, componentInfo netwo if result.Error != nil || result.RowsAffected == 0 { err := result.Error if result.RowsAffected == 0 { - err = fmt.Errorf("%w:please check update component conditions", constant.ErrUpdateRowZero) + err = fmt.Errorf("%w:please check update component conditions", errcode.ErrUpdateRowZero) } return -1, fmt.Errorf("query component info failed:%w", err) } @@ -54,7 +54,7 @@ func UpdateComponentIntoDB(ctx context.Context, tx *gorm.DB, componentInfo netwo if result.Error != nil || result.RowsAffected == 0 { err := result.Error if result.RowsAffected == 0 { - err = fmt.Errorf("%w:please check update component conditions", constant.ErrUpdateRowZero) + err = fmt.Errorf("%w:please check update component conditions", errcode.ErrUpdateRowZero) } return -1, fmt.Errorf("update component info failed:%w", err) } diff --git a/database/update_model_info.go b/database/update_model_info.go index 89fe21a..627f081 100644 --- a/database/update_model_info.go +++ b/database/update_model_info.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "modelRT/constant" + "modelRT/common/errcode" "modelRT/model" jsoniter "github.com/json-iterator/go" @@ -33,7 +33,7 @@ func UpdateModelIntoDB(ctx context.Context, tx *gorm.DB, componentID int64, comp if result.Error != nil || result.RowsAffected == 0 { err := result.Error if result.RowsAffected == 0 { - err = fmt.Errorf("%w:please check where conditions", constant.ErrUpdateRowZero) + err = fmt.Errorf("%w:please check where conditions", errcode.ErrUpdateRowZero) } return err } diff --git a/database/update_topologic.go b/database/update_topologic.go index d64a674..009b3ce 100644 --- a/database/update_topologic.go +++ b/database/update_topologic.go @@ -6,7 +6,8 @@ import ( "fmt" "time" - "modelRT/constant" + "modelRT/common/errcode" + constants "modelRT/constant" "modelRT/network" "modelRT/orm" @@ -21,9 +22,9 @@ func UpdateTopologicIntoDB(ctx context.Context, tx *gorm.DB, pageID int64, chang defer cancel() switch changeInfo.ChangeType { - case constant.UUIDFromChangeType: + case constants.UUIDFromChangeType: result = tx.WithContext(cancelCtx).Model(&orm.Topologic{}).Where("page_id = ? and uuid_from = ? and uuid_to = ?", pageID, changeInfo.OldUUIDFrom, changeInfo.OldUUIDTo).Updates(orm.Topologic{UUIDFrom: changeInfo.NewUUIDFrom}) - case constant.UUIDToChangeType: + case constants.UUIDToChangeType: var delTopologic orm.Topologic result = tx.WithContext(cancelCtx).Model(&orm.Topologic{}).Where("page_id = ? and uuid_to = ?", pageID, changeInfo.NewUUIDTo).Find(&delTopologic) @@ -38,16 +39,15 @@ func UpdateTopologicIntoDB(ctx context.Context, tx *gorm.DB, pageID int64, chang if result.Error != nil || result.RowsAffected == 0 { err := result.Error if result.RowsAffected == 0 { - err = fmt.Errorf("%w:please check delete topologic where conditions", constant.ErrDeleteRowZero) + err = fmt.Errorf("%w:please check delete topologic where conditions", errcode.ErrDeleteRowZero) } return fmt.Errorf("del old topologic link by new_uuid_to failed:%w", err) } } result = tx.WithContext(cancelCtx).Model(&orm.Topologic{}).Where("page_id = ? and uuid_from = ? and uuid_to = ?", pageID, changeInfo.OldUUIDFrom, changeInfo.OldUUIDTo).Updates(&orm.Topologic{UUIDTo: changeInfo.NewUUIDTo}) - case constant.UUIDAddChangeType: + case constants.UUIDAddChangeType: topologic := orm.Topologic{ - PageID: pageID, Flag: changeInfo.Flag, UUIDFrom: changeInfo.NewUUIDFrom, UUIDTo: changeInfo.NewUUIDTo, @@ -61,7 +61,7 @@ func UpdateTopologicIntoDB(ctx context.Context, tx *gorm.DB, pageID int64, chang if result.Error != nil || result.RowsAffected == 0 { err := result.Error if result.RowsAffected == 0 { - err = fmt.Errorf("%w:please check update topologic where conditions", constant.ErrUpdateRowZero) + err = fmt.Errorf("%w:please check update topologic where conditions", errcode.ErrUpdateRowZero) } return fmt.Errorf("insert or update topologic link failed:%w", err) } diff --git a/diagram/graph.go b/diagram/graph.go index 3affcc9..e540e33 100644 --- a/diagram/graph.go +++ b/diagram/graph.go @@ -5,7 +5,7 @@ import ( "fmt" "sync" - "modelRT/constant" + constants "modelRT/constant" "modelRT/network" "github.com/gofrs/uuid" @@ -148,7 +148,7 @@ func (g *Graph) PrintGraph() { // UpdateEdge update edge link info between two verticeLinks func (g *Graph) UpdateEdge(changeInfo network.TopologicUUIDChangeInfos) error { - if changeInfo.ChangeType == constant.UUIDFromChangeType || changeInfo.ChangeType == constant.UUIDToChangeType { + if changeInfo.ChangeType == constants.UUIDFromChangeType || changeInfo.ChangeType == constants.UUIDToChangeType { g.DelEdge(changeInfo.OldUUIDFrom, changeInfo.OldUUIDTo) g.AddEdge(changeInfo.NewUUIDFrom, changeInfo.NewUUIDTo) } else { diff --git a/diagram/hash_test.go b/diagram/hash_test.go new file mode 100644 index 0000000..ed320f3 --- /dev/null +++ b/diagram/hash_test.go @@ -0,0 +1,33 @@ +package diagram + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/redis/go-redis/v9" +) + +func TestHMSet(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Network: "tcp", + Addr: "192.168.2.104:6379", + Password: "cnstar", + PoolSize: 50, + DialTimeout: 10 * time.Second, + }) + params := map[string]interface{}{ + "field1": "Hello1", + "field2": "World1", + "field3": 11, + } + + ctx := context.Background() + res, err := rdb.HSet(ctx, "myhash", params).Result() + if err != nil { + fmt.Printf("err:%v\n", err) + } + fmt.Printf("res:%v\n", res) + return +} diff --git a/diagram/multi_branch_tree.go b/diagram/multi_branch_tree.go new file mode 100644 index 0000000..04337e3 --- /dev/null +++ b/diagram/multi_branch_tree.go @@ -0,0 +1,66 @@ +package diagram + +import ( + "fmt" + + "github.com/gofrs/uuid" +) + +var GlobalTree *MultiBranchTreeNode + +// MultiBranchTreeNode represents a topological structure using an multi branch tree +type MultiBranchTreeNode struct { + ID uuid.UUID // 节点唯一标识 + NodeComponentType int // 节点组件类型 + Parent *MultiBranchTreeNode // 指向父节点的指针 + Children []*MultiBranchTreeNode // 指向所有子节点的指针切片 +} + +func NewMultiBranchTree(id uuid.UUID, componentType int) *MultiBranchTreeNode { + return &MultiBranchTreeNode{ + ID: id, + NodeComponentType: componentType, + Children: make([]*MultiBranchTreeNode, 0), + } +} + +func (n *MultiBranchTreeNode) AddChild(child *MultiBranchTreeNode) { + child.Parent = n + n.Children = append(n.Children, child) +} + +func (n *MultiBranchTreeNode) RemoveChild(childID uuid.UUID) bool { + for i, child := range n.Children { + if child.ID == childID { + n.Children = append(n.Children[:i], n.Children[i+1:]...) + child.Parent = nil + return true + } + } + return false +} + +func (n *MultiBranchTreeNode) FindNodeByID(id uuid.UUID) *MultiBranchTreeNode { + if n.ID == id { + return n + } + + for _, child := range n.Children { + if found := child.FindNodeByID(id); found != nil { + return found + } + } + return nil +} + +func (n *MultiBranchTreeNode) PrintTree(level int) { + for i := 0; i < level; i++ { + fmt.Print(" ") + } + + fmt.Printf("- ComponentType:%d,(ID: %s)\n", n.NodeComponentType, n.ID) + + for _, child := range n.Children { + child.PrintTree(level + 1) + } +} diff --git a/diagram/redis_hash.go b/diagram/redis_hash.go new file mode 100644 index 0000000..e0724f0 --- /dev/null +++ b/diagram/redis_hash.go @@ -0,0 +1,94 @@ +package diagram + +import ( + "context" + + locker "modelRT/distributedlock" + "modelRT/logger" + + "github.com/redis/go-redis/v9" +) + +// RedisHash defines the encapsulation struct of redis hash type +type RedisHash struct { + ctx context.Context + rwLocker *locker.RedissionRWLocker + storageClient *redis.Client +} + +// NewRedisHash define func of new redis hash instance +func NewRedisHash(ctx context.Context, hashKey string, token string, lockLeaseTime uint64, needRefresh bool) *RedisHash { + return &RedisHash{ + ctx: ctx, + rwLocker: locker.InitRWLocker(hashKey, token, lockLeaseTime, needRefresh), + storageClient: GetRedisClientInstance(), + } +} + +// SetRedisHashByMap define func of set redis hash by map struct +func (rh *RedisHash) SetRedisHashByMap(hashKey string, fields map[string]interface{}) error { + err := rh.rwLocker.WLock(rh.ctx) + if err != nil { + logger.Error(rh.ctx, "lock wLock by hash_key failed", "hash_key", hashKey, "error", err) + return err + } + defer rh.rwLocker.UnWLock(rh.ctx) + + err = rh.storageClient.HSet(rh.ctx, hashKey, fields).Err() + if err != nil { + logger.Error(rh.ctx, "set hash by map failed", "hash_key", hashKey, "fields", fields, "error", err) + return err + } + return nil +} + +// SetRedisHashByKV define func of set redis hash by kv struct +func (rh *RedisHash) SetRedisHashByKV(hashKey string, field string, value interface{}) error { + err := rh.rwLocker.WLock(rh.ctx) + if err != nil { + logger.Error(rh.ctx, "lock wLock by hash_key failed", "hash_key", hashKey, "error", err) + return err + } + defer rh.rwLocker.UnWLock(rh.ctx) + + err = rh.storageClient.HSet(rh.ctx, hashKey, field, value).Err() + if err != nil { + logger.Error(rh.ctx, "set hash by kv failed", "hash_key", hashKey, "field", field, "value", value, "error", err) + return err + } + return nil +} + +// HGet define func of get specified field value from redis hash by key and field name +func (rh *RedisHash) HGet(hashKey string, field string) (string, error) { + err := rh.rwLocker.RLock(rh.ctx) + if err != nil { + logger.Error(rh.ctx, "lock rLock by hash_key failed", "hash_key", hashKey, "error", err) + return "", err + } + defer rh.rwLocker.UnRLock(rh.ctx) + + result, err := rh.storageClient.HGet(rh.ctx, hashKey, field).Result() + if err != nil { + logger.Error(rh.ctx, "set hash by kv failed", "hash_key", hashKey, "field", field, "error", err) + return "", err + } + return result, nil +} + +// HGetAll define func of get all filelds from redis hash by key +func (rh *RedisHash) HGetAll(hashKey string) (map[string]string, error) { + err := rh.rwLocker.RLock(rh.ctx) + if err != nil { + logger.Error(rh.ctx, "lock rLock by hash_key failed", "hash_key", hashKey, "error", err) + return nil, err + } + defer rh.rwLocker.UnRLock(rh.ctx) + + result, err := rh.storageClient.HGetAll(rh.ctx, hashKey).Result() + if err != nil { + logger.Error(rh.ctx, "get all hash field by hash key failed", "hash_key", hashKey, "error", err) + return nil, err + } + return result, nil +} diff --git a/diagram/redis_init.go b/diagram/redis_init.go new file mode 100644 index 0000000..3aa0fa5 --- /dev/null +++ b/diagram/redis_init.go @@ -0,0 +1,45 @@ +package diagram + +import ( + "sync" + "time" + + "modelRT/config" + "modelRT/util" + + "github.com/redis/go-redis/v9" +) + +var ( + _globalStorageClient *redis.Client + once sync.Once +) + +// initClient define func of return successfully initialized redis client +func initClient(rCfg config.RedisConfig) *redis.Client { + client, err := util.NewRedisClient( + rCfg.Addr, + util.WithPassword(rCfg.Password), + util.WithDB(rCfg.DB), + util.WithPoolSize(rCfg.PoolSize), + util.WithTimeout(time.Duration(rCfg.Timeout)*time.Second), + ) + if err != nil { + panic(err) + } + return client +} + +// InitClientInstance define func of return instance of redis client +func InitClientInstance(rCfg config.RedisConfig) *redis.Client { + once.Do(func() { + _globalStorageClient = initClient(rCfg) + }) + return _globalStorageClient +} + +// GetRedisClientInstance define func of get redis client instance +func GetRedisClientInstance() *redis.Client { + client := _globalStorageClient + return client +} diff --git a/diagram/redis_set.go b/diagram/redis_set.go new file mode 100644 index 0000000..cf0ba5a --- /dev/null +++ b/diagram/redis_set.go @@ -0,0 +1,92 @@ +package diagram + +import ( + "context" + "fmt" + + locker "modelRT/distributedlock" + "modelRT/logger" + + "github.com/redis/go-redis/v9" + "go.uber.org/zap" +) + +// RedisSet defines the encapsulation struct of redis hash type +type RedisSet struct { + ctx context.Context + rwLocker *locker.RedissionRWLocker + storageClient *redis.Client + logger *zap.Logger +} + +// NewRedisSet define func of new redis set instance +func NewRedisSet(ctx context.Context, hashKey string, token string, lockLeaseTime uint64, needRefresh bool) *RedisSet { + return &RedisSet{ + ctx: ctx, + rwLocker: locker.InitRWLocker(hashKey, token, lockLeaseTime, needRefresh), + storageClient: GetRedisClientInstance(), + logger: logger.GetLoggerInstance(), + } +} + +// SADD define func of add redis set by members +func (rs *RedisSet) SADD(setKey string, members ...interface{}) error { + err := rs.rwLocker.WLock(rs.ctx) + if err != nil { + logger.Error(rs.ctx, "lock wLock by setKey failed", "set_key", setKey, "error", err) + return err + } + defer rs.rwLocker.UnWLock(rs.ctx) + + err = rs.storageClient.SAdd(rs.ctx, setKey, members).Err() + if err != nil { + logger.Error(rs.ctx, "add set by memebers failed", "set_key", setKey, "members", members, "error", err) + return err + } + return nil +} + +// SREM define func of remove the specified members from redis set by key +func (rs *RedisSet) SREM(setKey string, members ...interface{}) error { + err := rs.rwLocker.WLock(rs.ctx) + if err != nil { + logger.Error(rs.ctx, "lock wLock by setKey failed", "set_key", setKey, "error", err) + return err + } + defer rs.rwLocker.UnWLock(rs.ctx) + + count, err := rs.storageClient.SRem(rs.ctx, setKey, members).Result() + if err != nil || count != int64(len(members)) { + logger.Error(rs.ctx, "rem members from set failed", "set_key", setKey, "members", members, "error", err) + + return fmt.Errorf("rem members from set failed:%w", err) + } + return nil +} + +// SMembers define func of get all memebers from redis set by key +func (rs *RedisSet) SMembers(setKey string) ([]string, error) { + err := rs.rwLocker.RLock(rs.ctx) + if err != nil { + logger.Error(rs.ctx, "lock rLock by setKey failed", "set_key", setKey, "error", err) + return nil, err + } + defer rs.rwLocker.UnRLock(rs.ctx) + + result, err := rs.storageClient.SMembers(rs.ctx, setKey).Result() + if err != nil { + logger.Error(rs.ctx, "get all set field by hash key failed", "set_key", setKey, "error", err) + return nil, err + } + return result, nil +} + +// SIsMember define func of determine whether an member is in set by key +func (rs *RedisSet) SIsMember(setKey string, member interface{}) (bool, error) { + result, err := rs.storageClient.SIsMember(rs.ctx, setKey, member).Result() + if err != nil { + logger.Error(rs.ctx, "get all set field by hash key failed", "set_key", setKey, "error", err) + return false, err + } + return result, nil +} diff --git a/distributedlock/constant/lock_err.go b/distributedlock/constant/lock_err.go new file mode 100644 index 0000000..f8aebf0 --- /dev/null +++ b/distributedlock/constant/lock_err.go @@ -0,0 +1,6 @@ +package constants + +import "errors" + +// AcquireTimeoutErr define error of get lock timeout +var AcquireTimeoutErr = errors.New("the waiting time for obtaining the lock operation has timed out") diff --git a/distributedlock/constant/redis_result.go b/distributedlock/constant/redis_result.go new file mode 100644 index 0000000..d942faf --- /dev/null +++ b/distributedlock/constant/redis_result.go @@ -0,0 +1,136 @@ +package constants + +import ( + "fmt" +) + +type RedisCode int + +const ( + LockSuccess = RedisCode(1) + UnLockSuccess = RedisCode(1) + RefreshLockSuccess = RedisCode(1) + UnRLockSuccess = RedisCode(0) + UnWLockSuccess = RedisCode(0) + RLockFailureWithWLockOccupancy = RedisCode(-1) + UnRLockFailureWithWLockOccupancy = RedisCode(-2) + WLockFailureWithRLockOccupancy = RedisCode(-3) + WLockFailureWithWLockOccupancy = RedisCode(-4) + UnWLockFailureWithRLockOccupancy = RedisCode(-5) + UnWLockFailureWithWLockOccupancy = RedisCode(-6) + WLockFailureWithNotFirstPriority = RedisCode(-7) + RefreshLockFailure = RedisCode(-8) + LockFailure = RedisCode(-9) + UnLocakFailureWithLockOccupancy = RedisCode(-10) + UnknownInternalError = RedisCode(-99) +) + +type RedisLockType int + +const ( + LockType = RedisLockType(iota) + UnRLockType + UnWLockType + UnLockType + RefreshLockType +) + +type RedisResult struct { + Code RedisCode + Message string +} + +func (e *RedisResult) Error() string { + return fmt.Sprintf("redis execution code:%d,message:%s\n", e.Code, e.Message) +} + +func (e *RedisResult) OutputResultMessage() string { + return e.Message +} + +func (e *RedisResult) OutputResultCode() int { + return int(e.Code) +} + +func NewRedisResult(res RedisCode, lockType RedisLockType, redisMsg string) error { + resInt := int(res) + switch resInt { + case 1: + if lockType == LockType { + return &RedisResult{Code: res, Message: "redis lock success"} + } else if (lockType == UnRLockType) || (lockType == UnWLockType) || (lockType == UnLockType) { + return &RedisResult{Code: res, Message: "redis unlock success"} + } else { + return &RedisResult{Code: res, Message: "redis refresh lock success"} + } + case 0: + if lockType == UnRLockType { + return &RedisResult{Code: res, Message: "redis unlock read lock success, the lock is still occupied by other processes read lock"} + } else { + return &RedisResult{Code: res, Message: "redis unlock write lock success, the lock is still occupied by other processes write lock"} + } + case -1: + return &RedisResult{Code: res, Message: "redis lock read lock failure,the lock is already occupied by another processes write lock"} + case -2: + return &RedisResult{Code: res, Message: "redis un lock read lock failure,the lock is already occupied by another processes write lock"} + case -3: + return &RedisResult{Code: res, Message: "redis lock write lock failure,the lock is already occupied by anthor processes read lock"} + case -4: + return &RedisResult{Code: res, Message: "redis lock write lock failure,the lock is already occupied by anthor processes write lock"} + case -5: + return &RedisResult{Code: res, Message: "redis unlock write lock failure,the lock is already occupied by another processes read lock"} + case -6: + return &RedisResult{Code: res, Message: "redis unlock write lock failure,the lock is already occupied by another processes write lock"} + case -7: + return &RedisResult{Code: res, Message: "redis lock write lock failure,the first priority in the current process non-waiting queue"} + case -8: + return &RedisResult{Code: res, Message: "redis refresh lock failure,the lock not exist"} + case -9: + return &RedisResult{Code: res, Message: "redis lock failure,the lock is already occupied by another processes lock"} + case -99: + return &RedisResult{Code: res, Message: fmt.Sprintf("redis internal execution error:%v\n", redisMsg)} + default: + msg := "unkown redis execution result" + if redisMsg != "" { + msg = fmt.Sprintf("%s:%s\n", msg, redisMsg) + } + return &RedisResult{Code: res, Message: msg} + } +} + +func TranslateResultToStr(res RedisCode, lockType RedisLockType) string { + resInt := int(res) + switch resInt { + case 1: + if lockType == LockType { + return "redis lock success" + } else if (lockType == UnRLockType) || (lockType == UnWLockType) || (lockType == UnLockType) { + return "redis unlock success" + } else { + return "redis refresh lock success" + } + case 0: + if lockType == UnRLockType { + return "redis unlock read lock success, the lock is still occupied by other processes read lock" + } else { + return "redis unlock write lock success, the lock is still occupied by other processes write lock" + } + case -1: + return "redis lock read lock failure,the lock is already occupied by another processes write lock" + case -2: + return "redis un lock read lock failure,the lock is already occupied by another processes write lock" + case -3: + return "redis lock write lock failure,the lock is already occupied by anthor processes read lock" + case -4: + return "redis lock write lock failure,the lock is already occupied by anthor processes write lock" + case -5: + return "redis un lock write lock failure,the lock is already occupied by another processes read lock" + case -6: + return "redis un lock write lock failure,the lock is already occupied by another processes write lock" + case -7: + return "redis lock write lock failure,the first priority in the current process non-waiting queue" + case -8: + return "redis refresh lock failure,the lock not exist" + } + return "unkown redis execution result" +} diff --git a/distributedlock/locker_init.go b/distributedlock/locker_init.go new file mode 100644 index 0000000..35ecc78 --- /dev/null +++ b/distributedlock/locker_init.go @@ -0,0 +1,45 @@ +package distributedlock + +import ( + "sync" + "time" + + "modelRT/config" + "modelRT/util" + + "github.com/redis/go-redis/v9" +) + +var ( + _globalLockerClient *redis.Client + once sync.Once +) + +// initClient define func of return successfully initialized redis client +func initClient(rCfg config.RedisConfig) *redis.Client { + client, err := util.NewRedisClient( + rCfg.Addr, + util.WithPassword(rCfg.Password), + util.WithDB(rCfg.DB), + util.WithPoolSize(rCfg.PoolSize), + util.WithTimeout(time.Duration(rCfg.Timeout)*time.Second), + ) + if err != nil { + panic(err) + } + return client +} + +// InitClientInstance define func of return instance of redis client +func InitClientInstance(rCfg config.RedisConfig) *redis.Client { + once.Do(func() { + _globalLockerClient = initClient(rCfg) + }) + return _globalLockerClient +} + +// GetRedisClientInstance define func of get redis client instance +func GetRedisClientInstance() *redis.Client { + client := _globalLockerClient + return client +} diff --git a/distributedlock/luascript/lock_script.go b/distributedlock/luascript/lock_script.go new file mode 100644 index 0000000..64a6be8 --- /dev/null +++ b/distributedlock/luascript/lock_script.go @@ -0,0 +1,62 @@ +package luascript + +/* +KEYS[1]:锁的键名(key),通常是锁的唯一标识。 +ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。 +ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 +*/ +var LockScript = ` +-- 锁不存在的情况下加锁 +if (redis.call('exists', KEYS[1]) == 0) then + redis.call('hset', KEYS[1], ARGV[2], 1); + redis.call('expire', KEYS[1], ARGV[1]); + return 1; +end; +-- 重入锁逻辑 +if (redis.call('hexists', KEYS[1], ARGV[2]) == 1) then + redis.call('hincrby', KEYS[1], ARGV[2], 1); + redis.call('expire', KEYS[1], ARGV[1]); + return 1; +end; +-- 持有锁的 token 不是当前客户端的 token,返回加锁失败 +return -9; +` + +/* +KEYS[1]:锁的键名(key),通常是锁的唯一标识。 +ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。 +ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 +*/ +var RefreshLockScript = ` +if (redis.call('hexists', KEYS[1], ARGV[2]) == 1) then + redis.call('expire', KEYS[1], ARGV[1]); + return 1; +end; +return -8; +` + +/* +KEYS[1]:锁的键名(key),通常是锁的唯一标识。 +KEYS[2]:锁的释放通知频道(chankey),用于通知其他客户端锁已释放。 +ARGV[1]:解锁消息(unlockMessage),用于通知其他客户端锁已释放。 +ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 +*/ +var UnLockScript = ` +if (redis.call('exists', KEYS[1]) == 0) then + redis.call('publish', KEYS[2], ARGV[1]); + return 1; +end; +if (redis.call('hexists', KEYS[1], ARGV[2]) == 0) then + return 1; +end; +local counter = redis.call('hincrby', KEYS[1], ARGV[2], -1); +if (counter > 0) then + return 1; +else + redis.call('del', KEYS[1]); + redis.call('publish', KEYS[2], ARGV[1]); + return 1; +end; +-- 持有锁的 token 不是当前客户端的 token,返回解锁失败 +return -10; +` diff --git a/distributedlock/luascript/rwlock_script.go b/distributedlock/luascript/rwlock_script.go new file mode 100644 index 0000000..b26f557 --- /dev/null +++ b/distributedlock/luascript/rwlock_script.go @@ -0,0 +1,263 @@ +// Package luascript defines the lua script used for redis distributed lock +package luascript + +// RLockScript is the lua script for the lock read lock command +/* +KEYS[1]:锁的键名(key),通常是锁的唯一标识。 +KEYS[2]:锁的超时键名前缀(rwTimeoutPrefix),用于存储每个读锁的超时键。 +ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。 +ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 +*/ +var RLockScript = ` +local mode = redis.call('hget', KEYS[1], 'mode'); +local lockKey = KEYS[2] .. ':' .. ARGV[2]; +if (mode == false) then + redis.call('hset', KEYS[1], 'mode', 'read'); + redis.call('hset', KEYS[1], lockKey, '1'); + redis.call('hpexpire', KEYS[1], ARGV[1], 'fields', '1', lockKey); + redis.call('pexpire', KEYS[1], ARGV[1]); + return 1; +end; + +if (mode == 'write') then + -- 放到 list 中等待写锁释放后再次尝试加锁并且订阅写锁释放的消息 + local waitKey = KEYS[1] .. ':read'; + redis.call('rpush', waitKey, ARGV[2]); + return -1; +end; + +if (mode == 'read') then + if (redis.call('exists', KEYS[1], ARGV[2]) == 1) then + redis.call('hincrby', KEYS[1], lockKey, '1'); + local remainTime = redis.call('hpttl', KEYS[1], 'fields', '1', lockKey); + redis.call('hpexpire', KEYS[1], math.max(tonumber(remainTime[1]), ARGV[1]), 'fields', '1', lockKey); + else + redis.call('hset', KEYS[1], lockKey, '1'); + redis.call('hpexpire', KEYS[1], ARGV[1], 'fields', '1', lockKey); + end; + local cursor = 0; + local maxRemainTime = tonumber(ARGV[1]); + local pattern = KEYS[2] .. ':*'; + repeat + local hscanResult = redis.call('hscan', KEYS[1], cursor, 'match', pattern, 'count', '100'); + cursor = tonumber(hscanResult[1]); + local fields = hscanResult[2]; + + for i = 1, #fields,2 do + local field = fields[i]; + local remainTime = redis.call('hpttl', KEYS[1], 'fields', '1', field); + maxRemainTime = math.max(tonumber(remainTime[1]), maxRemainTime); + end; + until cursor == 0; + + local remainTime = redis.call('pttl', KEYS[1]); + redis.call('pexpire', KEYS[1], math.max(tonumber(remainTime),maxRemainTime)); + return 1; +end; +` + +// UnRLockScript is the lua script for the unlock read lock command +/* +KEYS[1]:锁的键名(key),通常是锁的唯一标识。 +KEYS[2]:锁的超时键名前缀(rwTimeoutPrefix),用于存储每个读锁的超时键。 +KEYS[3]:锁的释放通知写频道(chankey),用于通知其他写等待客户端锁已释放。 +ARGV[1]:解锁消息(unlockMessage),用于通知其他客户端锁已释放。 +ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 +*/ +var UnRLockScript = ` +local lockKey = KEYS[2] .. ':' .. ARGV[2]; +local mode = redis.call('hget', KEYS[1], 'mode'); +if (mode == false) then + local writeWait = KEYS[1] .. ':write'; + -- 优先写锁加锁 + local counter = redis.call('llen',writeWait); + if (counter >= 1) then + redis.call('publish', KEYS[3], ARGV[1]); + end; + return 1; +elseif (mode == 'write') then + return -2; +end; + +-- 判断当前的确是读模式但是当前 token 并没有加读锁的情况,返回 0 +local lockExists = redis.call('hexists', KEYS[1], lockKey); +if ((mode == 'read') and (lockExists == 0)) then + return 0; +end; + +local counter = redis.call('hincrby', KEYS[1], lockKey, -1); +local delTTLs = redis.call('hpttl', KEYS[1], 'fields', '1', lockKey); +local delTTL = tonumber(delTTLs[1]); +if (counter == 0) then + redis.call('hdel', KEYS[1], lockKey); +end; + +if (redis.call('hlen', KEYS[1]) > 1) then + local cursor = 0; + local maxRemainTime = 0; + local pattern = KEYS[2] .. ':*'; + repeat + local hscanResult = redis.call('hscan', KEYS[1], cursor, 'match', pattern, 'count', '100'); + cursor = tonumber(hscanResult[1]); + local fields = hscanResult[2]; + + for i = 1, #fields,2 do + local field = fields[i]; + local remainTime = redis.call('hpttl', KEYS[1], 'fields', '1', field); + maxRemainTime = math.max(tonumber(remainTime[1]), maxRemainTime); + end; + until cursor == 0; + + if (maxRemainTime > 0) then + if (delTTL > maxRemainTime) then + redis.call('pexpire', KEYS[1], maxRemainTime); + else + local remainTime = redis.call('pttl', KEYS[1]); + redis.call('pexpire', KEYS[1], math.max(tonumber(remainTime),maxRemainTime)); + end; + end; +else + redis.call('del', KEYS[1]); + local writeWait = KEYS[1] .. ':write'; + -- 优先写锁加锁 + local counter = redis.call('llen',writeWait); + if (counter >= 1) then + redis.call('publish', KEYS[3], ARGV[1]); + end; + return 1; +end; +` + +// WLockScript is the lua script for the lock write lock command +/* +KEYS[1]:锁的键名(key),通常是锁的唯一标识。 +KEYS[2]:锁的超时键名前缀(rwTimeoutPrefix),用于存储每个读锁的超时键。 +ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。 +ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 +*/ +var WLockScript = ` +local mode = redis.call('hget', KEYS[1], 'mode'); +local lockKey = KEYS[2] .. ':' .. ARGV[2]; +local waitKey = KEYS[1] .. ':write'; +if (mode == false) then + local waitListLen = redis.call('llen', waitKey); + if (waitListLen > 0) then + local firstToken = redis.call('lindex', waitKey,'0'); + if (firstToken ~= ARGV[2]) then + return -7; + end; + end; + redis.call('hset', KEYS[1], 'mode', 'write'); + redis.call('hset', KEYS[1], lockKey, 1); + redis.call('hpexpire', KEYS[1], ARGV[1], 'fields', '1', lockKey); + redis.call('pexpire', KEYS[1], ARGV[1]); + redis.call('lpop', waitKey, '1'); + return 1; +elseif (mode == 'read') then + -- 放到 list 中等待读锁释放后再次尝试加锁并且订阅读锁释放的消息 + redis.call('rpush', waitKey, ARGV[2]); + return -3; +else + -- 可重入写锁逻辑 + local lockKey = KEYS[2] .. ':' .. ARGV[2]; + local lockExists = redis.call('hexists', KEYS[1], lockKey); + if (lockExists == 1) then + redis.call('hincrby', KEYS[1], lockKey, 1); + redis.call('hpexpire', KEYS[1], ARGV[1], 'fields', '1', lockKey); + redis.call('pexpire', KEYS[1], ARGV[1]); + return 1; + end; + -- 放到 list 中等待写锁释放后再次尝试加锁并且订阅写锁释放的消息 + local key = KEYS[1] .. ':write'; + redis.call('rpush', key, ARGV[2]); + return -4; +end; +` + +// UnWLockScript is the lua script for the unlock write lock command +/* +KEYS[1]:锁的键名(key),通常是锁的唯一标识。 +KEYS[2]:锁的超时键名前缀(rwTimeoutPrefix),用于存储每个读锁的超时键。 +KEYS[3]:锁的释放通知写频道(writeChankey),用于通知其他写等待客户端锁已释放。 +KEYS[4]:锁的释放通知读频道(readChankey),用于通知其他读等待客户端锁已释放。 +ARGV[1]:解锁消息(unlockMessage),用于通知其他客户端锁已释放。 +ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 +*/ +var UnWLockScript = ` +local mode = redis.call('hget', KEYS[1], 'mode'); +local writeWait = KEYS[1] .. ':write'; +if (mode == false) then + -- 优先写锁加锁,无写锁的情况通知读锁加锁 + local counter = redis.call('llen',writeWait); + if (counter >= 1) then + redis.call('publish', KEYS[3], ARGV[1]); + else + redis.call('publish', KEYS[4], ARGV[1]); + end; + return 1; +elseif (mode == 'read') then + return -5; +else + local lockKey = KEYS[2] .. ':' .. ARGV[2]; + local lockExists = redis.call('hexists', KEYS[1], lockKey); + if (lockExists >= 1) then + -- 可重入写锁逻辑 + local incrRes = redis.call('hincrby', KEYS[1], lockKey, -1); + if (incrRes == 0) then + redis.call('del', KEYS[1]); + local counter = redis.call('llen',writeWait); + if (counter >= 1) then + redis.call('publish', KEYS[3], ARGV[1]); + else + redis.call('publish', KEYS[4], ARGV[1]); + end; + return 1; + end; + return 0; + else + return -6; + end; +end; +` + +// RefreshRWLockScript is the lua script for the refresh lock command +/* +KEYS[1]:锁的键名(key),通常是锁的唯一标识。 +KEYS[2]:锁的超时键名前缀(rwTimeoutPrefix),用于存储每个读锁的超时键。 +ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。 +ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 +*/ +var RefreshRWLockScript = ` +local lockKey = KEYS[2] .. ':' .. ARGV[2]; +local lockExists = redis.call('hexists', KEYS[1], lockKey); +local mode = redis.call('hget', KEYS[1], 'mode'); +local maxRemainTime = tonumber(ARGV[1]); +if (lockExists == 1) then + redis.call('hpexpire', KEYS[1], ARGV[1], 'fields', '1', lockKey); + if (mode == 'read') then + local cursor = 0; + local pattern = KEYS[2] .. ':*'; + repeat + local hscanResult = redis.call('hscan', KEYS[1], cursor, 'match', pattern, 'count', '100'); + cursor = tonumber(hscanResult[1]); + local fields = hscanResult[2]; + + for i = 1, #fields,2 do + local field = fields[i]; + local remainTime = redis.call('hpttl', KEYS[1], 'fields', '1', field); + maxRemainTime = math.max(tonumber(remainTime[1]), maxRemainTime); + end; + until cursor == 0; + + if (maxRemainTime > 0) then + local remainTime = redis.call('pttl', KEYS[1]); + redis.call('pexpire', KEYS[1], math.max(tonumber(remainTime),maxRemainTime)); + end; + elseif (mode == 'write') then + redis.call('pexpire', KEYS[1], ARGV[1]); + end; + -- return redis.call('pttl',KEYS[1]); + return 1; +end; +return -8; +` diff --git a/distributedlock/redis_lock.go b/distributedlock/redis_lock.go new file mode 100644 index 0000000..2234574 --- /dev/null +++ b/distributedlock/redis_lock.go @@ -0,0 +1,256 @@ +package distributedlock + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + constants "modelRT/distributedlock/constant" + luascript "modelRT/distributedlock/luascript" + "modelRT/logger" + + uuid "github.com/gofrs/uuid" + "github.com/redis/go-redis/v9" + "go.uber.org/zap" +) + +const ( + internalLockLeaseTime = uint64(30 * 1000) + unlockMessage = 0 +) + +// RedissionLockConfig define redission lock config +type RedissionLockConfig struct { + LockLeaseTime uint64 + Token string + Prefix string + ChanPrefix string + TimeoutPrefix string + Key string + NeedRefresh bool +} + +type redissionLocker struct { + lockLeaseTime uint64 + Token string + Key string + waitChanKey string + needRefresh bool + refreshExitChan chan struct{} + subExitChan chan struct{} + client *redis.Client + refreshOnce *sync.Once +} + +func (rl *redissionLocker) Lock(ctx context.Context, timeout ...time.Duration) error { + if rl.refreshExitChan == nil { + rl.refreshExitChan = make(chan struct{}) + } + result := rl.tryLock(ctx).(*constants.RedisResult) + if result.Code == constants.UnknownInternalError { + logger.Error(ctx, result.OutputResultMessage()) + return fmt.Errorf("get lock failed:%w", result) + } + + if (result.Code == constants.LockSuccess) && rl.needRefresh { + rl.refreshOnce.Do(func() { + // async refresh lock timeout unitl receive exit singal + go rl.refreshLockTimeout(ctx) + }) + return nil + } + + subMsg := make(chan struct{}, 1) + defer close(subMsg) + sub := rl.client.Subscribe(ctx, rl.waitChanKey) + defer sub.Close() + go rl.subscribeLock(ctx, sub, subMsg) + + if len(timeout) > 0 && timeout[0] > 0 { + acquireTimer := time.NewTimer(timeout[0]) + for { + select { + case _, ok := <-subMsg: + if !ok { + err := errors.New("failed to read the lock waiting for for the channel message") + logger.Error(ctx, "failed to read the lock waiting for for the channel message") + return err + } + + resultErr := rl.tryLock(ctx).(*constants.RedisResult) + if (resultErr.Code == constants.LockFailure) || (resultErr.Code == constants.UnknownInternalError) { + logger.Info(ctx, resultErr.OutputResultMessage()) + continue + } + + if resultErr.Code == constants.LockSuccess { + logger.Info(ctx, resultErr.OutputResultMessage()) + return nil + } + case <-acquireTimer.C: + err := errors.New("the waiting time for obtaining the lock operation has timed out") + logger.Info(ctx, "the waiting time for obtaining the lock operation has timed out") + return err + } + } + } + return fmt.Errorf("lock the redis lock failed:%w", result) +} + +func (rl *redissionLocker) subscribeLock(ctx context.Context, sub *redis.PubSub, subMsgChan chan struct{}) { + if sub == nil || subMsgChan == nil { + return + } + logger.Info(ctx, "lock: enter sub routine", zap.String("token", rl.Token)) + + for { + select { + case <-rl.subExitChan: + close(subMsgChan) + return + case <-sub.Channel(): + // 这里只会收到真正的数据消息 + subMsgChan <- struct{}{} + default: + } + } +} + +/* +KEYS[1]:锁的键名(key),通常是锁的唯一标识。 +ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。 +ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 +*/ +func (rl *redissionLocker) refreshLockTimeout(ctx context.Context) { + logger.Info(ctx, "lock refresh by key and token", zap.String("token", rl.Token), zap.String("key", rl.Key)) + + lockTime := time.Duration(rl.lockLeaseTime/3) * time.Millisecond + timer := time.NewTimer(lockTime) + defer timer.Stop() + + for { + select { + case <-timer.C: + // extend key lease time + res := rl.client.Eval(ctx, luascript.RefreshLockScript, []string{rl.Key}, rl.lockLeaseTime, rl.Token) + val, err := res.Int() + if err != redis.Nil && err != nil { + logger.Info(ctx, "lock refresh failed", "token", rl.Token, "key", rl.Key, "error", err) + return + } + + if constants.RedisCode(val) == constants.RefreshLockFailure { + logger.Error(ctx, "lock refreash failed,can not find the lock by key and token", "token", rl.Token, "key", rl.Key) + break + } + + if constants.RedisCode(val) == constants.RefreshLockSuccess { + logger.Info(ctx, "lock refresh success by key and token", "token", rl.Token, "key", rl.Key) + } + timer.Reset(lockTime) + case <-rl.refreshExitChan: + return + } + } +} + +func (rl *redissionLocker) cancelRefreshLockTime() { + if rl.refreshExitChan != nil { + close(rl.refreshExitChan) + rl.refreshOnce = &sync.Once{} + } +} + +func (rl *redissionLocker) closeSub(ctx context.Context, sub *redis.PubSub, noticeChan chan struct{}) { + if sub != nil { + err := sub.Close() + if err != nil { + logger.Error(ctx, "close sub failed", "token", rl.Token, "key", rl.Key, "error", err) + } + } + + if noticeChan != nil { + close(noticeChan) + } +} + +/* +KEYS[1]:锁的键名(key),通常是锁的唯一标识。 +ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。 +ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 +*/ +func (rl *redissionLocker) tryLock(ctx context.Context) error { + lockType := constants.LockType + res := rl.client.Eval(ctx, luascript.LockScript, []string{rl.Key}, rl.lockLeaseTime, rl.Token) + val, err := res.Int() + if err != redis.Nil && err != nil { + return constants.NewRedisResult(constants.UnknownInternalError, lockType, err.Error()) + } + return constants.NewRedisResult(constants.RedisCode(val), lockType, "") +} + +/* +KEYS[1]:锁的键名(key),通常是锁的唯一标识。 +KEYS[2]:锁的释放通知频道(chankey),用于通知其他客户端锁已释放。 +ARGV[1]:解锁消息(unlockMessage),用于通知其他客户端锁已释放。 +ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 +*/ +func (rl *redissionLocker) UnLock(ctx context.Context) error { + res := rl.client.Eval(ctx, luascript.UnLockScript, []string{rl.Key, rl.waitChanKey}, unlockMessage, rl.Token) + val, err := res.Int() + if err != redis.Nil && err != nil { + logger.Info(ctx, "unlock lock failed", zap.String("token", rl.Token), zap.String("key", rl.Key), zap.Error(err)) + return fmt.Errorf("unlock lock failed:%w", constants.NewRedisResult(constants.UnknownInternalError, constants.UnLockType, err.Error())) + } + + if constants.RedisCode(val) == constants.UnLockSuccess { + if rl.needRefresh { + rl.cancelRefreshLockTime() + } + + logger.Info(ctx, "unlock lock success", zap.String("token", rl.Token), zap.String("key", rl.Key)) + return nil + } + + if constants.RedisCode(val) == constants.UnLocakFailureWithLockOccupancy { + logger.Info(ctx, "unlock lock failed", zap.String("token", rl.Token), zap.String("key", rl.Key)) + return fmt.Errorf("unlock lock failed:%w", constants.NewRedisResult(constants.UnLocakFailureWithLockOccupancy, constants.UnLockType, "")) + } + return nil +} + +// TODO 优化 panic +func GetLocker(client *redis.Client, ops *RedissionLockConfig) *redissionLocker { + if ops.Token == "" { + token, err := uuid.NewV4() + if err != nil { + panic(err) + } + ops.Token = token.String() + } + + if len(ops.Prefix) <= 0 { + ops.Prefix = "redission-lock" + } + + if len(ops.ChanPrefix) <= 0 { + ops.ChanPrefix = "redission-lock-channel" + } + + if ops.LockLeaseTime == 0 { + ops.LockLeaseTime = internalLockLeaseTime + } + + r := &redissionLocker{ + Token: ops.Token, + Key: strings.Join([]string{ops.Prefix, ops.Key}, ":"), + waitChanKey: strings.Join([]string{ops.ChanPrefix, ops.Key, "wait"}, ":"), + needRefresh: ops.NeedRefresh, + client: client, + refreshExitChan: make(chan struct{}), + } + return r +} diff --git a/distributedlock/redis_rwlock.go b/distributedlock/redis_rwlock.go new file mode 100644 index 0000000..3dbfdc6 --- /dev/null +++ b/distributedlock/redis_rwlock.go @@ -0,0 +1,329 @@ +package distributedlock + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + constants "modelRT/distributedlock/constant" + "modelRT/distributedlock/luascript" + "modelRT/logger" + + uuid "github.com/gofrs/uuid" + "github.com/redis/go-redis/v9" +) + +type RedissionRWLocker struct { + redissionLocker + writeWaitChanKey string + readWaitChanKey string + RWTokenTimeoutPrefix string +} + +func (rl *RedissionRWLocker) RLock(ctx context.Context, timeout ...time.Duration) error { + result := rl.tryRLock(ctx).(*constants.RedisResult) + if result.Code == constants.UnknownInternalError { + logger.Error(ctx, result.OutputResultMessage()) + return fmt.Errorf("get read lock failed:%w", result) + } + + if result.Code == constants.LockSuccess { + if rl.needRefresh { + rl.refreshOnce.Do(func() { + if rl.refreshExitChan == nil { + rl.refreshExitChan = make(chan struct{}) + } + + // async refresh lock timeout unitl receive exit singal + go rl.refreshLockTimeout(ctx) + }) + } + logger.Info(ctx, "success get the read lock by key and token", "key", rl.Key, "token", rl.Token) + return nil + } + + if len(timeout) > 0 && timeout[0] > 0 { + if rl.subExitChan == nil { + rl.subExitChan = make(chan struct{}) + } + + subMsgChan := make(chan struct{}, 1) + sub := rl.client.Subscribe(ctx, rl.readWaitChanKey) + go rl.subscribeLock(ctx, sub, subMsgChan) + + acquireTimer := time.NewTimer(timeout[0]) + for { + select { + case _, ok := <-subMsgChan: + if !ok { + err := errors.New("failed to read the read lock waiting for for the channel message") + logger.Error(ctx, "failed to read the read lock waiting for for the channel message") + return err + } + + result := rl.tryRLock(ctx).(*constants.RedisResult) + if (result.Code == constants.RLockFailureWithWLockOccupancy) || (result.Code == constants.UnknownInternalError) { + logger.Info(ctx, result.OutputResultMessage()) + continue + } + + if result.Code == constants.LockSuccess { + logger.Info(ctx, result.OutputResultMessage()) + rl.closeSub(ctx, sub, rl.subExitChan) + + if rl.needRefresh { + rl.refreshOnce.Do(func() { + if rl.refreshExitChan == nil { + rl.refreshExitChan = make(chan struct{}) + } + + // async refresh lock timeout unitl receive exit singal + go rl.refreshLockTimeout(ctx) + }) + } + return nil + } + case <-acquireTimer.C: + logger.Info(ctx, "the waiting time for obtaining the read lock operation has timed out") + rl.closeSub(ctx, sub, rl.subExitChan) + // after acquire lock timeout,notice the sub channel to close + return constants.AcquireTimeoutErr + } + } + } + return fmt.Errorf("lock the redis read lock failed:%w", result) +} + +func (rl *RedissionRWLocker) tryRLock(ctx context.Context) error { + lockType := constants.LockType + + res := rl.client.Eval(ctx, luascript.RLockScript, []string{rl.Key, rl.RWTokenTimeoutPrefix}, rl.lockLeaseTime, rl.Token) + val, err := res.Int() + if err != redis.Nil && err != nil { + return constants.NewRedisResult(constants.UnknownInternalError, lockType, err.Error()) + } + return constants.NewRedisResult(constants.RedisCode(val), lockType, "") +} + +func (rl *RedissionRWLocker) refreshLockTimeout(ctx context.Context) { + logger.Info(ctx, "lock refresh by key and token", "token", rl.Token, "key", rl.Key) + + lockTime := time.Duration(rl.lockLeaseTime/3) * time.Millisecond + timer := time.NewTimer(lockTime) + defer timer.Stop() + + for { + select { + case <-timer.C: + // extend key lease time + res := rl.client.Eval(ctx, luascript.RefreshRWLockScript, []string{rl.Key, rl.RWTokenTimeoutPrefix}, rl.lockLeaseTime, rl.Token) + val, err := res.Int() + if err != redis.Nil && err != nil { + logger.Info(ctx, "lock refresh failed", "token", rl.Token, "key", rl.Key, "error", err) + return + } + + if constants.RedisCode(val) == constants.RefreshLockFailure { + logger.Error(ctx, "lock refreash failed,can not find the read lock by key and token", "rwTokenPrefix", rl.RWTokenTimeoutPrefix, "token", rl.Token, "key", rl.Key) + return + } + + if constants.RedisCode(val) == constants.RefreshLockSuccess { + logger.Info(ctx, "lock refresh success by key and token", "token", rl.Token, "key", rl.Key) + } + timer.Reset(lockTime) + case <-rl.refreshExitChan: + return + } + } +} + +func (rl *RedissionRWLocker) UnRLock(ctx context.Context) error { + logger.Info(ctx, "unlock RLock by key and token", "key", rl.Key, "token", rl.Token) + res := rl.client.Eval(ctx, luascript.UnRLockScript, []string{rl.Key, rl.RWTokenTimeoutPrefix, rl.writeWaitChanKey}, unlockMessage, rl.Token) + val, err := res.Int() + if err != redis.Nil && err != nil { + logger.Info(ctx, "unlock read lock failed", "token", rl.Token, "key", rl.Key, "error", err) + return fmt.Errorf("unlock read lock failed:%w", constants.NewRedisResult(constants.UnknownInternalError, constants.UnRLockType, err.Error())) + } + + if (constants.RedisCode(val) == constants.UnLockSuccess) || (constants.RedisCode(val) == constants.UnRLockSuccess) { + if rl.needRefresh && (constants.RedisCode(val) == constants.UnLockSuccess) { + rl.cancelRefreshLockTime() + } + + logger.Info(ctx, "unlock read lock success", "token", rl.Token, "key", rl.Key) + return nil + } + + if constants.RedisCode(val) == constants.UnRLockFailureWithWLockOccupancy { + logger.Info(ctx, "unlock read lock failed", "token", rl.Token, "key", rl.Key) + return fmt.Errorf("unlock read lock failed:%w", constants.NewRedisResult(constants.UnRLockFailureWithWLockOccupancy, constants.UnRLockType, "")) + } + return nil +} + +func (rl *RedissionRWLocker) WLock(ctx context.Context, timeout ...time.Duration) error { + result := rl.tryWLock(ctx).(*constants.RedisResult) + if result.Code == constants.UnknownInternalError { + logger.Error(ctx, result.OutputResultMessage()) + return fmt.Errorf("get write lock failed:%w", result) + } + + if result.Code == constants.LockSuccess { + if rl.needRefresh { + rl.refreshOnce.Do(func() { + if rl.refreshExitChan == nil { + rl.refreshExitChan = make(chan struct{}) + } + + // async refresh lock timeout unitl receive exit singal + go rl.refreshLockTimeout(ctx) + }) + } + logger.Info(ctx, "success get the write lock by key and token", "key", rl.Key, "token", rl.Token) + return nil + } + + if len(timeout) > 0 && timeout[0] > 0 { + if rl.subExitChan == nil { + rl.subExitChan = make(chan struct{}) + } + + subMsgChan := make(chan struct{}, 1) + sub := rl.client.Subscribe(ctx, rl.writeWaitChanKey) + go rl.subscribeLock(ctx, sub, subMsgChan) + + acquireTimer := time.NewTimer(timeout[0]) + for { + select { + case _, ok := <-subMsgChan: + if !ok { + err := errors.New("failed to read the write lock waiting for for the channel message") + logger.Error(ctx, "failed to read the read lock waiting for for the channel message") + return err + } + + result := rl.tryWLock(ctx).(*constants.RedisResult) + if (result.Code == constants.UnknownInternalError) || (result.Code == constants.WLockFailureWithRLockOccupancy) || (result.Code == constants.WLockFailureWithWLockOccupancy) || (result.Code == constants.WLockFailureWithNotFirstPriority) { + logger.Info(ctx, result.OutputResultMessage()) + continue + } + + if result.Code == constants.LockSuccess { + logger.Info(ctx, result.OutputResultMessage()) + rl.closeSub(ctx, sub, rl.subExitChan) + + if rl.needRefresh { + rl.refreshOnce.Do(func() { + if rl.refreshExitChan == nil { + rl.refreshExitChan = make(chan struct{}) + } + + // async refresh lock timeout unitl receive exit singal + go rl.refreshLockTimeout(ctx) + }) + } + return nil + } + case <-acquireTimer.C: + logger.Info(ctx, "the waiting time for obtaining the write lock operation has timed out") + rl.closeSub(ctx, sub, rl.subExitChan) + // after acquire lock timeout,notice the sub channel to close + return constants.AcquireTimeoutErr + } + } + } + return fmt.Errorf("lock write lock failed:%w", result) +} + +func (rl *RedissionRWLocker) tryWLock(ctx context.Context) error { + lockType := constants.LockType + + res := rl.client.Eval(ctx, luascript.WLockScript, []string{rl.Key, rl.RWTokenTimeoutPrefix}, rl.lockLeaseTime, rl.Token) + val, err := res.Int() + if err != redis.Nil && err != nil { + return constants.NewRedisResult(constants.UnknownInternalError, lockType, err.Error()) + } + return constants.NewRedisResult(constants.RedisCode(val), lockType, "") +} + +func (rl *RedissionRWLocker) UnWLock(ctx context.Context) error { + res := rl.client.Eval(ctx, luascript.UnWLockScript, []string{rl.Key, rl.RWTokenTimeoutPrefix, rl.writeWaitChanKey, rl.readWaitChanKey}, unlockMessage, rl.Token) + val, err := res.Int() + if err != redis.Nil && err != nil { + logger.Error(ctx, "unlock write lock failed", "token", rl.Token, "key", rl.Key, "error", err) + return fmt.Errorf("unlock write lock failed:%w", constants.NewRedisResult(constants.UnknownInternalError, constants.UnWLockType, err.Error())) + } + + if (constants.RedisCode(val) == constants.UnLockSuccess) || constants.RedisCode(val) == constants.UnWLockSuccess { + if rl.needRefresh && (constants.RedisCode(val) == constants.UnLockSuccess) { + rl.cancelRefreshLockTime() + } + logger.Info(ctx, "unlock write lock success", "token", rl.Token, "key", rl.Key) + return nil + } + + if (constants.RedisCode(val) == constants.UnWLockFailureWithRLockOccupancy) || (constants.RedisCode(val) == constants.UnWLockFailureWithWLockOccupancy) { + logger.Info(ctx, "unlock write lock failed", "token", rl.Token, "key", rl.Key) + return fmt.Errorf("unlock write lock failed:%w", constants.NewRedisResult(constants.RedisCode(val), constants.UnWLockType, "")) + } + return nil +} + +// TODO 优化 panic +func GetRWLocker(client *redis.Client, conf *RedissionLockConfig) *RedissionRWLocker { + if conf.Token == "" { + token, err := uuid.NewV4() + if err != nil { + panic(err) + } + conf.Token = token.String() + } + + if conf.Prefix == "" { + conf.Prefix = "redission-rwlock" + } + + if conf.TimeoutPrefix == "" { + conf.TimeoutPrefix = "rwlock_timeout" + } + + if conf.ChanPrefix == "" { + conf.ChanPrefix = "redission-rwlock-channel" + } + + if conf.LockLeaseTime == 0 { + conf.LockLeaseTime = internalLockLeaseTime + } + + r := &redissionLocker{ + Token: conf.Token, + Key: strings.Join([]string{conf.Prefix, conf.Key}, ":"), + needRefresh: conf.NeedRefresh, + lockLeaseTime: conf.LockLeaseTime, + client: client, + refreshOnce: &sync.Once{}, + } + + rwLocker := &RedissionRWLocker{ + redissionLocker: *r, + writeWaitChanKey: strings.Join([]string{conf.ChanPrefix, conf.Key, "write"}, ":"), + readWaitChanKey: strings.Join([]string{conf.ChanPrefix, conf.Key, "read"}, ":"), + RWTokenTimeoutPrefix: conf.TimeoutPrefix, + } + return rwLocker +} + +func InitRWLocker(key string, token string, lockLeaseTime uint64, needRefresh bool) *RedissionRWLocker { + conf := &RedissionLockConfig{ + Key: key, + Token: token, + LockLeaseTime: lockLeaseTime, + NeedRefresh: needRefresh, + } + return GetRWLocker(GetRedisClientInstance(), conf) +} diff --git a/go.mod b/go.mod index e5f2b8c..bfab534 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module modelRT go 1.22.5 require ( + github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/bitly/go-simplejson v0.5.1 github.com/confluentinc/confluent-kafka-go v1.9.2 github.com/gin-gonic/gin v1.10.0 @@ -11,11 +12,15 @@ require ( github.com/json-iterator/go v1.1.12 github.com/natefinch/lumberjack v2.0.0+incompatible github.com/panjf2000/ants/v2 v2.10.0 + github.com/redis/go-redis/v9 v9.7.3 github.com/spf13/viper v1.19.0 + github.com/stretchr/testify v1.9.0 github.com/swaggo/files v1.0.1 github.com/swaggo/gin-swagger v1.6.0 github.com/swaggo/swag v1.16.4 go.uber.org/zap v1.27.0 + golang.org/x/sys v0.28.0 + gorm.io/driver/mysql v1.5.7 gorm.io/driver/postgres v1.5.9 gorm.io/gorm v1.25.12 ) @@ -25,8 +30,11 @@ require ( github.com/KyleBanks/depth v1.2.1 // indirect github.com/bytedance/sonic v1.12.5 // indirect github.com/bytedance/sonic/loader v0.2.1 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gabriel-vasile/mimetype v1.4.7 // indirect github.com/gin-contrib/sse v0.1.0 // indirect @@ -37,6 +45,7 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.23.0 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -55,6 +64,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect @@ -70,7 +80,6 @@ require ( golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/net v0.32.0 // indirect golang.org/x/sync v0.10.0 // indirect - golang.org/x/sys v0.28.0 // indirect golang.org/x/text v0.21.0 // indirect golang.org/x/tools v0.28.0 // indirect google.golang.org/protobuf v1.35.2 // indirect diff --git a/go.sum b/go.sum index da45600..e0dd779 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc= github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= github.com/actgardner/gogen-avro/v10 v10.1.0/go.mod h1:o+ybmVjEa27AAr35FRqU98DJu1fXES56uXniYFv4yDA= @@ -11,6 +13,10 @@ github.com/actgardner/gogen-avro/v9 v9.1.0/go.mod h1:nyTj6wPqDJoxM3qdnjcLv+EnMDS github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/bitly/go-simplejson v0.5.1 h1:xgwPbetQScXt1gh9BmoJ6j9JMr3TElvuIyjR8pgdoow= github.com/bitly/go-simplejson v0.5.1/go.mod h1:YOPVLzCfwK14b4Sff3oP1AmGhI9T9Vsg84etUnlyp+Q= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/bytedance/sonic v1.12.5 h1:hoZxY8uW+mT+OpkcUWw4k0fDINtOcVavEsGfzwzFU/w= github.com/bytedance/sonic v1.12.5/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= @@ -18,6 +24,8 @@ github.com/bytedance/sonic/loader v0.2.1 h1:1GgorWTqf12TA8mma4DDSbaQigE2wOgQo7iC github.com/bytedance/sonic/loader v0.2.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -39,6 +47,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -79,6 +89,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.23.0 h1:/PwmTwZhS0dPkav3cdK9kV1FsAmrL8sThn8IHr/sO+o= github.com/go-playground/validator/v10 v10.23.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= @@ -149,6 +161,7 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/juju/qthttptest v0.1.1/go.mod h1:aTlAv8TYaflIiTDIQYzxnl1QdPjAg8Q8qJMErpKy6A4= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= @@ -195,6 +208,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= +github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/rogpeppe/clock v0.0.0-20190514195947-2896927a307a/go.mod h1:4r5QyqhjIWCcK8DO4KMclc5Iknq5qVBAlbYYzAbUScQ= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= @@ -395,8 +410,11 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= +gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= gorm.io/driver/postgres v1.5.9 h1:DkegyItji119OlcaLjqN11kHoUgZ/j13E0jkJZgD6A8= gorm.io/driver/postgres v1.5.9/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/handler/alert_event_query.go b/handler/alert_event_query.go index a652758..1ddadc6 100644 --- a/handler/alert_event_query.go +++ b/handler/alert_event_query.go @@ -6,25 +6,22 @@ import ( "strconv" "modelRT/alert" - "modelRT/constant" + constants "modelRT/constant" "modelRT/logger" "modelRT/network" "github.com/gin-gonic/gin" - "go.uber.org/zap" ) // QueryAlertEventHandler define query alert event process API func QueryAlertEventHandler(c *gin.Context) { - var targetLevel constant.AlertLevel + var targetLevel constants.AlertLevel - logger := logger.GetLoggerInstance() alertManger := alert.GetAlertMangerInstance() - levelStr := c.Query("level") level, err := strconv.Atoi(levelStr) if err != nil { - logger.Error("convert alert level string to int failed", zap.Error(err)) + logger.Error(c, "convert alert level string to int failed", "error", err) resp := network.FailureResponse{ Code: -1, @@ -32,7 +29,7 @@ func QueryAlertEventHandler(c *gin.Context) { } c.JSON(http.StatusOK, resp) } - targetLevel = constant.AlertLevel(level) + targetLevel = constants.AlertLevel(level) events := alertManger.GetRangeEventsByLevel(targetLevel) resp := network.SuccessResponse{ diff --git a/handler/anchor_point_replace.go b/handler/anchor_point_replace.go index ba110e2..2b3956e 100644 --- a/handler/anchor_point_replace.go +++ b/handler/anchor_point_replace.go @@ -7,7 +7,8 @@ import ( "net/http" "time" - "modelRT/constant" + "modelRT/common/errcode" + constants "modelRT/constant" "modelRT/database" "modelRT/diagram" "modelRT/logger" @@ -16,21 +17,19 @@ import ( "modelRT/orm" "github.com/gin-gonic/gin" - "go.uber.org/zap" ) // ComponentAnchorReplaceHandler define component anchor point replace process API func ComponentAnchorReplaceHandler(c *gin.Context) { var uuid, anchorName string - logger := logger.GetLoggerInstance() - pgClient := database.GetPostgresDBClient() + pgClient := database.GetPostgresDBClient() cancelCtx, cancel := context.WithTimeout(c, 5*time.Second) defer cancel() var request network.ComponetAnchorReplaceRequest if err := c.ShouldBindJSON(&request); err != nil { - logger.Error("unmarshal component anchor point replace info failed", zap.Error(err)) + logger.Error(c, "unmarshal component anchor point replace info failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -45,7 +44,7 @@ func ComponentAnchorReplaceHandler(c *gin.Context) { var componentInfo orm.Component result := pgClient.WithContext(cancelCtx).Model(&orm.Component{}).Where("global_uuid = ?", uuid).Find(&componentInfo) if result.Error != nil { - logger.Error("query component detail info failed", zap.Error(result.Error)) + logger.Error(c, "query component detail info failed", "error", result.Error) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -56,8 +55,8 @@ func ComponentAnchorReplaceHandler(c *gin.Context) { } if result.RowsAffected == 0 { - err := fmt.Errorf("query component detail info by uuid failed:%w", constant.ErrQueryRowZero) - logger.Error("query component detail info from table is empty", zap.String("table_name", "component")) + err := fmt.Errorf("query component detail info by uuid failed:%w", errcode.ErrQueryRowZero) + logger.Error(c, "query component detail info from table is empty", "table_name", "component") resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -73,7 +72,7 @@ func ComponentAnchorReplaceHandler(c *gin.Context) { tableName := model.SelectModelNameByType(componentInfo.ComponentType) result = pgClient.WithContext(cancelCtx).Table(tableName).Where("global_uuid = ?", uuid).Find(&unmarshalMap) if result.Error != nil { - logger.Error("query model detail info failed", zap.Error(result.Error)) + logger.Error(c, "query model detail info failed", "error", result.Error) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -84,8 +83,8 @@ func ComponentAnchorReplaceHandler(c *gin.Context) { } if unmarshalMap == nil { - err := fmt.Errorf("query model detail info by uuid failed:%w", constant.ErrQueryRowZero) - logger.Error("query model detail info from table is empty", zap.String("table_name", tableName)) + err := fmt.Errorf("query model detail info by uuid failed:%w", errcode.ErrQueryRowZero) + logger.Error(c, "query model detail info from table is empty", "table_name", tableName) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -96,8 +95,8 @@ func ComponentAnchorReplaceHandler(c *gin.Context) { } componentType := unmarshalMap["component_type"].(int) - if componentType != constant.DemoType { - logger.Error("can not process real time data of component type not equal DemoType", zap.Int64("component_id", componentInfo.ID)) + if componentType != constants.DemoType { + logger.Error(c, "can not process real time data of component type not equal DemoType", "component_id", componentInfo.ID) } diagram.UpdateAnchorValue(componentInfo.ID, anchorName) diff --git a/handler/circuit_diagram_create.go b/handler/circuit_diagram_create.go index 633cb8c..d66968f 100644 --- a/handler/circuit_diagram_create.go +++ b/handler/circuit_diagram_create.go @@ -13,17 +13,15 @@ import ( "github.com/bitly/go-simplejson" "github.com/gin-gonic/gin" "github.com/gofrs/uuid" - "go.uber.org/zap" ) // CircuitDiagramCreateHandler define circuit diagram create process API func CircuitDiagramCreateHandler(c *gin.Context) { - logger := logger.GetLoggerInstance() pgClient := database.GetPostgresDBClient() var request network.CircuitDiagramCreateRequest if err := c.ShouldBindJSON(&request); err != nil { - logger.Error("unmarshal circuit diagram create info failed", zap.Error(err)) + logger.Error(c, "unmarshal circuit diagram create info failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -35,7 +33,7 @@ func CircuitDiagramCreateHandler(c *gin.Context) { graph, err := diagram.GetGraphMap(request.PageID) if err != nil { - logger.Error("get topologic data from set by pageID failed", zap.Error(err)) + logger.Error(c, "get topologic data from set by pageID failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -63,7 +61,7 @@ func CircuitDiagramCreateHandler(c *gin.Context) { err = fmt.Errorf("convert uuid from string failed:%w:%w", err1, err2) } - logger.Error("format uuid from string failed", zap.Error(err)) + logger.Error(c, "format uuid from string failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -87,7 +85,7 @@ func CircuitDiagramCreateHandler(c *gin.Context) { if err != nil { tx.Rollback() - logger.Error("create topologic info into DB failed", zap.Any("topologic_info", topologicCreateInfos), zap.Error(err)) + logger.Error(c, "create topologic info into DB failed", "topologic_info", topologicCreateInfos, "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -109,7 +107,7 @@ func CircuitDiagramCreateHandler(c *gin.Context) { if err != nil { tx.Rollback() - logger.Error("insert component info into DB failed", zap.Error(err)) + logger.Error(c, "insert component info into DB failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -127,7 +125,7 @@ func CircuitDiagramCreateHandler(c *gin.Context) { if err != nil { tx.Rollback() - logger.Error("create component model into DB failed", zap.Any("component_infos", request.ComponentInfos), zap.Error(err)) + logger.Error(c, "create component model into DB failed", "component_infos", request.ComponentInfos, "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -147,7 +145,7 @@ func CircuitDiagramCreateHandler(c *gin.Context) { if err != nil { tx.Rollback() - logger.Error("unmarshal component params info failed", zap.String("component_params", componentInfo.Params), zap.Error(err)) + logger.Error(c, "unmarshal component params info failed", "component_params", componentInfo.Params, "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -165,7 +163,7 @@ func CircuitDiagramCreateHandler(c *gin.Context) { if err != nil { tx.Rollback() - logger.Error("format params json info to map failed", zap.Error(err)) + logger.Error(c, "format params json info to map failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, diff --git a/handler/circuit_diagram_delete.go b/handler/circuit_diagram_delete.go index 4f8797c..73c3151 100644 --- a/handler/circuit_diagram_delete.go +++ b/handler/circuit_diagram_delete.go @@ -7,7 +7,7 @@ import ( "net/http" "time" - "modelRT/constant" + "modelRT/common/errcode" "modelRT/database" "modelRT/diagram" "modelRT/logger" @@ -17,18 +17,16 @@ import ( "github.com/gin-gonic/gin" "github.com/gofrs/uuid" - "go.uber.org/zap" "gorm.io/gorm/clause" ) // CircuitDiagramDeleteHandler define circuit diagram delete process API func CircuitDiagramDeleteHandler(c *gin.Context) { - logger := logger.GetLoggerInstance() pgClient := database.GetPostgresDBClient() var request network.CircuitDiagramDeleteRequest if err := c.ShouldBindJSON(&request); err != nil { - logger.Error("unmarshal circuit diagram del info failed", zap.Error(err)) + logger.Error(c, "unmarshal circuit diagram del info failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -40,7 +38,7 @@ func CircuitDiagramDeleteHandler(c *gin.Context) { graph, err := diagram.GetGraphMap(request.PageID) if err != nil { - logger.Error("get topologic data from set by pageID failed", zap.Error(err)) + logger.Error(c, "get topologic data from set by pageID failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -68,7 +66,7 @@ func CircuitDiagramDeleteHandler(c *gin.Context) { err = fmt.Errorf("convert uuid from string failed:%w:%w", err1, err2) } - logger.Error("format uuid from string failed", zap.Error(err)) + logger.Error(c, "format uuid from string failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -93,7 +91,7 @@ func CircuitDiagramDeleteHandler(c *gin.Context) { if err != nil { tx.Rollback() - logger.Error("delete topologic info into DB failed", zap.Any("topologic_info", topologicDelInfo), zap.Error(err)) + logger.Error(c, "delete topologic info into DB failed", "topologic_info", topologicDelInfo, "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -110,7 +108,7 @@ func CircuitDiagramDeleteHandler(c *gin.Context) { if err != nil { tx.Rollback() - logger.Error("delete topologic info failed", zap.Any("topologic_info", topologicDelInfo), zap.Error(err)) + logger.Error(c, "delete topologic info failed", "topologic_info", topologicDelInfo, "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -136,7 +134,7 @@ func CircuitDiagramDeleteHandler(c *gin.Context) { if err != nil { tx.Rollback() - logger.Error("format uuid from string failed", zap.Error(err)) + logger.Error(c, "format uuid from string failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -157,10 +155,10 @@ func CircuitDiagramDeleteHandler(c *gin.Context) { err := result.Error if result.RowsAffected == 0 { - err = fmt.Errorf("%w:please check uuid conditions", constant.ErrDeleteRowZero) + err = fmt.Errorf("%w:please check uuid conditions", errcode.ErrDeleteRowZero) } - logger.Error("query component info into postgresDB failed", zap.String("component_global_uuid", componentInfo.UUID), zap.Error(err)) + logger.Error(c, "query component info into postgresDB failed", "component_global_uuid", componentInfo.UUID, "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -179,10 +177,10 @@ func CircuitDiagramDeleteHandler(c *gin.Context) { err := result.Error if result.RowsAffected == 0 { - err = fmt.Errorf("%w:please check uuid conditions", constant.ErrDeleteRowZero) + err = fmt.Errorf("%w:please check uuid conditions", errcode.ErrDeleteRowZero) } - logger.Error("delete component info into postgresDB failed", zap.String("component_global_uuid", componentInfo.UUID), zap.Error(err)) + logger.Error(c, "delete component info into postgresDB failed", "component_global_uuid", componentInfo.UUID, "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -203,11 +201,11 @@ func CircuitDiagramDeleteHandler(c *gin.Context) { err := result.Error if result.RowsAffected == 0 { - err = fmt.Errorf("%w:please check uuid conditions", constant.ErrDeleteRowZero) + err = fmt.Errorf("%w:please check uuid conditions", errcode.ErrDeleteRowZero) } msg := fmt.Sprintf("delete component info from table %s failed", modelStruct.ReturnTableName()) - logger.Error(msg, zap.String("component_global_uuid", componentInfo.UUID), zap.Error(err)) + logger.Error(c, msg, "component_global_uuid", componentInfo.UUID, "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, diff --git a/handler/circuit_diagram_load.go b/handler/circuit_diagram_load.go index 9174b6e..763acbd 100644 --- a/handler/circuit_diagram_load.go +++ b/handler/circuit_diagram_load.go @@ -11,7 +11,6 @@ import ( "modelRT/network" "github.com/gin-gonic/gin" - "go.uber.org/zap" ) // CircuitDiagramLoadHandler define circuit diagram load process API @@ -25,12 +24,11 @@ import ( // @Failure 400 {object} network.FailureResponse "request process failed" // @Router /model/diagram_load/{page_id} [get] func CircuitDiagramLoadHandler(c *gin.Context) { - logger := logger.GetLoggerInstance() pgClient := database.GetPostgresDBClient() pageID, err := strconv.ParseInt(c.Query("page_id"), 10, 64) if err != nil { - logger.Error("get pageID from url param failed", zap.Error(err)) + logger.Error(c, "get pageID from url param failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -45,7 +43,7 @@ func CircuitDiagramLoadHandler(c *gin.Context) { topologicInfo, err := diagram.GetGraphMap(pageID) if err != nil { - logger.Error("get topologic data from set by pageID failed", zap.Error(err)) + logger.Error(c, "get topologic data from set by pageID failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -66,7 +64,7 @@ func CircuitDiagramLoadHandler(c *gin.Context) { for _, componentUUID := range VerticeLink { component, err := database.QueryComponentByUUID(c, pgClient, componentUUID) if err != nil { - logger.Error("get component id info from DB by uuid failed", zap.Error(err)) + logger.Error(c, "get component id info from DB by uuid failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -81,7 +79,7 @@ func CircuitDiagramLoadHandler(c *gin.Context) { componentParams, err := diagram.GetComponentMap(component.ID) if err != nil { - logger.Error("get component data from set by uuid failed", zap.Error(err)) + logger.Error(c, "get component data from set by uuid failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -100,7 +98,7 @@ func CircuitDiagramLoadHandler(c *gin.Context) { rootVertexUUID := topologicInfo.RootVertex.String() rootComponent, err := database.QueryComponentByUUID(c, pgClient, topologicInfo.RootVertex) if err != nil { - logger.Error("get component id info from DB by uuid failed", zap.Error(err)) + logger.Error(c, "get component id info from DB by uuid failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -115,7 +113,7 @@ func CircuitDiagramLoadHandler(c *gin.Context) { rootComponentParam, err := diagram.GetComponentMap(rootComponent.ID) if err != nil { - logger.Error("get component data from set by uuid failed", zap.Error(err)) + logger.Error(c, "get component data from set by uuid failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, diff --git a/handler/circuit_diagram_update.go b/handler/circuit_diagram_update.go index cd08bdb..2ab945f 100644 --- a/handler/circuit_diagram_update.go +++ b/handler/circuit_diagram_update.go @@ -11,17 +11,15 @@ import ( "github.com/bitly/go-simplejson" "github.com/gin-gonic/gin" - "go.uber.org/zap" ) // CircuitDiagramUpdateHandler define circuit diagram update process API func CircuitDiagramUpdateHandler(c *gin.Context) { - logger := logger.GetLoggerInstance() pgClient := database.GetPostgresDBClient() var request network.CircuitDiagramUpdateRequest if err := c.ShouldBindJSON(&request); err != nil { - logger.Error("unmarshal circuit diagram update info failed", zap.Error(err)) + logger.Error(c, "unmarshal circuit diagram update info failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -33,7 +31,7 @@ func CircuitDiagramUpdateHandler(c *gin.Context) { graph, err := diagram.GetGraphMap(request.PageID) if err != nil { - logger.Error("get topologic data from set by pageID failed", zap.Error(err)) + logger.Error(c, "get topologic data from set by pageID failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -50,7 +48,7 @@ func CircuitDiagramUpdateHandler(c *gin.Context) { for _, topologicLink := range request.TopologicLinks { changeInfo, err := network.ParseUUID(topologicLink) if err != nil { - logger.Error("format uuid from string failed", zap.Error(err)) + logger.Error(c, "format uuid from string failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -73,7 +71,7 @@ func CircuitDiagramUpdateHandler(c *gin.Context) { if err != nil { tx.Rollback() - logger.Error("update topologic info into DB failed", zap.Any("topologic_info", topologicChangeInfo), zap.Error(err)) + logger.Error(c, "update topologic info into DB failed", "topologic_info", topologicChangeInfo, "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -90,7 +88,7 @@ func CircuitDiagramUpdateHandler(c *gin.Context) { if err != nil { tx.Rollback() - logger.Error("update topologic info failed", zap.Any("topologic_info", topologicChangeInfo), zap.Error(err)) + logger.Error(c, "update topologic info failed", "topologic_info", topologicChangeInfo, "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -107,7 +105,7 @@ func CircuitDiagramUpdateHandler(c *gin.Context) { for index, componentInfo := range request.ComponentInfos { componentID, err := database.UpdateComponentIntoDB(c, tx, componentInfo) if err != nil { - logger.Error("udpate component info into DB failed", zap.Error(err)) + logger.Error(c, "udpate component info into DB failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -125,7 +123,7 @@ func CircuitDiagramUpdateHandler(c *gin.Context) { err = database.UpdateModelIntoDB(c, tx, componentID, componentInfo.ComponentType, componentInfo.Params) if err != nil { - logger.Error("udpate component model info into DB failed", zap.Error(err)) + logger.Error(c, "udpate component model info into DB failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -143,7 +141,7 @@ func CircuitDiagramUpdateHandler(c *gin.Context) { for _, componentInfo := range request.ComponentInfos { paramsJSON, err := simplejson.NewJson([]byte(componentInfo.Params)) if err != nil { - logger.Error("unmarshal component info by concurrent map failed", zap.String("component_params", componentInfo.Params), zap.Error(err)) + logger.Error(c, "unmarshal component info by concurrent map failed", "component_params", componentInfo.Params, "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -159,7 +157,7 @@ func CircuitDiagramUpdateHandler(c *gin.Context) { componentMap, err := paramsJSON.Map() if err != nil { - logger.Error("format params json info to map failed", zap.Error(err)) + logger.Error(c, "format params json info to map failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, diff --git a/handler/real_time_data_query.go b/handler/real_time_data_query.go index a8fc12f..5c4ed3c 100644 --- a/handler/real_time_data_query.go +++ b/handler/real_time_data_query.go @@ -6,25 +6,23 @@ import ( "strconv" "modelRT/alert" - "modelRT/constant" + constants "modelRT/constant" "modelRT/logger" "modelRT/network" "github.com/gin-gonic/gin" - "go.uber.org/zap" ) // QueryRealTimeDataHandler define query real time data process API func QueryRealTimeDataHandler(c *gin.Context) { - var targetLevel constant.AlertLevel + var targetLevel constants.AlertLevel - logger := logger.GetLoggerInstance() alertManger := alert.GetAlertMangerInstance() levelStr := c.Query("level") level, err := strconv.Atoi(levelStr) if err != nil { - logger.Error("convert alert level string to int failed", zap.Error(err)) + logger.Error(c, "convert alert level string to int failed", "error", err) resp := network.FailureResponse{ Code: http.StatusBadRequest, @@ -32,7 +30,7 @@ func QueryRealTimeDataHandler(c *gin.Context) { } c.JSON(http.StatusOK, resp) } - targetLevel = constant.AlertLevel(level) + targetLevel = constants.AlertLevel(level) events := alertManger.GetRangeEventsByLevel(targetLevel) resp := network.SuccessResponse{ diff --git a/handler/real_time_data_receive.go b/handler/real_time_data_receive.go index d84f534..b2412fb 100644 --- a/handler/real_time_data_receive.go +++ b/handler/real_time_data_receive.go @@ -3,13 +3,11 @@ package handler import ( "modelRT/logger" "modelRT/network" + realtimedata "modelRT/real-time-data" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" jsoniter "github.com/json-iterator/go" - "go.uber.org/zap" - - realtimedata "modelRT/real-time-data" ) var upgrader = websocket.Upgrader{ @@ -19,11 +17,9 @@ var upgrader = websocket.Upgrader{ // RealTimeDataReceivehandler define real time data receive and process API func RealTimeDataReceivehandler(c *gin.Context) { - logger := logger.GetLoggerInstance() - conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { - logger.Error("upgrade http protocol to websocket protocal failed", zap.Error(err)) + logger.Error(c, "upgrade http protocol to websocket protocal failed", "error", err) return } defer conn.Close() @@ -31,17 +27,17 @@ func RealTimeDataReceivehandler(c *gin.Context) { for { messageType, p, err := conn.ReadMessage() if err != nil { - logger.Error("read message from websocket connection failed", zap.Error(err)) + logger.Error(c, "read message from websocket connection failed", "error", err) respByte := processResponse(-1, "read message from websocket connection failed", nil) if len(respByte) == 0 { - logger.Error("process message from byte failed", zap.Error(err)) + logger.Error(c, "process message from byte failed", "error", err) continue } err = conn.WriteMessage(messageType, respByte) if err != nil { - logger.Error("write message to websocket connection failed", zap.Error(err)) + logger.Error(c, "write message to websocket connection failed", "error", err) continue } continue @@ -50,17 +46,17 @@ func RealTimeDataReceivehandler(c *gin.Context) { var request network.RealTimeDataReceiveRequest err = jsoniter.Unmarshal([]byte(p), &request) if err != nil { - logger.Error("unmarshal message from byte failed", zap.Error(err)) + logger.Error(c, "unmarshal message from byte failed", "error", err) respByte := processResponse(-1, "unmarshal message from byte failed", nil) if len(respByte) == 0 { - logger.Error("process message from byte failed", zap.Error(err)) + logger.Error(c, "process message from byte failed", "error", err) continue } err = conn.WriteMessage(messageType, respByte) if err != nil { - logger.Error("write message to websocket connection failed", zap.Error(err)) + logger.Error(c, "write message to websocket connection failed", "error", err) continue } continue @@ -74,13 +70,13 @@ func RealTimeDataReceivehandler(c *gin.Context) { } respByte := processResponse(0, "success", payload) if len(respByte) == 0 { - logger.Error("process message from byte failed", zap.Error(err)) + logger.Error(c, "process message from byte failed", "error", err) continue } err = conn.WriteMessage(messageType, respByte) if err != nil { - logger.Error("write message to websocket connection failed", zap.Error(err)) + logger.Error(c, "write message to websocket connection failed", "error", err) continue } } diff --git a/logger/facede.go b/logger/facede.go new file mode 100644 index 0000000..b9a5b54 --- /dev/null +++ b/logger/facede.go @@ -0,0 +1,54 @@ +// Package logger define log struct of modelRT project +package logger + +import ( + "context" + "sync" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +var ( + f *facade + fOnce sync.Once +) + +type facade struct { + _logger *zap.Logger +} + +// Debug define facade func of debug level log +func Debug(ctx context.Context, msg string, kv ...any) { + logFacade().log(ctx, zapcore.DebugLevel, msg, kv...) +} + +// Info define facade func of info level log +func Info(ctx context.Context, msg string, kv ...any) { + logFacade().log(ctx, zapcore.InfoLevel, msg, kv...) +} + +// Warn define facade func of warn level log +func Warn(ctx context.Context, msg string, kv ...any) { + logFacade().log(ctx, zapcore.WarnLevel, msg, kv...) +} + +// Error define facade func of error level log +func Error(ctx context.Context, msg string, kv ...any) { + logFacade().log(ctx, zapcore.ErrorLevel, msg, kv...) +} + +func (f *facade) log(ctx context.Context, lvl zapcore.Level, msg string, kv ...any) { + fields := makeLogFields(ctx, kv...) + ce := f._logger.Check(lvl, msg) + ce.Write(fields...) +} + +func logFacade() *facade { + fOnce.Do(func() { + f = &facade{ + _logger: GetLoggerInstance(), + } + }) + return f +} diff --git a/logger/gorm_logger.go b/logger/gorm_logger.go new file mode 100644 index 0000000..511d677 --- /dev/null +++ b/logger/gorm_logger.go @@ -0,0 +1,61 @@ +// Package logger define log struct of modelRT project +package logger + +import ( + "context" + "errors" + "time" + + "gorm.io/gorm" + gormLogger "gorm.io/gorm/logger" +) + +// GormLogger define struct for implementing gormLogger.Interface +type GormLogger struct { + SlowThreshold time.Duration +} + +// NewGormLogger define func for init GormLogger +func NewGormLogger() *GormLogger { + return &GormLogger{ + SlowThreshold: 500 * time.Millisecond, + } +} + +// LogMode define func for implementing gormLogger.Interface +func (l *GormLogger) LogMode(_ gormLogger.LogLevel) gormLogger.Interface { + return &GormLogger{} +} + +// Info define func for implementing gormLogger.Interface +func (l *GormLogger) Info(ctx context.Context, msg string, data ...any) { + Info(ctx, msg, "data", data) +} + +// Warn define func for implementing gormLogger.Interface +func (l *GormLogger) Warn(ctx context.Context, msg string, data ...any) { + Warn(ctx, msg, "data", data) +} + +// Error define func for implementing gormLogger.Interface +func (l *GormLogger) Error(ctx context.Context, msg string, data ...any) { + Error(ctx, msg, "data", data) +} + +// Trace define func for implementing gormLogger.Interface +func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + // get SQL running time + duration := time.Since(begin).Milliseconds() + // get gorm exec sql and rows affected + sql, rows := fc() + // gorm error judgment + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + Error(ctx, "SQL ERROR", "sql", sql, "rows", rows, "dur(ms)", duration) + } + // slow query judgment + if duration > l.SlowThreshold.Milliseconds() { + Warn(ctx, "SQL SLOW", "sql", sql, "rows", rows, "dur(ms)", duration) + } else { + Debug(ctx, "SQL DEBUG", "sql", sql, "rows", rows, "dur(ms)", duration) + } +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..eb97838 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,109 @@ +// Package logger define log struct of modelRT project +package logger + +import ( + "context" + "path" + "runtime" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type logger struct { + ctx context.Context + traceID string + spanID string + pSpanID string + _logger *zap.Logger +} + +func (l *logger) Debug(msg string, kv ...any) { + l.log(zapcore.DebugLevel, msg, kv...) +} + +func (l *logger) Info(msg string, kv ...any) { + l.log(zapcore.InfoLevel, msg, kv...) +} + +func (l *logger) Warn(msg string, kv ...any) { + l.log(zapcore.WarnLevel, msg, kv...) +} + +func (l *logger) Error(msg string, kv ...any) { + l.log(zapcore.ErrorLevel, msg, kv...) +} + +func (l *logger) log(lvl zapcore.Level, msg string, kv ...any) { + fields := makeLogFields(l.ctx, kv...) + ce := l._logger.Check(lvl, msg) + ce.Write(fields...) +} + +func makeLogFields(ctx context.Context, kv ...any) []zap.Field { + // Ensure that log information appears in pairs in the form of key-value pairs + if len(kv)%2 != 0 { + kv = append(kv, "unknown") + } + + kv = append(kv, "traceID", ctx.Value("traceID"), "spanID", ctx.Value("spanID"), "pspanID", ctx.Value("pspanID")) + + funcName, file, line := getLoggerCallerInfo() + kv = append(kv, "func", funcName, "file", file, "line", line) + fields := make([]zap.Field, 0, len(kv)/2) + for i := 0; i < len(kv); i += 2 { + key := kv[i].(string) + value := kv[i+1] + switch v := value.(type) { + case string: + fields = append(fields, zap.String(key, v)) + case int: + fields = append(fields, zap.Int(key, v)) + case int64: + fields = append(fields, zap.Int64(key, v)) + case float32: + fields = append(fields, zap.Float32(key, v)) + case float64: + fields = append(fields, zap.Float64(key, v)) + case bool: + fields = append(fields, zap.Bool(key, v)) + case error: + fields = append(fields, zap.Error(v)) + default: + fields = append(fields, zap.Any(key, v)) + } + } + return fields +} + +// getLoggerCallerInfo define func of return log caller information、method name、file name、line number +func getLoggerCallerInfo() (funcName, file string, line int) { + pc, file, line, ok := runtime.Caller(4) + if !ok { + return + } + file = path.Base(file) + funcName = runtime.FuncForPC(pc).Name() + return +} + +func New(ctx context.Context) *logger { + var traceID, spanID, pSpanID string + if ctx.Value("traceID") != nil { + traceID = ctx.Value("traceID").(string) + } + if ctx.Value("spanID") != nil { + spanID = ctx.Value("spanID").(string) + } + if ctx.Value("psapnID") != nil { + pSpanID = ctx.Value("pspanID").(string) + } + + return &logger{ + ctx: ctx, + traceID: traceID, + spanID: spanID, + pSpanID: pSpanID, + _logger: GetLoggerInstance(), + } +} diff --git a/logger/init.go b/logger/zap.go similarity index 76% rename from logger/init.go rename to logger/zap.go index 63818f9..f46c315 100644 --- a/logger/init.go +++ b/logger/zap.go @@ -6,7 +6,7 @@ import ( "sync" "modelRT/config" - "modelRT/constant" + constants "modelRT/constant" "github.com/natefinch/lumberjack" "go.uber.org/zap" @@ -31,17 +31,17 @@ func getEncoder() zapcore.Encoder { } // getLogWriter responsible for setting the location of log storage -func getLogWriter(mode, filename string, maxsize, maxBackup, maxAge int) zapcore.WriteSyncer { +func getLogWriter(mode, filename string, maxsize, maxBackup, maxAge int, compress bool) zapcore.WriteSyncer { lumberJackLogger := &lumberjack.Logger{ Filename: filename, // log file position MaxSize: maxsize, // log file maxsize MaxAge: maxAge, // maximum number of day files retained MaxBackups: maxBackup, // maximum number of old files retained - Compress: false, // whether to compress + Compress: compress, // whether to compress } syncConsole := zapcore.AddSync(os.Stderr) - if mode == constant.DevelopmentLogMode { + if mode == constants.DevelopmentLogMode { return syncConsole } @@ -51,7 +51,7 @@ func getLogWriter(mode, filename string, maxsize, maxBackup, maxAge int) zapcore // initLogger return successfully initialized zap logger func initLogger(lCfg config.LoggerConfig) *zap.Logger { - writeSyncer := getLogWriter(lCfg.Mode, lCfg.FilePath, lCfg.MaxSize, lCfg.MaxBackups, lCfg.MaxAge) + writeSyncer := getLogWriter(lCfg.Mode, lCfg.FilePath, lCfg.MaxSize, lCfg.MaxBackups, lCfg.MaxAge, lCfg.Compress) encoder := getEncoder() l := new(zapcore.Level) @@ -61,21 +61,22 @@ func initLogger(lCfg config.LoggerConfig) *zap.Logger { } core := zapcore.NewCore(encoder, writeSyncer, l) - _globalLogger = zap.New(core, zap.AddCaller()) - zap.ReplaceGlobals(_globalLogger) + logger := zap.New(core, zap.AddCaller()) - return _globalLogger + // 替换全局日志实例 + zap.ReplaceGlobals(logger) + return logger } -// InitLoggerInstance return instance of zap logger -func InitLoggerInstance(lCfg config.LoggerConfig) *zap.Logger { +// InitLoggerInstance define func of return instance of zap logger +func InitLoggerInstance(lCfg config.LoggerConfig) { once.Do(func() { _globalLogger = initLogger(lCfg) }) - return _globalLogger + defer _globalLogger.Sync() } -// GetLoggerInstance returns the global logger instance It's safe for concurrent use. +// GetLoggerInstance define func of returns the global logger instance It's safe for concurrent use. func GetLoggerInstance() *zap.Logger { _globalLoggerMu.RLock() logger := _globalLogger diff --git a/main.go b/main.go index 1bb84e0..c2c890c 100644 --- a/main.go +++ b/main.go @@ -4,25 +4,30 @@ package main import ( "context" "flag" + "net/http" + "os" + "os/signal" + "syscall" "time" "modelRT/alert" "modelRT/config" "modelRT/database" + "modelRT/diagram" + locker "modelRT/distributedlock" _ "modelRT/docs" "modelRT/handler" "modelRT/logger" "modelRT/middleware" "modelRT/pool" - - swaggerFiles "github.com/swaggo/files" - ginSwagger "github.com/swaggo/gin-swagger" + "modelRT/router" realtimedata "modelRT/real-time-data" "github.com/gin-gonic/gin" "github.com/panjf2000/ants/v2" - "go.uber.org/zap" + swaggerFiles "github.com/swaggo/files" + ginSwagger "github.com/swaggo/gin-swagger" "gorm.io/gorm" ) @@ -41,7 +46,6 @@ var ( var ( modelRTConfig config.ModelRTConfig postgresDBClient *gorm.DB - zapLogger *zap.Logger alertManager *alert.EventManager ) @@ -50,6 +54,9 @@ func main() { flag.Parse() ctx := context.TODO() + // init logger + logger.InitLoggerInstance(modelRTConfig.LoggerConfig) + modelRTConfig = config.ReadAndInitConfig(*modelRTConfigDir, *modelRTConfigName, *modelRTConfigType) // init postgresDBClient postgresDBClient = database.InitPostgresDBInstance(ctx, modelRTConfig.PostgresDBURI) @@ -62,25 +69,27 @@ func main() { sqlDB.Close() }() - // init logger - zapLogger = logger.InitLoggerInstance(modelRTConfig.LoggerConfig) - defer zapLogger.Sync() - // init alert manager _ = alert.InitAlertEventManager() // init model parse ants pool parsePool, err := ants.NewPoolWithFunc(modelRTConfig.ParseConcurrentQuantity, pool.ParseFunc) if err != nil { - zapLogger.Error("init concurrent parse task pool failed", zap.Error(err)) + logger.Error(ctx, "init concurrent parse task pool failed", "error", err) panic(err) } defer parsePool.Release() + storageClient := diagram.InitClientInstance(modelRTConfig.StorageRedisConfig) + defer storageClient.Close() + + lockerClient := locker.InitClientInstance(modelRTConfig.LockerRedisConfig) + defer lockerClient.Close() + // init anchor param ants pool anchorRealTimePool, err := pool.AnchorPoolInit(modelRTConfig.RTDReceiveConcurrentQuantity) if err != nil { - zapLogger.Error("init concurrent anchor param task pool failed", zap.Error(err)) + logger.Error(ctx, "init concurrent anchor param task pool failed", "error", err) panic(err) } defer anchorRealTimePool.Release() @@ -93,18 +102,19 @@ func main() { postgresDBClient.Transaction(func(tx *gorm.DB) error { // load circuit diagram from postgres - err := database.QueryCircuitDiagramComponentFromDB(cancelCtx, tx, parsePool, zapLogger) + componentTypeMap, err := database.QueryCircuitDiagramComponentFromDB(cancelCtx, tx, parsePool) if err != nil { - zapLogger.Error("load circuit diagrams from postgres failed", zap.Error(err)) + logger.Error(ctx, "load circuit diagrams from postgres failed", "error", err) panic(err) } // TODO 暂时屏蔽完成 swagger 启动测试 - err = database.QueryTopologicFromDB(ctx, tx, zapLogger, modelRTConfig.GridID, modelRTConfig.ZoneID, modelRTConfig.StationID) + tree, err := database.QueryTopologicFromDB(ctx, tx, componentTypeMap) if err != nil { - zapLogger.Error("load topologic info from postgres failed", zap.Error(err)) + logger.Error(ctx, "load topologic info from postgres failed", "error", err) panic(err) } + diagram.GlobalTree = tree return nil }) @@ -112,14 +122,34 @@ func main() { // TODO 暂时屏蔽完成 swagger 启动测试 // go realtimedata.RealTimeDataComputer(ctx, nil, []string{}, "") - engine := gin.Default() - engine.Use(limiter.Middleware) + engine := gin.New() + router.RegisterRoutes(engine) + server := http.Server{ + Addr: ":8080", + Handler: engine, + } - // diagram api - engine.GET("/model/diagram_load", handler.CircuitDiagramLoadHandler) - engine.POST("/model/diagram_create", handler.CircuitDiagramCreateHandler) - engine.POST("/model/diagram_update", handler.CircuitDiagramUpdateHandler) - engine.POST("/model/diagram_delete", handler.CircuitDiagramDeleteHandler) + // creating a System Signal Receiver + done := make(chan os.Signal, 10) + signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-done + if err := server.Shutdown(context.Background()); err != nil { + logger.Error(ctx, "ShutdownServerError", "err", err) + } + }() + + logger.Info(ctx, "Starting ModelRT server...") + err = server.ListenAndServe() + if err != nil { + if err == http.ErrServerClosed { + // the service receives the shutdown signal normally and then closes + logger.Info(ctx, "Server closed under request") + } else { + // abnormal shutdown of service + logger.Error(ctx, "Server closed unexpected", "err", err) + } + } // real time data api engine.GET("/ws/rtdatas", handler.RealTimeDataReceivehandler) diff --git a/middleware/trace.go b/middleware/trace.go new file mode 100644 index 0000000..ef2dad9 --- /dev/null +++ b/middleware/trace.go @@ -0,0 +1,97 @@ +package middleware + +import ( + "bytes" + "io" + "strings" + "time" + + "modelRT/logger" + "modelRT/util" + + "github.com/gin-gonic/gin" +) + +// StartTrace define func of set trace info from request header +func StartTrace() gin.HandlerFunc { + return func(c *gin.Context) { + traceID := c.Request.Header.Get("traceid") + pSpanID := c.Request.Header.Get("spanid") + spanID := util.GenerateSpanID(c.Request.RemoteAddr) + if traceID == "" { // 如果traceId 为空,证明是链路的发端,把它设置成此次的spanId,发端的spanId是root spanId + traceID = spanID // trace 标识整个请求的链路, span则标识链路中的不同服务 + } + c.Set("traceid", traceID) + c.Set("spanid", spanID) + c.Set("pspanid", pSpanID) + c.Next() + } +} + +type bodyLogWriter struct { + gin.ResponseWriter + body *bytes.Buffer +} + +// 包装一下 gin.ResponseWriter,通过这种方式拦截写响应 +// 让gin写响应的时候先写到 bodyLogWriter 再写gin.ResponseWriter , +// 这样利用中间件里输出访问日志时就能拿到响应了 +// https://stackoverflow.com/questions/38501325/how-to-log-response-body-in-gin +func (w bodyLogWriter) Write(b []byte) (int, error) { + w.body.Write(b) + return w.ResponseWriter.Write(b) +} + +func LogAccess() gin.HandlerFunc { + return func(c *gin.Context) { + // 保存body + var reqBody []byte + contentType := c.GetHeader("Content-Type") + // multipart/form-data 文件上传请求, 不在日志里记录body + if !strings.Contains(contentType, "multipart/form-data") { + reqBody, _ = io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewReader(reqBody)) + + // var request map[string]interface{} + // if err := c.ShouldBindBodyWith(&request, binding.JSON); err != nil { + // c.JSON(400, gin.H{"error": err.Error()}) + // return + // } + } + start := time.Now() + blw := &bodyLogWriter{body: bytes.NewBufferString(""), ResponseWriter: c.Writer} + c.Writer = blw + + accessLog(c, "access_start", time.Since(start), reqBody, nil) + defer func() { + var responseLogging string + if c.Writer.Size() > 10*1024 { // 响应大于10KB 不记录 + responseLogging = "Response data size is too Large to log" + } else { + responseLogging = blw.body.String() + } + accessLog(c, "access_end", time.Since(start), reqBody, responseLogging) + }() + c.Next() + return + } +} + +func accessLog(c *gin.Context, accessType string, dur time.Duration, body []byte, dataOut interface{}) { + req := c.Request + bodyStr := string(body) + query := req.URL.RawQuery + path := req.URL.Path + // TODO: 实现Token认证后再把访问日志里也加上token记录 + // token := c.Request.Header.Get("token") + logger.New(c).Info("AccessLog", + "type", accessType, + "ip", c.ClientIP(), + //"token", token, + "method", req.Method, + "path", path, + "query", query, + "body", bodyStr, + "output", dataOut, + "time(ms)", int64(dur/time.Millisecond)) +} diff --git a/model/model_select.go b/model/model_select.go index b860a61..d6eb7ef 100644 --- a/model/model_select.go +++ b/model/model_select.go @@ -2,17 +2,17 @@ package model import ( - "modelRT/constant" + constants "modelRT/constant" "modelRT/orm" ) // SelectModelByType define select the data structure for parsing based on the input model type func SelectModelByType(modelType int) BasicModelInterface { - if modelType == constant.BusbarType { + if modelType == constants.BusbarType { return &orm.BusbarSection{} - } else if modelType == constant.AsyncMotorType { + } else if modelType == constants.AsyncMotorType { return &orm.AsyncMotor{} - } else if modelType == constant.DemoType { + } else if modelType == constants.DemoType { return &orm.Demo{} } return nil diff --git a/network/circuit_diagram_update_request.go b/network/circuit_diagram_update_request.go index 2b3e3ba..1a6804c 100644 --- a/network/circuit_diagram_update_request.go +++ b/network/circuit_diagram_update_request.go @@ -4,7 +4,8 @@ package network import ( "fmt" - "modelRT/constant" + "modelRT/common/errcode" + constants "modelRT/constant" "github.com/gofrs/uuid" ) @@ -61,12 +62,12 @@ func ParseUUID(info TopologicChangeInfo) (TopologicUUIDChangeInfos, error) { UUIDChangeInfo.ChangeType = info.ChangeType switch info.ChangeType { - case constant.UUIDFromChangeType: + case constants.UUIDFromChangeType: if info.NewUUIDFrom == info.OldUUIDFrom { - return UUIDChangeInfo, fmt.Errorf("topologic change data check failed:%w", constant.ErrUUIDFromCheckT1) + return UUIDChangeInfo, fmt.Errorf("topologic change data check failed:%w", constants.ErrUUIDFromCheckT1) } if info.NewUUIDTo != info.OldUUIDTo { - return UUIDChangeInfo, fmt.Errorf("topologic change data check failed:%w", constant.ErrUUIDToCheckT1) + return UUIDChangeInfo, fmt.Errorf("topologic change data check failed:%w", constants.ErrUUIDToCheckT1) } oldUUIDFrom, err := uuid.FromString(info.OldUUIDFrom) @@ -87,12 +88,12 @@ func ParseUUID(info TopologicChangeInfo) (TopologicUUIDChangeInfos, error) { } UUIDChangeInfo.OldUUIDTo = OldUUIDTo UUIDChangeInfo.NewUUIDTo = OldUUIDTo - case constant.UUIDToChangeType: + case constants.UUIDToChangeType: if info.NewUUIDFrom != info.OldUUIDFrom { - return UUIDChangeInfo, fmt.Errorf("topologic change data check failed:%w", constant.ErrUUIDFromCheckT2) + return UUIDChangeInfo, fmt.Errorf("topologic change data check failed:%w", constants.ErrUUIDFromCheckT2) } if info.NewUUIDTo == info.OldUUIDTo { - return UUIDChangeInfo, fmt.Errorf("topologic change data check failed:%w", constant.ErrUUIDToCheckT2) + return UUIDChangeInfo, fmt.Errorf("topologic change data check failed:%w", constants.ErrUUIDToCheckT2) } oldUUIDFrom, err := uuid.FromString(info.OldUUIDFrom) @@ -113,12 +114,12 @@ func ParseUUID(info TopologicChangeInfo) (TopologicUUIDChangeInfos, error) { return UUIDChangeInfo, fmt.Errorf("convert data from string type to uuid type failed,new uuid_to value:%s", info.NewUUIDTo) } UUIDChangeInfo.NewUUIDTo = newUUIDTo - case constant.UUIDAddChangeType: + case constants.UUIDAddChangeType: if info.OldUUIDFrom != "" { - return UUIDChangeInfo, fmt.Errorf("topologic change data check failed:%w", constant.ErrUUIDFromCheckT3) + return UUIDChangeInfo, fmt.Errorf("topologic change data check failed:%w", constants.ErrUUIDFromCheckT3) } if info.OldUUIDTo != "" { - return UUIDChangeInfo, fmt.Errorf("topologic change data check failed:%w", constant.ErrUUIDToCheckT3) + return UUIDChangeInfo, fmt.Errorf("topologic change data check failed:%w", constants.ErrUUIDToCheckT3) } newUUIDFrom, err := uuid.FromString(info.NewUUIDFrom) @@ -133,7 +134,7 @@ func ParseUUID(info TopologicChangeInfo) (TopologicUUIDChangeInfos, error) { } UUIDChangeInfo.NewUUIDTo = newUUIDTo default: - return UUIDChangeInfo, constant.ErrUUIDChangeType + return UUIDChangeInfo, errcode.ErrUUIDChangeType } UUIDChangeInfo.Flag = info.Flag UUIDChangeInfo.Comment = info.Comment diff --git a/orm/circuit_diagram_topologic.go b/orm/circuit_diagram_topologic.go index 5b4b5b0..08bd4b3 100644 --- a/orm/circuit_diagram_topologic.go +++ b/orm/circuit_diagram_topologic.go @@ -6,7 +6,6 @@ import "github.com/gofrs/uuid" // Topologic structure define topologic info set of circuit diagram type Topologic struct { ID int64 `gorm:"column:id"` - PageID int64 `gorm:"column:page_id"` Flag int `gorm:"column:flag"` UUIDFrom uuid.UUID `gorm:"column:uuid_from"` UUIDTo uuid.UUID `gorm:"column:uuid_to"` diff --git a/pool/concurrency_anchor_parse.go b/pool/concurrency_anchor_parse.go index 8e2cb68..59411e3 100644 --- a/pool/concurrency_anchor_parse.go +++ b/pool/concurrency_anchor_parse.go @@ -7,12 +7,11 @@ import ( "modelRT/alert" "modelRT/config" - "modelRT/constant" + constants "modelRT/constant" "modelRT/diagram" "modelRT/logger" "github.com/panjf2000/ants/v2" - "go.uber.org/zap" ) // AnchorRealTimePool define anchor param pool of real time data @@ -31,12 +30,11 @@ func AnchorPoolInit(concurrentQuantity int) (pool *ants.PoolWithFunc, err error) // AnchorFunc defines func that process the real time data of component anchor params var AnchorFunc = func(poolConfig interface{}) { var firstStart bool - logger := logger.GetLoggerInstance() alertManager := alert.GetAlertMangerInstance() anchorChanConfig, ok := poolConfig.(config.AnchorChanConfig) if !ok { - logger.Error("conversion component anchor chan type failed") + logger.Error(anchorChanConfig.Ctx, "conversion component anchor chan type failed") return } @@ -56,12 +54,12 @@ var AnchorFunc = func(poolConfig interface{}) { for _, value := range anchorRealTimeDatas { anchorName, err := diagram.GetAnchorValue(componentID) if err != nil { - logger.Error("can not get anchor value from map by uuid", zap.Int64("component_id", componentID), zap.Error(err)) + logger.Error(anchorChanConfig.Ctx, "can not get anchor value from map by uuid", "component_id", componentID, "error", err) continue } if anchorName != anchorParaConfig.AnchorName { - logger.Error("anchor name not equal param config anchor value", zap.String("map_anchor_name", anchorName), zap.String("param_anchor_name", anchorParaConfig.AnchorName)) + logger.Error(anchorChanConfig.Ctx, "anchor name not equal param config anchor value", "map_anchor_name", anchorName, "param_anchor_name", anchorParaConfig.AnchorName) continue } @@ -74,7 +72,7 @@ var AnchorFunc = func(poolConfig interface{}) { event := alert.Event{ ComponentID: componentID, AnchorName: anchorName, - Level: constant.InfoAlertLevel, + Level: constants.InfoAlertLevel, Message: message, StartTime: time.Now().Unix(), } diff --git a/pool/concurrency_model_parse.go b/pool/concurrency_model_parse.go index b0a4fa6..797d197 100644 --- a/pool/concurrency_model_parse.go +++ b/pool/concurrency_model_parse.go @@ -10,21 +10,17 @@ import ( "modelRT/diagram" "modelRT/logger" "modelRT/model" - - "go.uber.org/zap" ) // ParseFunc defines func that parses the model data from postgres var ParseFunc = func(parseConfig interface{}) { - logger := logger.GetLoggerInstance() - modelParseConfig, ok := parseConfig.(config.ModelParseConfig) if !ok { - logger.Error("conversion model parse config type failed") + logger.Error(modelParseConfig.Ctx, "conversion model parse config type failed") return } - cancelCtx, cancel := context.WithTimeout(modelParseConfig.Context, 5*time.Second) + cancelCtx, cancel := context.WithTimeout(modelParseConfig.Ctx, 5*time.Second) defer cancel() pgClient := database.GetPostgresDBClient() @@ -33,10 +29,10 @@ var ParseFunc = func(parseConfig interface{}) { result := pgClient.WithContext(cancelCtx).Table(tableName).Where("component_id = ?", modelParseConfig.ComponentInfo.ID).Find(&unmarshalMap) if result.Error != nil { - logger.Error("query component detail info failed", zap.Error(result.Error)) + logger.Error(modelParseConfig.Ctx, "query component detail info failed", "error", result.Error) return } else if result.RowsAffected == 0 { - logger.Error("query component detail info from table is empty", zap.String("table_name", tableName)) + logger.Error(modelParseConfig.Ctx, "query component detail info from table is empty", "table_name", tableName) return } @@ -48,7 +44,7 @@ var ParseFunc = func(parseConfig interface{}) { } diagram.StoreAnchorValue(modelParseConfig.ComponentInfo.ID, anchorName) - GetComponentChan(modelParseConfig.Context, modelParseConfig.ComponentInfo.ID) + GetComponentChan(modelParseConfig.Ctx, modelParseConfig.ComponentInfo.ID) uuid := modelParseConfig.ComponentInfo.GlobalUUID.String() unmarshalMap["id"] = modelParseConfig.ComponentInfo.ID diff --git a/real-time-data/kafka.go b/real-time-data/kafka.go index 1cfa31f..d6b857c 100644 --- a/real-time-data/kafka.go +++ b/real-time-data/kafka.go @@ -8,7 +8,6 @@ import ( "modelRT/logger" "github.com/confluentinc/confluent-kafka-go/kafka" - "go.uber.org/zap" ) // RealTimeDataComputer continuously processing real-time data from Kafka specified topics @@ -17,9 +16,6 @@ func RealTimeDataComputer(ctx context.Context, consumerConfig kafka.ConfigMap, t ctx, cancel := context.WithCancel(ctx) defer cancel() - // get a logger - logger := logger.GetLoggerInstance() - // setup a channel to listen for interrupt signals // TODO 将中断信号放到入参中 interrupt := make(chan struct{}, 1) @@ -30,13 +26,13 @@ func RealTimeDataComputer(ctx context.Context, consumerConfig kafka.ConfigMap, t // create a new consumer consumer, err := kafka.NewConsumer(&consumerConfig) if err != nil { - logger.Error("init kafka consume by config failed", zap.Any("config", consumerConfig), zap.Error(err)) + logger.Error(ctx, "init kafka consume by config failed", "config", consumerConfig, "error", err) } // subscribe to the topic err = consumer.SubscribeTopics(topics, nil) if err != nil { - logger.Error("subscribe to the topic failed", zap.Strings("topic", topics), zap.Error(err)) + logger.Error(ctx, "subscribe to the topic failed", "topic", topics, "error", err) } // start a goroutine to handle shutdown @@ -51,17 +47,17 @@ func RealTimeDataComputer(ctx context.Context, consumerConfig kafka.ConfigMap, t msg, err := consumer.ReadMessage(timeoutDuration) if err != nil { if ctx.Err() == context.Canceled { - logger.Info("context canceled, stopping read loop") + logger.Info(ctx, "context canceled, stopping read loop") break } - logger.Error("consumer read message failed", zap.Error(err)) + logger.Error(ctx, "consumer read message failed", "error", err) continue } // TODO 使用 ants.pool处理 kafka 的订阅数据 _, err = consumer.CommitMessage(msg) if err != nil { - logger.Error("manual submission information failed", zap.Any("message", msg), zap.Error(err)) + logger.Error(ctx, "manual submission information failed", "message", msg, "error", err) } } } diff --git a/real-time-data/real_time_data_receive.go b/real-time-data/real_time_data_receive.go index 2d70b5f..61bc2f5 100644 --- a/real-time-data/real_time_data_receive.go +++ b/real-time-data/real_time_data_receive.go @@ -5,13 +5,11 @@ import ( "context" "modelRT/config" - "modelRT/constant" + constants "modelRT/constant" "modelRT/diagram" "modelRT/logger" "modelRT/network" "modelRT/pool" - - "go.uber.org/zap" ) // RealTimeDataChan define channel of real time data receive @@ -23,8 +21,6 @@ func init() { // ReceiveChan define func of real time data receive and process func ReceiveChan(ctx context.Context) { - logger := logger.GetLoggerInstance() - for { select { case <-ctx.Done(): @@ -34,13 +30,13 @@ func ReceiveChan(ctx context.Context) { componentID := realTimeData.PayLoad.ComponentID component, err := diagram.GetComponentMap(componentID) if err != nil { - logger.Error("query component info from diagram map by componet id failed", zap.Int64("component_id", componentID), zap.Error(err)) + logger.Error(ctx, "query component info from diagram map by componet id failed", "component_id", componentID, "error", err) continue } componentType := component["component_type"].(int) - if componentType != constant.DemoType { - logger.Error("can not process real time data of component type not equal DemoType", zap.Int64("component_id", componentID)) + if componentType != constants.DemoType { + logger.Error(ctx, "can not process real time data of component type not equal DemoType", "component_id", componentID) continue } diff --git a/router/diagram.go b/router/diagram.go new file mode 100644 index 0000000..6bf9451 --- /dev/null +++ b/router/diagram.go @@ -0,0 +1,17 @@ +package router + +import ( + "modelRT/handler" + + "github.com/gin-gonic/gin" +) + +// RegisterRoutes define func of register diagram routes +func registerDiagramRoutes(rg *gin.RouterGroup) { + g := rg.Group("/diagram/") + // TODO add diagram middleware + g.GET("load", handler.CircuitDiagramLoadHandler) + g.POST("create", handler.CircuitDiagramCreateHandler) + g.POST("update", handler.CircuitDiagramUpdateHandler) + g.POST("delete", handler.CircuitDiagramDeleteHandler) +} diff --git a/router/router.go b/router/router.go new file mode 100644 index 0000000..3e9b3dc --- /dev/null +++ b/router/router.go @@ -0,0 +1,22 @@ +package router + +import ( + "time" + + "modelRT/middleware" + + "github.com/gin-gonic/gin" +) + +var limiter *middleware.Limiter + +func init() { + limiter = middleware.NewLimiter(10, 1*time.Minute) // 设置限流器,允许每分钟最多请求10次 +} + +func RegisterRoutes(engine *gin.Engine) { + // use global middlewares + engine.Use(middleware.StartTrace(), limiter.Middleware) + routeGroup := engine.Group("") + registerDiagramRoutes(routeGroup) +} diff --git a/sharememory/share_memeory.go b/sharememory/share_memeory.go new file mode 100644 index 0000000..6667cfe --- /dev/null +++ b/sharememory/share_memeory.go @@ -0,0 +1,98 @@ +package sharememory + +import ( + "fmt" + "unsafe" + + "modelRT/orm" + + "golang.org/x/sys/unix" +) + +// CreateShareMemory defines a function to create a shared memory +func CreateShareMemory(key uintptr, structSize uintptr) (uintptr, error) { + // logger := logger.GetLoggerInstance() + // create shared memory + shmID, _, err := unix.Syscall(unix.SYS_SHMGET, key, structSize, unix.IPC_CREAT|0o666) + if err != 0 { + // logger.Error(fmt.Sprintf("create shared memory by key %v failed:", key), zap.Error(err)) + return 0, fmt.Errorf("create shared memory failed:%w", err) + } + + // attach shared memory + shmAddr, _, err := unix.Syscall(unix.SYS_SHMAT, shmID, 0, 0) + if err != 0 { + // logger.Error(fmt.Sprintf("attach shared memory by shmID %v failed:", shmID), zap.Error(err)) + return 0, fmt.Errorf("attach shared memory failed:%w", err) + } + return shmAddr, nil +} + +// ReadComponentFromShareMemory defines a function to read component value from shared memory +func ReadComponentFromShareMemory(key uintptr, componentInfo *orm.Component) error { + structSize := unsafe.Sizeof(orm.Component{}) + shmID, _, err := unix.Syscall(unix.SYS_SHMGET, key, uintptr(int(structSize)), 0o666) + if err != 0 { + return fmt.Errorf("get shared memory failed:%w", err) + } + + shmAddr, _, err := unix.Syscall(unix.SYS_SHMAT, shmID, 0, 0) + if err != 0 { + return fmt.Errorf("attach shared memory failed:%w", err) + } + + // 读取共享内存中的数据 + componentInfo = (*orm.Component)(unsafe.Pointer(shmAddr + structSize)) + + // Detach shared memory + unix.Syscall(unix.SYS_SHMDT, shmAddr, 0, 0) + return nil +} + +func WriteComponentInShareMemory(key uintptr, componentInfo *orm.Component) error { + structSize := unsafe.Sizeof(orm.Component{}) + shmID, _, err := unix.Syscall(unix.SYS_SHMGET, key, uintptr(int(structSize)), 0o666) + if err != 0 { + return fmt.Errorf("get shared memory failed:%w", err) + } + + shmAddr, _, err := unix.Syscall(unix.SYS_SHMAT, shmID, 0, 0) + if err != 0 { + return fmt.Errorf("attach shared memory failed:%w", err) + } + + obj := (*orm.Component)(unsafe.Pointer(shmAddr + unsafe.Sizeof(structSize))) + fmt.Println(obj) + obj.ComponentType = componentInfo.ComponentType + + // id integer NOT NULL DEFAULT nextval('component_id_seq'::regclass), + // global_uuid uuid NOT NULL DEFAULT gen_random_uuid(), + // nspath character varying(32) COLLATE pg_catalog."default", + // tag character varying(32) COLLATE pg_catalog."default" NOT NULL, + // name character varying(64) COLLATE pg_catalog."default" NOT NULL, + // description character varying(512) COLLATE pg_catalog."default" NOT NULL DEFAULT ''::character varying, + // grid character varying(64) COLLATE pg_catalog."default" NOT NULL, + // zone character varying(64) COLLATE pg_catalog."default" NOT NULL, + // station character varying(64) COLLATE pg_catalog."default" NOT NULL, + // type integer NOT NULL, + // in_service boolean DEFAULT false, + // state integer NOT NULL DEFAULT 0, + // connected_bus jsonb NOT NULL DEFAULT '{}'::jsonb, + // label jsonb NOT NULL DEFAULT '{}'::jsonb, + // context jsonb NOT NULL DEFAULT '{}'::jsonb, + // page_id integer NOT NULL, + // op integer NOT NULL DEFAULT '-1'::integer, + // ts timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP, + + unix.Syscall(unix.SYS_SHMDT, shmAddr, 0, 0) + return nil +} + +// DeleteShareMemory defines a function to delete shared memory +func DeleteShareMemory(key uintptr) error { + _, _, err := unix.Syscall(unix.SYS_SHM_UNLINK, key, 0, 0o666) + if err != 0 { + return fmt.Errorf("get shared memory failed:%w", err) + } + return nil +} diff --git a/sql/topologic.go b/sql/topologic.go index 58be916..95f292d 100644 --- a/sql/topologic.go +++ b/sql/topologic.go @@ -3,15 +3,12 @@ package sql // RecursiveSQL define Topologic table recursive query statement var RecursiveSQL = `WITH RECURSIVE recursive_tree as ( - SELECT uuid_from,uuid_to,page_id,flag + SELECT uuid_from,uuid_to,flag FROM "Topologic" - WHERE uuid_from is null and page_id = ? + WHERE uuid_from = ? UNION ALL SELECT t.uuid_from,t.uuid_to,t.page_id,t.flag FROM "Topologic" t JOIN recursive_tree rt ON t.uuid_from = rt.uuid_to ) SELECT * FROM recursive_tree;` - -// TODO 为 Topologic 表增加唯一索引 -// CREATE UNIQUE INDEX uuid_from_to_page_id_idx ON public."Topologic"(uuid_from,uuid_to,page_id); diff --git a/test/distributedlock/rwlock_test.go b/test/distributedlock/rwlock_test.go new file mode 100644 index 0000000..e81936c --- /dev/null +++ b/test/distributedlock/rwlock_test.go @@ -0,0 +1,563 @@ +package distributedlock_test + +import ( + "context" + "strings" + "testing" + "time" + + dl "modelRT/distributedlock" + constants "modelRT/distributedlock/constant" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +var rdb *redis.Client + +func init() { + rdb = redis.NewClient(&redis.Options{ + Network: "tcp", + Addr: "192.168.2.104:30001", + // pool config + PoolSize: 100, // max connections + PoolFIFO: true, + PoolTimeout: 4 * time.Second, + MinIdleConns: 10, // min idle connections + MaxIdleConns: 20, // max idle connections + // tiemout config + DialTimeout: 5 * time.Second, + ReadTimeout: 3 * time.Second, + WriteTimeout: 3 * time.Second, + }) +} + +func TestRWLockRLockAndUnRLock(t *testing.T) { + ctx := context.TODO() + + rwLocker := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 120, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a", + }) + + duration := 10 * time.Second + // 第一次加读锁 + err := rwLocker.RLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey := strings.Join([]string{rwLocker.RWTokenTimeoutPrefix, rwLocker.Token}, ":") + num, err := rdb.HGet(ctx, rwLocker.Key, tokenKey).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + err = rwLocker.UnRLock(ctx) + assert.Equal(t, nil, err) + + num, err = rdb.HGet(ctx, rwLocker.Key, tokenKey).Int() + assert.Equal(t, redis.Nil, err) + assert.Equal(t, 0, num) + t.Log("rwLock rlock and unrlock test success") + return +} + +func TestRWLockReentrantRLock(t *testing.T) { + ctx := context.TODO() + + rwLocker := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 120, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a", + }) + + duration := 10 * time.Second + // 第一次加读锁 + err := rwLocker.RLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey := strings.Join([]string{rwLocker.RWTokenTimeoutPrefix, rwLocker.Token}, ":") + num, err := rdb.HGet(ctx, rwLocker.Key, tokenKey).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + // 第二次加读锁 + err = rwLocker.RLock(ctx, duration) + assert.Equal(t, nil, err) + + num, err = rdb.HGet(ctx, rwLocker.Key, tokenKey).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 2, num) + + // 第一次解读锁 + err = rwLocker.UnRLock(ctx) + assert.Equal(t, nil, err) + + num, err = rdb.HGet(ctx, rwLocker.Key, tokenKey).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + // 第二次解读锁 + err = rwLocker.UnRLock(ctx) + assert.Equal(t, nil, err) + + num, err = rdb.HGet(ctx, rwLocker.Key, tokenKey).Int() + assert.Equal(t, redis.Nil, err) + assert.Equal(t, 0, num) + t.Log("rwLock reentrant lock test success") + return +} + +func TestRWLockRefreshRLock(t *testing.T) { + ctx := context.TODO() + + rwLocker := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 10, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a", + }) + + duration := 10 * time.Second + // 第一次加读锁 + err := rwLocker.RLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey := strings.Join([]string{rwLocker.RWTokenTimeoutPrefix, rwLocker.Token}, ":") + num, err := rdb.HGet(ctx, rwLocker.Key, tokenKey).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + time.Sleep(10 * time.Second) + script := `return redis.call('httl', KEYS[1], 'fields', '1', ARGV[1]);` + result, err := rdb.Eval(ctx, script, []string{rwLocker.Key}, tokenKey).Result() + assert.Equal(t, nil, err) + // ttls, ok := result.([]interface{}) + ttls, ok := result.([]any) + assert.Equal(t, true, ok) + ttl, ok := ttls[0].(int64) + assert.Equal(t, true, ok) + compareValue := int64(8) + assert.Greater(t, ttl, compareValue) + + err = rwLocker.UnRLock(ctx) + assert.Equal(t, nil, err) + + num, err = rdb.HGet(ctx, rwLocker.Key, tokenKey).Int() + assert.Equal(t, redis.Nil, err) + assert.Equal(t, 0, num) + t.Log("rwLock refresh lock test success") + return +} + +func TestRWLock2ClientRLock(t *testing.T) { + ctx := context.TODO() + + rwLocker1 := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 120, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a", + }) + + rwLocker2 := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 120, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc5577", + }) + + duration := 10 * time.Second + // locker1加读锁 + err := rwLocker1.RLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey1 := strings.Join([]string{rwLocker1.RWTokenTimeoutPrefix, rwLocker1.Token}, ":") + num, err := rdb.HGet(ctx, rwLocker1.Key, tokenKey1).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + // locker2加读锁 + err = rwLocker2.RLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey2 := strings.Join([]string{rwLocker2.RWTokenTimeoutPrefix, rwLocker2.Token}, ":") + num, err = rdb.HGet(ctx, rwLocker2.Key, tokenKey2).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + err = rdb.HLen(ctx, rwLocker1.Key).Err() + assert.Equal(t, nil, err) + hLen := rdb.HLen(ctx, rwLocker1.Key).Val() + assert.Equal(t, int64(3), hLen) + + // locker1解读锁 + err = rwLocker1.UnRLock(ctx) + assert.Equal(t, nil, err) + + // locker2解读锁 + err = rwLocker2.UnRLock(ctx) + assert.Equal(t, nil, err) + + err = rdb.Exists(ctx, rwLocker1.Key).Err() + assert.Equal(t, nil, err) + existNum := rdb.Exists(ctx, rwLocker1.Key).Val() + assert.Equal(t, int64(0), existNum) + t.Log("rwLock 2 client lock test success") + return +} + +func TestRWLock2CWith2DifTimeRLock(t *testing.T) { + ctx := context.TODO() + + rwLocker1 := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 120, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a", + }) + + rwLocker2 := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 30, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc5577", + }) + + duration := 10 * time.Second + // locker1加读锁 + err := rwLocker1.RLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey1 := strings.Join([]string{rwLocker1.RWTokenTimeoutPrefix, rwLocker1.Token}, ":") + num, err := rdb.HGet(ctx, rwLocker1.Key, tokenKey1).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + // locker2加读锁 + err = rwLocker2.RLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey2 := strings.Join([]string{rwLocker2.RWTokenTimeoutPrefix, rwLocker2.Token}, ":") + num, err = rdb.HGet(ctx, rwLocker2.Key, tokenKey2).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + err = rdb.HLen(ctx, rwLocker1.Key).Err() + assert.Equal(t, nil, err) + hLen := rdb.HLen(ctx, rwLocker1.Key).Val() + assert.Equal(t, int64(3), hLen) + + script := `return redis.call('httl', KEYS[1], 'fields', '1', ARGV[1]);` + result, err := rdb.Eval(ctx, script, []string{rwLocker1.Key}, tokenKey1).Result() + assert.Equal(t, nil, err) + // ttls, ok := result.([]interface{}) + ttls, ok := result.([]any) + assert.Equal(t, true, ok) + ttl, ok := ttls[0].(int64) + assert.Equal(t, true, ok) + compareValue := int64(110) + assert.Greater(t, ttl, compareValue) + + // locker1解读锁 + err = rwLocker1.UnRLock(ctx) + assert.Equal(t, nil, err) + + hashTTL := rdb.TTL(ctx, rwLocker1.Key).Val().Seconds() + assert.Greater(t, hashTTL, float64(20)) + + // locker2解读锁 + err = rwLocker2.UnRLock(ctx) + assert.Equal(t, nil, err) + + err = rdb.Exists(ctx, rwLocker1.Key).Err() + assert.Equal(t, nil, err) + existNum := rdb.Exists(ctx, rwLocker1.Key).Val() + assert.Equal(t, int64(0), existNum) + t.Log("rwLock 2 client lock test success") + return +} + +func TestRWLock2CWithTimeTransformRLock(t *testing.T) { + ctx := context.TODO() + + rwLocker1 := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 30, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a", + }) + + rwLocker2 := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 120, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc5577", + }) + + duration := 10 * time.Second + // locker1加读锁 + err := rwLocker1.RLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey1 := strings.Join([]string{rwLocker1.RWTokenTimeoutPrefix, rwLocker1.Token}, ":") + num, err := rdb.HGet(ctx, rwLocker1.Key, tokenKey1).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + // locker2加读锁 + err = rwLocker2.RLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey2 := strings.Join([]string{rwLocker2.RWTokenTimeoutPrefix, rwLocker2.Token}, ":") + num, err = rdb.HGet(ctx, rwLocker2.Key, tokenKey2).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + err = rdb.HLen(ctx, rwLocker1.Key).Err() + assert.Equal(t, nil, err) + hLen := rdb.HLen(ctx, rwLocker1.Key).Val() + assert.Equal(t, int64(3), hLen) + + hashTTL := rdb.TTL(ctx, rwLocker2.Key).Val().Seconds() + assert.Greater(t, hashTTL, float64(100)) + + // locker2解读锁 + err = rwLocker2.UnRLock(ctx) + assert.Equal(t, nil, err) + + time.Sleep(10 * time.Second) + hashTTL = rdb.TTL(ctx, rwLocker1.Key).Val().Seconds() + assert.Greater(t, hashTTL, float64(15)) + + // locker1解读锁 + err = rwLocker1.UnRLock(ctx) + assert.Equal(t, nil, err) + + err = rdb.Exists(ctx, rwLocker1.Key).Err() + assert.Equal(t, nil, err) + existNum := rdb.Exists(ctx, rwLocker1.Key).Val() + assert.Equal(t, int64(0), existNum) + t.Log("rwLock 2 client lock test success") + return +} + +func TestRWLockWLockAndUnWLock(t *testing.T) { + ctx := context.TODO() + + rwLocker := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 120, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a", + }) + + duration := 10 * time.Second + // 第一次加读锁 + err := rwLocker.WLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey := strings.Join([]string{rwLocker.RWTokenTimeoutPrefix, rwLocker.Token}, ":") + num, err := rdb.HGet(ctx, rwLocker.Key, tokenKey).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + err = rwLocker.UnWLock(ctx) + assert.Equal(t, nil, err) + + num, err = rdb.HGet(ctx, rwLocker.Key, tokenKey).Int() + assert.Equal(t, redis.Nil, err) + assert.Equal(t, 0, num) + t.Log("rwLock rlock and unrlock test success") + return +} + +func TestRWLockReentrantWLock(t *testing.T) { + ctx := context.TODO() + + rwLocker := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 120, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a", + }) + + duration := 10 * time.Second + // 第一次加写锁 + err := rwLocker.WLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey := strings.Join([]string{rwLocker.RWTokenTimeoutPrefix, rwLocker.Token}, ":") + num, err := rdb.HGet(ctx, rwLocker.Key, tokenKey).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + // 第二次加写锁 + err = rwLocker.WLock(ctx, duration) + assert.Equal(t, nil, err) + + num, err = rdb.HGet(ctx, rwLocker.Key, tokenKey).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 2, num) + + // 第一次解写锁 + err = rwLocker.UnWLock(ctx) + assert.Equal(t, nil, err) + + num, err = rdb.HGet(ctx, rwLocker.Key, tokenKey).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + // 第二次解写锁 + err = rwLocker.UnWLock(ctx) + assert.Equal(t, nil, err) + + num, err = rdb.HGet(ctx, rwLocker.Key, tokenKey).Int() + assert.Equal(t, redis.Nil, err) + assert.Equal(t, 0, num) + t.Log("rwLock reentrant lock test success") + return +} + +func TestRWLock2CWithRLockAndWLockFailed(t *testing.T) { + ctx := context.TODO() + + rwLocker1 := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 120, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a", + }) + + rwLocker2 := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 30, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc5577", + }) + + duration := 10 * time.Second + // locker1加读锁 + err := rwLocker1.RLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey1 := strings.Join([]string{rwLocker1.RWTokenTimeoutPrefix, rwLocker1.Token}, ":") + num, err := rdb.HGet(ctx, rwLocker1.Key, tokenKey1).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + // locker2加写锁锁 + duration = 10 * time.Second + err = rwLocker2.WLock(ctx, duration) + assert.Equal(t, constants.AcquireTimeoutErr, err) + + err = rwLocker1.UnRLock(ctx) + assert.Equal(t, nil, err) + + t.Log("rwLock 2 client lock test success") + return +} + +func TestRWLock2CWithRLockAndWLockSucceed(t *testing.T) { + ctx := context.TODO() + rwLocker1 := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 120, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a", + }) + + rwLocker2 := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 120, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc5577", + }) + + duration := 10 * time.Second + // locker1加读锁 + err := rwLocker1.RLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey1 := strings.Join([]string{rwLocker1.RWTokenTimeoutPrefix, rwLocker1.Token}, ":") + num, err := rdb.HGet(ctx, rwLocker1.Key, tokenKey1).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + go func() { + // locker1解写锁 + time.Sleep(10 * time.Second) + err = rwLocker1.UnRLock(ctx) + assert.Equal(t, nil, err) + }() + + // locker2加写锁 + duration = 30 * time.Second + err = rwLocker2.WLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey2 := strings.Join([]string{rwLocker2.RWTokenTimeoutPrefix, rwLocker2.Token}, ":") + num, err = rdb.HGet(ctx, rwLocker2.Key, tokenKey2).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + // locker2解写锁 + err = rwLocker2.UnWLock(ctx) + assert.Equal(t, nil, err) + + t.Log("rwLock 2 client lock test success") + return +} + +func TestRWLock2CWithWLockAndRLock(t *testing.T) { + ctx := context.TODO() + + rwLocker1 := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 120, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a", + }) + + rwLocker2 := dl.GetRWLocker(rdb, &dl.RedissionLockConfig{ + LockLeaseTime: 30, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc5577", + }) + + duration := 10 * time.Second + // locker1加写锁 + err := rwLocker1.WLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey1 := strings.Join([]string{rwLocker1.RWTokenTimeoutPrefix, rwLocker1.Token}, ":") + num, err := rdb.HGet(ctx, rwLocker1.Key, tokenKey1).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + go func() { + // locker1解写锁 + time.Sleep(10 * time.Second) + err = rwLocker1.UnWLock(ctx) + assert.Equal(t, nil, err) + }() + + // locker2加读锁 + duration = 30 * time.Second + err = rwLocker2.RLock(ctx, duration) + assert.Equal(t, nil, err) + + tokenKey2 := strings.Join([]string{rwLocker2.RWTokenTimeoutPrefix, rwLocker2.Token}, ":") + num, err = rdb.HGet(ctx, rwLocker2.Key, tokenKey2).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + // locker2解读锁 + err = rwLocker2.UnRLock(ctx) + assert.Equal(t, nil, err) + + t.Log("rwLock 2 client lock test success") + return +} diff --git a/test/orm/topologic_test.go b/test/orm/topologic_test.go new file mode 100644 index 0000000..f5443fc --- /dev/null +++ b/test/orm/topologic_test.go @@ -0,0 +1,68 @@ +package orm_test + +import ( + "context" + "database/sql" + "os" + "regexp" + "testing" + + "modelRT/database" + "modelRT/network" + "modelRT/orm" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +var ( + mock sqlmock.Sqlmock + err error + baseDB *sql.DB + pgClient *gorm.DB +) + +func TestMain(m *testing.M) { + baseDB, mock, err = sqlmock.New() + if err != nil { + panic(err) + } + // 把项目使用的DB连接换成sqlmock的DB连接 + pgClient, _ = gorm.Open(mysql.New(mysql.Config{ + Conn: baseDB, + SkipInitializeWithVersion: true, + DefaultStringSize: 0, + })) + os.Exit(m.Run()) +} + +func TestUserDao_CreateUser(t *testing.T) { + topologicInfo := &orm.Topologic{ + UUIDFrom: uuid.FromStringOrNil("70c190f2-8a60-42a9-b143-ec5f87e0aa6b"), + UUIDTo: uuid.FromStringOrNil("70c190f2-8a75-42a9-b166-ec5f87e0aa6b"), + Comment: "test", + Flag: 1, + } + + // ud := dao2.NewUserDao(context.TODO()) + mock.ExpectBegin() + mock.ExpectExec(regexp.QuoteMeta("INSERT INTO `Topologic`")). + WithArgs(topologicInfo.Flag, topologicInfo.UUIDFrom, topologicInfo.UUIDTo, topologicInfo.Comment). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + err := database.CreateTopologicIntoDB(context.TODO(), pgClient, 1, []network.TopologicUUIDCreateInfo{ + { + UUIDFrom: uuid.FromStringOrNil("70c190f2-8a60-42a9-b143-ec5f87e0aa6b"), + UUIDTo: uuid.FromStringOrNil("70c190f2-8a75-42a9-b166-ec5f87e0aa6b"), + Comment: "test", + Flag: 1, + }, + }) + assert.Nil(t, err) + err = mock.ExpectationsWereMet() + assert.Nil(t, err) +} diff --git a/util/redis_init.go b/util/redis_init.go new file mode 100644 index 0000000..441f4c4 --- /dev/null +++ b/util/redis_init.go @@ -0,0 +1,36 @@ +package util + +import ( + "context" + "fmt" + + "github.com/redis/go-redis/v9" +) + +// NewRedisClient define func of initialize the Redis client +func NewRedisClient(addr string, opts ...RedisOption) (*redis.Client, error) { + // default options + options := RedisOptions{ + redisOptions: &redis.Options{ + Addr: addr, + }, + } + + // Apply configuration options from config + for _, opt := range opts { + opt(&options) + } + + // create redis client + client := redis.NewClient(options.redisOptions) + + if options.timeout > 0 { + // check if the connection is successful + ctx, cancel := context.WithTimeout(context.Background(), options.timeout) + defer cancel() + if err := client.Ping(ctx).Err(); err != nil { + return nil, fmt.Errorf("can not connect redis:%v", err) + } + } + return client, nil +} diff --git a/util/redis_options.go b/util/redis_options.go new file mode 100644 index 0000000..23ed270 --- /dev/null +++ b/util/redis_options.go @@ -0,0 +1,59 @@ +package util + +import ( + "errors" + "time" + + "github.com/redis/go-redis/v9" +) + +type RedisOptions struct { + redisOptions *redis.Options + timeout time.Duration +} + +type RedisOption func(*RedisOptions) error + +// WithPassword define func of configure redis password options +func WithPassword(password string) RedisOption { + return func(o *RedisOptions) error { + if password == "" { + return errors.New("password is empty") + } + o.redisOptions.Password = password + return nil + } +} + +// WithTimeout define func of configure redis timeout options +func WithTimeout(timeout time.Duration) RedisOption { + return func(o *RedisOptions) error { + if timeout < 0 { + return errors.New("timeout can not be negative") + } + o.timeout = timeout + return nil + } +} + +// WithDB define func of configure redis db options +func WithDB(db int) RedisOption { + return func(o *RedisOptions) error { + if db < 0 { + return errors.New("db can not be negative") + } + o.redisOptions.DB = db + return nil + } +} + +// WithPoolSize define func of configure pool size options +func WithPoolSize(poolSize int) RedisOption { + return func(o *RedisOptions) error { + if poolSize <= 0 { + return errors.New("pool size must be greater than 0") + } + o.redisOptions.PoolSize = poolSize + return nil + } +} diff --git a/util/trace.go b/util/trace.go new file mode 100644 index 0000000..eefe769 --- /dev/null +++ b/util/trace.go @@ -0,0 +1,45 @@ +package util + +import ( + "context" + "encoding/binary" + "math/rand" + "net" + "strconv" + "strings" + "time" +) + +// GenerateSpanID define func of generate spanID +func GenerateSpanID(addr string) string { + strAddr := strings.Split(addr, ":") + ip := strAddr[0] + ipLong, _ := IP2Long(ip) + times := uint64(time.Now().UnixNano()) + rand.NewSource(time.Now().UnixNano()) + spanID := ((times ^ uint64(ipLong)) << 32) | uint64(rand.Int31()) + return strconv.FormatUint(spanID, 16) +} + +// IP2Long define func of convert ip to unit32 type +func IP2Long(ip string) (uint32, error) { + ipAddr, err := net.ResolveIPAddr("ip", ip) + if err != nil { + return 0, err + } + return binary.BigEndian.Uint32(ipAddr.IP.To4()), nil +} + +// GetTraceInfoFromCtx define func of get trace info from context +func GetTraceInfoFromCtx(ctx context.Context) (traceID, spanID, pSpanID string) { + if ctx.Value("traceid") != nil { + traceID = ctx.Value("traceid").(string) + } + if ctx.Value("spanid") != nil { + spanID = ctx.Value("spanid").(string) + } + if ctx.Value("pspanid") != nil { + pSpanID = ctx.Value("pspanid").(string) + } + return +}