from __future__ import annotations
from typing import TYPE_CHECKING
from kedro.pipeline import Pipeline as pipeline
from ...synth import synth_fit, synth_sample
from .meta import DatasetMeta as D
from .meta import PipelineMeta, node, TAGS_SYNTH, TAG_GPU
from .utils import gen_closure
if TYPE_CHECKING:
from ...synth import SynthFactory
from ...view import View
[docs]
def create_synth_pipeline(
view: View,
split: str,
fr: SynthFactory,
):
tags: list[str] = list(TAGS_SYNTH)
if fr.in_types is not None:
assert fr.type in fr.in_types, f"in_types must include type for '{fr.name}'"
data_in = (
f"{view}.{split}.{fr.type}"
if fr.in_types is None
else {t: f"{view}.{split}.{t}" for t in fr.in_types}
)
synth_in = {"data": data_in} if fr.in_sample else {}
pipe = pipeline(
[
node(
func=synth_fit,
name=f"fitting_{fr.name}",
args=[fr],
inputs={
"metadata": f"{view}.metadata",
"encoder": (
f"{view}.enc.{fr.type}"
if fr.in_types is None
else {t: f"{view}.enc.{t}" for t in fr.in_types}
),
"data": data_in, # type: ignore
},
namespace=f"{view}.{fr.name}",
outputs=f"{view}.{fr.name}.model",
tags=tags,
),
node(
func=synth_sample,
inputs={
"s": f"{view}.{fr.name}.model",
**synth_in, # type: ignore
},
outputs=f"{view}.{fr.name}.enc",
namespace=f"{view}.{fr.name}",
tags=tags,
),
]
)
outputs = [
D(
"synth_models",
f"{view}.{fr.name}.model",
["synth", view, fr.name, "model"],
versioned=True,
type="pkl",
),
D(
"synth_output",
f"{view}.{fr.name}.enc",
["synth", view, fr.name, "enc"],
versioned=True,
type="multi",
),
]
return PipelineMeta(pipe, outputs)