JAX/Flax로 딥러닝 레벨업

Erdos·2024년 9월 28일
1

감상

목록 보기
37/42
post-thumbnail

🟣 책에 대하여

대상 독자

  • 파이썬 프로그래밍에 익숙한 개발자 혹은 연구원
  • 기본적인 머신러닝/딥러닝 지식을 보유한 독자
  • 고성능 컴퓨팅 및 GPU/TPU 활용에 관심 있는 사람

🟣 Chapter2 JAX의 특징

  • 구글에서 개발한 고성능 수치 계산 라이브러리
  • 병렬 가속화 기능을 통해 대규모 모델의 효율적인 학습과 추론 가능

1. JIT 컴파일

  • just-in-time compile: XLA(Accelerated Linear Algebra)를 사용하여 런타임 시 코드를 컴파일하고 최적화

✔️ 참고 JIT 컴파일 : 프로그램을 실제 실행하는 시점에 기계어로 번역하는 컴파일 기법. 일반적으로 실행 속도가 빨라짐.

2. 자동 벡터화

  • 벡터화: 연산을 한 번에 여러 값에 적용하여 실행 시간을 줄이는 기법
  • 하드웨어 가속기에서 병렬처리를 가능하게 함.

3. 자동 미분

  • 연산의 시퀀스가 주어지면 이 - 좌항만으로 &&와 || 연산 결과를 판별하는 기능
    • 불필요한 연산을 줄여 실행 속도를 높일 수는 있으나 예상 외의 결과가 나올 수 있으니 주의가 필요 프로그램의 미분(변화율)을 계산하는 일련의 기술
  • 파이토치는 torch.autograd의 동적 연산 그래프를 사용한다. (조건부 실행과 같은 복잡한 제어 흐름을 가진 모델을 쉽게 구현할 수 있음)
    반면에, JAX는 더 '함수적'인 접근 방식을 취하는데, 호출 가능한 함수에 자동 미분을 적용하여 새로운 함수를 만든다. (복잡한 계산 그래프를 더 명확하게 추론하고 최적화하는데 이점을 가짐)

4. JAX의 난수

  • 의사 난수 생성(pseudo random number generation, PRNG): 적절한 분포에서 추출된 난수 시퀀스의 속성과 근사한 속성을 가진 숫자 시퀀스를 알고리즘적으로 생성하는 프로세스

5. pytree

https://jax.readthedocs.io/en/latest/pytrees.html
개념적으로는 낯선 용어다. 😟

  • JAX에서의 트리 구조와 유사한 구조. 파이썬 객체로 구성된 컨테이너처럼 작동한다.

6. 병렬 처리

여러가지 방법 중에 가장 간단하고 기본적인 pmap를 살펴봄

  • SPMD(single-program, multiple-data) 병렬처리를 사용 : 병렬 프로그램을 수행하는 모든 프로세스나 스레드가 동일한 하나의 프로그램을 실행하면서 프로그램 내의 함수가 서로 다른 데이터를 계산하는 형태
  • parallel mapping(병렬 매핑): 내부적으로 3가지 과정.
    - 입력 데이터 분할
    • 병렬 실행
    • 결과 수집

Tutorial_JAX


🟣 Chapter3 Flax 소개

  • JAX + Flexibility = JAX를 조금 더 쉽게 사용할 수 있게 만든 프레임워크
  • 고성능 신경망 라이브러리이자 에코 시스템
  • JAX의 모든 기능을 제공
  • end-to-end
  • 특징(4가지)
    • 안정성
    • 제어 기능
    • 함수형 API
    • 코드 간결성

🟣 Chapter4 JAX/Flax를 활용한 딥러닝 모델 만들기

내게 생소했던 DCGAN, CLIP에 대해서만 정리함

1. CNN

2. ResNet

3. DCGAN(Deep Convolutional Generative Adversarial Network)


책에서 Flax를 이용하여 합성곱 레이어로 구성된 DCGAN을 구현하고 있다.

  • DCGAN: 고품질 이미지를 생성하기 위한 GAN의 한 유형.
  • 장점

    (이미지 출처 블로그)
    • 매우 사실적인 이미지를 생성하는 데 뛰어나다.(deep convolutional networks를 사용해서)
    • GAN보다 훈련 중에 더 안정적이다. 배치 정규화로 더 신뢰성 있는 학습과 모드 붕괴와 같은 이슈를 줄일 수 있다.
    • 이미지 품질 향상, 새로운 디자인 만들기 등 다양한 곳에 응용이 가능하다.
    • 훈련 이미지를 생성할 수 있어서, 모델 성능 향상에 도움이 된다. (특히, 데이터가 부족한 경우)

4. CLIP

openai github

  • Contrastive Language-Image Pre-Training, OpenAI에서 개발한 대규모 이미지-텍스트 쌍으로 학습된 멀티모달 신경망
  • 특징
    • 비지도 학습: 대규모 이미지-텍스트 쌍 데이터셋으로 이미지와 텍스트 쌍 간의 유사성을 측정하여 특징 학습
    • zero-shot learning: 모델이 훈련하지 않은 특정 클래스에 대해서도 인식하거나 분류할 수 있음.

5. DistilGPT2 미세조정 학습


(번외) PyTorch vs TensorFlow vs JAX

library개발주요 강점사례
PyTorchMeta동적 그래프, 사용 편의성연구, 프로토타입 제작 및 중소 규모 애플리케이션.
빠른 실험에 매우 적합
TensorFlowgooglescalable, 생산 준비가 완료된 생태계
(TensorFlow Serving, TensorFlow Lite, TensorFlow.js(JS용), TFX 등
모델을 쉽게 배포하고 관리하도록 설계된 도구 및 서비스가 제공됨)
클라우드, 모바일, 웹 등 여러 환경에 걸쳐 확장해야 하는 대규모 애플리케이션
JAXgoogle자동 미분, JIT 컴파일수치 컴퓨팅, 과학 연구, 고성능 ML 실험

책에 대한 감상

개인적으로 책에 대해 아쉬운 점

  • JAX라는 뛰어난 라이브러리의 중요성을 알리고자 하는 책.
  • 단, 책 자체는 미완성이라고 느꼈다. JAX/Flax 처음은 알 수 있지만, 완전하게 담았나? 했을 때 의문이 든다.

그럼에도 불구하고

  • JAX/Flax의 기본 내용부터 LLM 파인튜닝까지 다양한 예제들을 소개하고 있다.
    현재 자료가 부족한 터라, JAX/Flax를 처음 만나고 학습하기에 적합한 한글책인 것 같다.
  • JAX/Flax를 활용한 예제를 기본부터 탄탄하게 제공하고 있음.
    (CNN, ResNet, DCGAN, CLIP 중에서도 DCGAN, CLIP fine-tuning이 특히 좋았다. )
    더불어, 딥러닝을 배울 때 기본으로 배우는 예제들을 JAX/Flax를 활용해서 구현하기 때문에 차이점들을 보기 좋았다. 당장 캐글이나 실무에 도입이 가능한가를 확인해보고자 한다면, 더 도움이 될 것이라고 생각한다.
profile
수학을 사랑하는 애독자📚 Stop dreaming. Start living. - 'The Secret Life of Walter Mitty'

0개의 댓글