Rearrange layer

반디·2023년 3월 24일
0

PyTorch

목록 보기
2/3
post-custom-banner

torch.py에서 Rearrange는 einops.rearrange와 같은 작업을 수행하는 layer라고 합니다.
(의식의 흐름 주의, 어떻게 동작하는지 알아보자)

einops.rearrange

einops.rearrange는multidimensional tensor를 쉽게 reordering하는 함수입니다.

rearrange(input, expression: str) #input을 expression에서 요구하는 형태로 변형함

Example

images = [np.random.randn(30, 40, 3) for _ in range(32)] #array를 원소로 가진 list 

print("shape after converting to tensor", rearrange(images, 'b h w c -> b h w c').shape)
#(32, 30, 40, 3)

# height (vertical axis) 기준으로 concat, 960 = 32 * 30
print("height (vertical axis) 기준으로 concat", rearrange(images, 'b h w c -> (b h) w c').shape)

# horizontal axis 기준으로 concat, 1280 = 32 * 40
print("horizontal axis 기준으로 concat", rearrange(images, 'b h w c -> h (b w) c').shape)

# reordered axes to "b h w c -> b c h w" 
print("reorder", rearrange(images, 'b h w c -> b c h w').shape)

# flattened each image -> vector, 3600 = 30 * 40 * 3
print("flatten", rearrange(images, 'b h w c -> b (c h w)').shape)

# 각 이미지를 네 부분으로 분할(top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
print("각 이미지 추가 분할", rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape)

# space-to-depth operation
print("space-to-depth operation", rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape)

Rearrange

Example

#x: torch.Size([1, 1, 28, 28])
#b = 1, c = 1, num_w = num_h = 7
rearrange = Rearrange('b c (num_w p1) (num_h p2) -> b (num_w num_h) (p1 p2 c) ', p1=4, p2=4)
print(rearrange(x).shape)

einops.rearrange와 동일한 방식으로 tensor의 형태를 바꿔주는 것을 확인할 수 있습니다.


어떻게 구현되는지 조금 뜯어보면...

from .._torch_specific import apply_for_scriptable_torch

class Rearrange(RearrangeMixin, torch.nn.Module):
    def forward(self, input):
        return apply_for_scriptable_torch(self._recipe, input, reduction_type='rearrange')

    def _apply_recipe(self, x):
        # overriding parent method to prevent it's scripting
        pass

input을 self._recipe에서 제시하는 형식대로 변형해주는 작업을 하는 것 같은데, apply_for_scriptable_torch는 어떻게 생긴걸까?

_reconstruct_from_shape_uncached은? shape이라는 인자를 이용해서 파라미터들을 조정

참고문헌

profile
꾸준히!
post-custom-banner

0개의 댓글