Rate Limiter는 간단하게 말해서 서버가 클라이언트의 시간당 요청횟수를 제한하는 기술을 의미한다.
보통 Api를 사용할 때 분당, 혹은 시간당 몇회 요청이 제한되어 있는 것을 본적이 있을 것이다.
서버가 제공할 수 있는 자원에는 한계가 있기 때문에 안정적으로 서비스를 제공하기 위해 사용하는 대표적인 혼잡 제어 기법이다.
실제로 규모가 큰 서비스(Linkedin, Github, Facebook, Amazon, Stripe 등)에는 대부분 Rate Limit 정책을 사용하고 있을 만큼 API 설계에 있어서 필수 적이다.
우리 팀은 현재 위치 추적 모듈을 개발하고 있으며, 이는 개발자들에게 편리한 위치 추적 api를 제공하는 서비스 이다.
일단 Spring Cloud Gateway(최근에는 mvc도 생김,, 여기서는 reactive gateway 만 논함)에는 RequestRateLimiter라는 기본 구현체가 존재한다.
사용법도 굉장히 간단한데,
dependencies {
//...
implementation 'org.springframework.boot:spring-boot-starter-data-redis-reactive'
implementation 'org.springframework.cloud:spring-cloud-starter-gateway'
//...
}
이렇게 두개만 추가해주면 AutoConfiguration에 의해 자동으로 등록되게 된다. (RedisTemplate의 존재 여부에 트리거링 되도록 되어 있다.)
package org.springframework.cloud.gateway.config;
@Configuration(proxyBeanMethods = false)
@AutoConfigureAfter(RedisReactiveAutoConfiguration.class)
@AutoConfigureBefore(GatewayAutoConfiguration.class)
@ConditionalOnBean(ReactiveRedisTemplate.class)
@ConditionalOnClass({ RedisTemplate.class, DispatcherHandler.class })
@ConditionalOnProperty(name = "spring.cloud.gateway.redis.enabled", matchIfMissing = true)
class GatewayRedisAutoConfiguration {
한가지는 필수적으로 제공해줘야 하는데, 바로 KeyResolver
인터페이스의 구현체를 빈으로 등록해줘야한다.
public interface KeyResolver {
Mono<String> resolve(ServerWebExchange exchange);
}
단순하다, exchange로 부터 키를 어떻게 구성할까라는 부분이다.
여기서 key란 하나의 사용자를 어떻게 구분할 것인가를 나타낸다.
ip를 기반으로 할 수도, api key를 기반으로 할 수도 세션을 기반으로 할 수도 있다.
그리고 다음과 같이 application.yml 에 설정 해준다.
spring:
cloud:
gateway:
routes:
- id: rate_limiter_route
uri: http://localhost:19000
filters:
- name: RequestRateLimiter
args:
key-resolver: "#{@userKeyResolver}"
redis-rate-limiter.replenishRate: 1
redis-rate-limiter.burstCapacity: 10
redis-rate-limiter.requestedTokens: 1
라우트가 여러개 일때는 저 속성이 공유되나요? 라고 질문할 수가 있다.
아니다. 저 라우트에 정의된 필터 하나당 AbstractGatewayFilterFactory#apply(C config)
를 통해 생성한다.
RateLimiter는 RequestRateLimiterGatewayFilterFactory
라는 구현체를 통해 GatewayFilter
가 생성된다.
그렇다 되는게 중요한게 아니라 원리가 중요하다.
이 구현체에서는 토큰 버킷 알고리즘이 사용되었다.
여러가지 알고리즘이 있지만, 이 글의 범위를 벗어나므로, 링크만 첨부한다(링크)
보다 자세한 원리는 이 글의 뒷부분에 나올 예정이다(중복되므로 뒤로 뺌)
다음과 같은 요구사항을 충족시키지 못했다.
따라서 기존 구현체를 참고하여 나만의 RateLimiter를 직접 구현해 보았다. 전체 코드는
Pull Request 나, code 를 참고하길 바란다.
핵심은 redis의 EVAL 커맨드다.
이 커맨드는 서버사이드(레디스)에서 Lua 스크립트를 실행시켜준다.
redis는 싱글 쓰레드로 동작하기 때문에(엄밀히 말하면 아니지만), 이 스크립트는 원자적으로 실행된다.
이 커맨드를 통해 Lua Script를 실행시키는게 핵심이다.
다음 Lua Script는 기존 구현체의 lua script를 참고해서 수정한 나의 코드이다.
redis.replicate_commands()
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)
local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local ttl = tonumber(ARGV[5])
local fill_time = capacity / rate
-- local ttl = math.floor(fill_time * 2)
if ttl == nil then
ttl = math.floor(fill_time * 2)
end
-- for testing, it should use redis system time in production
if now == nil then
now = redis.call('TIME')[1]
end
--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. now)
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)
local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
new_tokens = filled_tokens - requested
allowed_num = 1
end
--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)
if ttl > 0 then
redis.call("SET", tokens_key, new_tokens, "EX", ttl)
redis.call("SET", timestamp_key, now, "EX", ttl)
end
-- return { allowed_num, new_tokens, capacity, filled_tokens, requested, new_tokens }
return { allowed_num, new_tokens }
부분부분 나눠서 설명하겠다.
redis.replicate_commands()
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)
local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local ttl = tonumber(ARGV[5])
local fill_time = capacity / rate
-- local ttl = math.floor(fill_time * 2)
if ttl == nil then
ttl = math.floor(fill_time * 2)
end
-- for testing, it should use redis system time in production
if now == nil then
now = redis.call('TIME')[1]
end
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
last_tokens = capacity
end
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
last_refreshed = 0
end
local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
new_tokens = filled_tokens - requested
allowed_num = 1
end
if ttl > 0 then
redis.call("SET", tokens_key, new_tokens, "EX", ttl)
redis.call("SET", timestamp_key, now, "EX", ttl)
end
return { allowed_num, new_tokens }
아무튼 이 스크립트를 쓰게 되면, 말 그대로 token bucket 알고리즘이 수행되며, 분산환경에서도(서버가 여러대 있어도) 올바르게 rate limiter 기능을 할 수 있게 된다.
@Slf4j
@Getter @Setter
@ConfigurationProperties("spring.cloud.gateway.api-rate-limiter")
public class ApiRateLimiterGatewayFilterFactory extends AbstractGatewayFilterFactory<ApiRateLimiterGatewayFilterFactory.Config> {
private final ApiRateLimiter defaultRateLimiter;
private final ApiRateContextResolver defaultContextResolver;
private boolean denyEmptyKey = true;
private Long defaultRequestedTokens = 1L;
private HttpStatus emptyKeyStatus = HttpStatus.UNAUTHORIZED;
private HttpStatus invalidKeyStatus = HttpStatus.FORBIDDEN;
// ... main code
@Getter @Setter
public static class Config implements HasRouteId {
private ApiRateContextResolver contextResolver;
private ApiRateLimiter rateLimiter;
private Long requestedTokens = 1L;
private HttpStatus notAllowedStatus = HttpStatus.TOO_MANY_REQUESTS;
private Boolean denyEmptyKey;
private HttpStatus emptyKeyStatus = HttpStatus.UNAUTHORIZED;
private HttpStatus invalidKeyStatus = HttpStatus.FORBIDDEN;
private String routeId;
}
이런식으로 설정해서 bean으로 등록하게 되면, 다음과 같은 설정이 가능하게 된다.
#spring.application.name=gateway
spring:
cloud:
gateway:
routes:
- id: route1
uri: http://localhost:8000
predicates:
- Path=/test/**
filters:
- name: ApiRateLimiter
args:
contextResolver: "#{@apiRateContextResolver}"
rateLimiter: "#{@apiKeyRateLimiter}"
requestedTokens: 100
notAllowedStatus: TOO_MANY_REQUESTS
denyEmptyKey: true
emptyKeyStatus: UNAUTHORIZED
이렇게 설정하면 어떻게 되냐?
바로 저 args의 설정들이 내가 제너릭으로 명시한 저 Config 클래스에 매핑되어서 apply 메서드를 호출하게 된다.
그리고 여기서 바로 GatewayFilter를 만들어서 리턴하게 된다.
@Override
public GatewayFilter apply(Config config) {
ApiRateContextResolver rateContextResolver = getOrDefault(config.getContextResolver(), this.defaultContextResolver);
ApiRateLimiter rateLimiter = getOrDefault(config.getRateLimiter(), this.defaultRateLimiter);
boolean denyEmptyKey = getOrDefault(config.getDenyEmptyKey(), this.denyEmptyKey);
HttpStatus emptyKeyStatus = getOrDefault(config.getEmptyKeyStatus(), this.emptyKeyStatus);
HttpStatus invalidKeyStatus = getOrDefault(config.getInvalidKeyStatus(), this.invalidKeyStatus);
HttpStatus notAllowedStatus = getOrDefault(config.getNotAllowedStatus(), HttpStatus.TOO_MANY_REQUESTS);
Long requestedTokens = Math.max(0,getOrDefault(config.getRequestedTokens(), 1L));
return (exchange, chain) -> rateContextResolver.resolve(exchange).flatMap(rateContext -> {
if (!(rateContext instanceof ValidApiRateContext)) {
if (rateContext instanceof AbsentApiRateContext) {
if (denyEmptyKey) {
ServerWebExchangeUtils.setResponseStatus(exchange, emptyKeyStatus);
return exchange.getResponse().setComplete();
}
return chain.filter(exchange);
}
if (rateContext instanceof InvalidApiRateContext) {
ServerWebExchangeUtils.setResponseStatus(exchange, invalidKeyStatus);
return exchange.getResponse().setComplete();
}
// should never happen
log.error("Unknown ApiRateContext type: {}", rateContext.getClass().getName());
ServerWebExchangeUtils.setResponseStatus(exchange, emptyKeyStatus);
return exchange.getResponse().setComplete();
}
ValidApiRateContext context = (ValidApiRateContext) rateContext;
String routeId = config.getRouteId();
if (routeId == null) {
Route route = exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR);
routeId = route.getId();
}
return rateLimiter.isAllowed(routeId, context, requestedTokens).flatMap(response -> {
for (Map.Entry<String, String> header : response.getHeaders().entrySet()) {
exchange.getResponse().getHeaders().add(header.getKey(), header.getValue());
}
if (response.isAllowed()) {
return chain.filter(exchange);
}
ServerWebExchangeUtils.setResponseStatus(exchange, notAllowedStatus);
return exchange.getResponse().setComplete();
});
});
}
그러니, 저렇게 application.yml 에 라우트 마다 같은 필터를 여러개 선언해도, 각각 다른 인스턴스가 생성되는 것이다.
나는 애초에 사용자 별로 다른 토큰 버킷 알고리즘을 제공하고 싶었다. 그래서 직접 구현을 한 것이다.
그래서 ApiRateContext라는 인터페이스를 만들고 상속 자식들을 제한하였다.
public sealed interface ApiRateContext permits ValidApiRateContext, InvalidApiRateContext, AbsentApiRateContext {
}
그리고 이것을 resolve하는 인터페이스를 정의했다. 이것은 기본 구현체의 keyResolver 정도의 역할을 담당하는 것이다. 기본 구현체의 경우 String을 리턴하는 간단한 구현이었다.
public interface ApiRateContextResolver {
/**
* API 키에 대한 Rate Context를 조회한다.
* @param exchange ServerWebExchange
* @return Rate Context
*/
Mono<ApiRateContext> resolve(ServerWebExchange exchange);
}
이 Context를 기반으로 다양한 분기처리를 하도록 구현하였고, 또 토큰 버킷 알고리즘도, 각 사용자, api에 맞게 커스텀 할 수 있도록 확장 포인트를 열어 놓았다.
일단 다음 개발까지 임시로 이렇게 막아놓았다.
/**
* 다음 개발시까지 목업 해두기 위해 사용
*/
public class MockApiRateContextResolver implements ApiRateContextResolver {
private final static AbsentApiRateContext ABSENT_API_RATE_CONTEXT = new AbsentApiRateContext();
private final static ValidApiRateContext VALID_API_RATE_CONTEXT_FOR_MOCKED_TEST = new ValidApiRateContext("test-key", 10, 300);
@Override
public Mono<ApiRateContext> resolve(ServerWebExchange exchange) {
String apiKey = exchange.getRequest().getHeaders().getFirst(GatewayConstant.API_KEY_HEADER);
return resolve(apiKey);
}
protected Mono<ApiRateContext> resolve(String apiKey) {
if (apiKey == null) {
return Mono.just(ABSENT_API_RATE_CONTEXT);
}
// 실제 상황에서는 fetching 과 검증이 일어나야함
return Mono.just(VALID_API_RATE_CONTEXT_FOR_MOCKED_TEST);
}
}
이 것의 경우 기본 구현체의 RateLimiter 정도에 해당하는 역할인데, 인자와 응답값을 내가 다르게 하고 싶어서 따로 정의했다.
public interface ApiRateLimiter {
Mono<Response> isAllowed(String routeId, ValidApiRateContext apiRateContext, Long requestedTokens);
class Response {
private final boolean allowed;
private final Map<String, String> headers;
public Response(boolean allowed, Map<String, String> headers) {
this.allowed = allowed;
Assert.notNull(headers, "headers may not be null");
this.headers = headers;
}
public boolean isAllowed() {
return allowed;
}
public Map<String, String> getHeaders() {
return headers;
}
}
}
그리고 전체 코드는 아래와 같다.
isAllowed 를 따라가면서 읽어보면 되는데, 핵심은 tryAcquireToken에서 인자와 key를 만들어서 아까 설명했던 lua script를 실행시키는 내용이라는 것만 유념하면 된다.
@Slf4j
@Getter @Setter
@RequiredArgsConstructor
public class ApiRedisRateLimiter implements ApiRateLimiter {
/**
* Redis Rate Limiter property name.
*/
public static final String CONFIGURATION_PROPERTY_NAME = "redis-rate-limiter";
/**
* Replenish Rate Limit header name.
*/
public static final String REPLENISH_RATE_HEADER = "X-RateLimit-Replenish-Rate";
/**
* Burst Capacity header name.
*/
public static final String BURST_CAPACITY_HEADER = "X-RateLimit-Burst-Capacity";
/**
* Requested Tokens header name.
*/
public static final String REQUESTED_TOKENS_HEADER = "X-RateLimit-Requested-Tokens";
/**
* Remaining Rate Limit header name.
*/
public static final String REMAINING_HEADER = "X-RateLimit-Remaining";
private final ReactiveStringRedisTemplate redisTemplate;
private final RedisScript<List<Long>> script;
private final ApiRateContextResolver rateContextResolver;
// configuration properties
/**
* Whether or not to include headers containing rate limiter information, defaults to
* true.
*/
private boolean includeHeaders = true;
/**
* The name of the header that returns number of remaining requests during the current
* second.
*/
private String remainingHeader = REMAINING_HEADER;
/** The name of the header that returns the replenish rate configuration. */
private String replenishRateHeader = REPLENISH_RATE_HEADER;
/** The name of the header that returns the burst capacity configuration. */
private String burstCapacityHeader = BURST_CAPACITY_HEADER;
/** The name of the header that returns the requested tokens configuration. */
private String requestedTokensHeader = REQUESTED_TOKENS_HEADER;
static List<String> getKeys(String id) {
// use `{}` around keys to use Redis Key hash tags
// this allows for using redis cluster
// Make a unique key per user.
String prefix = "api_rate_limiter.{" + id;
// You need two Redis keys for Token Bucket.
String tokenKey = prefix + "}.tokens";
String timestampKey = prefix + "}.timestamp";
return Arrays.asList(tokenKey, timestampKey);
}
@Override
public Mono<Response> isAllowed(String routeId, ValidApiRateContext apiRateContext, Long requestedTokens) {
long replenishRate = apiRateContext.replenishRate();
long burstCapacity = apiRateContext.burstCapacity();
String apiKey = apiRateContext.key();
try {
return tryAcquireToken(apiKey, replenishRate, burstCapacity, requestedTokens);
} catch (Exception e) {
// redis에 의해 single point of failure가 되지 않도록 그냥 로그만 남기고 통과
log.error("Error determining if user allowed from redis", e);
return Mono.just(new Response(true, getHeaders(replenishRate, burstCapacity, requestedTokens, -1L)));
}
}
protected Mono<Response> tryAcquireToken(String key, long replenishRate, long burstCapacity, long requestedTokens) {
long ttl = timeToLive(replenishRate, burstCapacity, requestedTokens);
long now = Instant.now().getEpochSecond();
List<String> keys = getKeys(key);
List<String> scriptArgs = Arrays.asList(
String.valueOf(replenishRate), //ARGV[1]
String.valueOf(burstCapacity), //ARGV[2]
String.valueOf(now), //ARGV[3]
String.valueOf(requestedTokens), //ARGV[4]
String.valueOf(ttl) //ARGV[5]
);
return this.redisTemplate.execute(this.script, keys, scriptArgs)
.onErrorResume(throwable -> {
if (log.isDebugEnabled()) {
log.debug("Error calling rate limiter lua", throwable);
}
return Flux.just(Arrays.asList(1L, -1L));
}).reduce(new ArrayList<Long>(), (longs, l) -> {
longs.addAll(l);
return longs;
}).map(results -> {
boolean allowed = results.get(0) == 1L;
Long tokensLeft = results.get(1);
Response response = new Response(allowed, getHeaders(replenishRate, burstCapacity, requestedTokens, tokensLeft));
if (log.isDebugEnabled()) {
log.debug("response: " + response);
}
return response;
});
}
protected long timeToLive(long replenishRate, long burstCapacity, long requestedTokens) {
long fillTime = burstCapacity / replenishRate;
return fillTime * 2;
}
@Override
public String toString() {
return new ToStringCreator(this)
.append("redisTemplate", redisTemplate)
.append("script", script)
.toString();
}
public Map<String, String> getHeaders(long replenishRateHeader, long burstCapacityHeader, long requestedTokensHeader, Long tokensLeft) {
Map<String, String> headers = new HashMap<>();
if (isIncludeHeaders()) {
headers.put(this.remainingHeader, tokensLeft.toString());
headers.put(this.replenishRateHeader, String.valueOf(replenishRateHeader));
headers.put(this.burstCapacityHeader, String.valueOf(burstCapacityHeader));
headers.put(this.requestedTokensHeader, String.valueOf(requestedTokensHeader));
}
return headers;
}
}
오늘의 핵심 내용은 세가지다.