[ pytorch ] 특정 layer 학습하지 않기

hyunsooo·2021년 12월 9일
0
post-custom-banner
  • children() : 자식에 대한 iterator 반환
  • modules() : 전체에 대한 iterator 반환
ct = 0
for name, child in model.named_children():
    ct += 1
    if ct > 2:
        for p in child.parameters():
            p.requires_grad = False
  • state_dict() : 모델의 모든 상태를 딕셔너리로 반환함
#model weight key
print(model.state_dict().keys()

#전체 파라미터 false
for p in model.parameters():
	p.requires_grad = False

#특정 파라미터 ture
for name, p in model.named_parameters():
	if name in ['linear.1.weight', 'linear.2.weight']:
    	p.reguires.grad = Ture
profile
지식 공유
post-custom-banner

0개의 댓글