class BasicBlock(nn.Module):
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, planes,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
중
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, planes,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes)
)
여기서 shortcut
은 그림에서
identity connection을 의미한다.
element wise addition을 하려면 차원의 동일해야 되니,
if stride != 1 or in_planes != planes:
을 넣는 것 이다.
stride가 2이상일 경우는 크기가 축소 될 것 이고
input(in_planes
)과 output(planes
) 차원이 다르면 당연히 차원이 다를 것 이다.
그럴 경우,
nn.Conv2d(in_planes, planes,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes)
을 통해 addition이 가능하도록 차원을 만들어 주는 것 이다.