import torch
# dim 파라미터를 추가하지 않은 경우
a = torch.randn(4, 4)
argmax = torch.argmax(a)
print(a)
print(argmax)
# 출력결과
tensor([[ 0.5478, 0.6609, 0.6214, 1.0270],
[-0.0406, -1.1121, 0.7632, 0.3169],
[-1.2267, 1.6110, -0.3082, -0.4010],
[-0.9153, 0.1607, -1.0582, 0.8812]])
tensor(9)
위의 결과를 보면 알 수 있듯이, 1행 → 4행의 순서로 큰 숫자를 파악하여 tensor를 출력한다.
# dim = 0으로 할 경우
import torch
a = torch.randn(5, 3)
argmax_dim_0 = torch.argmax(a, dim = 0)
print(a)
print(argmax_dim_0)
# 출력결과
tensor([[ 0.7142, 0.4088, 0.1844],
[-0.9138, -1.2394, 0.1528],
[ 0.5173, 0.5606, -0.4164],
[ 1.7521, 0.4822, 0.6333],
[ 0.4563, -0.3746, 0.7837]])
tensor([3, 2, 4])
dim = 0의 파라미터를 추가할 경우, 결과는 위와 같다. 열을 기준으로 max 값이 있는 index를 출력한다.
# dim = 1으로 할 경우
argmax_dim_1 = torch.argmax(a, dim = 1)
print(a)
print(argmax_dim_1)
tensor([[ 0.7142, 0.4088, 0.1844],
[-0.9138, -1.2394, 0.1528],
[ 0.5173, 0.5606, -0.4164],
[ 1.7521, 0.4822, 0.6333],
[ 0.4563, -0.3746, 0.7837]])
tensor([0, 2, 1, 0, 2])
이전의 결과에서 dim = 0 → dim = 1로 변경할 경우의 결과이다. dim을 파라미터로 입력할 경우, 차원에 따라 값이 달라지는 것을 확인할 수 있다. 행을 기준으로 max 값이 있는 index를 출력한다.
import torch
a = torch.randn(5, 3)
argmax_dim_0 = torch.max(a, dim = 0)
print(a)
print(argmax_dim_0)
tensor([[-1.1839, -0.2547, -1.2708],
[ 0.6713, 0.0494, -0.2655],
[-0.3116, 0.7832, -1.2234],
[-0.6373, 0.1265, -1.8030],
[-0.1249, -1.7748, 1.2657]])
torch.return_types.max(
values=tensor([0.6713, 0.7832, 1.2657]),
indices=tensor([1, 2, 4]))
torch.max()의 경우 최댓값과 인덱스 모두 출력해준다. torch.argmax()와 마찬가지로 dim = 0이면, 열을 기준(각 열마다)으로 최댓값과 인덱스를 출력해준다.
argmax_dim_1 = torch.max(a, dim = 1)
print(a)
print(argmax_dim_1)
tensor([[-1.1839, -0.2547, -1.2708],
[ 0.6713, 0.0494, -0.2655],
[-0.3116, 0.7832, -1.2234],
[-0.6373, 0.1265, -1.8030],
[-0.1249, -1.7748, 1.2657]])
torch.return_types.max(
values=tensor([-0.2547, 0.6713, 0.7832, 0.1265, 1.2657]),
indices=tensor([1, 0, 1, 1, 2]))
dim = 1이면, 행을 기준(각 행마다)으로 최댓값과 인덱스를 동시에 출력해준다.