colossalai.utils
- colossalai.utils.checkpoint(function, activation_offload, *args)[source]
Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint.
- Parameters
function – Describe the forward pass function. It should know how to handle the input tuples.
args (list) – Tuple containing the parameters of the function
- Returns
Output of running function with provided args.
- colossalai.utils.print_rank_0(msg, logger=None)[source]
Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
- Parameters
msg (str) – A string message to output.
logger (
colossalai.logging.DistributedLogger
, optional) – The logger to record the message, defaults to None.
- colossalai.utils.sync_model_param(model, parallel_mode)[source]
Make sure data parameters are consistent during Data Parallel Mode.
- Parameters
model (
torch.nn.Module
) – A pyTorch model on whose parameters you check the consistency.parallel_mode (
colossalai.context.ParallelMode
) – Parallel mode to be checked.
Note
The parallel_mode should be concluded in
ParallelMode
. More details aboutParallelMode
could be found in parallel_mode
- colossalai.utils.clip_grad_norm_fp32(parameters, max_norm, norm_type=2)[source]
Clips gradient norm of an iterable of parameters whose gradients are in fp32.
This is adapted from
torch.nn.utils.clip_grad.clip_grad_norm_()
and added functionality to handle model parallel parameters.Note
the gradients are modified in place.
- Parameters
parameters (Iterable[
torch.tensor
] ortorch.tensor
) – An iterable of Tensors or a single Tensor that will have gradients normalized.max_norm (Union[float, int]) – Max norm of the gradients.
norm_type (Union[float, int, 'inf']) – Type of the used p-norm. Can be
'inf'
for infinity norm.
- Returns
Total norm of the parameters.
- Return type
float
- colossalai.utils.get_current_device()[source]
Returns currently selected device (gpu/cpu). If cuda available, return gpu, otherwise return cpu.
- colossalai.utils.synchronize()[source]
Similar to cuda.synchronize(). Waits for all kernels in all streams on a CUDA device to complete.
- colossalai.utils.empty_cache()[source]
Similar to cuda.empty_cache() Releases all unoccupied cached memory currently held by the caching allocator.
- colossalai.utils.set_to_cuda(models)[source]
Send model to gpu.
- Parameters
models – nn.module or a list of module
- colossalai.utils.report_memory_usage(message, logger=None, report_cpu=False)[source]
Calculate and print RAM usage (in GB)
- Parameters
message (str) – A prefix message to add in the log.
logger (
colossalai.logging.DistributedLogger
) – The logger used to record memory information.report_cpu (bool, optional) – Whether to report CPU memory.
- Raises
EnvironmentError – Raise error if no distributed environment has been initialized.
- colossalai.utils.colo_device_memory_capacity(device)[source]
Get the capacity of the memory of the device
- Parameters
device (torch.device) – a device
- Returns
size in byte
- Return type
int
- colossalai.utils.colo_device_memory_used(device)[source]
Get the device memory on device belonging to the current process.
- Parameters
device (torch.device) – a device
- Returns
memory size in bytes
- Return type
int
- colossalai.utils.colo_set_process_memory_fraction(ratio)[source]
set how much cuda memory used on the gpu belonging to the current process.
- Parameters
ratio (float) – a ratio between 0. ~ 1.
- class colossalai.utils.Timer[source]
A timer object which helps to log the execution times, and provides different tools to assess the times.
- stop(keep_in_history=False)[source]
Stop the timer and record the start-stop time interval.
- Parameters
keep_in_history (bool, optional) – Whether does it record into history each start-stop interval, defaults to False.
- Returns
Start-stop interval.
- Return type
int
- get_history_mean()[source]
Mean of all history start-stop time intervals.
- Returns
Mean of time intervals
- Return type
int
- get_history_sum()[source]
Add up all the start-stop time intervals.
- Returns
Sum of time intervals.
- Return type
int
- class colossalai.utils.MultiTimer(on=True)[source]
An object contains multiple timers.
- Parameters
on (bool, optional) – Whether the timer is enabled. Default is True.
- stop(name, keep_in_history)[source]
Stop namely one of the timers.
- Parameters
name (str) – Timer’s key.
keep_in_history (bool) – Whether does it record into history each start-stop interval.
- class colossalai.utils.DataParallelSampler(dataset, shuffle=False, seed=0, drop_last=False)[source]
A data sampler for distributed data parallelism.
- Parameters
dataset (
torch.utils.data.Dataset
) – The Dataset for sampling.shuffle (bool, optional) – Whether to shuffle data, defaults to False.
seed (int, optional) – The random seed used for sampling, defaults to 0.
drop_last (bool, optional) – Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller, defaults to False.
- colossalai.utils.get_dataloader(dataset, shuffle=False, seed=1024, add_sampler=True, drop_last=False, pin_memory=False, num_workers=0, **kwargs)[source]
Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
Note
When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data on the 1st stage and label on the last stage.
- Parameters
dataset (
torch.utils.data.Dataset
) – The dataset to be loaded.shuffle (bool, optional) – Whether to shuffle the dataset. Defaults to False.
seed (int, optional) – Random worker seed for sampling, defaults to 1024.
add_sampler – Whether to add
DistributedDataParallelSampler
to the dataset. Defaults to True.drop_last (bool, optional) – Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional) – Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional) – Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict) – optional parameters for
torch.utils.data.DataLoader
, more details could be found in DataLoader.
- Returns
A DataLoader used for training or testing.
- Return type
torch.utils.data.DataLoader
- colossalai.utils.load_checkpoint(file, model, optimizer=None, lr_scheduler=None, strict=True)[source]
Loads training states from a checkpoint file.
- Parameters
file – a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike object containing a file name.
model (
torch.nn.Module
) – Model to load saved weights and buffers.optimizer (Union[
torch.optim.Optimizer
,colossalai.nn.optimizer
]) – Optimizer to recuperate.lr_scheduler (
torch.optim.lr_scheduler._LRScheduler
, optional) – lr_scheduler to recuperate, defaults to None.strict (bool, optional) – Whether to strictly enforce that the keys in
state_dict
of the checkpoint match the names of parameters and buffers in model, defaults to True.
- Returns
The saved epoch number.
- Return type
int
- Raises
RuntimeError – Raise error if the model/optimizer cannot successfully be recuperated
- colossalai.utils.save_checkpoint(file, epoch, model, optimizer=None, lr_scheduler=None, **kwargs)[source]
Stores the checkpoint to disk. Saves all the training components’ parameters or buffers, such as model, optimizer, lr_scheduler etc. into a checkpoint dictionary.
- Parameters
file – a file-like object (has to implement write and flush) or a string or os.PathLike object containing a file name.
epoch (int) – Epoch number (indicates how many epochs have you trained this model).
model (
torch.nn.Module
) – Model to be saved.optimizer (Union[
torch.optim.Optimizer
,colossalai.nn.optimizer
]) – Optimizer to be saved.lr_scheduler (Union[
torch.optim.lr_scheduler
,colossalai.nn.lr_scheduler
], optional) – lr_scheduler to be saved, defaults to None.pickle_module – module used for pickling metadata and objects
pickle_protocol – can be specified to override the default protocol