An introduction to partial dependence profiles.
library(caret)
library(tidyverse)
library(DALEX)
theme479 <- theme_minimal() +
theme(
panel.grid.minor = element_blank(),
panel.background = element_rect(fill = "#f7f7f7"),
panel.border = element_rect(fill = NA, color = "#0c0c0c", size = 0.6),
legend.position = "bottom"
)
theme_set(theme479)
As more complex models become more common in practice, visualization has emerged as a key way for (a) summarizing their essential structure and (b) motivating further modeling refinements.
In modern machine learning, it’s common to use a function f to approximate the relationship between a D-dimensional input x and a univariate response y. We are given a sample of n pairs (xi,yi) with which to learn this relationship, and we hope that the function we learn will generalize to future observations.
Some further notation: We will write xj for the jth coordinate of x. We will write xj|=z to denote the observation x with the jth coordinate set to z.
Figure 1: Illustration of the xj|=z operation. The jth coordinate (1 in this case) for a selected observation is set equal to z.
Linear models are simple enough that they don’t require any follow-up visual inspection. Since they assume f(x)=ˆβTx, they are completely described by the vector of coefficients ˆβ. We can exactly describe what happens to f when we increase xj by one unit: we just increase the prediction by ˆβj.
More complex models — think random forests or neural networks — don’t have this property. While these models often have superior performance, it’s hard to say how changes in particular input features will affect the prediction.
Partial dependence plots try to address this problem. They provide a description for how changing the jth input feature affects the predictions made by complex models.
To motivate the definition, consider the toy example below. The surface is the fitted function f(x), mapping a two dimensional input x to a real-valued response. How would you summarize the relationship between x1 and y? The main problem is that the shape of the relationship depends on which value of x2 we start at.
Figure 2: An example of why it is difficult to summarize the relationship between an input variable and a fitted surface for nonlinear models.
One idea is to consider the values of x2 that were observed in our
dataset. Then, we can evaluate our model over a range of values x1 after
fixing those values of x2. These curves are called Ceteris Paribus
profiles
The same principle holds in higher dimensions. We can fix D−1 coordinates of an observation and then evaluate what happens to a sample’s predictions when we vary coordinate j. Mathematically, this is expressed by hf,jx(z):=f(xj|=z).
Figure 3: Visual intuition behind the CP profile. Varying the jth coordinate for an observation traces out a curve in the prediction surface.
data(titanic)
titanic <- select(titanic, -country) %>%
na.omit()
x <- select(titanic, -survived)
hyper <- data.frame(n.trees = 100, interaction.depth = 8, shrinkage = 0.1, n.minobsinnode = 10)
fit <- train(x = x, y = titanic$survived, method = "gbm", tuneGrid = hyper, verbose = F)
explanation <- explain(model = fit, data = x, y = titanic$survived)
Preparation of a new explainer is initiated
-> model label : train ( default )
-> data : 2179 rows 7 cols
-> target variable : 2179 values
-> predict function : yhat.train will be used ( default )
-> predicted values : No value for predict function target column. ( default )
-> model_info : package caret , ver. 6.0.94 , task classification ( default )
-> model_info : Model info detected classification task but 'y' is a factor . ( WARNING )
-> model_info : By deafult classification tasks supports only numercical 'y' parameter.
-> model_info : Consider changing to numerical vector with 0 and 1 values.
-> model_info : Otherwise I will not be able to calculate residuals or loss function.
-> predicted values : numerical, min = 0.01904679 , mean = 0.3272747 , max = 0.9865176
-> residual function : difference between y and yhat ( default )
-> residuals : numerical, min = NA , mean = NA , max = NA
A new explainer has been created!
profile <- model_profile(explainer = explanation)
plot(profile, geom = "profiles", variables = "age") +
theme479
Figure 4: CP and PDP profiles for age, for a GBM fitted to the Titanic dataset.
It seems that children had the highest probability
We used the DALEX package to produce these curves. The explain
function
takes the fitted model and original dataset as input. It returns an object with
many kinds of model summaries. To extract the CP profiles from these summaries,
we use model_profile
. The output of this function has been designed so that
calling plot
with geom = "profiles"
will show the CP profiles.
The PD profile is computed by averaging across all the CP profiles. It is a more concise alternative to CP profiles, showing one curve per features, rather than one curve per sample.
plot(profile, geom = "aggregates") +
theme479
For attribution, please cite this work as
Sankaran (2024, Jan. 7). STAT 436 (Fall 2024): Partial Dependence Profiles I. Retrieved from https://krisrs1128.github.io/stat436_s24/website/stat436_f24/posts/2024-12-27-week12-1/
BibTeX citation
@misc{sankaran2024partial, author = {Sankaran, Kris}, title = {STAT 436 (Fall 2024): Partial Dependence Profiles I}, url = {https://krisrs1128.github.io/stat436_s24/website/stat436_f24/posts/2024-12-27-week12-1/}, year = {2024} }