300 lines
6.5 KiB
Go
300 lines
6.5 KiB
Go
package distributed_lock
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
luascript "modelRT/distributedlock/luascript"
|
|
|
|
"github.com/go-redis/redis"
|
|
uuid "github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
var lockScript string = strings.Join([]string{
|
|
"if (redis.call('exists', KEYS[1]) == 0) then ",
|
|
"redis.call('hset', KEYS[1], ARGV[2], 1); ",
|
|
"redis.call('pexpire', KEYS[1], ARGV[1]); ",
|
|
"return nil; ",
|
|
"end; ",
|
|
"if (redis.call('hexists', KEYS[1], ARGV[2]) == 1) then ",
|
|
"redis.call('hincrby', KEYS[1], ARGV[2], 1); ",
|
|
"redis.call('pexpire', KEYS[1], ARGV[1]); ",
|
|
"return nil; ",
|
|
"end; ",
|
|
"return redis.call('pttl', KEYS[1]);",
|
|
}, "")
|
|
|
|
var refreshLockScript string = strings.Join([]string{
|
|
"if (redis.call('hexists', KEYS[1], ARGV[2]) == 1) then ",
|
|
"redis.call('pexpire', KEYS[1], ARGV[1]); ",
|
|
"return 1; ",
|
|
"end; ",
|
|
"return 0;",
|
|
}, "")
|
|
|
|
var unlockScript string = strings.Join([]string{
|
|
"if (redis.call('exists', KEYS[1]) == 0) then ",
|
|
"redis.call('publish', KEYS[2], ARGV[1]); ",
|
|
"return 1; ",
|
|
"end;",
|
|
"if (redis.call('hexists', KEYS[1], ARGV[3]) == 0) then ",
|
|
"return nil;",
|
|
"end; ",
|
|
"local counter = redis.call('hincrby', KEYS[1], ARGV[3], -1); ",
|
|
"if (counter > 0) then ",
|
|
"redis.call('pexpire', KEYS[1], ARGV[2]); ",
|
|
"return 0; ",
|
|
"else ",
|
|
"redis.call('del', KEYS[1]); ",
|
|
"redis.call('publish', KEYS[2], ARGV[1]); ",
|
|
"return 1; ",
|
|
"end; ",
|
|
"return nil;",
|
|
}, "")
|
|
|
|
const (
|
|
internalLockLeaseTime = uint64(30) * 1000
|
|
unlockMessage = 0
|
|
)
|
|
|
|
type RedissionLockConfig struct {
|
|
LockLeaseTime time.Duration
|
|
Prefix string
|
|
ChanPrefix string
|
|
Key string
|
|
}
|
|
|
|
type redissionLocker struct {
|
|
token string
|
|
key string
|
|
chankey string
|
|
exit chan struct{}
|
|
lockLeaseTime uint64
|
|
client *redis.Client
|
|
once *sync.Once
|
|
logger *zap.Logger
|
|
}
|
|
|
|
func (rl *redissionLocker) Lock(ctx context.Context, timeout ...time.Duration) {
|
|
fmt.Println(luascript.RlockScript)
|
|
if rl.exit == nil {
|
|
rl.exit = make(chan struct{})
|
|
}
|
|
ttl, err := rl.tryLock()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if ttl <= 0 {
|
|
rl.once.Do(func() {
|
|
go rl.refreshLockTimeout()
|
|
})
|
|
return
|
|
}
|
|
|
|
submsg := make(chan struct{}, 1)
|
|
defer close(submsg)
|
|
sub := rl.client.Subscribe(rl.chankey)
|
|
defer sub.Close()
|
|
go rl.subscribeLock(sub, submsg)
|
|
// listen := rl.listenManager.Subscribe(rl.key, rl.token)
|
|
// defer rl.listenManager.UnSubscribe(rl.key, rl.token)
|
|
|
|
timer := time.NewTimer(ttl)
|
|
defer timer.Stop()
|
|
// outimer 的作用理解为如果超过多长时间无法获得这个锁,那么就直接放弃
|
|
var outimer *time.Timer
|
|
if len(timeout) > 0 && timeout[0] > 0 {
|
|
outimer = time.NewTimer(timeout[0])
|
|
}
|
|
LOOP:
|
|
for {
|
|
ttl, err = rl.tryLock()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if ttl <= 0 {
|
|
rl.once.Do(func() {
|
|
go rl.refreshLockTimeout()
|
|
})
|
|
return
|
|
}
|
|
if outimer != nil {
|
|
select {
|
|
case _, ok := <-submsg:
|
|
if !timer.Stop() {
|
|
<-timer.C
|
|
}
|
|
|
|
if !ok {
|
|
panic("lock listen release")
|
|
}
|
|
|
|
timer.Reset(ttl)
|
|
case <-ctx.Done():
|
|
// break LOOP
|
|
panic("lock context already release")
|
|
case <-timer.C:
|
|
timer.Reset(ttl)
|
|
case <-outimer.C:
|
|
if !timer.Stop() {
|
|
<-timer.C
|
|
}
|
|
break LOOP
|
|
}
|
|
} else {
|
|
select {
|
|
case _, ok := <-submsg:
|
|
if !timer.Stop() {
|
|
<-timer.C
|
|
}
|
|
|
|
if !ok {
|
|
panic("lock listen release")
|
|
}
|
|
|
|
timer.Reset(ttl)
|
|
case <-ctx.Done():
|
|
// break LOOP
|
|
panic("lock context already release")
|
|
case <-timer.C:
|
|
timer.Reset(ttl)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (rl *redissionLocker) subscribeLock(sub *redis.PubSub, out chan struct{}) {
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
rl.logger.Error("subscribeLock catch error", zap.Error(err.(error)))
|
|
}
|
|
}()
|
|
if sub == nil || out == nil {
|
|
return
|
|
}
|
|
rl.logger.Debug("lock:%s enter sub routine", zap.String("token", rl.token))
|
|
LOOP:
|
|
for {
|
|
msg, err := sub.Receive()
|
|
if err != nil {
|
|
rl.logger.Info("sub receive message", zap.Error(err))
|
|
break LOOP
|
|
}
|
|
|
|
select {
|
|
case <-rl.exit:
|
|
break LOOP
|
|
default:
|
|
if len(out) > 0 {
|
|
// if channel hava msg. drop it
|
|
rl.logger.Debug("drop message when channel if full")
|
|
continue
|
|
}
|
|
|
|
switch msg.(type) {
|
|
case *redis.Subscription:
|
|
// Ignore.
|
|
case *redis.Pong:
|
|
// Ignore.
|
|
case *redis.Message:
|
|
out <- struct{}{}
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
rl.logger.Debug("lock sub routine release", zap.String("token", rl.token))
|
|
}
|
|
|
|
func (rl *redissionLocker) refreshLockTimeout() {
|
|
rl.logger.Debug("lock", zap.String("token", rl.token), zap.String("lock key", rl.key))
|
|
lockTime := time.Duration(rl.lockLeaseTime/3) * time.Millisecond
|
|
timer := time.NewTimer(lockTime)
|
|
defer timer.Stop()
|
|
LOOP:
|
|
for {
|
|
select {
|
|
case <-timer.C:
|
|
timer.Reset(lockTime)
|
|
// update key expire time
|
|
res := rl.client.Eval(refreshLockScript, []string{rl.key}, rl.lockLeaseTime, rl.token)
|
|
val, err := res.Int()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if val == 0 {
|
|
rl.logger.Debug("not find the lock key of self")
|
|
break LOOP
|
|
}
|
|
case <-rl.exit:
|
|
break LOOP
|
|
|
|
}
|
|
}
|
|
rl.logger.Debug("refresh routine release", zap.String("token", rl.token))
|
|
}
|
|
|
|
func (rl *redissionLocker) cancelRefreshLockTime() {
|
|
if rl.exit != nil {
|
|
close(rl.exit)
|
|
rl.exit = nil
|
|
rl.once = &sync.Once{}
|
|
}
|
|
}
|
|
|
|
func (rl *redissionLocker) tryLock() (time.Duration, error) {
|
|
res := rl.client.Eval(lockScript, []string{rl.key}, rl.lockLeaseTime, rl.token)
|
|
v, err := res.Result()
|
|
if err != redis.Nil && err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if v == nil {
|
|
return 0, nil
|
|
}
|
|
|
|
return time.Duration(v.(int64)), nil
|
|
}
|
|
|
|
func (rl *redissionLocker) UnLock() {
|
|
res := rl.client.Eval(unlockScript, []string{rl.key, rl.chankey}, unlockMessage, rl.lockLeaseTime, rl.token)
|
|
val, err := res.Result()
|
|
if err != redis.Nil && err != nil {
|
|
panic(err)
|
|
}
|
|
if val == nil {
|
|
panic("attempt to unlock lock, not locked by current routine by lock id:" + rl.token)
|
|
}
|
|
rl.logger.Debug("unlock", zap.String("token", rl.token), zap.String("key", rl.key))
|
|
if val.(int64) == 1 {
|
|
rl.cancelRefreshLockTime()
|
|
}
|
|
}
|
|
|
|
func GetLocker(client *redis.Client, ops *RedissionLockConfig) *redissionLocker {
|
|
r := &redissionLocker{
|
|
token: uuid.New().String(),
|
|
client: client,
|
|
exit: make(chan struct{}),
|
|
once: &sync.Once{},
|
|
}
|
|
|
|
if len(ops.Prefix) <= 0 {
|
|
ops.Prefix = "redission-lock"
|
|
}
|
|
if len(ops.ChanPrefix) <= 0 {
|
|
ops.ChanPrefix = "redission-lock-channel"
|
|
}
|
|
if ops.LockLeaseTime == 0 {
|
|
r.lockLeaseTime = internalLockLeaseTime
|
|
}
|
|
r.key = strings.Join([]string{ops.Prefix, ops.Key}, ":")
|
|
r.chankey = strings.Join([]string{ops.ChanPrefix, ops.Key}, ":")
|
|
return r
|
|
}
|