[알게된 것] ImportError: cannot import name 'shape_poly' from 'jax.experimental.jax2tf'

Chobby·2024년 1월 7일
1

😚문제상황

tensorflow 모델을 tensorflow-js 모델로 변환하는 과정에서

ImportError: cannot import name 'shape_poly' from 'jax.experimental.jax2tf' (/usr/local/lib/python3.10/dist-packages/jax/experimental/jax2tf/__init__.py)

에러가 발생함

😎해결방법

tensorflowjs_converter가 JAX 라이브러리의 shape_poly를 찾을 수 없어 발생한 것
JAX는 최근에 새로운 버전으로 업데이트되었고, 이 업데이트로 인해 shape_poly 모듈이 변경되거나 삭제되었음
tensorflowjs가 JAX의 이전 버전에 의존하고 있으므로, 이 문제를 해결하기 위해 JAX를 이전 버전으로 다운그레이드 해야하며,
다음과 같이 특정 버전의 JAX를 설치해 해결하였음.

!pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html

참고로, tensorflowjs_converter는 window에서 지원되지 않으니 반드시 mac, linux 환경에서 실행하세요. (colab 활용 추천)

profile
내 지식을 공유할 수 있는 대담함

0개의 댓글