2x2 filter와 stride=2인 상태 → 원소별 곱하고 총 합 X → window에 들어있는 값들의 최댓값/평균
계산의 편의를 위해 는 2의 배수라고 가정
pooling은 일반적으로 height, width를 절반으로 줄임
import torch
import torch.nn as nn
torch.set_printoptions(linewidth=100)
max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
input_ = torch.randn(1, 8, 8)
print(f"input: {input_.shape}\n{input_}\n")
output = max_pool(input_)
print(f"max pooling output: {output.shape}\n{output}\n")
output = avg_pool(input_)
print(f"avg pooling output: {output.shape}\n{output}")