diff --git a/distributedlock/luascript/rwlock_script.go b/distributedlock/luascript/rwlock_script.go index 2374e4b..956b909 100644 --- a/distributedlock/luascript/rwlock_script.go +++ b/distributedlock/luascript/rwlock_script.go @@ -30,7 +30,7 @@ if (mode == 'read') then if (redis.call('exists', KEYS[1], ARGV[2]) == 1) then redis.call('hincrby', KEYS[1], lockKey, '1'); local remainTime = redis.call('httl', KEYS[1], 'fields', '1', lockKey); - redis.call('hexpire', KEYS[1], math.max(remainTime, ARGV[1]), 'fields', '1', lockKey); + redis.call('hexpire', KEYS[1], math.max(tonumber(remainTime[1]), ARGV[1]), 'fields', '1', lockKey); else redis.call('hset', KEYS[1], lockKey, '1'); redis.call('hexpire', KEYS[1], ARGV[1], 'fields', '1', lockKey); @@ -222,7 +222,7 @@ local mode = redis.call('hget', KEYS[1], 'mode') local maxRemainTime = tonumber(ARGV[1]); if (lockExists == 1) then redis.call('hexpire', KEYS[1], ARGV[1], 'fields', '1', lockKey); - if (mode == 'read' ) then + if (mode == 'read') then local cursor = 0; local pattern = KEYS[2] .. ':*'; repeat diff --git a/distributedlock/redis_lock.go b/distributedlock/redis_lock.go index 8e55b29..651560a 100644 --- a/distributedlock/redis_lock.go +++ b/distributedlock/redis_lock.go @@ -24,10 +24,12 @@ const ( // RedissionLockConfig define redission lock config type RedissionLockConfig struct { LockLeaseTime uint64 + Token string Prefix string ChanPrefix string TimeoutPrefix string Key string + NeedRefresh bool } type redissionLocker struct { @@ -70,7 +72,6 @@ func (rl *redissionLocker) Lock(timeout ...time.Duration) error { acquireTimer := time.NewTimer(timeout[0]) for { select { - case _, ok := <-subMsg: if !ok { err := errors.New("failed to read the lock waiting for for the channel message") @@ -113,7 +114,7 @@ func (rl *redissionLocker) subscribeLock(sub *redis.PubSub, out chan struct{}) { select { case <-rl.exit: - break + return default: switch msg.(type) { case *redis.Subscription: @@ -161,7 +162,7 @@ func (rl *redissionLocker) refreshLockTimeout() { } timer.Reset(lockTime) case <-rl.exit: - break + return } } } @@ -219,12 +220,8 @@ func (rl *redissionLocker) UnLock() error { } func GetLocker(client *redis.Client, ops *RedissionLockConfig) *redissionLocker { - r := &redissionLocker{ - token: uuid.New().String(), - needRefresh: true, - client: client, - exit: make(chan struct{}), - logger: logger.GetLoggerInstance(), + if ops.Token == "" { + ops.Token = uuid.New().String() } if len(ops.Prefix) <= 0 { @@ -236,9 +233,17 @@ func GetLocker(client *redis.Client, ops *RedissionLockConfig) *redissionLocker } if ops.LockLeaseTime == 0 { - r.lockLeaseTime = internalLockLeaseTime + ops.LockLeaseTime = internalLockLeaseTime + } + + r := &redissionLocker{ + token: ops.Token, + key: strings.Join([]string{ops.Prefix, ops.Key}, ":"), + waitChanKey: strings.Join([]string{ops.ChanPrefix, ops.Key, "wait"}, ":"), + needRefresh: ops.NeedRefresh, + client: client, + exit: make(chan struct{}), + logger: logger.GetLoggerInstance(), } - r.key = strings.Join([]string{ops.Prefix, ops.Key}, ":") - r.waitChanKey = strings.Join([]string{ops.ChanPrefix, ops.Key, "wait"}, ":") return r } diff --git a/distributedlock/redis_rwlock.go b/distributedlock/redis_rwlock.go index caee734..7cf89f5 100644 --- a/distributedlock/redis_rwlock.go +++ b/distributedlock/redis_rwlock.go @@ -32,11 +32,14 @@ func (rl *RedissionRWLocker) RLock(timeout ...time.Duration) error { return fmt.Errorf("get read lock failed:%w", result) } - if (result.Code == constant.LockSuccess) && rl.needRefresh { - rl.once.Do(func() { - // async refresh lock timeout unitl receive exit singal - go rl.refreshLockTimeout() - }) + if result.Code == constant.LockSuccess { + if rl.needRefresh { + rl.once.Do(func() { + // async refresh lock timeout unitl receive exit singal + go rl.refreshLockTimeout() + }) + } + rl.logger.Info("success get the read by key and token", zap.String("key", rl.key), zap.String("token", rl.token)) return nil } @@ -50,7 +53,6 @@ func (rl *RedissionRWLocker) RLock(timeout ...time.Duration) error { acquireTimer := time.NewTimer(timeout[0]) for { select { - case _, ok := <-subMsg: if !ok { err := errors.New("failed to read the read lock waiting for for the channel message") @@ -117,12 +119,13 @@ func (rl *RedissionRWLocker) refreshLockTimeout() { } timer.Reset(lockTime) case <-rl.exit: - break + return } } } func (rl *RedissionRWLocker) UnRLock() error { + rl.logger.Info("unlock RLock by key and token", zap.String("key", rl.key), zap.String("token", rl.token)) res := rl.client.Eval(luascript.UnRLockScript, []string{rl.key, rl.rwTokenTimeoutPrefix, rl.waitChanKey}, unlockMessage, rl.token) val, err := res.Int() if err != redis.Nil && err != nil { @@ -237,24 +240,19 @@ func (rl *RedissionRWLocker) UnWLock() error { } func GetRWLocker(client *redis.Client, ops *RedissionLockConfig) *RedissionRWLocker { - r := &redissionLocker{ - token: uuid.New().String(), - needRefresh: true, - client: client, - exit: make(chan struct{}), - once: &sync.Once{}, - logger: logger.GetLoggerInstance(), + if ops.Token == "" { + ops.Token = uuid.New().String() } - if len(ops.Prefix) <= 0 { + if ops.Prefix == "" { ops.Prefix = "redission-rwlock" } - if len(ops.TimeoutPrefix) <= 0 { + if ops.TimeoutPrefix == "" { ops.TimeoutPrefix = "rwlock_timeout" } - if len(ops.ChanPrefix) <= 0 { + if ops.ChanPrefix == "" { ops.ChanPrefix = "redission-rwlock-channel" } @@ -262,9 +260,17 @@ func GetRWLocker(client *redis.Client, ops *RedissionLockConfig) *RedissionRWLoc ops.LockLeaseTime = internalLockLeaseTime } - r.key = strings.Join([]string{ops.Prefix, ops.Key}, ":") - r.lockLeaseTime = ops.LockLeaseTime - r.waitChanKey = strings.Join([]string{ops.ChanPrefix, ops.Key, "write"}, ":") + r := &redissionLocker{ + token: ops.Token, + key: strings.Join([]string{ops.Prefix, ops.Key}, ":"), + needRefresh: ops.NeedRefresh, + lockLeaseTime: ops.LockLeaseTime, + waitChanKey: strings.Join([]string{ops.ChanPrefix, ops.Key, "write"}, ":"), + client: client, + exit: make(chan struct{}), + once: &sync.Once{}, + logger: logger.GetLoggerInstance(), + } rwLocker := &RedissionRWLocker{ redissionLocker: *r, diff --git a/distributedlock/rwlock_test.go b/distributedlock/rwlock_test.go index 8ecaaaa..b3fc342 100644 --- a/distributedlock/rwlock_test.go +++ b/distributedlock/rwlock_test.go @@ -1,10 +1,12 @@ package distributed_lock import ( + "strings" "testing" "time" "github.com/go-redis/redis" + "github.com/stretchr/testify/assert" "go.uber.org/zap" ) @@ -14,6 +16,44 @@ func init() { log = zap.Must(zap.NewDevelopment()) } +func TestRWLockRLockAndUnRLock(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Network: "tcp", + Addr: "192.168.2.103:6379", + Password: "cnstar", + PoolSize: 50, + DialTimeout: 10 * time.Second, + }) + + rwLocker := GetRWLocker(rdb, &RedissionLockConfig{ + LockLeaseTime: 120, + NeedRefresh: true, + Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a", + }) + rwLocker.logger = log + + duration := 10 * time.Second + // 第一次加读锁 + err := rwLocker.RLock(duration) + assert.Equal(t, nil, err) + + tokenKey := strings.Join([]string{rwLocker.rwTokenTimeoutPrefix, rwLocker.token}, ":") + num, err := rdb.HGet(rwLocker.key, tokenKey).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + err = rwLocker.UnRLock() + assert.Equal(t, nil, err) + + num, err = rdb.HGet(rwLocker.key, tokenKey).Int() + assert.Equal(t, redis.Nil, err) + assert.Equal(t, 0, num) + t.Log("test success") + return +} + +// TODO 实现可重入读锁测试 func TestRWLockReentrantLock(t *testing.T) { rdb := redis.NewClient(&redis.Options{ Network: "tcp", @@ -25,22 +65,43 @@ func TestRWLockReentrantLock(t *testing.T) { rwLocker := GetRWLocker(rdb, &RedissionLockConfig{ LockLeaseTime: 120, + NeedRefresh: true, Key: "component", + Token: "fd348a84-e07c-4a61-8c19-f753e6bc556a", }) - rwLocker.logger = log - t.Logf("%+v\n", rwLocker) duration := 10 * time.Second // 第一次加读锁 err := rwLocker.RLock(duration) - t.Logf("err:%+v\n", err) - // TODO 实现可重入读锁测试 - // rwLocker.UnRLock() - // // 第二次加读锁 - // rwLocker.RLock(duration) - // // 查看 redis 中相关 key 的值 - // rwLocker.UnRLock() + assert.Equal(t, nil, err) + + tokenKey := strings.Join([]string{rwLocker.rwTokenTimeoutPrefix, rwLocker.token}, ":") + num, err := rdb.HGet(rwLocker.key, tokenKey).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 1, num) + + // 第二次加读锁 + err = rwLocker.RLock(duration) + assert.Equal(t, nil, err) + + num, err = rdb.HGet(rwLocker.key, tokenKey).Int() + assert.Equal(t, nil, err) + assert.Equal(t, 2, num) + + err = rwLocker.UnRLock() + assert.Equal(t, nil, err) + + num, err = rdb.HGet(rwLocker.key, tokenKey).Int() + assert.Equal(t, redis.Nil, err) + assert.Equal(t, 1, num) + + err = rwLocker.UnRLock() + assert.Equal(t, nil, err) + + num, err = rdb.HGet(rwLocker.key, tokenKey).Int() + assert.Equal(t, redis.Nil, err) + assert.Equal(t, 0, num) t.Log("test success") - select {} + return } diff --git a/go.mod b/go.mod index 3d655e8..21444d7 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/natefinch/lumberjack v2.0.0+incompatible github.com/panjf2000/ants/v2 v2.10.0 github.com/spf13/viper v1.19.0 + github.com/stretchr/testify v1.9.0 github.com/swaggo/files v1.0.1 github.com/swaggo/gin-swagger v1.6.0 github.com/swaggo/swag v1.16.4 @@ -30,6 +31,7 @@ require ( github.com/bytedance/sonic/loader v0.2.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gabriel-vasile/mimetype v1.4.7 // indirect github.com/gin-contrib/sse v0.1.0 // indirect @@ -60,6 +62,7 @@ require ( github.com/onsi/ginkgo v1.16.5 // indirect github.com/onsi/gomega v1.18.1 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect