Spring Legacy(Spring boot 를 사용하지 않는 환경)에서 작성한 코드 및 설정입니다.
Spring Boot 쓰시는 분들은 Filter 등록법을 조금만 검색해서
이 글을 응용하시면 될 듯합니다 😎(절대 귀찮아서 안쓰는 거 아닙니다.)
<filter>
<filter-name>encodingFilter</filter-name>
<filter-class>org.springframework.web.filter.CharacterEncodingFilter</filter-class>
<init-param>
<param-name>encoding</param-name>
<param-value>utf-8</param-value>
</init-param>
</filter>
<filter-mapping>
<filter-name>encodingFilter</filter-name>
<url-pattern>/*</url-pattern>
</filter-mapping>
<!--
RemoteIpFilter 는 proxy 서버가 앞단에 있을 때, Http Header 의 값들을
분석해서 실제 요청을 보낸 클라이언트의 IP 를 알아내서 request.remoteAddr
에 값을 설정해주는 편의성 filter 입니다.
<!--
RemoteIpFilter 는 아마 filter-class 에 작성하면 IDE 가 빨간줄을 그을텐데,
그럴 때는 maven dependecy 로 ...
groupId: org.apache.tomcat
artifactId: tomcat-catalina
version : 여러분의 톰캣 버전
scope: provided
... 를 추가해주시기 바랍니다.
-->
<filter>
<filter-name>RemoteIpFilter</filter-name>
<filter-class>org.apache.catalina.filters.RemoteIpFilter</filter-class>
</filter>
<filter-mapping>
<filter-name>RemoteIpFilter</filter-name>
<url-pattern>/*</url-pattern>
<dispatcher>REQUEST</dispatcher>
</filter-mapping>
<!-- 이게 바로 RATE LIMIT FILTER!! -->
<filter>
<filter-name>redisRateLimitFilter</filter-name>
<filter-class>org.springframework.web.filter.DelegatingFilterProxy</filter-class>
</filter>
<filter-mapping>
<filter-name>redisRateLimitFilter</filter-name>
<url-pattern>/remote-api/*</url-pattern>
<dispatcher>REQUEST</dispatcher>
</filter-mapping>
RedisRateLimitFilter 를 Bean 으로 등록해야 됩니다.
Bean ID (또는 명칭) 은 반드시web.xml
에서 표기한filter-name
과 동일하게
"redisRateLimitFilter" 로 명시해야 됩니다. 너무 쉬운 파트니 이건 Skip!
package me.dailycode.filter;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.util.StringUtils;
import org.springframework.web.servlet.support.ServletUriComponentsBuilder;
import org.springframework.web.util.UriComponents;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* <h2>Redis 기반 Rate Limiting Filter</h2>
* Redis 와 Lua Script 를 활용한 RateLimit 기능을 제공하는 Servlet Filter 이다.<br>
* 주목적은 DDOS 공격 방어용으로 만들었으며,<br>
* 특정 시간 (WINDOW_TIME_FRAME) 동안<br>
* 최대 요청 횟수(MAX_REQUEST_PER_WINDOW) 를 제한한다.<br>
*/
public class RedisRateLimitFilter implements Filter {
private static final Logger LOGGER = LoggerFactory.getLogger(RedisRateLimitFilter.class);
/**
* Spring Application 에서 DI 받은 RedisTemplate bean instance
*/
private final RedisTemplate<String, Object> redisTemplate;
/**
* Lua Script Wrapper
*/
private final RedisScript<Long> script;
/**
* Window 하나당 요청할 수 있는 최대 요청 수
*/
private final String MAX_REQUEST_PER_WINDOW = "5";
/**
* Window 의 크기 (=Window 하나의 유지시간, 초단위)<br>
* 너무 큰 값을 주지 않도록 주의 바람.
*/
private final String WINDOW_TIME_FRAME = "30";
/**
* lua script 에 사용될 고정된 인자값
*/
private final Object[] SCRIPT_ARGS = Arrays.asList(MAX_REQUEST_PER_WINDOW, WINDOW_TIME_FRAME).toArray();
/**
* Error Json Format Message 생성용 JsonNodeFactory.<br>
* Thread Safe 함으로 안심하고 사용 가능.
*/
private final JsonNodeFactory nodeFactory = JsonNodeFactory.instance;
/**
* 기본 에러 문구
*/
private final String DEFAULT_ERROR_MSG = "짧은 시간 내에 너무 많은 요청을 보냈습니다. 잠시 기다렸다 다시 요청해주세요.";
// ApplicationContext 에서 미리 생성한 RedisTemplate 인스턴스
// bean 을 주입받습니다.
public RedisRateLimitFilter(RedisTemplate<String, Object> redisTemplate) {
this.redisTemplate = redisTemplate;
// redis lua script 사용
String rateLimitScript =
"local current = redis.call('get', KEYS[1]) " +
"if current then " +
"if tonumber(current) >= tonumber(ARGV[1]) " +
"then return 0 " +
"else " +
"redis.call('incr', KEYS[1]) " +
"redis.call('expire', KEYS[1], ARGV[2]) " +
"return 1 " +
"end " +
"else " +
"redis.call('set', KEYS[1], 1) " +
"redis.call('expire', KEYS[1], ARGV[2]) " +
"return 1 " +
"end";
// 설명:
// 1. 먼저 KEYS[1] (= ip + 요청 uri 를 합친 문자열) 을 조회한다.
// 2-1. 만약에 조회가 안되면(끝에 있는 else) 신규로 해당 KEYS[1] 에 "1" 이라는 문자열값을 주고, EXPIRE 값(초 단위)도 준다.
// 2-2. 만약에 조회가 된다면 (if current then)
// 2-2-1. 읽어온 값(=current) 가 ARGV[1] (= Maximum 요청 제한 횟수) 을 같거나, 넘으면 0 을 반환한다.
// 2-2-2. 그게 아니라면 current 값을 1 증가시키고, expire 시간을 재조정한다. 그리고 나서 1을 반환한다.
//
// * 참고: 이 모든 과정은 하나의 트랜잭션 내에서 일어난다. Redis + lua script 의 기본 동작 방식이다.
// Lau Script Wrapper 생성
script = new DefaultRedisScript<>(rateLimitScript, Long.class);
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
if (request instanceof HttpServletRequest && response instanceof HttpServletResponse) {
HttpServletRequest servletRequest = (HttpServletRequest) request;
HttpServletResponse servletResponse = (HttpServletResponse) response;
String remoteAddr = servletRequest.getRemoteAddr();
String requestURI = servletRequest.getRequestURI();
String method = servletRequest.getMethod();
// 등록(=POST)과 관련해서만 체크한다. 필요하다면 PUT, PATCH, DELETE 도 사용할 수 있다.
if ("POST".equalsIgnoreCase(method)) {
String clientKey = "rate_limit:" + remoteAddr + ":" + requestURI;
List<String> keys = Collections.singletonList(clientKey);
Long result = redisTemplate.execute(script, keys, SCRIPT_ARGS);
if (result != null && result == 0) {
LOGGER.error("Too Many Request! Blocking! => [ ip: {} , url: {} ]", remoteAddr, requestURI);
servletResponse.setStatus(429);
servletResponse.setCharacterEncoding("UTF-8");
servletResponse.setHeader("Content-Type", "application/json");
String errorResponseJson = createErrorResponseJson(requestURI);
servletResponse.getWriter().write(errorResponseJson);
return;
}
}
}
chain.doFilter(request, response);
}
/**
* error Message 를 담는 Json String 을 반환한다.
* @param currentRequestUrl 현재 에러를 일으키는 요청 URL
* @return 에러 문구를 담는 json 포맷 형식의 string
*/
private String createErrorResponseJson(String currentRequestUrl) {
ObjectNode errorNode = nodeFactory.objectNode();
errorNode.put("errorMsg", DEFAULT_ERROR_MSG);
errorNode.put("blockedUrl", currentRequestUrl);
return errorNode.toString();
}
}