Trainer
class was designed to be easily extendable using Callback
classes. The Trainer
exposes a number of hooks which can be overridden using a Callback
.
On this page, you will learn about the basic Callback
mechanism. By the end you should be able to write and use your own custom Callback
.
Prerequisites
Please ensure that you have read through the Cerebras Model Zoo Trainer Overview beforehand. The rest of this page assumes that you already have at least a cursory understanding of what the Cerebras Model Zoo Trainer is and how to use the python API.Callbacks
The callback mechanism is the backbone of theTrainer
’s implementation. A lot of the heavy lifting in the Trainer is actually done by various Core Callbacks.
In general, the Callback
mechanism exposes a number of useful hooks that allow you to inject certain behaviour into the Trainer
. These hooks include (but are not limited to)
-
setup
-
on_{fit,train,validate}_{start,end}
-
on_{train,validate}_batch_{start,end}
-
on_{after,before}_{forward,backward}
-
on_{after,before}_optimizer_{step,zero_grad}
-
on_{after,before}_scheduler_step
-
on_{save,load}_checkpoint
-
on_after_save_checkpoint
-
on_before_load_checkpoint
fit
call and where the various hooks get called.
Callback
class.
Pre-packaged Callbacks
There are many callbacks that come pre-packaged inside of the Model Zoo. See Add-on Callbacks for a complete list of all the callbacks available out-of-the-box in the Model Zoo You can use any number of them to enhance the Trainer for your run. For example,Global Callbacks
Any callback can be registered globally so that allTrainer
instances know about it and will invoke that callback’s hooks.
There are two ways to globally register a callback. The first way is to treat the callback as a context manager. For example,
CheckLoss
’s context, all trainer fit
calls inside the context will check the loss values that come out of the model.
The other way to register a callback is to call :py:function:~cerebras.modelzoo.trainer.callbacks.register\_global\_callback
.
For example,
fit
calls inside the context will check the loss values that come out of the model.
:py:function:~cerebras.modelzoo.trainer.callbacks.register\_global\_callback
returns a removeable handle object that can be used to remove the added callback by calling handle.remove()
Callback Ordering
TheTrainer
is comprised of many different callbacks that all serve to enhance its functionality`.
All of these callbacks share common hooks. These hooks must be called in a specific order. The order in which callbacks are invoked is as follows:
-
Core Callbacks: The callbacks that implement the most fundamental behaviour of the
Trainer
get called first. -
User-defined callbacks: The callbacks that are passed into the
callbacks
argument of theTrainer
’s constructor are called next. - Global callbacks: Finally, the callbacks that are registered globally are called.
on_fit_start
hook. Between the three callbacks that are highlighed in the above example, the order that the callbacks’s on_fit_start
hook is invoked is as follows:
-
TrainingLoop.on_fit_start
: AsTrainingLoop
is a core callback. -
ComputeNorm.on_fit_start
: AsComputeNorm
was passed into the Trainer’s constructor. -
CheckLoss.on_fit_start
: As it is a globally registered callback.
Writing a Custom Callback
To write your own custom callback class, all you need to do is inherit from the baseCallback
class and override the hooks that you need.
For example, let’s implement a simple callback that scales the loss value by some constant value before we call loss.backward()
Trainer
as follows: