Our typical workflow involves using a training script provided in the Cerebras Model Zoo. However, if that training loop is insufficient for your model needs, you may write your own training loop using the Cerebras PyTorch API.
The following steps will only take you through the absolute minimum code required to run a simple, small model on the Cerebras Wafer Scale Cluster. To extend the script to feature things like learning rate scheduling, gradient scaling, etc. please continue to the further reading section to learn more about these topics.
Prerequisites
You have installed the cerebras.pytorch package in your environment.
Validate the Package Installation
Run the following command to validate that cerebras.pytorch package is installed correctly:
import cerebras.pytorch as cstorch
From here on, we will be using cstorch
as the alias for cerebras.pytorch
To configure the Cerebras Wafer-Scale cluster, construct a ClusterConfig
object, then use that to construct a backend
object:
cluster_config = cstorch.distributed.ClusterConfig(
max_wgt_servers=1,
max_act_per_csx=1,
num_workers_per_csx=1,
)
backend = cstorch.backend(
"CSX",
cluster_config=cluster_config,
)
See the class documentation for ClusterConfig
to view all configurable options.
Most options have reasonable defaults and do not need to be changed.
Define Your Model
When using the Cerebras PyTorch API, define your model in the same way you would in a Vanilla PyTorch workflow:
import torch
import torch.nn.functional as F
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(784, 256)
self.fc2 = torch.nn.Linear(256, 10)
def forward(self, inputs):
inputs = torch.flatten(inputs, 1)
outputs = F.relu(self.fc1(inputs))
return F.relu(self.fc2(outputs))
model = Model()
Weight initialization for large models can cause out-of-memory errors. Not only that, but initializing extremely large models eagerly can be very slow. See the page on Efficient weight initialization to see how to work around this issue.
Compile Your Model
Once the model has been instantiated, compile the model by calling the cerebras.pytorch.compile
, for example:
compiled_model = cstorch.compile(model, backend)
The call to cstorch.compile
doesn’t actually compile the model. Similar to torch.compile
it only prepares the model for compilation. Compilation only happens after the first iteration, once the input shapes are known.
Optimize Model Parameters
To optimize model parameters using the Cerebras Wafer-Scale cluster, you must use a Cerebras-compliant optimizer. There are exact drop-in replacements for all commonly used optimizers available in cerebras.pytorch.optim, for example:
optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
Dataloaders
To send data to the Wafer-Scale cluster, wrap your PyTorch dataloader with cerebras.pytorch.utils.data.DataLoader
.
For example:
def get_torch_dataloader(batch_size, train):
from torchvision import datasets, transforms
train_dataset = datasets.MNIST(
"/path/to/data", # change this path
train=train,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
target_transform=transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=torch.int32)
),
)
return torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
training_dataloader = cstorch.utils.data.DataLoader(
get_torch_dataloader, batch_size=64, train=True
)
The Cerebras PyTorch dataloader requires a callable that generates a PyTorch dataloader. This approach ensures that each worker can independently create its own dataloader instance, optimizing distributed parallelism.
Define the Training Step
To execute a single training iteration on the Wafer-Scale Cluster, you first need to capture all operations intended to run on the cluster. Do this by defining a function that includes all actions for a single training iteration and decorating it with cerebras.pytorch.trace
.
For example:
loss_fn = torch.nn.CrossEntropyLoss()
@cstorch.trace
def training_step(inputs, targets):
outputs = compiled_model(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
This function is traced and sent to the cluster for compilation and execution.
Define an Execution
To program an execution run on the Cerebras Wafer-Scale cluster, define an instance of the cerebras.pytorch.utils.data.DataExecutor
.
For example:
train_executor = cstorch.utils.data.DataExecutor(
training_dataloader,
num_steps=100,
checkpoint_steps=50,
)
It takes in the Cerebras PyTorch dataloader that will be used during the run, the total number of steps to run for, as well as the interval at which checkpoints will be taken.
Train Your Model
Once the above is defined, you can iterate through the executor to train your model.
@cstorch.step_closure
def print_loss(mode, loss: torch.Tensor, step: int):
print(f"{mode} Loss {step}: {loss.item()}")
@cstorch.checkpoint_closure
def save_checkpoint(step):
cstorch.save(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
},
f"checkpoint_{step}.mdl",
)
global_step = 0
for inputs, targets in train_executor:
loss = training_step(inputs, targets)
print_loss("Training", loss, global_step)
global_step += 1
save_checkpoint(global_step)
-
Notice how the loss was passed into a function decorated by
step_closure
. This is required to retrieve the loss value from the cluster before it can be used. See the page on step closures for more details.
-
Also, notice how checkpoints are saved inside a function decorated by
checkpoint_closure
. This is required to retrieve the model weights and optimizer state back from the cluster before it can be saved. Please see the page on saving checkpoints.
Putting It All Together
Combining all of the above steps, you create a super minimal training script for a simple, fully connected model training on the MNIST dataset:
# Import the Cerebras PyTorch module
import cerebras.pytorch as cstorch
import torch
import torch.nn.functional as F
def main():
# Define a model
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(784, 256)
self.fc2 = torch.nn.Linear(256, 10)
def forward(self, inputs):
inputs = torch.flatten(inputs, 1)
outputs = F.relu(self.fc1(inputs))
return F.relu(self.fc2(outputs))
backend = cstorch.backend(
"CSX",
cluster_config=cstorch.distributed.ClusterConfig(
max_wgt_servers=1,
max_act_per_csx=1,
num_workers_per_csx=1,
),
)
model = Model()
# Compile the model
compiled_model = cstorch.compile(model, backend=backend)
# Define an optimizer
optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Define a data loader
def get_torch_dataloader(batch_size, train):
from torchvision import datasets, transforms
train_dataset = datasets.MNIST(
"/path/to/data", # change this path
train=train,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
target_transform=transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=torch.int32)
),
)
return torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
training_dataloader = cstorch.utils.data.DataLoader(
get_torch_dataloader, batch_size=64, train=True
)
# Define the training step
loss_fn = torch.nn.CrossEntropyLoss()
@cstorch.trace
def training_step(inputs, targets):
outputs = compiled_model(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
@cstorch.step_closure
def print_loss(loss: torch.Tensor, step: int):
print(f"Train Loss {step}: {loss.item()}")
@cstorch.checkpoint_closure
def save_checkpoint(step):
cstorch.save(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
},
f"checkpoint_{step}.mdl",
)
global_step = 0
train_executor = cstorch.utils.data.DataExecutor(
training_dataloader,
num_steps=100,
checkpoint_steps=50,
)
model.train()
for inputs, targets in train_executor:
loss = training_step(inputs, targets)
print_loss(loss, global_step)
global_step += 1
save_checkpoint(global_step)
if __name__ == "__main__":
main()
Further Reading