Flutter (21) Flutter + Deep Learning

Huisu·2023년 3월 11일
1

Flutter

목록 보기
21/21
post-thumbnail

  • 교육 주제: 나의 관심사
  • 세부 실습명: 나의 관심사 통계 페이지 UI 구현
  • 세부 실습목표: 통계 라이브러리를 설치해 페이지 안에 표 그리기
  • 세부 실습내용 요약: (아래는 예시입니다)
    • 위젯 페이지를 만들어서 페이지 그리기
    • fl_chart library 사용하기
    • 나의 관심사를 나타내는 대시보드 형식의 페이지 구현하기

TensorFlow Lite

  • 딥러닝 모델을 간편하게 실어 주는

  • flutter tensorflow lite 설치하는 라이브러리

    tflite | Flutter Package

  • puspec.yaml 파일에 의존성 추가해서 설치해 주기

  • ios setting

    • build.gradle 폴더에서 minsdkversion 바꿔서 버전 맞추기
  • android setting

    • build.gradle 파일에 아래와 같은 코드 추가

      aaptOptions {
              noCompress 'tflite'
              noCompress 'lite'
          }

Kaggle MNIST TensorFlow Model

  • kaggle 에 있는 model 오픈소스 중에 해당 파일의 tfile 다운로드

  • 프로젝트 파일 > assets 폴더 생성 후 안에 tfile 넣기

  • mnist.txt 파일 생성

  • class 이름 적어 주기

  • pubspec.yaml 파일에서 assets 추가하기

  • screeen 폴더, services 폴더 생성 후 파일 넣기

TFLite Model Load

  • assets 파일에 있는 모델과 class 값들을 불러오는 model load 코드

    import 'package:tflite/tflite.dart';
    
    class Recognizer {
      Future loadModel() {
        Tflite.close();
        return Tflite.loadModel(
            model: "assets/mnist.tfile",
            labels: "assets/mnist.txt",
        );
      }
    }
  • main.dart 파일 수정하기

  • main.dart 파일에서 앱 테마 변경하기

  • constants 파일 생성해서 상수 먼저 지정하기 (캔버스 사이즈, 패딩값 등등)

    class Constants {
      static double canvasSize = 300;
      static double borderSize = 2;
    
      static double imageSize = 300;
      static int mnistImageSize = 28;
    
      static double strokeWidth = 8;
    }

draw_screen.dart

  • recognizer 받아 오는 변수 생성하고 초기화하기
  • 모델을 초기화하는 함수 생성
  • Appbar 생성하기
  • 버튼 생성하기
  • mnist 분석기에 대해 설명하는 위젯 생성하기

drawing_canvus.dart

  • 숫자를 그리는 painting 부분의 캔버스를 그리는 위젯 파일을 생성하고 다음과 같이 작성

    import 'package:flutter/material.dart';
    import 'package:mnistdigitrecognizer/utils/constants.dart';
    
    class DrawingPainter extends CustomPainter {
      final List<Offset> points;
    
      DrawingPainter(this.points);
    
      final Paint _paint = Paint()
        ..strokeCap = StrokeCap.round
        ..color = Colors.black
        ..strokeWidth = Constants.strokeWidth;
    
      
      void paint(Canvas canvas, Size size) {
        for (int i = 0; i < points.length - 1; i++) {
          if (points[i] != null && points[i + 1] != null) {
            canvas.drawLine(points[i], points[i + 1], _paint);
          }
        }
      }
    
      
      bool shouldRepaint(CustomPainter oldDelegate) {
        return true;
      }
    }
  • 움직임을 기록하는 포인터 변수 생성

  • 컨테이너 안에 움직임을 감지하기 위해 gesture detection 위젯 넣기

  • 아래와 같이 포인터가 마우스 모양으로 움직임을 따고 있는 모습을 볼 수 있음 (화살표 모양)

  • 마우스가 이동할 때마다 현재의 위치 (Offset)을 list에 저장할 수 있도록 설정하는 코드 작성

  • 아래와 같이 그림이 그려지는 모습을 확인할 수 있음

  • 캔버스 크기 지정하기 위해 boxdecoration으로 border 주기

  • 캔버스 안에서만 그림 그려질 수 있도록 하기 위해서 조건문 사용

  • recognize.dart 파일에서도 points들의 움직임을 받아올 수 있도록 구현
  • 변수 생성하기
  • points들의 움직임을 그림으로 바꿔 주는 함수 작성
  • 전체 코드
    import 'dart:ui';
    import 'package:flutter/material.dart';
    import 'package:tflite/tflite.dart';
    import '../utils/constants.dart';
    
    final _canvasCullRect = Rect.fromPoints(
      Offset(0, 0),
      Offset(Constants.imageSize, Constants.imageSize),
    );
    
    final _whitePaint = Paint()
      ..strokeCap = StrokeCap.round
      ..color = Colors.white
      ..strokeWidth = Constants.strokeWidth;
    
    final _bgPaint = Paint()
      ..color = Colors.black;
    
    class Recognizer {
      Future loadModel() {
        Tflite.close();
        return Tflite.loadModel(
          model: "assets/mnist.tfile",
          labels: "assets/mnist.txt",
        );
      }
    
      Future recognize(List<Offset> points) async {}
    
      Picture _pointsToPicture(List<Offset> points) {
        final recorder = PictureRecorder();
        final canvas = Canvas(recorder, _canvasCullRect)
          ..scale(Constants.mnistImageSize / Constants.canvasSize);
    
        canvas.drawRect(
            Rect.fromLTWH(0, 0, Constants.imageSize, Constants.imageSize),
            _bgPaint);
    
        for (int i = 0; i < points.length - 1; i++) {
          if (points[i] != null && points[i + 1] != null) {
            canvas.drawLine(points[i], points[i + 1], _whitePaint);
          }
        }
      }
    }

Digit Recognize

  • 그림 숫자를 unit8list 형태로 바꾸는 함수인 _imageToByteListUnit8 작성

  • tfLite.runModelOnBinary를 이용해 예측하는 함수인 _predict 작성

  • _predict 값을 반환해 예측 결과를 말해 주는 recognize 함수 작성

    Future recognize(List<Offset?> points) async {
        final picture = _pointsToPicture(points);
        Uint8List bytes = await _imageToByteListUint8(
            picture, Constants.mnistImageSize);
        return _predict(bytes);
      }
    
      Future _predict(Uint8List bytes) {
        return Tflite.runModelOnBinary(binary: bytes);
      }
    
      Future<Uint8List> _imageToByteListUint8(Picture pic, int size) async {
        final img = await pic.toImage(size, size);
        final imgBytes = await img.toByteData();
        final resultBytes = Float32List(size * size);
        final buffer = Float32List.view(resultBytes.buffer);
    
        int index = 0;
    
        for (int i = 0; i < imgBytes!.lengthInBytes; i += 4) {
          final r = imgBytes?.getUint8(i);
          final g = imgBytes?.getUint8(i + 1);
          final b = imgBytes?.getUint8(i + 2);
          buffer[index++] = (r! + g! + b!) / 3.0 / 255.0;
        }
    
        return resultBytes.buffer.asUint8List();
      }
  • draw_screen.dart 에 prediction한 값을 가져오는 함수 설정

    void _recognize() async {
        List<dynamic> pred = await _recognizer.recognize(_points);
        print(pred);
      }
  • 예측 결과를 저장하는 Prediction class를 만들기 위해 models 폴더 안에 prediction.dart 파일 작성 (객체 생성)

    class Prediction {
      final double confidence;
      final int index;
      final String label;
    
      Prediction({required this.confidence, required this.index, required this.label});
    
      factory Prediction.fromJson(Map<dynamic, dynamic> json) {
        return Prediction(
          confidence: json['confidence'],
          index: json['index'],
          label: json['label'],
        );
      }
    }
  • draw_screen.dart 에서 객체 생성

  • 인식했을 때 json 파일로 prediction 값 받아 오는 setstate 구문 생성

  • 오류가 발생하면 final → var로 변경 뒤 initialize 변수 추가

  • 데이터 관리를 위해 소멸자 생성

Preview (option)

  • 이미지를 프리뷰로 보기 위해서 recognizer.dart 파일에 previewImage 함수 생성하기

    Future<Uint8List> previewImage(List<Offset?> points) async {
        final picture = _pointsToPicture(points);
        final image = await picture.toImage(Constants.mnistImageSize, Constants.mnistImageSize);
        var pngBytes = await image.toByteData(format: ImageByteFormat.png);
    
        return pngBytes!.buffer.asUint8List();
      }
  • 그려지는 이미지를 우측 상단에 미리보기로 나타나게 하기 위해 위젯을 설정해 주기

    Widget _mnistPreviewImage() {
        return Container(
          width: 100,
          height: 100,
          color: Colors.black,
          child: FutureBuilder(
            future: _previewImage(),
            builder: (BuildContext _, snapshot) {
              if (snapshot.hasData) {
                return Image.memory(
                  snapshot.data!,
                  fit: BoxFit.fill,
                );
              } else {
                return Center(
                  child: Text('Error'),
                );
              }
            },
          ),
        );
      }
    
      Future<Uint8List> _previewImage() async {
        return await _recognizer.previewImage(_points);
      }
    • 반환하는 메인 위젯에 추가해 주기

Prediction

  • 새로 list를 만들고 이 안에 index 값마다 차례대로 해당 숫자가 그려졌다고 판단할 정확도 넣기

  • 숫자 하나를 그리는 위젯 생성하기

  • 숫자 전체를 엮어서 밑에 넣고 정확도가 높은 숫자들은 색깔을 빨간색으로 바꾸도록 설정하기

실행 화면

화면 기록 2023-01-10 오전 3.36.27.mov

전체 코드

  • prediction.dart
    class Prediction {
      final double? confidence;
      final int? index;
      final String? label;
    
      Prediction({this.confidence, this.index, this.label});
    
      factory Prediction.fromJson(Map<dynamic, dynamic> json) {
        return Prediction(
          confidence: json['confidence'],
          index: json['index'],
          label: json['label'],
        );
      }
    }
  • draw_screen.dart
    import 'dart:typed_data';
    import 'package:flutter/material.dart';
    import 'package:mnistdigitrecognizer/screens/drawing_painter.dart';
    import 'package:tflite/tflite.dart';
    import '../models/prediction.dart';
    import '../services/recognizer.dart';
    import '../utils/constants.dart';
    
    class DrawScreen extends StatefulWidget {
      const DrawScreen({Key? key}) : super(key: key);
    
      
      State<DrawScreen> createState() => _DrawScreenState();
    }
    
    class _DrawScreenState extends State<DrawScreen> {
      final _points = <Offset?>[];
      final _recognizer = Recognizer();
      var _prediction = <Prediction>[];
      bool initialize = false;
    
      
      void initState() {
        // TODO: implement initState
        super.initState();
      }
    
      
      Widget build(BuildContext context) {
        return Scaffold(
          appBar: AppBar(
            title: Text('Digit Recognizer'),
          ),
            body: Column(
              children: <Widget>[
                Row(
                  children: <Widget>[
                    Expanded(
                      child: Padding(
                        padding: const EdgeInsets.all(8.0),
                        child: Column(
                          children: <Widget>[
                            Text(
                              'MNIST database of handwritten digits',
                              style: TextStyle(
                                fontWeight: FontWeight.bold,
                              ),
                            ),
                            Text(
                              'The digits have been size-normalized and centered in a fixed-size images (28 x 28)',
                            )
                          ],
                        ),
                      ),
                    ),
                    _mnistPreviewImage(),
                  ],
                ),
                SizedBox(
                  height: 10,
                ),
                Container(
                  width: Constants.canvasSize + Constants.borderSize * 2,
                  height: Constants.canvasSize + Constants.borderSize * 2,
                  decoration: BoxDecoration(
                    border: Border.all(
                      color: Colors.black,
                      width: Constants.borderSize,
                    ),
                  ),
                  child: GestureDetector(
                    onPanUpdate: (DragUpdateDetails details) {
                      Offset _localPosition = details.localPosition;
                      if (_localPosition.dx >= 0 &&
                          _localPosition.dx <= Constants.canvasSize &&
                          _localPosition.dy >= 0 &&
                          _localPosition.dy <= Constants.canvasSize) {
                        setState(() {
                          _points.add(_localPosition);
                        });
                      }
                    },
                    onPanEnd: (DragEndDetails details) {
                      _points.add(null);
                      _recognize();
                    },
                    child: CustomPaint(
                      painter: DrawingPainter(_points),
                    ),
                  ),
                ),
                //PredictionWidget(
                  //predictions: _prediction,
                //)
              ],
            ),
          floatingActionButton: FloatingActionButton(
            child: Icon(Icons.clear),
            onPressed: () {
              _points.clear();
            }
          )
        );
      }
    
      dispose() {
        Tflite.close();
      }
    
      Widget _mnistPreviewImage() {
        return Container(
          width: 100,
          height: 100,
          color: Colors.black,
          child: FutureBuilder(
            future: _previewImage(),
            builder: (BuildContext _, snapshot) {
              if (snapshot.hasData) {
                return Image.memory(
                  snapshot.data!,
                  fit: BoxFit.fill,
                );
              } else {
                return Center(
                  child: Text('Error'),
                );
              }
            },
          ),
        );
      }
    
      Future<Uint8List> _previewImage() async {
        return await _recognizer.previewImage(_points);
      }
    
      void _initModel() async {
        var res = await _recognizer.loadModel();
      }
    
      void _recognize() async {
        List<dynamic> pred = await _recognizer.recognize(_points);
        setState(() {
          _prediction = pred.map((json) => Prediction.fromJson(json)).toList();
        });
      }
    }
  • drawing_painter.dart
    import 'package:flutter/material.dart';
    import 'package:mnistdigitrecognizer/utils/constants.dart';
    
    class DrawingPainter extends CustomPainter {
      final List<Offset?> points;
    
      DrawingPainter(this.points);
    
      final Paint _paint = Paint()
        ..strokeCap = StrokeCap.round
        ..color = Colors.black
        ..strokeWidth = Constants.strokeWidth;
    
      
      void paint(Canvas canvas, Size size) {
        for (int i = 0; i < points.length - 1; i++) {
          if (points[i] != null && points[i + 1] != null) {
            canvas.drawLine(points[i]!, points[i + 1]!, _paint);
          }
        }
      }
    
      
      bool shouldRepaint(CustomPainter oldDelegate) {
        return true;
      }
    }
  • prediction_widget.dart
    import 'package:flutter/material.dart';
    import '../models/prediction.dart';
    
    class PredictionWidget extends StatelessWidget {
      final List<Prediction> predictions;
      const PredictionWidget({Key? key, required this.predictions}) : super(key: key);
    
      Widget _numberWidget(int num, Prediction prediction) {
        return Column(
          children: <Widget>[
            Text(
              '$num',
              style: TextStyle(
                fontSize: 60,
                fontWeight: FontWeight.bold,
                color: prediction == null
                    ? Colors.black
                    : Colors.red.withOpacity(
                  (prediction.confidence! * 2).clamp(0, 1).toDouble(),
                ),
              ),
            ),
            Text(
              '${prediction == null ? '' : prediction.confidence?.toStringAsFixed(3)}',
              style: TextStyle(
                fontSize: 12,
              ),
            )
          ],
        );
      }
    
      List<dynamic> getPredictionStyles(List<Prediction> predictions) {
        List<dynamic> data = [
          null,
          null,
          null,
          null,
          null,
          null,
          null,
          null,
          null,
          null
        ];
        predictions.forEach((prediction) {
          data[prediction.index!] = prediction;
        });
    
        return data;
      }
    
      
      Widget build(BuildContext context) {
    
        var styles = getPredictionStyles(this.predictions);
    
        return Column(
          children: <Widget>[
            Row(
              mainAxisAlignment: MainAxisAlignment.spaceEvenly,
              children: <Widget>[
                for (var i = 0; i < 5; i++) _numberWidget(i, styles[i])
              ],
            ),
            Row(
              mainAxisAlignment: MainAxisAlignment.spaceEvenly,
              children: <Widget>[
                for (var i = 5; i < 10; i++) _numberWidget(i, styles[i])
              ],
            )
          ],
        );
      }
    }
  • recognizer.dart
    import 'dart:typed_data';
    import 'dart:ui';
    import 'package:flutter/material.dart';
    import 'package:tflite/tflite.dart';
    import '../utils/constants.dart';
    
    final _canvasCullRect = Rect.fromPoints(
      Offset(0, 0),
      Offset(Constants.imageSize, Constants.imageSize),
    );
    
    final _whitePaint = Paint()
      ..strokeCap = StrokeCap.round
      ..color = Colors.white
      ..strokeWidth = Constants.strokeWidth;
    
    final _bgPaint = Paint()
      ..color = Colors.black;
    
    class Recognizer {
      Future loadModel() {
        Tflite.close();
        return Tflite.loadModel(
          model: "assets/mnist.tfile",
          labels: "assets/mnist.txt",
        );
      }
     
    
      Future recognize(List<Offset?> points) async {
        final picture = _pointsToPicture(points);
        Uint8List bytes = await _imageToByteListUint8(
            picture, Constants.mnistImageSize);
        return _predict(bytes);
      }
    
      Future _predict(Uint8List bytes) {
        return Tflite.runModelOnBinary(binary: bytes);
      }
    
      Future<Uint8List> _imageToByteListUint8(Picture pic, int size) async {
        final img = await pic.toImage(size, size);
        final imgBytes = await img.toByteData();
        final resultBytes = Float32List(size * size);
        final buffer = Float32List.view(resultBytes.buffer);
    
        int index = 0;
    
        for (int i = 0; i < imgBytes!.lengthInBytes; i += 4) {
          final r = imgBytes?.getUint8(i);
          final g = imgBytes?.getUint8(i + 1);
          final b = imgBytes?.getUint8(i + 2);
          buffer[index++] = (r! + g! + b!) / 3.0 / 255.0;
        }
    
        return resultBytes.buffer.asUint8List();
      }
    
      Future<Uint8List> previewImage(List<Offset?> points) async {
        final picture = _pointsToPicture(points);
        final image = await picture.toImage(Constants.mnistImageSize, Constants.mnistImageSize);
        var pngBytes = await image.toByteData(format: ImageByteFormat.png);
    
        return pngBytes!.buffer.asUint8List();
      }
    
      Picture _pointsToPicture(List<Offset?> points) {
        final recorder = PictureRecorder();
        final canvas = Canvas(recorder, _canvasCullRect)
          ..scale(Constants.mnistImageSize / Constants.canvasSize);
    
        canvas.drawRect(
            Rect.fromLTWH(0, 0, Constants.imageSize, Constants.imageSize),
            _bgPaint);
    
        for (int i = 0; i < points.length - 1; i++) {
          if (points[i] != null && points[i + 1] != null) {
            canvas.drawLine(points[i]!, points[i + 1]!, _whitePaint);
          }
        }
        return recorder.endRecording();
      }
    }
  • constant.dart
    class Constants {
      static double canvasSize = 300;
      static double borderSize = 2;
    
      static double imageSize = 300;
      static int mnistImageSize = 28;
    
      static double strokeWidth = 8;
    }
  • main.dart
    import 'package:flutter/material.dart';
    import 'package:mnistdigitrecognizer/screens/draw_screen.dart';
    
    void main() {
      runApp(const MyApp());
    }
    
    class MyApp extends StatelessWidget {
      const MyApp({super.key});
    
      // This widget is the root of your application.
      
      Widget build(BuildContext context) {
        return MaterialApp(
          title: 'MNIST DIGIT RECOGNIZER',
          theme: ThemeData(
            appBarTheme: AppBarTheme(
              color: Colors.black,
            ),
            floatingActionButtonTheme: FloatingActionButtonThemeData(
              backgroundColor: Colors.black,
            ),
          ),
          home: DrawScreen(),
        );
      }
    }

0개의 댓글