[Springboot] WebSocket에 JWT?

LTT·2025년 5월 4일

Why?


WebScoket 을 이용하여 실시간으로 Binary 데이터를 서버로 전송하고, 그에 따른 텍스트 값을 클라이언트로 내려주는 서버를 구현했었다. 기존에 이 기능은 세션 방식으로 적용되어 있었지만,

원래같으면 WebSocket을 사용하는 기능을 따로 서버 분리하여 이용하고 싶었지만, 빠른 시간 내에 기능을 추가하고 배포해야 하기 때문에, 기존 서버에서 사용하는 인증 방식인 JWT 토큰을 WebSocket handshake 할 때 Interceptor를 이용하여 적용해 보려고 한다.

Configuration


WebSocketHandshakeConfig

@Configuration
public class WebSocketHandshakeConfig {
    @Bean
    public DefaultHandshakeHandler customHandshakeHandler() {
        return new DefaultHandshakeHandler() {
            @Override
            protected String selectProtocol(List<String> requestedProtocols, WebSocketHandler handler) {
                return !requestedProtocols.isEmpty() ? requestedProtocols.get(0) : null;
            }
        };
    }
}
  • 클라이언트가 요청한 WebSocket 서브프로토콜 목록 중 첫 번째 것을 선택
  • 클라이언트가 서브프로토콜을 제시하지 않았다면 null을 반환
    • ex. 클라이언트가 "chat", "superchat"이라는 프로토콜을 제시하면 "chat"을 선택
  • WebSocket 연결 시 클라이언트가 요청한 서브프로토콜 중 첫 번째를 선택해서 사용하는 핸드셰이크 핸들러를 등록하도록 함

WebSocketConfig

@Configuration
@EnableWebSocket
@RequiredArgsConstructor
public class WebSocketConfig implements WebSocketConfigurer {

    private final WebSocketHandler webSocketHandler;
    private final AuthWebSocketInterceptor authWebSocketInterceptor;
    private final DefaultHandshakeHandler defaultHandshakeHandler;

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry.addHandler(webSocketHandler, "/ws/v1/example")
                .setHandshakeHandler(defaultHandshakeHandler)
                .addInterceptors(authWebSocketInterceptor)
                .setAllowedOrigins("*");
    }
}
  • WebSocket 엔드포인트를 설정해주고, 해당 엔드포인트에 인터셉터와 핸들러를 적용해주는 설정

위와같이 설정해야 /ws/v1/example로 엔드포인트가 연결된다.

Interceptor


AuthWebSocketInterceptor

@Slf4j
@Component
@RequiredArgsConstructor
public class AuthWebSocketInterceptor implements HandshakeInterceptor {

    private final JWTUtil jwtUtil;

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        if (request.getMethod() == HttpMethod.OPTIONS) {
            return true;
        }

        final String accessToken = getAccessToken(request);

        if (accessToken != null) {
            String userId = jwtUtil.getUserId(jwtUtil.getClaims(accessToken));
            String requestURI = request.getURI().getPath();
            String uuid = UUID.randomUUID().toString();

            // 요청에 대한 로그 ID 설정
            request.getHeaders().add("LOG_ID", uuid);

            attributes.put(JWTUtil.USER_ID, userId);
            attributes.put(JWTUtil.USER_NAME, jwtUtil.getUsername(accessToken));

            log.info("HANDSHAKE[{}] Auth by : {}", requestURI, userId);
        }

        return true;
    }

    /**
     * Sec-WebSocket-Protocol header 에 담긴 Access Token.
     *
     * @param request ServerHttpRequest
     * @return Access Token
     */
    private String getAccessToken(ServerHttpRequest request) {

        List<String> subProtocolHeader = request.getHeaders().get("Sec-WebSocket-Protocol");

        if (subProtocolHeader != null && !subProtocolHeader.isEmpty()) {
            String[] protocols = subProtocolHeader.get(0).split(",");
            String accessToken = Arrays.asList(protocols).get(0).trim();

            final boolean isValid = jwtUtil.isValidAccessToken(accessToken);

            return isValid ? accessToken : null;
        }

        return null;
    }

    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
        String requestURI = request.getURI().getPath();
        String logId = String.valueOf(request.getHeaders().getFirst("LOG_ID"));

        log.info("HANDSHAKE[{}][{}] Complete", logId, requestURI);
    }

}

HandshakeInterceptor 를 상속받아 위와같이 메소드들을 오버라이드해준다.

beforeHandshake

    private final JWTUtil jwtUtil;

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        if (request.getMethod() == HttpMethod.OPTIONS) {
            return true;
        }

        final String accessToken = getAccessToken(request);

        if (accessToken != null) {
            String userId = jwtUtil.getUserId(jwtUtil.getClaims(accessToken));
            String requestURI = request.getURI().getPath();
            String uuid = UUID.randomUUID().toString();

            // 요청에 대한 로그 ID 설정
            request.getHeaders().add("LOG_ID", uuid);

            attributes.put(JWTUtil.USER_ID, userId);
            attributes.put(JWTUtil.USER_NAME, jwtUtil.getUsername(accessToken));

            log.info("HANDSHAKE[{}] Auth by : {}", requestURI, userId);
        }

        return true;
    }

    /**
     * Sec-WebSocket-Protocol header 에 담긴 Access Token.
     *
     * @param request ServerHttpRequest
     * @return Access Token
     */
    private String getAccessToken(ServerHttpRequest request) {

        List<String> subProtocolHeader = request.getHeaders().get("Sec-WebSocket-Protocol");

        if (subProtocolHeader != null && !subProtocolHeader.isEmpty()) {
            String[] protocols = subProtocolHeader.get(0).split(",");
            String accessToken = Arrays.asList(protocols).get(0).trim();

            final boolean isValid = jwtUtil.isValidAccessToken(accessToken);

            return isValid ? accessToken : null;
        }

        return null;
    }

beforeHandshake 에서 handshake가 일어나기 전 수행할 동작들을 작성해준다.

attribute에 미리 “userId”라는 키값으로 유저 아이디를 저장해두고, 토큰에 대한 검사도 진행해준다.


getAccessToken메소드에서는 Sec-WebSocket-Protocol 에서 엑세스 토큰을 추출해오는 코드이다. 프론트는 자세히 몰라서 설명은 못해드리지만, Vue.js에서 ts를 사용하는데 Websocket 객체에서 추가로 데이터를 보내고 싶을 땐 Sec-WebSocket-Protocol 프로토콜에 값을 집어넣는다고 하여 위와 같이 코드를 추출해준다.


추가로, 토큰 인증에 실패했을 경우는 beforeHandshake에서 작성하지 않았다. 이유는 이후에 나온다.

afterHandshake

    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
        String requestURI = request.getURI().getPath();
        String logId = String.valueOf(request.getHeaders().getFirst("LOG_ID"));

        log.info("HANDSHAKE[{}][{}] Complete", logId, requestURI);
    }

이건 handshake가 연결되고 나서 실행되는 코드이다.

handshake가 완료되었다고 로그를 남기기 위해 위와같이 작성한 것이다.

추가사항

코드를 보면 알겠지만, 토큰이 제대로 받아지지 않거나, 토큰이 잘못되었을 때, Interceptor에서 처리하지 않는다.

그 이유는 바로…

Interceptor에서 예외를 던질 경우, 예외 메세지가 제대로 전달되지 않고, 그냥 세션이 닫혀버리는 현상이 발생한다.

때문에 클라이언트 측에서 예외에 따른 동작들을 구현하기 위해 handler의 afterConnectionEstablished에서 예외처리를 해 주었다.

afterConnectionEstablished (in WebSocketHandler)


    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        String userId = (String) session.getAttributes().get(JWTUtil.USER_ID);

        if (userId != null) {
            // 처리해야할 동작...
        } else {
            // Interceptor에서 토큰이 유효하지 않아 예외를 발생시켜야 하는 경우
            session.sendMessage(new TextMessage("올바르지 않은 엑세스 토큰입니다."));
            session.close(new CloseStatus(401, "Token Invalid"));
            return;
        }
				
		// 추가 동작...
    }

이렇게 토큰을 읽지 못해 userId 값이 들어가지 않은 경우로 토큰이 잘못되었는지 판별하여 오류 메세지를 보내며 세션을 닫아준다.


더 좋고 깔끔한 방법이 있을 것 같은데, 빠른 시간 안에 만들었어야 하기에 이런 식으로 개발했습니다.

굉장히 정신없게 정리한 것 같은데, 추가로 궁금한 점이 있거나 피드백할만한 것이 있다면 말씀해주시면 너무 감사합니다…

그냥 시시콜콜한 것도 다 상관 없습니다ㅎㅎ

profile
개발자에서 엔지니어로, 엔지니어에서 리더로

0개의 댓글