Source code for pasteur.kedro.hooks.pasteur

import logging
from os import path
from typing import Any, Callable

import yaml
from kedro.config.abstract_config import MissingConfigException
from kedro.framework.context import KedroContext
from kedro.framework.hooks import hook_impl
from kedro.framework.project import pipelines
from kedro.io import DataCatalog, Version
from kedro.io.memory_dataset import MemoryDataset
from kedro_datasets.json import JSONDataset

from ...module import Module
from ..dataset import AutoDataset, Multiset, PickleDataset
from ..pipelines import generate_pipelines
from ..pipelines.main import NAME_LOCATION, get_view_names

logger = logging.getLogger(__name__)


def _load_config(fn: str):
    import yaml

    with open(fn, encoding="utf8") as yml:
        d = yaml.safe_load(yml)

    assert isinstance(d, dict), f"Could not load config file: '{fn}'"
    return {k: v for k, v in d.items() if not k.startswith("_")}


[docs] class PasteurHook: def __init__( self, modules: ( list[Module] | Callable[[Any], list[Module]] | Callable[[], list[Module]] ), ) -> None: self.lazy_modules = modules self.modules = None self._param_hash = None self._module_id = None self.load_any = False
[docs] def update_data(self): params = self.context.params _param_hash = hash(str(params)) _module_id = id(self.lazy_modules) if _param_hash == self._param_hash and _module_id == self._module_id: # SKip computation, params and modules are the same logger.debug("Using cached pipelines") return self._param_hash = _param_hash self._module_id = _module_id if callable(self.lazy_modules): try: self.modules = self.lazy_modules(params) # type: ignore except Exception: self.modules = self.lazy_modules() # type: ignore else: self.modules = self.lazy_modules ( self.pipelines, self.outputs, self.catalogs, self.parameters, ) = generate_pipelines(self.modules, params, self.locations)
# Has to be first to add location hook. # FIXME: remove try_first
[docs] @hook_impl(tryfirst=True) def after_context_created( self, context: KedroContext, ) -> None: try: # Try to use location configs for locations patterns = getattr(context.config_loader, "config_patterns", {}) if "locations" not in patterns: patterns["locations"] = ["location*", "location*/**", "**/location*"] locations = context.config_loader.get("locations") # Allow without duplicate errors if "hidden_raw" in locations: locations["raw"] = locations.pop("hidden_raw") if "hidden_base" in locations: locations["base"] = locations.pop("hidden_base") except MissingConfigException: locations = {} logger.warn( f"Consider using a 'locations.yml' file in the future. Using paths from params." ) def location_resolver(loc: str, default=None): if "_location" in loc: logger.warn( "Found '_location' in location name. Not required in locations.yml file." ) dir = locations.get(loc, default) if not dir: logger.warn( f"Location '{loc}' not found in 'locations.yml'. Falling back to `parameters.yml`." ) dir = context.params.get( loc + "_location", context.params.get(loc, None) ) assert dir, f"Dir '{loc}' not found." return context.project_path / dir # Try to register resolver with OmegaConfigLoader if hasattr(context.config_loader, "_register_new_resolvers"): getattr(context.config_loader, "_register_new_resolvers")( {"location": location_resolver} ) self.raw_location = location_resolver("raw") self.base_location = location_resolver("base") self.locations = {k: location_resolver(k) for k in [*locations, "raw", "base"]} self.context = context self.update_data() # FIXME: clean this up # Add pipelines pipelines._load_data() pipelines._content.update(self.pipelines) # Add view metadata for loaded modules runtime_params = {} for name, view_params in self.parameters.items(): # dict gets added straight away if isinstance(view_params, dict): runtime_params[name] = view_params # string is considered to point to a file else: runtime_params[name] = _load_config(view_params).copy() # Add hidden dict with views to remove their params in mlflow assert self.modules runtime_params["_views"] = get_view_names(self.modules) # FIXME: check if this is needed # Restore original overrides if context._runtime_params: runtime_params.update(context._runtime_params) # Apply overrides context._runtime_params = runtime_params setattr(context, "pasteur", self) # Save context variable so it is available globally import pasteur.kedro as kedro_init kedro_init.context = context
[docs] def get_version(self, name: str, versioned: bool): load_version = self.save_version if self.load_any: load_version = None elif self.load_versions: load_version = self.load_versions.get(name, load_version) if versioned: return Version(load_version, self.save_version) return None
[docs] def add_set(self, layer, name, path_seg, versioned=False, multi=False): fn = path.join( self.base_location, *path_seg[:-1], path_seg[-1], ) if multi: ds = Multiset( fn, { "type": AutoDataset, "save_args": self.pq_save_args, "metadata": {"kedro-viz": {"layer": layer}} if layer else None, }, version=self.get_version(name, versioned), ) else: ds = AutoDataset( fn + ".pq", save_args=self.pq_save_args, version=self.get_version(name, versioned), # type: ignore metadata={"kedro-viz": {"layer": layer}} if layer else None, ) self.catalog[name] = ds
# if layer: # self.catalog.layers[layer].add(name)
[docs] def add_pkl(self, layer, name, path_seg, versioned=False): self.catalog[name] = PickleDataset( path.join( self.base_location, *path_seg[:-1], path_seg[-1] + ".pkl", ), version=self.get_version(name, versioned), # type: ignore metadata={"kedro-viz": {"layer": layer}} if layer else None, )
[docs] def add_json(self, layer, name, path_seg, versioned=False): self.catalog[name] = JSONDataset( filepath=path.join( self.base_location, *path_seg[:-1], path_seg[-1] + ".json", ), version=self.get_version(name, versioned), # type: ignore metadata={"kedro-viz": {"layer": layer}} if layer else None, )
[docs] def add_mem(self, layer, name): self.catalog[name] = (MemoryDataset(metadata={"kedro-viz": {"layer": layer}} if layer else None),) # type: ignore
# if layer: # self.catalog.layers[layer].add(name)
[docs] @hook_impl def after_catalog_created( self, catalog: DataCatalog, conf_creds: dict[str, Any], save_version: str, load_versions: dict[str, str], ) -> None: # Parquet converts timestamps, but synthetic data can contain ns variations # which result in a loss of quality. This causes an exception. # By defining save args explicitly that exception is ignored. self.pq_save_args = { "coerce_timestamps": "us", "allow_truncated_timestamps": True, } self.catalog = catalog self.save_version = save_version self.load_versions = load_versions # if catalog.layers is None: # from collections import defaultdict # catalog.layers = defaultdict(set) # Add raw datasets from packaged datasets # Just replace `${<folder_name>_location}` with raw/<folder_name> or that parameter if self.catalogs: params = self.context.params for ds, folder_name, ds_catalog in self.catalogs: name = NAME_LOCATION.format(folder_name) if isinstance(ds_catalog, str): with open(ds_catalog, "r") as f: data = f.read() conf = yaml.safe_load(data) else: conf = ds_catalog if folder_name: def replace_fn(s: str): return s.replace( "${location}", params.get( name, path.join(self.raw_location, folder_name), ), ).replace( "${bootstrap}", path.join( self.base_location, "bootstrap", folder_name, ), ) for val in conf.values(): for k, v in val.items(): if isinstance(v, str): val[k] = replace_fn(v) elif isinstance(v, list): val[k] = [replace_fn(x) for x in v] # Normalize old catalog names to be '{ds}.raw@{name}' unless # they are already that # TODO: find clear criteria for when to do it conf = { f"{ds}.raw@{name}" if "." not in name else name: dataset for name, dataset in conf.items() } tmp_catalog = DataCatalog.from_config( conf, conf_creds, load_versions, save_version, ) # Add all traditional layers that exist for k, v in tmp_catalog.items(): catalog[k] = v depr_tag = set() # if hasattr(tmp_catalog, "layers"): # # Passthrough layers if they are not provided through metadata # cl = getattr(tmp_catalog, "layers") # for layer, children in cl.items(): # cl[layer].update(children) # depr_tag.update(children) # Skip constructor and set metadata attribute on datasets with # a raw layer. Datasets without a metadata key word argument crash otherwise. for n, d in tmp_catalog._datasets.items(): # Datasets with layer attribute are skipped if n in depr_tag: continue if not hasattr(d, "metadata") or getattr(d, "metadata") is None: setattr(d, "metadata", {"kedro-viz": {"layer": "raw"}}) elif "kedro-viz" not in getattr(d, "metadata"): getattr(d, "metadata")["kedro-viz"] = {"layer": "raw"} elif "layer" not in getattr(d, "metadata")["kedro-viz"]: getattr(d, "metadata")["kedro-viz"]["layer"] = "raw" # Add pipeline outputs for d in self.outputs: match d.type: case "pkl": self.add_pkl(d.layer, d.name, d.str_path, d.versioned) case "pq": self.add_set(d.layer, d.name, d.str_path, d.versioned) case "mpq": self.add_set(d.layer, d.name, d.str_path, d.versioned, multi=True) case "auto": self.add_set(d.layer, d.name, d.str_path, d.versioned) case "multi": self.add_set(d.layer, d.name, d.str_path, d.versioned, multi=True) case "mem": self.add_mem(d.layer, d.name) case "json": self.add_json(d.layer, d.name, d.str_path, d.versioned) case _: assert False, "Not implemented"