package distributed_lock import ( "context" "fmt" "strings" "sync" "time" luascript "modelRT/distributedlock/luascript" "github.com/go-redis/redis" uuid "github.com/google/uuid" "go.uber.org/zap" ) var lockScript string = strings.Join([]string{ "if (redis.call('exists', KEYS[1]) == 0) then ", "redis.call('hset', KEYS[1], ARGV[2], 1); ", "redis.call('pexpire', KEYS[1], ARGV[1]); ", "return nil; ", "end; ", "if (redis.call('hexists', KEYS[1], ARGV[2]) == 1) then ", "redis.call('hincrby', KEYS[1], ARGV[2], 1); ", "redis.call('pexpire', KEYS[1], ARGV[1]); ", "return nil; ", "end; ", "return redis.call('pttl', KEYS[1]);", }, "") var refreshLockScript string = strings.Join([]string{ "if (redis.call('hexists', KEYS[1], ARGV[2]) == 1) then ", "redis.call('pexpire', KEYS[1], ARGV[1]); ", "return 1; ", "end; ", "return 0;", }, "") var unlockScript string = strings.Join([]string{ "if (redis.call('exists', KEYS[1]) == 0) then ", "redis.call('publish', KEYS[2], ARGV[1]); ", "return 1; ", "end;", "if (redis.call('hexists', KEYS[1], ARGV[3]) == 0) then ", "return nil;", "end; ", "local counter = redis.call('hincrby', KEYS[1], ARGV[3], -1); ", "if (counter > 0) then ", "redis.call('pexpire', KEYS[1], ARGV[2]); ", "return 0; ", "else ", "redis.call('del', KEYS[1]); ", "redis.call('publish', KEYS[2], ARGV[1]); ", "return 1; ", "end; ", "return nil;", }, "") const ( internalLockLeaseTime = uint64(30) * 1000 unlockMessage = 0 ) type RedissionLockConfig struct { LockLeaseTime time.Duration Prefix string ChanPrefix string Key string } type redissionLocker struct { token string key string chankey string exit chan struct{} lockLeaseTime uint64 client *redis.Client once *sync.Once logger *zap.Logger } func (rl *redissionLocker) Lock(ctx context.Context, timeout ...time.Duration) { fmt.Println(luascript.RlockScript) if rl.exit == nil { rl.exit = make(chan struct{}) } ttl, err := rl.tryLock() if err != nil { panic(err) } if ttl <= 0 { rl.once.Do(func() { go rl.refreshLockTimeout() }) return } submsg := make(chan struct{}, 1) defer close(submsg) sub := rl.client.Subscribe(rl.chankey) defer sub.Close() go rl.subscribeLock(sub, submsg) // listen := rl.listenManager.Subscribe(rl.key, rl.token) // defer rl.listenManager.UnSubscribe(rl.key, rl.token) timer := time.NewTimer(ttl) defer timer.Stop() // outimer 的作用理解为如果超过多长时间无法获得这个锁,那么就直接放弃 var outimer *time.Timer if len(timeout) > 0 && timeout[0] > 0 { outimer = time.NewTimer(timeout[0]) } LOOP: for { ttl, err = rl.tryLock() if err != nil { panic(err) } if ttl <= 0 { rl.once.Do(func() { go rl.refreshLockTimeout() }) return } if outimer != nil { select { case _, ok := <-submsg: if !timer.Stop() { <-timer.C } if !ok { panic("lock listen release") } timer.Reset(ttl) case <-ctx.Done(): // break LOOP panic("lock context already release") case <-timer.C: timer.Reset(ttl) case <-outimer.C: if !timer.Stop() { <-timer.C } break LOOP } } else { select { case _, ok := <-submsg: if !timer.Stop() { <-timer.C } if !ok { panic("lock listen release") } timer.Reset(ttl) case <-ctx.Done(): // break LOOP panic("lock context already release") case <-timer.C: timer.Reset(ttl) } } } } func (rl *redissionLocker) subscribeLock(sub *redis.PubSub, out chan struct{}) { defer func() { if err := recover(); err != nil { rl.logger.Error("subscribeLock catch error", zap.Error(err.(error))) } }() if sub == nil || out == nil { return } rl.logger.Debug("lock:%s enter sub routine", zap.String("token", rl.token)) LOOP: for { msg, err := sub.Receive() if err != nil { rl.logger.Info("sub receive message", zap.Error(err)) break LOOP } select { case <-rl.exit: break LOOP default: if len(out) > 0 { // if channel hava msg. drop it rl.logger.Debug("drop message when channel if full") continue } switch msg.(type) { case *redis.Subscription: // Ignore. case *redis.Pong: // Ignore. case *redis.Message: out <- struct{}{} default: } } } rl.logger.Debug("lock sub routine release", zap.String("token", rl.token)) } func (rl *redissionLocker) refreshLockTimeout() { rl.logger.Debug("lock", zap.String("token", rl.token), zap.String("lock key", rl.key)) lockTime := time.Duration(rl.lockLeaseTime/3) * time.Millisecond timer := time.NewTimer(lockTime) defer timer.Stop() LOOP: for { select { case <-timer.C: timer.Reset(lockTime) // update key expire time res := rl.client.Eval(refreshLockScript, []string{rl.key}, rl.lockLeaseTime, rl.token) val, err := res.Int() if err != nil { panic(err) } if val == 0 { rl.logger.Debug("not find the lock key of self") break LOOP } case <-rl.exit: break LOOP } } rl.logger.Debug("refresh routine release", zap.String("token", rl.token)) } func (rl *redissionLocker) cancelRefreshLockTime() { if rl.exit != nil { close(rl.exit) rl.exit = nil rl.once = &sync.Once{} } } func (rl *redissionLocker) tryLock() (time.Duration, error) { res := rl.client.Eval(lockScript, []string{rl.key}, rl.lockLeaseTime, rl.token) v, err := res.Result() if err != redis.Nil && err != nil { return 0, err } if v == nil { return 0, nil } return time.Duration(v.(int64)), nil } func (rl *redissionLocker) UnLock() { res := rl.client.Eval(unlockScript, []string{rl.key, rl.chankey}, unlockMessage, rl.lockLeaseTime, rl.token) val, err := res.Result() if err != redis.Nil && err != nil { panic(err) } if val == nil { panic("attempt to unlock lock, not locked by current routine by lock id:" + rl.token) } rl.logger.Debug("unlock", zap.String("token", rl.token), zap.String("key", rl.key)) if val.(int64) == 1 { rl.cancelRefreshLockTime() } } func GetLocker(client *redis.Client, ops *RedissionLockConfig) *redissionLocker { r := &redissionLocker{ token: uuid.New().String(), client: client, exit: make(chan struct{}), once: &sync.Once{}, } if len(ops.Prefix) <= 0 { ops.Prefix = "redission-lock" } if len(ops.ChanPrefix) <= 0 { ops.ChanPrefix = "redission-lock-channel" } if ops.LockLeaseTime == 0 { r.lockLeaseTime = internalLockLeaseTime } r.key = strings.Join([]string{ops.Prefix, ops.Key}, ":") r.chankey = strings.Join([]string{ops.ChanPrefix, ops.Key}, ":") return r }