Source code for pasteur.metric

""" This module provides the definitions for Metric Modules.
Metric modules can fit to a column, a table, or a whole View.
In each case, modules are instanciated as required (for columns one is instantiated
per column type, for tables one per table and View metrics are instantiated once)."""

import logging
from collections import defaultdict
from typing import Any, Callable, Generic, Literal, TypedDict, TypeVar, cast

import pandas as pd
from pasteur.attribute import SeqValue
from pasteur.metadata import ColumnMeta

from pasteur.utils import LazyDataset

from .attribute import SeqValue
from .encode import Encoder
from .metadata import ColumnMeta, Metadata
from .module import Module, ModuleClass, ModuleFactory, get_module_dict_multiple
from .table import (
    ReferenceManager,
    TableTransformer,
    _calc_joined_refs,
    _calc_unjoined_refs,
    _backref_cols,
)
from .transform import SeqTransformer
from .utils import LazyChunk, LazyDataset, LazyFrame, lazy_load_tables
from .utils.progress import piter, process_in_parallel, reduce

logger = logging.getLogger(__name__)


[docs] class ColumnMetricFactory(ModuleFactory["AbstractColumnMetric"]): ...
[docs] class MetricFactory(ModuleFactory["Metric"]): def __init__( self, cls: type["Metric"], *args, name: str | None = None, **kwargs ) -> None: super().__init__(cls, *args, name=name, **kwargs) self.encodings = cls.encodings
A = TypeVar("A") _DATA = TypeVar("_DATA") _INGEST = TypeVar("_INGEST") _SUMMARY = TypeVar("_SUMMARY")
[docs] class Metric(ModuleClass, Generic[_INGEST, _SUMMARY]): """Encapsulates a special way to visualize results. The metric is provided with the metrics requested in `encodings`. If one encoding is requested and `encodings` is a string, `meta` and `data` will contain the metadata and data of that encoding. If `encodings` is a list, `meta` and `data` will be dictionaries containing the metadata and data for each encoding.""" _factory = MetricFactory encodings: str | list[str] = "raw"
[docs] def fit( self, meta: Any | dict[str, Any], data: dict[str, LazyDataset] | dict[str, dict[str, LazyDataset]], ): """Fit is used to capture information about the table or column the metric will process. It should be used to store information such as column value names, which is common among different executions of the view.""" raise NotImplementedError()
[docs] def preprocess( self, wrk: dict[str, LazyDataset] | dict[str, dict[str, LazyDataset]], ref: dict[str, LazyDataset] | dict[str, dict[str, LazyDataset]], ) -> _INGEST | None: """Preprocess is called to cache the summaries for the wrk and ref sets during ingest. Implementation is optional.""" ...
[docs] def process( self, wrk: dict[str, LazyDataset] | dict[str, dict[str, LazyDataset]], ref: dict[str, LazyDataset] | dict[str, dict[str, LazyDataset]], syn: dict[str, LazyDataset] | dict[str, dict[str, LazyDataset]], pre: _INGEST, ) -> _SUMMARY: """Process is called with each set of data from the view (reference, work, synthetic). It should capture data relevant to each metric but in a synopsis or compressed form, that can be used to compute the metric for different algorithm/split combinations. If `preprocess()` is implemented, `pre` will contain the results of the function.""" raise NotImplementedError()
[docs] def visualise(self, data: dict[str, _SUMMARY]): """Visualise is called for dicts of runs that run within the same view. It is expected to create detailed visualizations (such as tables, figures) which utilize the structure of the view (columns etc.). `comparison` is set to False when the method is run when executing a run and to true when run to compare multiple runs. It can be used to provide different summaries If required by the visualization, `wrk_set` and `ref_set` provide the names of the synthesis source data (wrk) and reference data (ref) which can be used as a reference.""" ...
[docs] def summarize(self, data: dict[str, _SUMMARY]): """Summarize is called for dicts of runs that are not necessarily from the same view. It is expected to create detailed summary metrics for the run which are dataset structure independent (such as avg KL, etc). `comparison` is set to False when the method is run when executing a run and to true when run to compare multiple runs. It can be used to provide different summaries""" ...
[docs] def unique_name(self) -> str: """Provides a unique name for the metric which will be used for the system. (currently saving artifacts).""" return self.name
[docs] class Summaries(Generic[A]): wrk: A ref: A syn: A def __init__(self, wrk: A, ref: A, syn: A | None = None) -> None: self.wrk = wrk self.ref = ref self.syn = cast(A, syn) # Skip lint check for syn being None
[docs] def replace(self, **kwargs): params = {"wrk": self.wrk, "ref": self.ref, "syn": self.syn} params.update(kwargs) return type(self)(**params)
[docs] class RefColumnData(TypedDict): data: pd.Series | pd.DataFrame ref: pd.Series | pd.DataFrame
[docs] class SeqColumnData(TypedDict): data: pd.Series | pd.DataFrame ref: dict[str, pd.DataFrame] ids: pd.DataFrame seq: pd.Series
[docs] class AbstractColumnMetric(ModuleClass, Generic[_DATA, _INGEST, _SUMMARY]): type: Literal["col", "ref", "seq"] = "col" _factory = ColumnMetricFactory
[docs] def fit(self, table: str, col: str | tuple[str, ...], data: _DATA): """Fit is used to capture information about the table or column the metric will process. It should be used to store information such as column value names, which is common among different executions of the view.""" raise NotImplementedError()
[docs] def reduce(self, other: "AbstractColumnMetric"): ...
[docs] def preprocess(self, wrk: _DATA, ref: _DATA) -> _INGEST | None: """Preprocess is called to cache the summaries for the wrk and ref sets during ingest. Implementation is optional.""" ...
[docs] def process(self, wrk: _DATA, ref: _DATA, syn: _DATA, pre: _INGEST) -> _SUMMARY: raise NotImplementedError()
[docs] def combine(self, summaries: list[_SUMMARY]) -> _SUMMARY: raise NotImplementedError()
[docs] def visualise(self, data: dict[str, _SUMMARY]): ...
[docs] def summarize(self, data: dict[str, _SUMMARY]): ...
[docs] class ColumnMetric( AbstractColumnMetric[pd.Series | pd.DataFrame, _INGEST, _SUMMARY], Generic[_INGEST, _SUMMARY], ): pass
[docs] class RefColumnMetric( AbstractColumnMetric[RefColumnData, _INGEST, _SUMMARY], Generic[_INGEST, _SUMMARY], ): pass
[docs] class SeqColumnMetric( AbstractColumnMetric[SeqColumnData, _INGEST, _SUMMARY], Generic[_INGEST, _SUMMARY], ):
[docs] def fit( self, table: str, col: str | tuple[str, ...], seq_val: SeqValue | None, data: SeqColumnData, ): raise NotImplementedError()
B = TypeVar("B", bound="Any") def _reduce_inner_2d( a: dict[str | tuple[str, ...], list[B]], b: dict[str | tuple[str, ...], list[B]], ): for key in a.keys(): for i in range(len(a[key])): a[key][i].reduce(b[key][i]) return a def _get_sequence( name: str, meta: Metadata, trn: SeqTransformer, ids: pd.DataFrame, table: pd.DataFrame, get_parent: Callable[[str], pd.DataFrame], ) -> pd.Series | None: seq_name = meta[name].sequencer assert seq_name col = meta[name].cols[seq_name] ref_cols = _calc_unjoined_refs(name, get_parent, col.ref, table) res = trn.transform(table[seq_name], ref_cols, ids) assert len(res) == 3 _, _, seq = res return seq def _fit_column_metrics( name: str, meta: Metadata, ref: ReferenceManager, trn: SeqTransformer | None, tables: dict[str, LazyChunk], metrics: dict[str, list[ColumnMetricFactory]], ): get_table = lazy_load_tables(tables) table = get_table(name) seq_val = None seq = None if ref.table_has_reference(): ids = ref.find_foreign_ids(name, get_table) if len(table.index.symmetric_difference(ids.index)): old_len = len(table) table = table.reindex(ids.index) logger.warn( f"There are missing ids for rows in {name}, dropping {old_len-len(table)}/{old_len} rows with missing ids." ) if trn is not None: seq_val = trn.get_seq_value() seq = _get_sequence(name, meta, trn, ids, table, get_table) else: ids = None out: dict[str | tuple[str, ...], list[AbstractColumnMetric]] = defaultdict(list) for col_name, col in meta[name].cols.items(): if col.is_id() or col.type not in metrics: continue for factory in metrics[col.type]: # Create metric if "main_param" in col.args: m = factory.build(col.args["main_param"], **col.args) else: m = factory.build(**col.args) if isinstance(m, ColumnMetric): m.fit(name, col_name, table[col_name]) elif isinstance(m, RefColumnMetric): cref = col.ref ref_col = _calc_joined_refs(name, get_table, ids, cref, table) if cref else None m.fit( name, col_name, RefColumnData(data=table[col_name], ref=ref_col), # type: ignore ) elif isinstance(m, SeqColumnMetric): ref_col = _calc_unjoined_refs(name, get_table, col.ref, table) assert ids is not None and seq is not None m.fit( name, col_name, seq_val, SeqColumnData(data=table[col_name], ref=ref_col, ids=ids, seq=seq), ) else: assert False, f"Unknown column metric type: {type(m)}" out[col_name].append(m) return out def _preprocess_metrics( name: str, meta: Metadata, ref: ReferenceManager, trn: SeqTransformer | None, tables_wrk: dict[str, LazyChunk], tables_ref: dict[str, LazyChunk], metrics: dict[str | tuple[str, ...], list[AbstractColumnMetric]], ): get_table_wrk = lazy_load_tables(tables_wrk) get_table_ref = lazy_load_tables(tables_ref) table_wrk = get_table_wrk(name) table_ref = get_table_ref(name) seq_wrk = None seq_ref = None if ref.table_has_reference(): ids_wrk = ref.find_foreign_ids(name, get_table_wrk) ids_ref = ref.find_foreign_ids(name, get_table_ref) if len(table_wrk.index.symmetric_difference(ids_wrk.index)): old_len = len(table_wrk) table_wrk = table_wrk.reindex(ids_wrk.index) logger.warn( f"There are missing ids for rows in {name}, dropping {old_len-len(table_wrk)}/{old_len} rows with missing ids." ) if len(table_ref.index.symmetric_difference(ids_ref.index)): old_len = len(table_ref) table_ref = table_ref.reindex(ids_ref.index) logger.warn( f"There are missing ids for rows in {name}, dropping {old_len-len(table_ref)}/{old_len} rows with missing ids." ) if trn is not None: seq_wrk = _get_sequence(name, meta, trn, ids_wrk, table_wrk, get_table_wrk) seq_ref = _get_sequence(name, meta, trn, ids_ref, table_ref, get_table_ref) else: ids_wrk = None ids_ref = None out = defaultdict(list) for col_name, ms in metrics.items(): for m in ms: col = meta[name][col_name] cref = col.ref if isinstance(m, ColumnMetric): prec = m.preprocess( table_wrk[col_name], table_ref[col_name], ) elif isinstance(m, RefColumnMetric): prec = m.preprocess( RefColumnData( data=table_wrk[col_name], ref=_calc_joined_refs( name, get_table_wrk, ids_ref, cref, table_wrk ) if cref else None, # type: ignore ), RefColumnData( data=table_ref[col_name], ref=_calc_joined_refs( name, get_table_ref, ids_ref, cref, table_ref ) if cref else None, # type: ignore ), ) elif isinstance(m, SeqColumnMetric): assert ( ids_wrk is not None and seq_wrk is not None and ids_ref is not None and seq_ref is not None ) prec = m.preprocess( SeqColumnData( data=table_wrk[col_name], ref=_calc_unjoined_refs( name, get_table_wrk, col.ref, table_wrk ), ids=ids_wrk, seq=seq_wrk, ), SeqColumnData( data=table_ref[col_name], ref=_calc_unjoined_refs( name, get_table_ref, col.ref, table_ref ), ids=ids_ref, seq=seq_ref, ), ) else: assert False, f"Unknown column metric type: {type(m)}" out[col_name].append(prec) return out def _process_metrics( name: str, meta: Metadata, ref: ReferenceManager, trn: SeqTransformer | None, tables_wrk: dict[str, LazyChunk], tables_ref: dict[str, LazyChunk], tables_syn: dict[str, LazyChunk], metrics: dict[str | tuple[str, ...], list[AbstractColumnMetric]], preprocess: dict[str | tuple[str, ...], list[Any]], ): get_table_wrk = lazy_load_tables(tables_wrk) get_table_ref = lazy_load_tables(tables_ref) get_table_syn = lazy_load_tables(tables_syn) table_wrk = get_table_wrk(name) table_ref = get_table_ref(name) table_syn = get_table_syn(name) seq_wrk = None seq_ref = None seq_syn = None if ref.table_has_reference(): ids_wrk = ref.find_foreign_ids(name, get_table_wrk) ids_ref = ref.find_foreign_ids(name, get_table_ref) ids_syn = ref.find_foreign_ids(name, get_table_syn) if len(table_wrk.index.symmetric_difference(ids_wrk.index)): old_len = len(table_wrk) table_wrk = table_wrk.reindex(ids_wrk.index) logger.warn( f"There are missing ids for rows in {name}, dropping {old_len-len(table_wrk)}/{old_len} rows with missing ids." ) if len(table_ref.index.symmetric_difference(ids_ref.index)): old_len = len(table_ref) table_ref = table_ref.reindex(ids_ref.index) logger.warn( f"There are missing ids for rows in {name}, dropping {old_len-len(table_ref)}/{old_len} rows with missing ids." ) if len(table_syn.index.symmetric_difference(ids_syn.index)): old_len = len(table_syn) table_syn = table_syn.reindex(ids_syn.index) logger.warn( f"There are missing ids for rows in {name}, dropping {old_len-len(table_syn)}/{old_len} rows with missing ids." ) if trn is not None: seq_wrk = _get_sequence(name, meta, trn, ids_wrk, table_wrk, get_table_wrk) seq_ref = _get_sequence(name, meta, trn, ids_ref, table_ref, get_table_ref) seq_syn = _get_sequence(name, meta, trn, ids_syn, table_syn, get_table_syn) else: ids_wrk = None ids_ref = None ids_syn = None out = defaultdict(list) for col_name, ms in metrics.items(): for m, prec in zip(ms, preprocess[col_name]): col = meta[name][col_name] if isinstance(m, ColumnMetric): proc = m.process( get_table_wrk(name)[col_name], get_table_ref(name)[col_name], get_table_syn(name)[col_name], prec, ) elif isinstance(m, RefColumnMetric): proc = m.process( RefColumnData( data=table_wrk[col_name], ref=_calc_joined_refs( name, get_table_wrk, ids_wrk, col.ref, table_wrk ) if col.ref else None, # type: ignore ), RefColumnData( data=table_ref[col_name], ref=_calc_joined_refs( name, get_table_ref, ids_ref, col.ref, table_ref ) if col.ref else None, # type: ignore ), RefColumnData( data=table_syn[col_name], ref=_calc_joined_refs( name, get_table_syn, ids_syn, col.ref, table_syn ) if col.ref else None, # type: ignore ), prec, ) elif isinstance(m, SeqColumnMetric): assert ( ids_wrk is not None and seq_wrk is not None and ids_ref is not None and seq_ref is not None and ids_syn is not None and seq_syn is not None ) proc = m.process( SeqColumnData( data=table_wrk[col_name], ref=_calc_unjoined_refs( name, get_table_wrk, col.ref, table_wrk ), ids=ids_wrk, seq=seq_wrk, ), SeqColumnData( data=table_ref[col_name], ref=_calc_unjoined_refs( name, get_table_ref, col.ref, table_ref ), ids=ids_ref, seq=seq_ref, ), SeqColumnData( data=table_syn[col_name], ref=_calc_unjoined_refs( name, get_table_syn, col.ref, table_syn ), ids=ids_syn, seq=seq_syn, ), prec, ) else: assert False, f"Unknown column metric type: {type(m)}" out[col_name].append(proc) return out
[docs] def name_add_prefix(col: str | tuple[str, ...], suffix: str): if isinstance(col, str): return (col, suffix) return (*col, suffix)
[docs] def name_style_fn(col: str | tuple[str, ...], ext: str | None = None): if not ext: ext = "" if isinstance(col, str): return col + ext return "_".join(col) + ext
[docs] def name_style_title(col: str | tuple[str, ...], title: str | None = None): if isinstance(col, str): if title: return f"{col.capitalize()} {title}" else: return col.capitalize() else: if title: return f"{', '.join([c.capitalize() for c in col])} {title}" else: return ", ".join([c.capitalize() for c in col])
[docs] class ColumnMetricHolder( Metric[ dict[str, list[dict[str | tuple[str, ...], list[Any]]]], dict[str, dict[str | tuple[str, ...], list[Any]]], ] ): name = "cols" encodings = "raw" metrics: dict[str, dict[str | tuple[str, ...], list[AbstractColumnMetric]]] def __init__(self, modules: list[Module]): self.table = "" self.metric_cls = get_module_dict_multiple( ColumnMetricFactory, [*modules, SeqMetricWrapper.get_factory(modules=modules)], ) self.metrics = {}
[docs] def fit( self, meta: Metadata, trns: dict[str, TableTransformer], data: dict[str, LazyFrame], ): per_call = [] per_call_meta = [] self.seqs = { k: v.get_sequencer() for k, v in trns.items() if v.get_sequencer() is not None } # Create fitting tasks for name in meta.tables: ref_mgr = ReferenceManager(meta, name) for tables in LazyFrame.zip_values(data): per_call.append( { "name": name, "meta": meta, "ref": ref_mgr, "trn": self.seqs.get(name, None), "tables": tables, "metrics": self.metric_cls, } ) per_call_meta.append(name) # Process them out = process_in_parallel( _fit_column_metrics, per_call, desc="Fitting column metrics" ) metrics = defaultdict(list) for chunk_metrics, table in zip(out, per_call_meta): metrics[table].append(chunk_metrics) self.metrics = {} for name in piter( meta.tables, desc="Reducing table modules for each table.", leave=False ): self.metrics[name] = reduce(_reduce_inner_2d, metrics[name]) self.meta = meta self.fitted = True
[docs] def preprocess( self, wrk: dict[str, LazyDataset], ref: dict[str, LazyDataset], ) -> dict[str, list[dict[str | tuple[str, ...], list[Any]]]]: per_call = [] per_call_meta = [] # Create preprocess tasks for name in self.meta.tables: ref_mgr = ReferenceManager(self.meta, name) for tables_wrk, tables_ref in LazyDataset.zip_values([wrk, ref]): per_call.append( { "name": name, "meta": self.meta, "ref": ref_mgr, "trn": self.seqs.get(name, None), "tables_wrk": tables_wrk, "tables_ref": tables_ref, "metrics": self.metrics[name], } ) per_call_meta.append(name) out = process_in_parallel( _preprocess_metrics, per_call, desc="Preprocessing column metric synopsis" ) # Fix by partition pre_dict = defaultdict(list) for name, pre in zip(per_call_meta, out): pre_dict[name].append(pre) return pre_dict
[docs] def process( self, wrk: dict[str, LazyDataset], ref: dict[str, LazyDataset], syn: dict[str, LazyDataset], pre: dict[str, list[dict[str | tuple[str, ...], list[Any]]]], ) -> dict[str, dict[str | tuple[str, ...], list[Any]]]: per_call = [] per_call_meta = [] # Create preprocess tasks for name in self.meta.tables: ref_mgr = ReferenceManager(self.meta, name) # FIXME: Syn may not be partitioned the same way as the others # This ideally would need an API change. If partition numbers are # different, there will be a truncation. for i, (tables_wrk, tables_ref, tables_syn) in enumerate(zip( LazyDataset.zip_values(wrk), LazyDataset.zip_values(ref), LazyDataset.zip_values(syn) ) ): per_call.append( { "name": name, "meta": self.meta, "ref": ref_mgr, "trn": self.seqs.get(name, None), "tables_wrk": tables_wrk, "tables_ref": tables_ref, "tables_syn": tables_syn, "metrics": self.metrics[name], "preprocess": pre[name][i], } ) per_call_meta.append(name) out = process_in_parallel( _process_metrics, per_call, desc="Processing column metric synopsis" ) # Fix by partition proc_dict = defaultdict(list) for name, proc in zip(per_call_meta, out): proc_dict[name].append(proc) procs: dict[str, dict[str | tuple[str, ...], list[Any]]] = defaultdict( lambda: defaultdict(list) ) for name, table_metrics in self.metrics.items(): for cols, col_metrics in table_metrics.items(): for i, metric in enumerate(col_metrics): d = [ proc_dict[name][j][cols][i] for j in range(len(proc_dict[name])) ] o = metric.combine(d) procs[name][cols].append(o) return dict(procs)
[docs] def visualise( self, data: dict[str, dict[str, dict[str | tuple[str, ...], list[Any]]]] ): for table, table_metrics in self.metrics.items(): for col_name, col_metrics in table_metrics.items(): for i, metric in enumerate(col_metrics): metric.visualise( {n: d[table][col_name][i] for n, d in data.items()} )
[docs] def summarize( self, data: dict[str, dict[str, dict[str | tuple[str, ...], list[Any]]]] ): for table, table_metrics in self.metrics.items(): for col_name, col_metrics in table_metrics.items(): for i, metric in enumerate(col_metrics): metric.summarize( {n: d[table][col_name][i] for n, d in data.items()} )
[docs] def fit_column_holder( modules: list[Module], metadata: Metadata, trns: dict[str, TableTransformer], data: dict[str, LazyFrame], ): holder = ColumnMetricHolder(modules) holder.fit(meta=metadata, trns=trns, data=data) return holder
[docs] def fit_metric( fs: MetricFactory, metadata: Metadata, encoder: Encoder | dict[str, Encoder], data: dict[str, LazyDataset] | dict[str, dict[str, LazyDataset]], ): module = fs.build(**metadata.metrics.get(fs.name, {})) if isinstance(fs.encodings, list): assert isinstance(encoder, dict) meta = {name: enc.get_metadata() for name, enc in encoder.items() if name in fs.encodings} if "raw" in fs.encodings: meta["raw"] = metadata else: if fs.encodings == "raw": meta = metadata elif fs.encodings == "bst": # FIXME: Should be something meta = None else: assert isinstance(encoder, Encoder) meta = encoder.get_metadata() module.fit(meta=meta, data=data) return module
[docs] class SeqMetricWrapper(SeqColumnMetric): name = "seq" mode: Literal["dual", "single", "notrn"] def __init__( self, modules: list[Module], visual: dict[str, Any] | None = None, seq: dict[str, Any] | None = None, ctx: dict[str, Any] | None = None, seq_col: str | None = None, ctx_to_ref: dict[str, str] | None = None, order: int | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.seq_col_ref = seq_col self.ctx_to_ref = ctx_to_ref self.order = order self.seq = [] self.ctx = [] self.visual = [] # Load metrics # Use three modes if seq is not None and ctx is not None: self.mode = "dual" elif seq is not None: self.mode = "single" else: self.mode = "notrn" if seq is not None: seq_kwargs = seq.copy() seq_type = seq_kwargs.pop("type") seq_kwargs["nullable"] = True seq_frs = get_module_dict_multiple(ColumnMetricFactory, modules).get( cast(str, seq_type), [] ) self.seq = [f.build(**seq_kwargs) for f in seq_frs] if ctx is not None: ctx_kwargs = ctx.copy() ctx_type = ctx_kwargs.pop("type") ctx_frs = get_module_dict_multiple(ColumnMetricFactory, modules).get( cast(str, ctx_type), [] ) self.ctx = [f.build(**ctx_kwargs) for f in ctx_frs] if visual is not None: visual_kwargs = visual.copy() visual_type = visual_kwargs.pop("type") visual_frs = get_module_dict_multiple(ColumnMetricFactory, modules).get( cast(str, visual_type), [] ) self.visual = [f.build(**visual_kwargs) for f in visual_frs]
[docs] def fit( self, table: str, col: str | tuple[str, ...], seq_val: SeqValue | None, data: SeqColumnData, ): seq = data["seq"] assert ( seq_val is not None and seq is not None ), "Wrapping RefTransformers requires sequenced data, fill in `sequencer` for the table." self.table = table self.col = col self.max_len = cast(int, seq.max()) + 1 self.parent = seq_val.table match self.mode: case "dual": if self.ctx: ctx_in = _wrap_get_data_ctx(self.parent, **data) for c in self.ctx: c.fit( self.table, name_add_prefix(self.col, "ctx"), ctx_in, ) # Data series is all rows where seq > 0 (skip initial) if self.seq: seq_in = _wrap_get_data_seq_dual(self.parent, **data) for c in self.seq: c.fit( self.table, name_add_prefix(self.col, "seq"), seq_in, ) case "single": if self.seq: seq_in = _wrap_get_data_seq_single( self.parent, **data, ctx_to_ref=self.ctx_to_ref ) for c in self.seq: c.fit( self.table, name_add_prefix(self.col, "seq"), seq_in, ) case "notrn": pass for c in self.visual: if isinstance(c, SeqColumnMetric): c.fit(self.table, self.col, seq_val, data) else: c.fit(self.table, self.col, data)
[docs] def preprocess(self, wrk: SeqColumnData, ref: SeqColumnData) -> Any | None: pre_viz = [] for c in self.visual: pre_viz.append(c.preprocess(wrk, ref)) match self.mode: case "dual": pre_ctx = [] if self.ctx: data_wrk = _wrap_get_data_ctx(self.parent, **wrk) data_ref = _wrap_get_data_ctx(self.parent, **ref) for c in self.ctx: pre_ctx.append(c.preprocess(data_wrk, data_ref)) pre_seq = [] if self.seq: data_wrk = _wrap_get_data_seq_dual(self.parent, **wrk) data_ref = _wrap_get_data_seq_dual(self.parent, **ref) for c in self.seq: pre_seq.append(c.preprocess(data_wrk, data_ref)) return (pre_viz, pre_ctx, pre_seq) case "single": pre_seq = [] if self.seq: data_wrk = _wrap_get_data_seq_single( self.parent, **wrk, ctx_to_ref=self.ctx_to_ref ) data_ref = _wrap_get_data_seq_single( self.parent, **ref, ctx_to_ref=self.ctx_to_ref ) for c in self.seq: pre_seq.append(c.preprocess(data_wrk, data_ref)) return (pre_viz, pre_seq) case "notrn": return (pre_viz,) assert False
[docs] def process( self, wrk: SeqColumnData, ref: SeqColumnData, syn: SeqColumnData, pre: Any ) -> Any: proc_viz = [] for c, p in zip(self.visual, pre[0]): proc_viz.append(c.process(wrk, ref, syn, p)) match self.mode: case "dual": proc_ctx = [] if self.ctx: data_wrk = _wrap_get_data_ctx(self.parent, **wrk) data_ref = _wrap_get_data_ctx(self.parent, **ref) data_syn = _wrap_get_data_ctx(self.parent, **syn) for c, p in zip(self.ctx, pre[1]): proc_ctx.append(c.process(data_wrk, data_ref, data_syn, p)) proc_seq = [] if self.seq: data_wrk = _wrap_get_data_seq_dual(self.parent, **wrk) data_ref = _wrap_get_data_seq_dual(self.parent, **ref) data_syn = _wrap_get_data_seq_dual(self.parent, **syn) for c, p in zip(self.seq, pre[2]): proc_seq.append(c.process(data_wrk, data_ref, data_syn, p)) return (proc_viz, proc_ctx, proc_seq) case "single": proc_seq = [] if self.seq: data_wrk = _wrap_get_data_seq_single( self.parent, **wrk, ctx_to_ref=self.ctx_to_ref ) data_ref = _wrap_get_data_seq_single( self.parent, **ref, ctx_to_ref=self.ctx_to_ref ) data_syn = _wrap_get_data_seq_single( self.parent, **syn, ctx_to_ref=self.ctx_to_ref ) for c, p in zip(self.seq, pre[1]): proc_seq.append(c.process(data_wrk, data_ref, data_syn, p)) return (proc_viz, proc_seq) case "notrn": return (proc_viz,) assert False
[docs] def combine(self, summaries: list) -> Any: sum_viz = [] for i, c in enumerate(self.visual): sum_viz.append(c.combine([s[0][i] for s in summaries])) match self.mode: case "dual": sum_ctx = [] for i, c in enumerate(self.ctx): sum_ctx.append(c.combine([s[1][i] for s in summaries])) sum_seq = [] for i, c in enumerate(self.seq): sum_seq.append(c.combine([s[2][i] for s in summaries])) return (sum_viz, sum_ctx, sum_seq) case "single": sum_seq = [] for i, c in enumerate(self.seq): sum_seq.append(c.combine([s[1][i] for s in summaries])) return (sum_viz, sum_seq) case "notrn": return (sum_viz,) assert False
def _distr(self, data: dict[str, Any], fun: str): for i, c in enumerate(self.visual): getattr(c, fun)({k: v[0][i] for k, v in data.items()}) match self.mode: case "dual": for i, c in enumerate(self.ctx): getattr(c, fun)({k: v[1][i] for k, v in data.items()}) for i, c in enumerate(self.seq): getattr(c, fun)({k: v[2][i] for k, v in data.items()}) case "single": for i, c in enumerate(self.seq): getattr(c, fun)({k: v[1][i] for k, v in data.items()}) case "notrn": pass
[docs] def visualise(self, data: dict[str, Any]): return self._distr(data, "visualise")
[docs] def summarize(self, data: dict[str, Any]): return self._distr(data, "summarize")
def _wrap_get_data_ctx( parent: str, data: pd.Series | pd.DataFrame, ref: dict[str, pd.DataFrame], ids: pd.DataFrame, seq: pd.Series, ): ctx_data = ( ids[[parent]] .join(data[seq == 0], how="right") .drop_duplicates(subset=[parent]) .set_index(parent) ) if isinstance(data, pd.Series): ctx_data = ctx_data[next(iter(ctx_data))] ctx_ref = None if ref: ctx_ref = ids.drop_duplicates(subset=[parent]) for name, ref_table in ref.items(): ctx_ref = ctx_ref.join(ref_table, on=name, how="left") ctx_ref = ctx_ref.set_index(parent).drop( columns=[d for d in ids.columns if d != parent] ) if ctx_ref.shape[1] == 1: ctx_ref = ctx_ref[next(iter(ctx_ref))] return {"data": ctx_data, "ref": ctx_ref} def _wrap_get_data_seq_dual( parent: str, data: pd.Series | pd.DataFrame, ref: dict[str, pd.DataFrame], ids: pd.DataFrame, seq: pd.Series, ): # Data series is all rows where seq > 0 (skip initial) ref_df = _backref_cols(ids, seq, data, parent) return {"data": data[seq > 0], "ref": ref_df} def _wrap_get_data_seq_single( parent: str, data: pd.Series | pd.DataFrame, ref: dict[str, pd.DataFrame], ids: pd.DataFrame, seq: pd.Series, ctx_to_ref, ): ref_df = _backref_cols(ids, seq, data, parent) if ref: ctx_ref = ids[seq == 0].drop_duplicates(subset=[parent]) for name, ref_table in ref.items(): ctx_ref = ctx_ref.join(ref_table, on=name, how="left") ctx_ref = ctx_ref.drop(columns=ids.columns) if ctx_ref.shape[1] == 1: ctx_ref = ctx_ref[next(iter(ctx_ref))] if isinstance(ref_df, pd.Series) and isinstance(ctx_ref, pd.Series): ref_df = pd.concat([ctx_ref, ref_df]) elif isinstance(ref_df, pd.DataFrame) and isinstance(ctx_ref, pd.DataFrame): if ctx_to_ref: ctx_ref = ctx_ref.rename(columns=ctx_to_ref) ref_df = pd.concat([ctx_ref, ref_df], axis=0) assert ( ref_df.shape[1] == ctx_ref.shape[1] ), f"Parent columns not joined correctly to reference ones. If they have different names, pass in `ctx_to_ref` with names mapping them to parents" else: assert ( False ), "fixme: mismatched reference column counts. If single column transformer, both should be series, otherwise both should be dataframes" return {"data": data, "ref": ref_df} __all__ = [ "ColumnMetricFactory", "MetricFactory", "Metric", "Summaries", "ColumnMetric", "RefColumnMetric", "SeqColumnMetric", "Metric", ]