Source code for pasteur.marginal.postprocess

""" This module contains functions for post-processing marginals.
"""

import numpy as np
from .numpy import AttrSelectors, CalculationInfo

ZERO_FILL = 1e-24


[docs] def unpack(req: AttrSelectors, mar: np.ndarray, info: CalculationInfo): """Undoes the marginal packing done by the marginal calculation algorithm to save space.""" packed_doms = [] full_doms = [] val_doms = [] commons = [] for (table, attr, sel) in req: packed_dom = 1 full_dom = 1 attr_dom = [] if isinstance(sel, dict): common = info.common[(table, attr)] for n, h in sel.items(): v_dom = info.domains[(table, n)][h] full_dom *= v_dom packed_dom *= v_dom - common attr_dom.append(v_dom) commons.append(common) full_doms.append(full_dom) packed_doms.append(packed_dom + common) val_doms.append(attr_dom) else: commons.append(0) v_dom = info.domains[(table, info.common_names[(table, attr)])][sel] packed_doms.append(v_dom) full_doms.append(v_dom) val_doms.append([v_dom]) out = np.zeros(full_doms) slices = tuple(slice(d) for d in packed_doms) out[slices] = mar.reshape(packed_doms) del mar for i, (common, val, packed) in enumerate(zip(commons, val_doms, packed_doms)): if not common: continue write_slices = tuple( slice(-packed + common, None) if j == i else slice(None) for j in range(len(packed_doms)) ) read_slices = tuple( slice(common, packed) if j == i else slice(None) for j in range(len(packed_doms)) ) out[write_slices] = out[read_slices] ofs = sum(val[:-1]) for k in reversed(range(1, common)): write_slices = tuple( slice(ofs * k, ofs * k + 1) if j == i else slice(None) for j in range(len(packed_doms)) ) read_slices = tuple( slice(k, k + 1) if j == i else slice(None) for j in range(len(packed_doms)) ) out[write_slices] = out[read_slices] expanded_doms = [] for v in val_doms: expanded_doms.extend(v) return out.reshape(expanded_doms)
[docs] def two_way_normalize( req: AttrSelectors, mar: np.ndarray, info: CalculationInfo, zero_fill: float | None = ZERO_FILL, ): table, attr, sel = req[-1] if isinstance(sel, dict): dom = 1 common = info.common[(table, attr)] for n, h in sel.items(): v_dom = info.domains[(table, n)][h] dom *= v_dom - common dom += common else: dom = info.domains[(table, info.common_names[(table, attr)])][sel] margin = mar.reshape((-1, dom)).astype("float32") margin /= margin.sum() if zero_fill is not None: # Mutual info turns into NaN without this margin += zero_fill j_mar = margin x_mar = np.sum(margin, axis=1) p_mar = np.sum(margin, axis=0) return j_mar, x_mar, p_mar
[docs] def normalize( req: AttrSelectors, mar: np.ndarray, info: CalculationInfo, zero_fill: float | None = ZERO_FILL, ): margin = mar.astype("float32") margin /= margin.sum() if zero_fill is not None: # Mutual info turns into NaN without this margin += zero_fill return margin