Saving Predictions

Save predictions from python to R, for subsequent visualization.

Kris Sankaran (UW Madison)
2021-02-15

Code Link

  1. Once we have trained a model, it’s important to visualize its predictions. We’re going to use R to make our visualizations, but our model has been saved as a python object. This script helps with the transition betwen languages, saving all the model’s predictions as numpy arrays, which can be read in R using the reticulate package. We’ll make predictions on both training and test data, to gauge the degree of over / underfitting.

  2. The block below defines some high-level parameters for this script. If you are running this on your own machine, you should change the data_dir parameter to whereever you have been storing the raw and processed data. Also, if you have access to a GPU, you should change the device parameter, since it would help us get the predictions more quickly.

from pathlib import Path

data_dir = Path("/home/jovyan/data")
process_dir = data_dir / "processed"
args = {
    "device": "cpu", # set to "cuda" if gpu is available
    "out_dir": data_dir / "predictions"
}
  1. We left the last notebook without fully training the model. We also never generated the test data that would have been made by the 2-preprocessing.Rmd script before. Instead, in this block, we will download a test data set and trained model, currently stored in a UW Madison box folder.
from data import download_data

links = {
    "test_data": "https://uwmadison.box.com/shared/static/zs8vtmwbl92j5oq6ekzcfod11ym1w599.gz",
    "model": "https://uwmadison.box.com/shared/static/byb5lpny6rjr15zbx28o8liku8g6nga6.pt"
}

download_data(links["test_data"], process_dir / "test.tar.gz")
download_data(links["model"], data_dir / "model.pt", unzip = False)
  1. This block sets up the model that we just downloaded. The .eval() step specifies that we are no longer using the model for training. We don’t need to keep track of model gradients anymore, since all we care about are predictions made with the existing weights.
import torch
from unet import Unet

state = torch.load(data_dir / "model.pt", map_location=args["device"])
model = Unet(13, 3, 4).to(args["device"])
model.load_state_dict(state)
model = model.eval()
  1. The block below creates Dataset objects from which we can load the preprocessed training and test samples. We rely on the fact that our directory structure completely species the train / test split. We will iterate over these images one by one, saving a prediction for each. In principle, it’s possible to save predictions over batches of images by first defining a data loader. This would be a bit more complex to implement, though, and we’re aiming for simplicity here.
from data import GlacierDataset
from torch.utils.data import DataLoader

paths = {}
for split in ["train", "test"]:
    paths[split] = {}
    for v in ["x", "y"]:
        paths[split][v] = list((process_dir / split).glob(v + "*"))
        paths[split][v].sort()

ds = {
    "train": GlacierDataset(paths["train"]["x"], paths["train"]["y"]),
    "test": GlacierDataset(paths["test"]["x"], paths["test"]["y"])
}
  1. Finally, we save predictions to the args["out_dir"] folder. The code for the predictions function is given in the train.py script. It iterates over the loader and saves a numpy array with predictions for each sample. Somewhat counterintuively, we also save the x and y’s associated with each prediction. The reason is that the output from the Dataset object is not deterministic – we may return a random rotation or flip of the original image. This was done to encourage invariance to these transformations in our model, but makes it hard to compare the predictions directly with the objects in the processed directory. By writing all the matched input, label, and prediction data again at this point, we make it easier to study the specific version of the inputs that are related to good and bad model performance.
from train import predictions

predictions(model, ds["train"], args["out_dir"] / "train", args["device"])
predictions(model, ds["test"], args["out_dir"] / "test", args["device"])