Model
Learn how to pass a model to the Trainer class.
The model
is the main Module
that all training and validation is run on. It’s required by all Trainer
instances.
Prerequisites
Read the Trainer Overview and Trainer Configuration Overview for a basic overview of how to run Model Zoo models.
Configure the Model
Use the model
argument to set the model you’d like to train or validate.
When using YAML, pass all model
subkeys as arguments to the model class. Your run script’s model_fn
determines the model class.
In Python, you can specify the model in two ways:
- As a callable that takes no arguments and returns a
Module
- As a
Module
that the system uses directly
When passing a Module
directly, initialize the model inside the Cerebras device context for optimal performance:
This approach automatically moves model parameters to the Cerebras device, optimizing memory usage and improving initialization speed. For more information, see Efficient weight initialization.