"""This utility module provides constants and functions for multiprocessing
and progress monitoring in Pasteur.
In most cases, the functions and constants are simple wrappers around existing libraries.
The common use of the primitives in the codebase allows for using different
implementations in the future."""
import functools
import io
import logging
import re
import shutil
import sys
import time
from contextlib import contextmanager
from functools import reduce
from multiprocessing.pool import AsyncResult, Pool
from os import cpu_count, environ
from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TextIO, TypeGuard, TypeVar
from rich import get_console
from tqdm import tqdm, trange
if TYPE_CHECKING:
from multiprocessing.managers import SyncManager
X = TypeVar("X")
P = ParamSpec("P")
logger = logging.getLogger(__name__)
# Jupyter doesn't support going up lines (moving cursor)
# This means up to 1 loading bar works
JUPYTER_MAX_NEST = 1
PBAR_COLOR = "blue"
PBAR_OFFSET = 11
PBAR_FORMAT = (" " * PBAR_OFFSET) + ">>>>>>> {l_bar}{bar}{r_bar}"
# Exact number for notebooks rendered in github to use up the whole width
# Assumes a stripping github filter is used to remove the empty space (or time)
# at the start
PBAR_JUP_NCOLS = 135 + PBAR_OFFSET
RICH_TRACEBACK_ARGS = {
"show_locals": False,
# "max_frames": 10,
"suppress": ["kedro", "click"],
}
PROGRESS_STEP_NS = 50_000_000
# Debug flag disables exception wrapping to print the correct stack trace so
# debugger stops at the top-most exception
DEBUG = bool(int(environ.get("_DEBUG", 0)))
MULTIPROCESS_ENABLE = bool(int(environ.get("_MULTIPROCESS", 1)))
IS_SUBPROCESS = False
CHECK_LEAKS = False
def _is_jupyter() -> bool: # pragma: no cover
"""Check if we're running in a Jupyter notebook.
taken from rich package."""
try:
get_ipython # type: ignore[name-defined]
except NameError:
return False
ipython = get_ipython() # type: ignore[name-defined]
shell = ipython.__class__.__name__
if "google.colab" in str(ipython.__class__) or shell == "ZMQInteractiveShell":
return True # Jupyter notebook or qtconsole
elif shell == "TerminalInteractiveShell":
return False # Terminal running IPython
else:
return False # Other type (?)
_jupyter = None
[docs]
def is_jupyter() -> bool:
# Avoid triggering raised exceptions
global _jupyter
if _jupyter is None:
_jupyter = _is_jupyter()
return _jupyter
[docs]
def get_tqdm_args():
if IS_SUBPROCESS:
"""Disable subprocess pbars until a better solution."""
disable = True
else:
active_pbars = len(tqdm._instances) # type: ignore
disable = is_jupyter() and active_pbars >= JUPYTER_MAX_NEST
return {
"disable": disable,
"colour": PBAR_COLOR,
"bar_format": PBAR_FORMAT,
"ncols": PBAR_JUP_NCOLS if is_jupyter() else None,
"dynamic_ncols": not is_jupyter(),
"ascii": True if is_jupyter() else None,
"file": sys.stdout if is_jupyter() else sys.stderr,
}
A = TypeVar("A", bound=Callable)
[docs]
def limit_pbar_nesting(pbar_gen: A) -> A:
"""Prevent nesting too much on jupyter. This causes ugly gaps to be generated
on vs code. Up to 2 progress bars work fine."""
@functools.wraps(pbar_gen)
def closure(*args, **kwargs):
return pbar_gen(*args, **kwargs, **get_tqdm_args())
return closure # type: ignore
prange = limit_pbar_nesting(trange)
piter = limit_pbar_nesting(tqdm)
def _wrap_exceptions(
fun: Callable[P, X],
/,
node_name: str,
lock,
first_error,
*args: P.args,
**kwargs: P.kwargs,
) -> X:
set_node_name(node_name)
try:
res = fun(*args, **kwargs)
if CHECK_LEAKS:
# check for leaks after first execution
from pasteur.utils.leaks import check, clear
clear()
a = fun(*args, **kwargs)
del a
check(f"Node {node_name} leaks")
return res
except Exception as e:
# Use some logic to only print the first error
if lock and first_error:
with lock:
_first_error = first_error.value
first_error.value = False
else:
_first_error = True
# raise original exception to to catch proper breakpoint
if _first_error:
_first_error = False
if DEBUG:
raise e
get_console().print_exception(**RICH_TRACEBACK_ARGS)
logger.error(
f'Subprocess of "{get_node_name()}" failed with error:\n{type(e).__name__}: {e}'
)
raise RuntimeError("subprocess failed") from e
def _calc_worker(args):
(
node_name,
progress_send,
progress_lock,
progress_first_error,
initializer,
fun,
finalizer,
base_args,
chunk,
) = args
set_node_name(node_name)
u = 0
try:
ex = None
if initializer is not None:
try:
base_args, chunk = initializer(base_args, chunk)
except Exception as e:
with progress_lock:
first_error = progress_first_error.value
progress_first_error.value = False
if first_error:
if DEBUG:
raise e
get_console().print_exception(**RICH_TRACEBACK_ARGS)
logger.error(
f'Subprocess initialization of "{get_node_name()}" failed with error:\n{type(e).__name__}: {e}'
)
raise RuntimeError("subprocess failed") from e
last_update = time.perf_counter_ns()
out = []
u = 0
for i, op in enumerate(chunk):
try:
args = {**base_args, **op} if base_args else op
out.append(fun(**args))
if CHECK_LEAKS:
# Run second so first run loads modules
from pasteur.utils.leaks import check, clear
clear()
a = fun(**args)
del a
check(f"Node {node_name}:{i} leaks")
except Exception as e:
with progress_lock:
first_error = progress_first_error.value
progress_first_error.value = False
if first_error:
if DEBUG:
raise e
get_console().print_exception(**RICH_TRACEBACK_ARGS)
logger.error(
f'Subprocess of "{get_node_name()}" at index {i} failed with error:\n{type(e).__name__}: {e}'
)
raise RuntimeError("subprocess failed") from e
u += 1
curr_time = time.perf_counter_ns()
if curr_time - last_update > PROGRESS_STEP_NS:
with progress_lock:
progress_send.send(u)
last_update = curr_time
u = 0
if finalizer is not None:
try:
finalizer(base_args, chunk)
except Exception as e:
with progress_lock:
first_error = progress_first_error.value
progress_first_error.value = False
if first_error:
if DEBUG:
raise e
get_console().print_exception(**RICH_TRACEBACK_ARGS)
logger.error(
f'Subprocess finalization of "{get_node_name()}" failed with error:\n{type(e).__name__}: {e}'
)
raise RuntimeError("subprocess failed") from e
return out
finally:
# Always close pipe. Raise special exception to hide stack trace
if u != 0:
with progress_lock:
progress_send.send(u)
progress_send.close()
_max_workers: int = 1
_pool: "tuple[Pool, SyncManager, Any] | None" = None
[docs]
def get_max_workers() -> int:
"""Returns the maximum number of workers in the current process pool."""
return _max_workers
def _logging_thread_fun(q):
try:
while True:
record = q.get()
if record is None:
break
logger = logging.getLogger(record.name)
logger.handle(record)
except EOFError:
pass
def _replace_loggers_with_queue(q):
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
loggers.append(logging.root)
for l in loggers:
l.propagate = True
l.handlers = []
l.level = logging.NOTSET
from logging.handlers import QueueHandler
logging.root.handlers.append(QueueHandler(q))
def _init_subprocess(log_queue):
import signal
signal.signal(signal.SIGINT, signal.SIG_IGN)
if log_queue is not None:
# Kedro installs rich logging when importing the following module
# and messes with loggers. Import before replacing the loggers.
import kedro.framework.project as _
_replace_loggers_with_queue(log_queue)
global IS_SUBPROCESS
IS_SUBPROCESS = True
def _get_pool():
global _pool
if _pool is None:
assert (
MULTIPROCESS_ENABLE
), "Multiprocessing has been disabled. Preventing pool creation."
logger.warning(
"Launching a process pool implicitly. Use `init_pool()` to explicitly control pool creation."
)
init_pool()
assert _pool is not None
return _pool
[docs]
def get_manager():
"""Returns the manager of the current process pool."""
return _get_pool()[1]
_node_name: Any | None = None
[docs]
def set_node_name(name: str):
global _node_name
if _node_name is None:
from threading import local
_node_name = local()
_node_name.name = name # type: ignore
else:
_node_name.name = name
[docs]
def get_node_name():
if _node_name and hasattr(_node_name, "name"):
return _node_name.name
return "UKN_NODE"
[docs]
def close_pool():
global _pool
if _pool is None:
return
pool, _, log_queue = _pool
log_queue.put(None)
pool.terminate()
_pool = None
def _init_pool(max_workers: int | None = None, refresh_processes: int | None = None):
# from multiprocessing import Pool, Manager
import threading
from multiprocessing import get_context
global _pool, _max_workers
close_pool()
ctx = get_context("spawn")
manager = ctx.Manager()
_max_workers = max_workers or cpu_count() or 1
# set up logging handler for subprocesses
log_queue = manager.Queue()
lp = threading.Thread(target=_logging_thread_fun, args=(log_queue,))
lp.start()
pool = ctx.Pool(
processes=max_workers,
initializer=_init_subprocess,
initargs=(log_queue,),
maxtasksperchild=refresh_processes,
)
_pool = (pool, manager, log_queue)
return _pool
[docs]
class init_pool:
[docs]
def __init__(
self, max_workers: int | None = None, refresh_processes: int | None = None
) -> None:
"""Creates a shared process pool for all threads in this process.
`max_workers` should be set based either on cores or on how many RAM GBs
will be required by each process.
`log_queue` connects the subprocesses to the main process logger, see `pasteur.kedro.runner.parallel.py`
`refresh_processes` sets `maxtasksperchild` for the pool, which
prevents memory leaks from snowballing from node to node. However,
due to additional imports every restart, it is slower."""
_init_pool(max_workers, refresh_processes)
def __enter__(self):
return None
def __exit__(self, type, value, traceback):
close_pool()
[docs]
def process(fun: Callable[P, X], *args: P.args, **kwargs: P.kwargs) -> X:
"""Uses a separate process to complete this task, taken from the common pool."""
if not MULTIPROCESS_ENABLE or IS_SUBPROCESS:
return fun(*args, **kwargs)
pool, manager, _ = _get_pool()
return pool.apply(_wrap_exceptions, (fun, get_node_name(), manager.Lock(), manager.Value("first_error", True), *args), kwargs) # type: ignore
[docs]
class AsyncResultStub(AsyncResult):
def __init__(self, obj):
super().__init__(None, None, None) # type: ignore
self.obj = obj
[docs]
def ready(self):
return True
[docs]
def successful(self):
return True
[docs]
def wait(self, timeout=None): ...
[docs]
def get(self, timeout=None):
return self.obj
[docs]
def process_async(
fun: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
) -> AsyncResult:
"""Uses a separate process to complete this task, taken from the common pool."""
if not MULTIPROCESS_ENABLE or IS_SUBPROCESS:
return AsyncResultStub(fun(*args, **kwargs))
pool, manager, _ = _get_pool()
return pool.apply_async(_wrap_exceptions, (fun, get_node_name(), manager.Lock(), manager.Value("first_error", True), *args), kwargs) # type: ignore
[docs]
def process_in_parallel(
fun: Callable[..., X],
per_call_args: list[dict],
base_args: dict[str, Any] | None = None,
min_chunk_size: int = 1,
desc: str | None = None,
max_worker_mult: int = 1,
initializer: Callable | None = None,
finalizer: Callable[..., None] | None = None,
) -> list[X]:
"""Processes arguments in parallel using the common process pool and prints progress bar.
Implements a custom form of chunk iteration, where `base_args` contains arguments
with large size that are common in all function calls and `per_call_args` which
change every iteration."""
from multiprocessing import Lock, Pipe
if (
# len(per_call_args) < 2 * min_chunk_size
not MULTIPROCESS_ENABLE
or IS_SUBPROCESS
):
if initializer is not None:
base_args, per_call_args = initializer(base_args, per_call_args)
out = []
for args in piter(
per_call_args,
desc=desc,
leave=False,
):
kwargs = args.copy()
if base_args:
kwargs.update(base_args)
res = _wrap_exceptions(fun, get_node_name(), None, None, **kwargs)
out.append(res)
if finalizer is not None:
finalizer(base_args, per_call_args)
return out
pool, manager, _ = _get_pool()
progress_recv, progress_send = Pipe(duplex=False)
progress_lock = manager.Lock()
progress_first_error = manager.Value("first_error", True)
n_tasks = len(per_call_args)
if n_tasks == 0:
return []
chunk_n_suggestion = min(
max_worker_mult * _max_workers, (n_tasks - 1) // min_chunk_size + 1
)
chunk_len = (n_tasks - 1) // chunk_n_suggestion + 1
chunk_n = (n_tasks - 1) // chunk_len + 1
chunks = [
per_call_args[chunk_len * j : min(chunk_len * (j + 1), n_tasks)]
for j in range(chunk_n)
]
args = []
for chunk in chunks:
args.append(
(
get_node_name(),
progress_send,
progress_lock,
progress_first_error,
initializer,
fun,
finalizer,
base_args,
chunk,
)
)
res = pool.map_async(
_calc_worker, args, error_callback=lambda _: progress_send.close()
)
pbar = piter(desc=desc, leave=False, total=n_tasks)
n = 0
while not res.ready():
try:
u = progress_recv.recv()
n += u
pbar.update(u)
if n == n_tasks:
break
except Exception:
break
out = []
for sub_arr in res.get():
out.extend(sub_arr)
progress_send.close()
progress_recv.close()
pbar.close()
return out
def _is_console_logging_handler(handler) -> TypeGuard[logging.StreamHandler]:
return isinstance(handler, logging.StreamHandler) and handler.stream in {
sys.stdout,
sys.stderr,
}
[docs]
@contextmanager
def logging_redirect_pbar():
""" "Implementation of the logging_redirect_tqdm context manager that supports the rich handler."""
from math import ceil
from rich import get_console
ansi_re = re.compile(r"\x1b\[[0-9;]*[A-Za-z]")
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
loggers.append(logging.root)
orig_streams: list[list[TextIO | None]] = []
# Use stdout for jupyter to avoid coloring it red
out_stream = sys.stdout if is_jupyter() else sys.stderr
class PbarIO(io.StringIO):
rmtext = ""
def write(self, text: str):
if self.rmtext:
rm_prefix = self.rmtext
rm_suffix = "\r"
self.rmtext = ""
else:
rm_prefix = ""
rm_suffix = ""
if ":ephemeral:" in text:
# Derive from external_write_mode
with tqdm.get_lock():
min_pbar = -1
max_pbar = -1
for inst in getattr(tqdm, "_instances", []):
if (
hasattr(inst, "start_t")
and hasattr(inst, "pos")
and (inst.fp in [out_stream, sys.stdout, sys.stderr])
):
if min_pbar == -1 or inst.pos < min_pbar:
min_pbar = inst.pos
if max_pbar == -1 or inst.pos > max_pbar:
max_pbar = inst.pos
active_pbars = max_pbar - min_pbar + 1 if max_pbar != -1 else 0
cleaned = text.replace(":ephemeral:", "")
term_size = shutil.get_terminal_size(fallback=(80, 24))
term_w = term_size.columns
term_h = max(1, term_size.lines - active_pbars + 1)
def wrapped_rows(line: str, _extra: int = 0) -> int:
# Rich handles pretty printing for us, there are no
# double lines here. If we try to calculate we make
# Mistakes.
return 1
# vlen = len(ansi_re.sub("", line))
# return max(1, ceil((vlen + _extra) / term_w))
# keep only the trailing rows that fit on screen, with the
# first line for context
lines = cleaned.splitlines() or [""]
kept: list[str] = []
used = wrapped_rows(lines[0])
for line in reversed(lines[1:]):
if not line:
continue
need = wrapped_rows(line)
if used + need > term_h:
break
kept.append(line)
used += need
kept.reverse()
kept.insert(0, lines[0])
text = "\n".join(kept) if kept else ""
self.rmtext = "\033[F\033[2K" * max(used - 1, 0)
tqdm.write(rm_prefix + text, file=out_stream, end=rm_suffix)
# Swap rich logger
pbar_stream = PbarIO()
c = get_console()
rich_fn = c.file
c.file = pbar_stream
try:
for logger in loggers:
# Swap standard loggers
orig_streams.append([])
for handler in logger.handlers:
if _is_console_logging_handler(handler):
orig_streams[-1].append(handler.stream) # type: ignore
handler.setStream(pbar_stream)
else:
orig_streams[-1].append(None)
yield
finally:
for logger, orig_streams_logger in zip(loggers, orig_streams):
for handler, stream in zip(logger.handlers, orig_streams_logger):
if isinstance(handler, logging.StreamHandler) and stream is not None:
handler.setStream(stream)
if rich_fn is not None:
c.file = rich_fn
__all__ = [
"MULTIPROCESS_ENABLE",
"piter",
"prange",
"process",
"process_async",
"process_in_parallel",
"logging_redirect_pbar",
"init_pool",
"close_pool",
"get_manager",
"set_node_name",
"get_node_name",
"reduce",
]