Skip to contents

This generalizes the built-in predict method to the multimedia class. Given an estimated multimedia object, this function supports prediction along the estimated DAG. It first predicts hat[M] | T, X and then hat[Y] | hat[M], T, X. Each prediction step will call the prediction method within the mediation and outcome models that make up the multimedia object on which this is called. By passing in new treatment, mediator, or pretreatment data, you can predict forward along the DAG using these as inputs.

Usage

# S4 method for class 'multimedia'
predict(object, profile = NULL, mediators = NULL, pretreatment = NULL, ...)

Arguments

object

An object of class multimedia containing the estimated mediation and outcome models whose mediation and outcome predictions we want to obtain.

profile

An object of class treatment_profile containing the treatment profile to consider in the difference. Defaults to a profile with all the unique treatment configurations observed in the original data, shared across both the mediators and outcomes.

mediators

By default, we will return outcome predictions using the predicted mediators from the mediation model. Modify this argument if you would like to directly control the mediation inputs for the outcome model. Must be a data.frame whose columns are named to match the mediators(object).

pretreatment

By default, we will return mediation and outcome model predictions using the same pretreatment variables as used when initially estimating the models (like setting newdata = NULL in usual predict). To pass in different pretreatment variables, provide a data.frame here whose columns match the pretreatments as the originally trained mediation and outcome models.

...

A placeholder to agree with predict in base R. Not ever used.

Value

A list with two elements: $mediators: A data.frame containing predicted values for the mediators. Each row corresponds to one row of the newdata, or one row of the default treatment profile, if no newdata is given.

$outcomes: A data.frame containing predicted values for the outcomes, given either (i) the predicted values of the mediators or (ii) the provided values of the mediators. Each row corresponds to one row of the newdata, or one row of the default treatment profile, if no newdata is given.

Examples

exper <- demo_spline(tau = c(2, 1)) |>
    mediation_data(starts_with("outcome"), "treatment", "mediator")
fit <- multimedia(exper, glmnet_model()) |>
    estimate(exper)
predict(fit)
#> $mediators
#> # A tibble: 2 × 1
#>   mediator
#>      <dbl>
#> 1    -1.50
#> 2     1.43
#> 
#> $outcomes
#> # A tibble: 2 × 2
#>   outcome_1 outcome_2
#>       <dbl>     <dbl>
#> 1  -0.00123     0.307
#> 2   2.00        0.693
#> 

# at new treatment configurations
t1 <- data.frame(treatment = factor(rep(c(0, 1), each = 5)))
profile <- setup_profile(fit, t_mediator = t1, t_outcome = t1)
predict(fit, profile)
#> $mediators
#> # A tibble: 10 × 1
#>    mediator
#>       <dbl>
#>  1    -1.50
#>  2    -1.50
#>  3    -1.50
#>  4    -1.50
#>  5    -1.50
#>  6     1.43
#>  7     1.43
#>  8     1.43
#>  9     1.43
#> 10     1.43
#> 
#> $outcomes
#> # A tibble: 10 × 2
#>    outcome_1 outcome_2
#>        <dbl>     <dbl>
#>  1  -0.00123     0.307
#>  2  -0.00123     0.307
#>  3  -0.00123     0.307
#>  4  -0.00123     0.307
#>  5  -0.00123     0.307
#>  6   2.00        0.693
#>  7   2.00        0.693
#>  8   2.00        0.693
#>  9   2.00        0.693
#> 10   2.00        0.693
#> 

# at new treatment and mediator configurations
mediators <- data.frame(mediator = rnorm(10, 0, 1))
predict(fit, profile, mediators)
#> $mediators
#>       mediator
#> 1   0.29215821
#> 2   1.09227838
#> 3  -0.08819959
#> 4  -0.50850895
#> 5   0.77703487
#> 6   1.16579730
#> 7   0.73168451
#> 8  -1.13674199
#> 9   1.44782355
#> 10 -0.47185232
#> 
#> $outcomes
#> # A tibble: 10 × 2
#>    outcome_1 outcome_2
#>        <dbl>     <dbl>
#>  1   -0.0849   -0.106 
#>  2   -0.122    -0.291 
#>  3   -0.0671   -0.0186
#>  4   -0.0475    0.0784
#>  5   -0.108    -0.218 
#>  6    2.01      0.755 
#>  7    2.03      0.855 
#>  8    2.12      1.29  
#>  9    2.00      0.690 
#> 10    2.09      1.13  
#>