Source code for pasteur.kedro.runner.common
import logging
from typing import Any, Callable, cast
from kedro.io import DataCatalog
from kedro.pipeline.node import Node
from kedro.runner.runner import _call_node_run, _collect_inputs_from_hook
from pluggy import PluginManager
from rich import get_console
from ...utils.perf import PerformanceTracker
from ...utils.progress import RICH_TRACEBACK_ARGS, process_in_parallel, set_node_name
logger = logging.getLogger(__name__)
def _task_worker(fun: Callable[..., dict[str, Any]], catalog: DataCatalog):
res = fun()
for name, data in res.items():
catalog.save(name, data)
[docs]def run_expanded_node(
node: Node,
catalog: DataCatalog,
hook_manager: PluginManager,
session_id: str | None = None,
) -> Node:
"""Handles expanded output's option of returning a set of callables.
Callables are processed by the process pool and the result is untangled
by `ExpandedNode` into the dictionary that gets saved in the catalog.
It also handles printing exceptions."""
node_name = node.name.split("(")[0]
set_node_name(node_name)
try:
t = PerformanceTracker.get("nodes")
t.log_to_file()
t.start(node_name)
inputs = {}
session_id = cast(str, session_id)
for name in node.inputs:
hook_manager.hook.before_dataset_loaded(node=node, dataset_name=name) # type: ignore
inputs[name] = catalog.load(name)
hook_manager.hook.after_dataset_loaded(node=node, dataset_name=name, data=inputs[name]) # type: ignore
is_async = False
additional_inputs = _collect_inputs_from_hook(
node, catalog, inputs, is_async, hook_manager, session_id=session_id
)
inputs.update(additional_inputs)
outputs = _call_node_run(
node, catalog, inputs, is_async, hook_manager, session_id=session_id
)
except Exception as e:
if not (isinstance(e, RuntimeError) and str(e) == "subprocess failed"):
# Prevent printing traceback for subprocesses that crash
get_console().print_exception(**RICH_TRACEBACK_ARGS)
logger.error(
f'Node "{node_name}" failed with error:\n{type(e).__name__}: {e}'
)
raise
# Clear outputs
for name in node.outputs:
d = catalog._get_dataset(name)
if hasattr(d, "reset"):
getattr(d, "reset")()
if isinstance(outputs, set):
# TODO: Fix hooks
# When a set is received, process it in parallel
process_in_parallel(
_task_worker,
per_call_args=[{"fun": fun} for fun in outputs],
base_args={"catalog": catalog},
desc=f"Processing tasks ({node.name.split('(')[0]:>25s})",
)
else:
try:
for name, data in outputs.items():
hook_manager.hook.before_dataset_saved(node=node, dataset_name=name, data=data) # type: ignore
catalog.save(name, data)
hook_manager.hook.after_dataset_saved(node=node, dataset_name=name, data=data) # type: ignore
except Exception as e:
if not (isinstance(e, RuntimeError) and str(e) == "subprocess failed"):
# Prevent printing traceback for subprocesses that crash
get_console().print_exception(**RICH_TRACEBACK_ARGS)
logger.error(
f'Saving "{node_name}" failed with error:\n{type(e).__name__}: {e}'
)
# TODO: Handle dataset errors better
for name in node.confirms:
catalog.confirm(name)
t.stop(node.name.split("(")[0])
return node