import logging
from typing import Any, Callable, cast
from kedro.io import DataCatalog
from kedro.pipeline.node import Node
from pluggy import PluginManager
from kedro.pipeline import Pipeline, transcoding
from rich import get_console
from ...utils.perf import PerformanceTracker
from ...utils.progress import RICH_TRACEBACK_ARGS, process_in_parallel, set_node_name
logger = logging.getLogger(__name__)
def _collect_inputs_from_hook( # noqa: PLR0913
node: Node,
catalog: Any,
inputs: dict[str, Any],
is_async: bool,
hook_manager: PluginManager,
session_id: str | None = None,
run_id: str | None = None,
) -> dict[str, Any]:
inputs = inputs.copy() # shallow copy to prevent in-place modification by the hook
hook_response = hook_manager.hook.before_node_run(
node=node,
catalog=catalog,
inputs=inputs,
is_async=is_async,
session_id=session_id,
run_id=run_id,
)
additional_inputs = {}
if (
hook_response is not None
): # all hooks on a _NullPluginManager will return None instead of a list
for response in hook_response:
if response is not None and not isinstance(response, dict):
response_type = type(response).__name__
raise TypeError(
f"'before_node_run' must return either None or a dictionary mapping "
f"dataset names to updated values, got '{response_type}' instead."
)
additional_inputs.update(response or {})
return additional_inputs
def _call_node_run( # noqa: PLR0913
node: Node,
catalog: Any,
inputs: dict[str, Any],
is_async: bool,
hook_manager: PluginManager,
session_id: str | None = None,
run_id: str | None = None,
) -> dict[str, Any]:
try:
outputs = node.run(inputs)
except Exception as exc:
hook_manager.hook.on_node_error(
error=exc,
node=node,
catalog=catalog,
inputs=inputs,
is_async=is_async,
session_id=session_id,
)
raise exc
hook_manager.hook.after_node_run(
node=node,
catalog=catalog,
inputs=inputs,
outputs=outputs,
is_async=is_async,
session_id=session_id,
run_id=run_id,
)
return outputs
def _task_worker(fun: Callable[..., dict[str, Any]], catalog: DataCatalog):
res = fun()
for name, data in res.items():
catalog.save(name, data)
[docs]
def run_expanded_node(
node: Node,
catalog: DataCatalog,
hook_manager: PluginManager,
session_id: str | None = None,
run_id: str | None = None,
) -> Node:
"""Handles expanded output's option of returning a set of callables.
Callables are processed by the process pool and the result is untangled
by `ExpandedNode` into the dictionary that gets saved in the catalog.
It also handles printing exceptions."""
node_name = node.name.split("(")[0]
set_node_name(node_name)
try:
t = PerformanceTracker.get("nodes")
t.log_to_file()
t.start(node_name)
# Readd mlflow tracking
if run_id:
import mlflow
if not mlflow.active_run():
mlflow.start_run(run_id=run_id)
inputs = {}
session_id = cast(str, session_id)
for name in node.inputs:
hook_manager.hook.before_dataset_loaded(node=node, dataset_name=name) # type: ignore
inputs[name] = catalog.load(name)
hook_manager.hook.after_dataset_loaded(node=node, dataset_name=name, data=inputs[name]) # type: ignore
is_async = False
additional_inputs = _collect_inputs_from_hook(
node, catalog, inputs, is_async, hook_manager, session_id=session_id, run_id=run_id
)
inputs.update(additional_inputs)
outputs = _call_node_run(
node, catalog, inputs, is_async, hook_manager, session_id=session_id, run_id=run_id
)
except Exception as e:
if not (isinstance(e, RuntimeError) and str(e) == "subprocess failed"):
# Prevent printing traceback for subprocesses that crash
get_console().print_exception(**RICH_TRACEBACK_ARGS)
logger.error(
f'Node "{node_name}" failed with error:\n{type(e).__name__}: {e}'
)
logger.info(
f'To continue from this node, add `-c "{node.name.split("(", 1)[0]}" -s "{session_id}"` to a pipeline run\'s arguments.'
)
raise e
# Clear outputs
for name in node.outputs:
d = catalog[name]
if hasattr(d, "reset"):
getattr(d, "reset")()
if isinstance(outputs, set):
# TODO: Fix hooks
# When a set is received, process it in parallel
try:
process_in_parallel(
_task_worker,
per_call_args=[{"fun": fun} for fun in outputs],
base_args={"catalog": catalog},
desc=f"Processing tasks ({node.name.split('(')[0]:>25s})",
)
except Exception as e:
logger.info(
f'To continue from this node, add `-c "{node.name.split("(", 1)[0]}" -s "{session_id}"` to a pipeline run\'s arguments.'
)
raise e
else:
try:
for name, data in outputs.items():
hook_manager.hook.before_dataset_saved(node=node, dataset_name=name, data=data) # type: ignore
catalog.save(name, data)
hook_manager.hook.after_dataset_saved(node=node, dataset_name=name, data=data) # type: ignore
except Exception as e:
if not (isinstance(e, RuntimeError) and str(e) == "subprocess failed"):
# Prevent printing traceback for subprocesses that crash
get_console().print_exception(**RICH_TRACEBACK_ARGS)
logger.error(
f'Saving "{node_name}" failed with error:\n{type(e).__name__}: {e}'
)
# TODO: Handle dataset errors better
for name in node.confirms:
catalog.confirm(name)
t.stop(node.name.split("(")[0])
return node
[docs]
def resume_from(pipeline: Pipeline, node: str | None):
# Since the list is always the same, resume from the last node that was ran
groups = pipeline.grouped_nodes
if node is None:
return [n for group in groups for n in group]
out = []
started = False
for g in groups:
group_started = started
for n in g:
if group_started:
out.append(n)
continue
# Append us, skip the rest of the group
# assuming it ran. Fair assumption that skips extra
# steps but could cause failures down the pipeline
if n.name == node:
started = True
out.append(n)
return out
[docs]
def resume_from_dependencies(pipeline: Pipeline, node: str):
# For the parallel runner, get a bit more complicated.
# Only run nodes that are children of the resume node
dependencies: dict[Node, set[Node]] = pipeline.node_dependencies.copy()
# Find resume node
resume_node = None
for n in dependencies:
if n.name.startswith(node):
resume_node = n
break
assert resume_node is not None, f"Could not find node {node}"
keep = {resume_node}
dependencies[resume_node] = set()
stack: set[Node] = {resume_node}
while stack:
n = stack.pop()
for child, parents in dependencies.items():
if n in parents and child not in keep:
keep.add(child)
stack.add(child)
new_dependencies = {}
for n in keep:
new_dependencies[n] = dependencies[n].intersection(keep)
return new_dependencies, list(new_dependencies)