source("https://github.com/krisrs1128/stat479_notes/raw/refs/heads/master/activities/05-helpers.R")
library(mlr3)
library(mlr3learners)
library(mlr3tuning)
library(patchwork)
library(shapviz)
library(tidyverse)
theme_set(theme_classic())
set.seed(20251227)
options(mlr3.store_backends = TRUE)SHAP analysis of Drosophila Development
Problem Formulation
Enhancers are DNA sequences that coordinate downstream gene expression. Their activity depends on genome state, like transcription factor binding and chromatin state.
Understanding enhancers means understanding how genotype becomes phenotype, a fundamental problem in biology. To make the problem tractable, we usually study it in simple organisms, like the fruit fly Drosophila melanogaster, often during embryo development.
(Basu et al. 2018) asked a precise question: given the genomic state of a fruit fly embryo at a fixed developmental stage, can we predict whether a given enhancer is active? They used genome-wide measurements from blastoderm (stage 5) embryos:
- DNA occupancy for 23 transcription factors
- Activity for 13 chromatin markers
- Binary enhancer activity labels
This is a binary classification problem. Each observation is a genomic region with associated regulatory features.
- Enhancer status: y \in \{0,1\}
- Predictors: x = (x_1, \ldots, x_D) transcription factor binding intensities and chromatin signals
- Model: f(x_i) = \mathbb{P}(y_i = 1 \mid x_i)
SHAP quantifies each feature’s contribution to f(x_i).
Setup
The first three packages below train the boosting model. The rest are for visualization and SHAP explanation.
This block downloads the data. Enhancer is stored in the first column y.
f <- tempfile()
download.file("https://zenodo.org/records/18371236/files/enhancer.Rdata?download=1", f)
load(f)
enhancer <- bind_cols(y = as.factor(Y), X)
enhancer# A tibble: 7,809 × 81
y H3_c12 H3_c14a H3_c14c H3_c8 H3K18ac_c12 H3K18ac_c14a H3K18ac_c14c
<fct> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 0 6 10 10 8 5 7 8
2 0 6 7 6 5 19 10 11
3 0 14 9 14 13 11 9 19
4 0 13 17 13 13 19 14 25
5 1 6 6 6 10 13 23 28
6 1 10 9 9 11 12 23 28
7 0 6 8 8 10 11 13 24
8 0 8 10 16 6 7 19 16
9 0 16 20 11 10 34 8 15
10 0 16 16 16 11 13 20 22
# ℹ 7,799 more rows
# ℹ 73 more variables: H3K18ac_c8 <dbl>, H3K27ac_c12 <dbl>, H3K27ac_c14a <dbl>,
# H3K27ac_c14c <dbl>, H3K27ac_c8 <dbl>, H3K27me3_c12 <dbl>,
# H3K27me3_c14a <dbl>, H3K27me3_c14c <dbl>, H3K36me3_c12 <dbl>,
# H3K36me3_c14a <dbl>, H3K36me3_c14c <dbl>, H3K4me1_c12 <dbl>,
# H3K4me1_c14a <dbl>, H3K4me1_c14c <dbl>, H3K4me1_c8 <dbl>,
# H3K4me3_c12 <dbl>, H3K4me3_c14a <dbl>, H3K4me3_c14c <dbl>, …
Model Training
We train a boosting classifier on the enhancer data. The tuning grid considers the number of trees nrounds and learning ratae eta, selecting hyperparameters using three fold CV.
task <- as_task_classif(enhancer, target = "y", id = "enhancer", positive = "1")
learner <- lrn("classif.xgboost", nrounds = to_tune(100, 200), eta = to_tune(1e-3, 0.1), predict_type = "prob")
instance <- ti(
task = task,
learner = learner,
resampling = rsmp("cv", folds = 3),
measures = msr("classif.auc"),
terminator = trm("none")
)We run the hyperparameter tuning next.
tuner <- tnr("grid_search", resolution = 5, batch_size = 5)
tuner$optimize(instance) eta nrounds learner_param_vals x_domain classif.auc
<num> <int> <list> <list> <num>
1: 0.02575 100 <list[5]> <list[2]> 0.825632
Let’s refit the model with the optimized hyperparameters.
learner$param_set$values <- instance$result_learner_param_vals
task <- as_task_classif(enhancer, target = "y", id = "enhancer", positive = "1")
learner$train(task)The histogram of predicted probabilities shows that the model behaves as expected: inactive enhancers (class 0) clusters near zero probability, and active enhancers (class 1) is more spread to high probability regions. Though, there is still noticeable overlap.
tibble(p_hat = learner$predict_newdata(enhancer)$prob[, 1], y = enhancer$y) |>
ggplot() +
geom_histogram(aes(p_hat, fill = y)) +
facet_grid(y ~ ., scales = "free_y")Interpretation Setup
Entry (i, d) of shp$S is the SHAP attribution \varphi_d(x_i), feature d’s contribution to f(x_i).
X_mat <- as.matrix(X)
shp <- shapviz(learner$model, X_pred = X_mat)
shp$S[1:5, 1:8] H3_c12 H3_c14a H3_c14c H3_c8 H3K18ac_c12
VT0002 0.068839401 -0.034920834 -0.008018139 -0.07555990 0.0327704698
VT0003 0.064319402 -0.004164241 0.007724053 -0.03716381 -0.0258688685
VT0004 0.083743654 -0.037574559 -0.007068704 -0.07883391 0.0002055851
VT0005 0.055233277 0.009904596 -0.002665190 -0.03974437 -0.0256465226
VT0006 -0.004646949 -0.001682470 -0.053047363 -0.02233944 -0.0076694093
H3K18ac_c14a H3K18ac_c14c H3K18ac_c8
VT0002 0.041389603 0.0031723969 0.011215613
VT0003 0.016818035 0.0003347002 0.008839345
VT0004 0.041915886 -0.0014176624 0.001084273
VT0005 -0.001247977 0.0006725105 0.016408261
VT0006 -0.009059546 -0.0014770469 0.007148674
The efficiency axiom holds, though on a log-odds scale: \sum_d \varphi_d(x_i) + \varphi_0 = \text{logit}^{-1}\mathbb{P}(y \vert x_i). Most feature importances measures lack this sum-to-prediction property; this is something that helps SHAP stand out.
y_hat <- predict(learner$model, X_mat, outputmargin = TRUE)
tibble(y_hat = y_hat, sum_phi = rowSums(shp$S) + shp$baseline) |>
ggplot() +
geom_point(aes(y_hat, sum_phi)) +
labs(
x = "Model Prediction (Log-Odds)",
y = "Sum of SHAP Attributions"
)Variable-Level Pltos
The beeswarm plot ranks features using \left|\varphi_d(x_i)\right|. Across many samples, the wt_ZLD, input_c14c, and sna1 features stand out. For most features, the attributions have consistent signs across samples, but the tails show that the sign can flip, perhaps when the feature has an unusual value/context.
sv_importance(shp, kind = "bee")The dependence plot for sna1 shows that high sna1 TF binding leads to positive attribution.
sv_dependence(shp, "sna1")This is consistent with the data, where large sna1 increases enhancer activity. The sna1 gene is generally a transcriptional repressor, but in these data, it seems that when sna1 is present at a location, the enhancer is more likely to be active.
tibble(p_hat = predict(learner$model, X_mat, outputmargin = TRUE), y = enhancer$y, x= enhancer$sna1) |>
ggplot() +
geom_point(aes(x, p_hat, col = y)) +
labs(x = "Value of SNA1 Feature", y = "Log-Odds", col = "True Class")Sample-Level Plots
Alternatively, waterfall plots break a single prediction down into feature contributions. Each bar is one feature, and tracking the endpoints across bars shows how the model arrives at the output by weighing evidence across many features.
map(seq_len(6), \(i) sv_waterfall(shp, max_display = 8, row = i)) |>
reduce(`+`)Waterfall plots don’t scale well to many samples. A stacked bar alternative solves this issue – each vertical bar below is one sample, colored by SHAP attributions for the 11 most improtant features. Samples are sorted so that those with similar attribution profiles are side-by-side. The “x” marks give the actual log-odds predictions, and features above or below indicate positive or negative attributions, respectively.
The block below calls our helper functions (in 05-helpers.R) to perform this ordering.
var_order <- order(abs(colSums(shp$S)), decreasing = TRUE)
important_vars <- colnames(shp$S)[var_order] |>
head(11)
ix <- sample(1:nrow(enhancer), 500)
shap_bars <- shap_to_long(shp$S[ix, ], y_hat[ix], shp$baseline, important_vars) |>
add_waterfall_offsets()This view shows that samples with nearly identical predictions can have different reasons for those predictions. For example, kr1 drives high predictions for a cluster of sequences near the left, but has barely any influence elsewhere.
ggplot(shap_bars) +
geom_hline(yintercept = 0) +
geom_rect(aes(xmin = xpos - 0.45, xmax = xpos + 0.45, ymin = ymin, ymax = ymax, fill = name, col = name)) +
geom_point(
data = distinct(shap_bars, xpos, pred),
aes(xpos, pred),
size = 3,
shape = "x"
) +
labs(y = "Prediction relative to E[f(x)]", x = "Samples", fill = "Feature", col = "Feature") +
scale_x_continuous(expand = c(0, 0)) +
scale_fill_brewer(palette = "Set3", na.value = "#dcdcdc") +
scale_color_brewer(palette = "Set3", na.value = "#dcdcdc") +
theme(
axis.ticks = element_blank(),
axis.text.x = element_blank(),
)