Source code for pasteur.kedro.pipelines.synth

from __future__ import annotations

from typing import TYPE_CHECKING

from kedro.pipeline import Pipeline as pipeline

from ...synth import synth_fit, synth_sample
from .meta import DatasetMeta as D
from .meta import PipelineMeta, node, TAGS_SYNTH, TAG_GPU
from .utils import gen_closure

if TYPE_CHECKING:
    from ...synth import SynthFactory
    from ...view import View


[docs] def create_synth_pipeline( view: View, split: str, fr: SynthFactory, ): tags: list[str] = list(TAGS_SYNTH) if fr.in_types is not None: assert fr.type in fr.in_types, f"in_types must include type for '{fr.name}'" data_in = ( f"{view}.{split}.{fr.type}" if fr.in_types is None else {t: f"{view}.{split}.{t}" for t in fr.in_types} ) synth_in = {"data": data_in} if fr.in_sample else {} pipe = pipeline( [ node( func=synth_fit, name=f"fitting_{fr.name}", args=[fr], inputs={ "metadata": f"{view}.metadata", "encoder": ( f"{view}.enc.{fr.type}" if fr.in_types is None else {t: f"{view}.enc.{t}" for t in fr.in_types} ), "data": data_in, # type: ignore }, namespace=f"{view}.{fr.name}", outputs=f"{view}.{fr.name}.model", tags=tags, ), node( func=synth_sample, inputs={ "s": f"{view}.{fr.name}.model", **synth_in, # type: ignore }, outputs=f"{view}.{fr.name}.enc", namespace=f"{view}.{fr.name}", tags=tags, ), ] ) outputs = [ D( "synth_models", f"{view}.{fr.name}.model", ["synth", view, fr.name, "model"], versioned=True, type="pkl", ), D( "synth_output", f"{view}.{fr.name}.enc", ["synth", view, fr.name, "enc"], versioned=True, type="multi", ), ] return PipelineMeta(pipe, outputs)