Reference: 고성능 딥러닝 프레임워크 JAX/Flax - 박정현 I 모두의연구소 모두팝
"JAX is Autograd and XLA, brought together for high-performance numerical computing"
JAX는 NumPy 스타일의 API를 제공하여 연구자들과 엔지니어들이 적응하기 쉽다.
뿐만 아니라 배칭, 자동 미분, 병렬화 등이 가능하며, CPU, GPU, TPU에서 같은 코드로 돌아갈 수 있다는 장점이 있다.
코드로 보자.
import jax
import jax.numpy as jsp
import numpy as np
x_jnp = jnp.arange(10)
x_np = np.arange(10)
print(x_jnp)
print(x_np)
그러나 NumPy 와 달리 JAX NumPy Array 는 불변성을 가지므로 배열의 어떤 값을 바꾼다면 y = x.at(idx).set(value)와 같은 식으로 변경해야 한다.
JIT은 인터프리터와 컴파일 사이라고 볼 수 있다.
런타임 실행시 코드의 일부분을 미리 컴파일해두어 속도를 향상할 수 있다.
JAX에서 이 방식을 사용하고 싶다면 함수 위에 @jit을 사용하거나 jax.jit()으로 감싸 할 수 있다.
import jax
import jax.numpy as np
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x>0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arrange(1000000)
%timeit selu(x).block_until_ready()
selu_jit = jax.jit(selu)
# Warm up
selu_jit(x).block_until_ready()
%timeit selu_jit(x).block_until_ready()
jax.grad()로 자동 미분을 할 수 있다.
import jax
f = lambda x: x**3 + 2*x**2 - 2*x + 1
dfdx = jax.grad(f)
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
print(dfdx(1.)) #4.0
print(d2fdx(1.)) #10.0
print(d3fdx(1.)) #6.0
또 stopgrad()를 사용하면 해당 파라미터로 인한 학습을 멈출 수 있다.
예를 들어 A와 B파라미터가 존재할 때, 둘 중 한 파라미터로만 학습을 시키고 싶다면 다음과 같이 작성할 수 있다.
def train_fn(A_params, B_params, x, y):
a = model_A.apply({'params': A_params}, x)
b = model_B.apply({'params': B_params}, a)
return loss_fn(b, y)
def train_fn_A_learn(A_params, B_params, x, y):
return train_fn(A_params, jax.lax.stop_gradient(B_params), x, y)
def train_fn_B_learn(A_params, B_params, x, y):
return train_fn(jax.lax.stop_gradient(A_params), B_params, x, y)
jax.vmap()을 이용하여 벡터화할 수 있다.
아래 세 가지 방법을 보자.
def naively_batched_apply_matrix(v_batched):
return jnp.stack([apply_matrix(v) for v in v_batched] # 3.22ms
@jit
def vmap_batched_apply_matrix(v_batched):
return vmap(apply_matrix)(v_batched) #53μs
위 두 함수는 완전히 동일하지만, 속도는 확연하게 차이나는 것을 알 수 있다.
Lax는 JAX보다 lowlevel에 있는 API로 (XLA의 애너그램), JAX NumPy보다 타입에 엄격하다.
lax.add(1, 1.0) # 이는 허용되지 않는다.
Lax API의 경우 성능이 뛰어나지만 사용자 친화적이지는 않다.
JAX는 함수형프로그래밍을 따르기 때문에, 주의해야할 점이 몇몇 있다.
Pure function(순수 함수)
같은 입력값이 주어졌을 때 언제나 같은 결과값을 리턴하며, 부수효과를 만들지 않는다.
Stateless(비상태), Immutability(불변성)
데이터의 변경이 필요한 경우, 원본 데이터의 구조를 변경하지 않고, 그 데이터의 복사본을 만들어 그 일부를 변경한 뒤, 변경한 복사본을 사용하여 작업을 진행한다.
만약 내부적으로 상태를 가지는 객체를 사용하더라도, 외부 상태를 읽거나 쓰지 않는 한 순수 함수로 간주한다.
(iterator도 사용하면 부수효과가 발생한다!)
@partialjax.jit을 사용할 때, static한 전달인자가 있다면 @partial 데코레이터를 사용한다.
또한 벗어난 인덱스를 입력할 경우 JAX에서는 대하는 방법이 다르므로 이 또한 주의해야 한다.
JAX는 Random Number를 만들 때 key를 만들고 필요할 때마다 분할한다.
key = jax.random.PRNGKey(seed)
key1, key2 = jax.random.split(key)
NumPy에서는 병렬환경에서 재현불가능하며, 너무 많은 것을 가정하므로, JAX의 난수 생성방식은 디버깅에 유리하며, 순수함수를 구현할 수 있도록 도와준다.
난수가 고정되어 있다는 단점이 있다.
머신러닝 프로그램은 보통 Stateful하다.
모델 파라미터, 옵티마이저의 상태, 자체적으로 상태를 가지고 있는 레이어 등 거의 모든 모델은 상태를 갖는다.
하지만 함수형 프로그래밍에서는 Stateless해야 하므로, 이 상태를 외부에서 관리해줘야 한다.
fast_count = jax.jit(counter.count)
for _ in range(3):
value, state = fast_count(state)
print(value)
이 내용은 Flax에서 TrainState와도 연결된다.
JAX에서 기기를 병렬적으로 사용하기 위해서 jax.pmap을 사용한다.
pmap의 겨어 위에서 보았던 vmap 스타일을 그대로 사용하며, pmap을 사용할 때는 jax.jit을 써줄 필요가 없다. (알아서 JIT 컴파일이 된다!)
JAX는 Google Research만 쓰는 것이 아니다.
Deepmind의 경우에도 JAX를 도입하고 있으며, 대표적인 Framework로 다음과 같은 것들이 있다.
JAX는 강력하고 좋지만 직접 딥러닝 모델을 만들기엔 너무 낮은 레벨의 프레임워크다. (numpy로 모델을 짠다고 상상해보자)
Flax는 Google Research에서 개발해서 사용하고 있는 High level API이다.
HuggingFace에서도 Flax community week를 만들어서 변환하고 있다.
현재 Google Research에서 나온 대부분의 논문 구현은 Flax로 구현되어있다.
jax.value_and_grad로 value와 gradient를 계산(텐서플로우와 굉장히 유사하다!)
Class에서 PyTorch, Tensorflow를 비교했을 때 차이점
__init__초기화를 두지 않는다.JAX는 위에서 보았던 것처럼 함수를 Stateless하게 만들어줘야 한다.
TrainState를 만드는 방법
train_state.TrainState.create로 생성def create_train_state(key, learning_rate, momentum):
cnn = CNN()
params = cnn.init(key, jnp.ones((1, *mnist_img_size)))['params']
sgd_opt = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(apply_fn = cnn.apply, params=params, tx=sgd_opt)
Optax는 DeepMind에서 만든 JAX Ecosystem 중 하나로, gradient processing과 optimization 라이브러리이다. (Flax와 구분되어있다!)
Optax를 적용한다면 다음이 바뀌게 된다.