Source code for pasteur.kedro.hooks.pasteur

import logging
from os import path
from typing import Any, Callable

import yaml
from kedro.framework.context import KedroContext
from kedro.framework.hooks import hook_impl
from kedro.framework.project import pipelines
from kedro.io import DataCatalog, Version
from kedro.io.memory_dataset import MemoryDataset

from ...module import Module
from ..dataset import AutoDataset, PickleDataset, Multiset
from ..pipelines import generate_pipelines
from ..pipelines.main import NAME_LOCATION, get_view_names

logger = logging.getLogger(__name__)


[docs]class PasteurHook: def __init__( self, modules: list[Module] | Callable[[Any], list[Module]] | Callable[[], list[Module]], ) -> None: self.lazy_modules = modules self.modules = None self._param_hash = None self._module_id = None
[docs] def update_data(self): params = self.context.params _param_hash = hash(str(params)) _module_id = id(self.lazy_modules) if _param_hash == self._param_hash and _module_id == self._module_id: # SKip computation, params and modules are the same logger.debug("Using cached pipelines") return self._param_hash = _param_hash self._module_id = _module_id if callable(self.lazy_modules): try: self.modules = self.lazy_modules(params) # type: ignore except Exception: self.modules = self.lazy_modules() # type: ignore else: self.modules = self.lazy_modules ( self.pipelines, self.outputs, self.catalogs, self.parameters, ) = generate_pipelines(self.modules, params)
[docs] @hook_impl def after_context_created( self, context: KedroContext, ) -> None: self.raw_location = context.params["raw_location"] self.base_location = context.params["base_location"] self.context = context self.update_data() # FIXME: clean this up # Add pipelines pipelines._load_data() pipelines._content.update(self.pipelines) # Add view metadata for loaded modules extra_params = {} for name, view_params in self.parameters.items(): # dict gets added straight away if isinstance(view_params, dict): extra_params[name] = view_params # string is considered to point to a file else: extra_params[name] = context.config_loader.get(view_params).copy() # Add hidden dict with views to remove their params in mlflow assert self.modules extra_params["_views"] = get_view_names(self.modules) # Restore original overrides if context._extra_params: extra_params.update(context._extra_params) # Apply overrides context._extra_params = extra_params setattr(context, "pasteur", self)
[docs] def get_version(self, name: str, versioned: bool): load_version = ( self.load_versions.get(name, None) if self.load_versions else None ) if versioned: return Version(load_version, self.save_version) return None
[docs] def add_set(self, layer, name, path_seg, versioned=False, multi=False): fn = path.join( self.base_location, *path_seg[:-1], path_seg[-1], ) if multi: ds = Multiset( fn, { "type": AutoDataset, "save_args": self.pq_save_args, "version": self.get_version(name, versioned), "metadata": {"kedro-viz": {"layer": layer}} if layer else None, }, ) else: ds = AutoDataset( fn + ".pq", save_args=self.pq_save_args, version=self.get_version(name, versioned), # type: ignore metadata={"kedro-viz": {"layer": layer}} if layer else None, ) self.catalog.add( name, ds, ) if layer: self.catalog.layers[layer].add(name)
[docs] def add_pkl(self, layer, name, path_seg, versioned=False): self.catalog.add( name, PickleDataset( path.join( self.base_location, *path_seg[:-1], path_seg[-1] + ".pkl", ), version=self.get_version(name, versioned), # type: ignore metadata={"kedro-viz": {"layer": layer}} if layer else None, ), )
[docs] def add_mem(self, layer, name): self.catalog.add( name, MemoryDataset(metadata={"kedro-viz": {"layer": layer}} if layer else None), ) if layer: self.catalog.layers[layer].add(name)
[docs] @hook_impl def after_catalog_created( self, catalog: DataCatalog, conf_creds: dict[str, Any], save_version: str, load_versions: dict[str, str], ) -> None: # Parquet converts timestamps, but synthetic data can contain ns variations # which result in a loss of quality. This causes an exception. # By defining save args explicitly that exception is ignored. self.pq_save_args = { "coerce_timestamps": "us", "allow_truncated_timestamps": True, } self.catalog = catalog self.save_version = save_version self.load_versions = load_versions if catalog.layers is None: from collections import defaultdict catalog.layers = defaultdict(set) # Add raw datasets from packaged datasets # Just replace `${<folder_name>_location}` with raw/<folder_name> or that parameter if self.catalogs: params = self.context.params for ds, folder_name, ds_catalog in self.catalogs: name = NAME_LOCATION.format(folder_name) if isinstance(ds_catalog, str): with open(ds_catalog, "r") as f: data = f.read() if folder_name: raw_dir = params.get( name, path.join(self.raw_location, folder_name) ) data = data.replace(f"${{location}}", raw_dir) data = data.replace( f"${{bootstrap}}", path.join(self.base_location, "bootstrap", folder_name), ) conf = yaml.safe_load(data) else: conf = ds_catalog # Normalize old catalog names to be '{ds}.raw@{name}' unless # they are already that # TODO: find clear criteria for when to do it conf = { f"{ds}.raw@{name}" if "." not in name else name: dataset for name, dataset in conf.items() } # Place all datasets to the raw layer for d in conf.values(): if "metadata" not in d: d["metadata"] = {"kedro-viz": {"layer": "raw"}} elif "kedro-viz" not in d["metadata"]: d["metadata"]["kedro-viz"] = {"layer": "raw"} elif "layer" not in d["metadata"]["kedro-viz"]: d["metadata"]["kedro-viz"]["layer"] = "raw" tmp_catalog = DataCatalog.from_config( conf, conf_creds, load_versions, save_version, ) catalog.add_all(tmp_catalog._data_sets) if tmp_catalog.layers: # Passthrough layers if they are not provided through metadata for layer, children in tmp_catalog.layers.items(): catalog.layers[layer] = { *children, *catalog.layers.get(layer, set()), } # Add pipeline outputs for d in self.outputs: match d.type: case "pkl": self.add_pkl(d.layer, d.name, d.str_path, d.versioned) case "pq": self.add_set(d.layer, d.name, d.str_path, d.versioned) case "mpq": self.add_set(d.layer, d.name, d.str_path, d.versioned, multi=True) case "auto": self.add_set(d.layer, d.name, d.str_path, d.versioned) case "multi": self.add_set(d.layer, d.name, d.str_path, d.versioned, multi=True) case "mem": self.add_mem(d.layer, d.name) case _: assert False, "Not implemented"