MNIST 학습하기
MNIST 학습하기
- 학습 목표
- TensorFlow 또는 PyTorch를 이용하여 MNIST 필기체 숫자 인식을 학습하기
- 학습된 모델을 파일로 저장하기 (*.pb 또는 *.onnx)
- OpenCV에서 학습된 모델을 불러와서 필기체 인식 프로그램을 실행하기
- 준비사항
- Python 설치
- TensorFlow설치
- PyTorch 설치
- Tensorflow에서 모델 파일(*.pb) 저장하기
- Pytorch에서 모델파일(*.oonx) 저장하기
학습된 MNIST 모델을 OpenCV에서 사용하기
학습된 MNIST 모델을 OpenCV에서 사용하기
- DNN 모듈을 이용한 숫자 인식 예제
- 마우스를 이용하여 사용자가 직접 숫자를 그리고, 인식을 수행
#include <iostream>
#include "opencv2/opencv.hpp"
using namespace std;
using namespace cv;
using namespace cv::dnn;
void on_mouse(int event, int x, int y, int flags, void* userdata);
Mat norm_digit(Mat& src)
{
CV_Assert(!src.empty() && src.type() == CV_8UC1);
Mat src_bin;
threshold(src, src_bin, 0, 255, THRESH_BINARY | THRESH_OTSU);
Mat labels, stats, centroids;
int n = connectedComponentsWithStats(src_bin, labels, stats, centroids);
Mat dst = Mat::zeros(src.rows, src.cols, src.type());
for (int i = 1; i < n; i++) {
if (stats.at<int>(i, 4) < 20) continue;
int cx = cvRound(centroids.at<double>(i, 0));
int cy = cvRound(centroids.at<double>(i, 1));
double dx = 14 - cx;
double dy = 14 - cy;
Mat warpMat = (Mat_<double>(2, 3) << 1, 0, dx, 0, 1, dy);
warpAffine(src, dst, warpMat, dst.size());
}
return dst;
}
int main()
{
Net net = readNet("mnist.pb");
if (net.empty()) {
cerr << "Network load failed!" << endl;
return -1;
}
Mat img = Mat::zeros(400, 400, CV_8UC1);
imshow("img", img);
setMouseCallback("img", on_mouse, (void*)&img);
while (true) {
int c = waitKey();
if (c == 27) {
break;
} else if (c == ' ') {
Mat blr, resized;
GaussianBlur(img, blr, Size(), 1.0);
resize(blr, resized, Size(28, 28), 0, 0, INTER_AREA);
Mat blob = blobFromImage(norm_digit(resized), 1/255.f, Size(28, 28));
net.setInput(blob);
Mat prob = net.forward();
double maxVal;
Point maxLoc;
minMaxLoc(prob, NULL, &maxVal, NULL, &maxLoc);
int digit = maxLoc.x;
cout << digit << " (" << maxVal * 100 << "%)" <<endl;
img.setTo(0);
imshow("img", img);
}
}
}
Point ptPrev(-1, -1);
void on_mouse(int event, int x, int y, int flags, void* userdata)
{
Mat img = *(Mat*)userdata;
if (event == EVENT_LBUTTONDOWN) {
ptPrev = Point(x, y);
} else if (event == EVENT_LBUTTONUP) {
ptPrev = Point(-1, -1);
} else if (event == EVENT_MOUSEMOVE && (flags & EVENT_FLAG_LBUTTON)) {
line(img, ptPrev, Point(x, y), Scalar::all(255), 40, LINE_AA, 0);
ptPrev = Point(x, y);
imshow("img", img);
}
}
- 실행 결과