tensorflow
모델을 tensorflow-js
모델로 변환하는 과정에서
ImportError: cannot import name 'shape_poly' from 'jax.experimental.jax2tf' (/usr/local/lib/python3.10/dist-packages/jax/experimental/jax2tf/__init__.py)
에러가 발생함
tensorflowjs_converter
가 JAX 라이브러리의 shape_poly를 찾을 수 없어 발생한 것
JAX는 최근에 새로운 버전으로 업데이트되었고, 이 업데이트로 인해 shape_poly 모듈이 변경되거나 삭제되었음
tensorflowjs가 JAX의 이전 버전에 의존하고 있으므로, 이 문제를 해결하기 위해 JAX를 이전 버전으로 다운그레이드 해야하며,
다음과 같이 특정 버전의 JAX를 설치해 해결하였음.
!pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html
참고로, tensorflowjs_converter는 window에서 지원되지 않으니 반드시 mac, linux 환경에서 실행하세요. (colab 활용 추천)