JAX
jax.numpy->cpu memory 대신 gpu memory 사용
grad->함수를 미분
jit->함수를 input으로 받고 컴파일된 함수를 return. cache개념
vmap->함수를 vectorize함
pmap->여러개의 gpu로 vmap 실행
reduce
얘가 결국 tree 형태다. tree 형태이기 때문에 각각의 leaf에 함수를 병렬적으로 적용시킬 수 있다.
jax.tree_util.tree_reduce
jnp.prod
원소들의 곱을 return
jax.value_and_grad
함수의 결과값과 도함수를 return
has_aux=True면 (함수값, auxilary data)와 도함수를 return
jax.tree_map
getattr()
flax.optim
functools.partial
함수의 class instance 느낌
numpy.clip
상한값, 하한값 설정
Random
PRNGKey -> 난수 생성기
random.split -> 난수 생성기를 쪼개서 여러개를 만듦
random.uniform -> 0-1사이의 숫자를 뽑음
call
https://technote.kr/344
https://jax.readthedocs.io/en/latest/index.html
https://hamait.tistory.com/823
https://wjunsea.tistory.com/61