고등학생 때 DLL injection으로 피카츄 배구 해킹같은걸 했었는데, 그 때 사용했던 기법들이 일종의 hooking이다. 그런 기법들을 공식적으로 pytorch의 nn.Module에서 지원해준다.
pytorch의 hook들은 다음과 같은 규칙을 가진다.
아래 코드들을 보면 알거다.
tensor는 backward에 대해서만 hook을 지원한다.
torch.tensor.register_hook(function)
아래와 같은 4개의 hook을 지원한다.
def pre_hook(module, input)
return Anything
return이 있다면 forward의 input을 Anything으로 바꿀 수 있다.
return이 없다면 단순히 input을 조회할 뿐이다.
def hook(module, input, output)
return Anything
return이 있다면 forward의 결과값이 Anything으로 교체된다.
return이 없다면 단순 조회.
def module_hook(module, grad_input, grad_output)
return이 있다면 backard()를 통해 grad_ouput으로 업데이트 될 때, grad_output을 교체할 수 있다.
return이 없다면 단순 교체.