train_paired#

from pssr.train import train_paired
pssr.train.train_paired(model: ~torch.nn.modules.module.Module, dataset: ~torch.utils.data.dataset.Dataset, batch_size: int, loss_fn: ~torch.nn.modules.module.Module, optim: ~torch.optim.optimizer.Optimizer, epochs: int, device: str = 'cpu', scheduler: <module 'torch.optim.lr_scheduler' from '/home/haydenstites/anaconda3/lib/python3.11/site-packages/torch/optim/lr_scheduler.py'> = None, log_frequency: int = 50, checkpoint_dir: str = None, collage_dir: str = None, clamp: bool = False, dataloader_kwargs=None, callbacks=None)#

Trains model on paired high-low-resolution crappified data.

Parameters:
  • model (nn.Module) – Model to train on paired data.

  • dataset (Dataset) – Paired image dataset to load data from.

  • batch_size (int) – Batch size for dataloader.

  • loss_fn (nn.Module) – Loss function for loss calculation.

  • optim (Optimizer) – Optimizer for weight calculation.

  • epochs (int) – Number of epochs to train model for.

  • device (str) – Device to train model on. Default is “cpu”.

  • scheduler (LRScheduler) – Optional learning rate scheduler for training. Default is None.

  • log_frequency (int) – Frequency to log losses and recalculate metrics in steps. Default is 50.

  • checkpoint_dir (str) – Directory to save model checkpoints each epoch. A value of None skips checkpointing. Default is None.

  • collage_dir (str) – Directory to save validation collages each epoch. A value of None skips the collage. Default is None.

  • clamp (bool) – Whether to clamp model image output before weight calculation. Default is False.

  • dataloader_kwargs (dict[str, Any]) – Keyword arguments for pytorch Dataloader. Default is None.

  • callbacks (list[Callable]) – Callbacks after each training batch. Can optionally specify an argument for locals to be passed. Default is None.

Returns:

List of losses during training.

val_losses (list[float]) : Validation losses per epoch.

Return type:

train_losses (list[float])