Source code for pararealml.utils.time

import functools
from typing import Callable, Optional, Any, Tuple
from timeit import default_timer as timer

from mpi4py import MPI


[docs]def time(function_name: Optional[str] = None) -> Callable: """ Times the execution of the wrapped function, prints the execution time using the provided function name, and returns the return value of the wrapped function along with the execution time. :param function_name: the name of the function :return: a function that returns the wrapped function """ def _time_wrapper_provider( function: Callable, name: Optional[str]) -> Callable: if name is None: name = f'{function.__name__!r}' @functools.wraps(function) def _time_wrapper(*args: Any, **kwargs: Any) -> Tuple[Any, float]: start_time = timer() value = function(*args, **kwargs) end_time = timer() run_time = end_time - start_time print(f'{name} completed in {run_time}s') return value, run_time return _time_wrapper return lambda function: _time_wrapper_provider(function, function_name)
[docs]def mpi_time(function_name: Optional[str] = None) -> Callable: """ Times the execution of the wrapped function using MPI, prints the execution time using the provided function name on the first rank, and returns the return value of the wrapped function along with the execution time. :param function_name: the name of the function :return: a function that returns the wrapped function """ def _mpi_time_wrapper_provider( function: Callable, name: Optional[str]) -> Callable: if name is None: name = f'{function.__name__!r}' @functools.wraps(function) def _mpi_time_wrapper(*args: Any, **kwargs: Any) -> Tuple[Any, float]: comm = MPI.COMM_WORLD comm.barrier() start_time = MPI.Wtime() value = function(*args, **kwargs) comm.barrier() end_time = MPI.Wtime() run_time = end_time - start_time if MPI.COMM_WORLD.rank == 0: print(f'{name} completed in {run_time}s') return value, run_time return _mpi_time_wrapper return lambda function: _mpi_time_wrapper_provider(function, function_name)