Discovering richer structure in partial dependence profiles.
library(tidyverse)
library(caret)
library(DALEX)
theme479 <- theme_minimal() +
theme(
panel.grid.minor = element_blank(),
panel.background = element_rect(fill = "#f7f7f7"),
panel.border = element_rect(fill = NA, color = "#0c0c0c", size = 0.6),
legend.position = "bottom"
)
theme_set(theme479)
We say that there is an interaction between variables \(j\) and \(j^{\prime}\) if the relationship between \(x_{j}\) and \(y\) is modulated by variable \(j^{\prime}\). For example, in the figure below, the slope of cross-sections across \(j\) depends on \(j^{\prime}\).
Using the language of CP profiles, the figure above means that the shape of the CP profile in \(j\) depends on the particular setting of \(j^{\prime}\). This motivates the use of Grouped PD profiles — we compute several PD profiles in \(j\), restricting attention to CP profiles whose value \(x_{j^{\prime}}\) lies within a prespecified range.
To illustrate, we revisit the CP profiles for age from the Titanic dataset. Below, the profiles are grouped according to the class of the ticket holder. The result shows that the relationship between age and survival was not the same across all passengers. For all classes, there was a decrease in survival probability for adults, but the dropoff was most severe for crew members.
f <- tempfile()
download.file("https://uwmadison.box.com/shared/static/nau695mppsoxx0f6bns1ieo7kh1bje0j.rda", f)
fit <- get(load(f))
data(titanic)
titanic <- titanic %>%
select(-country) %>%
na.omit()
x <- select(titanic, -survived)
explanation <- explain(model = fit, data = x, y = titanic$survived)
Preparation of a new explainer is initiated
-> model label : train ( default )
-> data : 2179 rows 7 cols
-> target variable : 2179 values
-> predict function : yhat.train will be used ( default )
-> predicted values : No value for predict function target column. ( default )
-> model_info : package caret , ver. 6.0.93 , task classification ( default )
-> model_info : Model info detected classification task but 'y' is a factor . ( WARNING )
-> model_info : By deafult classification tasks supports only numercical 'y' parameter.
-> model_info : Consider changing to numerical vector with 0 and 1 values.
-> model_info : Otherwise I will not be able to calculate residuals or loss function.
-> predicted values : numerical, min = 0.007271972 , mean = 0.3239071 , max = 0.9885397
-> residual function : difference between y and yhat ( default )
-> residuals : numerical, min = NA , mean = NA , max = NA
A new explainer has been created!
profiles <- model_profile(explainer = explanation, groups = "class")
plot(profiles, geom = "profiles", variables = "age") +
scale_color_brewer(palette = "Set2") +
theme479
What should we do if there are many input variables and we don’t have a priori knowledge about which variables \(j^{\prime}\) might be interacting with \(j\)? One idea is to try to discover relevant interactions by clustering the original set of CP profiles.
In more detail, we can compute the CP profiles for all the samples, and then see whether there are subsets of profiles that all look similar. If we find features \(j^{\prime}\) that characterize these groupings, then we have found features that interact with \(j\) (with respect to the fitted model \(f\)). The plot below shows the same profiles as above, but clustering directly. It seems to recover the interaction between age and class, even though we have not explicitly provided this grouping variable.
profiles <- model_profile(explainer = explanation, variables = "age", k = 3)
plot(profiles, geom = "profiles", variables = "age") +
scale_color_brewer(palette = "Set2") +
theme479
data(apartments)
x <- select(apartments, -m2.price)
profiles_lm <- train(x, apartments$m2.price, method = "lm") %>%
explain(x, apartments$m2.price, label = "LM") %>%
model_profile()
Preparation of a new explainer is initiated
-> model label : LM
-> data : 1000 rows 5 cols
-> target variable : 1000 values
-> predict function : yhat.train will be used ( default )
-> predicted values : No value for predict function target column. ( default )
-> model_info : package caret , ver. 6.0.93 , task regression ( default )
-> predicted values : numerical, min = 1781.848 , mean = 3487.019 , max = 6176.032
-> residual function : difference between y and yhat ( default )
-> residuals : numerical, min = -247.4728 , mean = 2.093656e-14 , max = 469.0023
A new explainer has been created!
profiles_rf <- train(x, apartments$m2.price, method = "rf", tuneGrid = data.frame(mtry = 10)) %>%
explain(x, apartments$m2.price, label = "RF") %>%
model_profile()
Preparation of a new explainer is initiated
-> model label : RF
-> data : 1000 rows 5 cols
-> target variable : 1000 values
-> predict function : yhat.train will be used ( default )
-> predicted values : No value for predict function target column. ( default )
-> model_info : package caret , ver. 6.0.93 , task regression ( default )
-> predicted values : numerical, min = 1667.05 , mean = 3486.418 , max = 6385.49
-> residual function : difference between y and yhat ( default )
-> residuals : numerical, min = -222.6328 , mean = 0.6008956 , max = 263.8863
A new explainer has been created!
plot(profiles_lm, profiles_rf, variables = c("surface", "floor")) +
scale_color_brewer(palette = "Set2") +
theme479
11.. When making the comparison between a simple and a complex model, certain discrepancies might become apparent. For example, important nonlinearities or interactions might be visible from the PD profiles of the complex model. This information can guide the design of new features in the simpler model, so that it can continue to be used. This is exactly the case in the apartments dataset above – there is a strong nonlinear relationship for the construction year variables. This suggests that, if a linear model is still desired, then a new feature should be defined that identifies whether the apartment was built between 1935 and 1990.
plot(profiles_lm, profiles_rf, variables = "construction.year") +
scale_color_brewer(palette = "Set2") +
theme479
It is a simulated dataset, but designed to reflect properties of a real dataset.↩︎
For attribution, please cite this work as
Sankaran (2022, Dec. 30). STAT 436 (Spring 2023): Partial Dependence Profiles II. Retrieved from https://krisrs1128.github.io/stat436_s23/website/stat436_s23/posts/2022-12-27-week12-2/
BibTeX citation
@misc{sankaran2022partial, author = {Sankaran, Kris}, title = {STAT 436 (Spring 2023): Partial Dependence Profiles II}, url = {https://krisrs1128.github.io/stat436_s23/website/stat436_s23/posts/2022-12-27-week12-2/}, year = {2022} }