modelRT/distributedlock/redis_rwlock.go

472 lines
11 KiB
Go

package distributed_lock
import (
"context"
"strings"
"sync"
"time"
"github.com/go-redis/redis"
uuid "github.com/google/uuid"
"go.uber.org/zap"
)
var rlockScript string = strings.Join([]string{
"local mode = redis.call('hget', KEYS[1], 'mode'); ",
"if (mode == false) then ",
"redis.call('hset', KEYS[1], 'mode', 'read'); ",
"redis.call('hset', KEYS[1], ARGV[2], 1); ",
"redis.call('set', KEYS[2] .. ':1', 1); ",
"redis.call('pexpire', KEYS[2] .. ':1', ARGV[1]); ",
"redis.call('pexpire', KEYS[1], ARGV[1]); ",
"return nil; ",
"end; ",
"if (mode == 'read') or (mode == 'write' and redis.call('hexists', KEYS[1], ARGV[3]) == 1) then ",
"local ind = redis.call('hincrby', KEYS[1], ARGV[2], 1); ",
"local key = KEYS[2] .. ':' .. ind;",
"redis.call('set', key, 1); ",
"redis.call('pexpire', key, ARGV[1]); ",
"local remainTime = redis.call('pttl', KEYS[1]); ",
"redis.call('pexpire', KEYS[1], math.max(remainTime, ARGV[1])); ",
"return nil; ",
"end;",
"return redis.call('pttl', KEYS[1]);",
}, "")
var runlockScript string = strings.Join([]string{
"local mode = redis.call('hget', KEYS[1], 'mode'); ",
"if (mode == false) then ",
"redis.call('publish', KEYS[2], ARGV[1]); ",
"return 1; ",
"end; ",
"local lockExists = redis.call('hexists', KEYS[1], ARGV[2]); ",
"if (lockExists == 0) then ",
"return nil;",
"end; ",
"local counter = redis.call('hincrby', KEYS[1], ARGV[2], -1); ",
"if (counter == 0) then ",
"redis.call('hdel', KEYS[1], ARGV[2]); ",
"end;",
"redis.call('del', KEYS[3] .. ':' .. (counter+1)); ",
"if (redis.call('hlen', KEYS[1]) > 1) then ",
"local maxRemainTime = -3; ",
"local keys = redis.call('hkeys', KEYS[1]); ",
"for n, key in ipairs(keys) do ",
"counter = tonumber(redis.call('hget', KEYS[1], key)); ",
"if type(counter) == 'number' then ",
"for i=counter, 1, -1 do ",
"local remainTime = redis.call('pttl', KEYS[4] .. ':' .. key .. ':rwlock_timeout:' .. i); ",
"maxRemainTime = math.max(remainTime, maxRemainTime);",
"end; ",
"end; ",
"end; ",
"if maxRemainTime > 0 then ",
"redis.call('pexpire', KEYS[1], maxRemainTime); ",
"return 0; ",
"end;",
"if mode == 'write' then ",
"return 0;",
"end; ",
"end; ",
"redis.call('del', KEYS[1]); ",
"redis.call('publish', KEYS[2], ARGV[1]); ",
"return 1; ",
}, "")
var rlockrefreshScript = strings.Join([]string{
"local counter = redis.call('hget', KEYS[1], ARGV[2]); ",
"if (counter ~= false) then ",
"redis.call('pexpire', KEYS[1], ARGV[1]); ",
"if (redis.call('hlen', KEYS[1]) > 1) then ",
"local keys = redis.call('hkeys', KEYS[1]); ",
"for n, key in ipairs(keys) do ",
"counter = tonumber(redis.call('hget', KEYS[1], key)); ",
"if type(counter) == 'number' then ",
"for i=counter, 1, -1 do ",
"redis.call('pexpire', KEYS[2] .. ':' .. key .. ':rwlock_timeout:' .. i, ARGV[1]); ",
"end; ",
"end; ",
"end; ",
"end; ",
"return 1; ",
"end; ",
"return 0;",
}, "")
var wlockScript string = strings.Join([]string{
"local mode = redis.call('hget', KEYS[1], 'mode'); ",
"if (mode == false) then ",
"redis.call('hset', KEYS[1], 'mode', 'write'); ",
"redis.call('hset', KEYS[1], ARGV[2], 1); ",
"redis.call('pexpire', KEYS[1], ARGV[1]); ",
"return nil; ",
"end; ",
"if (mode == 'write') then ",
"if (redis.call('hexists', KEYS[1], ARGV[2]) == 1) then ",
"redis.call('hincrby', KEYS[1], ARGV[2], 1); ",
"local currentExpire = redis.call('pttl', KEYS[1]); ",
"redis.call('pexpire', KEYS[1], currentExpire + ARGV[1]); ",
"return nil; ",
"end; ",
"end;",
"return redis.call('pttl', KEYS[1]);",
}, "")
var wunlockScript string = strings.Join([]string{
"local mode = redis.call('hget', KEYS[1], 'mode'); ",
"if (mode == false) then ",
"redis.call('publish', KEYS[2], ARGV[1]); ",
"return 1; ",
"end;",
"if (mode == 'write') then ",
"local lockExists = redis.call('hexists', KEYS[1], ARGV[3]); ",
"if (lockExists == 0) then ",
"return nil;",
"else ",
"local counter = redis.call('hincrby', KEYS[1], ARGV[3], -1); ",
"if (counter > 0) then ",
"redis.call('pexpire', KEYS[1], ARGV[2]); ",
"return 0; ",
"else ",
"redis.call('hdel', KEYS[1], ARGV[3]); ",
"if (redis.call('hlen', KEYS[1]) == 1) then ",
"redis.call('del', KEYS[1]); ",
"redis.call('publish', KEYS[2], ARGV[1]); ",
"else ",
// has unlocked read-locks
"redis.call('hset', KEYS[1], 'mode', 'read'); ",
"end; ",
"return 1; ",
"end; ",
"end; ",
"end; ",
"return nil;",
}, "")
type redissionReadLocker struct {
redissionLocker
rwTimeoutTokenPrefix string
prefixKey string
}
func (rl *redissionReadLocker) Lock(ctx context.Context, timeout ...time.Duration) {
if rl.exit == nil {
rl.exit = make(chan struct{})
}
ttl, err := rl.tryLock()
if err != nil {
panic(err)
}
if ttl <= 0 {
rl.once.Do(func() {
go rl.refreshLockTimeout()
})
return
}
submsg := make(chan struct{}, 1)
defer close(submsg)
sub := rl.client.Subscribe(rl.chankey)
defer sub.Close()
go rl.subscribeLock(sub, submsg)
// listen := rl.listenManager.Subscribe(rl.key, rl.token)
// defer rl.listenManager.UnSubscribe(rl.key, rl.token)
timer := time.NewTimer(ttl)
defer timer.Stop()
var outimer *time.Timer
if len(timeout) > 0 && timeout[0] > 0 {
outimer = time.NewTimer(timeout[0])
}
LOOP:
for {
ttl, err = rl.tryLock()
if err != nil {
panic(err)
}
if ttl <= 0 {
rl.once.Do(func() {
go rl.refreshLockTimeout()
})
return
}
if outimer != nil {
select {
case _, ok := <-submsg:
if !timer.Stop() {
<-timer.C
}
if !ok {
panic("lock listen release")
}
timer.Reset(ttl)
case <-ctx.Done():
// break LOOP
panic("lock context already release")
case <-timer.C:
timer.Reset(ttl)
case <-outimer.C:
if !timer.Stop() {
<-timer.C
}
break LOOP
}
} else {
select {
case _, ok := <-submsg:
if !timer.Stop() {
<-timer.C
}
if !ok {
panic("lock listen release")
}
timer.Reset(ttl)
case <-ctx.Done():
// break LOOP
panic("lock context already release")
case <-timer.C:
timer.Reset(ttl)
}
}
}
}
func (rl *redissionReadLocker) tryLock() (time.Duration, error) {
writeLockToken := strings.Join([]string{rl.token, "write"}, ":")
res := rl.client.Eval(rlockScript, []string{rl.key, rl.rwTimeoutTokenPrefix}, rl.lockLeaseTime, rl.token, writeLockToken)
v, err := res.Result()
if err != redis.Nil && err != nil {
return 0, err
}
if v == nil {
return 0, nil
}
return time.Duration(v.(int64)), nil
}
func (rl *redissionReadLocker) refreshLockTimeout() {
rl.logger.Debug("rlock: %s lock %s\n", zap.String("token", rl.token), zap.String("key", rl.key))
lockTime := time.Duration(rl.lockLeaseTime/3) * time.Millisecond
timer := time.NewTimer(lockTime)
defer timer.Stop()
LOOP:
for {
select {
case <-timer.C:
timer.Reset(lockTime)
// update key expire time
res := rl.client.Eval(rlockrefreshScript, []string{rl.key, rl.prefixKey}, rl.lockLeaseTime, rl.token)
val, err := res.Int()
if err != nil {
panic(err)
}
if val == 0 {
rl.logger.Debug("not find the rlock key of self")
break LOOP
}
case <-rl.exit:
break LOOP
}
}
rl.logger.Debug("rlock: refresh routine release", zap.String("token", rl.token))
}
func (rl *redissionReadLocker) UnLock() {
res := rl.client.Eval(runlockScript, []string{rl.key, rl.chankey, rl.rwTimeoutTokenPrefix, rl.prefixKey}, unlockMessage, rl.token)
val, err := res.Result()
if err != redis.Nil && err != nil {
panic(err)
}
if val == nil {
panic("attempt to unlock lock, not locked by current routine by lock id:" + rl.token)
}
rl.logger.Debug("lock: %s unlock %s\n", zap.String("token", rl.token), zap.String("key", rl.key))
if val.(int64) == 1 {
rl.cancelRefreshLockTime()
}
}
type redissionWriteLocker struct {
redissionLocker
}
func (rl *redissionWriteLocker) Lock(ctx context.Context, timeout ...time.Duration) {
if rl.exit == nil {
rl.exit = make(chan struct{})
}
ttl, err := rl.tryLock()
if err != nil {
panic(err)
}
if ttl <= 0 {
rl.once.Do(func() {
go rl.refreshLockTimeout()
})
return
}
submsg := make(chan struct{}, 1)
defer close(submsg)
sub := rl.client.Subscribe(rl.chankey)
defer sub.Close()
go rl.subscribeLock(sub, submsg)
// listen := rl.listenManager.Subscribe(rl.key, rl.token)
// defer rl.listenManager.UnSubscribe(rl.key, rl.token)
timer := time.NewTimer(ttl)
defer timer.Stop()
// outimer 理解为如果超过这个时间没有获取到锁,就直接放弃
var outimer *time.Timer
if len(timeout) > 0 && timeout[0] > 0 {
outimer = time.NewTimer(timeout[0])
}
LOOP:
for {
ttl, err = rl.tryLock()
if err != nil {
panic(err)
}
if ttl <= 0 {
rl.once.Do(func() {
go rl.refreshLockTimeout()
})
return
}
if outimer != nil {
select {
case _, ok := <-submsg:
if !timer.Stop() {
<-timer.C
}
if !ok {
panic("lock listen release")
}
timer.Reset(ttl)
case <-ctx.Done():
// break LOOP
panic("lock context already release")
case <-timer.C:
timer.Reset(ttl)
case <-outimer.C:
if !timer.Stop() {
<-timer.C
}
break LOOP
}
} else {
select {
case _, ok := <-submsg:
if !timer.Stop() {
<-timer.C
}
if !ok {
panic("lock listen release")
}
timer.Reset(ttl)
case <-ctx.Done():
// break LOOP
panic("lock context already release")
case <-timer.C:
timer.Reset(ttl)
}
}
}
}
func (rl *redissionWriteLocker) tryLock() (time.Duration, error) {
res := rl.client.Eval(wlockScript, []string{rl.key}, rl.lockLeaseTime, rl.token)
v, err := res.Result()
if err != redis.Nil && err != nil {
return 0, err
}
if v == nil {
return 0, nil
}
return time.Duration(v.(int64)), nil
}
func (rl *redissionWriteLocker) UnLock() {
res := rl.client.Eval(wunlockScript, []string{rl.key, rl.chankey}, unlockMessage, rl.lockLeaseTime, rl.token)
val, err := res.Result()
if err != redis.Nil && err != nil {
panic(err)
}
if val == nil {
panic("attempt to unlock lock, not locked by current routine by lock id:" + rl.token)
}
rl.logger.Debug("lock: unlock", zap.String("token", rl.token), zap.String("key", rl.key))
if val.(int64) == 1 {
rl.cancelRefreshLockTime()
}
}
func GetReadLocker(client *redis.Client, ops *RedissionLockConfig) *redissionReadLocker {
r := &redissionLocker{
token: uuid.New().String(),
client: client,
exit: make(chan struct{}),
once: &sync.Once{},
}
if len(ops.Prefix) <= 0 {
ops.Prefix = "redission-rwlock"
}
if len(ops.ChanPrefix) <= 0 {
ops.ChanPrefix = "redission-rwlock-channel"
}
if ops.LockLeaseTime == 0 {
r.lockLeaseTime = internalLockLeaseTime
}
r.key = strings.Join([]string{ops.Prefix, ops.Key}, ":")
r.chankey = strings.Join([]string{ops.ChanPrefix, ops.Key}, ":")
tkey := strings.Join([]string{"{", r.key, "}"}, "")
return &redissionReadLocker{redissionLocker: *r, rwTimeoutTokenPrefix: strings.Join([]string{tkey, r.token, "rwlock_timeout"}, ":"), prefixKey: tkey}
}
func GetWriteLocker(client *redis.Client, ops *RedissionLockConfig) *redissionWriteLocker {
r := &redissionLocker{
token: uuid.New().String(),
client: client,
exit: make(chan struct{}),
once: &sync.Once{},
}
if len(ops.Prefix) <= 0 {
ops.Prefix = "redission-rwlock"
}
if len(ops.ChanPrefix) <= 0 {
ops.ChanPrefix = "redission-rwlock-channel"
}
if ops.LockLeaseTime == 0 {
r.lockLeaseTime = internalLockLeaseTime
}
r.key = strings.Join([]string{ops.Prefix, ops.Key}, ":")
r.chankey = strings.Join([]string{ops.ChanPrefix, ops.Key}, ":")
return &redissionWriteLocker{redissionLocker: *r}
}