Evaluating the fit at particular observations in Bayesian models.
library("dplyr")
library("ggplot2")
library("ggrepel")
library("loo")
library("purrr")
library("rstan")
library("tidyr")
theme479 <- theme_minimal() +
theme(
panel.grid.minor = element_blank(),
panel.background = element_rect(fill = "#f7f7f7"),
panel.border = element_rect(fill = NA, color = "#0c0c0c", size = 0.6),
legend.position = "bottom"
)
All the model visualization strategies we’ve looked at in the last few lectures have been dataset-wide. That is, we looked at properties of the dataset as a whole, and whether the model made sense globally, across the whole dataset. Individual observations might warrant special attention, though.
The block below loads in the fitted models from the previous set of notes.
downloader <- function(link) {
f <- tempfile()
download.file(link, f)
get(load(f))
}
models <- downloader("https://uwmadison.box.com/shared/static/x7dotair443mhx34yzie3m3lrsvhk19a.rda")
GM <- downloader("https://uwmadison.box.com/shared/static/2pzgdu7gyobhl5tezo63tns7by1aiy6d.rda")
loo
function in the loo
package. The code below computes these probabilities for each model, storing the difference in predictive probabilities for models two and three in the diff23
variable.ggplot(elpd_diffs, aes(ID, diff23)) +
geom_point(
aes(col = super_region_name),
size = 0.9, alpha = 0.8
) +
geom_text_repel(
data = outliers,
aes(label = City_locality),
size = 3
) +
scale_color_brewer(palette = "Set2") +
labs(
y = "Influence (Model 2 vs. 3)",
col = "WHO Region"
)
psis
function. The pareto_k
diagnostic summarizes how much the posterior predictive shifts when an observation is or isn’t included. For example, in the figure below, observation 2674 (Ulaanbaatar again) is highly influential.loglik <- map(models, ~ as.matrix(., pars = "log_lik"))
kdata <- GM@data %>%
mutate(
k_hat = psis(loglik[[2]])$diagnostics$pareto_k,
Index = row_number()
)
outliers <- kdata %>%
filter(k_hat > 0.25)
ggplot(kdata, aes(x = Index, y = k_hat)) +
geom_point(aes(col = super_region_name), size = 0.5, alpha = 0.9) +
scale_color_brewer(palette = "Set2") +
geom_text_repel(data = outliers, aes(label = Index)) +
labs(y = "k-hat")