JAX/Flax 에 대해 알아보자

Bard·2024년 1월 14일
post-thumbnail

Reference: 고성능 딥러닝 프레임워크 JAX/Flax - 박정현 I 모두의연구소 모두팝

JAX 101

What is JAX?

"JAX is Autograd and XLA, brought together for high-performance numerical computing"

JAX는 NumPy 스타일의 API를 제공하여 연구자들과 엔지니어들이 적응하기 쉽다.

뿐만 아니라 배칭, 자동 미분, 병렬화 등이 가능하며, CPU, GPU, TPU에서 같은 코드로 돌아갈 수 있다는 장점이 있다.

코드로 보자.

NumPy와 매우 유사한 문법

import jax
import jax.numpy as jsp

import numpy as np

x_jnp = jnp.arange(10)
x_np = np.arange(10)

print(x_jnp)
print(x_np)

그러나 NumPy 와 달리 JAX NumPy Array 는 불변성을 가지므로 배열의 어떤 값을 바꾼다면 y = x.at(idx).set(value)와 같은 식으로 변경해야 한다.

JIT(Just-In-Time Compilation)

JIT은 인터프리터와 컴파일 사이라고 볼 수 있다.

런타임 실행시 코드의 일부분을 미리 컴파일해두어 속도를 향상할 수 있다.

JAX에서 이 방식을 사용하고 싶다면 함수 위에 @jit을 사용하거나 jax.jit()으로 감싸 할 수 있다.

import jax
import jax.numpy as np

def selu(x, alpha=1.67, lambda_=1.05):
	return lambda_ * jnp.where(x>0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arrange(1000000)
%timeit selu(x).block_until_ready()

selu_jit = jax.jit(selu)

# Warm up
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

Autograd

jax.grad()로 자동 미분을 할 수 있다.

import jax

f = lambda x: x**3 + 2*x**2 - 2*x + 1

dfdx = jax.grad(f)
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)

print(dfdx(1.))	 #4.0
print(d2fdx(1.)) #10.0
print(d3fdx(1.)) #6.0

stopgrad()를 사용하면 해당 파라미터로 인한 학습을 멈출 수 있다.

예를 들어 A와 B파라미터가 존재할 때, 둘 중 한 파라미터로만 학습을 시키고 싶다면 다음과 같이 작성할 수 있다.

def train_fn(A_params, B_params, x, y):
    a = model_A.apply({'params': A_params}, x)
    b = model_B.apply({'params': B_params}, a)
    return loss_fn(b, y)
    
def train_fn_A_learn(A_params, B_params, x, y):
    return train_fn(A_params, jax.lax.stop_gradient(B_params), x, y)
    
def train_fn_B_learn(A_params, B_params, x, y):
    return train_fn(jax.lax.stop_gradient(A_params), B_params, x, y)

vmap

jax.vmap()을 이용하여 벡터화할 수 있다.

아래 세 가지 방법을 보자.

  1. 배치마다 함수를 적용시키고 쌓는 방법
def naively_batched_apply_matrix(v_batched):
	return jnp.stack([apply_matrix(v) for v in v_batched] # 3.22ms
  1. vmap을 이용하여 벡터화하는 방법
@jit
def vmap_batched_apply_matrix(v_batched):
	return vmap(apply_matrix)(v_batched) #53μs

위 두 함수는 완전히 동일하지만, 속도는 확연하게 차이나는 것을 알 수 있다.

Lax

Lax는 JAX보다 lowlevel에 있는 API로 (XLA의 애너그램), JAX NumPy보다 타입에 엄격하다.

lax.add(1, 1.0) # 이는 허용되지 않는다.

Lax API의 경우 성능이 뛰어나지만 사용자 친화적이지는 않다.

JAX를 사용할 시 주의할 점

JAX는 함수형프로그래밍을 따르기 때문에, 주의해야할 점이 몇몇 있다.

  1. Pure function(순수 함수)
    같은 입력값이 주어졌을 때 언제나 같은 결과값을 리턴하며, 부수효과를 만들지 않는다.

  2. Stateless(비상태), Immutability(불변성)
    데이터의 변경이 필요한 경우, 원본 데이터의 구조를 변경하지 않고, 그 데이터의 복사본을 만들어 그 일부를 변경한 뒤, 변경한 복사본을 사용하여 작업을 진행한다.

만약 내부적으로 상태를 가지는 객체를 사용하더라도, 외부 상태를 읽거나 쓰지 않는 한 순수 함수로 간주한다.

(iterator도 사용하면 부수효과가 발생한다!)

  1. @partial

jax.jit을 사용할 때, static한 전달인자가 있다면 @partial 데코레이터를 사용한다.

또한 벗어난 인덱스를 입력할 경우 JAX에서는 대하는 방법이 다르므로 이 또한 주의해야 한다.

Random Numbers

JAX는 Random Number를 만들 때 key를 만들고 필요할 때마다 분할한다.

key = jax.random.PRNGKey(seed)
key1, key2 = jax.random.split(key)

NumPy에서는 병렬환경에서 재현불가능하며, 너무 많은 것을 가정하므로, JAX의 난수 생성방식은 디버깅에 유리하며, 순수함수를 구현할 수 있도록 도와준다.

난수가 고정되어 있다는 단점이 있다.

Stateless Class

머신러닝 프로그램은 보통 Stateful하다.

모델 파라미터, 옵티마이저의 상태, 자체적으로 상태를 가지고 있는 레이어 등 거의 모든 모델은 상태를 갖는다.

하지만 함수형 프로그래밍에서는 Stateless해야 하므로, 이 상태를 외부에서 관리해줘야 한다.

fast_count = jax.jit(counter.count)

for _ in range(3):
	value, state = fast_count(state)
    print(value)

이 내용은 Flax에서 TrainState와도 연결된다.

Parallel Evaluation in JAX

JAX에서 기기를 병렬적으로 사용하기 위해서 jax.pmap을 사용한다.
pmap의 겨어 위에서 보았던 vmap 스타일을 그대로 사용하며, pmap을 사용할 때는 jax.jit을 써줄 필요가 없다. (알아서 JIT 컴파일이 된다!)

JAX Ecosystem

JAX는 Google Research만 쓰는 것이 아니다.

Deepmind의 경우에도 JAX를 도입하고 있으며, 대표적인 Framework로 다음과 같은 것들이 있다.

  • Haiku: Neural Network
  • Optax: Optimizer만 따로 사용
  • RLax: 강화학습
  • Chex: 테스트환경

What is Flax?

JAX는 강력하고 좋지만 직접 딥러닝 모델을 만들기엔 너무 낮은 레벨의 프레임워크다. (numpy로 모델을 짠다고 상상해보자)

Flax는 Google Research에서 개발해서 사용하고 있는 High level API이다.

HuggingFace에서도 Flax community week를 만들어서 변환하고 있다.

현재 Google Research에서 나온 대부분의 논문 구현은 Flax로 구현되어있다.

  • ViT
  • PaLM
  • Imagen
  • JaxNeRF

Flax 사용법

  1. 초기화
  • Random Numbers를 위한 key 나눔
  • model 내에 있는 parameter를 초기화한다.
  1. 적용
  • model에 apply를 주고 예측을 실행한다.
  1. 학습
  • 원하는 epoch만큼 for문으로 돌림
  • 생성했던 loss에서 jax.value_and_grad로 value와 gradient를 계산
  • parameter는 tree_map을 이용하여 update를 실행한다.

(텐서플로우와 굉장히 유사하다!)

Flax 모델 선언 방법

Class에서 PyTorch, Tensorflow를 비교했을 때 차이점

  • __init__초기화를 두지 않는다.
  • 대신 @nn.compact를 사요아여 초기화를 생략할 수 있다.

TrainState 만들기

JAX는 위에서 보았던 것처럼 함수를 Stateless하게 만들어줘야 한다.

  • Train을 진행할 때 TrainState를 만들어서 학습을 진행하면 더 빠른 학습이 가능하다.

TrainState를 만드는 방법

  • 모델 선언
  • parameter 초기화
  • optimization 선언
  • train_state.TrainState.create로 생성
def create_train_state(key, learning_rate, momentum):
	cnn = CNN()
    params = cnn.init(key, jnp.ones((1, *mnist_img_size)))['params']
    sgd_opt = optax.sgd(learning_rate, momentum)
    
    return train_state.TrainState.create(apply_fn = cnn.apply, params=params, tx=sgd_opt)

Optax

Optax는 DeepMind에서 만든 JAX Ecosystem 중 하나로, gradient processing과 optimization 라이브러리이다. (Flax와 구분되어있다!)

Optax를 적용한다면 다음이 바뀌게 된다.

  • 이전에는 loss를 전부 JAX로 만들어 줘야 했다면, optax에서 편안하게 선언만 해주면 가능하다.
  • 항상 동일 패턴이므로 state를 외부로 빼줘야 한다.
  • gradient와 state를 update해주고, parameter의 경우에도 업데이트를 진행한다.
profile
돈 되는 건 다 공부합니다.

0개의 댓글