[PyTorch] interpolation

김유상·2022년 12월 22일
0

torch.nn.functional 모듈에서는 interpolation을 지원한다. interpolation이 무엇인가 하면 사전적으로는 보간이라는 뜻을 가지며 작은 사이즈의 이미지를 큰 사이즈로 키울 때 사용된다. 단순히 업샘플링이라고 할 수도 있지만 늘어날 때 중간 값을 적절하게 보간해주는 옵션들을 구체적으로 구현하고 있다.

params

torch.nn.functional.interpolate(
    input, # input tensor, 샘플링할 데이터 삽입
    size=None, # output spatial size로 int나 int형 tuple을 입력으로 넣을 수 있음, (h,w)
    scale_factor=None, # spatial size에 곱해지는 scale 값
    mode='nearest', # 샘플링 방식을 선택 ['nearest', 'linear', 'bilinear', 'bicubic', 'trilinear', 'area']
    align_corners=False # 가장자리를 정렬할 수 있는 옵션
)

호출 방법만 알아서는 어떤 기능인지 이해하기 어려우니 JINSOL KIM이 작성한 아래 예시 코드를 확인하자.

import torch
import torch.nn as nn
import torch.nn.functional as F

input = torch.arange(0, 16, dtype=torch.float32).reshape(1, 1, 4, 4)
# size : torch.Size([1, 1, 4, 4])
# value : tensor([[[[ 0.,  1.,  2.,  3.],
#                   [ 4.,  5.,  6.,  7.],
#                   [ 8.,  9., 10., 11.],
#                   [12., 13., 14., 15.]]]])

F.interpolate(input, scale_factor=2, mode='nearest')
# tensor([[[[ 0.,  0.,  1.,  1.,  2.,  2.,  3.,  3.],
#           [ 0.,  0.,  1.,  1.,  2.,  2.,  3.,  3.],
#           [ 4.,  4.,  5.,  5.,  6.,  6.,  7.,  7.],
#           [ 4.,  4.,  5.,  5.,  6.,  6.,  7.,  7.],
#           [ 8.,  8.,  9.,  9., 10., 10., 11., 11.],
#           [ 8.,  8.,  9.,  9., 10., 10., 11., 11.],
#           [12., 12., 13., 13., 14., 14., 15., 15.],
#           [12., 12., 13., 13., 14., 14., 15., 15.]]]])

F.interpolate(input, scale_factor=0.8, mode='nearest')
# tensor([[[[ 0.,  1.,  2.],
#           [ 4.,  5.,  6.],
#           [ 8.,  9., 10.]]]])

F.interpolate(input, size=(5, 3), mode='bilinear')
# tensor([[[[ 0.1667,  1.5000,  2.8333],
#           [ 2.9667,  4.3000,  5.6333],
#           [ 6.1667,  7.5000,  8.8333],
#           [ 9.3667, 10.7000, 12.0333],
#           [12.1667, 13.5000, 14.8333]]]])

0~15까지의 값을 가지는 4x4 텐서를 2배로 키우거나 줄이거나 특정 사이즈로 변환하는 모습을 확인할 수 있다. 이렇게 업샘플링도 가능하고 다운샘플링도 가능하다. 보간 방식에 따라 늘어난 자리에 들어가는 값이 다르다. nearest는 가까운 위치의 값을 그대로 복사하고 bilinear는 범위의 중간이 되는 값을 연산을 통해 선정한다.

Referenced: https://gaussian37.github.io/dl-pytorch-snippets/#finterpolate와-nnupsample-1, https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html#torch.nn.functional.interpolate

profile
continuous programming

0개의 댓글