Source code for pasteur.marginal.native_py

import numpy as np

from .native import marginal
from .numpy import AttrSelector, AttrSelectors

Op = tuple[int, np.ndarray]


[docs]def calc_marginal( cols: dict[str, list[np.ndarray]], cols_noncommon: dict[str, list[np.ndarray]], domains: dict[str, list[int]], x: AttrSelector, p: AttrSelectors, partial: bool = False, out: np.ndarray | None = None, simd: bool = True, ): """Calculates the 1 way and 2 way marginals between the subsection of the hierarchical attribute x and the attributes p(arents).""" # Keep only non-common items if there is a parent to source the others # if partial: # n = next(iter(x.cols)) # mask = cols[n][0] >= x.common # x_dom = x_dom - x.common # Handle parents ops: list[Op] = [] mul = 1 for attr_name, attr in p.items(): common = attr.common l_mul = 1 p_partial = partial and attr_name == x.name for i, (n, h) in enumerate(attr.cols.items()): if common == 0 or i == 0: ops.append((mul * l_mul, cols[n][h])) else: ops.append((mul * l_mul, cols_noncommon[n][h])) l_mul *= domains[n][h] - common if p_partial: mul *= l_mul else: mul *= l_mul + common p_dom = mul # Handle x common = x.common l_mul = 1 for i, (n, h) in enumerate(x.cols.items()): if common == 0 or (i == 0 and not partial): ops.append((mul * l_mul, cols[n][h])) else: ops.append((mul * l_mul, cols_noncommon[n][h])) l_mul *= domains[n][h] - common if not partial: l_mul += common x_dom = l_mul dom = mul * l_mul if out is None: out = np.zeros((dom,), dtype=np.uint32) else: out = out.reshape((-1,)) marginal(out, ops, simd) return out.reshape((x_dom, p_dom))
[docs]def calc_marginal_1way( cols: dict[str, list[np.ndarray]], cols_noncommon: dict[str, list[np.ndarray]], domains: dict[str, list[int]], x: AttrSelectors, out: np.ndarray | None = None, simd: bool = True, ): """Calculates the 1 way marginal of the subsections of attributes x""" ops: list[Op] = [] mul = 1 for attr in reversed(x.values()): common = attr.common l_mul = 1 for i, (n, h) in enumerate(attr.cols.items()): if common == 0 or i == 0: ops.append((l_mul * mul, cols[n][h])) else: ops.append((l_mul * mul, cols_noncommon[n][h])) l_mul *= domains[n][h] - common mul *= l_mul + common if out is None: out = np.zeros((mul,), dtype=np.uint32) marginal(out, ops, simd) return out