🟣 책에 대하여
대상 독자
- 파이썬 프로그래밍에 익숙한 개발자 혹은 연구원
- 기본적인 머신러닝/딥러닝 지식을 보유한 독자
- 고성능 컴퓨팅 및 GPU/TPU 활용에 관심 있는 사람
🟣 Chapter2 JAX의 특징
- 구글에서 개발한 고성능 수치 계산 라이브러리
- 병렬 가속화 기능을 통해 대규모 모델의 효율적인 학습과 추론 가능
1. JIT 컴파일
- just-in-time compile: XLA(Accelerated Linear Algebra)를 사용하여 런타임 시 코드를 컴파일하고 최적화
✔️ 참고 JIT 컴파일 : 프로그램을 실제 실행하는 시점에 기계어로 번역하는 컴파일 기법. 일반적으로 실행 속도가 빨라짐.
2. 자동 벡터화
- 벡터화: 연산을 한 번에 여러 값에 적용하여 실행 시간을 줄이는 기법
- 하드웨어 가속기에서 병렬처리를 가능하게 함.
3. 자동 미분
- 연산의 시퀀스가 주어지면 이 - 좌항만으로 &&와 || 연산 결과를 판별하는 기능
- 불필요한 연산을 줄여 실행 속도를 높일 수는 있으나 예상 외의 결과가 나올 수 있으니 주의가 필요 프로그램의 미분(변화율)을 계산하는 일련의 기술
- 파이토치는 torch.autograd의 동적 연산 그래프를 사용한다. (조건부 실행과 같은 복잡한 제어 흐름을 가진 모델을 쉽게 구현할 수 있음)
반면에, JAX는 더 '함수적'인 접근 방식을 취하는데, 호출 가능한 함수에 자동 미분을 적용하여 새로운 함수를 만든다. (복잡한 계산 그래프를 더 명확하게 추론하고 최적화하는데 이점을 가짐)
4. JAX의 난수
- 의사 난수 생성(pseudo random number generation, PRNG): 적절한 분포에서 추출된 난수 시퀀스의 속성과 근사한 속성을 가진 숫자 시퀀스를 알고리즘적으로 생성하는 프로세스
5. pytree
https://jax.readthedocs.io/en/latest/pytrees.html
개념적으로는 낯선 용어다. 😟
- JAX에서의 트리 구조와 유사한 구조. 파이썬 객체로 구성된 컨테이너처럼 작동한다.
6. 병렬 처리
여러가지 방법 중에 가장 간단하고 기본적인 pmap를 살펴봄
- SPMD(single-program, multiple-data) 병렬처리를 사용 : 병렬 프로그램을 수행하는 모든 프로세스나 스레드가 동일한 하나의 프로그램을 실행하면서 프로그램 내의 함수가 서로 다른 데이터를 계산하는 형태
- parallel mapping(병렬 매핑): 내부적으로 3가지 과정.
- 입력 데이터 분할
Tutorial_JAX
🟣 Chapter3 Flax 소개
- JAX + Flexibility = JAX를 조금 더 쉽게 사용할 수 있게 만든 프레임워크
- 고성능 신경망 라이브러리이자 에코 시스템
- JAX의 모든 기능을 제공
- end-to-end
- 특징(4가지)
🟣 Chapter4 JAX/Flax를 활용한 딥러닝 모델 만들기
내게 생소했던 DCGAN, CLIP에 대해서만 정리함
1. CNN
2. ResNet
3. DCGAN(Deep Convolutional Generative Adversarial Network)
책에서 Flax를 이용하여 합성곱 레이어로 구성된 DCGAN을 구현하고 있다.
- DCGAN: 고품질 이미지를 생성하기 위한 GAN의 한 유형.
- 장점
(이미지 출처 블로그)
- 매우 사실적인 이미지를 생성하는 데 뛰어나다.(deep convolutional networks를 사용해서)
- GAN보다 훈련 중에 더 안정적이다. 배치 정규화로 더 신뢰성 있는 학습과 모드 붕괴와 같은 이슈를 줄일 수 있다.
- 이미지 품질 향상, 새로운 디자인 만들기 등 다양한 곳에 응용이 가능하다.
- 훈련 이미지를 생성할 수 있어서, 모델 성능 향상에 도움이 된다. (특히, 데이터가 부족한 경우)
4. CLIP
openai github
- Contrastive Language-Image Pre-Training, OpenAI에서 개발한 대규모 이미지-텍스트 쌍으로 학습된 멀티모달 신경망
- 특징
- 비지도 학습: 대규모 이미지-텍스트 쌍 데이터셋으로 이미지와 텍스트 쌍 간의 유사성을 측정하여 특징 학습
- zero-shot learning: 모델이 훈련하지 않은 특정 클래스에 대해서도 인식하거나 분류할 수 있음.
5. DistilGPT2 미세조정 학습
(번외) PyTorch vs TensorFlow vs JAX
library | 개발 | 주요 강점 | 사례 |
---|
PyTorch | Meta | 동적 그래프, 사용 편의성 | 연구, 프로토타입 제작 및 중소 규모 애플리케이션. 빠른 실험에 매우 적합 |
TensorFlow | google | scalable, 생산 준비가 완료된 생태계 (TensorFlow Serving, TensorFlow Lite, TensorFlow.js(JS용), TFX 등 모델을 쉽게 배포하고 관리하도록 설계된 도구 및 서비스가 제공됨) | 클라우드, 모바일, 웹 등 여러 환경에 걸쳐 확장해야 하는 대규모 애플리케이션 |
JAX | google | 자동 미분, JIT 컴파일 | 수치 컴퓨팅, 과학 연구, 고성능 ML 실험 |
책에 대한 감상
개인적으로 책에 대해 아쉬운 점
- JAX라는 뛰어난 라이브러리의 중요성을 알리고자 하는 책.
- 단, 책 자체는 미완성이라고 느꼈다. JAX/Flax 처음은 알 수 있지만, 완전하게 담았나? 했을 때 의문이 든다.
그럼에도 불구하고
- JAX/Flax의 기본 내용부터 LLM 파인튜닝까지 다양한 예제들을 소개하고 있다.
현재 자료가 부족한 터라, JAX/Flax를 처음 만나고 학습하기에 적합한 한글책인 것 같다.
- JAX/Flax를 활용한 예제를 기본부터 탄탄하게 제공하고 있음.
(CNN, ResNet, DCGAN, CLIP 중에서도 DCGAN, CLIP fine-tuning이 특히 좋았다. )
더불어, 딥러닝을 배울 때 기본으로 배우는 예제들을 JAX/Flax를 활용해서 구현하기 때문에 차이점들을 보기 좋았다. 당장 캐글이나 실무에 도입이 가능한가를 확인해보고자 한다면, 더 도움이 될 것이라고 생각한다.