Source code for pasteur.table

"""Contains the logic for handling multiple tables, and holding transformers and
encoders.

The functionality is achieved through a class named `ReferenceManager`, which
is used to generate the ID tables, and the `TransformHolder`, which holds
everything required to encode and transform a table.

The `TransformHolder` holds a `TableTransformer` which hosts the Table's transformers
and multiple `TableEncoder`s, which can be accesed with array syntax (ex. `['idx']`),
one for each supported encoding.

Once the TransformHolder is fit, it can be loaded and used to transform, encode,
reverse, and decode table partitions."""

import logging
from collections import defaultdict
from typing import Any, Callable, Generic, Literal, Mapping, TypeVar, cast

import numpy as np
import pandas as pd
from pandas.api.types import is_float_dtype

from pasteur.attribute import (
    Attribute,
    Attributes,
    SeqAttribute,
    SeqValue,
    get_dtype,
)
from pasteur.module import Module, ModuleFactory, get_module_dict
from pasteur.transform import SeqTransformer, Transformer, TransformerFactory

from .attribute import Attribute, Attributes
from .encode import (
    AttributeEncoder,
    AttributeEncoderFactory,
    EncoderFactory,
    PostprocessEncoder,
    ViewEncoder,
)
from .metadata import ColumnRef, Metadata
from .module import Module, get_module_dict
from .transform import RefTransformer, SeqTransformer, Transformer, TransformerFactory
from .utils import LazyChunk, LazyFrame, LazyPartition, lazy_load_tables, to_chunked, get_relationships
from .utils.progress import process_in_parallel, reduce

logger = logging.getLogger(__file__)

A = TypeVar("A", bound="Any")
META = TypeVar("META")



def _reduce_inner(
    a: dict[str | tuple[str, ...], A],
    b: dict[str | tuple[str, ...], A],
):
    for key in a.keys():
        a[key].reduce(b[key])
    return a


[docs] class ReferenceManager: """Manages the foreign relationships of a table.""" def __init__( self, meta: Metadata, name: str, ) -> None: self.name = name self.meta = meta
[docs] def find_parents(self, table: str) -> list[tuple[str, str, str]]: """Finds the reference cols that link a table to its parents and returns tuples with the name of that column, the name of the parent table and the name of the parent key column (usually the primary key).""" res = [] meta = self.meta[table] for name, col in meta.cols.items(): if col.is_id(): assert isinstance(name, str), "ids only require one column" if col.ref: assert not isinstance(col.ref, list), "ids only have 1 reference" ref_table = col.ref.table assert ref_table ref_col = col.ref.col if not ref_col: ref_col = self.meta[ref_table].primary_key res.append((name, ref_table, ref_col)) return res
[docs] def get_id_cols(self, name: str, ref: bool = False): """Returns the id column names of the provided table. If ref is set to True, only the ids with a reference are returned.""" meta = self.meta[name] return [ n for n, col in meta.cols.items() if col.is_id() and n != meta.primary_key and (not ref or meta[n].ref) ]
[docs] def get_table_data(self, name: str, table: pd.DataFrame): """Returns the data columns of a table.""" return table.drop(columns=self.get_id_cols(name, False))
[docs] def get_foreign_keys(self, name: str, table: pd.DataFrame): """Returns the id columns of a table with a foreign reference.""" return table[self.get_id_cols(name, True)]
[docs] def find_foreign_ids(self, name: str, get_table: Callable[[str], pd.DataFrame]): """Creates an id lookup table for the provided table. The id lookup table is composed of columns that have the foreign table name and contain its index key to join on.""" ids = self.get_foreign_keys(name, get_table(name)) for col, table, foreign_col in self.find_parents(name): assert ( foreign_col == self.meta[table].primary_key ), "Only referencing primary keys supported for now." ids = ids.rename(columns={col: table}) foreign_ids = self.find_foreign_ids(table, get_table) ids = ids[~pd.isna(ids[table])].join(foreign_ids, on=table, how="inner") return ids
[docs] def table_has_reference(self): """Checks whether the table has a column that depends on another table's column for transformation.""" name = self.name meta = self.meta[name] for col in meta.cols.values(): if col.is_id() and col.ref is not None: assert not isinstance(col.ref, list), "ids can only have one reference" assert ( col.ref.table is not None ), "ids with a reference should have one on a foreign table" return True return False
def _calc_joined_refs( name: str, get_table: Callable[[str], pd.DataFrame], ids: pd.DataFrame | None, cref: list[ColumnRef] | ColumnRef, table: pd.DataFrame | None = None, ): """Returns a dataframe where for each row in the original data, reference values are provided matching the ones in cref. In the case of one reference, a series is returned. If no references are provided, returns None.""" if table is None: table = get_table(name) if isinstance(cref, list): table_cols: dict[str | None, list[str]] = defaultdict(list) for ref in cref: assert ref.col is not None table_cols[ref.table].append(ref.col) dfs = [] for rtable, refs in table_cols.items(): if rtable: assert ids is not None df = ids.join(get_table(rtable)[refs], on=rtable)[refs].add_prefix( f"{rtable}." ) dfs.append(df) else: dfs.append(table[refs]) return pd.concat(dfs, axis=1) else: ref = cast(ColumnRef, cref) f_table, f_col = ref.table, ref.col assert f_col if f_table: # Foreign column from another table assert ids is not None ref_col = ids.join(get_table(f_table)[f_col], on=f_table)[f_col] else: # Local column, duplicate and rename ref_col = table[f_col] return ref_col def _calc_unjoined_refs( name: str, get_table: Callable[[str], pd.DataFrame], cref: list[ColumnRef] | ColumnRef | None, table: pd.DataFrame | None = None, ): """Returns a dictionary containing columns from all upstream parents, as required by `cref`. If `cref` is None, None is returned.""" table_cols: dict[str, list[str]] = defaultdict(list) if cref and isinstance(cref, list): for ref in cref: assert ref.col is not None table_cols[ref.table or name].append(ref.col) elif cref: ref = cast(ColumnRef, cref) assert ref.col is not None table_cols[ref.table or name].append(ref.col) def get_table_l(k): if k == name and table is not None: return table return get_table(k) return {k: get_table_l(k)[v] for k, v in table_cols.items()}
[docs] class TableTransformer: def __init__( self, meta: Metadata, name: str, modules: list[Module], ) -> None: self.name = name self.meta = meta self.transformer_cls = get_module_dict( TransformerFactory, [*modules, SeqTransformerWrapper.get_factory(modules=modules)], ) self.ref = ReferenceManager(meta, name) self.transformers: dict[str | tuple[str, ...], Transformer] = {} self.fitted = False
[docs] def fit( self, tables: dict[str, LazyFrame], ids: LazyFrame | None = None, ): per_call = [] for cids, ctables in LazyFrame.zip_values([ids, tables]): per_call.append({"ids": cids, "tables": ctables}) transformer_chunks: list[dict[str | tuple[str, ...], Transformer]] = ( process_in_parallel( self.fit_chunk, per_call, desc=f"Fitting transformers for '{self.name}'", ) ) self.transformers = reduce(_reduce_inner, transformer_chunks) self.fitted = True
[docs] def fit_chunk( self, tables: dict[str, LazyChunk], ids: LazyChunk | None = None, ): get_table = lazy_load_tables(tables) # type: ignore loaded_ids = self._load_ids(ids, get_table) meta = self.meta[self.name] table = get_table(self.name) transformers = {} if loaded_ids is not None and len( table.index.symmetric_difference(loaded_ids.index) ): old_len = len(table) table = table.reindex(loaded_ids.index) logger.warn( f"There are missing ids for rows in {self.name}, dropping {old_len-len(table)}/{old_len} rows with missing ids." ) if not meta.primary_key == table.index.name: assert ( False ), "Properly formatted datasets should have their primary key as their index column" # table.reindex(meta.primary_key) # Process sequencer first seq_name = meta.sequencer if seq_name: col = meta.cols[seq_name] assert ( col.type in self.transformer_cls ), f"Column type {col.type} not in transformers:\n{list(self.transformer_cls.keys())}" # Fit transformer if "main_param" in col.args: t = self.transformer_cls[col.type].build( col.args["main_param"], **col.args ) else: t = self.transformer_cls[col.type].build(**col.args) assert isinstance( t, SeqTransformer ), f"Sequencer must be of type 'SeqTransformer', not '{type(t)}'" # Add foreign column if required ref_cols = _calc_unjoined_refs(self.name, get_table, col.ref, table) res = t.fit(self.name, table[seq_name], ref_cols, loaded_ids) assert res seq_attr, seq = res transformers[seq_name] = t else: seq_attr = seq = None for name, col in meta.cols.items(): if seq_name == name: continue if col.is_id(): continue name_l = list(name) if isinstance(name, tuple) else name assert ( col.type in self.transformer_cls ), f"Column type {col.type} not in transformers:\n{list(self.transformer_cls.keys())}" # Fit transformer if "main_param" in col.args: t = self.transformer_cls[col.type].build( col.args["main_param"], **col.args ) else: t = self.transformer_cls[col.type].build(**col.args) if isinstance(t, SeqTransformer): # Add foreign column if required ref_cols = _calc_unjoined_refs(self.name, get_table, col.ref, table) t.fit(self.name, table[name_l], ref_cols, loaded_ids, seq_attr, seq) elif isinstance(t, RefTransformer): # Add foreign column if required cref = col.ref ref_cols = ( _calc_joined_refs(self.name, get_table, loaded_ids, cref, table) if cref else None ) t.fit(table[name_l], ref_cols) else: t.fit(table[name_l]) transformers[name] = t return transformers
def _load_ids( self, ids: LazyChunk | pd.DataFrame | None, get_table: Callable[[str], pd.DataFrame], ): """Loads ids only if required. If `ids` is None, it calculates them anew using `get_table` and the reference manager.""" if not self.ref.table_has_reference(): return None if callable(ids): return ids() return self.ref.find_foreign_ids(self.name, get_table)
[docs] def transform_chunk( self, tables: dict[str, LazyChunk], ids: LazyChunk | None = None, ): assert self.fitted get_table = lazy_load_tables(tables) # type: ignore loaded_ids = self._load_ids(ids, get_table) meta = self.meta[self.name] table = get_table(self.name) tts = [] ctxs = defaultdict(list) # Return the index for an empty table if not [c for c in meta.cols.values() if not c.is_id()]: edf = pd.DataFrame(index=table.index) return edf, {}, loaded_ids if loaded_ids is not None else edf if loaded_ids is not None and len( table.index.symmetric_difference(loaded_ids.index) ): old_len = len(table) table = table.reindex(loaded_ids.index) logger.warn( f"There are missing ids for rows in {self.name}, dropping {old_len-len(table)}/{old_len} rows with missing ids." ) # Process sequencer first seq_name = meta.sequencer if seq_name: col = meta.cols[seq_name] trn = self.transformers[seq_name] assert isinstance(trn, SeqTransformer) ref_cols = _calc_unjoined_refs(self.name, get_table, col.ref, table) assert loaded_ids is not None res = trn.transform(table[seq_name], ref_cols, loaded_ids) assert len(res) == 3 tt, ctx, seq = res tts.append(tt) for n, c in ctx.items(): ctxs[n].append(c) else: seq = None for name, col in meta.cols.items(): # Skip sequencer if seq_name == name: continue if col.is_id(): continue name_l = list(name) if isinstance(name, tuple) else name trn = self.transformers[name] if isinstance(trn, SeqTransformer): # Add foreign column if required ref_cols = _calc_unjoined_refs(self.name, get_table, col.ref, table) assert loaded_ids is not None res = trn.transform(table[name_l], ref_cols, loaded_ids, seq) tt = res[0] ctx = res[1] for n, c in ctx.items(): ctxs[n].append(c) elif isinstance(trn, RefTransformer): # Add foreign column if required cref = col.ref ref_cols = ( _calc_joined_refs(self.name, get_table, loaded_ids, cref, table) if cref else None ) tt = trn.transform(table[name_l], ref_cols) else: tt = trn.transform(table[name_l]) tts.append(tt) table = pd.concat(tts, axis=1, copy=False, join="inner") ctx = { n: pd.concat(c, axis=1, copy=False, join="inner") for n, c in ctxs.items() } ids_table = ( loaded_ids if loaded_ids is not None else pd.DataFrame(index=table.index) ) return table, ctx, ids_table
@to_chunked def _transform_chunk( self, tables: dict[str, LazyChunk], ids: LazyChunk | None = None, ): return self.transform_chunk(tables, ids)
[docs] def transform( self, tables: dict[str, LazyFrame], ids: LazyFrame | None = None, ): return self._transform_chunk(tables, ids) # type: ignore
@to_chunked def _reverse_chunk( self, data: LazyChunk, ctx: dict[str, LazyChunk], ids: LazyChunk, parent_tables: dict[str, LazyChunk], ): # If there are no ids that reference a foreign table, the ids and # parent_table parameters can be set to None (ex. tabular data). assert self.fitted cached_table = data() cached_ids = ids() if ids is not None else pd.DataFrame() cached_ctx = {n: c() for n, c in ctx.items()} get_parent = ( lazy_load_tables(parent_tables) if parent_tables else lambda _: pd.DataFrame() ) meta = self.meta[self.name] # Add ids # If an id references another table it will be merged from the ids # dataframe. Otherwise, it will be set to 0 (irrelevant to data synthesis). tts = {} for name, col in meta.cols.items(): if not col.is_id() or name == meta.primary_key: continue if col.ref is not None: assert ( not isinstance(col.ref, list) and col.ref.table and cached_ids is not None ) tts[name] = cached_ids[col.ref.table].rename(name) else: tts[name] = 0 # Process columns using a for loop based topological sort completed_cols = set() processed_col = True while processed_col: processed_col = False for name, col in meta.cols.items(): # Skip already processed columns if name in completed_cols: continue # Skip ids if col.is_id(): continue # Check column requirements cref = col.ref if cref and isinstance(cref, list): # Check ref requirements met unmet_requirements = False for ref in cref: if not ref.table and not ref.col in tts: unmet_requirements = True break if unmet_requirements: continue elif cref: ref = cast(ColumnRef, cref) # Check ref requirements met if not ref.table and ref.col not in completed_cols: continue t = self.transformers[name] if isinstance(t, SeqTransformer): # TODO: remove pd.concat if it's slow ref_col = _calc_unjoined_refs( self.name, get_parent, cref, pd.concat(tts.values(), axis=1, copy=False, join="inner"), ) tt = t.reverse(cached_table, cached_ctx, ref_col, cached_ids) elif isinstance(t, RefTransformer): if cref: for r in cref if isinstance(cref, list) else [cref]: rtable = r.table if rtable: assert ( rtable in parent_tables ), f"Table '{self.name}' depends on table '{r.table}', but it was not specified in the parameter 'trn_deps' of the view." ref_col = _calc_joined_refs( self.name, get_parent, cached_ids, cref, pd.concat(tts.values(), axis=1, copy=False, join="inner"), ) else: ref_col = None tt = t.reverse(cached_table, ref_col) else: tt = t.reverse(cached_table) processed_col = True completed_cols.add(name) if isinstance(name, str): tts[name] = tt else: for n in name: tts[n] = tt[n] decoded_cols = sum( len(n) if isinstance(n, tuple) else 1 for n in meta.cols.keys() if n != meta.primary_key ) assert ( len(tts) == decoded_cols ), f"Did not process column in this loop. There are columns with cyclical dependencies." # If the table is empty, return an empty dataframe if decoded_cols == 0: return pd.DataFrame(index=cached_table.index) # Create decoded table del cached_table, cached_ids, get_parent dec_table = pd.concat(tts.values(), axis=1, copy=False, join="inner") del tts # Re-order columns to metadata based order cols = [] for key in meta.cols.keys(): if key == meta.primary_key: continue if isinstance(key, str): cols.append(key) else: cols.extend(key) dec_table = dec_table[cols] return dec_table
[docs] def reverse( self, data: LazyFrame, ctx: dict[str, LazyFrame], ids: LazyFrame | None = None, parent_tables: dict[str, LazyFrame] | None = None, ): return self._reverse_chunk(data, ctx, ids, parent_tables) # type: ignore
[docs] def get_attributes(self) -> tuple[Attributes, dict[str, Attributes]]: """Returns information about the transformed columns and their hierarchical attributes.""" assert self.fitted attrs = {} ctx_attrs = defaultdict(dict) for t in self.transformers.values(): if isinstance(t, SeqTransformer): t_attrs, t_ctx_attrs = t.get_attributes() for n, c in t_ctx_attrs.items(): ctx_attrs[n].update(c) else: t_attrs = t.get_attributes() assert isinstance( t_attrs, dict ), f"Transformer `{t.name}` did not return a dictionary as attributes." attrs.update(t_attrs) return attrs, dict(ctx_attrs)
[docs] def get_sequencer(self) -> SeqTransformer | None: s = self.meta[self.name].sequencer if s is not None: return cast(SeqTransformer, self.transformers[s]) return None
def _fit_encoders_for_table( factory: AttributeEncoderFactory, attrs: Attributes, data: LazyChunk ): table = data() encs = {} for name, attr in attrs.items(): enc = factory.build() enc.fit(attr, table) encs[name] = enc return encs @to_chunked def _return_df(name: str, df: LazyChunk): return {name: df()} @to_chunked def _return_ids(name: str, df: LazyChunk): return {}, {}, {name: df()}
[docs] class AttributeEncoderHolder( ViewEncoder[dict[str, dict[str | tuple[str, ...], META]]], Generic[META] ): """Receives tables that have been encoded by the base transformers and have attributes, and reformats them to fit a specific model.""" table_encoders: dict[str, dict[str | tuple[str, ...], AttributeEncoder[META]]] ctx_encoders: dict[ str, dict[str, dict[str | tuple[str, ...], AttributeEncoder[META]]] ] postprocess_enc: ( PostprocessEncoder[META, Any] | None ) def __init__(self, encoder: AttributeEncoderFactory, **kwargs) -> None: self.kwargs = kwargs self._encoder_fr = encoder self.table_encoders = {} self.ctx_encoders = {} tst = encoder.build() if isinstance(tst, PostprocessEncoder): self.postprocess_enc = tst else: self.postprocess_enc = None
[docs] def fit( self, attrs: dict[str, Attributes], tables: dict[str, LazyFrame], ctx_attrs: dict[str, dict[str, Attributes]], ctx: dict[str, dict[str, LazyFrame]], ids: dict[str, LazyFrame], ): self.encoders = {} self.relationships = get_relationships(ids) self.tables = list(tables.keys()) base_args = {"factory": self._encoder_fr} per_call = [] per_call_meta = [] # Create granular tasks for all table partitions for name in tables: for part in tables[name].values(): per_call.append({"attrs": attrs[name], "data": part}) per_call_meta.append({"ctx": False, "table": name}) for creator, ctxs in ctx.items(): for name in ctxs: for part in ctxs[name].values(): per_call.append({"attrs": ctx_attrs[creator][name], "data": part}) per_call_meta.append( {"ctx": True, "creator": creator, "table": name} ) # Process them out = process_in_parallel( _fit_encoders_for_table, per_call, base_args, desc="Fitting encoders" ) # Entangle output table_enc = defaultdict(list) ctx_enc = defaultdict(lambda: defaultdict(list)) for enc, meta in zip(out, per_call_meta): if meta["ctx"]: ctx_enc[meta["creator"]][meta["table"]].append(enc) else: table_enc[meta["table"]].append(enc) # Reduce resulting encoders self.table_encoders = {} self.ctx_encoders = defaultdict(dict) for name, encs in table_enc.items(): self.table_encoders[name] = reduce(_reduce_inner, encs) for creator, ctx_encs in ctx_enc.items(): for name, encs in ctx_encs.items(): self.ctx_encoders[creator][name] = reduce(_reduce_inner, encs) self.fitted = True
@to_chunked def _encode_postprocess( self, ids: Mapping[str, LazyChunk], tables: Mapping[str, LazyChunk], ctx: Mapping[str, Mapping[str, LazyChunk]], ): assert self.postprocess_enc is not None def encode_table(enc, table: LazyChunk): tts = [] cached_table = table() for attr_name, attr_enc in enc.items(): tts.append(attr_enc.encode(cached_table)) if tts: return pd.concat(tts, axis=1, copy=False, join="inner") else: return pd.DataFrame(index=cached_table.index) tables_enc = {} ctx_enc = {} for name, enc in self.table_encoders.items(): tables_enc[name] = encode_table(enc, tables[name]) for creator, cencs in self.ctx_encoders.items(): ctx_enc[creator] = {} for name, enc in cencs.items(): if not enc: continue ctx_enc[creator][name] = encode_table(enc, ctx[creator][name]) return self.postprocess_enc.finalize( self.get_metadata(), {k: v() for k, v in ids.items()}, tables_enc, ctx_enc, ) @to_chunked def _encode_chunk(self, name: str, table: LazyChunk, ctx: dict[str, LazyChunk]): tts = [] cached_table = table() cached_ctx = {n: c() for n, c in ctx.items()} for enc in self.table_encoders[name].values(): tts.append(enc.encode(cached_table)) for creator, ctx_encs in self.ctx_encoders.items(): if name not in ctx_encs: continue for enc in ctx_encs[name].values(): tts.append(enc.encode(cached_ctx[creator])) if tts: return {name: pd.concat(tts, axis=1, copy=False, join="inner")} else: return {name: pd.DataFrame(index=cached_table.index)}
[docs] def encode( self, tables: dict[str, LazyFrame], ctx: dict[str, dict[str, LazyFrame]], ids: dict[str, LazyFrame], ): if self.postprocess_enc is not None: return self._encode_postprocess(ids, tables, ctx) lazies = set() for name in tables: table_ctx = {} for creator, child_ctx in ctx.items(): if name in child_ctx: table_ctx[creator] = child_ctx[name] lazies |= self._encode_chunk(name, tables[name], table_ctx) # Passthrough ids for name, tid in ids.items(): lazies |= _return_df(name + "_ids", tid) return lazies
@to_chunked def _decode_chunk(self, name: str, data: LazyChunk) -> tuple[ dict[str, pd.DataFrame], dict[str, dict[str, pd.DataFrame]], dict[str, pd.DataFrame], ]: table = data() # Decode main table tts = [] for enc in self.table_encoders[name].values(): tts.append(enc.decode(table)) if tts: tables = {name: pd.concat(tts, axis=1, copy=False, join="inner")} else: tables = {name: pd.DataFrame(index=table.index)} # Decode context tables ctx_tts: dict[str, list] = defaultdict(list) for creator, ctx_encs in self.ctx_encoders.items(): if name not in ctx_encs: continue for ctx_enc in ctx_encs[name].values(): ctx_tts[creator].append(ctx_enc.decode(table)) ctx = { creator: {name: pd.concat(tts, axis=1, copy=False, join="inner")} for creator, tts in ctx_tts.items() } return tables, ctx, {} @to_chunked def _decode_postprocess(self, data: Mapping[str, LazyChunk]): assert False, "Decoding for postprocess encoders not implemented yet. Only used for flat encoder." assert self.postprocess_enc is not None ids, tables, ctx = self.postprocess_enc.undo(self.get_metadata(), dict(data)) return tables, ctx, ids
[docs] def decode( self, data: dict[str, LazyFrame], ): if self.postprocess_enc is not None: self._decode_postprocess(data) lazies = set() for name, table in data.items(): if not name.endswith("_ids"): lazies |= self._decode_chunk(name, table) # Passthrough ids for name, tid in data.items(): if name.endswith("_ids"): lazies |= _return_ids(name.replace("_ids", ""), tid) return lazies
[docs] def get_metadata(self) -> dict[str, dict[str | tuple[str, ...], META]]: out = {k: {} for k in self.tables} attrs = defaultdict(dict) ctx_attrs = defaultdict(lambda: defaultdict(dict)) for table, encs in self.table_encoders.items(): for enc in encs.values(): out[table].update(enc.get_metadata()) attrs[table].update(enc.get_metadata()) for cname, cencs in self.ctx_encoders.items(): for table, encs in cencs.items(): for enc in encs.values(): out[table].update(enc.get_metadata()) ctx_attrs[cname][table].update(enc.get_metadata()) if self.postprocess_enc is not None: return self.postprocess_enc.get_post_metadata(self.relationships, dict(attrs), dict(ctx_attrs)) # type: ignore return dict(out)
def _backref_cols( ids: pd.DataFrame, seq: pd.Series, data: pd.DataFrame | pd.Series, parent: str ): # Ref is calculated by mapping each id in data_df by merging its parent # key, sequence number to parent key, and the number - 1 and finding the # corresponding id for that row. Then, a join is performed. _IDX_NAME = "_id_lkjijk" _JOIN_NAME = "_id_zdjwk" ids_seq_prev = ids.join(seq + 1, how="right").reset_index(names=_JOIN_NAME) ids_seq = ids.join(seq, how="right").reset_index(names=_IDX_NAME) # FIXME: ids become float join_ids = ids_seq.merge( ids_seq_prev, on=[parent, seq.name], how="inner" ).set_index(_IDX_NAME)[ [_JOIN_NAME] ] # type: ignore ref_df = join_ids.join(data, on=_JOIN_NAME).drop(columns=_JOIN_NAME) ref_df.index.name = data.index.name if isinstance(data, pd.Series): return ref_df[data.name] return ref_df def _calculate_seq(data: pd.Series, ids: pd.DataFrame, parent: str, col_seq: str): _ID_SEQ = "_id_sdfasdf" seq = ( cast( pd.Series, pd.concat({parent: ids[parent], _ID_SEQ: data}, axis=1) .groupby(parent)[_ID_SEQ] .rank("first", na_option="top"), ) - 1 ) max_len = int(cast(float, seq.max())) + 1 return seq.astype(get_dtype(max_len + 1)).rename(col_seq)
[docs] class SeqTransformerWrapper(SeqTransformer): name = "seq" mode: Literal["dual", "single", "notrn"] def __init__( self, modules: list[Module], parent: str | None = None, seq: dict[str, Any] | None = None, ctx: dict[str, Any] | None = None, seq_col: str | None = None, ctx_to_ref: dict[str, str] | None = None, order: int | None = None, max_len: int | None = None, first_seq_ref_itself: bool = False, ref: Any | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.parent_arg = parent self.seq_col_ref = seq_col self.ctx_to_ref = ctx_to_ref self.order = order self.max_len_set = max_len self.first_seq_ref_itself = first_seq_ref_itself if ctx is None: assert ( not first_seq_ref_itself ), "For the first item in the sequence to reference itself during transform, it needs to be generated by a ctx transformer. Either set `first_seq_ref_itself` to False or provide a ctx transformer." # Load transformers # Use three modes if seq is not None and ctx is not None: self.mode = "dual" elif seq is not None: self.mode = "single" else: self.mode = "notrn" if seq is not None: seq_kwargs = seq.copy() seq_type = seq_kwargs.pop("type") if not seq_kwargs.get("nullable", False) and not first_seq_ref_itself and (not ref or ctx): logger.warning( f"Setting `nullable` to true for ref transformer `{seq_type}`, since reference for first seq will always be None. Set `first_seq_ref_itself` to True, to avoid introducing null values." ) seq_kwargs["nullable"] = True self.seq = get_module_dict(TransformerFactory, modules)[ cast(str, seq_type) ].build(**seq_kwargs) assert isinstance(self.seq, RefTransformer) if ctx is not None: ctx_kwargs = ctx.copy() ctx_type = ctx_kwargs.pop("type") self.ctx = get_module_dict(TransformerFactory, modules)[ cast(str, ctx_type) ].build(**ctx_kwargs) assert isinstance(self.ctx, Transformer)
[docs] def fit( self, table: str, data: pd.Series | pd.DataFrame, ref: dict[str, pd.DataFrame], ids: pd.DataFrame, seq_val: SeqValue | None = None, seq: pd.Series | None = None, ) -> tuple[SeqValue, pd.Series] | None: self.table = table # Grab parent from seq_val if available if seq_val is not None: assert ( seq_val.table is not None ), f"Transformer Wrapper does not support seq values without parents yet" self.parent = seq_val.table self.col_seq = seq_val.name else: self.col_seq = f"{table}_seq" if not self.parent_arg: try: if len(ids.columns) == 1: # If there is only 1 column in ids, assume it's the parent self.parent = cast(str, ids.columns[0]) else: # Try using the parent as the first reference table self.parent = cast(str, next(iter(ref))) logger.info( f"Assuming parent of table '{table}' is '{self.parent}' from references." ) except Exception as e: raise Exception( "Could not infer parent from references, please specify the table which acts as the parent with parameter `parent`." ) from e else: self.parent = self.parent_arg assert ( self.parent ), "Parent table not specified, use parameter 'parent' or a foreign reference." # If seq was not provided self.generate_seq = False if seq is None: self.generate_seq = True if isinstance(data, pd.DataFrame): assert ( self.seq_col_ref is not None ), f"Multiple columns are provided as input, specify which one is used sequence the table through parameter `seq_col`." seq_col = data[self.seq_col_ref] else: seq_col = data seq = _calculate_seq(seq_col, ids, self.parent, self.col_seq) if self.max_len_set is not None: self.max_len = self.max_len_set else: self.max_len = cast(int, seq.max()) + 1 match self.mode: case "dual": self._dual_fit(self.parent, data, ref, ids, seq) case "single": self._single_fit(self.parent, data, ref, ids, seq) case "notrn": assert ( self.generate_seq ), "No transformers, so column is going to be used as a sequence" assert isinstance( data, pd.Series ), "Expected a single column for sequencing this table" self.col_orig = cast(str, data.name) # If a seq_val was not provided, assume seq was also none and # become the sequencer if seq_val is None: return SeqValue(self.col_seq, self.parent, self.order, self.max_len), cast( pd.Series, seq )
def _dual_fit( self, parent: str, data: pd.Series | pd.DataFrame, ref: dict[str, pd.DataFrame], ids: pd.DataFrame, seq: pd.Series, ): ctx_data = ( ids[[parent]] .join(data[seq == 0], how="right") .drop_duplicates(subset=[parent]) .set_index(parent) ) if isinstance(data, pd.Series): ctx_data = ctx_data[next(iter(ctx_data))] if ref: ctx_ref = ids.drop_duplicates(subset=[parent]) for name, ref_table in ref.items(): ctx_ref = ctx_ref.join(ref_table, on=name, how="left") ctx_ref = ctx_ref.set_index(parent).drop( columns=[d for d in ids.columns if d != parent] ) if ctx_ref.shape[1] == 1: ctx_ref = ctx_ref[next(iter(ctx_ref))] assert isinstance( self.ctx, RefTransformer ), f"Reference found, initial transformer should be a reference transformer." self.ctx.fit(ctx_data, ctx_ref) else: self.ctx.fit(ctx_data) # Data series is all rows where seq > 0 (skip initial) ref_df = _backref_cols(ids, seq, data, parent) self.seq.fit(data[seq > 0], ref_df) def _single_fit( self, parent: str, data: pd.Series | pd.DataFrame, ref: dict[str, pd.DataFrame], ids: pd.DataFrame, seq: pd.Series, ): ref_df = _backref_cols(ids, seq, data, parent) if ref: ctx_ref = ids[seq == 0].drop_duplicates(subset=[self.parent]) for name, ref_table in ref.items(): ctx_ref = ctx_ref.join(ref_table, on=name, how="left") ctx_ref = ctx_ref.drop(columns=ids.columns) if ctx_ref.shape[1] == 1: ctx_ref = ctx_ref[next(iter(ctx_ref))] if isinstance(ref_df, pd.Series) and isinstance(ctx_ref, pd.Series): ref_df = pd.concat([ctx_ref, ref_df]) elif isinstance(ref_df, pd.DataFrame) and isinstance(ctx_ref, pd.DataFrame): if self.ctx_to_ref: ctx_ref = ctx_ref.rename(columns=self.ctx_to_ref) ref_df = pd.concat([ctx_ref, ref_df], axis=0) assert ( ref_df.shape[1] == ctx_ref.shape[1] ), f"Parent columns not joined correctly to reference ones. If they have different names, pass in `ctx_to_ref` with names mapping them to parents" else: assert ( False ), "fixme: mismatched reference column counts. If single column transformer, both should be series, otherwise both should be dataframes" self.seq.fit(data, ref_df)
[docs] def reduce(self, other: "SeqTransformerWrapper"): if self.mode == "dual": self.ctx.reduce(other.ctx) if self.mode != "notrn": self.seq.reduce(other.seq) self.max_len = max(other.max_len, self.max_len)
[docs] def transform( self, data: pd.Series | pd.DataFrame, ref: dict[str, pd.DataFrame], ids: pd.DataFrame, seq: pd.Series | None = None, ) -> ( tuple[pd.DataFrame, dict[str, pd.DataFrame]] | tuple[pd.DataFrame, dict[str, pd.DataFrame], pd.Series] ): parent = cast(str, self.parent) if self.generate_seq: if isinstance(data, pd.DataFrame): assert ( self.seq_col_ref is not None ), f"Multiple columns are provided as input, specify which one is used sequence the table through parameter `seq_col`." seq_col = data[self.seq_col_ref] else: seq_col = data seq = _calculate_seq(seq_col, ids, parent, self.col_seq) else: assert seq is not None match self.mode: case "dual": enc, ctx = self._dual_trn(parent, data, ref, ids, seq) case "single": enc, ctx = self._single_trn(parent, data, ref, ids, seq) case "notrn": enc = pd.DataFrame() ctx = pd.DataFrame() if self.generate_seq: return ( pd.concat([enc, seq], axis=1), {parent: ctx}, seq, ) return enc, {parent: ctx}
def _dual_trn( self, parent: str, data: pd.Series | pd.DataFrame, ref: dict[str, pd.DataFrame], ids: pd.DataFrame, seq: pd.Series, ): ctx_data = ( ids[[parent]] .join(data[seq == 0], how="right") .drop_duplicates(subset=[parent]) .set_index(parent) ) if ctx_data.shape[1] == 1: ctx_data = ctx_data[next(iter(ctx_data))] if ref: ctx_ref = ids.drop_duplicates(subset=[parent]) for name, ref_table in ref.items(): ctx_ref = ctx_ref.join(ref_table, on=name, how="left") ctx_ref = ctx_ref.set_index(parent).drop( columns=[d for d in ids.columns if d != parent] ) if ctx_ref.shape[1] == 1: ctx_ref = ctx_ref[next(iter(ctx_ref))] assert isinstance( self.ctx, RefTransformer ), f"Reference found, initial transformer should be a reference transformer." ctx = self.ctx.transform(ctx_data, ctx_ref) else: ctx = self.ctx.transform(ctx_data) if self.first_seq_ref_itself: # Avoid introducing null values to a non null column by having # the sequence transformer produce stand-in values for the first seq. # These values are thrown away during reversing, since the first # value is produced by the ctx transformer. # The model may incorrectly fit to them for subsequent rows but it # may be better than it introducing null values. ref_df = _backref_cols(ids, seq, data, parent) enc = self.seq.transform(data, data.where(seq == 0, ref_df)) else: # Data series is all rows where seq > 0 (skip initial) ref_df = _backref_cols(ids, seq, data, parent) enc = self.seq.transform(data[seq > 0], ref_df) # Fill seq == 0 ints with 0 and floats with nan enc = enc.reindex(data.index, fill_value=0) for k, d in enc.dtypes.items(): if is_float_dtype(d): enc.loc[seq == 0, k] = np.nan return enc, ctx def _single_trn( self, parent: str, data: pd.Series | pd.DataFrame, ref: dict[str, pd.DataFrame], ids: pd.DataFrame, seq: pd.Series, ): ref_df = _backref_cols(ids, seq, data, parent) if ref: ctx_ref = ids[seq == 0].drop_duplicates(subset=[self.parent]) for name, ref_table in ref.items(): ctx_ref = ctx_ref.join(ref_table, on=name, how="left") ctx_ref = ctx_ref.drop(columns=ids.columns) if ctx_ref.shape[1] == 1: ctx_ref = ctx_ref[next(iter(ctx_ref))] if isinstance(ref_df, pd.Series) and isinstance(ctx_ref, pd.Series): ref_df = pd.concat([ctx_ref, ref_df]) elif isinstance(ref_df, pd.DataFrame) and isinstance(ctx_ref, pd.DataFrame): if self.ctx_to_ref: ctx_ref = ctx_ref.rename(columns=self.ctx_to_ref) ref_df = pd.concat([ctx_ref, ref_df], axis=0) assert ( ref_df.shape[1] == ctx_ref.shape[1] ), f"Parent columns not joined correctly to reference ones. If they have different names, pass in `ctx_to_ref` with names mapping them to parents" else: assert ( False ), "fixme: mismatched reference column counts. If single column transformer, both should be series, otherwise both should be dataframes" return self.seq.transform(data, ref_df), pd.DataFrame() def _single_reverse( self, data: pd.DataFrame, ctx: dict[str, pd.DataFrame], ref: dict[str, pd.DataFrame], ids: pd.DataFrame, ) -> pd.DataFrame: seq = data[self.col_seq] parent = cast(str, self.parent) if ref: ctx_ref = ids[seq.reindex(ids.index) == 0].drop_duplicates( subset=[self.parent] ) for name, ref_table in ref.items(): ctx_ref = ctx_ref.join(ref_table, on=name, how="left") ctx_ref = ctx_ref.drop(columns=ids.columns) if self.ctx_to_ref: ctx_ref = ctx_ref.rename(columns=self.ctx_to_ref) if ctx_ref.shape[1] == 1: ctx_ref = ctx_ref[next(iter(ctx_ref))] else: ctx_ref = None # Data series is all rows where seq > 0 (skip initial) out = [] i = 0 while True: data_df = data[seq == i] if not len(data_df): break if i > 0: ref_df = ( ids.loc[data_df.index] .merge( ( ids[[parent]] .join(out[-1], how="left") .groupby(parent) .first() ), left_on=parent, right_index=True, how="left", ) .drop(columns=ids.columns) ) if ref_df.shape[1] == 1: ref_df = ref_df[next(iter(ref_df))] assert len(ref_df) == len( data_df ), "fixme: experimental, there is a join error." else: ref_df = ctx_ref out.append(pd.DataFrame(self.seq.reverse(data_df, ref_df))) i += 1 return pd.concat(out, axis=0) def _dual_reverse( self, data: pd.DataFrame, ctx: dict[str, pd.DataFrame], ref: dict[str, pd.DataFrame], ids: pd.DataFrame, ) -> pd.DataFrame: seq = data[self.col_seq] parent = cast(str, self.parent) ctx_data = ids.drop_duplicates(subset=[parent]) for name, ctx_table in ctx.items(): ctx_data = ctx_data.join(ctx_table, on=name, how="left") ctx_data = ctx_data.set_index(parent) if ref: ctx_ref = ids.drop_duplicates(subset=[parent]) for name, ref_table in ref.items(): ctx_ref = ctx_ref.join(ref_table, on=name, how="left") ctx_ref = ctx_ref.set_index(parent).drop( columns=[d for d in ids.columns if d != parent] ) if ctx_ref.shape[1] == 1: ctx_ref = ctx_ref[next(iter(ctx_ref))] assert isinstance( self.ctx, RefTransformer ), f"Reference found, initial transformer should be a reference transformer." ctx_dec = self.ctx.reverse(ctx_data, ctx_ref) else: ctx_dec = self.ctx.reverse(ctx_data) out = [ ids.loc[seq == 0, [parent]] .join(ctx_dec, on=parent, how="right") .drop(columns=[parent]) ] # Data series is all rows where seq > 0 (skip initial) i = 1 while True: seq_mask = seq == i data_df = data[seq_mask] i += 1 if not len(data_df): break ref_df = ( ids.loc[data_df.index] .merge( (ids[[parent]].join(out[-1], how="left").groupby(parent).first()), left_on=parent, right_index=True, how="left", ) .drop(columns=ids.columns) ) if ref_df.shape[1] == 1: ref_df = ref_df[next(iter(ref_df))] assert len(ref_df) == len( data_df ), "fixme: experimental, there is a join error." out.append(pd.DataFrame(self.seq.reverse(data_df, ref_df))) return pd.concat(out, axis=0)
[docs] def reverse( self, data: pd.DataFrame, ctx: dict[str, pd.DataFrame], ref: dict[str, pd.DataFrame], ids: pd.DataFrame, ) -> pd.DataFrame: match self.mode: case "dual": return self._dual_reverse(data, ctx, ref, ids) case "single": return self._single_reverse(data, ctx, ref, ids) case "notrn": return pd.DataFrame(data[self.col_seq].rename(self.col_orig)) assert False
[docs] def get_attributes(self) -> tuple[Attributes, dict[str, Attributes]]: if self.generate_seq: match self.mode: case "dual": return { **self.seq.get_attributes(), self.col_seq: SeqAttribute( self.col_seq, self.parent, self.order, self.max_len ), }, {self.parent: self.ctx.get_attributes()} case "single": return { **self.seq.get_attributes(), self.col_seq: SeqAttribute( self.col_seq, self.parent, self.order, self.max_len ), }, {self.parent: {}} case "notrn": return { self.col_seq: SeqAttribute( self.col_seq, self.parent, self.order, self.max_len ) }, {self.parent: {}} else: match self.mode: case "dual": return self.seq.get_attributes(), { self.parent: self.ctx.get_attributes(), } case "single": return self.seq.get_attributes(), {} case "notrn": assert False assert False
[docs] def get_seq_value(self) -> SeqValue | None: return SeqValue(self.col_seq, self.parent, self.order, self.max_len)