Silhouette Statistics

Diagnostics for the quality of a clustering.

Kris Sankaran (UW Madison)
01-07-2024

Reading, Recording, Rmarkdown

  1. Clustering algorithms usually require the number of clusters \(K\) as an argument. How should it be chosen?

  2. There are many possible criteria, but one common approach is to compute the silhouette statistic. It is a statistic that can be computed for each observation in a dataset, measuring how strongly it is tied to its assigned cluster. If a whole cluster has large silhouette statistics, then that cluster is well-defined and clearly isolated other clusters.

  3. The plots below illustrate the computation of silhouette statistics for a clustering of the penguins dataset that used \(K = 3\). To set up, we first need to cluster the penguins dataset. The idea is the same as in the \(K\)-means notes, but we encapsulate the code in a function, so that we can easily extract data for different values of \(K\).

penguins <- read_csv("https://uwmadison.box.com/shared/static/ijh7iipc9ect1jf0z8qa2n3j7dgem1gh.csv") %>%
  na.omit() %>%
  mutate(id = row_number())

cluster_penguins <- function(penguins, K) {
  x <- penguins %>%
    select(matches("length|depth|mass")) %>%
    scale()
    
  kmeans(x, center = K) %>%
    augment(penguins) %>% # creates column ".cluster" with cluster label
    mutate(silhouette = silhouette(as.integer(.cluster), dist(x))[, "sil_width"])
}
  1. Denote the silhouette statistic of observation \(i\) by \(s_{i}\). We will compute \(s_i\) for the observation with the black highlight below1.
cur_id <- 2
penguins3 <- cluster_penguins(penguins, K = 3)
obs_i <- penguins3 %>%
  filter(id == cur_id)
ggplot(penguins3, aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster)) +
  geom_point(data = obs_i, size = 5, col = "black") + 
  geom_point() +
  scale_color_brewer(palette = "Set2") +
  scale_size(range = c(4, 1))
The observation on which we will compute the silhouette statistic.

Figure 1: The observation on which we will compute the silhouette statistic.

  1. The first step in the calculation of the silhouette statistic is to measure the pairwise distances between the observation \(i\) and all observations in the same cluster. These distances are the lengths of the small lines below. Call average of these lengths \(a_{i}\).
ggplot(penguins3, aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster)) +
  geom_segment(
    data = penguins3 %>% filter(.cluster == obs_i$.cluster), 
    aes(xend = obs_i$bill_length_mm, yend = obs_i$bill_depth_mm),
    size = 0.6, alpha = 0.3
  ) +
  geom_point(data = obs_i, size = 5, col = "black") + 
  geom_point() +
  scale_color_brewer(palette = "Set2") +
  scale_size(range = c(4, 1)) +
  labs(title = expression(paste("Distances used for ", a[i])))
The average distance between the target observation and all others in the same cluster.

Figure 2: The average distance between the target observation and all others in the same cluster.

  1. Next, we compute pairwise distances to all observations in clusters 2 and 3. The average of these pairwise distances are called \(b_{i2}\) and \(b_{i3}\). Choose the smaller of \(b_{i2}\) and \(b_{i3}\), and call it \(b_{i}\). In a sense, this is the “next best” cluster to put observation \(i\). For a general \(K\), you would compute \(b_{ik}\) for all \(k\) (other than observation \(i\)’s cluster) and take the minimum across all of them. In this case, the orange segments are on average smaller than the blue segments, so \(b_i\) is defined as the average length of the orange segments.
ggplot(penguins3, aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster)) +
  geom_segment(
    data = penguins3 %>% filter(.cluster != obs_i$.cluster), 
    aes(xend = obs_i$bill_length_mm, yend = obs_i$bill_depth_mm, col = .cluster),
    size = 0.5, alpha = 0.3
  ) +
  geom_point(data = obs_i, size = 5, col = "black") + 
  geom_point() +
  scale_color_brewer(palette = "Set2") +
  scale_size(range = c(4, 1)) +
  labs(title = expression(paste("Distances used for ", b[i][1], " and ", b[i][2])))
The average distance between the target observation and all others in *different* clusters.

Figure 3: The average distance between the target observation and all others in different clusters.

  1. The silhouette statistic for observation \(i\) is derived from the relative lengths of the orange vs. green segments. Formally, the silhouette statistic for observation \(i\) is \(s_{i}:= \frac{b_{i} - a_{i}}{\max\left({a_{i}, b_{i}}\right)}\). This number is close to 1 if the orange segments are much longer than the green segments, close to 0 if the segments are about the same size, and close to -1 if the the orange segments are much shorter than the green segments2.

  2. The median of these \(s_{i}\) for all observations within cluster \(k\) is a measure of how well-defined cluster \(k\) is overall. The higher this number, the more well-defined the cluster.

  3. Denote the median of the silhouette statistics within cluster \(k\) by \(SS_{k}\). A measure how good a choice of \(K\) is can be determined by the median of these medians: \(\text{Quality}(K) := \text{median}_{k = 1 \dots, K} SS_{k}\).

  4. In particular, this can be used to define (a) a good cut point in a hierarchical clustering or (b) a point at which a cluster should no longer be split into subgroups.

  5. In R, we can use the silhouette function from the cluster package to compute the silhouette statistic. The syntax is silhouette(cluster_labels, pairwise_distances) where cluster_labels is a vector of (integer) cluster ID’s for each observation and pairwise_distances gives the lengths of the segments between all pairs of observations. An example of this function’s usage is given in the function at the start of the illustration.

  6. This is what the silhouette statistic looks like in the penguins dataset when we choose 3 clusters. The larger points have lower silhouette statistics. This points between clusters 2 and 3 have large silhouette statistics because those two clusters blend into one another.

ggplot(penguins3) +
  geom_point(aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster, size = silhouette)) +
  scale_color_brewer(palette = "Set2") +
  scale_size(range = c(4, 1))
The silhouette statistics on the Palmers Penguins dataset, when using $K$-means with $K = 3$.

Figure 4: The silhouette statistics on the Palmers Penguins dataset, when using \(K\)-means with \(K = 3\).

  1. We can also visualize the histogram of silhouette statistics within each cluster. Since the silhouette statistics for cluster 2 are generally lower than those for the other two clusters (in particular, its median is lower), we can conclude that it is less well-defined.
ggplot(penguins3) +
  geom_histogram(aes(x = silhouette), binwidth = 0.05) +
  facet_grid(~ .cluster)
The per-cluster histograms of silhouette statistics summarize how well-defined each cluster is.

Figure 5: The per-cluster histograms of silhouette statistics summarize how well-defined each cluster is.

  1. If we choose even more clusters, then there are more points lying along the boundaries of poorly defined clusters. Their associated silhouette statistics end up becoming larger. From the histogram, we can also see a deterioration in the median silhouette scores across all clusters.
penguins4 <- cluster_penguins(penguins, K = 4)
ggplot(penguins4) +
  geom_point(aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster, size = silhouette)) +
  scale_color_brewer(palette = "Set2") +
  scale_size(range = c(4, 1))
We can repeat the same exercise, but with $K = 4$ clusters instead.

Figure 6: We can repeat the same exercise, but with \(K = 4\) clusters instead.

ggplot(penguins4) +
  geom_histogram(aes(x = silhouette), binwidth = 0.05) +
  facet_grid(~ .cluster)

Citation

For attribution, please cite this work as

Sankaran (2024, Jan. 7). STAT 436 (Spring 2024): Silhouette Statistics. Retrieved from https://krisrs1128.github.io/stat436_s24/website/stat436_s24/posts/2024-12-27-week09-04/

BibTeX citation

@misc{sankaran2024silhouette,
  author = {Sankaran, Kris},
  title = {STAT 436 (Spring 2024): Silhouette Statistics},
  url = {https://krisrs1128.github.io/stat436_s24/website/stat436_s24/posts/2024-12-27-week09-04/},
  year = {2024}
}