🖥️ PyTorch MPS 연산 시 주의사항 정리
🔹 들어가며
Apple Silicon(M1/M2) 환경에서 PyTorch는 MPS(Metal Performance Shaders) 백엔드를 통해 GPU 가속을 지원합니다.
특히 3D 연산(Conv3d, ConvTranspose3d 등)과 Core ML 연동을 고려하는 경우, 버전별 지원 범위와 제약 사항을 반드시 확인해야 합니다.
아래는 PyTorch(MPS) 사용 시 주의할 점을 정리한 가이드입니다.
🔹 Conv3d 지원 현황 (버전별)
-
PyTorch 2.1~2.2 전후 → Conv3d의 초기 지원 시작
-
PyTorch 2.3~2.5 → 지원 범위 확장, 안정성 개선
-
PyTorch 2.7~2.8 → fp32 forward는 대체로 정상 동작.
- 단, AMP/autocast,
bfloat16/float16 dtype, 특정 kernel/stride 조합, backward 연산 등은 여전히 제약이 있을 수 있음.
👉 따라서 “정확히 몇 버전부터 완벽 지원”이라고 단정하기보다는,
2.2 근처에서 초기 지원, 2.4~2.8로 오며 안정화라고 이해하는 게 현실적입니다.
🔹 현실적인 권장사항
-
직접 테스트
- 현재 사용 중인 PyTorch/LibTorch 버전, macOS 버전, Metal 드라이버, dtype, 배치 크기, AMP 여부 등 조건을 동일하게 맞춰 간단한 Conv3d 모듈을 실행해 보는 것이 가장 확실합니다.
-
LibTorch 배포 시
- Python(MPS)에서 최소 재현 코드 테스트
- TorchScript export
- 동일 시스템에서 LibTorch 실행 검증
-
AMP/autocast 활용 시
- 먼저 fp32에서 확인 → 이후 AMP를 단계적으로 켜며 문제 지점 파악
🔹 ConvTranspose3d 및 기타 3D 연산
- Conv3d보다 지원 도입 시기가 늦고 제약이 많았던 이력이 있음.
- 따라서 UNet3D 같은 모델을 쓸 경우 ConvTranspose3d 별도 검증 필수.
🔹 복소수(complex) 연산 주의
- MPS와 Core ML 모두 복소수 타입을 직접 지원하지 않음.
torch.complex64 등을 사용하면 MPS 백엔드에서 에러 발생.
- Core ML 변환 시에도 complex dtype이 포함되면 변환 자체가 실패.
👉 해결 방법
- 실수/허수 채널 분리 →
[real, imag] 채널로 분리 후 수학적으로 풀어 작성
- Magnitude/Phase 변환 → FFT 결과를 magnitude/phase 실수 피처로 변환
🔹 요약 가이드
- Conv3d: PyTorch 2.2 전후부터 지원, 2.4~2.8에서 안정성 ↑
- LibTorch: TorchScript 기반으로 MPS 테스트 후 배포 권장
- AMP/autocast: 단계적으로 확인 필요
- ConvTranspose3d: Conv3d보다 제약 많음 → 반드시 별도 검증
- Complex dtype: MPS/Core ML 모두 미지원 → 실수 변환 필요
✨ 결론
PyTorch MPS는 Apple Silicon 환경에서 GPU 가속의 중요한 옵션이지만,
아직까지는 버전/연산/조건별 제약이 남아 있으므로 “직접 테스트 후 사용하는 것”이 최선입니다.
특히 3D 연산(Conv3d, ConvTranspose3d)과 복소수 연산은 반드시 사전 검증이 필요합니다.