[Spring Boot 3] Tomcat RateLimitFilter 간단 적용법 및 테스트

식빵·2024년 3월 22일
0

Spring Lab

목록 보기
31/34
post-thumbnail

설정 코드


spring boot 프로젝트 생성하면 main 메소드가 생성되는 위치가 있는데,
해당 위치에 모든 설정을 그냥 다 때려박았습니다.

package me.dailycode.hacktest;

import jakarta.servlet.*;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.validation.Valid;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotEmpty;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import org.apache.catalina.filters.RateLimitFilter;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.*;

import java.io.IOException;

@SpringBootApplication
public class HackTestApplication {
    public static void main(String[] args) {
        SpringApplication.run(HackTestApplication.class, args);
    }
}

@Configuration
class FilterConfiguration {
    @Bean
    public FilterRegistrationBean<CustomRateLimitFilter> customRateLimitFilterRegistration() {
        FilterRegistrationBean<CustomRateLimitFilter> registration = new FilterRegistrationBean<>();
        registration.setFilter(new CustomRateLimitFilter());
        registration.setName("customRateLimitFilter");
        
        // /simple/* 이라는 url 로 오는 요청에 대하여 Filter 가 동작한다.
        registration.addUrlPatterns("/simple/*");
        // 2초 내로 50~51 건 이상의 요청이 오면 막힌다!
        registration.addInitParameter("bucketDuration", "2");
        registration.addInitParameter("bucketRequests", "50");
        
        
        // 사실 지금은 필요 X, 여러 Filter 가 있을 때 order 를 써주면 좋다. 
        registration.setOrder(1);
        
        // 요청에 대해서만 반응한다.
        registration.setDispatcherTypes(DispatcherType.REQUEST);
        return registration;
    }
}

@Slf4j
class CustomRateLimitFilter extends RateLimitFilter {
    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        if (request instanceof HttpServletRequest servletRequest) {
        
        	// POST, PUT, DELETE 요청에 대해서만 동작한다.
            if("POST".equalsIgnoreCase(servletRequest.getMethod()) ||
                "PUT".equalsIgnoreCase(servletRequest.getMethod()) ||
                "DELETE".equalsIgnoreCase(servletRequest.getMethod()))
            {
                super.doFilter(request, response, chain);
                return;
            }
        }
        chain.doFilter(request, response);
    }
}





간단 테스트 코드


테스트용 Controller

package me.dailycode.hacktest.controller;

import jakarta.validation.Valid;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotEmpty;
import lombok.Builder;
import me.dailycode.hacktest.*;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.*;


@Controller
@RequestMapping("/simple")
public class SimpleController {

    record SimpleGetRequestDTO(String id, String name, Integer age) {}

    @Builder
    record SimpleGetResponseDTO(String id, String name, Integer age) {}
    record SimplePostRequestDTO(@NotEmpty String name, @Min(0) @Max(120) Integer age) {}

    @Builder
    record SimplePostResponseDTO(String name, Integer age) {}


    @GetMapping("/{id}")
    public ResponseEntity<SimpleGetResponseDTO> getMethod(SimpleGetRequestDTO dto) {
        return
                ResponseEntity.ok(
                    SimpleGetResponseDTO.builder()
                    .id(dto.id())
                    .name("hello-world")
                    .age(100).build());
    }

    @PostMapping
    public ResponseEntity<SimplePostResponseDTO> postMethod(@Valid SimplePostRequestDTO requestDTO) {
        System.out.println("requestDTO = " + requestDTO);
        return ResponseEntity.ok(
                SimplePostResponseDTO.builder()
                        .name(requestDTO.name())
                        .age(requestDTO.age()).build()
        );
    }
}




JUNIT 5 테스트 코드

package me.dailycode.hacktest.repeater;

import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;

public class RepeatSendRequestTest {
    @Test
    void sendPostRequest() throws InterruptedException {

        CountDownLatch countDownLatch = new CountDownLatch(100);
        List<Thread> threadList = new ArrayList<>(100);
        for (int i = 0; i < 100; i++) {
            final int I = i;
            Thread vThread = Thread.ofVirtual()
                    .unstarted(() -> {
                        HttpResponse<String> response;
                        try (HttpClient httpClient = HttpClient.newHttpClient()) {
                            HttpRequest request = HttpRequest.newBuilder()
                                    .uri(URI.create("http://localhost:8081/simple"))
                                    .header("Content-Type", "application/x-www-form-urlencoded")
                                    .POST(HttpRequest.BodyPublishers.ofString("name=dailycode&age=%d"
                                    .formatted(I + 1)))
                                    .build();

                            response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
                            System.out.printf("%d : response = %s%n", I, response);
                            countDownLatch.countDown();
                        } catch (IOException e) {
                            e.printStackTrace(System.err);
                        } catch (InterruptedException e) {
                            Thread.currentThread().interrupt();
                            e.printStackTrace(System.err);
                        }
                    });
            threadList.add(vThread);
        }

        threadList.forEach(Thread::start);
        countDownLatch.await();
        System.out.println("all done");
    }
}




테스트 결과 확인

결과적으로 하나의 아이피에서 너무 많은 요청을 보내게 되면
아래처럼 http response status code = 429 와
로그로 현재 요청이 막혔음을 알려주는 로그가 보인다.



참고: pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.2.4</version>
        <relativePath/> <!-- lookup parent from repository -->
    </parent>
    <groupId>me.dailycode</groupId>
    <artifactId>hack-test</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>hack-test</name>
    <description>hack-test</description>
    <properties>
        <java.version>21</java.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-validation</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-devtools</artifactId>
            <scope>runtime</scope>
            <optional>true</optional>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>1.18.30</version>
        </dependency>
    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>

</project>



참고: HttpClient 로 Json Body 전송법

package me.dailycode.hacktest.simple;

import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;

public class SimpleHttpRequestJsonBodyTest {

    private static final ObjectMapper MAPPER =  new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
    private static final JsonNodeFactory NODE_FACTORY = JsonNodeFactory.instance;

    @Test
    void simpleHttpRequestTest() {

        ObjectNode objectNode = NODE_FACTORY.objectNode();
        objectNode.put("name", "dailyCode");
        objectNode.put("age", 21);

        try (HttpClient httpClient = HttpClient.newHttpClient()) {
            HttpRequest httpRequest = HttpRequest.newBuilder(URI.create("http://localhost:8081/simple"))
                    .header("Content-Type", "application/json")
                    .POST(HttpRequest.BodyPublishers.ofString(objectNode.toString()))
                    .build();
            HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
            int statusCode = response.statusCode();
            String body = response.body();
            System.out.println("statusCode = " + statusCode);
            JsonNode jsonNode = MAPPER.readValue(body, JsonNode.class);
            System.out.println("repsonse body = " + jsonNode.toPrettyString());
        } catch (IOException | InterruptedException e) {
            throw new RuntimeException(e);
        }
    }
}
profile
백엔드를 계속 배우고 있는 개발자입니다 😊

0개의 댓글