기존에는 로그인 시 JWT 를 하나만 생성해서 API 요청 시 인증을 위해 사용하고 있었다.
토큰이 만료되는 경우 재로그인이 필요하므로 만료 기간을 짧게 잡을 수는 없었다.
이 토큰이 탈취되는 경우에는 아무래도 보안적으로 문제가 될 수 있었다.
기존에는 기능적으로 동작하는 게 우선이었기 때문에 간단하게 작업을 했지만 보안을 조금 더 강화하기 위해서 access token, refresh token 을 나누어서 작업해보는 건 어떨까 싶었다.
API 요청 시 전달하는 access token은 만료 기간을 짧게 하여 탈취되더라도 큰 문제가 없도록 하고, refresh token 만료 기간은 길게한 후 access token 만료 시 refresh token 으로 재발급 받을 수 있도록 처리한다.
@Service
public class JwtUtil {
...
private final static String TOKEN_KEY = "userId";
private final static String REFRESH_TOKEN_KEY_PREFIX = "refresh_token_user_id_";
@Value("${jwt.secret}")
private String secretKey;
@Value("${jwt.access-token-validity-in-ms}")
private long accessTokenValidMillisecond;
@Value("${jwt.refresh-token-validity-in-ms}")
private long refreshTokenValidMillisecond;
...
public String createAccessToken(Long userId) {
Date now = new Date();
Key key = new SecretKeySpec(Base64.getDecoder().decode(this.secretKey), SignatureAlgorithm.HS256.getJcaName());
return Jwts.builder()
.claim(TOKEN_KEY, userId)
.setIssuedAt(now)
.setExpiration(new Date(now.getTime() + accessTokenValidMillisecond))
.signWith(key)
.compact();
}
public String createRefreshToken(Long userId) {
Date now = new Date();
Key key = new SecretKeySpec(Base64.getDecoder().decode(this.secretKey), SignatureAlgorithm.HS256.getJcaName());
String token = Jwts.builder()
.claim(TOKEN_KEY, userId)
.setIssuedAt(now)
.setExpiration(new Date(now.getTime() + refreshTokenValidMillisecond))
.signWith(key)
.compact();
redisTemplate.opsForValue().set(REFRESH_TOKEN_KEY_PREFIX + userId, token, Duration.ofMillis(refreshTokenValidMillisecond));
return token;
}
...
}
String accessToken = jwtUtil.createToken(user.getUserId());
String refreshToken = jwtUtil.createRefreshToken(user.getUserId());
@Service
public class JwtUtil {
...
private final RedisTemplate<String, String> redisTemplate;
private final static String TOKEN_KEY = "userId";
private final static String REFRESH_TOKEN_KEY_PREFIX = "refresh_token_user_id_";
@Value("${jwt.secret}")
private String secretKey;
@Value("${jwt.access-token-validity-in-ms}")
private long accessTokenValidMillisecond;
@Value("${jwt.refresh-token-validity-in-ms}")
private long refreshTokenValidMillisecond;
public JwtUtil(RedisTemplate<String, String> redisTemplate) {
this.redisTemplate = redisTemplate;
}
...
public String createRefreshToken(Long userId) {
Date now = new Date();
Key key = new SecretKeySpec(Base64.getDecoder().decode(this.secretKey), SignatureAlgorithm.HS256.getJcaName());
String token = Jwts.builder()
.claim(TOKEN_KEY, userId)
.setIssuedAt(now)
.setExpiration(new Date(now.getTime() + refreshTokenValidMillisecond))
.signWith(key)
.compact();
redisTemplate.opsForValue().set(REFRESH_TOKEN_KEY_PREFIX + userId, token, Duration.ofMillis(refreshTokenValidMillisecond));
return token;
}
public long getUserId(String token) {
return Jwts.parserBuilder()
.setSigningKey(secretKey)
.build()
.parseClaimsJws(token)
.getBody()
.get(TOKEN_KEY, Long.class);
}
public boolean isValidRefreshToken(String refreshToken) {
if (isValidToken(refreshToken)) {
long userId = getUserId(refreshToken);
String foundRefreshToken = redisTemplate.opsForValue().get(REFRESH_TOKEN_KEY_PREFIX + userId);
return refreshToken.equals(foundRefreshToken);
}
return false;
}
private boolean isValidToken(String token) {
try {
logger.debug(this.secretKey);
Jws<Claims> claims = Jwts.parserBuilder().setSigningKey(this.secretKey).build().parseClaimsJws(token);
return !claims.getBody().getExpiration().before(new Date());
} catch (SecurityException | MalformedJwtException | IllegalArgumentException | SignatureException exception) {
logger.info("잘못된 Jwt 토큰입니다");
} catch (ExpiredJwtException exception) {
logger.info("만료된 Jwt 토큰입니다");
} catch (UnsupportedJwtException exception) {
logger.info("지원하지 않는 Jwt 토큰입니다");
}
return false;
}
@Transactional
public void removeRefreshToken(Long userId) {
redisTemplate.delete(REFRESH_TOKEN_KEY_PREFIX + userId);
}
...
}
@RestController
@RequestMapping("/users")
public class UserController {
...
@PostMapping("/token-reissue")
public UserTokenRefreshResponse reissueToken(@RequestBody UserTokenRefreshRequest request) {
return userService.reissueToken(request);
}
}
@Service
public class UserService {
...
@Transactional(readOnly = true)
public UserTokenRefreshResponse reissueToken(UserTokenRefreshRequest request) {
if (!jwtUtil.isValidRefreshToken(request.getRefreshToken())) {
throw new UserTokenNotExistException();
}
return UserTokenRefreshResponse.builder()
.userId(request.getUserId())
.accessToken(jwtUtil.createAccessToken(request.getUserId()))
.refreshToken(jwtUtil.createRefreshToken(request.getUserId()))
.build();
}
}
@Service
public class JwtUtil {
...
private final static String TOKEN_HEADER = "Authorization";
...
public String resolveAccessToken(HttpServletRequest request) {
String header = request.getHeader(TOKEN_HEADER);
return header != null ? header.substring(7) : null;
}
}
@Service
public class JwtUtil {
...
private final RedisTemplate<String, String> redisTemplate;
private final static String ACCESS_TOKEN_BLACKLIST_VALUE = "access_token_blacklist";
public JwtUtil(RedisTemplate<String, String> redisTemplate) {
this.redisTemplate = redisTemplate;
}
...
public long getUserId(String token) {
return Jwts.parserBuilder()
.setSigningKey(secretKey)
.build()
.parseClaimsJws(token)
.getBody()
.get(TOKEN_KEY, Long.class);
}
@Transactional(readOnly = true)
public boolean isValidAccessToken(String accessToken) {
if (StringUtils.hasText(redisTemplate.opsForValue().get(accessToken))) {
return false;
}
return isValidToken(accessToken);
}
private boolean isValidToken(String token) {
try {
logger.debug(this.secretKey);
Jws<Claims> claims = Jwts.parserBuilder().setSigningKey(this.secretKey).build().parseClaimsJws(token);
return !claims.getBody().getExpiration().before(new Date());
} catch (SecurityException | MalformedJwtException | IllegalArgumentException | SignatureException exception) {
logger.info("잘못된 Jwt 토큰입니다");
} catch (ExpiredJwtException exception) {
logger.info("만료된 Jwt 토큰입니다");
} catch (UnsupportedJwtException exception) {
logger.info("지원하지 않는 Jwt 토큰입니다");
}
return false;
}
@Transactional
public void setAccessTokenBlacklist(String accessToken) {
redisTemplate.opsForValue().set(accessToken, ACCESS_TOKEN_BLACKLIST_VALUE, Duration.ofMillis(accessTokenValidMillisecond));
}
}