Mean Shift Clustering (+implementation)
Mean Shift is an unsupervised clustering algorithm that aims to discover clusters without prior knowledge of labels. It works by updating candidates for centroids to be the mean of the points within a given region (also called bandwidth). These candidates are then filtered in a post-processing stage to eliminate near-duplicates to form the final set of centroids. Therefore contrary to KMeans, we don’t need to choose the number of clusters ourselves.
Simply put, take the weighted average from all other data points in the sample for each data point. It has some important advantages:
- It doesn’t require selecting the number of clusters in advance, but instead just requires a bandwidth to be specified, which can be easily chosen automatically
- It can handle clusters of any shape, whereas k-means (without using special extensions) require that clusters be roughly ball-shaped.
ALGORITHM STEPS:
- For each data point x in the sample X, find the distance between that point x and every other point in X(sample)
- Create weights for each point in X by using the Gaussian kernel of that point’s distance to x
- This weighting approach penalizes points further away from x
- The rate at which the weights fall to zero is determined by the bandwidth, which is the standard deviation of the Gaussian - Update x as the weighted average of all other points in X, weighted based on the previous step.
This will iteratively push points that are close together even closer until they are next to each other as shown in the diagram above. Now, let’s move on to implementing the Mean Shift algorithm.
Implementation:
import math, matplotlib.pyplot as plt, operator, torch
from functools import partial
from torch.distributions.multivariate_normal import MultivariateNormal
from torch import tensor
torch.manual_seed(42)
torch.set_printoptions(precision=3, linewidth=140, sci_mode=False)
To generate our data, we are going to pick 6 random points, which we’ll call centroids and for each point, we are going to generate 250 random points about it.
centroids = torch.rand(n_clusters,2)*70-35
# given `,2` mean x and y coords for each point then we are going to randomly
#generate data around those six centroids
def sample(m): # to get samples
return MultivariateNormal(m, torch.diag(tensor([5.,5.]))).sample((n_samples,))
# m: mean
# torch.diag(tensor([5.,5.])) mean covariance matrix to reflect relation between columns
# sample((n_samples,)) : to get samples
slices = [ sample(c) for c in centroids]
data= torch.cat(slices)
data.shape # 1500,2
to plot we’ll be using the following function
def plot_data(centroids, data, n_samples, ax=None):
if ax is None: _,ax = plt.subplots()
for i, centroid in enumerate(centroids):
samples = data[i*n_samples:(i+1)*n_samples]
ax.scatter(samples[:,0], samples[:,1], s=1)
ax.plot(*centroid, markersize=10, marker="x", color='k', mew=5)
ax.plot(*centroid, markersize=5, marker="x", color='m', mew=2)
plot_data(centroids, data, n_samples)
our data looks a bit like this
Now, let’s implement mean shift
midp = data.mean(0)
midp
our initial mean would be at the centre
plot_data([midp]*6, data, n_samples)
For weighing, we’d be using a typical Gaussian Kernel with a
bandwidth of 2.5
def gaussian(d, bw):
return torch.exp(-0.5*((d/bw))**2) / (bw*math.sqrt(2*math.pi))
and mean-shift algo goes like this,
def one_update(X):
for i, x in enumerate(X):
dist = torch.sqrt(((x-X)**2).sum(1)) # calculting distance
# computing weightage
weight = gaussian(dist, 2.5)
X[i] = (weight[:,None]*X).sum(0)/weight.sum()
def meanshift(data):
X = data.clone()
for it in range(5): one_update(X)
return X
######################################################
%time X=meanshift(data)
```
CPU times: user 453 ms, sys: 0 ns, total: 453 ms
Wall time: 452 ms
```
Visualization of the overall process,