[DL] Mac M1 PyTorch GPU 가속

Jungwoo Kim·2024년 1월 25일
0

ML/DL

목록 보기
5/5
! conda install pytorch -c pytorch-nightly

import torch

print(torch.backends.mps.is_built())
print(torch.backends.mps.is_available())

TRUE로 출력시 가속 가능

import torch

mps_device = torch.device("mps")

x = torch.ones(5, device=mps_device)
# or
x = torch.ones(5, device="mps")

# GPU 상에서 연산
y = x * 2

# 또는, 다른 장치와 마찬가지로 MPS로 이동 가능
model = YourFavoriteNet()  # 어떤 모델의 객체를 생성한 뒤에
model.to(mps_device)       # MPS 장치로 이동

# 이제 모델과 텐서를 호출하면 GPU에서 연산
pred = model(x)

0개의 댓글