딥러닝 모델을 간편하게 실어 주는
flutter tensorflow lite 설치하는 라이브러리
puspec.yaml 파일에 의존성 추가해서 설치해 주기
ios setting
android setting
build.gradle 파일에 아래와 같은 코드 추가
aaptOptions {
noCompress 'tflite'
noCompress 'lite'
}
kaggle 에 있는 model 오픈소스 중에 해당 파일의 tfile 다운로드
프로젝트 파일 > assets 폴더 생성 후 안에 tfile 넣기
mnist.txt 파일 생성
class 이름 적어 주기
pubspec.yaml 파일에서 assets 추가하기
screeen 폴더, services 폴더 생성 후 파일 넣기
assets 파일에 있는 모델과 class 값들을 불러오는 model load 코드
import 'package:tflite/tflite.dart';
class Recognizer {
Future loadModel() {
Tflite.close();
return Tflite.loadModel(
model: "assets/mnist.tfile",
labels: "assets/mnist.txt",
);
}
}
main.dart 파일 수정하기
main.dart 파일에서 앱 테마 변경하기
constants 파일 생성해서 상수 먼저 지정하기 (캔버스 사이즈, 패딩값 등등)
class Constants {
static double canvasSize = 300;
static double borderSize = 2;
static double imageSize = 300;
static int mnistImageSize = 28;
static double strokeWidth = 8;
}
숫자를 그리는 painting 부분의 캔버스를 그리는 위젯 파일을 생성하고 다음과 같이 작성
import 'package:flutter/material.dart';
import 'package:mnistdigitrecognizer/utils/constants.dart';
class DrawingPainter extends CustomPainter {
final List<Offset> points;
DrawingPainter(this.points);
final Paint _paint = Paint()
..strokeCap = StrokeCap.round
..color = Colors.black
..strokeWidth = Constants.strokeWidth;
void paint(Canvas canvas, Size size) {
for (int i = 0; i < points.length - 1; i++) {
if (points[i] != null && points[i + 1] != null) {
canvas.drawLine(points[i], points[i + 1], _paint);
}
}
}
bool shouldRepaint(CustomPainter oldDelegate) {
return true;
}
}
움직임을 기록하는 포인터 변수 생성
컨테이너 안에 움직임을 감지하기 위해 gesture detection 위젯 넣기
아래와 같이 포인터가 마우스 모양으로 움직임을 따고 있는 모습을 볼 수 있음 (화살표 모양)
마우스가 이동할 때마다 현재의 위치 (Offset)을 list에 저장할 수 있도록 설정하는 코드 작성
아래와 같이 그림이 그려지는 모습을 확인할 수 있음
캔버스 크기 지정하기 위해 boxdecoration으로 border 주기
캔버스 안에서만 그림 그려질 수 있도록 하기 위해서 조건문 사용
import 'dart:ui';
import 'package:flutter/material.dart';
import 'package:tflite/tflite.dart';
import '../utils/constants.dart';
final _canvasCullRect = Rect.fromPoints(
Offset(0, 0),
Offset(Constants.imageSize, Constants.imageSize),
);
final _whitePaint = Paint()
..strokeCap = StrokeCap.round
..color = Colors.white
..strokeWidth = Constants.strokeWidth;
final _bgPaint = Paint()
..color = Colors.black;
class Recognizer {
Future loadModel() {
Tflite.close();
return Tflite.loadModel(
model: "assets/mnist.tfile",
labels: "assets/mnist.txt",
);
}
Future recognize(List<Offset> points) async {}
Picture _pointsToPicture(List<Offset> points) {
final recorder = PictureRecorder();
final canvas = Canvas(recorder, _canvasCullRect)
..scale(Constants.mnistImageSize / Constants.canvasSize);
canvas.drawRect(
Rect.fromLTWH(0, 0, Constants.imageSize, Constants.imageSize),
_bgPaint);
for (int i = 0; i < points.length - 1; i++) {
if (points[i] != null && points[i + 1] != null) {
canvas.drawLine(points[i], points[i + 1], _whitePaint);
}
}
}
}
그림 숫자를 unit8list 형태로 바꾸는 함수인 _imageToByteListUnit8 작성
tfLite.runModelOnBinary를 이용해 예측하는 함수인 _predict 작성
_predict 값을 반환해 예측 결과를 말해 주는 recognize 함수 작성
Future recognize(List<Offset?> points) async {
final picture = _pointsToPicture(points);
Uint8List bytes = await _imageToByteListUint8(
picture, Constants.mnistImageSize);
return _predict(bytes);
}
Future _predict(Uint8List bytes) {
return Tflite.runModelOnBinary(binary: bytes);
}
Future<Uint8List> _imageToByteListUint8(Picture pic, int size) async {
final img = await pic.toImage(size, size);
final imgBytes = await img.toByteData();
final resultBytes = Float32List(size * size);
final buffer = Float32List.view(resultBytes.buffer);
int index = 0;
for (int i = 0; i < imgBytes!.lengthInBytes; i += 4) {
final r = imgBytes?.getUint8(i);
final g = imgBytes?.getUint8(i + 1);
final b = imgBytes?.getUint8(i + 2);
buffer[index++] = (r! + g! + b!) / 3.0 / 255.0;
}
return resultBytes.buffer.asUint8List();
}
draw_screen.dart 에 prediction한 값을 가져오는 함수 설정
void _recognize() async {
List<dynamic> pred = await _recognizer.recognize(_points);
print(pred);
}
예측 결과를 저장하는 Prediction class를 만들기 위해 models 폴더 안에 prediction.dart 파일 작성 (객체 생성)
class Prediction {
final double confidence;
final int index;
final String label;
Prediction({required this.confidence, required this.index, required this.label});
factory Prediction.fromJson(Map<dynamic, dynamic> json) {
return Prediction(
confidence: json['confidence'],
index: json['index'],
label: json['label'],
);
}
}
draw_screen.dart 에서 객체 생성
인식했을 때 json 파일로 prediction 값 받아 오는 setstate 구문 생성
오류가 발생하면 final → var로 변경 뒤 initialize 변수 추가
데이터 관리를 위해 소멸자 생성
이미지를 프리뷰로 보기 위해서 recognizer.dart 파일에 previewImage 함수 생성하기
Future<Uint8List> previewImage(List<Offset?> points) async {
final picture = _pointsToPicture(points);
final image = await picture.toImage(Constants.mnistImageSize, Constants.mnistImageSize);
var pngBytes = await image.toByteData(format: ImageByteFormat.png);
return pngBytes!.buffer.asUint8List();
}
그려지는 이미지를 우측 상단에 미리보기로 나타나게 하기 위해 위젯을 설정해 주기
Widget _mnistPreviewImage() {
return Container(
width: 100,
height: 100,
color: Colors.black,
child: FutureBuilder(
future: _previewImage(),
builder: (BuildContext _, snapshot) {
if (snapshot.hasData) {
return Image.memory(
snapshot.data!,
fit: BoxFit.fill,
);
} else {
return Center(
child: Text('Error'),
);
}
},
),
);
}
Future<Uint8List> _previewImage() async {
return await _recognizer.previewImage(_points);
}
반환하는 메인 위젯에 추가해 주기
새로 list를 만들고 이 안에 index 값마다 차례대로 해당 숫자가 그려졌다고 판단할 정확도 넣기
숫자 하나를 그리는 위젯 생성하기
숫자 전체를 엮어서 밑에 넣고 정확도가 높은 숫자들은 색깔을 빨간색으로 바꾸도록 설정하기
화면 기록 2023-01-10 오전 3.36.27.mov
class Prediction {
final double? confidence;
final int? index;
final String? label;
Prediction({this.confidence, this.index, this.label});
factory Prediction.fromJson(Map<dynamic, dynamic> json) {
return Prediction(
confidence: json['confidence'],
index: json['index'],
label: json['label'],
);
}
}
import 'dart:typed_data';
import 'package:flutter/material.dart';
import 'package:mnistdigitrecognizer/screens/drawing_painter.dart';
import 'package:tflite/tflite.dart';
import '../models/prediction.dart';
import '../services/recognizer.dart';
import '../utils/constants.dart';
class DrawScreen extends StatefulWidget {
const DrawScreen({Key? key}) : super(key: key);
State<DrawScreen> createState() => _DrawScreenState();
}
class _DrawScreenState extends State<DrawScreen> {
final _points = <Offset?>[];
final _recognizer = Recognizer();
var _prediction = <Prediction>[];
bool initialize = false;
void initState() {
// TODO: implement initState
super.initState();
}
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: Text('Digit Recognizer'),
),
body: Column(
children: <Widget>[
Row(
children: <Widget>[
Expanded(
child: Padding(
padding: const EdgeInsets.all(8.0),
child: Column(
children: <Widget>[
Text(
'MNIST database of handwritten digits',
style: TextStyle(
fontWeight: FontWeight.bold,
),
),
Text(
'The digits have been size-normalized and centered in a fixed-size images (28 x 28)',
)
],
),
),
),
_mnistPreviewImage(),
],
),
SizedBox(
height: 10,
),
Container(
width: Constants.canvasSize + Constants.borderSize * 2,
height: Constants.canvasSize + Constants.borderSize * 2,
decoration: BoxDecoration(
border: Border.all(
color: Colors.black,
width: Constants.borderSize,
),
),
child: GestureDetector(
onPanUpdate: (DragUpdateDetails details) {
Offset _localPosition = details.localPosition;
if (_localPosition.dx >= 0 &&
_localPosition.dx <= Constants.canvasSize &&
_localPosition.dy >= 0 &&
_localPosition.dy <= Constants.canvasSize) {
setState(() {
_points.add(_localPosition);
});
}
},
onPanEnd: (DragEndDetails details) {
_points.add(null);
_recognize();
},
child: CustomPaint(
painter: DrawingPainter(_points),
),
),
),
//PredictionWidget(
//predictions: _prediction,
//)
],
),
floatingActionButton: FloatingActionButton(
child: Icon(Icons.clear),
onPressed: () {
_points.clear();
}
)
);
}
dispose() {
Tflite.close();
}
Widget _mnistPreviewImage() {
return Container(
width: 100,
height: 100,
color: Colors.black,
child: FutureBuilder(
future: _previewImage(),
builder: (BuildContext _, snapshot) {
if (snapshot.hasData) {
return Image.memory(
snapshot.data!,
fit: BoxFit.fill,
);
} else {
return Center(
child: Text('Error'),
);
}
},
),
);
}
Future<Uint8List> _previewImage() async {
return await _recognizer.previewImage(_points);
}
void _initModel() async {
var res = await _recognizer.loadModel();
}
void _recognize() async {
List<dynamic> pred = await _recognizer.recognize(_points);
setState(() {
_prediction = pred.map((json) => Prediction.fromJson(json)).toList();
});
}
}
import 'package:flutter/material.dart';
import 'package:mnistdigitrecognizer/utils/constants.dart';
class DrawingPainter extends CustomPainter {
final List<Offset?> points;
DrawingPainter(this.points);
final Paint _paint = Paint()
..strokeCap = StrokeCap.round
..color = Colors.black
..strokeWidth = Constants.strokeWidth;
void paint(Canvas canvas, Size size) {
for (int i = 0; i < points.length - 1; i++) {
if (points[i] != null && points[i + 1] != null) {
canvas.drawLine(points[i]!, points[i + 1]!, _paint);
}
}
}
bool shouldRepaint(CustomPainter oldDelegate) {
return true;
}
}
import 'package:flutter/material.dart';
import '../models/prediction.dart';
class PredictionWidget extends StatelessWidget {
final List<Prediction> predictions;
const PredictionWidget({Key? key, required this.predictions}) : super(key: key);
Widget _numberWidget(int num, Prediction prediction) {
return Column(
children: <Widget>[
Text(
'$num',
style: TextStyle(
fontSize: 60,
fontWeight: FontWeight.bold,
color: prediction == null
? Colors.black
: Colors.red.withOpacity(
(prediction.confidence! * 2).clamp(0, 1).toDouble(),
),
),
),
Text(
'${prediction == null ? '' : prediction.confidence?.toStringAsFixed(3)}',
style: TextStyle(
fontSize: 12,
),
)
],
);
}
List<dynamic> getPredictionStyles(List<Prediction> predictions) {
List<dynamic> data = [
null,
null,
null,
null,
null,
null,
null,
null,
null,
null
];
predictions.forEach((prediction) {
data[prediction.index!] = prediction;
});
return data;
}
Widget build(BuildContext context) {
var styles = getPredictionStyles(this.predictions);
return Column(
children: <Widget>[
Row(
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
children: <Widget>[
for (var i = 0; i < 5; i++) _numberWidget(i, styles[i])
],
),
Row(
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
children: <Widget>[
for (var i = 5; i < 10; i++) _numberWidget(i, styles[i])
],
)
],
);
}
}
import 'dart:typed_data';
import 'dart:ui';
import 'package:flutter/material.dart';
import 'package:tflite/tflite.dart';
import '../utils/constants.dart';
final _canvasCullRect = Rect.fromPoints(
Offset(0, 0),
Offset(Constants.imageSize, Constants.imageSize),
);
final _whitePaint = Paint()
..strokeCap = StrokeCap.round
..color = Colors.white
..strokeWidth = Constants.strokeWidth;
final _bgPaint = Paint()
..color = Colors.black;
class Recognizer {
Future loadModel() {
Tflite.close();
return Tflite.loadModel(
model: "assets/mnist.tfile",
labels: "assets/mnist.txt",
);
}
Future recognize(List<Offset?> points) async {
final picture = _pointsToPicture(points);
Uint8List bytes = await _imageToByteListUint8(
picture, Constants.mnistImageSize);
return _predict(bytes);
}
Future _predict(Uint8List bytes) {
return Tflite.runModelOnBinary(binary: bytes);
}
Future<Uint8List> _imageToByteListUint8(Picture pic, int size) async {
final img = await pic.toImage(size, size);
final imgBytes = await img.toByteData();
final resultBytes = Float32List(size * size);
final buffer = Float32List.view(resultBytes.buffer);
int index = 0;
for (int i = 0; i < imgBytes!.lengthInBytes; i += 4) {
final r = imgBytes?.getUint8(i);
final g = imgBytes?.getUint8(i + 1);
final b = imgBytes?.getUint8(i + 2);
buffer[index++] = (r! + g! + b!) / 3.0 / 255.0;
}
return resultBytes.buffer.asUint8List();
}
Future<Uint8List> previewImage(List<Offset?> points) async {
final picture = _pointsToPicture(points);
final image = await picture.toImage(Constants.mnistImageSize, Constants.mnistImageSize);
var pngBytes = await image.toByteData(format: ImageByteFormat.png);
return pngBytes!.buffer.asUint8List();
}
Picture _pointsToPicture(List<Offset?> points) {
final recorder = PictureRecorder();
final canvas = Canvas(recorder, _canvasCullRect)
..scale(Constants.mnistImageSize / Constants.canvasSize);
canvas.drawRect(
Rect.fromLTWH(0, 0, Constants.imageSize, Constants.imageSize),
_bgPaint);
for (int i = 0; i < points.length - 1; i++) {
if (points[i] != null && points[i + 1] != null) {
canvas.drawLine(points[i]!, points[i + 1]!, _whitePaint);
}
}
return recorder.endRecording();
}
}
class Constants {
static double canvasSize = 300;
static double borderSize = 2;
static double imageSize = 300;
static int mnistImageSize = 28;
static double strokeWidth = 8;
}
import 'package:flutter/material.dart';
import 'package:mnistdigitrecognizer/screens/draw_screen.dart';
void main() {
runApp(const MyApp());
}
class MyApp extends StatelessWidget {
const MyApp({super.key});
// This widget is the root of your application.
Widget build(BuildContext context) {
return MaterialApp(
title: 'MNIST DIGIT RECOGNIZER',
theme: ThemeData(
appBarTheme: AppBarTheme(
color: Colors.black,
),
floatingActionButtonTheme: FloatingActionButtonThemeData(
backgroundColor: Colors.black,
),
),
home: DrawScreen(),
);
}
}