static Ptr<KNearest> KNearest::create();
Knearest 객체를 생성
viftual void KNearest::setDefaultK(int val);
val : kNN알고리즘에서 사용할 k값
virtual void KNearest::setIsClassifier(bool val);
val : 이 값이 true이면 분류(classification), false이면 회귀(regression)으로 사용한다.
virtual float KNearest::findNearest(InputArray samples,
int k ,
OutputArray results,
OutputArray neighborResponses = noArray(),
OutputArray dist = noArray()) const;
samples : 테스트 데이터 벡터가 행단위로 저장된 행렬.
k : 사용할 최근접 이웃개수
results : 각 입력 샘플에 대한 예측 결과를 저장한 행렬
neighborResponses : 예측에 사용된 k개의 최근접 이웃 클래스 정보를 담고있는 행렬
dist : 입력 벡터와 예측에 사용된 k개의 최근접 이웃과의 거리를 저장한 행렬
반환값 : 입력 벡터가 하나인 경우에 대한 응답이 반환됨
#include "opencv2/opencv.hpp"
#include <iostream>
using namespace std;
using namespace cv;
using namespace cv::ml;
Mat img;
Mat train, label;
Ptr<KNearest> knn;
int k_value = 1;
void on_k_changed(int, void*);
void addPoint(const Point& pt, int cls);
void trainAndDisplay();
int main(void) {
img = Mat::zeros(Size(500, 500), CV_8UC3);
knn = KNearest::create();
namedWindow("knn");
createTrackbar("k", "knn", &k_value, 5, on_k_changed);
const int NUM = 30;
Mat rn(NUM, 2, CV_32SC1);
randn(rn, 0, 50);
for (int i = 0; i < NUM; i++)
addPoint(Point(rn.at<int>(i, 0) + 150, rn.at<int>(i, 1) + 150), 0);
randn(rn, 0, 50);
for (int i = 0; i < NUM; i++)
addPoint(Point(rn.at<int>(i, 0) + 350, rn.at<int>(i, 1) + 150), 1);
randn(rn, 0, 70);
for (int i = 0; i < NUM; i++)
addPoint(Point(rn.at<int>(i, 0) + 250, rn.at<int>(i, 1) + 400), 2);
trainAndDisplay();
waitKey();
return 0;
}
void on_k_changed(int, void*) {
if (k_value < 1)
k_value = 1;
trainAndDisplay();
}
void addPoint(const Point& pt, int cls) {
Mat new_sample = (Mat_<float>(1, 2) << pt.x, pt.y);
train.push_back(new_sample);
Mat new_label = (Mat_<int>(1, 1) << cls);
label.push_back(new_label);
}
void trainAndDisplay() {
knn->train(train, ROW_SAMPLE, label);
for (int i = 0; i < img.rows; ++i) {
for (int j = 0; j < img.cols; ++j) {
Mat sample = (Mat_<float>(1, 2) << j, i);
Mat res;
knn->findNearest(sample, k_value, res);
int response = cvRound(res.at<float>(0, 0));
if (response == 0)
img.at<Vec3b>(i, j) = Vec3b(128, 128, 255);
else if (response == 1)
img.at<Vec3b>(i, j) = Vec3b(128, 255, 128);
else if (response == 2)
img.at<Vec3b>(i, j) = Vec3b(255, 128, 128);
}
}
for (int i = 0; i < train.rows; i++) {
int x = cvRound(train.at<float>(i, 0));
int y = cvRound(train.at<float>(i, 1));
int l = label.at<int>(i, 0);
if (l == 0)
circle(img, Point(x, y), 5, Scalar(0, 0, 128), -1, LINE_AA);
else if (l == 1)
circle(img, Point(x, y), 5, Scalar(0, 128, 0), -1, LINE_AA);
else if (l == 2)
circle(img, Point(x, y), 5, Scalar(128, 0, 0), -1, LINE_AA);
}
imshow("knn", img);
}