torch.py에서 Rearrange는 einops.rearrange와 같은 작업을 수행하는 layer라고 합니다.
(의식의 흐름 주의, 어떻게 동작하는지 알아보자)
einops.rearrange는multidimensional tensor를 쉽게 reordering하는 함수입니다.
rearrange(input, expression: str) #input을 expression에서 요구하는 형태로 변형함
Example
images = [np.random.randn(30, 40, 3) for _ in range(32)] #array를 원소로 가진 list
print("shape after converting to tensor", rearrange(images, 'b h w c -> b h w c').shape)
#(32, 30, 40, 3)
# height (vertical axis) 기준으로 concat, 960 = 32 * 30
print("height (vertical axis) 기준으로 concat", rearrange(images, 'b h w c -> (b h) w c').shape)
# horizontal axis 기준으로 concat, 1280 = 32 * 40
print("horizontal axis 기준으로 concat", rearrange(images, 'b h w c -> h (b w) c').shape)
# reordered axes to "b h w c -> b c h w"
print("reorder", rearrange(images, 'b h w c -> b c h w').shape)
# flattened each image -> vector, 3600 = 30 * 40 * 3
print("flatten", rearrange(images, 'b h w c -> b (c h w)').shape)
# 각 이미지를 네 부분으로 분할(top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
print("각 이미지 추가 분할", rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape)
# space-to-depth operation
print("space-to-depth operation", rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape)
Example
#x: torch.Size([1, 1, 28, 28])
#b = 1, c = 1, num_w = num_h = 7
rearrange = Rearrange('b c (num_w p1) (num_h p2) -> b (num_w num_h) (p1 p2 c) ', p1=4, p2=4)
print(rearrange(x).shape)
einops.rearrange와 동일한 방식으로 tensor의 형태를 바꿔주는 것을 확인할 수 있습니다.
어떻게 구현되는지 조금 뜯어보면...
from .._torch_specific import apply_for_scriptable_torch
class Rearrange(RearrangeMixin, torch.nn.Module):
def forward(self, input):
return apply_for_scriptable_torch(self._recipe, input, reduction_type='rearrange')
def _apply_recipe(self, x):
# overriding parent method to prevent it's scripting
pass
input을 self._recipe에서 제시하는 형식대로 변형해주는 작업을 하는 것 같은데, apply_for_scriptable_torch는 어떻게 생긴걸까?
_reconstruct_from_shape_uncached은? shape이라는 인자를 이용해서 파라미터들을 조정
참고문헌