import torch
from torch import nn
from torch.nn.parameter import Parameter
# Function
class Function_A(nn.Module):
def __init__(self, name):
super().__init__()
self.name = name
def forward(self, x):
x = x * 2
return x
class Function_B(nn.Module):
def __init__(self):
super().__init__()
self.W1 = Parameter(torch.Tensor([10]))
self.W2 = Parameter(torch.Tensor([2]))
def forward(self, x):
x = x / self.W1
x = x / self.W2
return x
class Function_C(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('duck', torch.Tensor([7]), persistent=True)
def forward(self, x):
x = x * self.duck
return x
class Function_D(nn.Module):
def __init__(self):
super().__init__()
self.W1 = Parameter(torch.Tensor([3]))
self.W2 = Parameter(torch.Tensor([5]))
self.c = Function_C()
def forward(self, x):
x = x + self.W1
x = self.c(x)
x = x / self.W2
return x
# Layer
class Layer_AB(nn.Module):
def __init__(self):
super().__init__()
self.a = Function_A('duck')
self.b = Function_B()
def forward(self, x):
x = self.a(x) / 5
x = self.b(x)
return x
class Layer_CD(nn.Module):
def __init__(self):
super().__init__()
self.c = Function_C()
self.d = Function_D()
def forward(self, x):
x = self.c(x)
x = self.d(x) + 1
return x
# Model
class Model(nn.Module):
def __init__(self):
super().__init__()
self.ab = Layer_AB()
self.cd = Layer_CD()
def forward(self, x):
x = self.ab(x)
x = self.cd(x)
return x
x = torch.tensor([7])
model = Model()
model(x)
[ ] = ์คํ ์์
x = tensor 7
layer_ab(x)
[1] duck์ด๋ผ๋ name์function_a
__init__ ์คํ
[2]function_b
__init__ ์คํ
[5] x =function_a ์ forward()
์ฐ์ฐ ์คํ / 5
[6] x =function_b ์ forward()
์ฐ์ฐ ์คํ
layer_cd(x)
[3] duck ๋ฒํผ ์ ์ฅํ๋function_a
__init__ ์คํ
[4]function_d
__init__ ์คํ
[7] x =function_c ์ forward()
์ฐ์ฐ ์คํ
[8] x =function_d ์ forward()
์ฐ์ฐ ์คํ + 1
__init__
์ด ์คํ๋๊ณ ๋ค์ ํจ์๋ฅผ ์คํํ๊ธฐ ์ ์ธ ๋๊ธฐ์ค ์ํ์ ๋ค์ด๊ฐ๋ค.ํ๋ผ๋ฏธํฐ
๋ก ์ง์ ํด์ฃผ๋ฉด, ํ์ํ ๋ ๊ฐ์ ๊ณ์ ๊ฐ์ ธ์ ์ธ ์ ์๊ณ Tensor
๋ก ์ง์ ํด์ฃผ๋ฉด, ๊ณ์ฐ์ ํ๋ผ๋ฏธํฐ์ ๋์ผํ๊ฒ ์ ์ํํ๊ฒ ์ง๋ง ๋ฏธ๋ถ์ด ๋ถ๊ฐ๋ฅํ๊ณ ๊ฐ์ด ์
๋ฐ์ดํธ๊ฐ ๋์ง ์๋๋ค. ๋ํ, ๋ชจ๋ธ์ ์ ์ฅํ ๋ ํ
์๊ฐ์ ํจ๊ป ์ ์ฅ๋์ง ์์ ๋ฌด์๋๋ค.for name, buffer in model.named_buffers():
print(f"[ Name ] : {name}\n[ Buffer ] : {buffer}")
print("-" * 30)
>>>
[ Name ] : cd.c.duck
[ Buffer ] : tensor([7.])
------------------------------
[ Name ] : cd.d.c.duck
[ Buffer ] : tensor([7.])
------------------------------
# TODO : Function_C์ ์ํ๋ Buffer๋ฅผ ๊ฐ์ ธ์ค์ธ์!
buffer = model.get_buffer("cd.c.duck")
for name, module in model.named_modules():
print(f"[ Name ] : {name}\n[ Module ]\n{module}")
print("-" * 30)
>>>
[ Name ] :
[ Module ]
Model(
(ab): Layer_AB(
(a): Function_A()
(b): Function_B()
)
(cd): Layer_CD(
(c): Function_C()
(d): Function_D(
(c): Function_C()
)
)
)
------------------------------
[ Name ] : ab
[ Module ]
Layer_AB(
(a): Function_A()
(b): Function_B()
)
------------------------------
[ Name ] : ab.a
[ Module ]
Function_A()
------------------------------
[ Name ] : ab.b
[ Module ]
Function_B()
------------------------------
[ Name ] : cd
[ Module ]
Layer_CD(
(c): Function_C()
(d): Function_D(
(c): Function_C()
)
)
------------------------------
[ Name ] : cd.c
[ Module ]
Function_C()
------------------------------
[ Name ] : cd.d
[ Module ]
Function_D(
(c): Function_C()
)
------------------------------
[ Name ] : cd.d.c
[ Module ]
Function_C()
------------------------------
- forward_pre_hooks
- forward_hooks
- full_backward_hooks
- state_dict_hooks # used internally
import torch
from torch import nn
@torch.no_grad()
def init_weights(m):
print('module:', m)
if type(m) == nn.Linear:
m.weight.fill_(1.0)
print('linear apply:', m.weight)
elif type(m) == nn.Sequential:
print('It is sequential')
net = nn.Sequential(nn.Linear(5, 2), nn.Linear(2, 2))
print('------apply start------')
net.apply(init_weights)
print('---------end----------')
>>>
------apply start------
module: Linear(in_features=5, out_features=2, bias=True)
linear apply: Parameter containing:
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]], requires_grad=True)
module: Linear(in_features=2, out_features=2, bias=True)
linear apply: Parameter containing:
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
module: Sequential(
(0): Linear(in_features=5, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
It is sequential
---------end----------
1๋ฒ์งธ ๊ณผ์ ๊ฐ ์ดํ์ ๊ฑธ์ณ ๋๋ฌ๋ค... Step by Step ์ผ๋ก ์ฐจ๊ทผ์ฐจ๊ทผ ๊ณต๋ถํ ์ ์์์ง๋ง ์ค๋ก์ง docs๋ง ๋ณด๊ณ ์ดํดํ๋ ค๋ ๋ง์ ์๊ฐ์ด ๊ฑธ๋ ธ๋ค.
์ธ์ฐ์ง๋ ๋ชปํ๋๋ผ๋ ์๋ฒฝํ ์ดํดํ๊ณ ๋์ด๊ฐ๊ณ ์ถ์ด์ ์ง๋ฌธ๋ ๋ง์ด ํ๊ณ ์ ๋ฆฌ๋ ํ๋ฉด์ ๋์ ์ดํด๋ฅผ ๋์๋ค. ์๊ฐ์ ์ค๋ ๊ฑธ๋ ธ์ง๋ง ์ดํ๋์ ๋ง์ด ์ฑ์ฅํ ๊ธฐ๋ถ์ด๋ผ ๋ฟ๋ฏํ๋ค!