colossalai.testing

colossalai.testing.parameterize(argument, values)[source]

This function is to simulate the same behavior as pytest.mark.parameterize. As we want to avoid the number of distributed network initialization, we need to have this extra decorator on the function launched by torch.multiprocessing.

If a function is wrapped with this wrapper, non-paramterized arguments must be keyword arguments, positioanl arguments are not allowed.

Usgae:

# Example 1:
@parameterize('person', ['xavier', 'davis'])
def say_something(person, msg):
    print(f'{person}: {msg}')

say_something(msg='hello')

# This will generate output:
# > xavier: hello
# > davis: hello

# Exampel 2:
@parameterize('person', ['xavier', 'davis'])
@parameterize('msg', ['hello', 'bye', 'stop'])
def say_something(person, msg):
    print(f'{person}: {msg}')

say_something()

# This will generate output:
# > xavier: hello
# > xavier: bye
# > xavier: stop
# > davis: hello
# > davis: bye
# > davis: stop
Parameters
  • argument (str) – the name of the argument to parameterize

  • values (List[Any]) – a list of values to iterate for this argument

colossalai.testing.rerun_on_exception(exception_type=<class 'Exception'>, pattern=None, max_try=5)[source]

A decorator on a function to re-run when an exception occurs.

Usage:

# rerun for all kinds of exception
@rerun_on_exception()
def test_method():
    print('hey')
    raise RuntimeError('Address already in use')

# rerun for RuntimeError only
@rerun_on_exception(exception_type=RuntimeError)
def test_method():
    print('hey')
    raise RuntimeError('Address already in use')

# rerun for maximum 10 times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, max_try=10)
def test_method():
    print('hey')
    raise RuntimeError('Address already in use')

# rerun for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, max_try=None)
def test_method():
    print('hey')
    raise RuntimeError('Address already in use')

# rerun only the exception message is matched with pattern
# for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$")
def test_method():
    print('hey')
    raise RuntimeError('Address already in use')
Parameters
  • exception_type (Exception, Optional) – The type of exception to detect for rerun

  • pattern (str, Optional) – The pattern to match the exception message. If the pattern is not None and matches the exception message, the exception will be detected for rerun

  • max_try (int, Optional) – Maximum reruns for this function. The default value is 5. If max_try is None, it will rerun foreven if exception keeps occurings

colossalai.testing.rerun_if_address_is_in_use()[source]

This function reruns a wrapped function if “address already in use” occurs in testing spawned with torch.multiprocessing

Usage:

@rerun_if_address_is_in_use()
def test_something():
    ...
colossalai.testing.skip_if_not_enough_gpus(min_gpus)[source]

This function is used to check the number of available GPUs on the system and automatically skip the test cases which require more GPUs.

Note

The wrapped function must have world_size in its keyword argument.

Usage:

@skip_if_not_enough_gpus(min_gpus=8) def test_something():

# will be skipped if there are fewer than 8 GPUs available do_something()

Arg:

min_gpus (int): the minimum number of GPUs required to run this test.