H : height 높이
W : width 너비
C : channel 채널 (RGB 컬러의 경우 3, grayscale의 경우 1또는 0)
H, W = 3,5
img1 = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
img2 = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
"""
img 1 :
tensor([[[193, 210, 209],
[160, 213, 52]],
[[ 22, 169, 202],
[164, 181, 17]],
[[175, 240, 109],
[ 0, 161, 185]]], dtype=torch.uint8)
"""
"""
img 2 :
tensor([[[104, 155, 197],
[ 10, 225, 143]],
[[237, 8, 60],
[229, 81, 124]],
[[144, 193, 159],
[185, 1, 67]]], dtype=torch.uint8)
"""
torch.concat과 stack 비교
- torch.cat은 원하는 dimension 방향으로 텐서를 나란히 쌓아줌
- '스며들어 합쳐진다'로 이해
- torch.stack은 텐서를 새로운 차원에 쌓아줌
- 합쳐지긴 하지만 '구분되어' 합쳐진다
1) torch.cat
cat0 = torch.cat([img1, img2], dim=0)
"""
tensor([[[193, 210, 209],
[160, 213, 52]],
[[ 22, 169, 202],
[164, 181, 17]],
[[175, 240, 109],
[ 0, 161, 185]], # 여기까지가 img1
[[104, 155, 197], # 여기부터가 img2
[ 10, 225, 143]],
[[237, 8, 60],
[229, 81, 124]],
[[144, 193, 159],
[185, 1, 67]]], dtype=torch.uint8)
"""
cat0.size()
"""
torch.Size([6, 2, 3])
"""
2) torch.stack
stack0 = torch.stack([img1, img2], dim=0)
"""
tensor([[[[193, 210, 209],
[160, 213, 52]],
[[ 22, 169, 202],
[164, 181, 17]],
[[175, 240, 109],
[ 0, 161, 185]]], # 여기부터가 img1
[[[104, 155, 197], # 여기부터가 img2
[ 10, 225, 143]],
[[237, 8, 60],
[229, 81, 124]],
[[144, 193, 159],
[185, 1, 67]]]], dtype=torch.uint8)
"""
stack0.size()
"""
torch.Size([2, 3, 2, 3])
"""
1) torch.cat
cat1 = torch.cat([img1, img2], dim=1)
cat1
"""
tensor([[[193, 210, 209],
[160, 213, 52],
[104, 155, 197],
[ 10, 225, 143]],
[[ 22, 169, 202],
[164, 181, 17],
[237, 8, 60],
[229, 81, 124]],
[[175, 240, 109],
[ 0, 161, 185],
[144, 193, 159],
[185, 1, 67]]], dtype=torch.uint8)
"""
cat1.size()
"""
torch.Size([3, 4, 3])
"""
2) torch.stack
stack1 = torch.stack([img1, img2], dim=1)
stack1
"""
tensor([[[[193, 210, 209],
[160, 213, 52]],
[[104, 155, 197],
[ 10, 225, 143]]],
[[[ 22, 169, 202],
[164, 181, 17]],
[[237, 8, 60],
[229, 81, 124]]],
[[[175, 240, 109],
[ 0, 161, 185]],
[[144, 193, 159],
[185, 1, 67]]]], dtype=torch.uint8)
"""
stack1.size()
"""
torch.Size([3, 2, 2, 3])
"""
1) torch.cat
cat2 = torch.cat([img1, img2], dim=2)
cat2
"""
tensor([[[193, 210, 209, 104, 155, 197],
[160, 213, 52, 10, 225, 143]],
[[ 22, 169, 202, 237, 8, 60],
[164, 181, 17, 229, 81, 124]],
[[175, 240, 109, 144, 193, 159],
[ 0, 161, 185, 185, 1, 67]]], dtype=torch.uint8)
"""
cat2.size()
"""
torch.Size([3, 2, 6])
"""
2) torch.stack
stack2 = torch.stack([img1, img2], dim=2)
stack2
"""
tensor([[[[193, 210, 209],
[104, 155, 197]],
[[160, 213, 52],
[ 10, 225, 143]]],
[[[ 22, 169, 202],
[237, 8, 60]],
[[164, 181, 17],
[229, 81, 124]]],
[[[175, 240, 109],
[144, 193, 159]],
[[ 0, 161, 185],
[185, 1, 67]]]], dtype=torch.uint8)
"""
stack2.size()
"""
torch.Size([3, 2, 2, 3])
"""
InceptionNet에서 filter concat을 할 때는 다양한 filter를 하나의 filter로 합치기 위해서 cat을 사용한다 !
def forward(self, x):
x = self.conv1(x)
x = [self.branch3x3_pool(x),
self.branch3x3_conv(x)]
x = torch.cat(x, 1)
x = [self.branch7x7a(x),
self.branch7x7b(x)]
x = torch.cat(x, 1)
x = [self.branchpoola(x),
self.branchpoolb(x)]
x = torch.cat(x, 1)
return x
reference
https://gaussian37.github.io/dl-pytorch-snippets/
https://discuss.pytorch.kr/t/torch-cat-torch-stack/26
https://masterzone.tistory.com/36