Source code for pasteur.kedro.pipelines.transform
from __future__ import annotations
import pandas as pd
from kedro.pipeline import Pipeline as pipeline
from ...encode import AttributeEncoderFactory, EncoderFactory, Encoder
from ...metadata import Metadata
from ...module import Module, get_module_dict
from ...table import AttributeEncoderHolder, TableTransformer
from ...utils import LazyFrame
from ...view import View
from .meta import TAGS_RETRANSFORM, TAGS_REVERSE, TAGS_TRANSFORM
from .meta import DatasetMeta as D
from .meta import PipelineMeta, node
def _fit_transformers(
name: str,
modules: list[Module],
meta: Metadata,
tables: dict[str, LazyFrame],
):
from ...table import TableTransformer
t = TableTransformer(meta, name, modules)
t.fit(tables)
return t
def _transform_table(
transformer: TableTransformer,
tables: dict[str, LazyFrame],
):
return transformer.transform(tables)
def _reverse_table(
transformer: TableTransformer,
table: LazyFrame,
ctx: dict[str, LazyFrame],
ids: LazyFrame,
parents: dict[str, LazyFrame],
):
return transformer.reverse(table, ctx, ids, parents)
def _fit_encoder(
enc: str,
modules: list[Module],
trns: dict[str, TableTransformer],
tables: dict[str, LazyFrame],
ctx: dict[str, dict[str, LazyFrame]],
ids: dict[str, LazyFrame],
):
# Get encoder factories
attr_encs = get_module_dict(AttributeEncoderFactory, modules)
encs = get_module_dict(EncoderFactory, modules)
assert (
enc not in attr_encs or enc not in encs
), f"Encoding '{enc}' is provided as both an Attribute Encoder and Encoder. Choose one."
assert enc in attr_encs or enc in encs, f"Encoder for encoding '{enc}' not found."
# Get attrs
attrs = {}
ctx_attrs = {}
for name, trn in trns.items():
table_attrs, table_ctx_attrs = trn.get_attributes()
attrs[name] = table_attrs
ctx_attrs[name] = table_ctx_attrs
# Create and fit
if enc in encs:
e = encs[enc].build()
else:
e = AttributeEncoderHolder(attr_encs[enc])
e.fit(attrs, tables, ctx_attrs, ctx, ids)
return e
def _encode_view(
encoder: Encoder,
tables: dict[str, LazyFrame],
ctx: dict[str, dict[str, LazyFrame]],
ids: dict[str, LazyFrame],
):
return encoder.encode(tables, ctx, ids)
def _decode_view(encoder: Encoder, data: dict[str, LazyFrame]):
return encoder.decode(data)
[docs]
def create_fit_pipeline(
view: View,
encs: list[str],
modules: list[Module],
split: str,
):
trn_fit_nodes = [
node(
name=f"fit_transformers_to_{t}",
func=_fit_transformers,
args=[t, modules],
inputs={
"meta": f"{view}.metadata",
"tables": {t: f"{view}.{split}.{t}" for t in view.tables},
},
outputs=f"{view}.trn.{t}",
namespace=f"{view}.trn",
)
for t in view.tables
]
enc_fit_nodes = [
node(
name=f"fit_{enc}_encoder",
func=_fit_encoder,
args=[enc, modules],
inputs={
"trns": {t: f"{view}.trn.{t}" for t in view.tables},
"tables": {t: f"{view}.{split}.bst_{t}" for t in view.tables},
"ctx": {t: f"{view}.{split}.ctx_{t}" for t in view.tables},
"ids": {t: f"{view}.{split}.ids_{t}" for t in view.tables},
},
outputs=f"{view}.enc.{enc}",
namespace=f"{view}.enc",
)
for enc in encs
if enc not in ("raw", "bst")
]
return PipelineMeta(
pipeline(trn_fit_nodes + enc_fit_nodes, tags=TAGS_TRANSFORM),
[
D("transformers", f"{view}.trn.{t}", ["view", view, "trn", t], type="pkl")
for t in view.tables
]
+ [
D("encoders", f"{view}.enc.{enc}", ["view", view, "enc", enc], type="pkl")
for enc in encs
if enc not in ("raw", "bst")
],
)
[docs]
def create_transform_pipeline(
view: View,
split: str,
types: list[str],
retransform: bool = False,
):
table_nodes = []
outputs = []
for t in view.tables:
if not retransform:
table_nodes += [
node(
func=_transform_table,
name=f"transform_{t}_for_{split}",
inputs={
"transformer": f"{view}.trn.{t}",
"tables": {t: f"{view}.{split}.{t}" for t in view.tables},
},
outputs=[
f"{view}.{split}.bst_{t}",
f"{view}.{split}.ctx_{t}",
f"{view}.{split}.ids_{t}",
],
namespace=f"{view}.{split}",
),
]
layer = "view_transformed" if split == "view" else "split_transformed"
outputs.append(
D(
layer,
f"{view}.{split}.ctx_{t}",
["view", view, split, "ctx", t],
type="multi",
)
)
outputs.append(
D(
layer,
f"{view}.{split}.bst_{t}",
["view", view, split, "bst", t],
)
)
outputs.append(
D(
layer,
f"{view}.{split}.ids_{t}",
["view", view, split, "ids", t],
)
)
for enc in types:
if enc in ("bst", "raw"):
continue
table_nodes += [
node(
func=_encode_view,
name=f"encode_{enc}",
inputs={
"encoder": f"{view}.enc.{enc}",
"tables": {t: f"{view}.{split}.bst_{t}" for t in view.tables},
"ctx": {t: f"{view}.{split}.ctx_{t}" for t in view.tables},
"ids": {t: f"{view}.{split}.ids_{t}" for t in view.tables},
},
outputs=f"{view}.{split}.{enc}",
namespace=f"{view}.{split}",
)
]
if retransform:
layer = "synth_reencoded"
elif split == "view":
layer = "view_encoded"
else:
layer = "split_encoded"
outputs.append(
D(
# FIXME: Pass proper layer properly, don't infer
layer,
f"{view}.{split}.{enc}",
["synth" if retransform else "view", view, split, enc],
versioned=retransform,
type="multi",
)
)
if not table_nodes:
return PipelineMeta(pipeline([]), outputs)
return PipelineMeta(
pipeline(table_nodes, tags=TAGS_RETRANSFORM if retransform else TAGS_TRANSFORM),
outputs,
)
[docs]
def create_reverse_pipeline(view: View, alg: str, enc: str):
decode_nodes = [
node(
func=_decode_view,
name=f"decode_synthetic_data",
inputs={
"encoder": f"{view}.enc.{enc}",
"data": f"{view}.{alg}.enc",
},
outputs=[
{t: f"{view}.{alg}.bst_{t}" for t in view.tables},
{t: f"{view}.{alg}.ctx_{t}" for t in view.tables},
{t: f"{view}.{alg}.ids_{t}" for t in view.tables},
],
namespace=f"{view}.{alg}",
),
]
outputs = []
for t in view.tables:
decode_nodes += [
node(
func=_reverse_table,
name=f"reverse_{t}",
inputs={
"transformer": f"{view}.trn.{t}",
"table": f"{view}.{alg}.bst_{t}",
"ctx": f"{view}.{alg}.ctx_{t}",
"ids": f"{view}.{alg}.ids_{t}",
"parents": {
req: f"{view}.{alg}.{req}" for req in view.trn_deps.get(t, [])
},
},
outputs=f"{view}.{alg}.{t}",
namespace=f"{view}.{alg}",
),
]
outputs.extend(
[
D(
"synth_decoded",
f"{view}.{alg}.bst_{t}",
["synth", view, alg, "bst", t],
versioned=True,
),
D(
"synth_decoded",
f"{view}.{alg}.ids_{t}",
["synth", view, alg, "ids", t],
versioned=True,
),
D(
"synth_decoded",
f"{view}.{alg}.ctx_{t}",
["synth", view, alg, "ctx", t],
versioned=True,
type="multi",
),
D(
"synth_reversed",
f"{view}.{alg}.{t}",
["synth", view, alg, "tables", t],
versioned=True,
),
]
)
pipe = pipeline(
decode_nodes,
tags=TAGS_REVERSE,
)
return PipelineMeta(pipe, outputs)