PyTorch - 모델 가중치 초기화하기

sp·2022년 3월 24일
0

PyTorch

목록 보기
2/2
post-thumbnail

모델을 학습하기 전에 모델의 가중치를 특정 분포로 초기화를 해주면, 학습 속도나 정확도 등 학습에 도움이 될 수 있습니다. 여기서는 모델의 가중치를 초기화하는 방법에 대해 알아보겠습니다.

텐서 초기화하기

먼저 텐서를 특정 분포로 초기화하는 함수들을 알아보겠습니다. 이는 torch.nn.init 모듈에 존재합니다. 대표적인 초기화 함수들을 살펴보겠습니다.

torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
torch.nn.init.normal_(tensor, mean=0.0, std=1.0)

torch.nn.init.constant_(tensor, val)
torch.nn.init.ones_(tensor)
torch.nn.init.zeros_(tensor)

함수의 이름에서 알 수 있듯이, 균등분포로 uniform_과 정규분포 normal_ 분포로 초기화하거나, 0, 1 아니면 특정 값으로 초기화할 수 있습니다.

추가적으로 xavier, kaiming 초기화도 다음 함수로 수행 가능합니다.

torch.nn.init.xavier_uniform_(tensor, gain=1.0)
torch.nn.init.xavier_normal_(tensor, gain=1.0)

torch.nn.init.kaiming_uniform_(tensor, a=0, ...)
torch.nn.init.kaiming_normal_(tensor, a=0, ...)

모델 초기화하기

모델 초기화는 모델에 들어있는 레이어의 가중치를 초기화하는 것으로 볼 수 있습니다. 그런데 모델에는 많은 레이어가 들어있기 때문에 위와 같이 하나씩 초기화하기는 쉽지 않습니다. 모델을 한번에 초기화하기 위해 다음과 같은 방법을 사용합니다.

import torch.nn as nn

def init_weight(module):
    class_name = module.__class__.__name__

    if class_name.find("Conv") != -1:
        nn.init.normal_(module.weight.data, 0.0, 0.02)
    elif class_name.find("BatchNorm2d") != -1:
        nn.init.normal_(module.weight.data, 1.0, 0.02)
        nn.init.constant(module.bias.data, 0.0)

if __name__ == "__main__":
    netG_A2B = Generator(6).to(device)
    netG_A2B.apply(init_weight)

모델에 대해 apply 메서드를 적용하는데, 그 인수로 함수가 들어갑니다. 이 메서드에서는 전체 모델 내의 레이어를 재귀 방식으로 접근해 함수로 들어아게 되고, 각 레이어에서 가중치를 찾아 초기화를 수행하게 됩니다.

재귀적으로 접근하기 때문에 Conv2d, BatchNorm, Sequential 등의 인스턴스가 매개변수로 들어오게 되고, 몇몇 클래스들에만 적용하기 위해 string에 적용하는 find 함수를 사용하게 됩니다. 즉, 위 코드는 컨볼루션 레이어일 때 weight.data를 초기화하게 되고, 배치 정규화에 대해서는 weight.databias.data를 초기화하는 것으로 이해할 수 있습니다.

0개의 댓글