이번 포스터는 QR-DQN에 대해서 알아보겠습니다!
QR은 Quantile Regression입니다.
여기서 재밌는 점은 기존에 이론적으로는 훌륭하나 practical하게 사용하기에는 어려웠던 wasserstein metric을 사용할 수 있게 됩니다.
QR-DQN에서는 기존 C51과는 반대로 확률을 uniform 확률로 fix시키고, support를 fixed support가 아닌, learnable support로 바꿔줍니다. 이 support에 해당하는 return값을 학습하는 방향으로 진행이 됩니다.
support의 개수를 5라고 하면 이제는 모두 같은 1/5 확률을 갖고 있게 됩니다.
이것을 표현하기 위해 각 0.2 면적을 기준으로 나누고, 0.1 0.1로 면적을 절반으로 나누는 곳을 support로 지정하게 됩니다. 이를 표현하기 위해 CDF를 이용합니다. 0.2 0.4 0.6 0.8 1
그리고 타우를 통해 0.1 0.3 0.5 0.7 0.9의 CDF값에 해당하는 return 세타를 찾아줍니다!
이렇게 우선 타우들을 계산합니다. 이것들을 quantile midpoints라고 합니다.
더 이상 lower, upper bound를 찾을 필요가 없게 됩니다. C51과 마찬가지로 action의 개수 x 타우개수 가 output개수가 됩니다.
QR분포를 사용하는 핵심 이유는
Wasserstein metric을 사용하여 target분포를 추정한다고 할 때
Wasserstein metric을 사용하여 Z와 ZΘ를 계산하면 두 CDF 사이의 면적을 계산한 것과 같습니다. 여기서 Wasserstein metric이 최소가 되는 곳은 타우라는 것 입니다.
우리의 목표는 Wasserstein metric을 줄이는 방향 즉, 세타를 잘 움직여 CDF간 면적을 최소화시키는 것 입니다. 이를 만족하기 위한 곳은 타우에 해당하는 부분이라는 것 입니다.
또한 quantile 분포를 학습할 때, QR을 사용하는데 SGD를 사용할 수 있다는 장점이 있습니다.
그렇기에 Wasserstein metric을 직접 계산할 필요가 없습니다.
이제는 quantile 분포를 학습할 때, SGD를 사용할 수 있는 QR를 사용할 수 있고,
이렇게 해서 찾은 값이 Wasserstein metric을 최소화하는 것을 충족하게 됩니다!
그렇기에 Theory와 practice 사이의 이론 갭이 없어집니다.
Quantile Regression에 대해 알아보겠습니다.
데이터들을 잘 표현할 수 있는 직선을 Liear Regression을 통해 찾을 수 있습니다.
여기서는 loss ft으로 MSE가 사용됩니다. 이것을 최소화하는 방향으로 학습하게 됩니다.
이번에는 L2 regression을 보겠습니다. 이것은 데이터들의 평균이 되는 지점을 찾는 것 입니다.
한 단면의 분포에서 mean을 찾는 것 입니다. loss ft으로 x^2을 사용합니다.
이것을 통해 데이터들의 평균을 나타내는 곡선을 찾을 수 있습니다.
이번에는 L1 regression을 보겠습니다. loss ft은 |x|입니다.
중앙값을 찾게 계산하게 됩니다.
τ-quantile regression은 τ=0.1이라고 했을 때, 왼쪽 면적은 0.1 오른쪽 면적은 0.9가 되는 지점을 찾고 싶은 것 입니다. 그림으로 표현하면 선을 기준으로 10% 데이터는 아래에 존재하고 90%데이터는 위에 존재하게끔 하는 것 입니다. L1 regression은 τ가 0.5인 것 입니다.
벨만 업데이트를 하여 나온 target값과 behavior값의 차이를 minimize시키기 위해 QR loss를 사용합니다. QR loss를 최소화시키는 Θi가 역시 Wasserstein loss를 최소화시킵니다.
또한 QR loss는 미분 가능이기에 SGD를 사용할 수 있습니다.
기존 QR loss는 미분 불가한 지점이 있기에, Quantile Huber loss를 통해 smoth하게 바꿔줍니다.
upper lower bound를 구할 필요가 없고, disjoint support issue를 해결했습니다. 또한 SGD를 사용할 수 있게 되었습니다.
고려대학교 오승상 교수님 강화학습 강의 : https://www.youtube.com/watch?v=NZP7Va21WO8&list=PLvbUC2Zh5oJtYXow4jawpZJ2xBel6vGhC&index=33