Stable Baseline 3 MPS Error

Hansol Kang·2024년 7월 4일
0

SB3

목록 보기
1/1
post-thumbnail

MPS에서 float 64 아직 미지원

  • device를 mps로 주면 에러 발생

에러 내용

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

발생하는 원인은 현재 MPS에서는 float 32만 지원함. 따라서 mps일  floast 32로 변환해주면 됨
obs_as_tesnsor 함수에서 msp일 때 아래 줄 추가하면 됨

site-packates/stable_baseline3/common/utils.py > obs_as_tensor

if device.type == 'mps':
	obs = obs.astype('float32') if isinstance(obs, np.ndarray) else obs.float()

최종본

def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]:
"""
Moves the observation to the given device.
:param obs:
:param device: PyTorch device
:return: PyTorch tensor of the observation on a desired device.
"""
if device.type == 'mps':
	obs = obs.astype('float32') if isinstance(obs, np.ndarray) else obs.float()
if isinstance(obs, np.ndarray):
	return th.as_tensor(obs, device=device)
elif isinstance(obs, dict):
	return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()}
else:
	raise Exception(f"Unrecognized type of observation {type(obs)}")
profile
놀면 뭐하니 정리해

0개의 댓글

관련 채용 정보