Discrete Normalizing Flows

상솜공방·2025년 10월 23일

딥러닝

목록 보기
3/4

1. Discrete Normalizing Flows (Discrete NF)

  • 목표: 간단한 확률 분포 pz(z)p_z(z) (예: 표준 정규분포 zN(0,I)z \sim \mathcal{N}(0, I))를 복잡한 데이터 분포 px(x)p_x(x)로 변환하는 것입니다.
  • 방법: zzxx로 변환하는 함수 ff를 학습합니다. 이 함수 ff미분 가능하고 역변환이 가능(differentiable and invertible)해야 합니다.
  • 구조: ff는 여러 개의 간단한 변환 f1,f2,,fKf_1, f_2, \dots, f_K의 합성(composition)으로 이루어집니다.
    z0=zz_0 = z
    z1=f1(z0)z_1 = f_1(z_0)
    ...
    x=zK=fK(zK1)=fKf1(z0)x = z_K = f_K(z_{K-1}) = f_K \circ \dots \circ f_1(z_0)
  • 확률 밀도 계산 (핵심): '확률 변수의 변환(Change of Variables)' 공식을 사용합니다. z=f1(x)z = f^{-1}(x)일 때, xx의 로그 확률(log\log-likelihood)은 다음과 같습니다.
    logpx(x)=logpz(f1(x))+logdet(f1x)\log p_x(x) = \log p_z(f^{-1}(x)) + \log \left| \det\left( \frac{\partial f^{-1}}{\partial x} \right) \right|
    계산의 편의성을 위해, 역함수 f1f^{-1} 대신 순방향 함수 ff의 자코비안(Jacobian)을 사용하여 표현하는 것이 일반적입니다. (z0=f1(x)z_0 = f^{-1}(x))
    logpx(x)=logpz(z0)k=1Klogdet(Jfk(zk1))\log p_x(x) = \log p_z(z_0) - \sum_{k=1}^K \log \left| \det\left( J_{f_k}(z_{k-1}) \right) \right|
  • 단점: 각 변환 fkf_k는 (1) 역변환이 가능해야 하고 (2) 자코비안 행렬식(det(J)\det(J))을 계산하기 쉬워야 합니다. 이 두 가지 제약 조건 때문에 RealNVP, Glow, MAF 등 특정 구조(예: coupling layers, autoregressive)만 사용해야 하는 아키텍처의 제약이 매우 큽니다.

2. Discrete NF 수식 유도

2.1. 확률 밀도 함수 (Probability Density Function, PDF)

우선 확률 밀도가 무엇인지부터 짚고 넘어가야 합니다.

  • 이산(Discrete) vs. 연속(Continuous):
    • 이산 확률 변수 (예: 주사위 눈)는 '확률 질량 함수(PMF)'를 가집니다. P(X=3)=1/6P(X=3) = 1/6 처럼 특정 값에 대한 확률이 존재합니다.
    • 연속 확률 변수 (예: 키, 몸무게)는 '확률 밀도 함수(PDF)', px(x)p_x(x)를 가집니다. 키가 정확히 175.0000...cm일 확률은 0입니다. (측정 가능한 점이 무한히 많기 때문)
  • PDF의 의미: px(x)p_x(x) 자체는 확률이 아닙니다. (그래서 1보다 클 수도 있습니다.) px(x)p_x(x)xx 지점에서의 확률의 상대적인 밀도(촘촘함)를 나타냅니다.
  • 확률 계산: 우리는 PDF를 적분해야만 확률을 얻을 수 있습니다. xxaabb 사이에 있을 확률은 aa부터 bb까지 PDF 곡선 아래의 면적입니다.
    P(axb)=abpx(x)dxP(a \le x \le b) = \int_a^b p_x(x) dx
  • 핵심 속성: xx가 아주 작은 구간 dxdx 안에 존재할 "확률"은 px(x)dxp_x(x) \cdot dx (즉, 그 지점의 높이 ×\times 밑변)로 근사할 수 있습니다.

2.2. 자코비안 (Jacobian) 이란?

자코비안은 다변수 함수에서 "국소적인 변화율(스케일링 팩터)"을 나타내는 행렬입니다.

  • 1D (1차원)의 경우: 도함수(Derivative)
    xRx \in \mathbb{R}, yRy \in \mathbb{R}이고 y=f(x)y = f(x)일 때, 도함수 dydx\frac{dy}{dx}xx가 1만큼 변할 때 yy가 얼마나 변하는지를 나타냅니다. 즉, xx에서의 작은 길이 dxdxyy에서의 길이 dy=dydxdxdy = \frac{dy}{dx} dx로 "스케일링"됩니다.

  • nD (다차원)의 경우: 자코비안 행렬(Jacobian Matrix)
    zRnz \in \mathbb{R}^n, xRnx \in \mathbb{R}^n이고 x=f(z)x = f(z)일 때, ffnn개의 입력(z1,,znz_1, \dots, z_n)을 받아 nn개의 출력(x1,,xnx_1, \dots, x_n)을 내보내는 함수입니다.
    자코비안 Jf(z)J_f(z)는 이 변환의 모든 편미분(partial derivatives)을 모아놓은 행렬입니다.

    Jf(z)=xz=[x1z1x1z2x1znx2z1x2z2x2znxnz1xnz2xnzn]J_f(z) = \frac{\partial x}{\partial z} = \begin{bmatrix} \frac{\partial x_1}{\partial z_1} & \frac{\partial x_1}{\partial z_2} & \dots & \frac{\partial x_1}{\partial z_n} \\ \frac{\partial x_2}{\partial z_1} & \frac{\partial x_2}{\partial z_2} & \dots & \frac{\partial x_2}{\partial z_n} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial x_n}{\partial z_1} & \frac{\partial x_n}{\partial z_2} & \dots & \frac{\partial x_n}{\partial z_n} \end{bmatrix}
  • 자코비안 행렬식 (Jacobian Determinant)
    우리가 정말 관심 있는 것은 이 행렬의 행렬식(Determinant)det(Jf(z))\det(J_f(z)) 입니다.
    이 스칼라 값은 zz 공간에서의 아주 작은 부피(Volume) dVzdV_zff에 의해 xx 공간으로 변환될 때, 그 부피 dVxdV_x가 얼마나 스케일링되는지를 알려줍니다.

    dVx=det(Jf(z))dVzdV_x = |\det(J_f(z))| \cdot dV_z
    • det(Jf)>1|\det(J_f)| > 1 : 부피가 팽창 (Stretching)
    • det(Jf)<1|\det(J_f)| < 1 : 부피가 수축 (Shrinking)
    • det(Jf)=1|\det(J_f)| = 1 : 부피가 보존
      (부피는 항상 양수여야 하므로 절댓값 |\cdot|을 사용합니다.)

자코비안 예시

2차원 함수

{x=f1(u,v),y=f2(u,v)\begin{cases} x = f_1(u, v),\\ y = f_2(u, v) \end{cases}

를 생각해보자. 즉,
입력 공간 (u,v)(u, v)의 한 점이 함수 ff를 통해 (x,y)(x, y)로 매핑되는 상황이다.
그리고 각 함수가 아래와 같다고 가정하자.

{x=2u+vy=u+3v\begin{cases} x = 2u + v \\ y = u + 3v \end{cases}

이 함수는 (u,v)(u, v) 공간의 점을 (x,y)(x, y) 공간으로 선형 변환(linear transformation) 시킨다.

각 성분을 u,vu, v에 대해 편미분해서 자코비안 행렬을 만든다.

Jf(u,v)=[xuxvyuyv]=[2113]J_f(u, v) = \begin{bmatrix} \frac{\partial x}{\partial u} & \frac{\partial x}{\partial v} \\ \frac{\partial y}{\partial u} & \frac{\partial y}{\partial v} \\ \end{bmatrix} = \begin{bmatrix} 2 & 1 \\ 1 & 3 \end{bmatrix}

이제 이 행렬의 행렬식(Determinant) 을 계산한다.

det(Jf)=(2)(3)(1)(1)=61=5\det(J_f) = (2)(3) - (1)(1) = 6 - 1 = 5

이 말은 곧 다음을 의미한다:

(u,v)(u,v) 공간의 아주 작은 면적 요소 du,dvdu , dv가,
(x,y)(x,y) 공간에서는 55배 커진 면적 dx,dy=5,du,dvdx , dy = 5 , du , dv가 된다는 뜻이다.

즉, 이 변환은 모든 방향에서 국소적으로 5배 확대되는 변환이다.

  • 원래 (u,v)(u,v) 공간에서 1×1 정사각형 (면적 1)을 생각하자.
  • 이 정사각형은 (x,y)(x,y) 공간으로 매핑되면 평행사변형(parallelogram) 이 된다.
  • 그 평행사변형의 면적이 정확히 5배로 늘어난다.

즉,

면적 변화 비율=det(Jf)=5\text{면적 변화 비율} = |\det(J_f)| = 5

요약

개념수식의미
자코비안 행렬[2113]\begin{bmatrix} 2 & 1 \\ 1 & 3 \end{bmatrix}각 축 방향의 국소적 변화율
자코비안 행렬식det(Jf)=5\det(J_f)=5면적이 5배로 확대됨
스케일링 관계dAxy=5,dAuvdA_{xy} = 5, dA_{uv}(u,v)(u,v) 공간의 면적 요소가 (x,y)(x,y) 공간에서 5배 커짐

2.3. 첫 번째 공식 유도 (Change of Variables)

이제 이 두 개념을 합쳐서 첫 번째 공식을 유도해 보겠습니다.

  1. 기본 가정:

    • zz는 간단한 분포 pz(z)p_z(z)를 따릅니다. (예: 가우시안)
    • xxzz를 변환하여 얻어집니다: x=f(z)x = f(z).
    • ff는 역변환이 가능합니다: z=f1(x)z = f^{-1}(x).
    • 우리는 xx의 분포 px(x)p_x(x)를 알고 싶습니다.
  2. 핵심 원리: 확률 보존
    변환 전후에 zz가 아주 작은 부피 dVzdV_z 안에 존재할 확률과, 그에 대응하는 xxdVxdV_x 안에 존재할 확률은 반드시 같아야 합니다. (확률은 어디로 사라지거나 생겨나지 않습니다.)

    P(zdVz)=P(xdVx)P(z \in dV_z) = P(x \in dV_x)

  3. PDF로 표현하기 (1번 개념 적용)
    위 확률을 PDF로 표현하면 다음과 같습니다.

    pz(z)dVz=px(x)dVxp_z(z) |dV_z| = p_x(x) |dV_x|

  4. px(x)p_x(x)에 대해 정리하기

    px(x)=pz(z)dVzdVxp_x(x) = p_z(z) \left| \frac{dV_z}{dV_x} \right|

    여기서 dVzdVx\left| \frac{dV_z}{dV_x} \right|xx 공간에서 zz 공간으로 변환될 때의 부피 스케일링 팩터입니다.

  5. 자코비안 적용 (2번 개념 적용)

    • xx에서 zz로 가는 변환은 z=f1(x)z = f^{-1}(x) 입니다.
    • 이 변환의 자코비안은 Jf1(x)=f1xJ_{f^{-1}}(x) = \frac{\partial f^{-1}}{\partial x} 입니다.
    • 따라서 xx 공간의 부피 dVxdV_xzz 공간의 부피 dVzdV_z로 변환될 때의 스케일링 팩터는 det(f1x)\left| \det\left( \frac{\partial f^{-1}}{\partial x} \right) \right| 입니다.
    • 즉, dVzdVx=det(f1x)\left| \frac{dV_z}{dV_x} \right| = \left| \det\left( \frac{\partial f^{-1}}{\partial x} \right) \right| 입니다.
  6. 공식 완성
    위 식을 4번에 대입하고, z=f1(x)z = f^{-1}(x) 관계도 대입하면 다음과 같습니다.

    px(x)=pz(f1(x))det(f1x)p_x(x) = p_z(f^{-1}(x)) \left| \det\left( \frac{\partial f^{-1}}{\partial x} \right) \right|
  7. Log-Likelihood (로그 확률)
    확률 값들은 매우 작아져서 곱셈 연산 시 수치적으로 불안정(underflow)할 수 있습니다. 그래서 log\log를 씌워 덧셈으로 바꿔줍니다.

    logpx(x)=log(pz(f1(x))det(f1x))\log p_x(x) = \log \left( p_z(f^{-1}(x)) \left| \det\left( \frac{\partial f^{-1}}{\partial x} \right) \right| \right)

    log(ab)=log(a)+log(b)\log(a \cdot b) = \log(a) + \log(b) 이므로,

    logpx(x)=logpz(f1(x))+logdet(f1x)\log p_x(x) = \log p_z(f^{-1}(x)) + \log \left| \det\left( \frac{\partial f^{-1}}{\partial x} \right) \right|

    이것이 바로 질문에서 언급된 첫 번째 공식입니다.

2.4. 두 번째 공식 유도 (실용적 공식)

첫 번째 공식은 이론적으로는 맞지만 매우 비실용적입니다. 왜냐하면:

  • 우리는 보통 zxz \to x로 가는 순방향(forward) 함수 ff를 쉽게 설계합니다.
  • 하지만 xzx \to z로 가는 역방향(inverse) 함수 f1f^{-1}를 명시적으로 계산하는 것은 매우 어렵습니다.
  • f1f^{-1}의 자코비안 f1x\frac{\partial f^{-1}}{\partial x}를 계산하는 것은 더더욱 어렵습니다.

목표: 공식을 역함수 f1f^{-1}가 아닌, 우리가 아는 순방향 함수 ff와 그 자코비안 Jf(z)=fzJ_f(z) = \frac{\partial f}{\partial z}로 표현하고 싶습니다.

  1. 핵심 수학 정리 (2가지)

    • 정리 1 (역함수의 자코비안): 역함수의 자코비안은 원래 함수의 자코비안의 역행렬입니다.
      Jf1(x)=(Jf(z))1J_{f^{-1}}(x) = \left( J_f(z) \right)^{-1} (단, z=f1(x)z=f^{-1}(x))
    • 정리 2 (역행렬의 행렬식): 역행렬의 행렬식은 원래 행렬의 행렬식의 역수입니다.
      det(A1)=1det(A)=(det(A))1\det(A^{-1}) = \frac{1}{\det(A)} = (\det(A))^{-1}
  2. 자코비안 항 변환하기
    첫 번째 공식의 logdet(f1x)\log \left| \det\left( \frac{\partial f^{-1}}{\partial x} \right) \right| 항을 변환해 봅시다.

    logdet(f1x)=logdet((Jf(z))1)\log \left| \det\left( \frac{\partial f^{-1}}{\partial x} \right) \right| = \log \left| \det\left( (J_f(z))^{-1} \right) \right| (by 정리 1)

    =log(det(Jf(z)))1= \log \left| ( \det(J_f(z)) )^{-1} \right| (by 정리 2)

    =log(1det(Jf(z)))= \log \left( \frac{1}{\left| \det(J_f(z)) \right|} \right)

    log(1/a)=log(a)\log(1/a) = -\log(a) 이므로,

    =logdet(Jf(z))= - \log \left| \det(J_f(z)) \right|

  3. 단일 변환 공식 (순방향 기준)
    이것을 첫 번째 공식에 다시 대입합니다. (z0=f1(x)z_0 = f^{-1}(x) 라고 표기)

    logpx(x)=logpz(z0)+(logdet(Jf(z0)))\log p_x(x) = \log p_z(z_0) + \left( - \log \left| \det(J_f(z_0)) \right| \right)

    logpx(x)=logpz(z0)logdet(Jf(z0))\log p_x(x) = \log p_z(z_0) - \log \left| \det(J_f(z_0)) \right|

    이제 우리는 순방향 함수 ff의 자코비안만 계산하면 됩니다!

  4. KK개의 연속 변환으로 확장 (두 번째 공식 완성)
    Discrete NF는 ffKK개의 간단한 함수 f1,,fKf_1, \dots, f_K의 합성(composition)으로 만듭니다.

    z0f1z1f2z2fKzK=xz_0 \xrightarrow{f_1} z_1 \xrightarrow{f_2} z_2 \dots \xrightarrow{f_K} z_K = x

    이때 전체 변환 f=fKf1f = f_K \circ \dots \circ f_1 입니다.

    • 정리 3 (자코비안의 연쇄 법칙): 합성 함수의 자코비안은 각 자코비안의 행렬 곱입니다.
      Jf(z0)=JfK(zK1)JfK1(zK2)Jf1(z0)J_f(z_0) = J_{f_K}(z_{K-1}) \cdot J_{f_{K-1}}(z_{K-2}) \cdot \dots \cdot J_{f_1}(z_0)
    • 정리 4 (행렬식의 곱셈): 행렬 곱의 행렬식은 각 행렬식의 입니다.
      det(AB)=det(A)det(B)\det(A \cdot B) = \det(A) \cdot \det(B)

    이제 logdet(Jf(z0))\log \left| \det(J_f(z_0)) \right| 항을 이 KK개의 변환에 대해 풀어봅시다.

    logdet(Jf(z0))=logdet(JfK(zK1)Jf1(z0))\log \left| \det(J_f(z_0)) \right| = \log \left| \det(J_{f_K}(z_{K-1}) \cdot \dots \cdot J_{f_1}(z_0)) \right| (by 정리 3)

    =log(det(JfK(zK1))det(Jf1(z0)))= \log \left( \left| \det(J_{f_K}(z_{K-1})) \right| \cdot \dots \cdot \left| \det(J_{f_1}(z_0)) \right| \right) (by 정리 4)

    log(abc)=log(a)+log(b)+log(c)\log(a \cdot b \cdot c) = \log(a) + \log(b) + \log(c) 이므로,

    =logdet(JfK(zK1))++logdet(Jf1(z0))= \log \left| \det(J_{f_K}(z_{K-1})) \right| + \dots + \log \left| \det(J_{f_1}(z_0)) \right|

    =k=1Klogdet(Jfk(zk1))= \sum_{k=1}^K \log \left| \det\left( J_{f_k}(z_{k-1}) \right) \right|
    (여기서 zk1z_{k-1}kk번째 함수 fkf_k의 입력입니다.)

  5. 최종 공식
    이 합(sum)을 3번의 단일 변환 공식에 대입하면,

    logpx(x)=logpz(z0)k=1Klogdet(Jfk(zk1))\log p_x(x) = \log p_z(z_0) - \sum_{k=1}^K \log \left| \det\left( J_{f_k}(z_{k-1}) \right) \right|

    이것이 바로 우리가 원했던, 계산 가능한 형태의 두 번째 공식입니다.

3. 직접 손으로 써보며 이해하기

profile
상어 인형을 좋아하는 사람

0개의 댓글