<object object>
, model=<object object>
, optimizer=None, schedulers=None, precision=None, sparsity=None, loop=None, checkpoint=None, logging=None, callbacks=None, loggers=None, seed=None)
The Trainer class is the main entry point for training models in ModelZoo.
Parameters.
- device (Optional[str]) – The device to train the model on. It must be one of “CSX”, “CPU”, or “GPU”.
- backend (Optional[Backend]) – The backend used to train the model. This argument is mutually exclusive with device.
- model_dir (str) – The directory where the model artifacts are saved.
-
model (Union[Callable[[],torch.nn.Module],torch.nn.Module]) –The model to train. It must be one of the following:
- If a callable is passed, it is assumed to be a function that takes in no arguments returns a torch.nn.Module.
- If a torch.nn.Module is passed, it is used as is.
- optimizer (Union_[Optimizer,_ Callable_[[torch.nn.Module],_ Optimizer]__, None_]_) –The optimizer used to optimize the model. It must be one of the following:
-
schedulers (SchedulersInput) –The set of optimizer schedulers to be used. Common schedulers include LR schedulers. It must be a list of these items:
- If a cstorch.optim.scheduler.Scheduler is passed, it is used as is.
-
A callable that is assumed to be a function that takes in a
Optimizer
and returns a cstorch.optim.scheduler.Scheduler. - If None, there is no optimizer param group scheduling.
- precision (Optional_[Precision]_) – The Precision callback used during training
-
sparsity (Optional_[SparsityAlgorithm]_) –The sparsity algorithm used to sparsify weights during training/validation It must be one of the following:
-
If a callable is passed, it is assumed to be a function that takes in no arguments returns a
SparsityAlgorithm
. -
If a
SparsityAlgorithm
is passed, it is used as is.
-
If a callable is passed, it is assumed to be a function that takes in no arguments returns a
- loop (Optional_[LoopCallback]_) – The loop callback to use for training. It must be an instance of LoopCallback. If not provided, the default loop is TrainingLoop(num_epochs=1).
- checkpoint (Optional_[Checkpoint]_) – The checkpoint callback to use for saving/loading checkpoints. It must be an instance of Checkpoints. If not provided, then no checkpoints are saved.
-
logging (Optional_[Logging]_) – The logging callback used to set up python logging. This callback also controls when logs are supposed to be logged. If not provided, the default logging settings
Logging(log_steps=1, log_level="INFO")
are used. - callbacks (Optional_[List[Callback]__]_) – A list of callbacks to used by the trainer. The order in which the callbacks are provided is important as it determines the order in which the callback’s hooks are executed.
- loggers (Optional_[List[Logger]__]_) – A list of loggers to use for logging.
- seed (Optional_[int]_) – Initial seed for the torch random number generator.
- hook_name (str) – The name of the hook to call. It must be the name of a method in the Callback class.
- args – Other positional arguments to forward along to the called hook.
- kwargs – Other keyword arguments to forward along to the called hook.
- batch – The batch of data to train on.
- batch_idx – The index of the batch in the dataloader.
<object object>
)[source]##
Complete a full training run on the given train and validation dataloaders.
Parameters
- train_dataloader (cerebras.appliance.log.named_class_logger) – The training dataloader.
- val_dataloader (Optional_[Union[cerebras.appliance.log.named_class_logger,_ List_[cerebras.appliance.log.named_class_logger]]]_) –The validation dataloader.If provided, validation is run every eval_frequency steps as defined in the loop callback.If not provided, only training is run.If a list of dataloaders is provided, then each dataloader is validated in sequence.
- ckpt_path (Optional_[str]_) – The path to the checkpoint to load before starting training. If not provided and autoload_last_checkpoint is True, then the latest checkpoint is loaded
<object object>
, loop=None)[source]#
Complete a full validation run on the validation dataloader.
Parameters
- val_dataloader (Optional_[cerebras.appliance.log.named_class_logger]_) –The validation dataloader. If a list of dataloaders is provided, then each dataloader is> validated in sequence.
- ckpt_path (Optional_[str]_) – The path to the checkpoint to load before starting validation. If not provided and autoload_last_checkpoint is True, then the latest checkpoint is loaded.
- loop (Optional_[cerebras.modelzoo.trainer.callbacks.loop.ValidationLoop]_) – The loop callback to use for validation. If not provided, the default loop is used. If provided, it must be an instance of ValidationLoop. Note, this should only be provided if the loop callback provided in the constructor is not sufficient.
<object object>
, loop=None)[source]#
Runs all upstream and downstream validation permutations.
for ckpt_path in ckpt_paths:
for val_dataloader in val_dataloaders:
trainer.validate(val_dataloader, ckpt_path)
# run downstream validation
run_validation(…)
Copy to clipboard
Parameters
- val_dataloaders (Optional_[Union[cerebras.appliance.log.named_class_logger,_ List_[cerebras.appliance.log.named_class_logger]]]_) – A list of validation dataloaders to run validation on.
- ckpt_paths (Optional_[Union[List[str],_ str_]]_) – A list of checkpoint paths to run validation on. Each checkpoint path must be a path to a checkpoint file, or a glob pattern.
- loop (Optional_[cerebras.modelzoo.trainer.callbacks.loop.ValidationLoop]_) – The validation loop to use for validation. If not provided, then the default loop is used.
Cerebras Model Zoo Callbacks API#
cerebras.modelzoo.trainer.callbacks | This module contains the base Callback class as well as a number of core callbacks directly invoked by the Trainer as well as other optional callbacks that can be used to extend the functionality of the Trainer. |
Cerebras Model Zoo Extensions API#
cerebras.modelzoo.trainer.extensions | This module contains integrations of external tools to the Trainer. |
Cerebras Model Zoo Loggers API#
cerebras.modelzoo.trainer.loggers | This module contains the base Logger class as well as a few useful Logger subclasses. |