package distributedlock import ( "context" "errors" "fmt" "strings" "sync" "time" "modelRT/distributedlock/constant" luascript "modelRT/distributedlock/luascript" "modelRT/logger" uuid "github.com/gofrs/uuid" "github.com/redis/go-redis/v9" "go.uber.org/zap" ) const ( internalLockLeaseTime = uint64(30) unlockMessage = 0 ) // RedissionLockConfig define redission lock config type RedissionLockConfig struct { LockLeaseTime uint64 Token string Prefix string ChanPrefix string TimeoutPrefix string Key string NeedRefresh bool } type redissionLocker struct { lockLeaseTime uint64 token string key string waitChanKey string needRefresh bool exit chan struct{} client *redis.Client once *sync.Once logger *zap.Logger } func (rl *redissionLocker) Lock(ctx context.Context, timeout ...time.Duration) error { if rl.exit == nil { rl.exit = make(chan struct{}) } result := rl.tryLock(ctx).(*constant.RedisResult) if result.Code == constant.UnknownInternalError { rl.logger.Error(result.OutputResultMessage()) return fmt.Errorf("get lock failed:%w", result) } if (result.Code == constant.LockSuccess) && rl.needRefresh { rl.once.Do(func() { // async refresh lock timeout unitl receive exit singal go rl.refreshLockTimeout(ctx) }) return nil } subMsg := make(chan struct{}, 1) defer close(subMsg) sub := rl.client.Subscribe(ctx, rl.waitChanKey) defer sub.Close() go rl.subscribeLock(ctx, sub, subMsg) if len(timeout) > 0 && timeout[0] > 0 { acquireTimer := time.NewTimer(timeout[0]) for { select { case _, ok := <-subMsg: if !ok { err := errors.New("failed to read the lock waiting for for the channel message") rl.logger.Error("failed to read the lock waiting for for the channel message") return err } resultErr := rl.tryLock(ctx).(*constant.RedisResult) if (resultErr.Code == constant.LockFailure) || (resultErr.Code == constant.UnknownInternalError) { rl.logger.Info(resultErr.OutputResultMessage()) continue } if resultErr.Code == constant.LockSuccess { rl.logger.Info(resultErr.OutputResultMessage()) return nil } case <-acquireTimer.C: err := errors.New("the waiting time for obtaining the lock operation has timed out") rl.logger.Info("the waiting time for obtaining the lock operation has timed out") return err } } } return fmt.Errorf("lock the redis lock failed:%w", result) } // TODO 优化订阅流程 func (rl *redissionLocker) subscribeLock(ctx context.Context, sub *redis.PubSub, out chan struct{}) { if sub == nil || out == nil { return } rl.logger.Info("lock: enter sub routine", zap.String("token", rl.token)) // subCh := sub.Channel() // for msg := range subCh { // // 这里只会收到真正的数据消息 // fmt.Printf("Channel: %s, Payload: %s\n", // msg.Channel, // msg.Payload) // } receiveChan := make(chan interface{}, 2) go func() { for { msg, err := sub.Receive(ctx) if err != nil { if errors.Is(err, redis.ErrClosed) { return } rl.logger.Error("sub receive message failed", zap.Error(err)) continue } rl.logger.Info("sub receive message", zap.Any("msg", msg)) receiveChan <- msg } }() for { select { case <-rl.exit: return case msg := <-receiveChan: switch msg.(type) { case *redis.Subscription: // Ignore. case *redis.Pong: // Ignore. case *redis.Message: out <- struct{}{} default: } // case <-subCh: // out <- struct{}{} } } } /* KEYS[1]:锁的键名(key),通常是锁的唯一标识。 ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。 ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 */ func (rl *redissionLocker) refreshLockTimeout(ctx context.Context) { rl.logger.Info("lock refresh by key and token", zap.String("token", rl.token), zap.String("key", rl.key)) lockTime := time.Duration(rl.lockLeaseTime/3) * time.Second timer := time.NewTimer(lockTime) defer timer.Stop() for { select { case <-timer.C: // extend key lease time res := rl.client.Eval(ctx, luascript.RefreshLockScript, []string{rl.key}, rl.lockLeaseTime, rl.token) val, err := res.Int() if err != redis.Nil && err != nil { rl.logger.Info("lock refresh failed", zap.String("token", rl.token), zap.String("key", rl.key), zap.Error(err)) return } if constant.RedisCode(val) == constant.RefreshLockFailure { rl.logger.Error("lock refreash failed,can not find the lock by key and token", zap.String("token", rl.token), zap.String("key", rl.key)) break } if constant.RedisCode(val) == constant.RefreshLockSuccess { rl.logger.Info("lock refresh success by key and token", zap.String("token", rl.token), zap.String("key", rl.key)) } timer.Reset(lockTime) case <-rl.exit: return } } } func (rl *redissionLocker) cancelRefreshLockTime() { if rl.exit != nil { close(rl.exit) rl.once = &sync.Once{} } } /* KEYS[1]:锁的键名(key),通常是锁的唯一标识。 ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。 ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 */ func (rl *redissionLocker) tryLock(ctx context.Context) error { lockType := constant.LockType res := rl.client.Eval(ctx, luascript.LockScript, []string{rl.key}, rl.lockLeaseTime, rl.token) val, err := res.Int() if err != redis.Nil && err != nil { return constant.NewRedisResult(constant.UnknownInternalError, lockType, err.Error()) } return constant.NewRedisResult(constant.RedisCode(val), lockType, "") } /* KEYS[1]:锁的键名(key),通常是锁的唯一标识。 KEYS[2]:锁的释放通知频道(chankey),用于通知其他客户端锁已释放。 ARGV[1]:解锁消息(unlockMessage),用于通知其他客户端锁已释放。 ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 */ func (rl *redissionLocker) UnLock(ctx context.Context) error { res := rl.client.Eval(ctx, luascript.UnLockScript, []string{rl.key, rl.waitChanKey}, unlockMessage, rl.token) val, err := res.Int() if err != redis.Nil && err != nil { rl.logger.Info("unlock lock failed", zap.String("token", rl.token), zap.String("key", rl.key), zap.Error(err)) return fmt.Errorf("unlock lock failed:%w", constant.NewRedisResult(constant.UnknownInternalError, constant.UnLockType, err.Error())) } if constant.RedisCode(val) == constant.UnLockSuccess { if rl.needRefresh { rl.cancelRefreshLockTime() } rl.logger.Info("unlock lock success", zap.String("token", rl.token), zap.String("key", rl.key)) return nil } if constant.RedisCode(val) == constant.UnLocakFailureWithLockOccupancy { rl.logger.Info("unlock lock failed", zap.String("token", rl.token), zap.String("key", rl.key)) return fmt.Errorf("unlock lock failed:%w", constant.NewRedisResult(constant.UnLocakFailureWithLockOccupancy, constant.UnLockType, "")) } return nil } // TODO 优化 panic func GetLocker(client *redis.Client, ops *RedissionLockConfig) *redissionLocker { if ops.Token == "" { token, err := uuid.NewV4() if err != nil { panic(err) } ops.Token = token.String() } if len(ops.Prefix) <= 0 { ops.Prefix = "redission-lock" } if len(ops.ChanPrefix) <= 0 { ops.ChanPrefix = "redission-lock-channel" } if ops.LockLeaseTime == 0 { ops.LockLeaseTime = internalLockLeaseTime } r := &redissionLocker{ token: ops.Token, key: strings.Join([]string{ops.Prefix, ops.Key}, ":"), waitChanKey: strings.Join([]string{ops.ChanPrefix, ops.Key, "wait"}, ":"), needRefresh: ops.NeedRefresh, client: client, exit: make(chan struct{}), logger: logger.GetLoggerInstance(), } return r }