add rlock lock&unlock test and rlock reentrant test

This commit is contained in:
douxu 2025-03-12 16:24:28 +08:00
parent 9381e547b6
commit d962462c42
5 changed files with 119 additions and 44 deletions

View File

@ -30,7 +30,7 @@ if (mode == 'read') then
if (redis.call('exists', KEYS[1], ARGV[2]) == 1) then if (redis.call('exists', KEYS[1], ARGV[2]) == 1) then
redis.call('hincrby', KEYS[1], lockKey, '1'); redis.call('hincrby', KEYS[1], lockKey, '1');
local remainTime = redis.call('httl', KEYS[1], 'fields', '1', lockKey); local remainTime = redis.call('httl', KEYS[1], 'fields', '1', lockKey);
redis.call('hexpire', KEYS[1], math.max(remainTime, ARGV[1]), 'fields', '1', lockKey); redis.call('hexpire', KEYS[1], math.max(tonumber(remainTime[1]), ARGV[1]), 'fields', '1', lockKey);
else else
redis.call('hset', KEYS[1], lockKey, '1'); redis.call('hset', KEYS[1], lockKey, '1');
redis.call('hexpire', KEYS[1], ARGV[1], 'fields', '1', lockKey); redis.call('hexpire', KEYS[1], ARGV[1], 'fields', '1', lockKey);
@ -222,7 +222,7 @@ local mode = redis.call('hget', KEYS[1], 'mode')
local maxRemainTime = tonumber(ARGV[1]); local maxRemainTime = tonumber(ARGV[1]);
if (lockExists == 1) then if (lockExists == 1) then
redis.call('hexpire', KEYS[1], ARGV[1], 'fields', '1', lockKey); redis.call('hexpire', KEYS[1], ARGV[1], 'fields', '1', lockKey);
if (mode == 'read' ) then if (mode == 'read') then
local cursor = 0; local cursor = 0;
local pattern = KEYS[2] .. ':*'; local pattern = KEYS[2] .. ':*';
repeat repeat

View File

@ -24,10 +24,12 @@ const (
// RedissionLockConfig define redission lock config // RedissionLockConfig define redission lock config
type RedissionLockConfig struct { type RedissionLockConfig struct {
LockLeaseTime uint64 LockLeaseTime uint64
Token string
Prefix string Prefix string
ChanPrefix string ChanPrefix string
TimeoutPrefix string TimeoutPrefix string
Key string Key string
NeedRefresh bool
} }
type redissionLocker struct { type redissionLocker struct {
@ -70,7 +72,6 @@ func (rl *redissionLocker) Lock(timeout ...time.Duration) error {
acquireTimer := time.NewTimer(timeout[0]) acquireTimer := time.NewTimer(timeout[0])
for { for {
select { select {
case _, ok := <-subMsg: case _, ok := <-subMsg:
if !ok { if !ok {
err := errors.New("failed to read the lock waiting for for the channel message") err := errors.New("failed to read the lock waiting for for the channel message")
@ -113,7 +114,7 @@ func (rl *redissionLocker) subscribeLock(sub *redis.PubSub, out chan struct{}) {
select { select {
case <-rl.exit: case <-rl.exit:
break return
default: default:
switch msg.(type) { switch msg.(type) {
case *redis.Subscription: case *redis.Subscription:
@ -161,7 +162,7 @@ func (rl *redissionLocker) refreshLockTimeout() {
} }
timer.Reset(lockTime) timer.Reset(lockTime)
case <-rl.exit: case <-rl.exit:
break return
} }
} }
} }
@ -219,12 +220,8 @@ func (rl *redissionLocker) UnLock() error {
} }
func GetLocker(client *redis.Client, ops *RedissionLockConfig) *redissionLocker { func GetLocker(client *redis.Client, ops *RedissionLockConfig) *redissionLocker {
r := &redissionLocker{ if ops.Token == "" {
token: uuid.New().String(), ops.Token = uuid.New().String()
needRefresh: true,
client: client,
exit: make(chan struct{}),
logger: logger.GetLoggerInstance(),
} }
if len(ops.Prefix) <= 0 { if len(ops.Prefix) <= 0 {
@ -236,9 +233,17 @@ func GetLocker(client *redis.Client, ops *RedissionLockConfig) *redissionLocker
} }
if ops.LockLeaseTime == 0 { if ops.LockLeaseTime == 0 {
r.lockLeaseTime = internalLockLeaseTime ops.LockLeaseTime = internalLockLeaseTime
}
r := &redissionLocker{
token: ops.Token,
key: strings.Join([]string{ops.Prefix, ops.Key}, ":"),
waitChanKey: strings.Join([]string{ops.ChanPrefix, ops.Key, "wait"}, ":"),
needRefresh: ops.NeedRefresh,
client: client,
exit: make(chan struct{}),
logger: logger.GetLoggerInstance(),
} }
r.key = strings.Join([]string{ops.Prefix, ops.Key}, ":")
r.waitChanKey = strings.Join([]string{ops.ChanPrefix, ops.Key, "wait"}, ":")
return r return r
} }

View File

@ -32,11 +32,14 @@ func (rl *RedissionRWLocker) RLock(timeout ...time.Duration) error {
return fmt.Errorf("get read lock failed:%w", result) return fmt.Errorf("get read lock failed:%w", result)
} }
if (result.Code == constant.LockSuccess) && rl.needRefresh { if result.Code == constant.LockSuccess {
rl.once.Do(func() { if rl.needRefresh {
// async refresh lock timeout unitl receive exit singal rl.once.Do(func() {
go rl.refreshLockTimeout() // async refresh lock timeout unitl receive exit singal
}) go rl.refreshLockTimeout()
})
}
rl.logger.Info("success get the read by key and token", zap.String("key", rl.key), zap.String("token", rl.token))
return nil return nil
} }
@ -50,7 +53,6 @@ func (rl *RedissionRWLocker) RLock(timeout ...time.Duration) error {
acquireTimer := time.NewTimer(timeout[0]) acquireTimer := time.NewTimer(timeout[0])
for { for {
select { select {
case _, ok := <-subMsg: case _, ok := <-subMsg:
if !ok { if !ok {
err := errors.New("failed to read the read lock waiting for for the channel message") err := errors.New("failed to read the read lock waiting for for the channel message")
@ -117,12 +119,13 @@ func (rl *RedissionRWLocker) refreshLockTimeout() {
} }
timer.Reset(lockTime) timer.Reset(lockTime)
case <-rl.exit: case <-rl.exit:
break return
} }
} }
} }
func (rl *RedissionRWLocker) UnRLock() error { func (rl *RedissionRWLocker) UnRLock() error {
rl.logger.Info("unlock RLock by key and token", zap.String("key", rl.key), zap.String("token", rl.token))
res := rl.client.Eval(luascript.UnRLockScript, []string{rl.key, rl.rwTokenTimeoutPrefix, rl.waitChanKey}, unlockMessage, rl.token) res := rl.client.Eval(luascript.UnRLockScript, []string{rl.key, rl.rwTokenTimeoutPrefix, rl.waitChanKey}, unlockMessage, rl.token)
val, err := res.Int() val, err := res.Int()
if err != redis.Nil && err != nil { if err != redis.Nil && err != nil {
@ -237,24 +240,19 @@ func (rl *RedissionRWLocker) UnWLock() error {
} }
func GetRWLocker(client *redis.Client, ops *RedissionLockConfig) *RedissionRWLocker { func GetRWLocker(client *redis.Client, ops *RedissionLockConfig) *RedissionRWLocker {
r := &redissionLocker{ if ops.Token == "" {
token: uuid.New().String(), ops.Token = uuid.New().String()
needRefresh: true,
client: client,
exit: make(chan struct{}),
once: &sync.Once{},
logger: logger.GetLoggerInstance(),
} }
if len(ops.Prefix) <= 0 { if ops.Prefix == "" {
ops.Prefix = "redission-rwlock" ops.Prefix = "redission-rwlock"
} }
if len(ops.TimeoutPrefix) <= 0 { if ops.TimeoutPrefix == "" {
ops.TimeoutPrefix = "rwlock_timeout" ops.TimeoutPrefix = "rwlock_timeout"
} }
if len(ops.ChanPrefix) <= 0 { if ops.ChanPrefix == "" {
ops.ChanPrefix = "redission-rwlock-channel" ops.ChanPrefix = "redission-rwlock-channel"
} }
@ -262,9 +260,17 @@ func GetRWLocker(client *redis.Client, ops *RedissionLockConfig) *RedissionRWLoc
ops.LockLeaseTime = internalLockLeaseTime ops.LockLeaseTime = internalLockLeaseTime
} }
r.key = strings.Join([]string{ops.Prefix, ops.Key}, ":") r := &redissionLocker{
r.lockLeaseTime = ops.LockLeaseTime token: ops.Token,
r.waitChanKey = strings.Join([]string{ops.ChanPrefix, ops.Key, "write"}, ":") key: strings.Join([]string{ops.Prefix, ops.Key}, ":"),
needRefresh: ops.NeedRefresh,
lockLeaseTime: ops.LockLeaseTime,
waitChanKey: strings.Join([]string{ops.ChanPrefix, ops.Key, "write"}, ":"),
client: client,
exit: make(chan struct{}),
once: &sync.Once{},
logger: logger.GetLoggerInstance(),
}
rwLocker := &RedissionRWLocker{ rwLocker := &RedissionRWLocker{
redissionLocker: *r, redissionLocker: *r,

View File

@ -1,10 +1,12 @@
package distributed_lock package distributed_lock
import ( import (
"strings"
"testing" "testing"
"time" "time"
"github.com/go-redis/redis" "github.com/go-redis/redis"
"github.com/stretchr/testify/assert"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -14,6 +16,44 @@ func init() {
log = zap.Must(zap.NewDevelopment()) log = zap.Must(zap.NewDevelopment())
} }
func TestRWLockRLockAndUnRLock(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Network: "tcp",
Addr: "192.168.2.103:6379",
Password: "cnstar",
PoolSize: 50,
DialTimeout: 10 * time.Second,
})
rwLocker := GetRWLocker(rdb, &RedissionLockConfig{
LockLeaseTime: 120,
NeedRefresh: true,
Key: "component",
Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a",
})
rwLocker.logger = log
duration := 10 * time.Second
// 第一次加读锁
err := rwLocker.RLock(duration)
assert.Equal(t, nil, err)
tokenKey := strings.Join([]string{rwLocker.rwTokenTimeoutPrefix, rwLocker.token}, ":")
num, err := rdb.HGet(rwLocker.key, tokenKey).Int()
assert.Equal(t, nil, err)
assert.Equal(t, 1, num)
err = rwLocker.UnRLock()
assert.Equal(t, nil, err)
num, err = rdb.HGet(rwLocker.key, tokenKey).Int()
assert.Equal(t, redis.Nil, err)
assert.Equal(t, 0, num)
t.Log("test success")
return
}
// TODO 实现可重入读锁测试
func TestRWLockReentrantLock(t *testing.T) { func TestRWLockReentrantLock(t *testing.T) {
rdb := redis.NewClient(&redis.Options{ rdb := redis.NewClient(&redis.Options{
Network: "tcp", Network: "tcp",
@ -25,22 +65,43 @@ func TestRWLockReentrantLock(t *testing.T) {
rwLocker := GetRWLocker(rdb, &RedissionLockConfig{ rwLocker := GetRWLocker(rdb, &RedissionLockConfig{
LockLeaseTime: 120, LockLeaseTime: 120,
NeedRefresh: true,
Key: "component", Key: "component",
Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a",
}) })
rwLocker.logger = log rwLocker.logger = log
t.Logf("%+v\n", rwLocker)
duration := 10 * time.Second duration := 10 * time.Second
// 第一次加读锁 // 第一次加读锁
err := rwLocker.RLock(duration) err := rwLocker.RLock(duration)
t.Logf("err:%+v\n", err) assert.Equal(t, nil, err)
// TODO 实现可重入读锁测试
// rwLocker.UnRLock() tokenKey := strings.Join([]string{rwLocker.rwTokenTimeoutPrefix, rwLocker.token}, ":")
// // 第二次加读锁 num, err := rdb.HGet(rwLocker.key, tokenKey).Int()
// rwLocker.RLock(duration) assert.Equal(t, nil, err)
// // 查看 redis 中相关 key 的值 assert.Equal(t, 1, num)
// rwLocker.UnRLock()
// 第二次加读锁
err = rwLocker.RLock(duration)
assert.Equal(t, nil, err)
num, err = rdb.HGet(rwLocker.key, tokenKey).Int()
assert.Equal(t, nil, err)
assert.Equal(t, 2, num)
err = rwLocker.UnRLock()
assert.Equal(t, nil, err)
num, err = rdb.HGet(rwLocker.key, tokenKey).Int()
assert.Equal(t, redis.Nil, err)
assert.Equal(t, 1, num)
err = rwLocker.UnRLock()
assert.Equal(t, nil, err)
num, err = rdb.HGet(rwLocker.key, tokenKey).Int()
assert.Equal(t, redis.Nil, err)
assert.Equal(t, 0, num)
t.Log("test success") t.Log("test success")
select {} return
} }

3
go.mod
View File

@ -14,6 +14,7 @@ require (
github.com/natefinch/lumberjack v2.0.0+incompatible github.com/natefinch/lumberjack v2.0.0+incompatible
github.com/panjf2000/ants/v2 v2.10.0 github.com/panjf2000/ants/v2 v2.10.0
github.com/spf13/viper v1.19.0 github.com/spf13/viper v1.19.0
github.com/stretchr/testify v1.9.0
github.com/swaggo/files v1.0.1 github.com/swaggo/files v1.0.1
github.com/swaggo/gin-swagger v1.6.0 github.com/swaggo/gin-swagger v1.6.0
github.com/swaggo/swag v1.16.4 github.com/swaggo/swag v1.16.4
@ -30,6 +31,7 @@ require (
github.com/bytedance/sonic/loader v0.2.1 // indirect github.com/bytedance/sonic/loader v0.2.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.7 // indirect github.com/gabriel-vasile/mimetype v1.4.7 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
@ -60,6 +62,7 @@ require (
github.com/onsi/ginkgo v1.16.5 // indirect github.com/onsi/ginkgo v1.16.5 // indirect
github.com/onsi/gomega v1.18.1 // indirect github.com/onsi/gomega v1.18.1 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect