Spring Boot - GraphQL subscription JWT 인증 및 구독 에러 핸들링 구현

조제·2024년 10월 18일
0

GraphQL Subscription이란?

GraphQL에서 Subscription은 서버에서 클라이언트로 실시간 데이터를 푸시하는 기능입니다. 일반적인 Query 또는 Mutation과는 달리, Subscription은 클라이언트가 실시간으로 업데이트된 데이터를 받을 수 있도록 서버와 지속적인 연결을 유지합니다.

WebSocket에서의 인증 처리 과정

WebSocket 프로토콜은 HTTP처럼 헤더를 직접적으로 처리하지 않기 때문에 JWT와 같은 토큰을 인증하기 위해, connectionParams를 활용하여 토큰을 서버로 전송하고 검증하는 과정을 거칩니다.

build.gradle

plugins {
	id 'java'
	id 'org.springframework.boot' version '3.3.1'
	id 'io.spring.dependency-management' version '1.1.5'
}

group = 'com.ontacthealth'
version = '0.0.1-SNAPSHOT'

java {
	toolchain {
		languageVersion = JavaLanguageVersion.of(21)
	}
}

repositories {
	mavenCentral()
}

dependencies {
	// spring
	implementation 'org.springframework.boot:spring-boot-starter-web'

	//security
	implementation 'org.springframework.boot:spring-boot-starter-security'
	testImplementation 'org.springframework.security:spring-security-test'

	//jwt
	implementation 'io.jsonwebtoken:jjwt-api:0.11.2'
	runtimeOnly 'io.jsonwebtoken:jjwt-impl:0.11.2'
	runtimeOnly 'io.jsonwebtoken:jjwt-jackson:0.11.2' // or jjwt-gson for Gson

	// lombok
	compileOnly 'org.projectlombok:lombok'
	annotationProcessor 'org.projectlombok:lombok'
	testCompileOnly 'org.projectlombok:lombok'
	testAnnotationProcessor 'org.projectlombok:lombok'

	// db
	implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
	implementation 'org.mybatis.spring.boot:mybatis-spring-boot-starter:3.0.3'
	runtimeOnly 'org.postgresql:postgresql'
	implementation 'org.mybatis:mybatis-typehandlers-jsr310:1.0.2'

	//webflux
	implementation 'org.springframework.boot:spring-boot-starter-webflux'

	//graphql
	implementation 'org.springframework.boot:spring-boot-starter-graphql'
	testImplementation 'org.springframework.graphql:spring-graphql-test'
	testImplementation 'org.springframework:spring-webflux'
	implementation 'org.springframework.boot:spring-boot-starter-websocket'  // for subscriptions
}

클라이언트

const wsLink = new GraphQLWsLink(createClient({
    url: 'ws://localhost:8080/graphql',
    connectionParams: {
        Authorization: `Bearer ${token}`  // 헤더 대신 connectionParams로 설정
    },
}));

서버

WebSocketGraphQlInterceptor를 구현한 CustomGraphQLInterceptor 클래스가 WebSocket 연결 초기화 단계에서 JWT를 검증하는 역할을 합니다.

@Slf4j
@Component
@RequiredArgsConstructor
public class CustomGraphQLInterceptor implements WebSocketGraphQlInterceptor {

    private final SecurityServiceUseCase securityServiceUseCase;

    @Override
    public Mono<Object> handleConnectionInitialization(WebSocketSessionInfo sessionInfo, Map<String, Object> connectionInitPayload) {
        log.info("connectionInitPayload = {}", connectionInitPayload);

        String accessToken = (String) connectionInitPayload.get("Authorization");
        log.info("Authorization token: {}", accessToken);

        if (accessToken == null || accessToken.isEmpty()) {
            return Mono.error(new CustomAuthenticationException("Invalid Token"));
        }

        accessToken = accessToken.startsWith("Bearer ") ? accessToken.substring(7) : accessToken;

        try {
            securityServiceUseCase.loginByAccessTokenElseThrow(accessToken);
            return WebSocketGraphQlInterceptor.super.handleConnectionInitialization(sessionInfo, connectionInitPayload);
        } catch (Exception e) {
            log.error("Error validating token", e);
            return Mono.error(new CustomAuthenticationException("Invalid Token"));
        }
    }

}

Spring Boot - GraphQL 구독 프로세스

spring-graphql-1.3.1
webmvc/GraphQlWebSocketHandler.class : TextWebSocketHandler를 상속한 클래스

@RegisterReflectionForBinding(GraphQlWebSocketMessage.class)
public class GraphQlWebSocketHandler extends TextWebSocketHandler implements SubProtocolCapable {
	...
    // {1} 클라이언트로부터 웹소켓 메시지를 수신합니다.
    @Override
	protected void handleTextMessage(WebSocketSession session, TextMessage webSocketMessage) throws Exception {
		try (AutoCloseable closeable = ContextHandshakeInterceptor.setThreadLocals(session)) {
			handleInternal(session, webSocketMessage);
		}
	}
    
    // {2} 웹 소켓 메세지 유형에 따라 다음의 처리를 진행합니다.
    private void handleInternal(WebSocketSession session, TextMessage webSocketMessage) throws IOException {
		GraphQlWebSocketMessage message = decode(webSocketMessage);
		String id = message.getId();
		Map<String, Object> payload = message.getPayload();
		SessionState state = getSessionInfo(session);
		switch (message.resolvedType()) {
        	// {4} 인증에 성공했을 경우 웹소켓 메시지를 다시 수신합니다.
			case SUBSCRIBE -> {
				if (state.getConnectionInitPayload() == null) {
					GraphQlStatus.closeSession(session, GraphQlStatus.UNAUTHORIZED_STATUS);
					return;
				}
				if (id == null) {
					GraphQlStatus.closeSession(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
					return;
				}
				URI uri = session.getUri();
				Assert.notNull(uri, "Expected handshake url");
				HttpHeaders headers = session.getHandshakeHeaders();
				WebSocketGraphQlRequest request = new WebSocketGraphQlRequest(
						uri, headers, null, session.getRemoteAddress(), session.getAttributes(), payload, id, null, state.getSessionInfo());
				if (logger.isDebugEnabled()) {
					logger.debug("Executing: " + request);
				}
				this.graphQlHandler.handleRequest(request)
                		// {5} 비즈니스 로직 수행 후 응답을 처리합니다.
						.flatMapMany((response) -> handleResponse(session, request.getId(), response))
						.publishOn(state.getScheduler()) // Serial blocking send via single thread
						.subscribe(new SendMessageSubscriber(id, session, state));
			} // case SUBSCRIBE
            ...
			case CONNECTION_INIT -> {
				if (!state.setConnectionInitPayload(payload)) {
					GraphQlStatus.closeSession(session, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
					return;
				}
				this.webSocketGraphQlInterceptor.handleConnectionInitialization(state.getSessionInfo(), payload)
						.defaultIfEmpty(Collections.emptyMap())
						.publishOn(state.getScheduler()) // Serial blocking send via single thread
						.doOnNext((ackPayload) -> {
							TextMessage outputMessage = encode(GraphQlWebSocketMessage.connectionAck(ackPayload));
							try {
								session.sendMessage(outputMessage);
							}
							catch (IOException ex) {
								throw new IllegalStateException(ex);
							}
						})
						.onErrorResume((ex) -> {
                            // {3} 인터셉터에서 인증에 실패할경우 4401을 리턴합니다.
							GraphQlStatus.closeSession(session, GraphQlStatus.UNAUTHORIZED_STATUS);
							return Mono.empty();
						})
						.block(Duration.ofSeconds(10));

				if (this.keepAliveDuration != null) {
					Flux.interval(this.keepAliveDuration, this.keepAliveDuration)
							.filter((aLong) -> true)
							.publishOn(state.getScheduler()) // Serial blocking send via single thread
							.doOnNext((aLong) -> {
								try {
									session.sendMessage(encode(GraphQlWebSocketMessage.ping(null)));
								}
								catch (IOException ex) {
									ExceptionWebSocketHandlerDecorator.tryCloseWithError(session, ex, logger);
								}
							})
							.subscribe(state.getKeepAliveSubscriber());
				} // if
			} // case CONNECTION_INIT
			...
		} // switch
	} // handleInternal()
    ...
    // {5} 비즈니스 로직 수행 후 응답을 처리합니다.
    private Flux<TextMessage> handleResponse(WebSocketSession session, String id, WebGraphQlResponse response) {
		if (logger.isDebugEnabled()) {
			logger.debug("Execution result ready"
					+ (!CollectionUtils.isEmpty(response.getErrors()) ? " with errors: " + response.getErrors() : "")
					+ ".");
		}
		Flux<Map<String, Object>> responseFlux;
		if (response.getData() instanceof Publisher) {
			// Subscription
			responseFlux = Flux.from((Publisher<ExecutionResult>) response.getData())
					.map(ExecutionResult::toSpecification)
					.doOnSubscribe((subscription) -> {
							Subscription prev = getSessionInfo(session).getSubscriptions().putIfAbsent(id, subscription);
							if (prev != null) {
								throw new SubscriptionExistsException();
							}
					});
		}
		else {
			// Single response (query or mutation) that may contain errors
			responseFlux = Flux.just(response.toMap());
		}

		return responseFlux
				.map((responseMap) -> encode(GraphQlWebSocketMessage.next(id, responseMap)))
				.concatWith(Mono.fromCallable(() -> encode(GraphQlWebSocketMessage.complete(id))))
				.onErrorResume((ex) -> {
					if (ex instanceof SubscriptionExistsException) {
						CloseStatus status = new CloseStatus(4409, "Subscriber for " + id + " already exists");
						GraphQlStatus.closeSession(session, status);
						return Flux.empty();
					}
                    // {6} 구독 후 비즈니스 로직에서 발생한 에러는 모두 INTERNAL_ERROR로 반환됩니다.
					List<GraphQLError> errors = ((ex instanceof SubscriptionPublisherException) ?
							((SubscriptionPublisherException) ex).getErrors() :
							Collections.singletonList(GraphqlErrorBuilder.newError()
									.message("Subscription error")
									.errorType(ErrorType.INTERNAL_ERROR)
									.build()));
					return Mono.just(encode(GraphQlWebSocketMessage.error(id, errors)));
				});
	} // handleResponse()
    ...
}

{1} : GraphQlWebSocketHandler.handleTextMessage() : 클라이언트로부터 웹소켓 메시지를 수신합니다.
{2} : GraphQlWebSocketHandler.handleInternal() : 웹소켓 메시지의 유형에 따라 다음의 처리를 진행합니다.
- type : connection_init
- payload : 인증 토큰이 포함됩니다.
{3} : CustomGraphQLInterceptor 에서 인증에 실패할 경우 4401을 리턴합니다.

{
  "errors": [
    {
      "message": "Socket closed with event 4401 Unauthorized",
      "stack": "Error: Socket closed with event 4401 Unauthorized\n    at Object.error (https://unpkg.com/graphiql/graphiql.min.js:79281:46)\n    at https://unpkg.com/graphiql/graphiql.min.js:22255:14"
    }
  ]
}

인증에 성공했을 경우
{4} : GraphQlWebSocketHandler.handleTextMessage() : 웹소켓 메시지를 다시 수신합니다.
- handleInternal() : 웹소켓 메시지의 유형을 확인합니다.
- type : SUBSCRIBE
- payload : 이 단계에서는 쿼리만 포함되어 있습니다 (토큰 없음).
{5} : @SubscriptionMapping 비즈니스 로직 수행 후 응답을 처리합니다.
{6} : 구독 후 비즈니스 로직에서 발생한 에러는 모두 INTERNAL_ERROR로 반환됩니다.

서버에서 subscription 에러 핸들링

INTERNAL_ERROR 가 아닌 다른 값을 응답해주고 싶은 경우
@GraphQlExceptionHandler 어노테이션 사용하여 핸들링 할 수 있습니다.

@GraphQlExceptionHandler
public GraphQLError handleGraphQlException(CustomException e) {
    log.error("handleGraphQlException occurred", e);

    return GraphqlErrorBuilder.newError()
            .message(e.getMessage())
            .errorType(ErrorType.BAD_REQUEST)
            .build();
}
profile
조제

0개의 댓글