259 lines
7.5 KiB
Go
259 lines
7.5 KiB
Go
package distributedlock
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"modelRT/distributedlock/constant"
|
||
luascript "modelRT/distributedlock/luascript"
|
||
"modelRT/logger"
|
||
|
||
uuid "github.com/gofrs/uuid"
|
||
"github.com/redis/go-redis/v9"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
const (
|
||
internalLockLeaseTime = uint64(30)
|
||
unlockMessage = 0
|
||
)
|
||
|
||
// RedissionLockConfig define redission lock config
|
||
type RedissionLockConfig struct {
|
||
LockLeaseTime uint64
|
||
Token string
|
||
Prefix string
|
||
ChanPrefix string
|
||
TimeoutPrefix string
|
||
Key string
|
||
NeedRefresh bool
|
||
}
|
||
|
||
type redissionLocker struct {
|
||
lockLeaseTime uint64
|
||
token string
|
||
key string
|
||
waitChanKey string
|
||
needRefresh bool
|
||
refreshExitChan chan struct{}
|
||
subExitChan chan struct{}
|
||
client *redis.Client
|
||
refreshOnce *sync.Once
|
||
logger *zap.Logger
|
||
}
|
||
|
||
func (rl *redissionLocker) Lock(ctx context.Context, timeout ...time.Duration) error {
|
||
if rl.refreshExitChan == nil {
|
||
rl.refreshExitChan = make(chan struct{})
|
||
}
|
||
result := rl.tryLock(ctx).(*constant.RedisResult)
|
||
if result.Code == constant.UnknownInternalError {
|
||
rl.logger.Error(result.OutputResultMessage())
|
||
return fmt.Errorf("get lock failed:%w", result)
|
||
}
|
||
|
||
if (result.Code == constant.LockSuccess) && rl.needRefresh {
|
||
rl.refreshOnce.Do(func() {
|
||
// async refresh lock timeout unitl receive exit singal
|
||
go rl.refreshLockTimeout(ctx)
|
||
})
|
||
return nil
|
||
}
|
||
|
||
subMsg := make(chan struct{}, 1)
|
||
defer close(subMsg)
|
||
sub := rl.client.Subscribe(ctx, rl.waitChanKey)
|
||
defer sub.Close()
|
||
go rl.subscribeLock(sub, subMsg)
|
||
|
||
if len(timeout) > 0 && timeout[0] > 0 {
|
||
acquireTimer := time.NewTimer(timeout[0])
|
||
for {
|
||
select {
|
||
case _, ok := <-subMsg:
|
||
if !ok {
|
||
err := errors.New("failed to read the lock waiting for for the channel message")
|
||
rl.logger.Error("failed to read the lock waiting for for the channel message")
|
||
return err
|
||
}
|
||
|
||
resultErr := rl.tryLock(ctx).(*constant.RedisResult)
|
||
if (resultErr.Code == constant.LockFailure) || (resultErr.Code == constant.UnknownInternalError) {
|
||
rl.logger.Info(resultErr.OutputResultMessage())
|
||
continue
|
||
}
|
||
|
||
if resultErr.Code == constant.LockSuccess {
|
||
rl.logger.Info(resultErr.OutputResultMessage())
|
||
return nil
|
||
}
|
||
case <-acquireTimer.C:
|
||
err := errors.New("the waiting time for obtaining the lock operation has timed out")
|
||
rl.logger.Info("the waiting time for obtaining the lock operation has timed out")
|
||
return err
|
||
}
|
||
}
|
||
}
|
||
return fmt.Errorf("lock the redis lock failed:%w", result)
|
||
}
|
||
|
||
func (rl *redissionLocker) subscribeLock(sub *redis.PubSub, subMsgChan chan struct{}) {
|
||
if sub == nil || subMsgChan == nil {
|
||
return
|
||
}
|
||
rl.logger.Info("lock: enter sub routine", zap.String("token", rl.token))
|
||
|
||
for {
|
||
select {
|
||
case <-rl.subExitChan:
|
||
close(subMsgChan)
|
||
return
|
||
case <-sub.Channel():
|
||
// 这里只会收到真正的数据消息
|
||
subMsgChan <- struct{}{}
|
||
default:
|
||
}
|
||
}
|
||
}
|
||
|
||
/*
|
||
KEYS[1]:锁的键名(key),通常是锁的唯一标识。
|
||
ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。
|
||
ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。
|
||
*/
|
||
func (rl *redissionLocker) refreshLockTimeout(ctx context.Context) {
|
||
rl.logger.Info("lock refresh by key and token", zap.String("token", rl.token), zap.String("key", rl.key))
|
||
|
||
lockTime := time.Duration(rl.lockLeaseTime/3) * time.Second
|
||
timer := time.NewTimer(lockTime)
|
||
defer timer.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-timer.C:
|
||
// extend key lease time
|
||
res := rl.client.Eval(ctx, luascript.RefreshLockScript, []string{rl.key}, rl.lockLeaseTime, rl.token)
|
||
val, err := res.Int()
|
||
if err != redis.Nil && err != nil {
|
||
rl.logger.Info("lock refresh failed", zap.String("token", rl.token), zap.String("key", rl.key), zap.Error(err))
|
||
return
|
||
}
|
||
|
||
if constant.RedisCode(val) == constant.RefreshLockFailure {
|
||
rl.logger.Error("lock refreash failed,can not find the lock by key and token", zap.String("token", rl.token), zap.String("key", rl.key))
|
||
break
|
||
}
|
||
|
||
if constant.RedisCode(val) == constant.RefreshLockSuccess {
|
||
rl.logger.Info("lock refresh success by key and token", zap.String("token", rl.token), zap.String("key", rl.key))
|
||
}
|
||
timer.Reset(lockTime)
|
||
case <-rl.refreshExitChan:
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
func (rl *redissionLocker) cancelRefreshLockTime() {
|
||
if rl.refreshExitChan != nil {
|
||
close(rl.refreshExitChan)
|
||
rl.refreshOnce = &sync.Once{}
|
||
}
|
||
}
|
||
|
||
func (rl *redissionLocker) closeSub(sub *redis.PubSub, noticeChan chan struct{}) {
|
||
if sub != nil {
|
||
err := sub.Close()
|
||
if err != nil {
|
||
rl.logger.Error("close sub failed", zap.String("token", rl.token), zap.String("key", rl.key), zap.Error(err))
|
||
}
|
||
}
|
||
|
||
if noticeChan != nil {
|
||
close(noticeChan)
|
||
}
|
||
}
|
||
|
||
/*
|
||
KEYS[1]:锁的键名(key),通常是锁的唯一标识。
|
||
ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。
|
||
ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。
|
||
*/
|
||
func (rl *redissionLocker) tryLock(ctx context.Context) error {
|
||
lockType := constant.LockType
|
||
res := rl.client.Eval(ctx, luascript.LockScript, []string{rl.key}, rl.lockLeaseTime, rl.token)
|
||
val, err := res.Int()
|
||
if err != redis.Nil && err != nil {
|
||
return constant.NewRedisResult(constant.UnknownInternalError, lockType, err.Error())
|
||
}
|
||
return constant.NewRedisResult(constant.RedisCode(val), lockType, "")
|
||
}
|
||
|
||
/*
|
||
KEYS[1]:锁的键名(key),通常是锁的唯一标识。
|
||
KEYS[2]:锁的释放通知频道(chankey),用于通知其他客户端锁已释放。
|
||
ARGV[1]:解锁消息(unlockMessage),用于通知其他客户端锁已释放。
|
||
ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。
|
||
*/
|
||
func (rl *redissionLocker) UnLock(ctx context.Context) error {
|
||
res := rl.client.Eval(ctx, luascript.UnLockScript, []string{rl.key, rl.waitChanKey}, unlockMessage, rl.token)
|
||
val, err := res.Int()
|
||
if err != redis.Nil && err != nil {
|
||
rl.logger.Info("unlock lock failed", zap.String("token", rl.token), zap.String("key", rl.key), zap.Error(err))
|
||
return fmt.Errorf("unlock lock failed:%w", constant.NewRedisResult(constant.UnknownInternalError, constant.UnLockType, err.Error()))
|
||
}
|
||
|
||
if constant.RedisCode(val) == constant.UnLockSuccess {
|
||
if rl.needRefresh {
|
||
rl.cancelRefreshLockTime()
|
||
}
|
||
|
||
rl.logger.Info("unlock lock success", zap.String("token", rl.token), zap.String("key", rl.key))
|
||
return nil
|
||
}
|
||
|
||
if constant.RedisCode(val) == constant.UnLocakFailureWithLockOccupancy {
|
||
rl.logger.Info("unlock lock failed", zap.String("token", rl.token), zap.String("key", rl.key))
|
||
return fmt.Errorf("unlock lock failed:%w", constant.NewRedisResult(constant.UnLocakFailureWithLockOccupancy, constant.UnLockType, ""))
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// TODO 优化 panic
|
||
func GetLocker(client *redis.Client, ops *RedissionLockConfig) *redissionLocker {
|
||
if ops.Token == "" {
|
||
token, err := uuid.NewV4()
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
ops.Token = token.String()
|
||
}
|
||
|
||
if len(ops.Prefix) <= 0 {
|
||
ops.Prefix = "redission-lock"
|
||
}
|
||
|
||
if len(ops.ChanPrefix) <= 0 {
|
||
ops.ChanPrefix = "redission-lock-channel"
|
||
}
|
||
|
||
if ops.LockLeaseTime == 0 {
|
||
ops.LockLeaseTime = internalLockLeaseTime
|
||
}
|
||
|
||
r := &redissionLocker{
|
||
token: ops.Token,
|
||
key: strings.Join([]string{ops.Prefix, ops.Key}, ":"),
|
||
waitChanKey: strings.Join([]string{ops.ChanPrefix, ops.Key, "wait"}, ":"),
|
||
needRefresh: ops.NeedRefresh,
|
||
client: client,
|
||
refreshExitChan: make(chan struct{}),
|
||
logger: logger.GetLoggerInstance(),
|
||
}
|
||
return r
|
||
}
|