붓꽃 종 분류(K-means clustering)

meta fring·2023년 7월 28일
0

붓꽃의 꽃받침 길이와 너비, 꽃잎의 길이와 너비, 그리고 종을 담은 데이터 셋이 있다.
종을 모른다고 가정하고 K-means clustering을 이용해서 군집을 지어보겠다. K-means clustering으로 묶은 군집이 실제 종과 얼마나 일치하는 지를 비교해 볼 것이다.

import pandas as pd

df = pd.read_csv("./data/iris.csv")

판다스를 임포트하여 CSV 파일을 데이터 프레임으로 만든다.
df를 확인해 보면 아래와 같다.

from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
df["color"] = label_encoder.fit_transform(df["Species"])

sklearn.preprocessing 모듈에서 LabelEncoder 클래스를 사용하여 "Species" 컬럼의 값을 숫자로 인코딩한다.

cur_df = df[["PetalLengthCm", "SepalLengthCm", "color"]]
data = cur_df.drop(["color"], axis=1).values

세개의 컬럼을 cur_df에 할당한다.
그리고 data에는 color 컬럼을 제외한 후, 넘파이 배열 형태로 할당한다.

from sklearn.cluster import KMeans

model = KMeans(
    n_init=10,
    n_clusters=3,
)
model.fit(data)

sklearn.cluster 모듈에서 KMeans 클래스를 사용하여 k-평균 군집화(K-means clustering) 모델을 생성한다.
n_init=10으로 설정하여 초기 중심점 설정을 10번 시도하고, n_clusters=3으로 설정하여 3개의 클러스터를 생성한다.
그리고 data를 학습시킨다.

def inference(model, data):
    h = 0.02
    x_min, x_max = data[:, 0].min() - 1, data[:, 0].max() + 1
    y_min, y_max = data[:, 1].min() - 1, data[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    Z = model.predict(np.c_[xx.ravel(), yy.ravel()]) 
    return xx, yy, Z

k-평균 군집화 모델과 데이터를 받아들이고, 군집화 결과를 시각화하기 위한 그리드 데이터를 생성하는 역할을 수행하는 함수를 정의한다.
시각화할 그리드의 간격을 0.02로 주고,
그리드 범위를 설정하고,
xx, yy 변수에 x_min부터 x_max까지, y_min부터 y_max까지의 범위에서 h 간격으로 그리드 데이터를 생성한다.
생성된 그리드 데이터 xx, yy를 펼쳐서 1차원으로 만든 후, k-평균 군집화 모델 model을 사용하여 군집 레이블 Z를 예측한다.

import numpy as np
from matplotlib import pyplot as plt

def plot_kmeans_cluster(xx, yy, Z, model, data, columns, species):
    labels = model.labels_
    centroids = model.cluster_centers_
    plt.clf()
    Z = Z.reshape(xx.shape)

    plt.scatter(data[:, 0], data[:, 1], c=labels, s=10)
    plt.xlabel(columns[0])
    plt.ylabel(columns[1])
    plt.show()
    
    plt.scatter(data[:, 0], data[:, 1], c=species, s=10)
    plt.xlabel(columns[0])
    plt.ylabel(columns[1])
    plt.show()

labels = model.labels_: 군집화 모델 model에서 예측된 레이블 labels를 가져온다. 각 데이터 포인트가 속한 군집의 레이블이 저장된다.

centroids = model.clustercenters: 군집화 모델 model에서 예측된 클러스터 중심점들을 가져온다.

plt.clf(): 그래프를 초기화한다.

Z = Z.reshape(xx.shape): 군집화 결과인 Z를 xx의 shape에 맞게 재구성한다.

plt.scatter(data[:, 0], data[:, 1], c=labels, s=10): 데이터 포인트들을 산점도로 시각화한다. 각 데이터 포인트의 x와 y 좌표가 data[:, 0]과 data[:, 1]에 저장되어 있고, 각 데이터 포인트의 군집 레이블이 labels로 지정된다.

plt.xlabel(columns[0]), plt.ylabel(columns[1]): x축과 y축의 라벨을 columns에 지정된 컬럼 이름으로 설정한다.

plt.show(): 시각화된 그래프를 화면에 출력한다.

그리고 데이터 포인트들을 종(species) 정보를 이용하여 다시 산점도로 시각화한다.

columns = cur_df.columns
xx, yy, Z = inference(model, data)
plot_kmeans_cluster(xx, yy, Z, model, data, columns, cur_df["color"])

plot_kmeans_cluster로 전달할 파라미터들을 정의하고, plot_kmeans_cluster를 호출하여 시각화한다.

새로운 데이터가 들어왔을 때 어떤 클러스터에 속하는지 분류를 쉽게 할 수 있다면 인간들의 업무에 큰 도움이 될 것이다.
더 정확하고 세밀한 모델을 만들 수 있도록 공부를 더 해야겠다.

profile
긍정적인 개발자

0개의 댓글