Request Body 데이터 확인하기 (w. HttpServletRequestWrapper)

최준호·2022년 11월 14일
3

Spring

목록 보기
42/47
post-thumbnail

😂 Request Body 데이터 확인이 왜 필요할까?

업무를 진행하며 filter, Interceptor에서 Request에 담긴 parameter와 body 값에서 금칙어를 확인해야할 일이 생겼다. 그래서 filter와 interceptor를 고민하다가 특정 url에서 금칙어 확인을 제외해주어야하는 부분이 있어 url 등록과 제외가 편한 interceptor에 적용하려고 했다.

하지만 여기서 문제가 발생하는데... Request에서 Parameter를 가져오는 것은 문제가 없지만 Body에 담겨오는 데이터를 읽기 위해 getInputStream()을 사용하는 순간 filter가 종료된 이후 controller에서 데이터를 읽을 수 없게 되버리는 것이다.

그 이유는 Request의 getInputStream()는 한번만 반환하고 반환 이후에는 데이터가 사라져 버리기 때문에 Interceptor에서 한번 가져와 버리면 그 이후에는 데이터가 비어버리기 때문이다.

참고로 interceptor에서는 request의 body를 조작할 수 없다고 한다. 그래서 filter로 진행해야만 한다!!!
아마 내 생각에는 이미 interceptor로 넘어온 순간 servlet을 지났기 때문에 HttpRequestWrapper에서 처리되는 servlet 과정을 반복하지 않기 때문에 interceptor에서는 더이상 처리할 수 없는거 같다...
정확한 이유 누가 알면 댓글로 달아주실래요?ㅜㅜ

이 문제점을 해결하기 위해 HttpServletRequestWrapper를 filter와 함께 사용하여 문제를 해결해보려고 한다.

📗 HttpServletRequestWrapper

먼저 Spring 요청 구조를 이해하면 위 그림과 같다. 여기서 Interceptor 부분에서 요청을 가로채 데이터를 확인하고 다시 다음 로직으로 태워주려고 하는 것이다.

📄 HttpServletRequestWrapper 구현하기

먼저 HttpServletRequestWrapper class를 살펴보면 ServletRequestWrapper class를 상속받고 있는데 이 부분의 소스를 다시 파보면

    /**
     * The default behavior of this method is to return getInputStream() on the
     * wrapped request object.
     */
    @Override
    public ServletInputStream getInputStream() throws IOException {
        return this.request.getInputStream();
    }
    
    /**
     * The default behavior of this method is to return getReader() on the
     * wrapped request object.
     */
    @Override
    public BufferedReader getReader() throws IOException {
        return this.request.getReader();
    }

getInputStream() getReader() 메서드를 확인해볼 수 있다. 이 두개의 메서드를 overried해주면 된다.

public class MyRequestWrapper extends HttpServletRequestWrapper {

    private String requestData;

    public MyRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        requestData = requestDataByte(request);
    }

    private String requestDataByte(HttpServletRequest request) throws IOException {
        ServletInputStream inputStream = request.getInputStream();
        byte[] rawData = StreamUtils.copyToByteArray(inputStream);
        return new String(rawData);
    }

    @Override
    public ServletInputStream getInputStream() {
        ByteArrayInputStream inputStream = new ByteArrayInputStream(this.requestData.getBytes(StandardCharsets.UTF_8));
        return new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return inputStream.available() == 0;
            }

            @Override
            public boolean isReady() {
                return true;
            }

            @Override
            public void setReadListener(ReadListener listener) {
                throw new UnsupportedOperationException();
            }

            @Override
            public int read() {
                return inputStream.read();
            }
        };
    }

    @Override
    public BufferedReader getReader() {
        return new BufferedReader(new InputStreamReader(this.getInputStream()));
    }
}

request에서 getInputStream을 해서 데이터를 가져오고 해당 데이터를 override를 통해 기존에 실제 request에서 getInputStream()을 진행하는 것이 request에서 가져온 복사된 InputStream을 반환할 수 있도록 수정해주었다.

📄 filter 적용하기

참고 밸덩 OncePerRequestFilter shouldNotFilter 사용방법

@Slf4j
public abstract class FilterSupport extends OncePerRequestFilter {

    boolean isNotFilter(String requestURI, String[] whiteList) {
        boolean notFilter = false;
        for(String white : whiteList){
            if(requestURI.equals(white)) notFilter = true;
        }
        return notFilter;
    }

    /**
     * 에러 발생 response 세팅
     * @param httpStatus
     * @param response
     * @param message
     * @param error
     */
    void onError(HttpStatus httpStatus,HttpServletResponse response, String message, Error error) {
        ObjectMapper objectMapper = new ObjectMapper();
        ErrorResponse<Error> errorResponse = ErrorResponse.<Error>builder()
                .resultCode(ResultCode.FAIL)
                .resultType(ResultType.ALERT)
                .resultMsg(message)
                .error(error)
                .build();
        response.setStatus(httpStatus.value());
        response.setContentType(MediaType.APPLICATION_JSON_VALUE);
        response.setCharacterEncoding(StandardCharsets.UTF_8.name());
        try {
            response.getWriter().write(objectMapper.writeValueAsString(errorResponse));
        } catch (IOException e) {
            log.error("filter error !",e);
            response.setStatus(HttpStatus.INTERNAL_SERVER_ERROR.value());
        }
    }

    void onError(HttpStatus httpStatus,HttpServletResponse response, String message) {
        onError(httpStatus, response, message, new Error());
    }
}

OncePerRequestFilter class를 상속받은 추상 클래스를 하나 두었다. filter에서는 exception을 handling하여 api를 반환할 수 없어서 response에 꽂아서 넣어줘야하기 때문에 Error가 발생했을때 처리를 하기 위해서다.

여기서 ObjectMapper를 Spring에 등록된 빈으로 가져오고 싶었는데 뭔짓을 해도 안되더라... 계속 npe가 떠서 포기... 일단은 ObjectMapper 객체를 생성해서 사용했다. 나중에 왜그런지 찾아봐야겠다

@RequiredArgsConstructor
@Slf4j
@Order(2)
public class BadWordFilter extends FilterSupport {
    private final ObjectMapper objectMapper;
    private final String[] blockWordList = {
            "화끈", "<", ">"
    };

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        String method = request.getMethod();
        if(method.equals(HttpMethod.GET.name())){
            validateRequestParameter(request, response, filterChain);
        }else{
            validateRequestBody(request, response, filterChain);
        }
    }

    private void validateRequestParameter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException {
        Enumeration<String> parameterNames = request.getParameterNames();
        while(parameterNames.hasMoreElements()){
            String parameterName = parameterNames.nextElement();
            String parameterValue = request.getParameter(parameterName);

            if (isContainBlockWord(response, parameterValue)) return;
        }
        filterChain.doFilter(request, response);
    }

    private boolean isContainBlockWord(HttpServletResponse response, String parameterValue) {
        for(String blockWord : blockWordList){
            if(parameterValue.contains(blockWord)){
                onError(HttpStatus.BAD_REQUEST, response, String.format("금지 단어가 포함되어 있습니다. 금지어 포함 단어 = %s", parameterValue), new BadWordError(blockWord, parameterValue));
                return true;
            }
        }
        return false;
    }

    private void validateRequestBody(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException {
        MyRequestWrapper myRequestWrapper = new MyRequestWrapper(request);
        StringBuilder sb = getBodyToStringBuilder(response, filterChain, myRequestWrapper);
        if (sb == null) return;

        try {
            HashMap map = objectMapper.readValue(sb.toString(), HashMap.class);
            Set<String> keys = map.keySet();
            for (String key : keys) {
                if(isContainBlockWord(response, map.get(key).toString())) return;
            }
        } catch (NullPointerException ne){
            // data 중 null 값 발생
        } catch (JsonProcessingException e) {
            onError(HttpStatus.INTERNAL_SERVER_ERROR, response, "서버 내부 에러 발생");
            return;
        }
        filterChain.doFilter(myRequestWrapper, response);
    }

    private StringBuilder getBodyToStringBuilder(HttpServletResponse response, FilterChain filterChain, MyRequestWrapper myRequestWrapper) throws IOException, ServletException {
        StringBuilder sb = new StringBuilder();
        BufferedReader br = null;
        //한줄씩 담을 변수
        String line = "";
        try {
            ServletInputStream inputStream = myRequestWrapper.getInputStream();
            if(inputStream != null){
                br = new BufferedReader(new InputStreamReader(inputStream));
                while ((line = br.readLine()) != null) {
                    sb.append(line);
                }
            }
        } catch (IOException e) {
            log.debug("body에 요청이 없습니다.");
            filterChain.doFilter(myRequestWrapper, response);
            return null;
        }
        return sb;
    }

    @Override
    protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
        String requestURI = request.getRequestURI();
        String[] whiteList = {""};  // ex) /v1/auth/email/signup

        return isNotFilter(requestURI, whiteList);
    }
}

@Order({순서}) 어노테이션을 통해 순서를 표시해두고 OncePerRequestFilter class를 상속하여 shouldNotFilter() method를 구현해주었다. shouldNotFilter()를 구현함으로 요청시 제외될 url을 관리해줄 수 있게 된다.

그리고 여기서 중요한 코드는 filterChain.doFilter(myRequestWrapper, response); request body를 파싱한 뒤에 이미 사용된 request를 반환하는게 아니라 request를 통해 복사해둔 myRequestWrapper를 반환하므로써 다음 servlet에서도 body를 파싱할 수 있도록 넘기는 것이 중요하다.

@Configuration
@RequiredArgsConstructor
public class WebMvcConfig implements WebMvcConfigurer{
    
    ...
    
    @Bean
    public FilterRegistrationBean<BadWordFilter> badWordFilterFilterRegistrationBean(){
        FilterRegistrationBean<BadWordFilter> registrationBean = new FilterRegistrationBean<>();
        registrationBean.setFilter(new BadWordFilter(objectMapper));
        registrationBean.addUrlPatterns(
                "/*"
        );

        return registrationBean;
    }
}

그 후에 전역으로 filter를 등록시켜주면 filter 등록이 끝났다!

📄 테스트 해보기

@SpringBootTest
@AutoConfigureMockMvc
@ExtendWith(MockitoExtension.class)
class BadWordFilterTest extends ControllerTest {

    @Test
    @DisplayName("금지어를 포함한 get 요청은 실패한다.")
    void badWordFilterFail1() throws Exception {
        // given & when
        ResultActions perform = mockMvc.perform(
                get("/v1/test/bad-word")
                        .with(request -> {
                            request.addParameter("value1", "<테스터>");
                            request.addParameter("value2", "화끈하게");
                            return request;
                        }
        )).andDo(print());
        // then
        perform.andExpect(status().is4xxClientError());
        assertTrue(perform.andReturn().getResponse().getContentAsString().contains("금지 단어가 포함되어 있습니다."));
    }

    @Test
    @DisplayName("금지어를 포함한 post 요청은 실패한다.")
    void badWordFilterFail2() throws Exception {
        // given
        Map<String, Object> map = new HashMap<>();
        map.put("value1", "테스터");
        map.put("value2", "화끈하게");
        // when
        ResultActions perform = mockMvc.perform(
                post("/v1/test/bad-word")
                        .contentType(MediaType.APPLICATION_JSON)
                        .content(convertToString(map))
        ).andDo(print());
        // then
        perform.andExpect(status().is4xxClientError());
        assertTrue(perform.andReturn().getResponse().getContentAsString().contains("금지 단어가 포함되어 있습니다."));
    }
}

그 후에 다음과 같이 테스트 코드를 작성하여 테스트를 진행해보면

테스트가 잘 진행되며 반환 결과도 내가 예상한 에러 데이터로 나오는 것을 확인할 수 있었다!

👏 진행하며 느낀점

요청 데이터의 필터링을 적용하는 것이 쉬울것 같았는데 interceptor쪽 부터 시작하여 삽질하기 시작해서 겨우 끝낼 수 있었던거 같다.

하지만 오히려 interceptor로 삽질하게 되어서 filter와 interceptor의 동작과정을 한번 더 살펴볼 수 있었고 filter에서 exclude 할 수 있도록 제공되는 OncePerRequestFilter class까지 학습할 수 있었던거 같다.

profile
코딩을 깔끔하게 하고 싶어하는 초보 개발자 (편하게 글을 쓰기위해 반말체를 사용하고 있습니다! 양해 부탁드려요!) 현재 KakaoVX 근무중입니다!

0개의 댓글