[Pytorch]#1 nn.Module & number of Parameters

Clay Ryu's sound lab·2024년 1월 15일
0

Framework

목록 보기
40/49

nn.Module

We usually don't know what is and how nn.Module is working. I write down some functions which I use occasionally.

  • Module.train()
  • Module.to()
  • Module.zero_grad()
  • Module.state_dict()
  • Module.register_buffer()
  • Module.forward()
  • Module.parameters()

https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module

parameters()

Today I'm going to check the number of parameters in my transformer decoder using Module.parameters()
Dimension of the model is set to 512, and it has a single layer.
I used custom decoder from x-transformers, and it uses pre-norm as a default normalization option.

LayerNorm

Scale: 512
Bias: 512

Attention

to_q: 262,144 (512 ^ 2)
to_k: 262,144 (512 ^ 2)
to_v: 262,144 (512 ^ 2)
to_out: 262,144 (512 ^ 2)

sum(p.numel() for p in lm_model.main_decoder.transformer_decoder.layers[0][1].to_q.parameters())

FeedForward

Linear(in_features=512, out_features=2048, bias=True) : 1,050,624 (512 x 2048 + 2048)
Linear(in_features=2048, out_features=512, bias=True) :
1,049,088 (2048 x 512 + 512)

sum(p.numel() for p in lm_model.main_decoder.transformer_decoder.layers[1][1].ff[0][0].parameters())
lm_model.main_decoder.transformer_decoder
Decoder(
  (layers): ModuleList(
    (0): ModuleList(
      (0): ModuleList(
        (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (1-2): 2 x None
      )
      (1): Attention(
        (to_q): Linear(in_features=512, out_features=512, bias=False)
        (to_k): Linear(in_features=512, out_features=512, bias=False)
        (to_v): Linear(in_features=512, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.1, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=512, bias=False)
      )
      (2): Residual()
    )
    (1): ModuleList(
      (0): ModuleList(
        (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (1-2): 2 x None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=512, out_features=2048, bias=True)
            (1): GELU(approximate='none')
          )
          (1): Dropout(p=0.1, inplace=False)
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
      )
      (2): Residual()
    )
  )
  (final_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
profile
chords & code // harmony with structure

0개의 댓글