K-means

Machine Learning

Introduction

K-means is a simple iterative clustering algorithm. Starting with randomly chosen \( K \) centroids, the algorithm proceeds to update the centroids and their clusters to equilibrium while minimizing the total within cluster variance. It is primarily used in scenarios with real-valued features because it relies on the Euclidean distance to discover cluster centroids.

Prerequisites

To understand the K-means clustering algorithm, we recommend familiarity with the concepts in

Follow the above links to first get acquainted with the corresponding concepts.

Problem setting

In clustering, the goal of the model is to group the available data into clusters — groups of observations that are similar in some sense.

Consider observations of the form \( \vx \in \real^\ndim \) — vectors consisting of \( \ndim \) features, \(\vx = [x_1, x_2, \ldots, x_\ndim] \). A collection of such observations is provided in the unlabeled set \( \unlabeledset = \set{\vx_1,\ldots,\vx_\nunlabeled} \).

The \(K\)-means clustering algorithm needs to partition these \( \nunlabeled \) examples into \( K \) groups or clusters.

Intuition

Here's the intuitive strategy for the \(k\)-means algorithm.

  1. Randomly pick \( k\) observations from the training set as your centroids.
  2. Assign each of the training observations to their closest centroids, based on Euclidean distance. These are your \( k\) clusters.
  3. Estimate the centroid for each of the \( k \) clusters as the mean of the training observations assigned to that cluster.
  4. If the new centroids are significantly different from the previous centroids, go back to step 2 and iterate till convergence.

Note that \(k\)-means is a iterative approach that recomputes centroids and reassigns training examples to the clusters represented by these centroids, until the centroids converge.

Now, let's try to understand this is expected to work and the mathematical motivation behind it.

K-means: demo

In this next interactive demo, load a dataset, choose a suitable number of clusters \( K \), and fit the K-means to the data. We indicate the membership of each data point by changing its color. We also show each cluster by a circle. The inferred centroids are indicated by stars.

Note that K-means may result in a different clustering of the same data each time it is run. This is due to the randomly initialized centroid locations each time K-means is fit to the data.

Also observe that the number of components is a crucial hyperparameter for K-means. While it may be easy to visualize the optimal number of clusters for this small 2-dimensional dataset, it is typically difficult to do so on higher-dimensional dataset.

Demo: Fitting kmeans to data

Grouping criterion

The goal of any clustering algorithm is to assign similar observations to the same cluster and keep dissimilar points in disparate clusters. This effectively means that a good clustering should satisfy the following property: If similar observations are grouped within the same cluster, the within-cluster (intra-cluster) distances should be low. Average of distances between two points belonging to the same cluster is an intra-cluster distance.

This within-cluster distance is be a good loss function to minimize for distance-based clustering algorithms. Let's define it.

Let \( C(\nunlabeledsmall) \) denote the cluster assigned by the algorithm to the instance \( \vx_\nunlabeledsmall \) from among the set of clusters \( \sC \). Then, the within-cluster distance \( W(\sC) \) for the clusterings is defined as

\begin{equation} W(\sC) = \frac{1}{2} \sum_{k=1}^K \sum_{C(\nunlabeledsmall)=k} \sum_{C(\dash{\nunlabeledsmall})=k} \norm{\vx_\nunlabeledsmall - \vx_{\dash{\nunlabeledsmall}}}{}^2 \label{eqn:loss} \end{equation}

This clustering loss function is also known as within-point scatter.

Centroids

Centroids or means are prototypes in the feature space whose coordinates are the averages of the points that they represent. This means, a centroid \( \bar{\vx}_k \) for a cluster \( k \) is defined as

$$ \bar{\vx}_k = \frac{1}{\nunlabeled_k} \sum_{C(\nunlabeledsmall)=k} \vx_\nunlabeledsmall $$

Where, \( \nunlabeled_k \) denotes the number of examples in the cluster \( k \), so that

$$ \nunlabeled_k = \sum_{\nunlabeledsmall=1}^{\nunlabeled} \indicator{C(\nunlabeledsmall) = k} $$

Here, \( \indicator{C(\nunlabeledsmall) = k} \) is the indicator function that takes on value 1 if the condition is true and zero otherwise.

The within-point scatter loss function defined earlier in Equation \eqref{eqn:loss} can now be redefined in terms of the centroids as

\begin{equation} W(\sC) = \sum_{k=1}^K \nunlabeled_k \sum_{C(\nunlabeledsmall)=k} \norm{\vx_\nunlabeledsmall - \bar{\vx}_k}{}^2 \label{eqn:loss-centroids} \end{equation}

Why? Let's find out, in case you are curious.

Proof of clustering loss in terms of centroids

First, note that

\begin{align} \norm{\vx_\nunlabeledsmall - \vx_{\dash{\nunlabeledsmall}}}{}^2 &= \left(\vx_\nunlabeledsmall - \vx_{\dash{\nunlabeledsmall}} \right)^T \left(\vx_\nunlabeledsmall - \vx_{\dash{\nunlabeledsmall}} \right) \\\\ &= \vx_\nunlabeledsmall^T\vx_\nunlabeledsmall- 2\vx_\nunlabeledsmall^T\dash{\vx_\nunlabeledsmall} + \vx_\nunlabeledsmall^{'T}\dash{\vx_\nunlabeledsmall} \end{align}

Now consider the inner sum in \( W(\sC) \).

\begin{align} \sum_{C(\dash{\nunlabeledsmall})=k} \norm{\vx_\nunlabeledsmall - \vx_{\dash{\nunlabeledsmall}}}{}^2 &= \sum_{C(\dash{\nunlabeledsmall})=k} \left[ \vx_\nunlabeledsmall^T\vx_\nunlabeledsmall- 2\vx_\nunlabeledsmall^T\dash{\vx_\nunlabeledsmall} + \vx_\nunlabeledsmall^{'T}\dash{\vx_\nunlabeledsmall} \right] \\\\ &= \nunlabeled_k\vx_\nunlabeledsmall^T\vx_\nunlabeledsmall- 2\vx_\nunlabeledsmall^T \left[\sum_{C(\dash{\nunlabeledsmall})=k} \dash{\vx_\nunlabeledsmall}\right] + \sum_{C(\dash{\nunlabeledsmall})=k} \vx_\nunlabeledsmall^{'T}\dash{\vx_\nunlabeledsmall} \\\\ &= \nunlabeled_k\vx_\nunlabeledsmall^T\vx_\nunlabeledsmall- 2\nunlabeled_k\vx_\nunlabeledsmall^T\bar{\vx}_k + \sum_{C(\dash{\nunlabeledsmall})=k} \vx_\nunlabeledsmall^{'T}\dash{\vx_\nunlabeledsmall} \end{align}

Now let's sum this result over the middle sum involving \( \sum_{C(\nunlabeledsmall)=k} \)

\begin{align} \sum_{C(\nunlabeledsmall)=k} \sum_{C(\dash{\nunlabeledsmall})=k} \norm{\vx_\nunlabeledsmall - \vx_{\dash{\nunlabeledsmall}}}{}^2 &= \sum_{C(\nunlabeledsmall)=k} \left[\nunlabeled_k\vx_\nunlabeledsmall^T\vx_\nunlabeledsmall- 2\nunlabeled_k\vx_\nunlabeledsmall^T\bar{\vx}_k + \sum_{C(\dash{\nunlabeledsmall})=k} \vx_\nunlabeledsmall^{'T}\dash{\vx_\nunlabeledsmall} \right] \\\\ &= \nunlabeled_k \left[ \sum_{C(\nunlabeledsmall)=k} \vx_\nunlabeledsmall^T\vx_\nunlabeledsmall \right] - 2\nunlabeled_k\left[ \sum_{C(\nunlabeledsmall)=k}\vx_\nunlabeledsmall \right]^T\bar{\vx}_k + \nunlabeled_k \left[\sum_{C(\dash{\nunlabeledsmall})=k} \vx_\nunlabeledsmall^{'T}\dash{\vx_\nunlabeledsmall} \right] \\\\ &= 2\nunlabeled_k \left[ \sum_{C(\nunlabeledsmall)=k} \vx_\nunlabeledsmall^T\vx_\nunlabeledsmall \right] - 2\nunlabeled_k\left[\nunlabeled_k\bar{\vx}_k\right]^T\bar{\vx}_k \\\\ &= 2 \nunlabeled_k \left[\sum_{C(\nunlabeledsmall)=k} \vx_\nunlabeledsmall^T\vx_\nunlabeledsmall - \nunlabeled_k\bar{\vx}_k^T\bar{\vx}_k\right] \end{align}

Now, note that

\begin{align} \sum_{C(\nunlabeledsmall)=k} \norm{\vx_\nunlabeledsmall - \bar{\vx}_k}{}^2 &= \sum_{C(\nunlabeledsmall)=k} \left[ \vx_\nunlabeledsmall^T\vx_\nunlabeledsmall- 2\vx_\nunlabeledsmall^T\bar{\vx}_k + \bar{\vx}_k^T \bar{\vx}_k\right] \\\\ &= \nunlabeled_k \vx_\nunlabeledsmall^T\vx_\nunlabeledsmall - 2\left[ \sum_{C(\nunlabeledsmall)=k} \vx_\nunlabeledsmall \right]^T \bar{\vx}_k + \nunlabeled_k \bar{\vx}_k^T \bar{\vx}_k \\\\ &= \nunlabeled_k \vx_\nunlabeledsmall^T\vx_\nunlabeledsmall - 2\nunlabeled_k \bar{\vx}_k^T \bar{\vx}_k + \nunlabeled_k \bar{\vx}_k^T \bar{\vx}_k \\\\ &= \nunlabeled_k \vx_\nunlabeledsmall^T\vx_\nunlabeledsmall - \nunlabeled_k \bar{\vx}_k^T \bar{\vx}_k \\\\ \end{align}

Substituting this result in the previous expression, we get

\begin{align} \sum_{C(\nunlabeledsmall)=k} \sum_{C(\dash{\nunlabeledsmall})=k} \norm{\vx_\nunlabeledsmall - \vx_{\dash{\nunlabeledsmall}}}{}^2 &=2 \nunlabeled_k \left[\sum_{C(\nunlabeledsmall)=k} \vx_\nunlabeledsmall^T\vx_\nunlabeledsmall - \nunlabeled_k\bar{\vx}_k^T\bar{\vx}_k\right] \\\\ &= 2 \nunlabeled_k \sum_{C(\nunlabeledsmall)=k} \norm{\vx_\nunlabeledsmall - \bar{\vx}_k}{}^2 \end{align}

Thus, the within point scatter loss function can be written as

\begin{equation} W(\sC) = \sum_{k=1}^K \nunlabeled_k \sum_{C(\nunlabeledsmall)=k} \norm{\vx_\nunlabeledsmall - \bar{\vx}_k}{}^2 \label{eqn:loss-centroids-proved} \end{equation}

Why the iterative approach?

Note that given the within-point scatter loss, our goal is a to find the optimal clustering \( \star{C} \), such that the loss is minimized

\begin{equation} \star{C} = \argmin_{C} \sum_{k=1}^K \nunlabeled_k \sum_{C(\nunlabeledsmall)=k} \norm{\vx_\nunlabeledsmall - \bar{\vx}_k}{}^2 \label{eqn:loss-argmin} \end{equation}

Note that for any set of observations \( \sA \), their centroid \( \bar{\vx}_\sA \) is the point that has the shortest total distance to all the points in \( \sA \).

$$ \bar{\vx}_{\sA} = \argmin_{\va} \sum_{\vx \in \sA} \norm{\vx - \va}{}^2 $$

Thus, the iterative process is actually just solving the following optimization problem.

\begin{equation} \star{C} = \argmin_{C,\set{\va_k}_{k=1}^K} \sum_{k=1}^K \nunlabeled_k \sum_{C(\nunlabeledsmall)=k} \norm{\vx_\nunlabeledsmall - \va_k}{}^2 \label{eqn:loss-argmin-ak} \end{equation}

Therefore, the iterative steps alternate between two objectives:

  1. Repeated estimations of the centroids are optimizing for \( \set{\va_k}_{k=1}^K \).
  2. Repeated cluster assignments are optimizing for \( C \).

Dealing with feature types

Note that \(k\)-means algorithm relies on Euclidean distance and the idea of centroids or averages of the training instances belonging to a cluster. This makes it suitable for binary and continuous features since both can be treated as real-valued features.

In the case of categorical features a direct metric score calculation is not possible. Therefore, we need to first preprocess the categorical variables using one-hot encoding to arrive at a binary feature representation. Although, there are better alternatives for dealing with clustering categorical or mix-type features.

Please support us

Help us create more engaging and effective content and keep it free of paywalls and advertisements!

Subscribe for article updates

Stay up to date with new material for free.