package distributedlock import ( "context" "errors" "fmt" "strings" "sync" "time" constants "modelRT/distributedlock/constant" luascript "modelRT/distributedlock/luascript" "modelRT/logger" uuid "github.com/gofrs/uuid" "github.com/redis/go-redis/v9" "go.uber.org/zap" ) const ( internalLockLeaseTime = uint64(30 * 1000) unlockMessage = 0 ) // RedissionLockConfig define redission lock config type RedissionLockConfig struct { LockLeaseTime uint64 Token string Prefix string ChanPrefix string TimeoutPrefix string Key string NeedRefresh bool } type redissionLocker struct { lockLeaseTime uint64 Token string Key string waitChanKey string needRefresh bool refreshExitChan chan struct{} subExitChan chan struct{} client *redis.Client refreshOnce *sync.Once } func (rl *redissionLocker) Lock(ctx context.Context, timeout ...time.Duration) error { if rl.refreshExitChan == nil { rl.refreshExitChan = make(chan struct{}) } result := rl.tryLock(ctx).(*constants.RedisResult) if result.Code == constants.UnknownInternalError { logger.Error(ctx, result.OutputResultMessage()) return fmt.Errorf("get lock failed:%w", result) } if (result.Code == constants.LockSuccess) && rl.needRefresh { rl.refreshOnce.Do(func() { // async refresh lock timeout unitl receive exit singal go rl.refreshLockTimeout(ctx) }) return nil } subMsg := make(chan struct{}, 1) defer close(subMsg) sub := rl.client.Subscribe(ctx, rl.waitChanKey) defer sub.Close() go rl.subscribeLock(ctx, sub, subMsg) if len(timeout) > 0 && timeout[0] > 0 { acquireTimer := time.NewTimer(timeout[0]) for { select { case _, ok := <-subMsg: if !ok { err := errors.New("failed to read the lock waiting for for the channel message") logger.Error(ctx, "failed to read the lock waiting for for the channel message") return err } resultErr := rl.tryLock(ctx).(*constants.RedisResult) if (resultErr.Code == constants.LockFailure) || (resultErr.Code == constants.UnknownInternalError) { logger.Info(ctx, resultErr.OutputResultMessage()) continue } if resultErr.Code == constants.LockSuccess { logger.Info(ctx, resultErr.OutputResultMessage()) return nil } case <-acquireTimer.C: err := errors.New("the waiting time for obtaining the lock operation has timed out") logger.Info(ctx, "the waiting time for obtaining the lock operation has timed out") return err } } } return fmt.Errorf("lock the redis lock failed:%w", result) } func (rl *redissionLocker) subscribeLock(ctx context.Context, sub *redis.PubSub, subMsgChan chan struct{}) { if sub == nil || subMsgChan == nil { return } logger.Info(ctx, "lock: enter sub routine", zap.String("token", rl.Token)) for { select { case <-rl.subExitChan: close(subMsgChan) return case <-sub.Channel(): // 这里只会收到真正的数据消息 subMsgChan <- struct{}{} default: } } } /* KEYS[1]:锁的键名(key),通常是锁的唯一标识。 ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。 ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 */ func (rl *redissionLocker) refreshLockTimeout(ctx context.Context) { logger.Info(ctx, "lock refresh by key and token", zap.String("token", rl.Token), zap.String("key", rl.Key)) lockTime := time.Duration(rl.lockLeaseTime/3) * time.Millisecond timer := time.NewTimer(lockTime) defer timer.Stop() for { select { case <-timer.C: // extend key lease time res := rl.client.Eval(ctx, luascript.RefreshLockScript, []string{rl.Key}, rl.lockLeaseTime, rl.Token) val, err := res.Int() if err != redis.Nil && err != nil { logger.Info(ctx, "lock refresh failed", "token", rl.Token, "key", rl.Key, "error", err) return } if constants.RedisCode(val) == constants.RefreshLockFailure { logger.Error(ctx, "lock refreash failed,can not find the lock by key and token", "token", rl.Token, "key", rl.Key) break } if constants.RedisCode(val) == constants.RefreshLockSuccess { logger.Info(ctx, "lock refresh success by key and token", "token", rl.Token, "key", rl.Key) } timer.Reset(lockTime) case <-rl.refreshExitChan: return } } } func (rl *redissionLocker) cancelRefreshLockTime() { if rl.refreshExitChan != nil { close(rl.refreshExitChan) rl.refreshOnce = &sync.Once{} } } func (rl *redissionLocker) closeSub(ctx context.Context, sub *redis.PubSub, noticeChan chan struct{}) { if sub != nil { err := sub.Close() if err != nil { logger.Error(ctx, "close sub failed", "token", rl.Token, "key", rl.Key, "error", err) } } if noticeChan != nil { close(noticeChan) } } /* KEYS[1]:锁的键名(key),通常是锁的唯一标识。 ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。 ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 */ func (rl *redissionLocker) tryLock(ctx context.Context) error { lockType := constants.LockType res := rl.client.Eval(ctx, luascript.LockScript, []string{rl.Key}, rl.lockLeaseTime, rl.Token) val, err := res.Int() if err != redis.Nil && err != nil { return constants.NewRedisResult(constants.UnknownInternalError, lockType, err.Error()) } return constants.NewRedisResult(constants.RedisCode(val), lockType, "") } /* KEYS[1]:锁的键名(key),通常是锁的唯一标识。 KEYS[2]:锁的释放通知频道(chankey),用于通知其他客户端锁已释放。 ARGV[1]:解锁消息(unlockMessage),用于通知其他客户端锁已释放。 ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 */ func (rl *redissionLocker) UnLock(ctx context.Context) error { res := rl.client.Eval(ctx, luascript.UnLockScript, []string{rl.Key, rl.waitChanKey}, unlockMessage, rl.Token) val, err := res.Int() if err != redis.Nil && err != nil { logger.Info(ctx, "unlock lock failed", zap.String("token", rl.Token), zap.String("key", rl.Key), zap.Error(err)) return fmt.Errorf("unlock lock failed:%w", constants.NewRedisResult(constants.UnknownInternalError, constants.UnLockType, err.Error())) } if constants.RedisCode(val) == constants.UnLockSuccess { if rl.needRefresh { rl.cancelRefreshLockTime() } logger.Info(ctx, "unlock lock success", zap.String("token", rl.Token), zap.String("key", rl.Key)) return nil } if constants.RedisCode(val) == constants.UnLocakFailureWithLockOccupancy { logger.Info(ctx, "unlock lock failed", zap.String("token", rl.Token), zap.String("key", rl.Key)) return fmt.Errorf("unlock lock failed:%w", constants.NewRedisResult(constants.UnLocakFailureWithLockOccupancy, constants.UnLockType, "")) } return nil } // TODO 优化 panic func GetLocker(client *redis.Client, ops *RedissionLockConfig) *redissionLocker { if ops.Token == "" { token, err := uuid.NewV4() if err != nil { panic(err) } ops.Token = token.String() } if len(ops.Prefix) <= 0 { ops.Prefix = "redission-lock" } if len(ops.ChanPrefix) <= 0 { ops.ChanPrefix = "redission-lock-channel" } if ops.LockLeaseTime == 0 { ops.LockLeaseTime = internalLockLeaseTime } r := &redissionLocker{ Token: ops.Token, Key: strings.Join([]string{ops.Prefix, ops.Key}, ":"), waitChanKey: strings.Join([]string{ops.ChanPrefix, ops.Key, "wait"}, ":"), needRefresh: ops.NeedRefresh, client: client, refreshExitChan: make(chan struct{}), } return r }