Source code for pasteur.utils.progress

""" This utility module provides constants and functions for multiprocessing
and progress monitoring in Pasteur.

In most cases, the functions and constants are simple wrappers around existing libraries.
The common use of the primitives in the codebase allows for using different
implementations in the future."""

import functools
import io
import logging
import sys
import time
from contextlib import contextmanager
from functools import reduce
from multiprocessing.pool import AsyncResult, Pool
from os import cpu_count, environ
from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TextIO, TypeGuard, TypeVar

from rich import get_console
from tqdm import tqdm, trange

if TYPE_CHECKING:
    from multiprocessing.managers import SyncManager


X = TypeVar("X")
P = ParamSpec("P")
logger = logging.getLogger(__name__)

# Jupyter doesn't support going up lines (moving cursor)
# This means up to 1 loading bar works
JUPYTER_MAX_NEST = 1
PBAR_COLOR = "blue"
PBAR_OFFSET = 11
PBAR_FORMAT = (" " * PBAR_OFFSET) + ">>>>>>>  {l_bar}{bar}{r_bar}"
# Exact number for notebooks rendered in github to use up the whole width
# Assumes a stripping github filter is used to remove the empty space (or time)
# at the start
PBAR_JUP_NCOLS = 135 + PBAR_OFFSET

RICH_TRACEBACK_ARGS = {
    "show_locals": False,
    # "max_frames": 10,
    "suppress": ["kedro", "click"],
}
PROGRESS_STEP_NS = 50_000_000

# Debug flag disables exception wrapping to print the correct stack trace so
# debugger stops at the top-most exception
DEBUG = bool(int(environ.get("_DEBUG", 0)))
MULTIPROCESS_ENABLE = bool(int(environ.get("_MULTIPROCESS", 1)))
IS_SUBPROCESS = False

CHECK_LEAKS = False


[docs]def is_jupyter() -> bool: # pragma: no cover """Check if we're running in a Jupyter notebook. taken from rich package.""" try: get_ipython # type: ignore[name-defined] except NameError: return False ipython = get_ipython() # type: ignore[name-defined] shell = ipython.__class__.__name__ if "google.colab" in str(ipython.__class__) or shell == "ZMQInteractiveShell": return True # Jupyter notebook or qtconsole elif shell == "TerminalInteractiveShell": return False # Terminal running IPython else: return False # Other type (?)
[docs]def get_tqdm_args(): if IS_SUBPROCESS: """Disable subprocess pbars until a better solution.""" disable = True else: active_pbars = len(tqdm._instances) # type: ignore disable = is_jupyter() and active_pbars >= JUPYTER_MAX_NEST return { "disable": disable, "colour": PBAR_COLOR, "bar_format": PBAR_FORMAT, "ncols": PBAR_JUP_NCOLS if is_jupyter() else None, "dynamic_ncols": not is_jupyter(), "ascii": True if is_jupyter() else None, "file": sys.stdout if is_jupyter() else sys.stderr, }
A = TypeVar("A", bound=Callable)
[docs]def limit_pbar_nesting(pbar_gen: A) -> A: """Prevent nesting too much on jupyter. This causes ugly gaps to be generated on vs code. Up to 2 progress bars work fine.""" @functools.wraps(pbar_gen) def closure(*args, **kwargs): return pbar_gen(*args, **kwargs, **get_tqdm_args()) return closure # type: ignore
prange = limit_pbar_nesting(trange) piter = limit_pbar_nesting(tqdm) def _wrap_exceptions( fun: Callable[P, X], /, node_name: str, *args: P.args, **kwargs: P.kwargs ) -> X: set_node_name(node_name) if DEBUG: # skip catching exception so that debugger catches intenal exception. return fun(*args, **kwargs) try: res = fun(*args, **kwargs) if CHECK_LEAKS: # check for leaks after first execution from pasteur.utils.leaks import check, clear clear() a = fun(*args, **kwargs) del a check(f"Node {node_name} leaks") return res except Exception as e: get_console().print_exception(**RICH_TRACEBACK_ARGS) logger.error( f'Subprocess of "{get_node_name()}" failed with error:\n{type(e).__name__}: {e}' ) raise RuntimeError("subprocess failed") from e def _calc_worker(args): ( node_name, progress_send, progress_lock, initializer, fun, finalizer, base_args, chunk, ) = args set_node_name(node_name) if initializer is not None: try: base_args, chunk = initializer(base_args, chunk) except Exception as e: get_console().print_exception(**RICH_TRACEBACK_ARGS) logger.error( f'Subprocess initialization of "{get_node_name()}" failed with error:\n{type(e).__name__}: {e}' ) raise e last_update = time.perf_counter_ns() out = [] u = 0 for i, op in enumerate(chunk): try: args = {**base_args, **op} if base_args else op out.append(fun(**args)) if CHECK_LEAKS: # Run second so first run loads modules from pasteur.utils.leaks import check, clear clear() a = fun(**args) del a check(f"Node {node_name}:{i} leaks") except Exception as e: get_console().print_exception(**RICH_TRACEBACK_ARGS) logger.error( f'Subprocess of "{get_node_name()}" at index {i} failed with error:\n{type(e).__name__}: {e}' ) raise e u += 1 curr_time = time.perf_counter_ns() if curr_time - last_update > PROGRESS_STEP_NS: with progress_lock: progress_send.send(u) last_update = curr_time u = 0 if finalizer is not None: try: finalizer(base_args, chunk) except Exception as e: get_console().print_exception(**RICH_TRACEBACK_ARGS) logger.error( f'Subprocess finalization of "{get_node_name()}" failed with error:\n{type(e).__name__}: {e}' ) raise e if u != 0: with progress_lock: progress_send.send(u) # progress_send.close() return out _max_workers: int = 1 _pool: "tuple[Pool, SyncManager, Any] | None" = None def _logging_thread_fun(q): try: while True: record = q.get() if record is None: break logger = logging.getLogger(record.name) logger.handle(record) except EOFError: pass def _replace_loggers_with_queue(q): loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] loggers.append(logging.root) for l in loggers: l.propagate = True l.handlers = [] l.level = logging.NOTSET from logging.handlers import QueueHandler logging.root.handlers.append(QueueHandler(q)) def _init_subprocess(log_queue): import signal signal.signal(signal.SIGINT, signal.SIG_IGN) if log_queue is not None: # Kedro installs rich logging when importing the following module # and messes with loggers. Import before replacing the loggers. import kedro.framework.project as _ _replace_loggers_with_queue(log_queue) global IS_SUBPROCESS IS_SUBPROCESS = True def _get_pool(): global _pool if _pool is None: assert MULTIPROCESS_ENABLE, "Multiprocessing has been disabled. Preventing pool creation." logger.warning( "Launching a process pool implicitly. Use `init_pool()` to explicitly control pool creation." ) init_pool() assert _pool is not None return _pool
[docs]def get_manager(): """ Returns the manager of the current process pool. """ return _get_pool()[1]
_node_name: Any | None = None
[docs]def set_node_name(name: str): global _node_name if _node_name is None: from threading import local _node_name = local() _node_name.name = name # type: ignore else: _node_name.name = name
[docs]def get_node_name(): if _node_name and hasattr(_node_name, "name"): return _node_name.name return "UKN_NODE"
[docs]def close_pool(): global _pool if _pool is None: return pool, _, log_queue = _pool log_queue.put(None) pool.terminate() _pool = None
def _init_pool(max_workers: int | None = None, refresh_processes: int | None = None): # from multiprocessing import Pool, Manager import threading from multiprocessing import get_context global _pool, _max_workers close_pool() ctx = get_context("spawn") manager = ctx.Manager() _max_workers = max_workers or cpu_count() or 1 # set up logging handler for subprocesses log_queue = manager.Queue() lp = threading.Thread(target=_logging_thread_fun, args=(log_queue,)) lp.start() pool = ctx.Pool( processes=max_workers, initializer=_init_subprocess, initargs=(log_queue,), maxtasksperchild=refresh_processes, ) _pool = (pool, manager, log_queue) return _pool
[docs]class init_pool:
[docs] def __init__( self, max_workers: int | None = None, refresh_processes: int | None = None ) -> None: """Creates a shared process pool for all threads in this process. `max_workers` should be set based either on cores or on how many RAM GBs will be required by each process. `log_queue` connects the subprocesses to the main process logger, see `pasteur.kedro.runner.parallel.py` `refresh_processes` sets `maxtasksperchild` for the pool, which prevents memory leaks from snowballing from node to node. However, due to additional imports every restart, it is slower.""" _init_pool(max_workers, refresh_processes)
def __enter__(self): return None def __exit__(self, type, value, traceback): close_pool()
[docs]def process(fun: Callable[P, X], *args: P.args, **kwargs: P.kwargs) -> X: """Uses a separate process to complete this task, taken from the common pool.""" if not MULTIPROCESS_ENABLE or IS_SUBPROCESS: return fun(*args, **kwargs) return _get_pool()[0].apply(_wrap_exceptions, (fun, get_node_name(), *args), kwargs) # type: ignore
[docs]class AsyncResultStub(AsyncResult): def __init__(self, obj): super().__init__(None, None, None) # type: ignore self.obj = obj
[docs] def ready(self): return True
[docs] def successful(self): return True
[docs] def wait(self, timeout=None): ...
[docs] def get(self, timeout=None): return self.obj
[docs]def process_async( fun: Callable[P, Any], *args: P.args, **kwargs: P.kwargs ) -> AsyncResult: """Uses a separate process to complete this task, taken from the common pool.""" if not MULTIPROCESS_ENABLE or IS_SUBPROCESS: return AsyncResultStub(fun(*args, **kwargs)) return _get_pool()[0].apply_async(_wrap_exceptions, (fun, get_node_name(), *args), kwargs) # type: ignore
[docs]def process_in_parallel( fun: Callable[..., X], per_call_args: list[dict], base_args: dict[str, Any] | None = None, min_chunk_size: int = 1, desc: str | None = None, max_worker_mult: int = 1, initializer: Callable | None = None, finalizer: Callable[..., None] | None = None, ) -> list[X]: """Processes arguments in parallel using the common process pool and prints progress bar. Implements a custom form of chunk iteration, where `base_args` contains arguments with large size that are common in all function calls and `per_call_args` which change every iteration.""" from multiprocessing import Lock, Pipe if ( # len(per_call_args) < 2 * min_chunk_size not MULTIPROCESS_ENABLE or IS_SUBPROCESS ): if initializer is not None: base_args, per_call_args = initializer(base_args, per_call_args) out = [] for args in piter( per_call_args, desc=desc, leave=False, ): kwargs = args.copy() if base_args: kwargs.update(base_args) res = _wrap_exceptions(fun, get_node_name(), **kwargs) out.append(res) if finalizer is not None: finalizer(base_args, per_call_args) return out pool, manager, _ = _get_pool() progress_recv, progress_send = Pipe(duplex=False) progress_lock = manager.Lock() n_tasks = len(per_call_args) if n_tasks == 0: return [] chunk_n_suggestion = min( max_worker_mult * _max_workers, (n_tasks - 1) // min_chunk_size + 1 ) chunk_len = (n_tasks - 1) // chunk_n_suggestion + 1 chunk_n = (n_tasks - 1) // chunk_len + 1 chunks = [ per_call_args[chunk_len * j : min(chunk_len * (j + 1), n_tasks)] for j in range(chunk_n) ] args = [] for chunk in chunks: args.append( ( get_node_name(), progress_send, progress_lock, initializer, fun, finalizer, base_args, chunk, ) ) res = pool.map_async(_calc_worker, args) pbar = piter(desc=desc, leave=False, total=n_tasks) n = 0 while not res.ready(): u = progress_recv.recv() n += u pbar.update(u) if n == n_tasks: break out = [] for sub_arr in res.get(): out.extend(sub_arr) progress_send.close() progress_recv.close() pbar.close() return out
def _is_console_logging_handler(handler) -> TypeGuard[logging.StreamHandler]: return isinstance(handler, logging.StreamHandler) and handler.stream in { sys.stdout, sys.stderr, }
[docs]@contextmanager def logging_redirect_pbar(): """ "Implementation of the logging_redirect_tqdm context manager that supports the rich handler.""" from rich import get_console loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] loggers.append(logging.root) orig_streams: list[list[TextIO | None]] = [] # Use stdout for jupyter to avoid coloring it red out_stream = sys.stdout if is_jupyter() else sys.stderr class PbarIO(io.StringIO): def write(self, text: str): tqdm.write(text[:-1], file=out_stream) # Swap rich logger pbar_stream = PbarIO() c = get_console() rich_fn = c.file c.file = pbar_stream try: for logger in loggers: # Swap standard loggers orig_streams.append([]) for handler in logger.handlers: if _is_console_logging_handler(handler): orig_streams[-1].append(handler.stream) # type: ignore handler.setStream(pbar_stream) else: orig_streams[-1].append(None) yield finally: for logger, orig_streams_logger in zip(loggers, orig_streams): for handler, stream in zip(logger.handlers, orig_streams_logger): if isinstance(handler, logging.StreamHandler) and stream is not None: handler.setStream(stream) if rich_fn is not None: c.file = rich_fn
__all__ = [ "MULTIPROCESS_ENABLE", "piter", "prange", "process", "process_async", "process_in_parallel", "logging_redirect_pbar", "init_pool", "close_pool", "get_manager", "set_node_name", "get_node_name", "reduce" ]