PSSR2 Workflow#

Note

This section explains how to use an example PSSR2 workflow, similar to that implemented in the CLI. It does not necessarily apply to all use cases and is meant to be expanded upon.

Training a Basic Model#

Before diving into the code, we will first specify our imports.

import torch
from pssr.data import ImageDataset
from pssr.crappifiers import Poisson
from pssr.models import ResUNet
from pssr.util import SSIMLoss
from pssr.train import train_paired
from torch.optim.lr_scheduler import ReduceLROnPlateau

Defining Objects#

Before we define our dataset we must first define our Crappifier, as it is utilized by our dataset.

crappifier = Poisson(intensity=1, gain=0)

This sets the crappifier variable to an instance of the Poisson crappifier with default arguments. It will be used to synthetically generate low-resolution images to train on, given the high-resolution images in our dataset.


Which dataset to use depends on the format of your images. If you are using multidimensional or time series images, consider learning Advanced Dataloading.

In this example, we will use ImageDataset, assuming that our images are already sliced.

dataset = ImageDataset("your/hr", hr_res=512, lr_scale=4, crappifier=crappifier, extension="tif")

This sets the dataset variable to an instance of ImageDataset, loading high-resolution .tif images from your/hr. The high-resolution images are specified to have a horizontal and vertical resolution of hr_res=512. If the images provided are not square or are of the wrong resolution, they will be cropped and/or rescaled to fit.

We provide the Crappifier we defined earlier as an argument that will generate low-resolution images lr_scale=4 times smaller than the high-resolution images, for a horizontal and vertical resolution of 128.

Note

Users are advised to keep image resolutions to a power of 2 even if the raw input images have a different size. This is elaborated in models.


The last thing we need to define before training is our model.

model = ResUNet(
   hidden=[64, 128, 256, 512, 1024],
   scale=4,
   depth=3,
)

This sets the model variable to an instance of ResUNet. The scale argument sets the factor by which the input low-resolution images must be upscaled by, and should be equivalent to the lr_scale argument in our dataset. The other arguments specify the number of channels per hidden layer, and the depth of each hidden layer (number of hidden convolutions).


Train Arguments#

As we are training on a synthetic paired high-low-resolution dataset, we will use the train_paired function.

For simplicity, we will define our arguments before beginning training.


We will first define our loss function.

loss_fn = SSIMLoss(mix=.8, ms=True)

While MSE loss can also be used to good results, we will instead use SSIMLoss here, which will optimize visually significant elements our predictions. The mix argument controls the inverse contribution of corrected L1 loss, while the ms argument enables MS-SSIM, a more robust version of SSIM.


We also need to provide an optimizer.

optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optim, factor=0.1, patience=5, verbose=True)

This defines the optimizer of our model with starting learning rate of 1e-3. By defining a scheduler, the learning rate of the optimizer will decay by factor=0.1 after model performance doesn’t improve for patience=5 epochs.


And finally we define our miscellaneous arguments.

batch_size = 16
device = "cuda" if torch.cuda.is_available() else "cpu"

kwargs = dict(
   num_workers = 4,
   pin_memory = True,
)

This sets our batch size and training device, along with our torch DataLoader arguments. The batch size can be adjusted depending on the amount of memory available for training.


Training#

We can now train our model using the train_paired function.

losses = train_paired(
   model=model,
   dataset=dataset,
   batch_size=batch_size,
   loss_fn=loss_fn,
   optim=optim,
   epochs=20,
   device=device,
   scheduler=scheduler,
   dataloader_kwargs=kwargs,
)

While training, various metrics will be provided along with the loss to easily monitor progress.

Additionally, at the end of every epoch a collage will be saved to the preds folder containing low-resolution crappified images, upscaled high-resolution predictions, and ground truth high-resolution images in that order.


After training is over, we should save our model for future use.

torch.save(model.state_dict(), "model.pth")

We can also plot the training losses returned by train_paired to see the progress of our model over time.

import matplotlib.pyplot as plt

plt.plot(losses)

Using the Model for Predictions#

We now have our trained model, which takes in low-resolution input images and outputs upscaled high-resolution images.

There are now two things we can do with our trained model, use it for predictions, or benchmark it.


If you decide to run your model predictions in a separate file, you will want to load your trained model before proceeding with

model.load_state_dict(torch.load("model.pth"))

where model is an instance of the same architecture you used previously.


Predicting Images#

To use our model, we will use the predict_images function.

from pssr.predict import predict_images

During the training phase, we loaded high-resolution images to create synthetic low-resolution images using a Crappifier. While predicting images, we will instead use experimentally acquired low-resolution images to predict upscaled high-resolution images.

We can do this by creating the same ImageDataset, only now we provide the path to our low-resolution images.

test_dataset = ImageDataset("your/lr", hr_res=512, lr_scale=4, extension="tif")

The low-resolution images are implied to have a horizontal and vertical resolution of 128 (hr_res=512 / lr_scale=4). A crappifier does not have to be specified, as it will not be used.


We can now use our model to upscale the low-resolution images.

predict_images(model, test_dataset, device)

This will super-resolve high-resolution images from our low-resolution images and save them to the preds folder.


Benchmarking the Model#

If you have a dataset containing aligned high-low-resolution pairs (every high-resolution image has an aligned low-resolution counterpart), we can use test_metrics to quantify the performance of our model on real world data.

Note

Metrics can still be acquired from training datasets with only high-resolution images, but they will only represent training performance on crappified data and may not represent real world performance.


We can do this by creating a new PairedImageDataset instance, containing our high-low-resolution image pairs.

paired_dataset = PairedImageDataset("your/hr", "your/lr", hr_res=512, lr_scale=4)

The images in each folder should be properly aligned and have a similar naming/ordering scheme so that they are returned in the same order when that dataset is iterated.


We can then compute metrics for all images.

test_metrics(model, paired_dataset, device=device)