colossalai.engine
- class colossalai.engine.Engine(model, optimizer, criterion=None, gradient_handlers=None, clip_grad_norm=0.0, ophook_list=None, verbose=True, schedule=None)[source]
Basic engine class for training and evaluation. It runs a specific process method
step()
which is based on the givenschedule
over each batch of a dataset. It controls a iteration in training.- Parameters
model (
torch.nn.Module
) – The neural network model.optimizer (
colossalai.nn.optimizer.ColossalaiOptimizer
) – Optimizer for updating the parameters.criterion (
torch.nn.modules.loss._Loss
, optional) – Loss function for calculating loss.gradient_handlers (List[
BaseGradientHandler
], optional) – A list of gradient handler used in backward.clip_grad_norm (float, optional) – The norm of gradient clipping.
ophook_list (list) – List of ophook.
verbose (bool) – whether to display log info.
schedule (''BaseSchedule'') – Runtime schedule.
Examples
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training >>> model = ... >>> criterion = ... >>> optimizer = ... >>> train_dataloader = ... >>> engine, _, _, _ = colossalai.initialize(model, optimizer, criterion) >>> engine.train() >>> for inputs, labels in train_dataloader >>> # set gradients to zero >>> engine.zero_grad() >>> # run forward pass >>> outputs = engine(inputs) >>> # compute loss value and run backward pass >>> loss = engine.criterion(outputs, labels) >>> engine.backward(loss) >>> # update parameters >>> engine.step()
The example of using Engine in training could be find in Training with engine and trainer. and Run resnet cifar10 with engine.
- property ophooks
show current activated ophooks
- property model
Model attached to the engine
- property optimizer
Optimizer attached to the engine
- property criterion
Criterion attached to the engine
- property schedule
Schedule attached to the engine
- property uses_pipeline
show the pipeline parallel used or not
- backward(loss)[source]
Start backward propagation given the loss value computed by a loss function.
- Parameters
loss (
torch.Tensor
) – Loss value computed by a loss function.
- backward_by_grad(tensor, grad)[source]
Start backward propagation given the gradient of the output tensor.
- Parameters
tensor (
torch.Tensor
) – Output tensor.grad (
torch.Tensor
) – Gradient passed back to the output.