How to Sparsify Your Model
Sparsifying a model is straightforward with the cstorch API. Here’s an example where 30% of the values in every parameter (such as weights, biases, embeddings, among others) are set to zero prior to training:model.apply(sparsity)
, your model parameters are sparsified, enhancing training efficiency.
Important considerations
-
Only once the model has been compiled, apply sparsity with
cstorch.compile
, ensuring all parameters are on the Cerebras device. -
To exclude certain parameters from sparsity, set
param.requires_dense = True
. If a parameter does not have this attribute, the algorithm assumes that it isFalse
.
Sparsifying Optimizers
For training, simply sparsifying the model’s parameters is insufficient; the optimizer must be sparsified as well. To extend sparsity to your optimizer:optimizer.step()
is executed.
Executing optimizer.apply(sparsity)
transforms your optimizer into a sparse optimizer.
Important considerations
- The sparsity algorithm targets all optimizer states associated with a parameter, assuming the state tensor matches the parameter’s shape. To exempt certain state tensors from sparsification, designate them as requiring to be dense:
False
.
- Sparsity algorithms typically include a hook that updates the sparsity pattern after each
optimizer.step()
call. This automatic update feature can be deactivated if necessary:
Sparsity Algorithms
The cstorch API offers several out-of-the-box sparsity algorithms, including:Composing Sparsity Algorithms
You can apply distinct sparsity strategies to various parameter groups within your model. For instance, one group of weights might be statically reduced to 30% of its original values, while another group undergoes dynamic sparsity adjustments using the SET algorithm. This can be achieved by composing different sparsity algorithms into a composable strategy with theGroup
class.
fc1.*
glob pattern, while employing the SET sparsity algorithm for parameters that match the fc2.*
glob pattern.
Writing Custom Sparsity Algorithms
All sparsity algorithms must inherit from the baseSparsityAlgorithm
.
update
which takes care of updating the sparsity patterns for all sparse parameters.
For algorithms that dynamically change the sparsity pattern, there is a convenient DynamicSparsityAlgorithm
class that you can inherit from that takes care of many of the implementation details required to facilitate dynamic sparsity.
DynamicSparsityAlgorithm
already implements update
, but it exposes a new abstract method update_mask
that must be overriden instead. update_mask
takes in the existing sparsity pattern in the form of a mask tensor and must return the new sparsity pattern in the form of a mask tensor as well.
See GMP
, SET
, and RigL
for examples of how to implement update_mask
.
In addition, there are many building blocks that are provided that can be used directly, inherited from, or composed to help build new DynamicSparsityAlgorithm
subclasses. See Customizing Sparsity & Reference for more details.
Once you’ve written your custom sparsity algorithm, as long as it’s available in the global scope, you can use it directly or even through a call to configure
by setting the algorithm
to be the name of your custom sparsity algorithm class. By extension, this means that you can use it in ModelZoo in a similar way by setting the algorithm
to be the name of your custom sparsity algorithm class in your params YAML file.
Implementation Notes
The Cerebras Wafer-Scale Cluster natively implements sparse computations in the Compressed Sparse Row (CSR) format. For user convenience, sparse models are represented as a combination of dense tensors and masks at the PyTorch level, with the compiler seamlessly converting between these representations. While PyTorch provides tools for representing sparse tensors and utilities for pruning networks, these features might not fully align with the needs of the Cerebras Wafer Scale Engine (WSE). Sparse tensors in PyTorch require specialized kernels and may not be entirely compatible with existing models and utilities. Notably, atorch.nn.Parameter
cannot directly accommodate a torch.sparse.Tensor
without specific adjustments. The torch.prune
utilities are convenient, but the asynchronous and precompiled nature of computation on the WSE requires a custom solution.
Similar to how torch.prune
handles its mask tensors, when the sparsity algorithm is applied to the model, every parameter that is sparsified has a mask tensor registered as a stateful buffer next to it in the module that owns the parameter.
For example, take the following simple model:
weight
and weight_mask
tensors collectively represent the sparsified weight
, showing how sparsity is represented within the model’s architecture.