Source code for pasteur.kedro.pipelines.views

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING

from kedro.pipeline import Pipeline as pipeline
from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline

from ...metadata import Metadata
from ...utils import LazyFrame
from ...utils.parser import get_params_for_pipe
from .meta import TAGS_VIEW, TAGS_VIEW_META, TAGS_VIEW_SPLIT
from .meta import DatasetMeta as D
from .meta import PipelineMeta, node
from .utils import get_params_closure

if TYPE_CHECKING:
    from ...view import View


def _create_metadata(view: str, params: dict):
    meta_dict = get_params_for_pipe(view, params)
    return Metadata(meta_dict)


def _check_tables(metadata: Metadata, **tables: LazyFrame):
    partitions = {}
    for name, table in tables.items():
        partitions[name] = table.sample()
    metadata.check(partitions)


[docs]def create_view_pipeline(view: View): return PipelineMeta( pipeline( [ node( func=view.query, name=f"query_{t}", args=[t], inputs={dep: f"{view.dataset}.{dep}" for dep in view.deps[t]}, namespace=f"{view}.view", outputs=f"{view}.view.{t}", tags=TAGS_VIEW, ) for t in view.tables ] ), [ D("primary", f"{view}.view.{t}", ["view", view, "tables", t], type="pq") for t in view.tables ], )
[docs]def create_check_tables_pipeline(view: View): return pipeline( [ node( func=_check_tables, name="check_tables", inputs={ "metadata": f"{view}.metadata", **{t: f"{view}.view.{t}" for t in view.tables}, }, outputs=None, namespace=f"{view}.view", tags=TAGS_VIEW_META, ) ] )
[docs]def create_meta_pipeline(view: View): return PipelineMeta( pipeline( [ node( func=_create_metadata, name="create_metadata", args=[view.name], inputs="parameters", outputs=f"{view}.metadata", namespace=f"{view}", ) ], tags=TAGS_VIEW_META, ), [D("metadata", f"{view}.metadata", ["view", view, "metadata"], type="pkl")], )
def _filter_keys( view: View, req_splits: list[str] | None, ratios: dict[str, float], random_state: int, keys: LazyFrame, ): return view.split_keys(keys, req_splits, ratios, random_state)
[docs]def create_keys_pipeline(view: View, splits: list[str]): fun = get_params_closure( partial(_filter_keys, view, splits), str(view), "ratios", "random_state", ) pipe = pipeline( [ node( func=fun, inputs={ "params": "parameters", "keys": f"{view.dataset}.keys", }, name="split_keys", namespace=f"{view}.keys", outputs={s: f"{view}.keys.{s}" for s in splits}, tags=TAGS_VIEW_SPLIT, ) ] ) return PipelineMeta( pipe, [D("keys", f"{view}.keys.{s}", ["view", view, "keys", s]) for s in splits], )
[docs]def create_filter_pipeline(view: View, splits: list[str]): tables = view.tables nodes = [] for split in splits: for table in tables: nodes.append( node( func=view.filter_table, args=[table], name=f"filter_{table}_{split}", inputs={ "keys": f"keys.{split}", table: f"view.{table}", }, outputs=f"{split}.{table}", namespace=split, tags=TAGS_VIEW_SPLIT, ) ) return PipelineMeta( modular_pipeline( pipe=pipeline(nodes, tags=["view"]), namespace=view.name, ), [ D("splits", f"{view}.{s}.{t}", ["view", view, s, "tables", t]) for t in tables for s in splits ], )