Mean Shift Clustering is a non-parametric, iterative algorithm that locates the maxima of a density function given discrete data sampled from that function. It is used for detecting the underlying distribution of data points in a multi-dimensional space, identifying clusters accordingly. Unlike methods that require prior specification of the number of clusters, Mean Shift adjusts its window dynamically with the aim of finding dense areas of data points. This makes it particularly effective for applications where the number of clusters is not known beforehand and for data with arbitrary shapes.
The core idea of Mean Shift is to update candidates for centroids to the mean of the points within a given region, iteratively shifting this window towards the peak of each cluster's density. This process is based on the concept of kernel density estimation (KDE), a fundamental technique for probability density estimation. Given a set of data points in a -dimensional space, Mean Shift considers the feature space as a probability density function and seeks modes (local maxima) of this function.
KDE is used to estimate the probability density function (PDF) of a random variable in a non-parametric way. For a given kernel and bandwidth , the KDE at point is given by:
where are the data points, is the bandwidth, and is the kernel function, typically a Gaussian kernel. The choice of bandwidth is crucial as it controls the smoothness of the estimated density.
The mean shift vector at any point in the feature space is defined as the difference between the weighted mean of the data points within a window centered at and the current center . Mathematically, for a given kernel , the mean shift vector is computed as:
where denotes the neighborhood of points around , determined by the bandwidth . The kernel function assigns weights to points in the neighborhood, with points closer to receiving higher weights.
A popular choice for the kernel function is the Gaussian kernel, defined as:
Using this kernel, the mean shift vector in a multi-dimensional space can be computed with a focus on regions of high density, effectively moving the centroid towards the densest part of the cluster.
The algorithm converges when the mean shift vector's magnitude falls below a small threshold , indicating that the centroid has reached a local density maximum:
bandwidth
: float
, default = 2.0max_iter
: int
, default = 300tol
: float
, default = 0.001Test on synthesized dataset with 3 blobs:
from luma.clustering.density import MeanShiftClustering
from luma.visual.evaluation import ClusterPlot
from luma.metric.clustering import SilhouetteCoefficient
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
X, y = make_blobs(n_samples=600,
centers=3,
cluster_std=2.0,
random_state=42)
mshift = MeanShiftClustering(bandwidth=5,
max_iter=300,
tol=0.001)
mshift.fit(X)
fig = plt.figure(figsize=(10, 5))
ax1 = fig.add_subplot(1, 2, 1)
ax2 = fig.add_subplot(1, 2, 2)
clst = ClusterPlot(mshift, X)
clst.plot(ax=ax1)
sil = SilhouetteCoefficient(X, mshift.labels)
sil.plot(ax=ax2, show=True)
Mean Shift Clustering is widely used in various domains, including: