✏️ 공부하며 작성하므로 개선될 부분이 있다면 댓글로 알려주세요!
이번에는 텐서플로 라이트를 활용하여 안드로이드 앱에서 손글씨 숫자를 그리고 해당 숫자가 어떤 숫자인지 인식하는 공부를 해보려고 합니다.
_먼저 텐서플로우와 호환되는 CUDA와 cuDNN가 설치 되어 있다고 가정합니다. 환경 구축은 이 시리즈의 환경 구축 글을 참고해주세요!
개발 환경
python = 3.8
tensorflow = 2.10.0
minSDK = 28
compileSDK = 34
targetSDK = 34
kaggle 같은 곳을 가면 아주 잘 학습되어 정확도가 높은 모델들이 많습니다. 이런걸 가져다가 사용해도 라이선스에만 맞게 사용한다면 문제가 되지 않습니다. 그러나 저는 공부하는 입장이기 때문에 이번에는 모델을 직접 개발해보려고 합니다. 그러나 데이터셋을 지금부터 모으기에는 한계가 명확하므로 MNIST 데이터셋을 사용하겠습니다.
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
MNIST는 손글씨 데이터셋으로 위키피디아 피셜로 위와 같은 데이터들이 있다고 합니다.
파이썬으로 위 코드를 실행하면 사용자 폴더의 .keras/datasets
폴더가 생성되고 그 안에 데이터셋 파일이 생성됩니다.
이번에 만들 모델의 입력 데이터는 28X28 이미지이고 각 픽셀 값은 0~255 입니다. 이 입력 데이터를 0~1의 범위의 값으로 정규화 할겁니다. 정규화를 한다면 약 5%의 성능 향상에 도움이 되므로 정규화는 꼭 해야할 작업이 됩니다.
손글씨 숫자를 추론하기 위해 가장 단순한 인공 신경망이라고 하는 MLP(다층 퍼셉트론)을 사용해도 되지만 CNN(합성곱 신경망)을 적용하려고 합니다.
MLP는 입력층, 은닉층, 출력층으로 구성되는 인공 신경망이고 CNN은 합성곱층과 폴링층으로 구성됩니다. CNN은 지역성에 기반한 특성을 갖고 있으며 가중치가 공유되어 MLP에 비해 파라미터가 적은 특성으로 이미지 데이터에 적합합니다.
MLP 모델 학습
CNN 모델 학습
훈련된 모델의 성능은 MLP에 비해 CNN이 높은 정확도를 보입니다.
CNN 손글씨 모델 파이썬 코드
import tensorflow as tf
# CNN 손글씨 분류 모델 클래스
class CNN_Model(tf.keras.Model):
def __init__(self):
super(CNN_Model, self).__init__()
self.conv_2d_32 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1))
self.max_pooling_2d = tf.keras.layers.MaxPooling2D((2, 2))
self.conv_2d_64_1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu')
self.conv_2d_64_2 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu')
self.flatten = tf.keras.layers.Flatten()
self.dense_64 = tf.keras.layers.Dense(64, activation='relu')
self.softmax = tf.keras.layers.Dense(10, activation='softmax')
def call(self, inputs, training=None, mask=None):
x = self.conv_2d_32(inputs)
x = self.max_pooling_2d(x)
x = self.conv_2d_64_1(x)
x = self.max_pooling_2d(x)
x = self.conv_2d_64_2(x)
x = self.flatten(x)
x = self.dense_64(x)
return self.softmax(x)
# MNIST 데이터셋 불러오기
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 정규화
x_train, x_test = x_train / 255.0, x_test / 255.0
# 높이, 너비, 채널의 3차원 텐서를 사용하기 위한 변환
x_train_4d = x_train.reshape(-1, 28, 28, 1)
x_test_4d = x_test.reshape(-1, 28, 28, 1)
# 모델 객체 인스턴스 생성
cnn_model = CNN_Model()
# 모델 컴파일
cnn_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 모델 빌드
cnn_model.build(input_shape=(None, 28, 28, 1))
# 모델 요약
cnn_model.summary()
# 모델 훈련
cnn_model.fit(x_train_4d, y_train, epochs=5)
# 모델 테스트
eval_result = cnn_model.evaluate(x_test_4d, y_test, verbose=2)
print(eval_result)
tf.keras.model 클래스를 상속받아 객체지향 구조로 모델을 구현했습니다.
텐서플로에서는 모델을 SavedModel로 저장하고 불러오는 것을 권장한다고합니다.
텐서플로 라이트의 컨버터를 통해 모델을 간편하게 변환할 수 있습니다.
모델 저장 및 변환 파이썬 코드
# CNN 모델 저장
cnn_model.save("./cnn_model/")
# 모델 불러오기
saved_model = tf.keras.models.load_model("./cnn_model/")
# tflite 모델 변환
converter = tf.lite.TFLiteConverter.from_saved_model("./cnn_model/")
tflite_model = converter.convert()
# tflite 모델 저장
with open("./saved_cnn_model.tflite", "wb") as f:
f.write(tflite_model)
이렇게 tflite 파일이 저장됩니다.
파이썬 파일 전체 코드
import tensorflow as tf
# CNN 손글씨 분류 모델 클래스
class CNN_Model(tf.keras.Model):
def __init__(self):
super(CNN_Model, self).__init__()
self.conv_2d_32 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1))
self.max_pooling_2d = tf.keras.layers.MaxPooling2D((2, 2))
self.conv_2d_64_1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu')
self.conv_2d_64_2 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu')
self.flatten = tf.keras.layers.Flatten()
self.dense_64 = tf.keras.layers.Dense(64, activation='relu')
self.softmax = tf.keras.layers.Dense(10, activation='softmax')
def call(self, inputs, training=None, mask=None):
x = self.conv_2d_32(inputs)
x = self.max_pooling_2d(x)
x = self.conv_2d_64_1(x)
x = self.max_pooling_2d(x)
x = self.conv_2d_64_2(x)
x = self.flatten(x)
x = self.dense_64(x)
return self.softmax(x)
# MNIST 데이터셋 불러오기
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 정규화
x_train, x_test = x_train / 255.0, x_test / 255.0
# 높이, 너비, 채널의 3차원 텐서를 사용하기 위한 변환
x_train_4d = x_train.reshape(-1, 28, 28, 1)
x_test_4d = x_test.reshape(-1, 28, 28, 1)
# 모델 객체 인스턴스 생성
cnn_model = CNN_Model()
# 모델 컴파일
cnn_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 모델 빌드
cnn_model.build(input_shape=(None, 28, 28, 1))
# 모델 요약
cnn_model.summary()
# 모델 훈련
cnn_model.fit(x_train_4d, y_train, epochs=5)
# 모델 테스트
eval_result = cnn_model.evaluate(x_test_4d, y_test, verbose=2)
print(eval_result)
# CNN 모델 저장
cnn_model.save("./cnn_model/")
# 모델 불러오기
saved_model = tf.keras.models.load_model("./cnn_model/")
# tflite 모델 변환
converter = tf.lite.TFLiteConverter.from_saved_model("./cnn_model/")
tflite_model = converter.convert()
# tflite 모델 저장
with open("./saved_cnn_model.tflite", "wb") as f:
f.write(tflite_model)
안드로이드에서 사용 가능한 tflite 파일을 변환하였으니 이제 손글씨를 그릴 수 있는 뷰를 생성하고 모델을 사용하면 됩니다. 뷰는 AndroidDrawView를 사용합니다.
안드로이드에서 간단하게 터치하여 그릴 수 있도록 만들어주는 뷰입니다.
gradle 8.0 기준으로 작성됩니다.
settings.gradle
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
mavenCentral()
maven { url 'https://jitpack.io' }
}
}
build.gradle(Module :app)
implementation 'com.github.divyanshub024:AndroidDraw:v0.1'
gradle.properties
android.enableJetifier=true
gradle sync를 누르고 제대로 실행되는지 체크 후 레이아웃 작성
저는 메인 액티비티에서 작업했습니다.
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<Button
android:id="@+id/btn_classify"
android:layout_width="0dp"
android:layout_height="wrap_content"
app:layout_constraintTop_toTopOf="parent"
android:text="분석"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintEnd_toStartOf="@id/btn_clear"
android:layout_margin="10dp" />
<Button
android:id="@+id/btn_clear"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:text="지우기"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintStart_toEndOf="@id/btn_classify"
app:layout_constraintEnd_toEndOf="parent"
android:layout_margin="10dp" />
<com.divyanshu.draw.widget.DrawView
android:id="@+id/draw_view"
app:layout_constraintTop_toBottomOf="@id/btn_classify"
app:layout_constraintBottom_toTopOf="@id/text_view"
android:layout_width="match_parent"
android:layout_height="0dp" />
<TextView
android:textStyle="bold"
android:textSize="30sp"
android:gravity="center"
android:minHeight="100sp"
android:id="@+id/text_view"
app:layout_constraintTop_toBottomOf="@id/draw_view"
app:layout_constraintBottom_toBottomOf="parent"
android:layout_width="match_parent"
android:layout_height="wrap_content" />
</androidx.constraintlayout.widget.ConstraintLayout>
이 액티비티에서 손씨를 그리고 모델을 사용하여 추론할 예정입니다.
implementation "org.tensorflow:tensorflow-lite:2.8.0"
그리고 안드로이드에서 모델을 사용하기 위해 텐서플로 라이트 라이브러리를 추가합니다.
그리고 위에서 만든 모델(.tflite)을 안드로이드에서 불러서 쓰기 위해서 간편하게 사용할 수 있는 폴더를 만들겁니다.
프로젝트를 오른쪽 클릭 -> New -> Folder -> Assets Folder 클릭하면 Assets 폴더가 만들어지는데 여기에 복붙해줍니다.
안드로이드에서 텐서플로 라이트 모델을 동작시키기 위해서는 텐서플로 라이트에 존재하는 Interpreter를 통해 모델을 로드하고 동작시킵니다.
그래서 Interpreter를 이용하여 모델을 사용하는 분류기 클래스를 객체지향으로 작성하여 사용한다면 이해하기 쉽고 수정이 용이할 것입니다.
Classifier 클래스 코틀린 코드
class Classifier(private val context: Context) {
}
이 분류기 코드에서 모델을 로드하고 추론 결과를 뱉을겁니다.
생성자에 context를 넣고 init 블럭을 만들어서 모델을 로드하는 코드를 작성합니다.
분류기 초기화 코드
private val interpreter: Interpreter
private var modelInputWidth = 0
private var modelInputHeight = 0
private var modelInputChannel = 0
private var modelOutputClasses = 0
init {
val am = context.assets
val afd = am.openFd("cnn_model.tflite")
val fis = FileInputStream(afd.fileDescriptor)
val model = fis.channel.map(FileChannel.MapMode.READ_ONLY, afd.startOffset, afd.declaredLength)
model.order(ByteOrder.nativeOrder())
interpreter = Interpreter(model)
initModelShape()
}
private fun initModelShape() {
val inputTensor = interpreter.getInputTensor(0)
val inputShape = inputTensor.shape()
modelInputChannel = inputShape[0]
modelInputWidth = inputShape[1]
modelInputHeight = inputShape[2]
val outputTensor = interpreter.getOutputTensor(0)
val outputShape = outputTensor.shape()
modelOutputClasses = outputShape[1]
}
모델을 로드하고 Interpreter를 생성하였습니다.
이제 이미지를 입력 받고 추론하는 코드를 이어서 작성합니다.
이미지 전처리, 추론
fun classify(bitmap: Bitmap) : Pair<Int, Float> {
val buffer = convertBitmapToGrayByteBuffer(resizeBitmap(bitmap))
val result = arrayOf(FloatArray(modelOutputClasses))
interpreter.run(buffer, result)
return argmax(result[0])
}
private fun resizeBitmap(bitmap: Bitmap) : Bitmap {
return Bitmap.createScaledBitmap(bitmap, modelInputWidth, modelInputHeight, false)
}
private fun convertBitmapToGrayByteBuffer(bitmap: Bitmap) : ByteBuffer {
val byteBuffer = ByteBuffer.allocateDirect(bitmap.byteCount)
byteBuffer.order(ByteOrder.nativeOrder())
val pixels = IntArray(bitmap.width * bitmap.height)
bitmap.getPixels(pixels, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
for (pixel in pixels) {
val r = (pixel shr 16) and 0xFF
val g = (pixel shr 8) and 0xFF
val b = pixel and 0xFF
val avgPixelValue = (r + g + b) / 3.0f
val normalizedPixelValue = avgPixelValue / 255.0f
byteBuffer.putFloat(normalizedPixelValue)
}
return byteBuffer
}
private fun argmax(array: FloatArray) : Pair<Int, Float> {
val maxBy = array.withIndex().maxBy { it.value }
return Pair(maxBy.index, maxBy.value)
}
입력 받은 이미지를 모델 입력 크기인 28x28로 리스케일링하고 원래 모델의 학습 이미지에 맞춰서 그레이스케일링을 진행합니다. 그리고 0 ~ 1 값으로 정규화를 진행합니다. 모델을 학습할 때도 정규화를 진행해서 정확도를 향상 시켰기 때문이죠.
메인 액티비티 작성
package com.wnview.tensorflowlitestudy
import android.graphics.Color
import androidx.appcompat.app.AppCompatActivity
import android.os.Bundle
import android.util.Log
import com.wnview.tensorflowlitestudy.databinding.ActivityMainBinding
import org.opencv.android.OpenCVLoader
import org.opencv.android.Utils
import java.util.Locale
class MainActivity : AppCompatActivity() {
private lateinit var viewBinding: ActivityMainBinding
private lateinit var classifier: Classifier
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
viewBinding = ActivityMainBinding.inflate(layoutInflater)
setContentView(viewBinding.root)
viewBinding.drawView.setStrokeWidth(40.0f)
viewBinding.drawView.setBackgroundColor(Color.BLACK)
viewBinding.drawView.setColor(Color.WHITE)
viewBinding.btnClassify.setOnClickListener {
val bitmap = viewBinding.drawView.getBitmap()
val res = classifier.classify(bitmap)
val outStr = String.format(Locale.ENGLISH, "%d, %.0f%%", res.first, res.second * 100.0f)
viewBinding.textView.text = outStr
}
viewBinding.btnClear.setOnClickListener {
viewBinding.drawView.clearCanvas()
}
classifier = Classifier(this)
}
override fun onDestroy() {
super.onDestroy()
classifier.finish()
}
}
draw view와 분류기 클래스 간의 인터페이스를 작성합니다.
앱 결과
CNN 모델을 사용하여 추론된 결과는 MLP 모델을 사용한 것 보다 더 나은 결과를 보여줬습니다. 추론하는데 1ms 이하로 동작하니까 이렇게 간단한 작업은 굳이 서버를 통하지 않아도 모바일 기기 자체에서 가능하다는 것을 보여줬습니다.
텐서플로 라이트를 활용한 안드로이드 딥러닝
https://m.hanbit.co.kr/store/books/book_view.html?p_code=B2354289186