Diagnostics for the quality of a clustering.
Clustering algorithms usually require the number of clusters \(K\) as an argument. How should it be chosen?
There are many possible criteria, but one common approach is to compute the silhouette statistic. It is a statistic that can be computed for each observation in a dataset, measuring how strongly it is tied to its assigned cluster. If a whole cluster has large silhouette statistics, then that cluster is well-defined and clearly isolated other clusters.
The plots below illustrate the computation of silhouette statistics for a clustering of the penguins dataset that used \(K = 3\). To set up, we first need to cluster the penguins dataset. The idea is the same as in the \(K\)-means notes, but we encapsulate the code in a function, so that we can easily extract data for different values of \(K\).
penguins <- read_csv("https://uwmadison.box.com/shared/static/ijh7iipc9ect1jf0z8qa2n3j7dgem1gh.csv") %>%
na.omit() %>%
mutate(id = row_number())
cluster_penguins <- function(penguins, K) {
x <- penguins %>%
select(matches("length|depth|mass")) %>%
scale()
kmeans(x, center = K) %>%
augment(penguins) %>% # creates column ".cluster" with cluster label
mutate(silhouette = silhouette(as.integer(.cluster), dist(x))[, "sil_width"])
}
ggplot(penguins3, aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster)) +
geom_point(data = obs_i, size = 5, col = "black") +
geom_point() +
scale_color_brewer(palette = "Set2") +
scale_size(range = c(4, 1))
ggplot(penguins3, aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster)) +
geom_segment(
data = penguins3 %>% filter(.cluster == obs_i$.cluster),
aes(xend = obs_i$bill_length_mm, yend = obs_i$bill_depth_mm),
size = 0.6, alpha = 0.3
) +
geom_point(data = obs_i, size = 5, col = "black") +
geom_point() +
scale_color_brewer(palette = "Set2") +
scale_size(range = c(4, 1)) +
labs(title = expression(paste("Distances used for ", a[i])))
ggplot(penguins3, aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster)) +
geom_segment(
data = penguins3 %>% filter(.cluster != obs_i$.cluster),
aes(xend = obs_i$bill_length_mm, yend = obs_i$bill_depth_mm, col = .cluster),
size = 0.5, alpha = 0.3
) +
geom_point(data = obs_i, size = 5, col = "black") +
geom_point() +
scale_color_brewer(palette = "Set2") +
scale_size(range = c(4, 1)) +
labs(title = expression(paste("Distances used for ", b[i][1], " and ", b[i][2])))
The silhouette statistic for observation \(i\) is derived from the relative lengths of the orange vs. green segments. Formally, the silhouette statistic for observation \(i\) is \(s_{i}:= \frac{b_{i} - a_{i}}{\max\left({a_{i}, b_{i}}\right)}\). This number is close to 1 if the orange segments are much longer than the green segments, close to 0 if the segments are about the same size, and close to -1 if the the orange segments are much shorter than the green segments2.
The median of these \(s_{i}\) for all observations within cluster \(k\) is a measure of how well-defined cluster \(k\) is overall. The higher this number, the more well-defined the cluster.
Denote the median of the silhouette statistics within cluster \(k\) by \(SS_{k}\). A measure how good a choice of \(K\) is can be determined by the median of these medians: \(\text{Quality}(K) := \text{median}_{k = 1 \dots, K} SS_{k}\).
In particular, this can be used to define (a) a good cut point in a hierarchical clustering or (b) a point at which a cluster should no longer be split into subgroups.
In R, we can use the silhouette
function from the cluster package to
compute the silhouette statistic. The syntax is silhouette(cluster_labels, pairwise_distances)
where cluster_labels
is a vector of (integer) cluster
ID’s for each observation and pairwise_distances
gives the lengths of the
segments between all pairs of observations. An example of this function’s usage
is given in the function at the start of the illustration.
This is what the silhouette statistic looks like in the penguins dataset when we choose 3 clusters. The larger points have lower silhouette statistics. This points between clusters 2 and 3 have large silhouette statistics because those two clusters blend into one another.
ggplot(penguins3) +
geom_point(aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster, size = silhouette)) +
scale_color_brewer(palette = "Set2") +
scale_size(range = c(4, 1))
ggplot(penguins3) +
geom_histogram(aes(x = silhouette), binwidth = 0.05) +
facet_grid(~ .cluster)
penguins4 <- cluster_penguins(penguins, K = 4)
ggplot(penguins4) +
geom_point(aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster, size = silhouette)) +
scale_color_brewer(palette = "Set2") +
scale_size(range = c(4, 1))
ggplot(penguins4) +
geom_histogram(aes(x = silhouette), binwidth = 0.05) +
facet_grid(~ .cluster)
For attribution, please cite this work as
Sankaran (2024, Jan. 7). STAT 436 (Spring 2024): Silhouette Statistics. Retrieved from https://krisrs1128.github.io/stat436_s24/website/stat436_s24/posts/2024-12-27-week09-04/
BibTeX citation
@misc{sankaran2024silhouette, author = {Sankaran, Kris}, title = {STAT 436 (Spring 2024): Silhouette Statistics}, url = {https://krisrs1128.github.io/stat436_s24/website/stat436_s24/posts/2024-12-27-week09-04/}, year = {2024} }