[Pytorch] Hook & Apply

lijm1358·2023년 3월 14일

hook

패키지화된 코드에서 다른 사용자가 만든 코드를 중간에 실행시킬 수 있도록 만들어놓은 인터페이스(https://www.techtarget.com/whatis/search/query?q=Hook)

Pytorch에서의 hook은 Tensor에 적용하는 hook과 Module에 적용하는 hook으로 나뉜다. 보통 디버깅, layer중간의 feature 출력, gradient clipping등에 활용될 수 있다.

Tensor는 backward hook만 존재.

  • register_hook(hook) : tensor에 대한 gradient가 계산된 후 실행될 hook을 지정. hook의 형태는 hook(grad)->Tensor or None.

Module은 backward hook외에도 여러 hook이 존재한다. hook의 종류는 model.__dict__의 결과로도 확인할 수 있다.

  • register_forward_hook(hook) : forward()가 호출된 후 출력이 계산되면 실행될 hook을 지정. hook의 형태는 hook(module, input, output) -> None or modified output.
  • register_forward_pre_hook(hook) : forward() 호출 전에 실행될 hook을 지정. hook의 형태는 hook(module, input) -> None or modified input.
  • register_backward_hook(hook) (deprecated)
  • register_full_backward_hook(hook) : module input의 gradient가 계산될 떄 마다 호출될 hook을 지정. hook의 형태는 hook(module, grad_input, grad_output) -> tuple(Tensor) or None. grad_input과 grad_output의 내용을 바꿔버리면 오류가 발생한다.
  • _register_state_dict_hook(hook) (삭제됨) register_load_state_dict_post_hook(hook) : 내부적으로만 사용. load_state_dict()가 호출된 후에 실행될 hook을 지정.

forward hook을 이용하면 forward의 결과로 나온 output을 임의로 수정하여 출력하도록 할 수 있다.

hook에 return값이 있다면 해당 return값이 출력화 되어 gradient 계산에 사용된다.

apply

module을 구성하는 전체 module 각각에 동일한 함수를 적용할 때 사용한다.

@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.fill_(1.0)
        print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)

apply에 넣은 함수의 파라미터로는 전체 module에 포함된 module들이 하나씩 들어가게 된다. submodule들은 postorder 방식으로 탐색하며 apply를 적용한다. 주로 파라미터의 가중치를 초기화(module.weight.data.fill_())할 때 사용된다.

profile
ML, DL 공부중

0개의 댓글