[Spring] Redis 를 사용한 Rate Limit Filter 구현

식빵·2024년 3월 25일
0

Spring Lab

목록 보기
32/33
post-thumbnail

Spring Legacy(Spring boot 를 사용하지 않는 환경)에서 작성한 코드 및 설정입니다.
Spring Boot 쓰시는 분들은 Filter 등록법을 조금만 검색해서
이 글을 응용하시면 될 듯합니다 😎 (절대 귀찮아서 안쓰는 거 아닙니다.)


web.xml 일부

  <filter>
    <filter-name>encodingFilter</filter-name>
    <filter-class>org.springframework.web.filter.CharacterEncodingFilter</filter-class>
    <init-param>
      <param-name>encoding</param-name>
      <param-value>utf-8</param-value>
    </init-param>
  </filter>
  <filter-mapping>
    <filter-name>encodingFilter</filter-name>
    <url-pattern>/*</url-pattern>
  </filter-mapping>


  <!-- 
	RemoteIpFilter 는 proxy 서버가 앞단에 있을 때, Http Header 의 값들을
	분석해서 실제 요청을 보낸 클라이언트의 IP 를 알아내서 request.remoteAddr
	에 값을 설정해주는 편의성 filter 입니다.
  <!-- 
	RemoteIpFilter 는 아마 filter-class 에 작성하면 IDE 가 빨간줄을 그을텐데,
	그럴 때는 maven dependecy 로 ...

	groupId: org.apache.tomcat
	artifactId: tomcat-catalina
	version : 여러분의 톰캣 버전
	scope: provided 

	... 를 추가해주시기 바랍니다.
  -->
  <filter>
    <filter-name>RemoteIpFilter</filter-name>
    <filter-class>org.apache.catalina.filters.RemoteIpFilter</filter-class>
  </filter>

  <filter-mapping>
    <filter-name>RemoteIpFilter</filter-name>
    <url-pattern>/*</url-pattern>
    <dispatcher>REQUEST</dispatcher>
  </filter-mapping>


  <!-- 이게 바로 RATE LIMIT FILTER!! -->
  <filter>
    <filter-name>redisRateLimitFilter</filter-name>
    <filter-class>org.springframework.web.filter.DelegatingFilterProxy</filter-class>
  </filter>
  <filter-mapping>
    <filter-name>redisRateLimitFilter</filter-name>
    <url-pattern>/remote-api/*</url-pattern>
    <dispatcher>REQUEST</dispatcher>
  </filter-mapping>




Rate Limit Filter 구현

RedisRateLimitFilter 를 Bean 으로 등록해야 됩니다.
Bean ID (또는 명칭) 은 반드시 web.xml 에서 표기한 filter-name 과 동일하게
"redisRateLimitFilter" 로 명시해야 됩니다. 너무 쉬운 파트니 이건 Skip!

package me.dailycode.filter;

import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.util.StringUtils;
import org.springframework.web.servlet.support.ServletUriComponentsBuilder;
import org.springframework.web.util.UriComponents;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/**
 * <h2>Redis 기반 Rate Limiting Filter</h2>
 * Redis 와 Lua Script 를 활용한 RateLimit 기능을 제공하는 Servlet Filter 이다.<br>
 * 주목적은 DDOS 공격 방어용으로 만들었으며,<br>
 * 특정 시간 (WINDOW_TIME_FRAME) 동안<br>
 * 최대 요청 횟수(MAX_REQUEST_PER_WINDOW) 를 제한한다.<br>
 */
public class RedisRateLimitFilter implements Filter {
    private static final Logger LOGGER = LoggerFactory.getLogger(RedisRateLimitFilter.class);

    /**
     * Spring Application 에서 DI 받은 RedisTemplate bean instance
     */
    private final RedisTemplate<String, Object> redisTemplate;

    /**
     * Lua Script Wrapper
     */
    private final RedisScript<Long> script;

    /**
     * Window 하나당 요청할 수 있는 최대 요청 수
     */
    private final String MAX_REQUEST_PER_WINDOW = "5";

    /**
     * Window 의 크기 (=Window 하나의 유지시간, 초단위)<br>
     * 너무 큰 값을 주지 않도록 주의 바람.
     */
    private final String WINDOW_TIME_FRAME = "30";

    /**
     * lua script 에 사용될 고정된 인자값
     */
    private final Object[] SCRIPT_ARGS = Arrays.asList(MAX_REQUEST_PER_WINDOW, WINDOW_TIME_FRAME).toArray();

    /**
     * Error Json Format Message 생성용 JsonNodeFactory.<br>
     * Thread Safe 함으로 안심하고 사용 가능.
     */
    private final JsonNodeFactory nodeFactory = JsonNodeFactory.instance;

    /**
     * 기본 에러 문구
     */
    private final String DEFAULT_ERROR_MSG = "짧은 시간 내에 너무 많은 요청을 보냈습니다. 잠시 기다렸다 다시 요청해주세요.";


    // ApplicationContext 에서 미리 생성한  RedisTemplate 인스턴스
    //  bean 을 주입받습니다.
    public RedisRateLimitFilter(RedisTemplate<String, Object> redisTemplate) {
        this.redisTemplate = redisTemplate;
        // redis lua script 사용
        String rateLimitScript =
                "local current = redis.call('get', KEYS[1]) " +
                "if current then " +
                    "if tonumber(current) >= tonumber(ARGV[1]) " +
                        "then return 0 " +
                    "else " +
                        "redis.call('incr', KEYS[1]) " +
                        "redis.call('expire', KEYS[1], ARGV[2]) " +
                        "return 1 " +
                    "end " +
                "else " +
                    "redis.call('set', KEYS[1], 1) " +
                    "redis.call('expire', KEYS[1], ARGV[2]) " +
                    "return 1 " +
                "end";
        // 설명:
        // 1. 먼저 KEYS[1] (= ip + 요청 uri 를 합친 문자열) 을 조회한다.
        // 2-1. 만약에 조회가 안되면(끝에 있는 else) 신규로 해당 KEYS[1] 에 "1" 이라는 문자열값을 주고, EXPIRE 값(초 단위)도 준다.
        // 2-2. 만약에 조회가 된다면 (if current then)
        //      2-2-1. 읽어온 값(=current) 가 ARGV[1] (= Maximum 요청 제한 횟수) 을 같거나, 넘으면 0 을 반환한다.
        //      2-2-2. 그게 아니라면 current 값을 1 증가시키고, expire 시간을 재조정한다. 그리고 나서 1을 반환한다.
        //
        // * 참고: 이 모든 과정은 하나의 트랜잭션 내에서 일어난다. Redis + lua script 의 기본 동작 방식이다.

        // Lau Script Wrapper 생성
        script = new DefaultRedisScript<>(rateLimitScript, Long.class);
    }


    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {

        if (request instanceof HttpServletRequest && response instanceof HttpServletResponse) {

            HttpServletRequest servletRequest = (HttpServletRequest) request;
            HttpServletResponse servletResponse = (HttpServletResponse) response;
            String remoteAddr = servletRequest.getRemoteAddr();
            String requestURI = servletRequest.getRequestURI();
            String method = servletRequest.getMethod();

            // 등록(=POST)과 관련해서만 체크한다. 필요하다면 PUT, PATCH, DELETE 도 사용할 수 있다.
            if ("POST".equalsIgnoreCase(method)) {

                String clientKey = "rate_limit:" + remoteAddr + ":" + requestURI;
                List<String> keys = Collections.singletonList(clientKey);
                Long result = redisTemplate.execute(script, keys, SCRIPT_ARGS);

                if (result != null && result == 0) {
                    LOGGER.error("Too Many Request! Blocking! => [ ip: {} , url: {} ]", remoteAddr, requestURI);
                    servletResponse.setStatus(429);
                    servletResponse.setCharacterEncoding("UTF-8");
                    servletResponse.setHeader("Content-Type", "application/json");
                    String errorResponseJson = createErrorResponseJson(requestURI);
                    servletResponse.getWriter().write(errorResponseJson);
                    return;
                }
            }
        }
        chain.doFilter(request, response);
    }

    /**
     * error Message 를 담는 Json String 을 반환한다.
     * @param currentRequestUrl 현재 에러를 일으키는 요청 URL
     * @return 에러 문구를 담는 json 포맷 형식의 string
     */
    private String createErrorResponseJson(String currentRequestUrl) {
        ObjectNode errorNode = nodeFactory.objectNode();
        errorNode.put("errorMsg", DEFAULT_ERROR_MSG);
        errorNode.put("blockedUrl", currentRequestUrl);
        return errorNode.toString();
    }

}
profile
백엔드를 계속 배우고 있는 개발자입니다 😊

0개의 댓글