1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
| @Component public class RedisRateLimiter { @Autowired private RedisTemplate<String, Object> redisTemplate;
public boolean isAllowedFixedWindow(String key, int limit, int window) { String luaScript = "local key = KEYS[1] " + "local limit = tonumber(ARGV[1]) " + "local window = tonumber(ARGV[2]) " + "local current = redis.call('get', key) " + "if current == false then " + " redis.call('set', key, 1, 'EX', window) " + " return 1 " + "elseif tonumber(current) < limit then " + " redis.call('incr', key) " + " return 1 " + "else " + " return 0 " + "end"; DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>(); redisScript.setScriptText(luaScript); redisScript.setResultType(Long.class); Long result = redisTemplate.execute(redisScript, Collections.singletonList(key), String.valueOf(limit), String.valueOf(window)); return Long.valueOf(1).equals(result); }
public boolean isAllowedSlidingWindow(String key, int limit, int window) { long now = System.currentTimeMillis(); long windowStart = now - window * 1000L; String luaScript = "local key = KEYS[1] " + "local limit = tonumber(ARGV[1]) " + "local windowStart = tonumber(ARGV[2]) " + "local now = tonumber(ARGV[3]) " + "redis.call('zremrangebyscore', key, 0, windowStart) " + "local current = redis.call('zcard', key) " + "if current < limit then " + " redis.call('zadd', key, now, now) " + " redis.call('expire', key, " + window + ") " + " return 1 " + "else " + " return 0 " + "end"; DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>(); redisScript.setScriptText(luaScript); redisScript.setResultType(Long.class); Long result = redisTemplate.execute(redisScript, Collections.singletonList(key), String.valueOf(limit), String.valueOf(windowStart), String.valueOf(now)); return Long.valueOf(1).equals(result); }
public boolean isAllowedTokenBucket(String key, int capacity, double refillRate, int tokens) { long now = System.currentTimeMillis(); String luaScript = "local key = KEYS[1] " + "local capacity = tonumber(ARGV[1]) " + "local refillRate = tonumber(ARGV[2]) " + "local tokens = tonumber(ARGV[3]) " + "local now = tonumber(ARGV[4]) " + "local bucket = redis.call('hmget', key, 'tokens', 'lastRefill') " + "local currentTokens = tonumber(bucket[1]) or capacity " + "local lastRefill = tonumber(bucket[2]) or now " + "local timePassed = (now - lastRefill) / 1000 " + "local newTokens = math.min(capacity, currentTokens + timePassed * refillRate) " + "if newTokens >= tokens then " + " newTokens = newTokens - tokens " + " redis.call('hmset', key, 'tokens', newTokens, 'lastRefill', now) " + " redis.call('expire', key, 3600) " + " return 1 " + "else " + " redis.call('hmset', key, 'tokens', newTokens, 'lastRefill', now) " + " redis.call('expire', key, 3600) " + " return 0 " + "end"; DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>(); redisScript.setScriptText(luaScript); redisScript.setResultType(Long.class); Long result = redisTemplate.execute(redisScript, Collections.singletonList(key), String.valueOf(capacity), String.valueOf(refillRate), String.valueOf(tokens), String.valueOf(now)); return Long.valueOf(1).equals(result); } }
@RestController public class ApiController { @Autowired private RedisRateLimiter rateLimiter; @GetMapping("/api/data") public ResponseEntity<?> getData(HttpServletRequest request) { String clientIp = getClientIp(request); String rateLimitKey = "rate_limit:" + clientIp; if (!rateLimiter.isAllowedSlidingWindow(rateLimitKey, 100, 60)) { return ResponseEntity.status(HttpStatus.TOO_MANY_REQUESTS) .body("Rate limit exceeded"); } return ResponseEntity.ok("Data response"); } private String getClientIp(HttpServletRequest request) { String xForwardedFor = request.getHeader("X-Forwarded-For"); if (xForwardedFor != null && !xForwardedFor.isEmpty()) { return xForwardedFor.split(",")[0].trim(); } return request.getRemoteAddr(); } }
|