구현체를 참고할 때 대부분 텐서플로를 사용해서 batchnorm을 구현하는 데 난 파이토치를 이용해서 구현할거다. batchNorm에 대한 자세한 내용은 각자 학습하길 바람
import torch.nn as nn
def batch_norm(x, data_format, momentum=0.01):
if data_format == 'NHWC':
x = x.permute(0, 2, 3, 1)
elif data_format == 'NCHW':
pass
else:
raise NotImplementedError("data_format not supported")
return nn.BatchNorm2d(x.size(1), momentum=momentum)(x)
image_ops란 파일을 만들어주고 batch_norm을 만들어주자.
파이토치에서 제공하는 nn.BatchNorm2d()은 자동으로 관리해주지만 굳이 만드는 이유는 많은 논문에서 사용하는 cifar-10이라는 데이터셋을 표현하는데 관습적으로 두 가지 표현을 쓰기 때문이다.
NHWC, NCHW 두 가지 형태로 이미지 텐서를 표현하는데 파이토치의 batchNorm은 입력 텐서의 모양이 NCHW가 기본이다.
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
아래와 같이 기본 설정이 되어 있어서 NHWC인 경우엔 간단히 permute을 이용해서 텐서의 모양을 맞춰준다.
또한 기본 설정에 momentum이 0.1로 설정되어 있는데 mobileNet을 비롯한 기본적인 논문의 batchNorm의 모멘텀이 0.01, 0.001로 설정된 경우가 많아서 기본값을 0.01로 바꿔주었다.
텐서플로에 비해서 파이토치의 batchNorm이 좋은 이유는 training 상태를 자동적으로 추론해준다는 점이다. 따라서 논문 저자가 구현한 코드에 is_training을 따로 구현할 필요가 없다.