modelRT/distributedlock/redis_lock.go

251 lines
7.1 KiB
Go
Raw Normal View History

package distributed_lock
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"
"modelRT/distributedlock/constant"
luascript "modelRT/distributedlock/luascript"
"modelRT/logger"
uuid "github.com/google/uuid"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
const (
internalLockLeaseTime = uint64(30)
unlockMessage = 0
)
// RedissionLockConfig define redission lock config
type RedissionLockConfig struct {
LockLeaseTime uint64
Token string
Prefix string
ChanPrefix string
TimeoutPrefix string
Key string
NeedRefresh bool
}
type redissionLocker struct {
lockLeaseTime uint64
token string
key string
waitChanKey string
needRefresh bool
exit chan struct{}
client *redis.Client
once *sync.Once
logger *zap.Logger
}
func (rl *redissionLocker) Lock(ctx context.Context, timeout ...time.Duration) error {
if rl.exit == nil {
rl.exit = make(chan struct{})
}
result := rl.tryLock(ctx).(*constant.RedisResult)
if result.Code == constant.UnknownInternalError {
rl.logger.Error(result.OutputResultMessage())
return fmt.Errorf("get lock failed:%w", result)
}
if (result.Code == constant.LockSuccess) && rl.needRefresh {
rl.once.Do(func() {
// async refresh lock timeout unitl receive exit singal
go rl.refreshLockTimeout(ctx)
})
return nil
}
subMsg := make(chan struct{}, 1)
defer close(subMsg)
sub := rl.client.Subscribe(ctx, rl.waitChanKey)
defer sub.Close()
go rl.subscribeLock(ctx, sub, subMsg)
if len(timeout) > 0 && timeout[0] > 0 {
acquireTimer := time.NewTimer(timeout[0])
for {
select {
case _, ok := <-subMsg:
if !ok {
err := errors.New("failed to read the lock waiting for for the channel message")
rl.logger.Error("failed to read the lock waiting for for the channel message")
return err
}
resultErr := rl.tryLock(ctx).(*constant.RedisResult)
if (resultErr.Code == constant.LockFailure) || (resultErr.Code == constant.UnknownInternalError) {
rl.logger.Info(resultErr.OutputResultMessage())
continue
}
if resultErr.Code == constant.LockSuccess {
rl.logger.Info(resultErr.OutputResultMessage())
return nil
}
case <-acquireTimer.C:
err := errors.New("the waiting time for obtaining the lock operation has timed out")
rl.logger.Info("the waiting time for obtaining the lock operation has timed out")
return err
}
}
}
return fmt.Errorf("lock the redis lock failed:%w", result)
}
func (rl *redissionLocker) subscribeLock(ctx context.Context, sub *redis.PubSub, out chan struct{}) {
if sub == nil || out == nil {
return
}
rl.logger.Info("lock: enter sub routine", zap.String("token", rl.token))
for {
msg, err := sub.Receive(ctx)
if err != nil {
rl.logger.Info("sub receive message failed", zap.Error(err))
continue
}
select {
case <-rl.exit:
return
default:
switch msg.(type) {
case *redis.Subscription:
// Ignore.
case *redis.Pong:
// Ignore.
case *redis.Message:
out <- struct{}{}
default:
}
}
}
}
/*
KEYS[1]:锁的键名key,通常是锁的唯一标识
ARGV[1]:锁的过期时间lockLeaseTime,单位为秒
ARGV[2]:当前客户端的唯一标识token,用于区分不同的客户端
*/
func (rl *redissionLocker) refreshLockTimeout(ctx context.Context) {
rl.logger.Info("lock refresh by key and token", zap.String("token", rl.token), zap.String("key", rl.key))
lockTime := time.Duration(rl.lockLeaseTime/3) * time.Second
timer := time.NewTimer(lockTime)
defer timer.Stop()
for {
select {
case <-timer.C:
// extend key lease time
res := rl.client.Eval(ctx, luascript.RefreshLockScript, []string{rl.key}, rl.lockLeaseTime, rl.token)
val, err := res.Int()
if err != redis.Nil && err != nil {
rl.logger.Info("lock refresh failed", zap.String("token", rl.token), zap.String("key", rl.key), zap.Error(err))
return
}
if constant.RedisCode(val) == constant.RefreshLockFailure {
rl.logger.Error("lock refreash failed,can not find the lock by key and token", zap.String("token", rl.token), zap.String("key", rl.key))
break
}
if constant.RedisCode(val) == constant.RefreshLockSuccess {
rl.logger.Info("lock refresh success by key and token", zap.String("token", rl.token), zap.String("key", rl.key))
}
timer.Reset(lockTime)
case <-rl.exit:
return
}
}
}
func (rl *redissionLocker) cancelRefreshLockTime() {
if rl.exit != nil {
close(rl.exit)
rl.once = &sync.Once{}
}
}
/*
KEYS[1]:锁的键名key,通常是锁的唯一标识
ARGV[1]:锁的过期时间lockLeaseTime,单位为秒
ARGV[2]:当前客户端的唯一标识token,用于区分不同的客户端
*/
func (rl *redissionLocker) tryLock(ctx context.Context) error {
lockType := constant.LockType
res := rl.client.Eval(ctx, luascript.LockScript, []string{rl.key}, rl.lockLeaseTime, rl.token)
val, err := res.Int()
if err != redis.Nil && err != nil {
return constant.NewRedisResult(constant.UnknownInternalError, lockType, err.Error())
}
return constant.NewRedisResult(constant.RedisCode(val), lockType, "")
}
/*
KEYS[1]:锁的键名key,通常是锁的唯一标识
KEYS[2]:锁的释放通知频道chankey,用于通知其他客户端锁已释放
ARGV[1]:解锁消息unlockMessage,用于通知其他客户端锁已释放
ARGV[2]:当前客户端的唯一标识token,用于区分不同的客户端
*/
func (rl *redissionLocker) UnLock(ctx context.Context) error {
res := rl.client.Eval(ctx, luascript.UnLockScript, []string{rl.key, rl.waitChanKey}, unlockMessage, rl.token)
val, err := res.Int()
if err != redis.Nil && err != nil {
rl.logger.Info("unlock lock failed", zap.String("token", rl.token), zap.String("key", rl.key), zap.Error(err))
return fmt.Errorf("unlock lock failed:%w", constant.NewRedisResult(constant.UnknownInternalError, constant.UnLockType, err.Error()))
}
if constant.RedisCode(val) == constant.UnLockSuccess {
if rl.needRefresh {
rl.cancelRefreshLockTime()
}
rl.logger.Info("unlock lock success", zap.String("token", rl.token), zap.String("key", rl.key))
return nil
}
if constant.RedisCode(val) == constant.UnLocakFailureWithLockOccupancy {
rl.logger.Info("unlock lock failed", zap.String("token", rl.token), zap.String("key", rl.key))
return fmt.Errorf("unlock lock failed:%w", constant.NewRedisResult(constant.UnLocakFailureWithLockOccupancy, constant.UnLockType, ""))
}
return nil
}
func GetLocker(client *redis.Client, ops *RedissionLockConfig) *redissionLocker {
if ops.Token == "" {
ops.Token = uuid.New().String()
}
if len(ops.Prefix) <= 0 {
ops.Prefix = "redission-lock"
}
if len(ops.ChanPrefix) <= 0 {
ops.ChanPrefix = "redission-lock-channel"
}
if ops.LockLeaseTime == 0 {
ops.LockLeaseTime = internalLockLeaseTime
}
r := &redissionLocker{
token: ops.Token,
key: strings.Join([]string{ops.Prefix, ops.Key}, ":"),
waitChanKey: strings.Join([]string{ops.ChanPrefix, ops.Key, "wait"}, ":"),
needRefresh: ops.NeedRefresh,
client: client,
exit: make(chan struct{}),
logger: logger.GetLoggerInstance(),
}
return r
}