mps
로 주면 에러 발생TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
발생하는 원인은 현재 MPS에서는 float 32만 지원함. 따라서 mps일 floast 32로 변환해주면 됨
obs_as_tesnsor 함수에서 msp일 때 아래 줄 추가하면 됨
site-packates/stable_baseline3/common/utils.py > obs_as_tensor
if device.type == 'mps':
obs = obs.astype('float32') if isinstance(obs, np.ndarray) else obs.float()
def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]:
"""
Moves the observation to the given device.
:param obs:
:param device: PyTorch device
:return: PyTorch tensor of the observation on a desired device.
"""
if device.type == 'mps':
obs = obs.astype('float32') if isinstance(obs, np.ndarray) else obs.float()
if isinstance(obs, np.ndarray):
return th.as_tensor(obs, device=device)
elif isinstance(obs, dict):
return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()}
else:
raise Exception(f"Unrecognized type of observation {type(obs)}")