OpenCV | kNN 알고리즘을 이용한 2차원 점 분류

박나연·2021년 7월 16일
0

OpenCV

목록 보기
39/40

KNearest 클래스

create()

static Ptr<KNearest> KNearest::create();

Knearest 객체를 생성

setDefaultK()

viftual void KNearest::setDefaultK(int val);

val : kNN알고리즘에서 사용할 k값

setIsClassifier()

virtual void KNearest::setIsClassifier(bool val);

val : 이 값이 true이면 분류(classification), false이면 회귀(regression)으로 사용한다.

findNearest()

virtual float KNearest::findNearest(InputArray samples,
int k ,
OutputArray results,
OutputArray neighborResponses = noArray(),
OutputArray dist = noArray()) const;

samples : 테스트 데이터 벡터가 행단위로 저장된 행렬.
k : 사용할 최근접 이웃개수
results : 각 입력 샘플에 대한 예측 결과를 저장한 행렬
neighborResponses : 예측에 사용된 k개의 최근접 이웃 클래스 정보를 담고있는 행렬
dist : 입력 벡터와 예측에 사용된 k개의 최근접 이웃과의 거리를 저장한 행렬
반환값 : 입력 벡터가 하나인 경우에 대한 응답이 반환됨


kNN알고리즘을 이용한 2차원 점 분류 구현

#include "opencv2/opencv.hpp"
#include <iostream>
using namespace std;
using namespace cv;
using namespace cv::ml;

Mat img;
Mat train, label;
Ptr<KNearest> knn;
int k_value = 1;

void on_k_changed(int, void*);
void addPoint(const Point& pt, int cls);
void trainAndDisplay();

int main(void) {
	img = Mat::zeros(Size(500, 500), CV_8UC3);
	knn = KNearest::create();

	namedWindow("knn");
	createTrackbar("k", "knn", &k_value, 5, on_k_changed);

	const int NUM = 30;
	Mat rn(NUM, 2, CV_32SC1);

	randn(rn, 0, 50);
	for (int i = 0; i < NUM; i++)
		addPoint(Point(rn.at<int>(i, 0) + 150, rn.at<int>(i, 1) + 150), 0);

	randn(rn, 0, 50);
	for (int i = 0; i < NUM; i++)
		addPoint(Point(rn.at<int>(i, 0) + 350, rn.at<int>(i, 1) + 150), 1);

	randn(rn, 0, 70);
	for (int i = 0; i < NUM; i++)
		addPoint(Point(rn.at<int>(i, 0) + 250, rn.at<int>(i, 1) + 400), 2);

	trainAndDisplay();

	waitKey();
	return 0;
}

void on_k_changed(int, void*) {
	if (k_value < 1)
		k_value = 1;
	trainAndDisplay();
}

void addPoint(const Point& pt, int cls) {
	Mat new_sample = (Mat_<float>(1, 2) << pt.x, pt.y);
	train.push_back(new_sample);

	Mat new_label = (Mat_<int>(1, 1) << cls);
	label.push_back(new_label);

}

void trainAndDisplay() {
	knn->train(train, ROW_SAMPLE, label);

	for (int i = 0; i < img.rows; ++i) {
		for (int j = 0; j < img.cols; ++j) {
			Mat sample = (Mat_<float>(1, 2) << j, i);

			Mat res;
			knn->findNearest(sample, k_value, res);

			int response = cvRound(res.at<float>(0, 0));
			if (response == 0)
				img.at<Vec3b>(i, j) = Vec3b(128, 128, 255);
			else if (response == 1)
				img.at<Vec3b>(i, j) = Vec3b(128, 255, 128);
			else if (response == 2)
				img.at<Vec3b>(i, j) = Vec3b(255, 128, 128);

		}
	}

	for (int i = 0; i < train.rows; i++) {
		int x = cvRound(train.at<float>(i, 0));
		int y = cvRound(train.at<float>(i, 1));
		int l = label.at<int>(i, 0);

		if (l == 0)
			circle(img, Point(x, y), 5, Scalar(0, 0, 128), -1, LINE_AA);
		else if (l == 1)
			circle(img, Point(x, y), 5, Scalar(0, 128, 0), -1, LINE_AA);
		else if (l == 2)
			circle(img, Point(x, y), 5, Scalar(128, 0, 0), -1, LINE_AA);
	}

	imshow("knn", img);
}


profile
Data Science / Computer Vision

0개의 댓글