아래 내용은 pytorch 2.1.0 버전으로 작성됨.
0. 바로 사용하기
torch.topk(input, k)
input
: 가장 큰 k개의 값과 index를 반환할 tensor
k
: 반환할 상위 요소의 갯수
1. 기본형
torch.topk(input,
k,
dim=None,
largest=True,
sorted=True,
out=None)
2. 기능
- 입력 받은 tensor의 상위 k개의 index와 value를 반환함.
import torch
x = torch.tensor([1, 2, 3, 4, 5])
values, indices = torch.topk(x, k=3)
print("Values:", values)
print("Indices:", indices)
3. 파라미터
- 상위 k개의 index, value를 반환할 tensor
k
k = integer
: 반환할 상위 요소의 갯수
dim
dim = integer
: 상위 요소를 찾을 차원값
largest
largest = True
: 가장 큰 k개의 값을 반환
largest = False
: 가장 작은 k개의 값을 반환
sorted
sorted = True
: 값을 내림차순으로 정렬하여 반환함.
sorted = False
: 값을 내림차순으로 정렬 하지 않고 반환함.
ref