[Rasberry Pi 4] Tensorflow lite MobileBert 사용하기

김지원·2022년 11월 2일
0

RasberryPi4

목록 보기
3/4

BERT Question Answer with TensorFlow Lite Model Maker

https://github.com/google-research/bert#tokenization
https://github.com/gemde001/MobileBERT

https://stackoverflow.com/questions/59759522/mobilebert-from-tensorflow-lite-in-python

💡 MobileBert Tensorflow Lite 모델 다운

tfhub 사이트 에서 아래 TFLite(v1, metadata) 를 다운 받았다.

💡 Rasberrypi 4 Tensorflow 설치하기

Install tensorflow==2.4.0-rc2 with pip (requires Python 3.7, Raspberry Pi 4)

## python version 3.7.3

$ pip install https://github.com/bitsy-ai/tensorflow-arm-bin/releases/downl

Requirements

!pip install tflite_runtime
!pip install transformers
!pip install bert-for-tf2

💡 참고한 코드

Github 페이지
Python API for interacting with pre-trained tflite BERT model provided by Tensorflow

BertQuestionAnswerer API 주요 특징

  • 두 개의 텍스트 입력을 1. question 및 2. context로 받아서 가능한 답변 목록을 출력
  • 입력 텍스트에서 그래프 외 Wordpiece 또는 Sentencepiece 토큰화를 수행합니다.

지원되는 BertQuestionAnswerer 모델

다음 모델은 BertNLClassifier API와 호환됩니다.

  • 질문 응답기를 위한 TensorFlow Lite Model Maker에서 만든 모델
  • TensorFlow Hub에서 사전 훈련된 BERT 모델
  • 모델 호환성 요구 사항을 충족하는 사용자 정의 모델

모델 호환성 요구 사항

BertQuestionAnswerer API는 필수 TFLite 모델 메타데이터가 있는 TFLite 모델을 예상합니다.메타 데이터는 다음 요구 사항을 충족해야 합니다.

  • Wordpiece/Sentencepiece Tokenizer를 위한 input_process_units
  • Tokenizer의 출력을 위한 이름이 "ids", "mask" 및 "segment_ids"인 3개의 입력 텐서
  • 컨텍스트에서 답변의 상대적 위치를 나타내는 이름이 "end_logits"및 "start_logits"인 2개의 출력 텐서
import numpy as np
import tensorflow as tf

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="mobilebert_float_20191023.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)

input_data = np.array(np.random.random_sample(input_shape), dtype=np.int32)

interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
profile
Make your lives Extraordinary!

0개의 댓글