torch.max / dim 역할

jaeha_lee·2021년 2월 23일
1

torch.max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor)

input의 형태가 [A,B,C,D]라고 할 때
dim=n 이라고 하면 n번째를 제외한 output이 나오게 되고
ex) dim = 2, C가 빠져 [A,B,D]가 나오게 된다.
이 C에 해당하는 데이터를 기준으로 최대값 및 인덱스가 튜플로 나오게 된다.

참고로 예시에서 [3,-1,5] 이 부분은 -1에 해당하는 부분은 차원이 맞게 끔 알아서 조절된다.
input이 [3,5,2,2] 였기 때문에 -1에은 4가 들어가게 된다.

Code :

import torch
aa = torch.randint(0,5,[3,5,2,2])

# print(aa)

bb = torch.reshape(aa,[3,-1,5])
print("INPUT : ")
print(bb)
c = torch.max(bb,dim=2)
print("dim=2")
print(c)
d = torch.max(bb,dim=1)
print("dim=1")
print(d)
d = torch.max(bb,dim=0)
print("dim=0")
print(d)
d = torch.max(bb,dim=-1)
print("dim= -1")
print(d)

Output :

INPUT : 
tensor([[[3, 4, 2, 0, 3],
         [0, 4, 0, 3, 0],
         [2, 0, 4, 3, 3],
         [1, 3, 3, 0, 4]],

        [[0, 3, 4, 2, 0],
         [0, 3, 4, 0, 2],
         [4, 0, 3, 2, 2],
         [3, 1, 1, 0, 1]],

        [[3, 4, 3, 0, 2],
         [2, 4, 1, 3, 1],
         [4, 0, 1, 1, 4],
         [4, 4, 1, 1, 0]]])
dim=2
torch.return_types.max(
values=tensor([[4, 4, 4, 4],
        [4, 4, 4, 3],
        [4, 4, 4, 4]]),
indices=tensor([[1, 1, 2, 4],
        [2, 2, 0, 0],
        [1, 1, 0, 0]]))
dim=1
torch.return_types.max(
values=tensor([[3, 4, 4, 3, 4],
        [4, 3, 4, 2, 2],
        [4, 4, 3, 3, 4]]),
indices=tensor([[0, 0, 2, 1, 3],
        [2, 0, 0, 0, 1],
        [2, 0, 0, 1, 2]]))
dim=0
torch.return_types.max(
values=tensor([[3, 4, 4, 2, 3],
        [2, 4, 4, 3, 2],
        [4, 0, 4, 3, 4],
        [4, 4, 3, 1, 4]]),
indices=tensor([[0, 0, 1, 1, 0],
        [2, 0, 1, 0, 1],
        [1, 0, 0, 0, 2],
        [2, 2, 0, 2, 0]]))
dim= -1
torch.return_types.max(
values=tensor([[4, 4, 4, 4],
        [4, 4, 4, 3],
        [4, 4, 4, 4]]),
indices=tensor([[1, 1, 2, 4],
        [2, 2, 0, 0],
        [1, 1, 0, 0]]))

0개의 댓글