torch.argmax() & torch.max()

Jarvis_Geun·2021년 8월 26일
2

Pytorch

목록 보기
1/1
post-thumbnail

torch.argmax()

Pytorch docs : torch.argmax()

import torch

# dim 파라미터를 추가하지 않은 경우
a = torch.randn(4, 4)
argmax = torch.argmax(a)
print(a)
print(argmax)
# 출력결과
tensor([[ 0.5478,  0.6609,  0.6214,  1.0270],
        [-0.0406, -1.1121,  0.7632,  0.3169],
        [-1.2267,  1.6110, -0.3082, -0.4010],
        [-0.9153,  0.1607, -1.0582,  0.8812]])
tensor(9)

위의 결과를 보면 알 수 있듯이, 1행 → 4행의 순서로 큰 숫자를 파악하여 tensor를 출력한다.

# dim = 0으로 할 경우
import torch

a = torch.randn(5, 3)
argmax_dim_0 = torch.argmax(a, dim = 0)
print(a)
print(argmax_dim_0)
# 출력결과
tensor([[ 0.7142,  0.4088,  0.1844],
        [-0.9138, -1.2394,  0.1528],
        [ 0.5173,  0.5606, -0.4164],
        [ 1.7521,  0.4822,  0.6333],
        [ 0.4563, -0.3746,  0.7837]])
tensor([3, 2, 4])

dim = 0의 파라미터를 추가할 경우, 결과는 위와 같다. 열을 기준으로 max 값이 있는 index를 출력한다.

# dim = 1으로 할 경우
argmax_dim_1 = torch.argmax(a, dim = 1)
print(a)
print(argmax_dim_1)
tensor([[ 0.7142,  0.4088,  0.1844],
        [-0.9138, -1.2394,  0.1528],
        [ 0.5173,  0.5606, -0.4164],
        [ 1.7521,  0.4822,  0.6333],
        [ 0.4563, -0.3746,  0.7837]])
tensor([0, 2, 1, 0, 2])

이전의 결과에서 dim = 0 → dim = 1로 변경할 경우의 결과이다. dim을 파라미터로 입력할 경우, 차원에 따라 값이 달라지는 것을 확인할 수 있다. 행을 기준으로 max 값이 있는 index를 출력한다.


torch.max()

Pytorch docs : torch.max()

import torch

a = torch.randn(5, 3)
argmax_dim_0 = torch.max(a, dim = 0)
print(a)
print(argmax_dim_0)
tensor([[-1.1839, -0.2547, -1.2708],
        [ 0.6713,  0.0494, -0.2655],
        [-0.3116,  0.7832, -1.2234],
        [-0.6373,  0.1265, -1.8030],
        [-0.1249, -1.7748,  1.2657]])
torch.return_types.max(
values=tensor([0.6713, 0.7832, 1.2657]),
indices=tensor([1, 2, 4]))

torch.max()의 경우 최댓값과 인덱스 모두 출력해준다. torch.argmax()와 마찬가지로 dim = 0이면, 열을 기준(각 열마다)으로 최댓값과 인덱스를 출력해준다.

argmax_dim_1 = torch.max(a, dim = 1)
print(a)
print(argmax_dim_1)
tensor([[-1.1839, -0.2547, -1.2708],
        [ 0.6713,  0.0494, -0.2655],
        [-0.3116,  0.7832, -1.2234],
        [-0.6373,  0.1265, -1.8030],
        [-0.1249, -1.7748,  1.2657]])
torch.return_types.max(
values=tensor([-0.2547,  0.6713,  0.7832,  0.1265,  1.2657]),
indices=tensor([1, 0, 1, 1, 2]))

dim = 1이면, 행을 기준(각 행마다)으로 최댓값과 인덱스를 동시에 출력해준다.

profile
Although. 그럼에도 불구하고.

0개의 댓글