This generalizes the built-in predict method to the multimedia class. Given
an estimated multimedia object, this function supports prediction along the
estimated DAG. It first predicts hat[M] | T, X
and then hat[Y] | hat[M], T, X
. Each prediction step will call the prediction method within the
mediation and outcome models that make up the multimedia object on which this
is called. By passing in new treatment, mediator, or pretreatment data, you
can predict forward along the DAG using these as inputs.
Usage
# S4 method for class 'multimedia'
predict(object, profile = NULL, mediators = NULL, pretreatment = NULL, ...)
Arguments
- object
An object of class multimedia containing the estimated mediation and outcome models whose mediation and outcome predictions we want to obtain.
- profile
An object of class
treatment_profile
containing the treatment profile to consider in the difference. Defaults to a profile with all the unique treatment configurations observed in the original data, shared across both the mediators and outcomes.- mediators
By default, we will return outcome predictions using the predicted mediators from the mediation model. Modify this argument if you would like to directly control the mediation inputs for the outcome model. Must be a data.frame whose columns are named to match the
mediators(object)
.- pretreatment
By default, we will return mediation and outcome model predictions using the same pretreatment variables as used when initially estimating the models (like setting
newdata = NULL
in usualpredict
). To pass in different pretreatment variables, provide a data.frame here whose columns match the pretreatments as the originally trained mediation and outcome models.- ...
A placeholder to agree with
predict
in base R. Not ever used.
Value
A list with two elements: $mediators: A data.frame containing predicted values for the mediators. Each row corresponds to one row of the newdata, or one row of the default treatment profile, if no newdata is given.
$outcomes: A data.frame containing predicted values for the outcomes, given either (i) the predicted values of the mediators or (ii) the provided values of the mediators. Each row corresponds to one row of the newdata, or one row of the default treatment profile, if no newdata is given.
Examples
exper <- demo_spline(tau = c(2, 1)) |>
mediation_data(starts_with("outcome"), "treatment", "mediator")
fit <- multimedia(exper, glmnet_model()) |>
estimate(exper)
predict(fit)
#> $mediators
#> # A tibble: 2 × 1
#> mediator
#> <dbl>
#> 1 -1.50
#> 2 1.43
#>
#> $outcomes
#> # A tibble: 2 × 2
#> outcome_1 outcome_2
#> <dbl> <dbl>
#> 1 -0.00123 0.307
#> 2 2.00 0.693
#>
# at new treatment configurations
t1 <- data.frame(treatment = factor(rep(c(0, 1), each = 5)))
profile <- setup_profile(fit, t_mediator = t1, t_outcome = t1)
predict(fit, profile)
#> $mediators
#> # A tibble: 10 × 1
#> mediator
#> <dbl>
#> 1 -1.50
#> 2 -1.50
#> 3 -1.50
#> 4 -1.50
#> 5 -1.50
#> 6 1.43
#> 7 1.43
#> 8 1.43
#> 9 1.43
#> 10 1.43
#>
#> $outcomes
#> # A tibble: 10 × 2
#> outcome_1 outcome_2
#> <dbl> <dbl>
#> 1 -0.00123 0.307
#> 2 -0.00123 0.307
#> 3 -0.00123 0.307
#> 4 -0.00123 0.307
#> 5 -0.00123 0.307
#> 6 2.00 0.693
#> 7 2.00 0.693
#> 8 2.00 0.693
#> 9 2.00 0.693
#> 10 2.00 0.693
#>
# at new treatment and mediator configurations
mediators <- data.frame(mediator = rnorm(10, 0, 1))
predict(fit, profile, mediators)
#> $mediators
#> mediator
#> 1 0.29215821
#> 2 1.09227838
#> 3 -0.08819959
#> 4 -0.50850895
#> 5 0.77703487
#> 6 1.16579730
#> 7 0.73168451
#> 8 -1.13674199
#> 9 1.44782355
#> 10 -0.47185232
#>
#> $outcomes
#> # A tibble: 10 × 2
#> outcome_1 outcome_2
#> <dbl> <dbl>
#> 1 -0.0849 -0.106
#> 2 -0.122 -0.291
#> 3 -0.0671 -0.0186
#> 4 -0.0475 0.0784
#> 5 -0.108 -0.218
#> 6 2.01 0.755
#> 7 2.03 0.855
#> 8 2.12 1.29
#> 9 2.00 0.690
#> 10 2.09 1.13
#>