Discovering richer structure in partial dependence profiles.
library(tidyverse)
library(caret)
library(DALEX)
theme479 <- theme_minimal() + 
  theme(
    panel.grid.minor = element_blank(),
    panel.background = element_rect(fill = "#f7f7f7"),
    panel.border = element_rect(fill = NA, color = "#0c0c0c", size = 0.6),
    legend.position = "bottom"
  )
theme_set(theme479)We say that there is an interaction between variables \(j\) and \(j^{\prime}\) if the relationship between \(x_{j}\) and \(y\) is modulated by variable \(j^{\prime}\). For example, in the figure below, the slope of cross-sections across \(j\) depends on \(j^{\prime}\).
Using the language of CP profiles, the figure above means that the shape of the CP profile in \(j\) depends on the particular setting of \(j^{\prime}\). This motivates the use of Grouped PD profiles — we compute several PD profiles in \(j\), restricting attention to CP profiles whose value \(x_{j^{\prime}}\) lies within a prespecified range.
To illustrate, we revisit the CP profiles for age from the Titanic dataset. Below, the profiles are grouped according to the class of the ticket holder. The result shows that the relationship between age and survival was not the same across all passengers. For all classes, there was a decrease in survival probability for adults, but the dropoff was most severe for crew members.
f <- tempfile()
download.file("https://uwmadison.box.com/shared/static/nau695mppsoxx0f6bns1ieo7kh1bje0j.rda", f)
fit <- get(load(f))
data(titanic)
titanic <- titanic %>%
  select(-country) %>%
  na.omit()
x <- select(titanic, -survived)
explanation <- explain(model = fit, data = x, y = titanic$survived)Preparation of a new explainer is initiated
  -> model label       :  train  (  default  )
  -> data              :  2179  rows  7  cols 
  -> target variable   :  2179  values 
  -> predict function  :  yhat.train  will be used (  default  )
  -> predicted values  :  No value for predict function target column. (  default  )
  -> model_info        :  package caret , ver. 6.0.94 , task classification (  default  ) 
  -> model_info        :  Model info detected classification task but 'y' is a factor .  (  WARNING  )
  -> model_info        :  By deafult classification tasks supports only numercical 'y' parameter. 
  -> model_info        :  Consider changing to numerical vector with 0 and 1 values.
  -> model_info        :  Otherwise I will not be able to calculate residuals or loss function.
  -> predicted values  :  numerical, min =  0.007271972 , mean =  0.3239071 , max =  0.9885397  
  -> residual function :  difference between y and yhat (  default  )
  -> residuals         :  numerical, min =  NA , mean =  NA , max =  NA  
  A new explainer has been created!  profiles <- model_profile(explainer = explanation, groups = "class")
plot(profiles, geom = "profiles", variables = "age") +
  scale_color_brewer(palette = "Set2") +
  theme479 
Figure 1: Grouping the CP profiles by ticket class reveals an interaction effect with age in the Titanic dataset.
What should we do if there are many input variables and we don’t have a priori knowledge about which variables \(j^{\prime}\) might be interacting with \(j\)? One idea is to try to discover relevant interactions by clustering the original set of CP profiles.
In more detail, we can compute the CP profiles for all the samples, and then see whether there are subsets of profiles that all look similar. If we find features \(j^{\prime}\) that characterize these groupings, then we have found features that interact with \(j\) (with respect to the fitted model \(f\)). The plot below shows the same profiles as above, but clustering directly. It seems to recover the interaction between age and class, even though we have not explicitly provided this grouping variable.
profiles <- model_profile(explainer = explanation, variables = "age", k = 3)
plot(profiles, geom = "profiles", variables = "age") +
  scale_color_brewer(palette = "Set2") +
  theme479 
Figure 2: Discovered groupings in the CP profiles for age reveals an interaction effect.
data(apartments)
x <- select(apartments, -m2.price)
profiles_lm <- train(x, apartments$m2.price, method = "lm") %>%
  explain(x, apartments$m2.price, label = "LM") %>%
  model_profile()Preparation of a new explainer is initiated
  -> model label       :  LM 
  -> data              :  1000  rows  5  cols 
  -> target variable   :  1000  values 
  -> predict function  :  yhat.train  will be used (  default  )
  -> predicted values  :  No value for predict function target column. (  default  )
  -> model_info        :  package caret , ver. 6.0.94 , task regression (  default  ) 
  -> predicted values  :  numerical, min =  1781.848 , mean =  3487.019 , max =  6176.032  
  -> residual function :  difference between y and yhat (  default  )
  -> residuals         :  numerical, min =  -247.4728 , mean =  -1.337654e-12 , max =  469.0023  
  A new explainer has been created!  profiles_rf <- train(x, apartments$m2.price, method = "rf", tuneGrid = data.frame(mtry = 10)) %>%
  explain(x, apartments$m2.price, label = "RF") %>%
  model_profile()Preparation of a new explainer is initiated
  -> model label       :  RF 
  -> data              :  1000  rows  5  cols 
  -> target variable   :  1000  values 
  -> predict function  :  yhat.train  will be used (  default  )
  -> predicted values  :  No value for predict function target column. (  default  )
  -> model_info        :  package caret , ver. 6.0.94 , task regression (  default  ) 
  -> predicted values  :  numerical, min =  1671.409 , mean =  3486.406 , max =  6374.482  
  -> residual function :  difference between y and yhat (  default  )
  -> residuals         :  numerical, min =  -232.2327 , mean =  0.6134767 , max =  264.4474  
  A new explainer has been created!  plot(profiles_lm, profiles_rf, variables = c("surface", "floor")) +
  scale_color_brewer(palette = "Set2") +
  theme479 
Figure 3: A contrastive PD display suggests that the floor and surface features are linearly related with apartment price.
11.. When making the comparison between a simple and a complex model, certain discrepancies might become apparent. For example, important nonlinearities or interactions might be visible from the PD profiles of the complex model. This information can guide the design of new features in the simpler model, so that it can continue to be used. This is exactly the case in the apartments dataset above – there is a strong nonlinear relationship for the construction year variables. This suggests that, if a linear model is still desired, then a new feature should be defined that identifies whether the apartment was built between 1935 and 1990.
plot(profiles_lm, profiles_rf, variables = "construction.year") +
  scale_color_brewer(palette = "Set2") +
  theme479 
Figure 4: The random forest learns a nonlinear relationship between construction year and apartment price. This suggests designing new features to include in the linear model.
For attribution, please cite this work as
Sankaran (2024, Jan. 7). STAT 436 (Spring 2024): Partial Dependence Profiles II. Retrieved from https://krisrs1128.github.io/stat436_s24/website/stat436_s24/posts/2024-12-27-week12-2/
BibTeX citation
@misc{sankaran2024partial,
  author = {Sankaran, Kris},
  title = {STAT 436 (Spring 2024): Partial Dependence Profiles II},
  url = {https://krisrs1128.github.io/stat436_s24/website/stat436_s24/posts/2024-12-27-week12-2/},
  year = {2024}
}