import numpy as np
import matplotlib.pyplot as plt
from sklearn_extra.cluster import KMedoids
from sklearn.datasets import make_blobs
X, _ = make_blobs(n_samples=500, random_state=0, centers=3, n_features=2, cluster_std=0.9)
model = KMedoids(metric="euclidean", n_clusters=3)
model.fit(X)
b_pred = model.predict(X)
plt.scatter(X[:, 0], X[:, 1], c=b_pred, cmap='viridis')
plt.scatter(model.cluster_centers_[:, 0], model.cluster_centers_[:, 1], marker='X', s=125, c='red')
plt.grid()
plt.show()

from sklearn.mixture import GaussianMixture
gmm = GaussianMixture(n_components=3, random_state=42)
gmm.fit(X)
labels = gmm.predict(X)
plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis')
plt.scatter(gmm.means_[:, 0], gmm.means_[:, 1], c='red', marker='X', s=200)
plt.grid()
plt.show()

from sklearn.cluster import KMeans
from yellowbrick.cluster import KElbowVisualizer
model = KMeans(random_state=42)
visualizer = KElbowVisualizer(model, k=(1,10))
visualizer.fit(X)
visualizer.show()

original data
np.random.seed(42)
cov = [[10,8], [8,10]]
ellipse_cluster = np.random.multivariate_normal([0,0], cov,size=200)
circle_cluster = np.random.normal([25,0],0.5,size=(200,2))
outlier_cluster = np.random.normal([10,5],0.5,size=(10,2))
X = np.vstack([ellipse_cluster, circle_cluster, outlier_cluster])
plt.scatter(ellipse_cluster[:,0], ellipse_cluster[:,1], s=20, color='blue', label = 'ellipse_cluster')
plt.scatter(circle_cluster[:,0], circle_cluster[:,1], s=20, color='green', label = 'circle_cluster')
plt.scatter(outlier_cluster[:,0], outlier_cluster[:,1], s=20, color='red',marker='X', label = 'outlier_cluster')
plt.title('Original Data')
plt.legend()
plt.grid(True)
plt.show()

비교
k = 3 #2가 맞는데 3으로 하라심
kmeans = KMeans(n_clusters= k, random_state=42).fit(X)
kmeans_labels = kmeans.labels_
kmedoids = KMedoids(n_clusters= k, random_state=42).fit(X)
kmedoids_labels = kmedoids.labels_
gmm = GaussianMixture(n_components=k, random_state=42)
gmm.fit(X)
gmm_labels = gmm.predict(X)
# 결과 시각화
fig,axes=plt.subplots(1,3,figsize=(18,5))
axes[0].scatter(X[:,0],X[:,1], c=kmeans_labels, cmap='viridis',s=20)
axes[0].scatter(kmeans.cluster_centers_[:,0],kmeans.cluster_centers_[:,1], c='red',marker='X',s=150,label='Centers')
axes[0].legend()
axes[0].set_title('k-means')
axes[0].grid(True)
axes[1].scatter(X[:,0],X[:,1], c=kmedoids_labels, cmap='viridis',s=20)
axes[1].scatter(kmedoids.cluster_centers_[:,0],kmedoids.cluster_centers_[:,1], c='red',marker='X',s=150,label='Centers')
axes[1].set_title('k-medodis')
axes[1].grid(True)
axes[1].legend()
axes[2].scatter(X[:,0],X[:,1], c=gmm_labels, cmap='viridis',s=20)
axes[2].scatter(gmm.means_[:,0],gmm.means_[:,1], c='red',marker='X',s=150,label='Centers')
axes[2].set_title('Guassian Mixture Model(GMM)')
axes[2].grid(True)
axes[2].legend()
plt.suptitle('clustering Algorithm Comparison',fontsize=16)
plt.show()

# generate swiss roll data
np.random.seed(42)
X, _ =make_swiss_roll(n_samples=1500, noise=0.1)
# original data - visualization
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X[:, 0], X[:, 1], X[:, 2], s=10, color='gray')
plt.show()
# initial k setting
k = 5
# kmeans
kmeans = KMeans(n_clusters=k, random_state=42).fit(X)
kmeans_labels = kmeans.labels_
# kmedoids
kmedoids = KMedoids(n_clusters=k, random_state=42).fit(X)
kmedoids_labels = kmedoids.labels_
# gmm
gmm = GaussianMixture(n_components=k, random_state=42)
gmm.fit(X)
gmm_labels = gmm.predict(X)
# visualization
fig = plt.figure(figsize=(18,5))
ax = fig.add_subplot(141,projection='3d')
ax.scatter(X[:,0], X[:,1], X[:,2], c=kmeans_labels, cmap='viridis', s=10)
ax.set_title('k-means')
ax = fig.add_subplot(142,projection='3d')
ax.scatter(X[:,0], X[:,1], X[:,2], c=kmedoids_labels, cmap='viridis', s=10)
ax.set_title('k-medoids')
ax = fig.add_subplot(143,projection='3d')
ax.scatter(X[:,0], X[:,1], X[:,2], c=gmm_labels, cmap='viridis', s=10)
ax.set_title('gmm')
ax = fig.add_subplot(144,projection='3d')
ax.scatter(X[:,0], X[:,1], X[:,2], s=10, color='gray')
ax.set_title('Swiss-Roll')
plt.suptitle('3D Swiss-Roll : k-means, k-medoids, gmm vs. original', fontsize=16)
plt.tight_layout()
plt.show()


위의 알고리즘들을 비교하여 데이터의 특성에 맞는 클러스터링 기법을 선택하는 것이 중요함