Spring Legacy(Spring boot 를 사용하지 않는 환경)에서 작성한 코드 및 설정입니다.
Spring Boot 쓰시는 분들은 Filter 등록법을 조금만 검색해서
이 글을 응용하시면 될 듯합니다 😎(절대 귀찮아서 안쓰는 거 아닙니다.)
<filter>
<filter-name>encodingFilter</filter-name>
<filter-class>org.springframework.web.filter.CharacterEncodingFilter</filter-class>
<init-param>
<param-name>encoding</param-name>
<param-value>utf-8</param-value>
</init-param>
</filter>
<filter-mapping>
<filter-name>encodingFilter</filter-name>
<url-pattern>/*</url-pattern>
</filter-mapping>
<!--
RemoteIpFilter 는 proxy 서버가 앞단에 있을 때, Http Header 의 값들을
분석해서 실제 요청을 보낸 클라이언트의 IP 를 알아내서 request.remoteAddr
에 값을 설정해주는 편의성 filter 입니다.
<!--
RemoteIpFilter 는 아마 filter-class 에 작성하면 IDE 가 빨간줄을 그을텐데,
그럴 때는 maven dependecy 로 ...
groupId: org.apache.tomcat
artifactId: tomcat-catalina
version : 여러분의 톰캣 버전
scope: provided
... 를 추가해주시기 바랍니다.
-->
<filter>
<filter-name>RemoteIpFilter</filter-name>
<filter-class>org.apache.catalina.filters.RemoteIpFilter</filter-class>
</filter>
<filter-mapping>
<filter-name>RemoteIpFilter</filter-name>
<url-pattern>/*</url-pattern>
<dispatcher>REQUEST</dispatcher>
</filter-mapping>
<!-- 이게 바로 RATE LIMIT FILTER!! -->
<filter>
<filter-name>redisRateLimitFilter</filter-name>
<filter-class>org.springframework.web.filter.DelegatingFilterProxy</filter-class>
</filter>
<filter-mapping>
<filter-name>redisRateLimitFilter</filter-name>
<url-pattern>/remote-api/*</url-pattern>
<dispatcher>REQUEST</dispatcher>
</filter-mapping>
RedisRateLimitFilter 를 Bean 으로 등록해야 됩니다.
Bean ID (또는 명칭) 은 반드시web.xml
에서 표기한filter-name
과 동일하게
"redisRateLimitFilter" 로 명시해야 됩니다. 너무 쉬운 파트니 이건 Skip!
주의: java 17, jakarta ee 을 사용한다는 점 유의하셔서 보시기 바랍니다.
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.util.StringUtils;
import jakarta.servlet.*;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* <h2>Redis 기반 Rate Limiting Filter</h2>
* Redis 와 Lua Script 를 활용한 RateLimit 기능을 제공하는 Servlet Filter 이다.<br>
* 주목적은 DDOS 공격 방어용으로 만들었으며,<br>
* 특정 시간 (WINDOW_TIME_FRAME) 동안<br>
* 최대 요청 횟수(MAX_REQUEST_PER_WINDOW) 를 제한한다.<br>
*/
public class RedisRateLimitFilter implements Filter {
private static final Logger LOGGER
= LoggerFactory.getLogger(RedisRateLimitFilter.class);
/**
* Spring Application 에서 DI 받은 RedisTemplate bean instance
*/
private final RedisTemplate<String, Object> redisTemplate;
/**
* Lua Script Wrapper
*/
private final RedisScript<Long> script;
/**
* Window 하나당 요청할 수 있는 최대 요청 수
*/
private final String MAX_REQUEST_PER_WINDOW = "8";
/**
* Window 의 크기 (=Window 하나의 유지시간, 초단위)<br>
* 너무 큰 값을 주지 않도록 주의 바람.
*/
private final String WINDOW_TIME_FRAME = "1";
// 1초에 8번을 초과해서 요청하면 막는다
/**
* lua script 에 사용될 고정된 인자값
*/
private final Object[] SCRIPT_ARGS
= Arrays.asList(MAX_REQUEST_PER_WINDOW, WINDOW_TIME_FRAME).toArray();
/**
* 기본 에러 문구
*/
private final String DEFAULT_ERROR_MSG
= "짧은 시간 내에 너무 많은 요청을 보냈습니다. 잠시 기다렸다 다시 요청해주세요.";
// ApplicationContext 에서 미리 생성한 RedisTemplate 인스턴스
// bean 을 주입받습니다.
public RedisRateLimitFilter(RedisTemplate<String, Object> redisTemplate) {
this.redisTemplate = redisTemplate;
// redis lua script 사용
String rateLimitScript =
"local current = redis.call('get', KEYS[1]) " +
"if current then " +
"if tonumber(current) >= tonumber(ARGV[1]) " +
"then return 0 " +
"else " +
"redis.call('incr', KEYS[1]) " +
"redis.call('expire', KEYS[1], ARGV[2]) " +
"return 1 " +
"end " +
"else " +
"redis.call('set', KEYS[1], 1) " +
"redis.call('expire', KEYS[1], ARGV[2]) " +
"return 1 " +
"end";
// 설명:
// 1. 먼저 KEYS[1] (= ip + 요청 uri 를 합친 문자열) 을 조회한다.
//
// 2-1. 만약에 조회가 안되면(끝에 있는 else) 신규로 해당 KEYS[1] 에
// "1" 이라는 문자열값을 주고, EXPIRE 값(초 단위)도 준다.
//
// 2-2. 만약에 조회가 된다면 (if current then)
//
// 2-2-1. 읽어온 값(=current) 가 ARGV[1] (= Maximum 요청 제한 횟수)
// 을 같거나, 넘으면 0 을 반환한다.
//
// 2-2-2. 그게 아니라면 current 값을 1 증가시키고, expire 시간을
// 재조정한다. 그리고 나서 1을 반환한다.
//
// * 참고: 이 모든 과정은 하나의 트랜잭션 내에서 일어난다.
// Redis + lua script 의 기본 동작 방식이다.
// Lau Script Wrapper 생성
script = new DefaultRedisScript<>(rateLimitScript, Long.class);
}
@Override
public void doFilter(ServletRequest request,
ServletResponse response,
FilterChain chain) throws IOException, ServletException {
if (request instanceof HttpServletRequest servletRequest
&& response instanceof HttpServletResponse servletResponse) {
String remoteAddr = servletRequest.getRemoteAddr();
String requestURI = servletRequest.getRequestURI();
String method = servletRequest.getMethod();
// 등록(=POST) 요청만 체크하겠다.
if ("POST".equalsIgnoreCase(method)) {
// (중요) 막는 타깃은 [IP + 세션] 입니다!
String sessionId = request.getSession(true);
String clientKey
= "rate_limit:"
+ remoteAddr + ":"
+ sessionId + ":"
+ requestURI;
List<String> keys = Collections.singletonList(clientKey);
Long result = redisTemplate.execute(script, keys, SCRIPT_ARGS);
if (result != null && result == 0) {
LOGGER.error("Too Many Request! [ ip: {} , url: {} ]",
remoteAddr, requestURI);
servletResponse.setStatus(429);
servletResponse.setCharacterEncoding("UTF-8");
servletResponse.setHeader("Content-Type", "application/json");
String errorResponseJson = createErrorResponseJson(requestURI);
servletResponse.getWriter().write(errorResponseJson);
return;
}
}
}
chain.doFilter(request, response);
}
/**
* error Message 를 담는 Json String 을 반환한다.
* @param currentRequestUrl 현재 에러를 일으키는 요청 URL
* @return 에러 문구를 담는 json 포맷 형식의 string
*/
private String createErrorResponseJson(String currentRequestUrl) {
// java 17 text block 문법 사용
return """
{"errorMsg" : "%s","blockedUrl" : "%s"}"""
.formatted(DEFAULT_ERROR_MSG, currentRequestUrl);
}
}
주의사항
LAN 환경에서는 여러 호스트가 공용으로 사용하는 하나의 외부 IP 를 갖는 경우가
많습니다. 이런 경우를 생각해서 위의 코드에서는 타깃 코드(=clientKey)를 생성할 때
절대로 IP 만으로 막으면 안됩니다!// 막는 타깃이 [IP] 일 경우 String clientKey = "rate_limit:" + remoteAddr + ":" + requestURI;
그러니 꼭 아래처럼 세션(또는 각 호스트를 구별할 수 있는 어떤것이든)값을 하나
끼워서 clientKey 를 생성하시기 바랍니다!String sessionId = request.getSession(true); String clientKey = "rate_limit:" + remoteAddr + ":" + sessionId + ":" // 이게 핵심! + requestURI;