self.hooks
에 등록된 함수가 있으면 실행하게 된다.def hook_custom(x):
print(f'current value is {x}')
클래스객체.hooks = []
클래스객체.hooks.append(hook_custom)
register_hook
으로 등록tensor = torch.rand(1, requires_grad=True)
def tensor_hook(grad):
pass
tensor.register_hook(tensor_hook)
# tensor는 backward hook만 있다.
tensor._backward_hooks
register_forward_hook
, register_forward_pre_hook
, register_full_backward_hook
으로 등록할 수 있다.
register_forward_hook : forward pass를 하는 동안 (output이 계산할 때 마다) 만들어놓은 hook function을 호출.
이렇게 등록한 함수에 인자로 모듈이 실행되기 전 입력값과 실행 후 출력값을 받음 (input, output)
register_forward_pre_hook : forward pass를 하기 직전에 hook function을 호출.
이렇게 등록한 함수에 인자로 모듈이 실행되기 전 입력값만을 받음 (input)
register_full_backward_hook : backward pass를 하는 동안 (gradient가 계산될 때마다) hook function을 호출.
이렇게 등록한 함수에 인자로 backpropagation에서의 gradient 값들을 받음 (grad_input, grad_output)
class Model(nn.Module):
def __init__(self):
super().__init__()
def module_hook(grad):
pass
model = Model()
model.register_forward_pre_hook(module_hook)
model.register_forward_hook(module_hook)
model.register_full_backward_hook(module_hook)
# model.get_model_shortcuts ; 모델 각각의 모듈
for name, module in model.get_model_shortcuts():
if(name == 'target_layer_name'): # 만약 해당 모듈이름이 우리가 원하는 모듈 네임이면
module.register_forward_hook(module_hook) # 해당 모듈에 hook 등록
hook을 등록할 때, 이를 특정 변수에 지정하여 등록한 후, 해당 변수.remove()
를 하게 되면 등록된 hook을 지우고 사용할 수 있다.
* 관련 유튜브 자료 : https://www.youtube.com/watch?v=syLFCVYua6Q