[pytorch]nn.Linear

ma-kjh·2023년 9월 10일
0

Pytorch

목록 보기
6/20

https://pytorch.org/docs/stable/generated/torch.nn.Linear.html

새롭게 알게된 점

Pytorch의 Linear layer 모듈은 오직 TensorFloat32 dtype만을 지원한다.

근데 이 때, weight가 float16이면 32가 들어왔을 때 계산이 안되므로, weight도 float32로 넣어주어야 한다.

pretrainedweight = torch.nn.Parameter(h_zs.clone())
layer1 = nn.Linear(512, 100)
layer1.weight = pretrainedweight

를 아래와 같이 바꾸면 에러가 사라짐.

pretrainedweight = torch.nn.Parameter(h_zs.clone().float())
layer1 = nn.Linear(512, 100)
layer1.weight = pretrainedweight.float()

혹은 gpu에서 autocast를 실행해서 float16으로 계산하도록 만들어준다.

profile
거인의 어깨에 올라서서 더 넓은 세상을 바라보라 - 아이작 뉴턴

0개의 댓글