train_crappifier#

from pssr.train import train_crappifier
pssr.train.train_crappifier(model: ~torch.nn.modules.module.Module, dataset: ~torch.utils.data.dataset.Dataset, batch_size: int, optim: ~torch.optim.optimizer.Optimizer, epochs: int, sigma: int = 5, clip: float = 3, device: str = 'cpu', scheduler: <module 'torch.optim.lr_scheduler' from '/home/haydenstites/anaconda3/lib/python3.11/site-packages/torch/optim/lr_scheduler.py'> = None, log_frequency: int = 50, checkpoint_dir: str = None, collage_dir: str = None, clamp: bool = False, dataloader_kwargs=None)#

EXPERIMENTAL, NOT CURRENTLY RECOMMENDED FOR MOST WORKFLOWS!

Trains an nn.Module model as a crappifier on high-low-resolution paired data. The model must output an image the same size as the input/have a scale value of 1. This is not necessary if you are using a Crappifier instance as your crappifier.

Parameters:
  • model (nn.Module) – Model to train on paired data.

  • dataset (Dataset) – Paired image dataset to load data from.

  • batch_size (int) – Batch size for dataloader.

  • optim (Optimizer) – Optimizer for weight calculation.

  • epochs (int) – Number of epochs to train model for.

  • sigma (int) – Precision of noise distribution. Higher values will better approximate noise distribution but can cause larger gradients that are unstable during training. Default is 5.

  • clip – (float) : Max gradient for gradient clipping. Use None for no clipping. Default is 3.

  • device (str) – Device to train model on. Default is “cpu”.

  • scheduler (LRScheduler) – Optional learning rate scheduler for training. Default is None.

  • log_frequency (int) – Frequency to log losses and recalculate metrics in steps. Default is 50.

  • checkpoint_dir (str) – Directory to save model checkpoints each epoch. A value of None skips checkpointing. Default is None.

  • collage_dir (str) – Directory to save validation collages each epoch. A value of None skips the collage. Default is None.

  • clamp (bool) – Whether to clamp model image output before weight calculation. Default is False.

  • dataloader_kwargs (dict[str, Any]) – Keyword arguments for pytorch Dataloader. Default is None.

  • callbacks (list[Callable]) – Callbacks after each training batch. Can optionally specify an argument for locals to be passed. Default is None.

Returns:

List of losses during training.

val_losses (list[float]) : Validation losses per epoch.

Return type:

train_losses (list[float])