diff --git a/distributedlock/luascript/rlock_script.go b/distributedlock/luascript/rlock_script.go index d2d72f1..ee5805c 100644 --- a/distributedlock/luascript/rlock_script.go +++ b/distributedlock/luascript/rlock_script.go @@ -1,14 +1,15 @@ // Package luascript defines the lua script used for redis distributed lock package luascript -// RlockScript is the lua script for the lock read lock command +// RLockScript is the lua script for the lock read lock command /* KEYS[1]:锁的键名(key),通常是锁的唯一标识。 KEYS[2]:锁的超时键名前缀(rwTimeoutPrefix),用于存储每个读锁的超时键。 -ARGV[1]:锁的过期时间(lockLeaseTime),单位为毫秒。 +ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。 ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 */ -var RlockScript = `local mode = redis.call('hget', KEYS[1], 'mode'); +var RLockScript = ` +local mode = redis.call('hget', KEYS[1], 'mode'); local lockKey = KEYS[2] .. ARGV[2]; if (mode == false) then redis.call('hset', KEYS[1], 'mode', 'read'); @@ -19,6 +20,9 @@ if (mode == false) then end; if (mode == 'write') then + -- TODO 放到 list 中等待写锁释放后再次尝试加锁并且订阅写锁释放的消息 + local key = KEYS[1] .. ':read'; + redis.call('rpush', key, ARGV[2]); return -1; end; @@ -42,7 +46,7 @@ if (mode == 'read') then for i = 1, #fields,2 do local field = fields[i]; local remainTime = redis.call('httl', KEYS[1], 'fields', '1', field); - maxRemainTime = math.max(tonumber(remainTime[1]), maxRemainTime); + maxRemainTime = math.max(tonumber(remainTime[1]), maxRemainTime); end; until cursor == 0; @@ -52,56 +56,198 @@ if (mode == 'read') then end; ` -// TODO 优化读锁解锁语句 -// UnRlockScript is the lua script for the unlock read lock command +// UnRLockScript is the lua script for the unlock read lock command /* KEYS[1]:锁的键名(key),通常是锁的唯一标识。 -KEYS[2]:锁的释放通知频道(chankey),用于通知其他客户端锁已释放。 -KEYS[3]:锁的超时键名前缀(rwTimeoutTokenPrefix),用于存储每个读锁的超时键。 -KEYS[4]:锁的超时键名前缀(prefixKey),用于存储每个读锁的超时键。 +KEYS[2]:锁的超时键名前缀(rwTimeoutPrefix),用于存储每个读锁的超时键。 +KEYS[3]:锁的释放通知读频道(chankey),用于通知其他客户端锁已释放。 +KEYS[4]:锁的释放通知写频道(chankey),用于通知其他客户端锁已释放。 ARGV[1]:解锁消息(unlockMessage),用于通知其他客户端锁已释放。 ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 */ -var UnRlockScript = `local mode = redis.call('hget', KEYS[1], 'mode'); +var UnRLockScript = ` +local lockKey = KEYS[2] .. ARGV[2]; +local mode = redis.call('hget', KEYS[1], 'mode'); if (mode == false) then - redis.call('publish', KEYS[2], ARGV[1]); + local writeWait = KEYS[1] .. ':write'; + -- 优先写锁加锁,无写锁的情况通知读锁加锁 + local counter = redis.call('llen',writeWait) + if (counter >= 1) then + redis.call('publish', KEYS[4], ARGV[1]); + else + redis.call('publish', KEYS[3], ARGV[1]); + end; + return 1; +elseif (mode == 'write') then + return -2; +end; + +-- 判断当前的确是读模式但是当前 token 并没有加读锁的情况,返回 1 +local lockExists = redis.call('hexists', KEYS[1], lockKey); +if ((mode == 'read') and (lockExists == 0)) then return 1; end; -local lockExists = redis.call('hexists', KEYS[1], ARGV[2]); -if (lockExists == 0) then - return nil; -end; -local counter = redis.call('hincrby', KEYS[1], ARGV[2], -1); +local counter = redis.call('hincrby', KEYS[1], lockKey, -1); if (counter == 0) then - redis.call('hdel', KEYS[1], ARGV[2]); + redis.call('hdel', KEYS[1], lockKey); end; -redis.call('del', KEYS[3] .. ':' .. (counter+1)); if (redis.call('hlen', KEYS[1]) > 1) then - local maxRemainTime = -3; - local keys = redis.call('hkeys', KEYS[1]); - for n, key in ipairs(keys) do - counter = tonumber(redis.call('hget', KEYS[1], key)); - if type(counter) == 'number' then - for i=counter, 1, -1 do - local remainTime = redis.call('ttl', KEYS[4] .. ':' .. key .. ':rwlock_timeout:' .. i); - maxRemainTime = math.max(remainTime, maxRemainTime); + local cursor = 0; + local maxRemainTime = 0; + local pattern = KEYS[2] .. ':*'; + repeat + local hscanResult = redis.call('hscan', KEYS[1], cursor, 'match', pattern, 'count', '100'); + cursor = tonumber(hscanResult[1]); + local fields = hscanResult[2]; + + for i = 1, #fields,2 do + local field = fields[i]; + local remainTime = redis.call('httl', KEYS[1], 'fields', '1', field); + maxRemainTime = math.max(tonumber(remainTime[1]), maxRemainTime); + end; + until cursor == 0; + + if (maxRemainTime > 0) then + local remainTime = redis.call('ttl', KEYS[1]); + redis.call('expire', KEYS[1], math.max(tonumber(remainTime),maxRemainTime)); + end; +else + redis.call('del', KEYS[1]); + local writeWait = KEYS[1] .. ':write'; + -- 优先写锁加锁,无写锁的情况通知读锁加锁 + local counter = redis.call('llen',writeWait) + if (counter >= 1) then + redis.call('publish', KEYS[4], ARGV[1]); + else + redis.call('publish', KEYS[3], ARGV[1]); + end; + return 1; +end; +` + +// WLockScript is the lua script for the lock write lock command +/* +KEYS[1]:锁的键名(key),通常是锁的唯一标识。 +KEYS[2]:锁的超时键名前缀(rwTimeoutPrefix),用于存储每个读锁的超时键。 +ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。 +ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 +*/ +var WLockScript = ` +local mode = redis.call('hget', KEYS[1], 'mode'); +local lockKey = KEYS[2] .. ARGV[2]; +local waitKey = KEYS[1] .. ':write'; +if (mode == false) then + local firstToken = redis.call('lindex', waitKey,'0') + if (firstToken ~= ARGV[2]) then + return -7; + end; + redis.call('hset', KEYS[1], 'mode', 'write'); + redis.call('hset', KEYS[1], lockKey, 1); + redis.call('hexpire', KEYS[1], ARGV[1] 'fields' '1' lockKey); + redis.call('expire', KEYS[1], ARGV[1]); + redis.call('lpop', waitKey, '1') + return 1; +elseif (mode == 'read') then + -- TODO 放到 list 中等待读锁释放后再次尝试加锁并且订阅读锁释放的消息 + redis.call('rpush', waitkey, ARGV[2]); + return -3; +else + // 可重入写锁逻辑 + local lockKey = KEYS[2] .. ARGV[2] + local lockExists = redis.call('hexists', KEYS[1], lockKey) + if (lockExists == 1) then + redis.call('hincrby', KEYS[1], lockKey, 1); + redis.call('hexpire', KEYS[1], ARGV[1] 'fields' '1' lockKey); + redis.call('expire', KEYS[1], ARGV[1]); + return 1; + end; + -- 放到 list 中等待写锁释放后再次尝试加锁并且订阅写锁释放的消息 + local key = KEYS[1] .. ':write'; + redis.call('rpush', key, ARGV[2]); + return -4; +end; +` + +// UnWLockScript is the lua script for the unlock write lock command +/* +KEYS[1]:锁的键名(key),通常是锁的唯一标识。 +KEYS[2]:锁的超时键名前缀(rwTimeoutPrefix),用于存储每个读锁的超时键。 +KEYS[3]:锁的释放通知读频道(chankey),用于通知其他客户端锁已释放。 +KEYS[4]:锁的释放通知写频道(chankey),用于通知其他客户端锁已释放。 +ARGV[1]:解锁消息(unlockMessage),用于通知其他客户端锁已释放。 +ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 +*/ +var UnWLockScript = ` +local mode = redis.call('hget', KEYS[1], 'mode'); +local writeWait = KEYS[1] .. ':write'; +if (mode == false) then + -- 优先写锁加锁,无写锁的情况通知读锁加锁 + local counter = redis.call('llen',writeWait) + if (counter >= 1) then + redis.call('publish', KEYS[4], ARGV[1]); + else + redis.call('publish', KEYS[3], ARGV[1]); + end; + return 1; +elseif (mode == 'read') then + return -5; +else + // 可重入写锁逻辑 + local lockKey = KEYS[2] .. ARGV[2] + local lockExists = redis.call('hexists', KEYS[1], lockKey) + if (lockExists == 1) then + local incrRes = redis.call('hincrby', KEYS[1], lockKey, -1); + if (incrRes == 0) then + redis.call('del', KEYS[1]); + local counter = redis.call('llen',writeWait) + if (counter >= 1) then + redis.call('publish', KEYS[4], ARGV[1]); + else + redis.call('publish', KEYS[3], ARGV[1]); end; + return 1 end; end; - - if maxRemainTime > 0 then - redis.call('pexpire', KEYS[1], maxRemainTime); - return 0; - end; - - if mode == 'write' then - return 0; - end; + return -6; end; - -redis.call('del', KEYS[1]); -redis.call('publish', KEYS[2], ARGV[1]); -return 1; +` + +// RefreshLockScript is the lua script for the refresh lock command +/* +KEYS[1]:锁的键名(key),通常是锁的唯一标识。 +KEYS[2]:锁的超时键名前缀(rwTimeoutPrefix),用于存储每个读锁的超时键。 +ARGV[1]:锁的过期时间(lockLeaseTime),单位为秒。 +ARGV[2]:当前客户端的唯一标识(token),用于区分不同的客户端。 +*/ +var RefreshLockScript = ` +local lockKey = KEYS[2] .. ARGV[2] +local lockExists = redis.call('hexists', KEYS[1], lockKey); +local mode = redis.call('hget', KEYS[1], 'mode') +if (lockExists == 1) then + redis.call('hexpire', KEYS[1], ARGV[1] 'fields' '1' lockKey); + if (mode == 'read' ) then + local cursor = 0; + local maxRemainTime = tonumber(ARGV[1]); + local pattern = KEYS[2] .. ':*'; + repeat + local hscanResult = redis.call('hscan', KEYS[1], cursor, 'match', pattern, 'count', '100'); + cursor = tonumber(hscanResult[1]); + local fields = hscanResult[2]; + + for i = 1, #fields,2 do + local field = fields[i]; + local remainTime = redis.call('httl', KEYS[1], 'fields', '1', field); + maxRemainTime = math.max(tonumber(remainTime[1]), maxRemainTime); + end; + until cursor == 0; + if (maxRemainTime > 0) then + local remainTime = redis.call('ttl', KEYS[1]); + redis.call('expire', KEYS[1], math.max(tonumber(remainTime),maxRemainTime)); + end; + end; + return 1; +end; +return -8; ` diff --git a/distributedlock/redis_lock.go b/distributedlock/redis_lock.go index e66c03c..cbf79dc 100644 --- a/distributedlock/redis_lock.go +++ b/distributedlock/redis_lock.go @@ -80,7 +80,7 @@ type redissionLocker struct { } func (rl *redissionLocker) Lock(ctx context.Context, timeout ...time.Duration) { - fmt.Println(luascript.RlockScript) + fmt.Println(luascript.RLockScript) if rl.exit == nil { rl.exit = make(chan struct{}) }