torch.cat과 torch.stack의 차이

hangdi·2023년 12월 7일
0

1. tensor와 vector 연관짓기



이미지를 Tensor로 표현하기

H : height 높이
W : width 너비
C : channel 채널 (RGB 컬러의 경우 3, grayscale의 경우 1또는 0)
H, W = 3,5
img1 = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
img2 = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)

"""
img 1 :
tensor([[[193, 210, 209],
         [160, 213,  52]],

        [[ 22, 169, 202],
         [164, 181,  17]],

        [[175, 240, 109],
         [  0, 161, 185]]], dtype=torch.uint8)
"""

"""
img 2 : 
tensor([[[104, 155, 197],
         [ 10, 225, 143]],

        [[237,   8,  60],
         [229,  81, 124]],

        [[144, 193, 159],
         [185,   1,  67]]], dtype=torch.uint8)
"""

torch.concat과 stack 비교

  • torch.cat은 원하는 dimension 방향으로 텐서를 나란히 쌓아줌
    • '스며들어 합쳐진다'로 이해
  • torch.stack은 텐서를 새로운 차원에 쌓아줌
    • 합쳐지긴 하지만 '구분되어' 합쳐진다

2. 다양한 dimension으로 쌓아보기

채널로 쌓기

1) torch.cat

  • 원래 tensor의 크기가 확장
cat0 = torch.cat([img1, img2], dim=0) 

"""
tensor([[[193, 210, 209],
         [160, 213,  52]],

        [[ 22, 169, 202],
         [164, 181,  17]],

        [[175, 240, 109],
         [  0, 161, 185]], # 여기까지가 img1

        [[104, 155, 197], # 여기부터가 img2
         [ 10, 225, 143]],

        [[237,   8,  60],
         [229,  81, 124]],

        [[144, 193, 159],
         [185,   1,  67]]], dtype=torch.uint8)
"""
cat0.size()

"""
torch.Size([6, 2, 3])
"""

2) torch.stack

  • 원래 텐서의 크기는 그대로고 차원이 늘어남
stack0 = torch.stack([img1, img2], dim=0)

"""
tensor([[[[193, 210, 209],
          [160, 213,  52]],

         [[ 22, 169, 202],
          [164, 181,  17]],

         [[175, 240, 109],
          [  0, 161, 185]]], # 여기부터가 img1


        [[[104, 155, 197], # 여기부터가 img2
          [ 10, 225, 143]],

         [[237,   8,  60],
          [229,  81, 124]],

         [[144, 193, 159],
          [185,   1,  67]]]], dtype=torch.uint8)
"""
stack0.size()

"""
torch.Size([2, 3, 2, 3])
"""

Height를 기준으로 쌓기

1) torch.cat

  • height가 img1 height + img2 height 길이만큼 되는 것
cat1 = torch.cat([img1, img2], dim=1)
cat1

"""
tensor([[[193, 210, 209],
         [160, 213,  52],
         [104, 155, 197],
         [ 10, 225, 143]],

        [[ 22, 169, 202],
         [164, 181,  17],
         [237,   8,  60],
         [229,  81, 124]],

        [[175, 240, 109],
         [  0, 161, 185],
         [144, 193, 159],
         [185,   1,  67]]], dtype=torch.uint8)
"""
cat1.size()
"""
torch.Size([3, 4, 3])
"""

2) torch.stack

  • height 기준으로
stack1 = torch.stack([img1, img2], dim=1)
stack1

"""
tensor([[[[193, 210, 209],
          [160, 213,  52]],

         [[104, 155, 197],
          [ 10, 225, 143]]],


        [[[ 22, 169, 202],
          [164, 181,  17]],

         [[237,   8,  60],
          [229,  81, 124]]],


        [[[175, 240, 109],
          [  0, 161, 185]],

         [[144, 193, 159],
          [185,   1,  67]]]], dtype=torch.uint8)
"""
stack1.size()

"""
torch.Size([3, 2, 2, 3])
"""

width 기준으로 쌓기

1) torch.cat

  • width가 img1 width + img2 width만큼 되는 것
cat2 = torch.cat([img1, img2], dim=2)
cat2

"""
tensor([[[193, 210, 209, 104, 155, 197],
         [160, 213,  52,  10, 225, 143]],

        [[ 22, 169, 202, 237,   8,  60],
         [164, 181,  17, 229,  81, 124]],

        [[175, 240, 109, 144, 193, 159],
         [  0, 161, 185, 185,   1,  67]]], dtype=torch.uint8)
"""
cat2.size()

"""
torch.Size([3, 2, 6])
"""

2) torch.stack

stack2 = torch.stack([img1, img2], dim=2)
stack2

"""
tensor([[[[193, 210, 209],
          [104, 155, 197]],

         [[160, 213,  52],
          [ 10, 225, 143]]],


        [[[ 22, 169, 202],
          [237,   8,  60]],

         [[164, 181,  17],
          [229,  81, 124]]],


        [[[175, 240, 109],
          [144, 193, 159]],

         [[  0, 161, 185],
          [185,   1,  67]]]], dtype=torch.uint8)
"""
stack2.size()

"""
torch.Size([3, 2, 2, 3])
"""
  • stack1이랑 결과가 동일하게 나왔지만, stack1은 torch.Size([3, height가 2개, 2, 3])이고, torch.Size([3, 2, width가 2개, 3])이다.

3. 정리



4. InceptionNet에서의 구현


InceptionNet에서 filter concat을 할 때는 다양한 filter를 하나의 filter로 합치기 위해서 cat을 사용한다 !

  • x = torch.cat(x, 1)에서 1은 channel
def forward(self, x):
	x = self.conv1(x)
    x = [self.branch3x3_pool(x),
         self.branch3x3_conv(x)]
    x = torch.cat(x, 1)

    x = [self.branch7x7a(x),
         self.branch7x7b(x)]
    x = torch.cat(x, 1)

    x = [self.branchpoola(x),
         self.branchpoolb(x)]
    x = torch.cat(x, 1)
    
    return x

reference
https://gaussian37.github.io/dl-pytorch-snippets/
https://discuss.pytorch.kr/t/torch-cat-torch-stack/26
https://masterzone.tistory.com/36

profile
눈물 콧물 흘리면서 배우는 코딩

0개의 댓글

관련 채용 정보