Trainer class and gives you a basic understanding of how to use it.
Prerequisites
Please ensure that you have installed the Cerebras Model Zoo package by going through the installation guide. Optionally, you can also read through the basic Cerebras PyTorch guide to first gain an understanding of the underlying API that underpins theTrainer class.
Basic Usage
TheTrainer class can be imported and used as follows:
Trainer class takes in the following:
Learn more about these parameters in our Trainer Configuration guide.
-
device: The device to run training/validation on. -
model_dir: The directory at which to store model related artifacts (e.g. model checkpoints). -
model: Thetorch.nn.Moduleinstance that we are training/validating. -
optimizer: Optionally, acerebras.pytorch.optim.Optimizerinstance can be passed in to optimize the model weights during the training phase.
fit takes in the following:
-
train_dataloader: Thecerebras.pytorch.utils.data.DataLoaderinstance to use during training. -
val_dataloader: Optionally, acerebras.pytorch.utils.data.DataLoaderinstance can be passed in to run validation during and/or at the end of training.
Trainer to fit your needs.
Configuring the Training loop
As mentioned above, if both atrain_dataloader and val_dataloader are provided to the fit call, the default behaviour is to run a single epoch of training followed by a single epoch of validation.
This behaviour can be configured by passing in a TrainingLoop instance to the Trainer as follows:
-
num_stepsrepresents the total number of batches to train for. Ifnum_stepsexceeds the number of available batches in the train dataloader, the dataloader is automatically repeated to be able to run training fornum_steps. -
eval_stepsrepresents the number of steps to run validation for every time we run validation. Similar to training, ifeval_stepsexceeds the number of available batches in the val dataloader, the dataloader is automatically repeated. Although, typically validation is never run for more than a single epoch. So, it is advised to seteval_stepsto be less than the length of the validation dataloader. Otherwise, the validation metrics may be incorrect. -
eval_frequencyrepresents how often validation is run during training. In the above example, validation is run every 100 steps of training. That is to say, throughout the 1000 steps of training, validation is run 10 times. Regardless of the value ofeval_frequency, ifeval_frequencyis greater than zero, we always run validation at the end of training.
Checkpointing
TheTrainer can be further configured to save checkpoints at regular intervals by passing in a Checkpoint instance as follows:
num_steps is a multiple of the checkpoint steps.
The checkpoints are saved in the model_dir directory that was passed to the Trainer.
This checkpoint is meant for resuming training from the same point in the future. As such, it will contain the model weights, optimizer state, and any other state that is necessary to resume training. Please see Selective Checkpoint State Saving for examples of how to configure what state is saved into the checkpoint.
ckpt_path argument to the call to fit. For example,
If a
ckpt_path is not provided, but a checkpoint is found inside the model_dir, then Trainer “cerebras.modelzoo.Trainer”) will automatically load the latest checkpoint found in the model_dir.Trainer, see Checkpointing.
What’s next?
To learn about how to specify a schedule for learning rates, please see Optimizer and Scheduler. To learn about how you can configure aTrainer instance using a YAML configuration file, you can check out:
- Trainer YAML Overview
Trainer in some core workflows, you can check out:
To learn more about how you can extend the capabilities of the Trainer class, you can check out:
- Defer Weight Initialization
- Numeric Precision
- Train a model with weight sparsity
- Checkpointing
- Customizing the Trainer with Callbacks
- Logging
- Performance Flags
Trainer class outputs during the run, you can check out:

