[Study] JAX&Flax(1) - XLA {작성중}

HamYHoon·2024년 9월 28일

Study

목록 보기
1/2

본 포스트는 "모두의 연구소"JAX&Flax로 딥러닝 레벨업이라는 책을 공부하며 정리하는 글입니다. chatGPT, Napkin AI 등을 사용하여 작성합니다.


INTRO

자동 미분(autograd)과 XLA를 결합한 고성능 ML 프레임워크

여기서 XLA란 "Accelerated Linear Algebra"를 말하며, 구글에서 만든 딥러닝 전용 하드웨어인 TPU에서 numpy 프로그램을 컴파일하고 실행할 수 있게 만든다.

XLA

  • TensorFlow와 같은 머신러닝 프레임워크에서 사용되는 최적화 컴파일러
  • 모델 성능 향상을 위해 그래프를 최적화하고 하드웨어 가속기를 보다 효율적으로 사용하게 함
    => 계산 속도를 빠르게 하고 메모리 사용량을 줄이며 하드웨어 리소스를 더 효율적으로 활용할 수 있음

동작 원리

1. 계산 그래프

그래프 생성(Graph Construction)

딥러닝 모델의 기본 구조는 여러 수학적 연산의 조합으로 이루어져 있으며, 이 연산들이 연결된 형태를 계산 그래프라고 한다. 계산 그래프는 딥러닝 모델이 입력 데이터로부터 출력 값을 계산하는 전체 과정을 명확하게 시각적으로 나타낸 것이라고 볼 수 있다.

  • 각각의 연산은 노드(Node) 로 표현되며, 연산 간의 데이터 흐름은 엣지(Edge) 로 연결된다.

  • TensorFlow, PyTorch와 같은 프레임워크에서는 각 연산을 정의할 때 그래프가 자동으로 생성된다. 예를 들어, y = a * x + b라는 선형 회귀 모델을 생각해 보면, 이 수식은 두 개의 연산인 곱셈 연산과 덧셈 연산으로 나눌 수 있다.

    • a * x: 곱셈 노드
    • a * x + b: 덧셈 노드

이들은 두 개의 노드로 변환되고, 각각의 노드는 입력 변수 a, x, b로부터 값을 받아 연산을 수행한 후 결과값을 다음 노드로 전달한다.


정적(Static) 그래프와 동적(Dynamic) 그래프

그래프는 크게 정적 그래프와 동적 그래프로 나뉩니다.

  • 정적 그래프(Static Graph)
    • TensorFlow는 전통적으로 정적 그래프 방식을 사용한다.
    • 정적 그래프는 모델을 정의한 후 한 번 그래프를 생성하면 그 그래프를 수정하지 않고 그대로 실행한다.
    • 이러한 방식은 모델을 실행하기 전에 전체 그래프를 최적화하고, 이후 여러 번 실행할 때 성능이 매우 뛰어난 장점이 있다.
      • XLA가 성능 최적화를 할 수 있는 주요 이유는 이 정적 그래프 덕분에 그래프 전체를 분석하고 최적화할 수 있기 때문이다.
  • 동적 그래프(Dynamic Graph)
    • PyTorch는 동적 그래프 방식을 채택하고 있다.
    • 이 방식에서는 그래프가 매번 실행될 때마다 동적으로 생성된다. 즉, 연산이 일어날 때마다 그때그때 그래프가 만들어지는 방식이다.
      • 동적 그래프는 유연성이 크고 디버깅이 쉽지만, 최적화 측면에서는 정적 그래프에 비해 어려움이 있다.

그래프의 구성 요소 (Components of a Computational Graph)

  • 노드(Node) : 노드는 각각의 수학적 연산을 의미한다. 딥러닝 모델의 모든 연산은 각각 노드로 표현되며, 이는 결국 신경망에서 가중치와 입력 데이터를 기반으로 계산이 이루어지는 개별 연산을 나타낸다.

  • 엣지(Edge) : 엣지는 한 노드의 출력이 다음 노드의 입력으로 전달되는 흐름을 나타낸다. 즉, 엣지는 데이터가 어떻게 이동하고 연산에 사용되는지를 정의하는 중요한 요소이다.

  • 입력 노드(Input Node) : 입력 노드는 외부에서 주어진 데이터를 나타낸다. 예를 들어, 이미지 데이터가 입력되는 경우, 이 입력 데이터는 연산을 거쳐 모델을 통과하면서 다양한 노드들에서 처리된다.

  • 출력 노드 (Output Node) : 마지막으로 연산 결과가 도출되는 노드이다. 이 노드에서 모델의 예측값 또는 분류 결과가 계산된다.


그래프의 구조 최적화 가능성

연산 최적화의 관점에서 그래프를 생성하는 것만으로도 최적화의 출발점이 된다.

XLA는 이 그래프의 구조를 분석하여 불필요한 연산을 찾아 제거하거나, 연산이 병목현상을 일으키는 구간을 찾아 병합할 수 있다. 예를 들어, 여러 연산들이 동일한 입력 데이터를 참조하는 경우, 데이터를 여러 번 복사하지 않고 한 번만 메모리에서 불러오도록 최적화할 수 있다.

  • 연산의 병합 : 여러 개의 연산을 하나로 결합하여, 메모리 접근 횟수를 줄이고, 불필요한 데이터 이동을 최소화할 수 있다. 이를 통해 하드웨어에서의 cash 활용도를 높이고 연산 속도를 극대화한다.
  • 데이터 흐름의 병목 제거 : 그래프 내에서 특정 연산이 병목을 일으키는 경우가 있는데, XLA는 이런 병목을 찾아내어 연산을 병렬화하거나 병합하여 성능을 향상시킨다.
profile
To Be Data Scientist

0개의 댓글