Partial Dependence Profiles II

Discovering richer structure in partial dependence profiles.

Kris Sankaran (UW Madison)
01-07-2024

Reading, Recording, Rmarkdown

library(tidyverse)
library(caret)
library(DALEX)
theme479 <- theme_minimal() + 
  theme(
    panel.grid.minor = element_blank(),
    panel.background = element_rect(fill = "#f7f7f7"),
    panel.border = element_rect(fill = NA, color = "#0c0c0c", size = 0.6),
    legend.position = "bottom"
  )
theme_set(theme479)
  1. Partial dependence (PD) plots help answer the question, « How is feature \(j\) used by my model \(f\)? » Slight variations on PD plots are useful for some related followup questions,
    • Has my model \(f\) learned interactions between features \(j\) and \(j^{\prime}\)?
    • How do the models \(f\) and \(f^{\prime}\) differ in the way that they use feature \(j\)?
  2. The variants, called Grouped and Contrastive PD plots, reduce the original CP profiles less aggressively than PD plots, but without becoming overwhelmingly complicated.

Interactions

  1. We say that there is an interaction between variables \(j\) and \(j^{\prime}\) if the relationship between \(x_{j}\) and \(y\) is modulated by variable \(j^{\prime}\). For example, in the figure below, the slope of cross-sections across \(j\) depends on \(j^{\prime}\).

  2. Using the language of CP profiles, the figure above means that the shape of the CP profile in \(j\) depends on the particular setting of \(j^{\prime}\). This motivates the use of Grouped PD profiles — we compute several PD profiles in \(j\), restricting attention to CP profiles whose value \(x_{j^{\prime}}\) lies within a prespecified range.

  3. To illustrate, we revisit the CP profiles for age from the Titanic dataset. Below, the profiles are grouped according to the class of the ticket holder. The result shows that the relationship between age and survival was not the same across all passengers. For all classes, there was a decrease in survival probability for adults, but the dropoff was most severe for crew members.

f <- tempfile()
download.file("https://uwmadison.box.com/shared/static/nau695mppsoxx0f6bns1ieo7kh1bje0j.rda", f)
fit <- get(load(f))

data(titanic)
titanic <- titanic %>%
  select(-country) %>%
  na.omit()
x <- select(titanic, -survived)

explanation <- explain(model = fit, data = x, y = titanic$survived)
Preparation of a new explainer is initiated
  -> model label       :  train  (  default  )
  -> data              :  2179  rows  7  cols 
  -> target variable   :  2179  values 
  -> predict function  :  yhat.train  will be used (  default  )
  -> predicted values  :  No value for predict function target column. (  default  )
  -> model_info        :  package caret , ver. 6.0.94 , task classification (  default  ) 
  -> model_info        :  Model info detected classification task but 'y' is a factor .  (  WARNING  )
  -> model_info        :  By deafult classification tasks supports only numercical 'y' parameter. 
  -> model_info        :  Consider changing to numerical vector with 0 and 1 values.
  -> model_info        :  Otherwise I will not be able to calculate residuals or loss function.
  -> predicted values  :  numerical, min =  0.007271972 , mean =  0.3239071 , max =  0.9885397  
  -> residual function :  difference between y and yhat (  default  )
  -> residuals         :  numerical, min =  NA , mean =  NA , max =  NA  
  A new explainer has been created!  
profiles <- model_profile(explainer = explanation, groups = "class")
plot(profiles, geom = "profiles", variables = "age") +
  scale_color_brewer(palette = "Set2") +
  theme479
Grouping the CP profiles by ticket class reveals an interaction effect with age in the Titanic dataset.

Figure 1: Grouping the CP profiles by ticket class reveals an interaction effect with age in the Titanic dataset.

  1. What should we do if there are many input variables and we don’t have a priori knowledge about which variables \(j^{\prime}\) might be interacting with \(j\)? One idea is to try to discover relevant interactions by clustering the original set of CP profiles.

  2. In more detail, we can compute the CP profiles for all the samples, and then see whether there are subsets of profiles that all look similar. If we find features \(j^{\prime}\) that characterize these groupings, then we have found features that interact with \(j\) (with respect to the fitted model \(f\)). The plot below shows the same profiles as above, but clustering directly. It seems to recover the interaction between age and class, even though we have not explicitly provided this grouping variable.

profiles <- model_profile(explainer = explanation, variables = "age", k = 3)
plot(profiles, geom = "profiles", variables = "age") +
  scale_color_brewer(palette = "Set2") +
  theme479
Discovered groupings in the CP profiles for age reveals an interaction effect.

Figure 2: Discovered groupings in the CP profiles for age reveals an interaction effect.

Model Comparison

  1. The comparison of different models’ PD profiles can be used to,
    • Validate a simple model
    • Guide the design of new features, and
    • Characterizing overfitting PD profiles that are used to compare different models are sometimes called « Contrastive » PD profiles.
  2. To validate a simple model, we can compare its PD profiles with those of a more sophisticated model. We will illustrate this by fitting linear and random forest models to a dataset of apartment prices1. Given various properties of an apartment, the goal is to determine its price. The code below fits the two models and extracts their CP and PD profiles.
data(apartments)
x <- select(apartments, -m2.price)
profiles_lm <- train(x, apartments$m2.price, method = "lm") %>%
  explain(x, apartments$m2.price, label = "LM") %>%
  model_profile()
Preparation of a new explainer is initiated
  -> model label       :  LM 
  -> data              :  1000  rows  5  cols 
  -> target variable   :  1000  values 
  -> predict function  :  yhat.train  will be used (  default  )
  -> predicted values  :  No value for predict function target column. (  default  )
  -> model_info        :  package caret , ver. 6.0.94 , task regression (  default  ) 
  -> predicted values  :  numerical, min =  1781.848 , mean =  3487.019 , max =  6176.032  
  -> residual function :  difference between y and yhat (  default  )
  -> residuals         :  numerical, min =  -247.4728 , mean =  -1.337654e-12 , max =  469.0023  
  A new explainer has been created!  
profiles_rf <- train(x, apartments$m2.price, method = "rf", tuneGrid = data.frame(mtry = 10)) %>%
  explain(x, apartments$m2.price, label = "RF") %>%
  model_profile()
Preparation of a new explainer is initiated
  -> model label       :  RF 
  -> data              :  1000  rows  5  cols 
  -> target variable   :  1000  values 
  -> predict function  :  yhat.train  will be used (  default  )
  -> predicted values  :  No value for predict function target column. (  default  )
  -> model_info        :  package caret , ver. 6.0.94 , task regression (  default  ) 
  -> predicted values  :  numerical, min =  1671.409 , mean =  3486.406 , max =  6374.482  
  -> residual function :  difference between y and yhat (  default  )
  -> residuals         :  numerical, min =  -232.2327 , mean =  0.6134767 , max =  264.4474  
  A new explainer has been created!  
  1. The PD profile below shows that the random forest learns linear relationships with price for both the surface and floor variables. If all the effects were like this, then we would have a good reason for preferring the linear model.
plot(profiles_lm, profiles_rf, variables = c("surface", "floor")) +
  scale_color_brewer(palette = "Set2") +
  theme479
A contrastive PD display suggests that the floor and surface features are linearly related with apartment price.

Figure 3: A contrastive PD display suggests that the floor and surface features are linearly related with apartment price.

11.. When making the comparison between a simple and a complex model, certain discrepancies might become apparent. For example, important nonlinearities or interactions might be visible from the PD profiles of the complex model. This information can guide the design of new features in the simpler model, so that it can continue to be used. This is exactly the case in the apartments dataset above – there is a strong nonlinear relationship for the construction year variables. This suggests that, if a linear model is still desired, then a new feature should be defined that identifies whether the apartment was built between 1935 and 1990.

plot(profiles_lm, profiles_rf, variables = "construction.year") +
  scale_color_brewer(palette = "Set2") +
  theme479
The random forest learns a nonlinear relationship between construction year and apartment price. This suggests designing new features to include in the linear model.

Figure 4: The random forest learns a nonlinear relationship between construction year and apartment price. This suggests designing new features to include in the linear model.

  1. Suppose you have found that a model is overfitting (e.g., by finding that it’s training error is much lower than its test error). One way to address this overfitting is to compare the PD profiles between the simple and complex models. If the profiles are very different for one of the features, then that feature may be the source of overfitting.

Citation

For attribution, please cite this work as

Sankaran (2024, Jan. 7). STAT 436 (Spring 2024): Partial Dependence Profiles II. Retrieved from https://krisrs1128.github.io/stat436_s24/website/stat436_s24/posts/2024-12-27-week12-2/

BibTeX citation

@misc{sankaran2024partial,
  author = {Sankaran, Kris},
  title = {STAT 436 (Spring 2024): Partial Dependence Profiles II},
  url = {https://krisrs1128.github.io/stat436_s24/website/stat436_s24/posts/2024-12-27-week12-2/},
  year = {2024}
}