TorchScript 소개

심준석·2024년 3월 26일
8

TorchSciprt는 PyTorch 모델을 프로덕션 환경에서 쉽게 사용할 수 있도록 설계된 언어 및 컴파일러입니다. PyTorch는 주로 연구 및 개발에 사용되는 파이썬 기반의 딥러닝 프레임워크이지만, 프로덕션 환경에서는 모델을 배포할 때 다양한 문제를 직면할 수 있습니다. 예를 들어, 파이썬은 실행 속도가 느리거나 멀티스레딩을 효율적으로 처리하지 못하는데, TorchScript는 이러한 문제를 해결하기 위해 도입되었죠.

TorchScript를 사용하면 PyTorch 모델을 고정된 그래프로 변환할 수 있으며, 이 그래프는 파이썬의 동적 특성에 의존하지 않고도 실행할 수 있습니다. 이를 통해 모델을 더 빠르고 효율적으로 실행할 수 있으며, 다양한 플랫폼과 환경에서 모델을 배포할 수 있는 유연성을 제공합니다.

TorchScript를 사용하는 과정은 크게 두 가지 방법으로 나누어집니다.

트레이싱(Tracing)

모델 입력을 기반으로 실행 흐름을 추적하여 그래프를 생성시키는 방법입니다. 간단한 모델에서는 잘 작동하지만, 입력에 따라 변경되는 제어 흐름(예를 들면, if문이나 for문 등등..)을 가진 모델에서는 제대로 작동하지 않을 수 있죠.

스트립팅(Scripting)

스크립팅을 사용한다면 더 복잡한 모델에 대해 정적 그래프를 생성할 수 있습니다. PyTorch의 torch.jit.script 데코레이터를 사용하여 함수나 모듈을 TorchScript 코드로 변환합니다. 이 방법을 통해 트레이싱(Tracing)에서 수행하기 어려운 동적 제어 흐름을 포함하는 모델에서도 잘 작동할 수 있게 해줍니다.

TorchScript로 변환된 모델은 '.pt' 또는 '.pth' 파일 형식으로 저장할 수 있으며, PyTorch에서 더 이상 파이썬 인터프리터에 의존하지 않고 실행할 수 있습니다. 이는 모델을 C++에서 로드하여 사용하거나, PyTorch의 JIT 컴파일러를 통해 성능을 최적화하는 등 다양한 방식으로 모델을 배포할 때 유용합니다.

그럼 간단한 예제를 통해, TorchScript 실습을 시작하도록 하겠습니다.

1. PyTorch 기반 모델 구현

간단한 Module을 정의하는 것부터 시작하겠습니다. Module은 PyTorch의 기본 구성단위로,

  • 호출을 위해 모듈을 준비하는 생성자(Constructor)
  • 매개변수(Parameter)하위 모듈(Sub Module)
  • 모듈이 호출될 때 실행되는 포워드 함수(Forward Function)

으로 구성되어 집니다.

1-1. 기본 모듈 생성하기

간단한 예제를 보겠습니다

(예제 1.1)

import torch

class MyCell(torch.nn.Module): 				# --- (1)
    def __init__(self):						# --- (2)
        super(MyCell, self).__init__()		# --- (3)

    def forward(self, x, h):				# --- (4)
        new_h = torch.tanh(x + h)			# --- (5)
        return new_h, new_h					# --- (6)

my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))

(1) torch.nn.Module을 상속받아 Subclass MyCell을 생성합니다.
(2) 생성자를 정의하고
(3) 상속받은 부모 클래스의 생성자를 호출합니다.
(4) 포워드 함수를 정의하는 부분입니다.
(5) 임의의 연산을 통해 new_h에 결과를 저장하고,
(6) 포워드 함수의 연산결과 new_h를 반환하는 부분입니다.


[Output]

(tensor([[0.9477, 0.9414, 0.9269, 0.8735],
        [0.8352, 0.9189, 0.6752, 0.9103],
        [0.7473, 0.9511, 0.9347, 0.9080]]), 
 tensor([[0.9477, 0.9414, 0.9269, 0.8735],
        [0.8352, 0.9189, 0.6752, 0.9103],
        [0.7473, 0.9511, 0.9347, 0.9080]]))

1-2. Self.linear 속성 추가하기

이번에는 self.linear 속성을 추가해 MyCell 모듈을 재정의하도록 하겠습니다. MyCell 모듈에 self.linear 속성을 추가하고, 포워드 함수에서 self.linear 속성을 호출하는 과정을 추가하겠습니다.

(예제 1.2)

class MyCell(torch.nn.Module): 				            # --- (1)
    def __init__(self):						            # --- (2)
        super(MyCell, self).__init__()		            # --- (3)
        self.linear = torch.nn.Linear(4, 4)             # --- (4)

    def forward(self, x, h):				            # --- (5)
        new_h = torch.tanh(self.linear(x) + h)			# --- (6)
        return new_h, new_h					            # --- (7)
    
my_cell = MyCell()
print(my_cell)											# --- (8)
print(my_cell(x, h))

(1) torch.nn.Module을 상속받아 Subclass MyCell을 생성합니다.
(2) 생성자를 정의하고
(3) 상속받은 부모 클래스의 생성자를 호출합니다.
(4) self.linear 속성을 생성자에 추가합니다.
(5) 포워드 함수를 정의하는 부분입니다.
(6) 임의의 연산에 self.linear 속성의 연산을 추가했습니다.
(7) 포워드 함수의 연산결과 new_h를 반환하는 부분입니다.
(8) 에서는 Module을 print 함으로써 Module의 하위 클래스 계층에 대한 시각적 표현이 가능합니다. linear을 하위 클래스로 사용했으므로, 하위 클래스 linear과 하위 클래스의 매개 변수를 확인할 수 있습니다.


[Output]

MyCell(
  (dg): MyDecisionGate()
  (linear): Linear(in_features=4, out_features=4, bias=True)
)

(tensor([[ 0.9006, -0.2039,  0.6150,  0.0858],
        [ 0.6858,  0.2722,  0.6303,  0.1171],
        [ 0.6038, -0.0257,  0.6882,  0.3034]], grad_fn=<TanhBackward0>), tensor([[ 0.9006, -0.2039,  0.6150,  0.0858],
        [ 0.6858,  0.2722,  0.6303,  0.1171],
        [ 0.6038, -0.0257,  0.6882,  0.3034]], grad_fn=<TanhBackward0>))

기본 모듈과 다른 출력을 확인할 수 있습니다. 여기서 grad_fn을 확인할 수 있습니다. 이것은 오토그라드(autograd)라 불리는 PyTorch의 자동 미분 방법의 세부 정보로, 출력과 연결된 그래디언트 함수(Gradient function)를 나타내며 역전파(Back-propagation) 단계에서 그래디언트를 계산하는데 사용됩니다. 이 방법을 통해 모델 제작에 엄청난 유연성을 얻을 수 있죠.

1-3. 오토그라드를 통한 복잡한 모듈 구현하기

제어문이 포함된 복잡한 형태의 모듈을 구현해보도록 하겠습니다.

(예제 1.3)

class MyDecisionGate(torch.nn.Module):                  # --- (1)
    def forward(self, x):                               # --- (2)
        if x.sum() > 0:                                 # --- (3)
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.dg = MyDecisionGate()                      # --- (4)
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h) # --- (5)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))

(1) 복잡한 모듈 구성을 위해 제어문이 포함된 모듈인 MyDecisionGate를 정의했습니다.
(2) MyDecisionGate의 포워드 함수를 정의하는 부분으로
(3) 포워드 함수에는 if문을 활용한 제어문이 포함되어 있습니다
(4) 클래스 기반으로 정의된 MyDecisionGate를 인스턴스 방식으로 MyCell에서 속성(self.dg)으로 선언합니다.
(5) linear를 하위 클래스로 사용한것과 같이 self.dg를 하위 클래스로 사용합니다.

[Output]
MyCell(
  (dg): MyDecisionGate()
  (linear): Linear(in_features=4, out_features=4, bias=True)
)

(tensor([[ 0.5814,  0.8933,  0.6983,  0.6155],
        [ 0.3094,  0.9222,  0.6860,  0.5930],
        [-0.0626,  0.9101,  0.7345,  0.7340]], grad_fn=<TanhBackward0>), tensor([[ 0.5814,  0.8933,  0.6983,  0.6155],
        [ 0.3094,  0.9222,  0.6860,  0.5930],
        [-0.0626,  0.9101,  0.7345,  0.7340]], grad_fn=<TanhBackward0>))

출력 결과와 같이, MyCellself.dg가 하위 클래스로 포함되는 것을 확인할 수 있습니다.

2. TorchScript 기초

왜 TorchSciprt를 사용해야할 까요?

TorchScript를 사용하는 이유는 크게 네 가지로 요약할 수 있습니다:

병렬 처리 및 성능 향상: TorchScript 코드는 자체적인 인터프리터에서 실행될 수 있으며, 이는 Global Interpreter Lock(GIL)을 필요로 하지 않습니다. GIL이 없음으로써 동일한 인스턴스에서 여러 요청을 동시에 처리할 수 있게 되어, 병렬 처리 능력과 성능이 크게 향상됩니다.

플랫폼 및 언어 독립성: TorchScript를 통해 생성된 모델은 디스크에 저장되어 Python 이외의 다른 프로그래밍 언어로 작성된 환경에서도 쉽게 불러와 사용할 수 있습니다. 이는 모델을 다양한 환경과 플랫폼에 배포하는 데 큰 유연성을 제공합니다.

컴파일러 최적화: TorchScript는 코드를 보다 효율적으로 실행할 수 있도록 컴파일러 최적화를 가능하게 하는 특정 표현을 제공합니다. 이 최적화를 통해 모델의 실행 속도와 효율성이 개선됩니다.

백엔드 및 장치 런타임과의 상호작용: TorchScript를 사용하면 개별 연산자보다는 프로그램 전체를 더 넓은 관점에서 보고, 다양한 백엔드 및 장치 런타임과 효과적으로 상호작용할 수 있습니다. 이는 특히 다양한 하드웨어 플랫폼에서 모델을 최적화하고 실행할 필요가 있을 때 유용합니다.

이러한 이유들 때문에 TorchScript는 PyTorch 모델을 더 넓은 범위의 환경과 플랫폼에 배포하고, 성능을 향상시키며, 병렬 처리 능력을 극대화하기 위한 강력한 도구로 사용됩니다.

Global Interpreter Lock(GIL) ? GIL은 멀티스레딩 환경에서 한 번에 하나의 스레드만이 Python 객체에 접근하도록 제한하는 메커니즘입니다. 이는 Python의 메모리 관리가 스레드에 안전하지 않기 때문에 필요한 것으로, 데이터의 무결성을 보장하고 race condition(데이터나 자원에 동시에 접근하고 변경을 시도할 때 발생하는 문제)을 방지하기 위해 도입되었습니다.

2-1. 모듈 트레이싱(Tracing)

TorchScript는 Python의 동적 특성을 지원하면서도 정적 분석이 가능한 코드 형태로, PyTorch 모델을 더 빠르게 실행할 수 있게 해주며, 다양한 환경에서 모델을 실행할 수 있는 호환성을 제공합니다. TorchScript를 사용하는 과정은 크게 두 가지 방법으로 나눌 수 있는데, 먼저 트레이싱부터 알아보도록 하겠습니다.

트레이싱이란? 트레이싱은 PyTorch 모델의 실행을 추적하여, 모델이 수행하는 연산들을 TorchScript라는 중간 표현(IR)으로 변환하는 과정입니다.

(예제 1.2)를 가져와 MyCell모듈을 TorchScript로 변환해보겠습니다. torch.nn.Module을 TorchScript로 변환하기 위해선 torch.jit.trace 함수가 필요합니다.

(예제 1.4)


class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))      # --- (1)
print(traced_cell)                                  # --- (2)
print(traced_cell.graph)                            # --- (3)
print(traced_cell.code)                             # --- (4)

(1) MyCell 모듈 기반의 TorchScript가 생성되는 과정입니다. TorchScript가 생성되는 과정은 다음과 같습니다.

  • 인스턴스화 및 트레이싱 : MyCell 모듈을 인스턴스화 하고, 모델이 처리할 입력 예제를 준비합니다. 그런 다음 torch.jit.trace 함수에 이 모듈과 입력 예제를 전달하여 트레이싱 과정이 시작됩니다.
  • 연산 기록 : 모듈이 입력을 처리할 때, torch.jit.trace는 이 과정에서 발생하는 모든 연산을 추적하고 기록합니다. 이렇게 하여 모델의 실행 경로를 정확하게 파악할 수 있습니다.
  • TorchScript 생성 : 추적 과정이 완료되면, 기록된 연산들은 TorchScript 코드로 변환되어 torch.jit.ScriptModule 인스턴스로 저장됩니다. 이 코드는 Python 인터프리터 없이도 실행될 수 있으며, 다양한 환경에서 모델을 더 빠르게 실행할 수 있습니다.

(3) 그래프를 확인하는 부분으로, 모델의 계산 그래프를 확인할 수 있습니다.

[Output]
graph(%self.1 : __torch__.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : int = prim::Constant[value=1]() # /var/folders/p3/5cd6lj5j7q11887yhw6dp9cm0000gn/T/ipykernel_43557/3757428429.py:7:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /var/folders/p3/5cd6lj5j7q11887yhw6dp9cm0000gn/T/ipykernel_43557/3757428429.py:7:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /var/folders/p3/5cd6lj5j7q11887yhw6dp9cm0000gn/T/ipykernel_43557/3757428429.py:7:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)

(4) 그러나 .graph를 사용할 경우, 저수준으로 표현되며 최종 사용자 입장에서는 직관성이 떨어집니다. 대신 .code 속성을 사용해 코드에 대한 Python 구문 해석을 제공할 수 있습니다.

[Output]
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.tanh(torch.add((linear).forward(x, ), h))
  return (_0, _0)

trace_cell을 호출하면 my_cell(Python 모듈)과 동일한 결과가 생성됩니다.

Python 모듈의 포워드 함수 연산 결과

print(my_cell(x, h))
[Output]
(tensor([[0.4135, 0.8980, 0.4250, 0.7770],
        [0.5633, 0.5571, 0.6075, 0.6998],
        [0.6665, 0.8463, 0.2813, 0.4935]], grad_fn=<TanhBackward0>), tensor([[0.4135, 0.8980, 0.4250, 0.7770],
        [0.5633, 0.5571, 0.6075, 0.6998],
        [0.6665, 0.8463, 0.2813, 0.4935]], grad_fn=<TanhBackward0>))

TorchScript 모듈의 포워드 함수 연산 결과

print(traced_cell(x, h))
[Output]
(tensor([[0.4135, 0.8980, 0.4250, 0.7770],
        [0.5633, 0.5571, 0.6075, 0.6998],
        [0.6665, 0.8463, 0.2813, 0.4935]], grad_fn=<TanhBackward0>), tensor([[0.4135, 0.8980, 0.4250, 0.7770],
        [0.5633, 0.5571, 0.6075, 0.6998],
        [0.6665, 0.8463, 0.2813, 0.4935]], grad_fn=<TanhBackward0>))

2-2. 스크립팅(Scripting)

트레이싱만 사용할 경우, 하위 모듈로 제어문이 포함된 복잡한 모듈을 사용한다면 제대로 작동하지 않는 문제가 있습니다.

(예제 1.5)


class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))

print(traced_cell.dg.code)						# --- (1)
print(traced_cell.code)							# --- (2)

(예제 1.5)에서는

(1)의 출력으로는

[Output]
def forward(self,
    argument_1: Tensor) -> Tensor:
  return torch.neg(argument_1)

(2)의 출력으로는

[Output]
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  _1 = torch.tanh(_0)
  return (_1, _1)

입니다. 어디에서도 제어문(if-else)이 확인할 수 없습니다! 트레이싱은 코드를 실행하고 발생하는 작업을 기록하며 스크립트 모듈(ScriptModule)을 구성하는 일을 수행합니다. 안타깝지만, 제어 흐름과 같은 것들이 지워지게 되죠.

TorchScript에서 제어문을 사용하려면, 스크립팅 메커니즘을 사용해야합니다. TorchScript는 정적 그래프를 생성하기 위해 모델의 코드를 분석하는데, 이 과정에서 동적인 Python 코드를 처리하는 데 제한이 있죠. 그러나 torch.jit.script 데코레이터나 함수를 사용하면 제어문이 포함된 Python 함수나 모듈을 TorchScript 코드로 변환할 수 있고, 동적 제어 구조도 처리할 수 있게 됩니다.

아래 예제를 한번 볼까요?

(예제 1.6)

scripted_gate = torch.jit.script(MyDecisionGate())  # --- (1)

my_cell = MyCell(scripted_gate)                     # --- (2)
scripted_cell = torch.jit.script(my_cell)           # --- (3)

print(scripted_gate.code)							# --- (4)
print(scripted_cell.code)							# --- (5)

MyCell은 (예제 1.5)에서 사용한 모듈을 다시 사용하도록 하겠습니다.

(1) MyDecisionGate() 인스턴스를 생성하고, 이를 torch.jit.script 함수에 전달하여 TorchScript로 변환합니다. 이 과정에서 MyDecisionGate 클래스 내의 코드는 정적 분석을 거쳐 TorchScript 호환 코드로 변환됩니다.

결과적으로, 동적인 Python 코드가 실행 가능한 정적인 그래프 형태로 컴파일 되어 scripted_gate 변수에 저장됩니다.

(2) scripted_gate 인스턴스를 매개변수로 사용하여 MyCell 클래스의 인스턴스르 생성합니다. MyCell의 생성자 함수가 scripted_gate 객체를 내부적으로 사용하거나 저장하는 방식으로 구현되었음을 의미합니다. 이 과정은 TorchScript 변환과는 직접적인 관련은 없으며, Python 레벨에서의 일반적인 객체 생성과 인스턴스화 과정입니다.

(3) my_cell 인스턴스를 torch.jit.script 함수에 전달하여 TorchScript로 변환합니다. 이 단계에서 MyCell 내부의 구현이 TorchScript 코드로 컴파일됩니다.

변환된 객체는 scripted_cell 변수에 저장되고, scripted_cellPython 코드가 아닌, TorchScript 환경에서 최적화되어 실행될 수 있는 코드를 포함합니다.

(4) 의 출력을 보면, 제어문이 포함되어 있음을 알 수 있고

[Output]
def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

(5) 의 출력을 보면, (4)의 포워드 함수가 하위 모듈로 포함되어 있음을 확인할 수 있습니다.

[Output]
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)

2-3 트레이싱과 스크립팅

스크립팅과 트레이싱, 이 두 방법은 서로 다른 사용 사례와 제약 사항을 가지고 있으며, 어느 것이 더 좋다고 일반화하기보다는 각각의 장단점을 이해하고 상황에 맞게 선택하는 것이 중요합니다.

트레이싱의 장/단점

  • 장점 :
    동적인 제어 흐름(예: 반복문, 조건문 등)을 처리할 수 있습니다. 즉, 실행 시점에 따라 변할 수 있는 모델의 동작을 정확히 반영할 수 있습니다.
    타입 시스템을 통해 더 엄격한 코드 검사를 제공하므로, 버그를 미리 발견하고 수정할 수 있습니다.

  • 단점 :
    모든 코드가 TorchScript의 제한된 서브셋으로 작성되어야 하며, 일부 Python 기능은 지원되지 않습니다.
    스크립팅 과정이 복잡한 모델에서는 더 오래 걸릴 수 있습니다.

스크립팅의 장/단점

  • 장점 :
    사용하기 쉽고, 모델의 입력과 출력만으로 작동하여, 코드 변환 과정이 매우 빠릅니다.
    대부분의 PyTorch 연산을 지원하며, 순수하게 데이터 흐름에 기반한 모델에 대해서는 매우 효율적입니다.

  • 단점 :
    동적인 제어 흐름을 추적할 수 없습니다. 즉, 모델의 실행 동안 조건문이나 반복문의 결과가 입력 데이터에 따라 달라지는 경우, 이를 정확히 반영하지 못할 수 있습니다.
    트레이싱은 실행 시점에 실제로 수행된 연산만을 기록하기 때문에, 입력 데이터에 크게 의존적입니다. 다양한 조건이나 경로를 포함하는 모델의 경우, 일부 경로가 누락될 수 있습니다.

스크립팅과 트레이싱의 혼합 사용

스크립팅과 트레이싱을 함께 사용하는 접근 방식은 모델의 특정 부분에서는 동적인 제어 흐름을 처리하는 스크립팅의 유연성을, 다른 부분에서는 트레이싱의 단순성과 효율성을 활용해야합니다.

트레이싱으로 생성된 모듈(또는 함수)은 스크립팅된 모델 내에서 호출될 때 해당 호출이 "인라인"되어, 트레이싱된 모델의 연산이 스크립팅된 모델의 TorchScript 코드 내에 직접 포함됩니다. 반대로도 마찬가지로, 스크립팅된 부분이 트레이싱 과정 중에 인라인 될 수 있습니다.

(예저 1.7)

class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))  # --- (1)

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())                            # --- (2)
print(rnn_loop.code)

(1) MyCell 인스턴스는 (x, h)라는 특정 입력에 대해 torch.jit.trace 함수를 사용해 트레이싱됩니다. 특정 입력 (x, h)MyCell이 처리할 텐서의 예시로, 트레이싱 과정에서 MyCell의 실행 경로를 추적하는 데 사용됩니다.

트레이싱을 통해, MyCell 내부에서 실행되는 연산들이 TorchScript 코드로 변환되며, 이 과정은 모듈이 실행 시간에 정적인 구조를 가진다고 가정됩니다. (동적인 행위(제어문 등)를 정확히 캡처하지 못할 수 있습니다)

(2) torch.jit.script를 사용해 MyRNNLoop클래스를 전체적으로 스크립팅합니다. 스크립팅 과정에서 MyRNNLoop의 포워드 함수 내부 로직을 분석하고, Python 코드가 TorchScript 코드로 변환됩니다.

이 과정을 통해 MyRNNLoop의 인스턴스는 TorchScript 환경에서 실행 가능한 코드를 포함하며, 제어문 등이 포함된 동적인 특성도 정확히 처리할 수 있게 됩니다.

(예제 1.8)

class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())               # --- (1)

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))     # --- (2)
print(traced.code)

(1) WrapRNNinit 메서드 내에서 MyRNNLoop() 인스턴스를 스크립팅합니다. MyRNNLoop() 모듈의 인스턴스를 생성하고, torch.jit.script 함수를 사용해 스크립팅합니다. 이를 통해 MyRNNLoop 클래스를 정적 그래프 형태로 변환합니다.

이 변환을 통해 동적인 제어 흐름과 조건문 등을 처리할 수 있으며, 모델의 실행 속도 향상 및 다양한 플랫폼에서 모델을 배포할 수 있게 되어 호환성이 향상되게 됩니다.

(2) WrapRNN()의 인스턴스를 트레이싱하는 과정입니다. WrapRNN() 클래스의 인스턴스를 생성하고, 이를 초기 입력값(torch.rand(10, 3, 4))과 함께 torch.jit.trace 함수에 전달하여 트레이싱합니다. 모델에 실제 데이터를 통과시키며 모델 실행 중 발생하는 모든 연산을 추적하여, 이 연산들로 구성된 정적 그래프를 생성합니다.

[Output]
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)

3. 모델 저장 및 불러오기

TorchScript는 PyTorch 모델을 최적화하고, 다양한 플랫폼에서 모델을 실행할 수 있도록 도와주는 기술입니다. 모델을 TorchScript로 변환하면, 이 모델은 Python 인터프리터 없이도 실행 가능한 독립적인 형태로 저장됩니다. 이 변환된 모델을 디스크에 저장하고, 나중에 다시 불러와서 사용할 수 있는 기능을 PyTorch는 제공합니다. 이 과정을 통해 모델을 다른 서버나 기기에서도 쉽게 배포하고 사용할 수 있습니다.

3-1 TorchScript 모델 저장하기

모델을 TorchScript 형식으로 저장할 떄는 .save 함수를 사용합니다. 예를 들어, trace.save("wrapped_rnn.pt") 코드는 TorchScript 형식의traced 모델을 'wrapped_rnn.pt'라는 파일로 저장합니다.

여기서 .pt 확장자는 PyTorch 모델 파일임을 나타내며, 모델의 코드 매개변수, 속성, 그리고 디버그 정보 등이 포함되어 있어 있습니다.


traced.save('wrapped_rnn.pt')

[Output]

3-2 TorchScript 모델 불러오기

저장된 TorchScript 모델을 불러올 때는 torch.jit.load 함수를 사용합니다. 예를 들어, loaded = torch.jit.load('wrapped_rnn.pt') 는 'wrapped_rnn.pt'파일에서 모델을 불러와 loaded변수에 할당합니다. 이렇게 불러온 모델은 저장할 때와 동일한 상태를 가지며, 즉시 실행이 가능하도록 합니다.

loaded = torch.jit.load('wrapped_rnn.pt')

4. 마무리

지금까지 Python 기반의 PyTorch 모델을 TorchScript로 변환하는 과정을 알아보았습니다. 다음 장에서는 이번 장에서 생성한 TorchScript를 C++에 로드하는 방법을 알아보도록 하겠습니다.

읽어주셔서 감사합니다. 😊

profile
Developer & Publisher 심준석 입니다.

0개의 댓글

관련 채용 정보