Source code for malt.trainer

"""Utilities to train a supervised model."""
# =============================================================================
# IMPORTS
# =============================================================================
import torch

# =============================================================================
# MODULE FUNCTIONS
# =============================================================================
[docs]def get_default_trainer( optimizer: str = "Adam", learning_rate: float = 1e-3, n_epochs: int = 10, batch_size: int = -1, reduce_factor: float = 0.5, patience: int = 10, without_player: bool = False, min_learning_rate: float = 1e-6, no_validation_threshold = 20, ): """ Get the default training scheme for models. Parameters ---------- optimizer : str Name of the optimizer. Must be an attribute of `torch.optim` learning_rate : float Initial learning rate. n_epochs : int Maximum epochs. batch_size : int Batch size. validation_split : float Proportion of validation set. reduce_factor : float Rate of learning rate reduction. Returns ------- Callable : Trainer function. """ def _default_trainer_without_player( model, data_train, data_valid, optimizer=optimizer, learning_rate=learning_rate, n_epochs=n_epochs, batch_size=batch_size, min_learning_rate=min_learning_rate, reduce_factor=reduce_factor, ): # see if cuda is available if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") # get original device original_device = next(model.parameters()).device # move model to cuda if available model = model.to(device) # consider the case of one batch if batch_size == -1: batch_size = len(data_train) # put data into loader data_train = data_train.view(batch_size=batch_size, pin_memory=True) data_valid = data_valid.view(batch_size=len(data_valid), pin_memory=True) # get optimizer object optimizer = getattr(torch.optim, optimizer,)( model.parameters(), learning_rate, ) # get scheduler scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=reduce_factor, patience=patience, ) # train for idx_epoch in range(n_epochs): # loop through the epochs for x in data_train: # loop through the dataset x = [_x.to(device) for _x in x] optimizer.zero_grad() model.train() loss = model.loss(*x).mean() # average just in case loss.backward() optimizer.step() model.eval() with torch.no_grad(): x = next(iter(data_valid)) x = [_x.to(device) for _x in x] loss = model.loss(*x).mean() scheduler.step(loss) if optimizer.param_groups[0]['lr'] < min_learning_rate: break x = next(iter(data_train)) x = [_x.to(device) for _x in x] loss = model.loss(*x).mean() model = model.to(original_device) model.train() model.eval() return model def _default_trainer( player, *args, **kwargs ): player.portfolio.shuffle() if len(player.portfolio) >= no_validation_threshold: data_train, data_valid = player.portfolio.split([9, 1]) else: data_train = data_valid = player.portfolio return _default_trainer_without_player( player.model, data_train, data_valid, *args, **kwargs, ) if without_player is True: return _default_trainer_without_player else: return _default_trainer