[pytorch] model summary

spring·2020년 11월 9일
1

pytorch에는 모델을 간략하게 print 함수로 출력할 수 있다.

아래의 코드는 Alexnet을 print 함수로 출력한 모습이다. 이 방법은 신경망의 입력크기를 지정해주지 않아도 볼 수 있다는 점이다.

import torch
import torchvision
net = torchvision.models.alexnet()
print(net)
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

두번째 방법은 torchsummary라이브러리를 이용한 방법이다 설치는 아래와 같이 설치한다.

pip install torchsummary

이는 keras 스타일로 출력해주는 장점이 있는데 입력크기가 출력되지 않아 아쉬움이 있다.

import torch
import torchvision
import torchsummary
net = torchvision.models.alexnet()
torchsummary.summary(net, (3, 256, 256),device='cpu')
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 63, 63]          23,296
              ReLU-2           [-1, 64, 63, 63]               0
         MaxPool2d-3           [-1, 64, 31, 31]               0
            Conv2d-4          [-1, 192, 31, 31]         307,392
              ReLU-5          [-1, 192, 31, 31]               0
         MaxPool2d-6          [-1, 192, 15, 15]               0
            Conv2d-7          [-1, 384, 15, 15]         663,936
              ReLU-8          [-1, 384, 15, 15]               0
            Conv2d-9          [-1, 256, 15, 15]         884,992
             ReLU-10          [-1, 256, 15, 15]               0
           Conv2d-11          [-1, 256, 15, 15]         590,080
             ReLU-12          [-1, 256, 15, 15]               0
        MaxPool2d-13            [-1, 256, 7, 7]               0
AdaptiveAvgPool2d-14            [-1, 256, 6, 6]               0
          Dropout-15                 [-1, 9216]               0
           Linear-16                 [-1, 4096]      37,752,832
             ReLU-17                 [-1, 4096]               0
          Dropout-18                 [-1, 4096]               0
           Linear-19                 [-1, 4096]      16,781,312
             ReLU-20                 [-1, 4096]               0
           Linear-21                 [-1, 1000]       4,097,000
================================================================
Total params: 61,100,840
Trainable params: 61,100,840
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 10.97
Params size (MB): 233.08
Estimated Total Size (MB): 244.80
----------------------------------------------------------------

세번째 방법은 torch-model-summary를 이용한 방법이고 설치는 아래와 같다.

pip3 install pytorch-model-summary

이 라이브러리는 입력 크기로 토치텐서를 주고 입력 크기가 출력되어서 조금 더 괜찮다. show_input 파라매터를 True로 두면 입력크기가 나오고 False로 두면 출력 크기가 나온다. 둘 모두를 출력하는 옵션은 없는것 같다.

import torch
import torchvision
import pytorch_model_summary

net = torchvision.models.alexnet()
print(pytorch_model_summary.summary(net, torch.zeros(1, 3, 256, 256), show_input=True))
-----------------------------------------------------------------------------
           Layer (type)          Input Shape         Param #     Tr. Param #
=============================================================================
               Conv2d-1     [1, 3, 256, 256]          23,296          23,296
                 ReLU-2      [1, 64, 63, 63]               0               0
            MaxPool2d-3      [1, 64, 63, 63]               0               0
               Conv2d-4      [1, 64, 31, 31]         307,392         307,392
                 ReLU-5     [1, 192, 31, 31]               0               0
            MaxPool2d-6     [1, 192, 31, 31]               0               0
               Conv2d-7     [1, 192, 15, 15]         663,936         663,936
                 ReLU-8     [1, 384, 15, 15]               0               0
               Conv2d-9     [1, 384, 15, 15]         884,992         884,992
                ReLU-10     [1, 256, 15, 15]               0               0
              Conv2d-11     [1, 256, 15, 15]         590,080         590,080
                ReLU-12     [1, 256, 15, 15]               0               0
           MaxPool2d-13     [1, 256, 15, 15]               0               0
   AdaptiveAvgPool2d-14       [1, 256, 7, 7]               0               0
             Dropout-15            [1, 9216]               0               0
              Linear-16            [1, 9216]      37,752,832      37,752,832
                ReLU-17            [1, 4096]               0               0
             Dropout-18            [1, 4096]               0               0
              Linear-19            [1, 4096]      16,781,312      16,781,312
                ReLU-20            [1, 4096]               0               0
              Linear-21            [1, 4096]       4,097,000       4,097,000
=============================================================================
Total params: 61,100,840
Trainable params: 61,100,840
Non-trainable params: 0
-----------------------------------------------------------------------------

References

profile
Researcher & Developer @ NAVER Corp | Designer @ HONGIK Univ.

0개의 댓글