Source code for pasteur.marginal.numpy

from collections import defaultdict
from typing import Mapping, NamedTuple, Sequence, TypeGuard, cast

import numpy as np
import pandas as pd

from ..attribute import (
    Attributes,
    CatValue,
    Grouping,
    SeqAttributes,
    StratifiedValue,
    get_dtype,
)

ChildSelector = dict[str, int]
CommonSelector = int
TableSelector = str | tuple[str, int] | None
AttrName = str | tuple[str, ...]


AttrSelector = tuple[TableSelector, AttrName, ChildSelector | CommonSelector]
AttrSelectors = Sequence[AttrSelector]

CalculationData = dict[tuple[TableSelector, str, bool], list[np.ndarray]]


[docs] class CalculationInfo(NamedTuple): domains: dict[tuple[TableSelector, str], list[int]] common: dict[tuple[TableSelector, AttrName | None], int] common_names: dict[tuple[TableSelector, AttrName], str]
def _calc_common_rec(common: Grouping, col: Grouping): assert len(col) == len(common) ofs = 0 for k, l in zip(common, col): if isinstance(k, Grouping): assert isinstance(l, Grouping) ofs += _calc_common_rec(k, l) elif isinstance(l, Grouping): return ofs else: ofs += 1 return ofs def _calc_common_seq_rec(seq: Grouping, col: Grouping): assert len(col) == len(seq) ofs = 0 for k, l in zip(seq, col): if isinstance(k, Grouping) and isinstance(l, Grouping): if len(k) != len(l): return ofs ofs += _calc_common_seq_rec(k, l) elif isinstance(l, Grouping) or isinstance(l, Grouping): return ofs else: ofs += 1 return ofs def _calc_common(col: CatValue, common: CatValue | None): if not common: return 0 if not isinstance(col, StratifiedValue): return 0 if not isinstance(common, StratifiedValue): return 0 return _calc_common_rec(common.head, col.head) def _calc_common_seq(col: CatValue, seq: CatValue | None): if not seq: return 0 if not isinstance(col, StratifiedValue): return 0 if not isinstance(seq, StratifiedValue): return 0 return _calc_common_seq_rec(seq.head, col.head) def _map_column(table: pd.DataFrame, col: CatValue, common: CatValue | None): cols = [] cols_noncommon = [] domains = [] common_num = _calc_common(col, common) for height in range(col.height): domain = col.get_domain(height) domains.append(domain) col_lvl = col.get_mapping(height)[table[col.name]] col_lvl = col_lvl.astype(get_dtype(domain)) cols.append(col_lvl) if common_num > 0: non_common = np.where(col_lvl > common_num, col_lvl - common_num, 0) cols_noncommon.append(non_common) return cols, cols_noncommon, domains, common_num
[docs] def expand_table( attrs: Mapping[str | None, Attributes | SeqAttributes] = {}, tables: dict[TableSelector, pd.DataFrame] = {}, *, prealloc: CalculationData | None = None, ) -> tuple[CalculationData, CalculationInfo]: """Takes in the raw idx encoded table and precalculates all column-height combinations of hierarchical attributes, with special versions for marginal calculations with attributes that have an NA value. Returns: cols: A dictionary, list structure that can be accessed as cols[name][height] to get each row's group with column <name> and <height> height. cols_noncommon: A second version that's offset by 1 or 2 depending on whether the parent attribute has na values/unknown values (+1 for each). domain: The same structure containing the domain of each <name>,<height> combination. It is then possible to calculate the marginal of an attribute with cols a,b,c, heights d,e,f and na values by doing the following: ``` groups = col[a][d] + domain[a][d]*(cols_noncommon[b][e] + (domain[b][e]-1)*cols_noncommon[c][f]) np.bincount(groups, minlength=domain[a][d]*(domain[b][e] - 1)*(domain[c][f] - 1)) ``` The above expression only requires one vector multiplication and one vector addition per attribute added to the marginal, with `bincount()` scaling linearly with dataset size `n`. For a dataset with size n=500k and 6 columns used in the marginal, it has a wallsize of 1.3ms, to 30ms of np.histogramdd. """ out = {} domains = {} common = {} common_names = {} # For hierarchical data, attributes are provided per table # We separate them into historical data, which are optional, # and the main table, which will have the name 'None' # The reason the main table is set to None that it may be partially synthesized, # so its name will be in the historical data. for table_name, table_attrs in attrs.items(): table_name: str | None table_attrs: Attributes | SeqAttributes # For each table, we have multiple versions when there are sequential data attr_sets: list[ tuple[tuple[str, int] | str | None, StratifiedValue | None, Attributes] ] if isinstance(table_attrs, SeqAttributes): assert isinstance(table_name, str) attr_sets = [ ((table_name, k), table_attrs.seq, v) for k, v in table_attrs.hist.items() ] if table_attrs.attrs is not None: attr_sets.append((table_name, None, table_attrs.attrs)) else: attr_sets = [(table_name, None, table_attrs)] for table_sel, seq, attr_set in attr_sets: table_common = 1024 for attr in attr_set.values(): vals = list(attr.vals.items()) if attr.common: common_name = attr.common.name vals.append((common_name, attr.common)) common_names[(table_sel, attr.name)] = common_name attr_common = 1024 for name, col in vals: if not isinstance(col, CatValue): continue if name not in tables[table_sel]: continue col_hier, col_noncommon, col_domain, col_common = _map_column( tables[table_sel], col, attr.common ) if prealloc: for height, data in enumerate(col_hier): prealloc[(table_sel, name, False)][height][:] = data for height, data in enumerate(col_noncommon): prealloc[(table_sel, name, True)][height][:] = data else: out[(table_sel, name, False)] = col_hier if col_noncommon: out[(table_sel, name, True)] = col_noncommon domains[(table_sel, name)] = col_domain attr_common = min(attr_common, col_common) table_common = min(table_common, _calc_common_seq(col, seq)) common[(table_sel, attr.name)] = ( attr_common if attr_common < 1000 else 0 ) return prealloc or out, CalculationInfo(domains, common, common_names)
[docs] def get_domains(attrs: Attributes) -> dict[str, list[int]]: domains = {} for attr in attrs.values(): for name, col in attr.vals.items(): col = cast(CatValue, col) col_dom = [] for height in range(col.height): domain = col.get_domain(height) col_dom.append(domain) domains[name] = col_dom return domains
[docs] def calc_marginal( data: CalculationData, info: CalculationInfo, x: AttrSelectors, out: np.ndarray | None = None, ): """Calculates the 1 way marginal of the subsections of attributes x""" # Find integer dtype based on domain dom = 1 for (table, attr, sel) in x: if isinstance(sel, dict): common = info.common[(table, attr)] l_dom = 1 for n, h in sel.items(): l_dom *= info.domains[(table, n)][h] - common dom *= l_dom + common else: dom *= info.domains[(table, info.common_names[(table, attr)])][sel] dtype = get_dtype(dom) n = len(next(iter(data.values()))[0]) _sum_nd = np.zeros((n,), dtype=dtype) _tmp_nd = np.empty((n,), dtype=dtype) mul = 1 for (table, attr, sel) in reversed(x): common = info.common[(table, attr)] if isinstance(sel, dict): l_mul = 1 for i, (n, h) in enumerate(reversed(sel.items())): if common == 0 or i == 0: np.multiply( data[(table, n, False)][h], mul * l_mul, out=_tmp_nd, dtype=dtype, ) else: np.multiply( data[(table, n, True)][h], mul * l_mul, out=_tmp_nd, dtype=dtype, ) np.add(_sum_nd, _tmp_nd, out=_sum_nd, dtype=dtype) l_mul *= info.domains[(table, n)][h] - common mul *= l_mul + common else: np.multiply( data[(table, info.common_names[(table, attr)], False)][sel], mul, out=_tmp_nd, dtype=dtype, ) mul *= info.domains[(table, info.common_names[(table, attr)])][sel] counts = np.bincount(_sum_nd, minlength=dom) if out is not None: out += counts else: out = counts return out