Source code for pasteur.marginal.numpy

from typing import NamedTuple, cast

import numpy as np
import pandas as pd

from ..attribute import Attributes, get_dtype, CatValue

ZERO_FILL = 1e-24


[docs]class AttrSelector(NamedTuple): name: str common: int cols: dict[str, int]
AttrSelectors = dict[str, AttrSelector]
[docs]def expand_table( attrs: Attributes, table: pd.DataFrame, *, out_cols: dict[str, list[np.ndarray]] | None = None, out_noncommon: dict[str, list[np.ndarray]] | None = None ) -> tuple[ dict[str, list[np.ndarray]], dict[str, list[np.ndarray]], dict[str, list[int]] ]: """Takes in the raw idx encoded table and precalculates all column-height combinations of hierarchical attributes, with special versions for marginal calculations with attributes that have an NA value. Returns: cols: A dictionary, list structure that can be accessed as cols[name][height] to get each row's group with column <name> and <height> height. cols_noncommon: A second version that's offset by 1 or 2 depending on whether the parent attribute has na values/unknown values (+1 for each). domain: The same structure containing the domain of each <name>,<height> combination. It is then possible to calculate the marginal of an attribute with cols a,b,c, heights d,e,f and na values by doing the following: ``` groups = col[a][d] + domain[a][d]*(cols_noncommon[b][e] + (domain[b][e]-1)*cols_noncommon[c][f]) np.bincount(groups, minlength=domain[a][d]*(domain[b][e] - 1)*(domain[c][f] - 1)) ``` The above expression only requires one vector multiplication and one vector addition per attribute added to the marginal, with `bincount()` scaling linearly with dataset size `n`. For a dataset with size n=500k and 6 columns used in the marginal, it has a wallsize of 1.3ms, to 30ms of np.histogramdd. """ cols = {} cols_noncommon = {} domains = {} for attr in attrs.values(): for name, col in attr.vals.items(): if name not in table: continue col = cast(CatValue, col) col_hier = [] col_noncommon = [] col_dom = [] for height in range(col.height): domain = col.get_domain(height) col_dom.append(domain) col_lvl = col.get_mapping(height)[table[name]] col_lvl = col_lvl.astype(get_dtype(domain)) if out_cols: out_cols[name][height][:] = col_lvl else: col_hier.append(col_lvl) if attr.common > 0: nc = np.where(col_lvl > attr.common, col_lvl - attr.common, 0) if out_noncommon: out_noncommon[name][height][:] = col_lvl else: col_noncommon.append(nc) domains[name] = col_dom cols[name] = col_hier cols_noncommon[name] = col_noncommon return out_cols or cols, out_noncommon or cols_noncommon, domains
[docs]def get_domains(attrs: Attributes) -> dict[str, list[int]]: domains = {} for attr in attrs.values(): for name, col in attr.vals.items(): col = cast(CatValue, col) col_dom = [] for height in range(col.height): domain = col.get_domain(height) col_dom.append(domain) domains[name] = col_dom return domains
[docs]def calc_marginal( cols: dict[str, list[np.ndarray]], cols_noncommon: dict[str, list[np.ndarray]], domains: dict[str, list[int]], x: AttrSelector, p: AttrSelectors, partial: bool = False, out: np.ndarray | None = None, ): """Calculates the 1 way and 2 way marginals between the subsection of the hierarchical attribute x and the attributes p(arents).""" # Find integer dtype based on domain p_dom = 1 for attr in p.values(): common = attr.common l_dom = 1 for i, (n, h) in enumerate(attr.cols.items()): l_dom *= domains[n][h] - common p_dom *= l_dom + common x_dom = 1 for i, (n, h) in enumerate(x.cols.items()): x_dom *= domains[n][h] - x.common x_dom += x.common dtype = get_dtype(p_dom * x_dom) n = len(next(iter(cols.values()))[0]) _sum_nd = np.zeros((n,), dtype=dtype) _tmp_nd = np.zeros((n,), dtype=dtype) # Handle parents mul = 1 for attr_name, attr in p.items(): common = attr.common l_mul = 1 p_partial = partial and attr_name == x.name for i, (n, h) in enumerate(attr.cols.items()): if common == 0 or i == 0: np.multiply(cols[n][h], mul * l_mul, out=_tmp_nd, dtype=dtype) else: np.multiply(cols_noncommon[n][h], mul * l_mul, out=_tmp_nd, dtype=dtype) np.add(_sum_nd, _tmp_nd, out=_sum_nd, dtype=dtype) l_mul *= domains[n][h] - common if p_partial: mul *= l_mul else: mul *= l_mul + common # Handle x common = x.common for i, (n, h) in enumerate(x.cols.items()): if common == 0 or (i == 0 and not partial): np.multiply(cols[n][h], mul, out=_tmp_nd, dtype=dtype) else: np.multiply(cols_noncommon[n][h], mul, out=_tmp_nd, dtype=dtype) np.add(_sum_nd, _tmp_nd, out=_sum_nd, dtype=dtype) mul *= domains[n][h] - common # Keep only non-common items if there is a parent to source the others if partial: n = next(iter(x.cols)) _sum_nd = _sum_nd[cols[n][0] >= common] x_dom = x_dom - x.common counts = np.bincount(_sum_nd, minlength=p_dom * x_dom) if out is not None: out = out.reshape((-1,)) out += counts else: out = counts return out.reshape((x_dom, p_dom))
[docs]def calc_marginal_1way( cols: dict[str, list[np.ndarray]], cols_noncommon: dict[str, list[np.ndarray]], domains: dict[str, list[int]], x: AttrSelectors, out: np.ndarray | None = None, ): """Calculates the 1 way marginal of the subsections of attributes x""" # Find integer dtype based on domain dom = 1 for attr in x.values(): common = attr.common l_dom = 1 for i, (n, h) in enumerate(attr.cols.items()): l_dom *= domains[n][h] - common dom *= l_dom + common dtype = get_dtype(dom) n = len(next(iter(cols.values()))[0]) _sum_nd = np.zeros((n,), dtype=dtype) _tmp_nd = np.empty((n,), dtype=dtype) mul = 1 for attr in reversed(x.values()): common = attr.common l_mul = 1 for i, (n, h) in enumerate(attr.cols.items()): if common == 0 or i == 0: np.multiply(cols[n][h], mul * l_mul, out=_tmp_nd, dtype=dtype) else: np.multiply(cols_noncommon[n][h], mul * l_mul, out=_tmp_nd, dtype=dtype) np.add(_sum_nd, _tmp_nd, out=_sum_nd, dtype=dtype) l_mul *= domains[n][h] - common mul *= l_mul + common counts = np.bincount(_sum_nd, minlength=dom) if out is not None: out += counts else: out = counts return out
[docs]def normalize(counts: np.ndarray, zero_fill: float | None = ZERO_FILL): margin = counts.astype("float32") margin /= margin.sum() if zero_fill is not None: # Mutual info turns into NaN without this margin += zero_fill j_mar = margin x_mar = np.sum(margin, axis=1) p_mar = np.sum(margin, axis=0) return j_mar, x_mar, p_mar
[docs]def normalize_1way(counts: np.ndarray, zero_fill: float | None = ZERO_FILL): margin = counts.astype("float32") margin /= margin.sum() if zero_fill is not None: # Mutual info turns into NaN without this margin += zero_fill return margin