JAX 기능: Tensorflow에 있던 XLA 기능을 빼와서 독립적인 모듈로 만들어낸 것. @tf.function 데코레이터만 함수에 붙여주면 JIT 연산이 가능해집니다.
fn, fn_jit, fn_tf2 비교
출처
https://brunch.co.kr/@chris-song/99
JAX JIT 이 numpy에 비해 약 367배 빠르다.
참고 사이트
https://www.tensorflow.org/xla?hl=ko
https://github.com/google/jax
간단한 신경망을 구현.
이 신경망은 두 개의 완전 연결 계층(Dense layer)으로 구성됩니다.
import numpy as np
import jax.numpy as jnp
from jax import jit, random
import tensorflow as tf
# 순수 numpy를 사용하는 함수
def fn(x, w1, w2, b1, b2):
z1 = np.dot(x, w1) + b1
a1 = np.maximum(z1, 0) # ReLU activation
z2 = np.dot(a1, w2) + b2
return z2
# JAX의 JIT 컴파일을 사용하는 함수
@jit
def fn_jit(x, w1, w2, b1, b2):
z1 = jnp.dot(x, w1) + b1
a1 = jnp.maximum(z1, 0) # ReLU activation
z2 = jnp.dot(a1, w2) + b2
return z2
# 테스트 데이터 및 가중치 생성
x = np.random.randn(1000, 784)
w1 = np.random.randn(784, 512)
w2 = np.random.randn(512, 10)
b1 = np.random.randn(512)
b2 = np.random.randn(10)
# 실행 시간 측정
import time
start = time.time()
fn(x, w1, w2, b1, b2)
numpy_time = time.time() - start
start = time.time()
fn_jit(x, w1, w2, b1, b2)
jax_time = time.time() - start
print(f"Numpy 시간: {numpy_time:.5f} 초")
print(f"JAX JIT 시간: {jax_time:.5f} 초")
Numpy 시간: 0.07953 초
JAX JIT 시간: 0.03624 초