Continuous Normalizing Flows

상솜공방·2025년 10월 24일

딥러닝

목록 보기
4/4

1. Continuous Normalizing Flows (CNF)

Continuous Normalizing Flows(CNF)는 딥러닝 기반 생성 모델, 특히 Normalizing Flows (NF) 계열에서 매우 중요하고 흥미로운 개념입니다. CNF는 KK개의 이산적인 변환 fkf_k를 사용하는 대신, 이 변환 과정을 무한히 작은 변환의 연속으로 일반화합니다.

  • Discrete NF에서 변환된 xx의 확률 밀도 함수는 다음과 같았습니다.
logpx(x)=logpz(z0)k=1Klogdet(Jfk(zk1))\log p_x(x) = \log p_z(z_0) - \sum_{k=1}^K \log \left| \det\left( J_{f_k}(z_{k-1}) \right) \right|
  • 발상의 전환: 여기서 KK \to \infty가 되면, 이산적인 변환 zk=fk(zk1)z_k = f_k(z_{k-1})상미분방정식(Ordinary Differential Equation, ODE) 으로 기술할 수 있습니다.

  • CNF의 정의: 시간 tt에 따른 상태 z(t)z(t)의 변화를 신경망 ff로 모델링합니다. (이제부터 ztz_tz(t)z(t)의 함수 꼴로 표기하겠습니다.) 이 신경망 ffz(t)z(t)가 시간 tt에서 어떻게 변해야 하는지(즉, 속도 dzdt\frac{dz}{dt})를 정의하는 벡터 필드(vector field) 역할을 합니다.

    dz(t)dt=f(z(t),t,θ)\frac{dz(t)}{dt} = f(z(t), t, \theta)

    여기서 θ\theta는 신경망 ff의 파라미터입니다.

  • 변환 과정:

    • 샘플링 (Sampling, zxz \to x): t=0t=0에서의 초기값 z(0)pzz(0) \sim p_z를 위 ODE에 넣고 t=0t=0부터 t=Tt=T까지 순방향으로 수치 적분(ODE 풀이)합니다. 그 결과가 x=z(T)x = z(T)입니다.
      x=z(T)=z(0)+0Tf(z(t),t,θ)dtx = z(T) = z(0) + \int_{0}^{T} f(z(t), t, \theta) dt
    • 밀도 추정 (Density, xzx \to z): t=Tt=T에서의 값 x=z(T)x = z(T)를 ODE에 넣고 t=Tt=T부터 t=0t=0까지 역방향으로 수치 적분합니다. 그 결과가 z(0)z(0)입니다.
      z(0)=z(T)0Tf(z(t),t,θ)dt=x+T0f(z(t),t,θ)dtz(0) = z(T) - \int_{0}^{T} f(z(t), t, \theta) dt = x + \int_{T}^{0} f(z(t), t, \theta) dt
  • 가장 큰 장점 (자유로운 아키텍처): Discrete NF에선 ff의 역함수가 존재했어야 하나, CNF에선 그럴 필요가 없습니다! ODE는 ff가 기본적인 조건(예: Lipschitz 연속)만 만족하면 해가 유일하게 존재하며, 시간을 거꾸로 돌리는 것만으로 역변환(xz(0)x \to z(0))이 보장됩니다. 따라서 ff어떤 신경망 아키텍처(ResNet, MLP 등)든 자유롭게 사용할 수 있습니다.

1.1. CNF의 수식 유도

상미분방정식(ODE)이란?

상미분방정식(Ordinary Differential Equation, ODE) 은 한 개의 독립 변수(보통 '시간' tt)에 대한 함수(z(t)z(t))와 그 함수의 도함수(dzdt\frac{dz}{dt}) 사이의 관계를 기술하는 방정식입니다.

  • 핵심 아이디어: 시스템의 현재 상태(z(t)z(t))가 주어졌을 때, 다음 순간의 변화율(dzdt\frac{dz}{dt})이 얼마인지 알려주는 규칙(rule)입니다.
  • 예시: 가장 간단한 예로 dzdt=z\frac{dz}{dt} = z가 있습니다. 이는 "함수 zz의 변화율이 현재 zz의 값과 같다"는 뜻이며, 이 규칙을 따르는 함수는 z(t)=Cetz(t) = C e^t (지수 함수) 형태가 됩니다.
  • 물리적 비유: 댐에서 물이 흐를 때, 각 위치(zz)에서 물의 속도와 방향(dzdt\frac{dz}{dt})을 정의하는 유속 지도(vector field)와 같습니다.

KK \to \infty가 ODE가 되는 이유

이산적인(discrete) 변환 zk=fk(zk1)z_k = f_k(z_{k-1})가 어떻게 연속적인(continuous) ODE가 되는지 이해하는 것이 핵심입니다.

Discrete NF, 특히 ResNet과 유사한 잔차 연결(residual connection) 형태의 변환을 생각해 보면 이해하기 쉽습니다.

  1. Discrete 변환: zkz_kzk1z_{k-1}에서 "작은 변화" gk(zk1)g_k(z_{k-1})만큼 바뀐다고 가정해 봅시다. (여기서 fk(z)=z+gk(z)f_k(z) = z + g_k(z)입니다.)
    zk=zk1+gk(zk1)z_k = z_{k-1} + g_k(z_{k-1})

  2. 시간 개념 도입: KK개의 변환을 t=0t=0부터 t=Tt=T까지의 시간 동안 일어나는 일이라고 생각해 봅시다. 총 시간 TTKK개의 작은 시간 단계(Δt=T/K\Delta t = T/K)로 나눕니다.

    • z0z_0z(t=0)z(t=0)
    • z1z_1z(t=Δt)z(t=\Delta t)
    • zkz_kz(t=kΔt)z(t=k \cdot \Delta t)
  3. 식 변형: 위 식에서 gk(zk1)g_k(z_{k-1})를 "Δt\Delta t 시간 동안의 변화량"으로 만들기 위해 Δt\Delta t를 곱한 형태로 모델링합니다. 즉, ff를 '속도'로 정의합니다.
    z(tk)=z(tk1)+f(z(tk1),tk1)Δtz(t_k) = z(t_{k-1}) + f(z(t_{k-1}), t_{k-1}) \cdot \Delta t
    (gkg_kfΔtf \cdot \Delta t가 된 것입니다.)

  4. 도함수의 정의: 위 식을 Δt\Delta t로 나누고 z(tk)z(tk1)z(t_k) - z(t_{k-1})Δz\Delta z로 표기하면,
    z(tk)z(tk1)Δt=ΔzΔt=f(z(tk1),tk1)\frac{z(t_k) - z(t_{k-1})}{\Delta t} = \frac{\Delta z}{\Delta t} = f(z(t_{k-1}), t_{k-1})
    이것은 도함수의 근사식(finite difference) 입니다.

  5. KK \to \infty 극한: 이제 KK \to \infty (변환 횟수를 무한대로)로 보냅니다.

    limK(z(tk)z(tk1)Δt)LHS=limK(f(z(tk1),tk1))RHS\lim_{K \to \infty} \underbrace{\left( \frac{z(t_k) - z(t_{k-1})}{\Delta t} \right)}_{\text{LHS}} = \lim_{K \to \infty} \underbrace{\left( f(z(t_{k-1}), t_{k-1}) \right)}_{\text{RHS}}

    이것을 좌변(LHS)과 우변(RHS)으로 나누어 살펴보자.

    좌변 (LHS): limΔzΔtdz(t)dt\lim \frac{\Delta z}{\Delta t} \to \frac{dz(t)}{dt}

    • 이것은 님이 이미 이해한 부분이다.
    • KK \to \infty가 되면 Δt0\Delta t \to 0이 된다.
    • tkt_ktk1t_{k-1} 사이의 간격이 무한히 좁아지므로, 이 구간에서의 평균 변화율은 tt라는 특정 시점에서의 순간 변화율(도함수) dz(t)dt\frac{dz(t)}{dt}로 정의된다.

    우변 (RHS): limf(z(tk1),tk1)f(z(t),t)\lim f(z(t_{k-1}), t_{k-1}) \to f(z(t), t)

    • tk1tt_{k-1} \to t: KK \to \infty 극한에서, tkt_ktk1t_{k-1}은 모두 동일한 특정 시점 tt로 수렴한다. (간격 Δt\Delta t가 0이 되므로) 따라서 이산적인 시간 인덱스 tk1t_{k-1}은 연속적인 시간 변수 tt가 된다.

    • z(tk1)z(t)z(t_{k-1}) \to z(t): tk1t_{k-1}tt로 수렴하므로, tk1t_{k-1}에서의 상태 z(tk1)z(t_{k-1}) 역시 tt에서의 상태 z(t)z(t)로 수렴한다. (함수 zz가 연속적이라고 가정)

    • f()f()f(\cdot) \to f(\cdot): 함수 ff 자체(신경망)는 KK가 변한다고 해서 바뀌지 않는다. ff는 우리가 정의한 규칙일 뿐이다. 이 ff가 연속 함수라고 가정하면 (ODE가 잘 정의되기 위한 기본 조건), 입력값이 극한으로 수렴할 때 함숫값도 극한의 함숫값으로 수렴한다.

      • 즉, limf(z(tk1),tk1)=f(limz(tk1),limtk1)=f(z(t),t)\lim f(z(t_{k-1}), t_{k-1}) = f(\lim z(t_{k-1}), \lim t_{k-1}) = f(z(t), t) 이다.

    결론적으로, 좌변과 우변을 합치면 다음과 같다.

    1. 좌변tt라는 시점에서의 순간 속도(dzdt\frac{dz}{dt})가 되었다.
    2. 우변tt라는 시점의 상태(z(t)z(t))와 시각(tt)을 입력받아 속도를 계산하는 함수(ff)가 되었다.

    따라서 이산적인(discrete) 관계식:

    " tk1t_{k-1}부터 tkt_k까지의 평균 속도(ΔzΔt\frac{\Delta z}{\Delta t})는 tk1t_{k-1}에서의 상태 z(tk1)z(t_{k-1})로 계산한 속도 f()f(\cdot)와 같다."

    이것이 KK \to \infty 극한을 만나 연속적인(continuous) 관계식:

    " tt 시점에서의 순간 속도(dzdt\frac{dz}{dt})는 tt 시점에서의 상태 z(t)z(t)로 계산한 속도 f(z(t),t)f(z(t), t)와 같다."

    라는 상미분방정식(ODE) 으로 일반화되는 것이다. (여기서 ff는 학습 가능한 파라미터 θ\theta를 가지므로 f(z(t),t,θ)f(z(t), t, \theta)θ\theta를 추가하여 표기하기로 한다.)


1.2. dz(t)dt=f(z(t),t,θ)\frac{dz(t)}{dt} = f(z(t), t, \theta)의 의미

이제 CNF 모델의 핵심적인 정의(definition) 로서 유도된 위 수식을 이해해보자.

  • dz(t)dt\frac{dz(t)}{dt}: tt라는 가상의 '시간'이 흐름에 따라 zz가 얼마나 빠르고 어느 방향으로 변하는지 나타내는 순간 속도(velocity) 벡터입니다.

  • f(z(t),t,θ)f(z(t), t, \theta): 이 속도를 계산해내는 함수, 즉 벡터 필드(vector field) 입니다. 이 함수 ff신경망으로 구현되며 θ\theta는 이 신경망의 학습 가능한 파라미터입니다.

이 방정식의 의미는 다음과 같습니다.

"어떤 데이터 포인트 zztt라는 시점에 특정 위치 z(t)z(t)에 있을 때, 이 zz가 다음 순간(t+dtt+dt)에 어디로 얼마나 빠르게 움직여야 하는지(dzdt\frac{dz}{dt})는 신경망 ff가 결정한다."

그리고 신경망 ff의 입력과 역할은 다음과 같습니다.
1. z(t)z(t) (현재 위치): 속도는 현재 위치에 따라 달라야 합니다. (예: 강물의 유속은 강둑 근처와 중앙이 다릅니다.)
2. tt (현재 시간): 벡터 필드(유속) 자체가 시간에 따라 변할 수 있습니다. (예: 밀물/썰물에 따라 강물의 흐름이 바뀔 수 있습니다.) 이는 모델의 표현력을 크게 높여줍니다.
3. θ\theta (파라미터): 우리가 '학습'하는 대상입니다.

결론: 우리는 zzpzp_z(단순한 데이터 분포)에서 pxp_x(복잡한 데이터 분포)로 흘러가는 가장 그럴듯한 '흐름' 또는 '경로' 를 만들 수 있는 최적의 벡터 필드 ff 를 딥러닝을 통해 학습하는 것입니다.


1.3. 샘플링(Sampling)과 밀도 추정(Density)의 수식 유도

이 두 과정은 모두 미적분학의 기본 정리(Fundamental Theorem of Calculus) 로부터 직접 유도됩니다.

핵심 방정식은 dz(t)dt=f(z(t),t,θ)\frac{dz(t)}{dt} = f(z(t), t, \theta)입니다.

이 방정식의 양변을 tt에 대해 tat_a부터 tbt_b까지 적분해 봅시다.

tatbdz(t)dtdt=tatbf(z(t),t,θ)dt\int_{t_a}^{t_b} \frac{dz(t)}{dt} dt = \int_{t_a}^{t_b} f(z(t), t, \theta) dt

미적분학의 기본 정리에 의해, 도함수(dzdt\frac{dz}{dt})를 적분하면 원시 함수의 차이(z(tb)z(ta)z(t_b) - z(t_a))가 됩니다.

z(tb)z(ta)=tatbf(z(t),t,θ)dtz(t_b) - z(t_a) = \int_{t_a}^{t_b} f(z(t), t, \theta) dt

이것이 CNF의 모든 변환을 설명하는 일반 해(General Solution) 입니다. 이제 이 식을 두 가지 상황에 적용해 봅시다.

1.3.1. 순방향 샘플링 (Sampling, zxz \to x) ➡️

  • 목표: 간단한 분포 pzp_z에서 뽑은 z(0)z(0)(초기 상태)로부터 실제 데이터 x=z(T)x = z(T)(최종 상태)를 생성합니다.

  • 과정: 시간을 t=0t=0에서 t=Tt=T까지 순방향으로 흐르게 합니다.

  • 유도:

    1. 일반 해에서 ta=0t_a = 0, tb=Tt_b = T로 설정합니다.

    2. z(T)z(0)=0Tf(z(t),t,θ)dtz(T) - z(0) = \int_{0}^{T} f(z(t), t, \theta) dt

    3. z(0)z(0)를 우변으로 넘기면, z(T)z(T)를 얻는 식이 나옵니다.

    4. x=z(T)x = z(T) 이므로, 다음과 같이 정리됩니다.

      x=z(T)=z(0)+0Tf(z(t),t,θ)dtx = z(T) = z(0) + \int_{0}^{T} f(z(t), t, \theta) dt
    • 의미: t=0t=0일 때의 초기값 z(0)z(0)에, 00초부터 TT초까지 신경망 ff가 계산해준 모든 '순간적인 변화(속도)'를 전부 더하면(0T\int_{0}^{T}) 최종 위치 x=z(T)x=z(T)를 알 수 있다는 뜻입니다. 이 적분은 실제로 ODE 솔버라는 수치해석 기법으로 풀게 됩니다.

1.3.2. 역방향 밀도 추정 (Density, xzx \to z) ⬅️

  • 목표: 주어진 데이터 x=z(T)x = z(T)(최종 상태)가 어떤 z(0)z(0)(초기 상태)로부터 왔는지 역추적합니다. (이 z(0)z(0)를 알아야 logpz(z(0))\log p_z(z(0)) 값을 계산할 수 있습니다.)

  • 과정: 시간을 t=Tt=T에서 t=0t=0까지 역방향으로 흐르게 합니다.

  • 유도:

    1. 일반 해에서 ta=Tt_a = T, tb=0t_b = 0로 설정합니다. (시작이 TT, 끝이 00)

    2. z(0)z(T)=T0f(z(t),t,θ)dtz(0) - z(T) = \int_{T}^{0} f(z(t), t, \theta) dt

    3. z(T)z(T)를 우변으로 넘기면, z(0)z(0)를 얻는 식이 나옵니다.

    4. x=z(T)x = z(T) 이므로, 다음과 같이 정리됩니다.

      z(0)=z(T)+T0f(z(t),t,θ)dt=x+T0f(z(t),t,θ)dtz(0) = z(T) + \int_{T}^{0} f(z(t), t, \theta) dt = x + \int_{T}^{0} f(z(t), t, \theta) dt
    • 의미: t=Tt=T일 때의 값 x=z(T)x=z(T)에서 출발하여, 시간을 거꾸로(T0\int_{T}^{0}) 흐르게 하면서 신경망 ff가 알려주는 변화를 (거꾸로) 더해가면, t=0t=0일 때의 초기 위치 z(0)z(0)를 복원할 수 있다는 뜻입니다.
    • 핵심 장점: ff의 역함수 f1f^{-1}를 따로 구할 필요 없이, 동일한 ff를 사용하되 ODE 솔버를 반대 방향(T에서 0으로) 으로 작동시키기만 하면 역변환이 자동으로 계산됩니다. 이것이 CNF가 아키텍처에 제약이 없는 가장 큰 이유입니다.

2. CNF의 확률 변수 변환

Discrete NF의 로그-가능도(log-likelihood) 공식은 다음과 같았다.

logpx(x)=logpz(z0)k=1Klogdet(Jfk(zk1))\log p_x(x) = \log p_z(z_0) - \sum_{k=1}^K \log \left| \det\left( J_{f_k}(z_{k-1}) \right) \right|

여기서 첫 번째 항 logpz(z0)\log p_z(z_0)z0z_0를 CNF에서 z(0)z(0)로 구하는 방법은 위에서 다루었다.

z(0)=x+T0f(z(t),t,θ)dtz(0) = x + \int_{T}^{0} f(z(t), t, \theta) dt

이제 두 번째 항, 즉 야코비안 행렬식의 로그 값 합계(sum)KK \to \infty 극한에서 어떻게 적분(integral) 으로 변하는지 유도해 보겠다.

결론부터 말하자면, 우리가 보여야 할 것은 다음과 같다.

k=1Klogdet(Jfk(zk1))K0TTr(f(z(t),t,θ)z(t))dt\sum_{k=1}^K \log \left| \det\left( J_{f_k}(z_{k-1}) \right) \right| \quad \xrightarrow{K \to \infty} \quad \int_{0}^{T} \text{Tr}\left( \frac{\partial f(z(t), t, \theta)}{\partial z(t)} \right) dt

이 유도는 몇 가지 핵심적인 수학적 근사 단계를 거친다.

2.1. 로그-가능도 변화율 dLdt\frac{dL}{dt} 정의하기

전체 로그-가능도의 변화량(두 번째 항)을 시간 tt에 대한 함수 L(t)L(t)라고 생각하자. L(T)L(T)가 우리가 구하려는 총 변화량이다.

미적분학의 기본 정리에 의해, t=0t=0부터 t=Tt=T까지의 총변화량 L(T)L(T)L(t)L(t)순간 변화율(도함수) dLdt\frac{dL}{dt}00부터 TT까지 적분한 것과 같다.

L(T)=0TdL(t)dtdtL(T) = \int_0^T \frac{dL(t)}{dt} dt

따라서 우리의 목표는 L(t)L(t)의 순간 변화율 dLdt\frac{dL}{dt}을 찾는 것이다. dLdt\frac{dL}{dt}tt 시점에서 Δt\Delta t라는 매우 짧은 시간 동안의 변화량을 Δt\Delta t로 나눈 극한값이다.

dL(t)dt=limΔt0L(t+Δt)L(t)Δt\frac{dL(t)}{dt} = \lim_{\Delta t \to 0} \frac{L(t+\Delta t) - L(t)}{\Delta t}

여기서 ΔL=L(t+Δt)L(t)\Delta L = L(t+\Delta t) - L(t)는 Discrete NF의 단일 스텝(single step) kk에서의 로그-행렬식 값, 즉 logdet(Jfk)\log |\det(J_{f_k})|와 같다.
(이제부터 logdet(Jfk(zk1))\log |\det\left( J_{f_k}(z_{k-1}) \right)|logdet(Jfk)\log |\det(J_{f_k})|로 줄여 부르기로 한다.)

따라서 우리는 dLdt=limΔt0ΔLΔt=limΔt0logdet(Jfk)Δt\frac{dL}{dt} = \lim_{\Delta t \to 0} \frac{\Delta L}{\Delta t} = \lim_{\Delta t \to 0} \frac{\log |\det(J_{f_k})|}{\Delta t} 를 계산해야 한다.

2.2. 단일 스텝(fkf_k)의 야코비안 JfkJ_{f_k}

KK \to \infty일 때, kk번째 변환 fkf_k는 다음과 같이 근사된다. (이전 질문에서 유도)
zk=z(t+Δt)z_k = z(t+\Delta t) 이고 zk1=z(t)z_{k-1} = z(t) 이다.

z(t+Δt)z(t)+f(z(t),t)Δtz(t+\Delta t) \approx z(t) + f(z(t), t) \cdot \Delta t

이제 이 단일 스텝 변환의 야코비안 JfkJ_{f_k}z(t)z(t)에 대해 계산해 보자.

Jfk=z(t+Δt)z(t)=z(t)(z(t)+f(z(t),t)Δt)J_{f_k} = \frac{\partial z(t+\Delta t)}{\partial z(t)} = \frac{\partial}{\partial z(t)} \left( z(t) + f(z(t), t) \cdot \Delta t \right)

z(t)z(t)z(t)z(t)로 미분하면 항등 행렬 II가 되고, f()Δtf(\cdot) \cdot \Delta t 항은 ff의 야코비안 fz(t)\frac{\partial f}{\partial z(t)}Δt\Delta t가 곱해진 형태가 된다.

Jfk=I+f(z(t),t)z(t)ΔtJ_{f_k} = I + \frac{\partial f(z(t), t)}{\partial z(t)} \cdot \Delta t

(여기서 f(z(t),t)z(t)\frac{\partial f(z(t), t)}{\partial z(t)}JfJ_f라고 줄여서 부르겠다.)

Jfk=I+JfΔtJ_{f_k} = I + J_f \cdot \Delta t

2.3. 두 가지 핵심 근사

이제 우리가 구해야 할 logdet(Jfk)\log |\det(J_{f_k})|에 위 식을 대입한다.

logdet(I+JfΔt)\log \left| \det\left( I + J_f \cdot \Delta t \right) \right|

Δt\Delta t00에 매우 가까운 작은 값이므로, 두 가지 근사를 사용할 수 있다.

근사 1: 야코비의 공식 (Jacobi's Formula)
행렬 AA와 매우 작은 스칼라 ϵ\epsilon에 대해, 항등 행렬 II에 가까운 행렬의 행렬식(determinant)은 다음과 같이 근사된다.

det(I+ϵA)1+ϵTr(A)\det(I + \epsilon A) \approx 1 + \epsilon \cdot \text{Tr}(A)

여기서 Tr(A)\text{Tr}(A)는 행렬 AA대각합(Trace) 이다.
우리의 경우 ϵ=Δt\epsilon = \Delta t 이고 A=JfA = J_f 이다.

det(Jfk)=det(I+JfΔt)1+Tr(Jf)Δt\det(J_{f_k}) = \det(I + J_f \cdot \Delta t) \approx 1 + \text{Tr}(J_f) \cdot \Delta t

근사 2: 로그의 테일러 근사
xx00에 매우 가까울 때, log(1+x)\log(1+x)는 다음과 같이 근사된다.

log(1+x)x\log(1+x) \approx x

우리의 경우 x=Tr(Jf)Δtx = \text{Tr}(J_f) \cdot \Delta t 이다. Δt0\Delta t \to 0 이므로 이 값은 00에 매우 가깝다.
(또한 det(Jfk)1\det(J_{f_k}) \approx 1이므로 양수라서 절대값 |\cdot| 기호는 생략할 수 있다.)

log(det(Jfk))log(1+Tr(Jf)Δt)Tr(Jf)Δt\log(\det(J_{f_k})) \approx \log(1 + \text{Tr}(J_f) \cdot \Delta t) \approx \text{Tr}(J_f) \cdot \Delta t

이것이 바로 Δt\Delta t 시간 동안의 로그-가능도 변화량이다.

2.4. 순간 변화율 dLdt\frac{dL}{dt} 계산 및 적분

이제 1번에서 정의한 순간 변화율 dLdt\frac{dL}{dt}을 계산할 수 있다.

dL(t)dt=limΔt0logdet(Jfk)Δt\frac{dL(t)}{dt} = \lim_{\Delta t \to 0} \frac{\log |\det(J_{f_k})|}{\Delta t}

위에서 구한 근사식을 대입하면,

dL(t)dt=limΔt0Tr(Jf)ΔtΔt=Tr(Jf)\frac{dL(t)}{dt} = \lim_{\Delta t \to 0} \frac{\text{Tr}(J_f) \cdot \Delta t}{\Delta t} = \text{Tr}(J_f)

JfJ_f의 원래 표기(파라미터 θ\theta 포함)를 다시 사용하면, tt 시점에서의 순간 변화율은 ff의 야코비안의 대각합(Trace)이 된다.

dL(t)dt=Tr(f(z(t),t,θ)z(t))\frac{dL(t)}{dt} = \text{Tr}\left( \frac{\partial f(z(t), t, \theta)}{\partial z(t)} \right)

KK \to \infty일 때, 이산적인 합 k=1K\sum_{k=1}^K은 이 순간 변화율을 t=0t=0부터 t=Tt=T까지 연속적으로 더하는 적분 0Tdt\int_0^T dt로 수렴한다.

limKk=1Klogdet(Jfk)=0TdL(t)dtdt=0TTr(f(z(t),t,θ)z(t))dt\lim_{K \to \infty} \sum_{k=1}^K \log \left| \det\left( J_{f_k} \right) \right| = \int_{0}^{T} \frac{dL(t)}{dt} dt = \int_{0}^{T} \text{Tr}\left( \frac{\partial f(z(t), t, \theta)}{\partial z(t)} \right) dt

2.5. CNF 로그-가능도 공식 완성

이제 Discrete NF 공식의 각 항을 우리가 유도한 CNF 버전으로 대체한다.

  1. z0z_0 \longrightarrow z(0)z(0) (ODE 역방향 적분으로 계산)
  2. k=1Klogdet(Jfk)\sum_{k=1}^K \log |\det(J_{f_k})| \longrightarrow 0TTr(fz(t))dt\int_{0}^{T} \text{Tr}\left( \frac{\partial f}{\partial z(t)} \right) dt

이 두 결과를 원래 식에 대입하면, CNF의 최종 로그-가능도 공식이 완성된다.

logpx(x)=logpz(z(0))0TTr(f(z(t),t,θ)z(t))dt\log p_x(x) = \log p_z(z(0)) - \int_{0}^{T} \text{Tr}\left( \frac{\partial f(z(t), t, \theta)}{\partial z(t)} \right) dt

이를 Discrete NF의 로그-가능도 공식과 비교해보자.

logpx(x)=logpz(z0)k=1Klogdet(Jfk(zk1))\log p_x(x) = \log p_z(z_0) - \sum_{k=1}^K \log \left| \det\left( J_{f_k}(z_{k-1}) \right) \right|
profile
상어 인형을 좋아하는 사람

0개의 댓글