input의 형태가 [A,B,C,D]라고 할 때
dim=n 이라고 하면 n번째를 제외한 output이 나오게 되고
ex) dim = 2, C가 빠져 [A,B,D]가 나오게 된다.
이 C에 해당하는 데이터를 기준으로 최대값 및 인덱스가 튜플로 나오게 된다.
참고로 예시에서 [3,-1,5] 이 부분은 -1에 해당하는 부분은 차원이 맞게 끔 알아서 조절된다.
input이 [3,5,2,2] 였기 때문에 -1에은 4가 들어가게 된다.
Code :
import torch
aa = torch.randint(0,5,[3,5,2,2])
# print(aa)
bb = torch.reshape(aa,[3,-1,5])
print("INPUT : ")
print(bb)
c = torch.max(bb,dim=2)
print("dim=2")
print(c)
d = torch.max(bb,dim=1)
print("dim=1")
print(d)
d = torch.max(bb,dim=0)
print("dim=0")
print(d)
d = torch.max(bb,dim=-1)
print("dim= -1")
print(d)
Output :
INPUT :
tensor([[[3, 4, 2, 0, 3],
[0, 4, 0, 3, 0],
[2, 0, 4, 3, 3],
[1, 3, 3, 0, 4]],
[[0, 3, 4, 2, 0],
[0, 3, 4, 0, 2],
[4, 0, 3, 2, 2],
[3, 1, 1, 0, 1]],
[[3, 4, 3, 0, 2],
[2, 4, 1, 3, 1],
[4, 0, 1, 1, 4],
[4, 4, 1, 1, 0]]])
dim=2
torch.return_types.max(
values=tensor([[4, 4, 4, 4],
[4, 4, 4, 3],
[4, 4, 4, 4]]),
indices=tensor([[1, 1, 2, 4],
[2, 2, 0, 0],
[1, 1, 0, 0]]))
dim=1
torch.return_types.max(
values=tensor([[3, 4, 4, 3, 4],
[4, 3, 4, 2, 2],
[4, 4, 3, 3, 4]]),
indices=tensor([[0, 0, 2, 1, 3],
[2, 0, 0, 0, 1],
[2, 0, 0, 1, 2]]))
dim=0
torch.return_types.max(
values=tensor([[3, 4, 4, 2, 3],
[2, 4, 4, 3, 2],
[4, 0, 4, 3, 4],
[4, 4, 3, 1, 4]]),
indices=tensor([[0, 0, 1, 1, 0],
[2, 0, 1, 0, 1],
[1, 0, 0, 0, 2],
[2, 2, 0, 2, 0]]))
dim= -1
torch.return_types.max(
values=tensor([[4, 4, 4, 4],
[4, 4, 4, 3],
[4, 4, 4, 4]]),
indices=tensor([[1, 1, 2, 4],
[2, 2, 0, 0],
[1, 1, 0, 0]]))