JAX: 속도 비교

Serendipity·2023년 10월 7일
0

Google  JAX/Flax

목록 보기
4/7

JAX 기능: Tensorflow에 있던 XLA 기능을 빼와서 독립적인 모듈로 만들어낸 것. @tf.function 데코레이터만 함수에 붙여주면 JIT 연산이 가능해집니다.

Case 0

fn, fn_jit, fn_tf2 비교

출처
https://brunch.co.kr/@chris-song/99
JAX JIT 이 numpy에 비해 약 367배 빠르다.

참고 사이트
https://www.tensorflow.org/xla?hl=ko
https://github.com/google/jax

  1. fn: 순수한 numpy를 사용하는 함수
  2. fn_jit: JAX의 JIT 컴파일을 사용하는 함수

Case 1

간단한 신경망을 구현.
이 신경망은 두 개의 완전 연결 계층(Dense layer)으로 구성됩니다.

  1. fn: 순수한 numpy를 사용하는 함수
  2. fn_jit: JAX의 JIT 컴파일을 사용하는 함수
import numpy as np
import jax.numpy as jnp
from jax import jit, random
import tensorflow as tf

# 순수 numpy를 사용하는 함수
def fn(x, w1, w2, b1, b2):
    z1 = np.dot(x, w1) + b1
    a1 = np.maximum(z1, 0)  # ReLU activation
    z2 = np.dot(a1, w2) + b2
    return z2

# JAX의 JIT 컴파일을 사용하는 함수
@jit
def fn_jit(x, w1, w2, b1, b2):
    z1 = jnp.dot(x, w1) + b1
    a1 = jnp.maximum(z1, 0)  # ReLU activation
    z2 = jnp.dot(a1, w2) + b2
    return z2


# 테스트 데이터 및 가중치 생성
x = np.random.randn(1000, 784)
w1 = np.random.randn(784, 512)
w2 = np.random.randn(512, 10)
b1 = np.random.randn(512)
b2 = np.random.randn(10)


# 실행 시간 측정
import time

start = time.time()
fn(x, w1, w2, b1, b2)
numpy_time = time.time() - start

start = time.time()
fn_jit(x, w1, w2, b1, b2)
jax_time = time.time() - start

print(f"Numpy 시간: {numpy_time:.5f} 초")
print(f"JAX JIT 시간: {jax_time:.5f} 초")

Numpy 시간: 0.07953 초
JAX JIT 시간: 0.03624 초

profile
I'm an graduate student majoring in Computer Engineering at Inha University. I'm interested in Machine learning developing frameworks, Formal verification, and Concurrency.

0개의 댓글