train_crappifier#
from pssr.train import train_crappifier
- pssr.train.train_crappifier(model: ~torch.nn.modules.module.Module, dataset: ~torch.utils.data.dataset.Dataset, batch_size: int, optim: ~torch.optim.optimizer.Optimizer, epochs: int, sigma: int = 5, clip: float = 3, device: str = 'cpu', scheduler: <module 'torch.optim.lr_scheduler' from '/home/haydenstites/anaconda3/lib/python3.11/site-packages/torch/optim/lr_scheduler.py'> = None, log_frequency: int = 50, checkpoint_dir: str = None, collage_dir: str = None, clamp: bool = False, dataloader_kwargs=None)#
EXPERIMENTAL, NOT CURRENTLY RECOMMENDED FOR MOST WORKFLOWS!
Trains an
nn.Module
model as a crappifier on high-low-resolution paired data. The model must output an image the same size as the input/have a scale value of 1. This is not necessary if you are using aCrappifier
instance as your crappifier.- Parameters:
model (nn.Module) – Model to train on paired data.
dataset (Dataset) – Paired image dataset to load data from.
batch_size (int) – Batch size for dataloader.
optim (Optimizer) – Optimizer for weight calculation.
epochs (int) – Number of epochs to train model for.
sigma (int) – Precision of noise distribution. Higher values will better approximate noise distribution but can cause larger gradients that are unstable during training. Default is 5.
clip – (float) : Max gradient for gradient clipping. Use None for no clipping. Default is 3.
device (str) – Device to train model on. Default is “cpu”.
scheduler (LRScheduler) – Optional learning rate scheduler for training. Default is None.
log_frequency (int) – Frequency to log losses and recalculate metrics in steps. Default is 50.
checkpoint_dir (str) – Directory to save model checkpoints each epoch. A value of None skips checkpointing. Default is None.
collage_dir (str) – Directory to save validation collages each epoch. A value of None skips the collage. Default is None.
clamp (bool) – Whether to clamp model image output before weight calculation. Default is False.
dataloader_kwargs (dict[str, Any]) – Keyword arguments for pytorch
Dataloader
. Default is None.callbacks (list[Callable]) – Callbacks after each training batch. Can optionally specify an argument for locals to be passed. Default is None.
- Returns:
List of losses during training.
val_losses (list[float]) : Validation losses per epoch.
- Return type:
train_losses (list[float])