k-means
는 Unsupervised Learning
에 속하는 Clustering Algorithm
이다.
Unsupervised Learning
은 Training Data에 Label 이 없기 때문에, 비슷한 군집 끼리 분류를 해야한다.
클러스터를 정의하는 방법에는 많은 방법이 있지만, 여기서 소개할 k-means는 클러스터와 데이터가 가까운것을 찾는다.
이 문제는 NP-hard 로써 각 클러스터와 데이터간의 유클리드 거리를 최소로 하는 집합을 찾는것이 목표이다.
알고리즘은 간단하다.
위 수식을 설명하면, K개의 클러스터와 모든 데이터의 거리의 최소값을 찾는다는 뜻이다.
where뒤의 식은 mj는 클러스터의 중심점인데, 무게중심을 개선한다는 뜻이다.
실제로 k-means 자체는 간단하다.
이를 간단하게 구현해 보았다.
kmeans.pde
/*
* @Project : k-means visualization
* @Architecture : Kim Bom
* kmeans.pde
*
* @Created by KimBom On 2016. 06. 12...
* @Copyright (C) 2016 KimBom. All rights reserved.
*/
import java.util.LinkedList;
import java.awt.Point;
LinkedList<Point> list=new LinkedList<Point>(); //point set
final int padding=50; //for drawing
final int radius=3; //for drawing
final int pt_size=5000;
final int K=7; //K-means, num of cluster
Cluster[] cluster=new Cluster[K];
//maximan num of cluster is 7
color[] rgbs=new color[]{
color(241, 95, 95),
color(250, 237, 125),
color(134, 229, 127),
color(92, 209, 229),
color(67, 116, 217),
color(217, 65, 197),
color(255, 255, 255)
};
boolean isEnd=false;
void setup() {
size(900, 900);
//init point position
for (int i=0; i<pt_size; i++) {
int x=(int)random(padding, width-padding);
int y=(int)random(padding, width-padding);
Point p=new Point(x, y);
list.add(p);
}
//set cluster;
try {
for (int i=0; i<K; i++) {
int x=(int)random(padding, width-padding);
int y=(int)random(padding, width-padding);
cluster[i]=new Cluster(x, y, rgbs[i]);
}
}
catch(ArrayIndexOutOfBoundsException e) {
e.printStackTrace();
isEnd=true;
}
frameRate(5);
}
void draw() {
background(33, 33, 33);
if (!isEnd) {
SetCluster();
} else {
fill(255,255,255);
textSize(20);
text("k-means end",20,30);
}
DrawClusterPoint();
if (!isEnd && ResetCluster()) {
isEnd=true;
}
}
void DrawClusterPoint() {
for (int i=0; i<K; i++) {
fill(cluster[i].rgb);
cluster[i].draw_x(radius);
for (Point p : cluster[i].set) {
ellipse(p.x, p.y, radius, radius);
}
}
}
//calculate Euclidean distance
float getDistance(float x, float y) {
return pow(x, 2)+pow(y, 2);
}
void SetCluster() {
for (Cluster c : cluster) {
c.set.clear();
}
for (Point p : list) {
int min_idx=0;
float min_value=Float.MAX_VALUE;
for (int i=0; i<cluster.length; i++) {
float d=getDistance(p.x-cluster[i].base.x, p.y-cluster[i].base.y);
if (d<min_value) {
min_idx=i;
min_value=d;
}
}
cluster[min_idx].set.add(p);
}
}
boolean ResetCluster() {
//center position : sum(x),sum(y) div size(x,y)
boolean b=true;
for (int i=0; i<K; i++) {
int sum_x=0;
int sum_y=0;
for (Point p : cluster[i].set) {
sum_x+=p.x;
sum_y+=p.y;
}
sum_x/=cluster[i].set.size();
sum_y/=cluster[i].set.size();
if (sum_x!=cluster[i].base.x || sum_y!=cluster[i].base.y) {
cluster[i].base=new Point(sum_x, sum_y);
b&=false;
} else {
b&=true;
}
}
return b;
}
cluster.pde
/*
* @Project : k-means visualization
* @Architecture : Kim Bom
* cluster.pde
*
* @Created by KimBom On 2016. 06. 12...
* @Copyright (C) 2016 KimBom. All rights reserved.
*/
import java.util.LinkedList;
import java.awt.Point;
class Cluster {
public LinkedList<Point> set;
public Point base;
public color rgb;
public Cluster(int x, int y,color c) {
this.base=new Point(x,y);
this.set=new LinkedList<Point>();
rgb=c;
}
public void draw_x(int radius){
stroke(red(rgb),green(rgb),blue(rgb));
line(this.base.x-radius,this.base.y-radius,this.base.x+radius,this.base.y+radius);
line(this.base.x+radius,this.base.y-radius,this.base.x-radius,this.base.y+radius);
}
};