Source code for pasteur.kedro.pipelines.synth
from __future__ import annotations
from typing import TYPE_CHECKING
from kedro.pipeline import Pipeline as pipeline
from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline
from ...synth import synth_fit, synth_sample
from .meta import DatasetMeta as D
from .meta import PipelineMeta, node, TAGS_SYNTH, TAG_GPU
from .utils import gen_closure
if TYPE_CHECKING:
from ...synth import SynthFactory
from ...view import View
[docs]def create_synth_pipeline(
view: View,
split: str,
fr: SynthFactory,
):
tags: list[str] = list(TAGS_SYNTH)
pipe = pipeline(
[
node(
func=synth_fit,
name=f"fitting_{fr.name}",
args=[fr],
inputs={
"metadata": f"{view}.metadata",
"encoder": f"{view}.enc.{fr.type}",
"data": f"{view}.{split}.{fr.type}",
},
namespace=f"{view}.{fr.name}",
outputs=f"{view}.{fr.name}.model",
tags=tags,
),
node(
func=synth_sample,
inputs=f"{view}.{fr.name}.model",
outputs=f"{view}.{fr.name}.enc",
namespace=f"{view}.{fr.name}",
tags=tags,
),
]
)
outputs = [
D(
"synth_models",
f"{view}.{fr.name}.model",
["synth", view, fr.name, 'model'],
versioned=True,
type="pkl",
),
D(
"synth_output",
f"{view}.{fr.name}.enc",
["synth", view, fr.name, 'enc'],
versioned=True,
type="multi",
),
]
return PipelineMeta(pipe, outputs)