colossalai.trainer

class colossalai.trainer.Trainer(engine, timer=None, logger=None)[source]

This is a class tending for easy deployments of users’ training and evaluation instead of writing their own scripts. It is similar with ignite.engine and keras.engine, but is called Trainer.

Parameters
  • engine (Engine) – Engine responsible for the process function.

  • timer (MultiTimer, optional) – Timer used to monitor the whole training.

  • logger (colossalai.logging.DistributedLogger, optional) – Logger used to record the whole training log.

Examples

>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
>>> model = ...
>>> criterion = ...
>>> optimizer = ...
>>> train_dataloader = ...
>>> # Initialize your engine, train_dataloader, test_dataloader, lr_scheduler
>>> engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion)
>>> # Beginning training progress
>>> timier = ...
>>> logger = ...
>>> trainer = Trainer(engine=engine, logger=logger, timer=timier)
>>> # add hooks you would like to use here.
>>> hook_list = []
>>> trainer.fit(
>>>    train_dataloader=train_dataloader,
>>>    epochs=gpc.config.NUM_EPOCHS,
>>>    test_interval=1,
>>>    hooks=hook_list,
>>>    display_progress=True,
>>>    return_output_label=False
>>>    )

More examples and details could be found in Training with engine and trainer and ColossalAI-Examples.

property cur_epoch

Returns the index of the current epoch.

property cur_step

Returns how many iteration steps have been processed.

fit(train_dataloader, epochs, max_steps=None, test_dataloader=None, test_interval=1, hooks=None, display_progress=False, return_output_label=True)[source]

Trains the model to fit training data.

Parameters
  • train_dataloader (torch.utils.data.DataLoader) – DataLoader for training.

  • epochs (int) – Maximum number of epochs.

  • max_steps (int, optional) – Maximum number of running iterations.

  • test_dataloader (torch.utils.data.DataLoader, optional) – DataLoader for validation.

  • test_interval (int, optional) – Interval of validation

  • hooks (list[BaseHook], optional) – A list of hooks used in training.

  • display_progress (bool, optional) – If True, a progress bar will be displayed.

evaluate(test_dataloader, hooks=None, display_progress=False, return_output_label=True)[source]

Evaluates the model with testing data.

Parameters
  • test_dataloader (torch.utils.data.DataLoader, optional) – Dataloader for testing.

  • hooks (list, optional) – A list of hooks used in evaluation. Defaults to None.

  • display_progress (bool, optional) – If True, the evaluation progress will be printed. Defaults to False.

  • return_output_label (bool, optional) – If True, the output of model and the label will be returned. Defaults to True.

predict(data)[source]

Uses trained model to make a prediction for a tensor or a tensor list.

Parameters

data (Union[torch.tensor, List[torch.tensor]]) – Data as the input.

Returns

The output of model as the prediction

Return type

torch.tensor