Source code for pasteur.extras.metrics.distr

from __future__ import annotations

import logging
from collections import defaultdict
from functools import reduce
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd
from numpy import ndarray
from scipy.special import rel_entr
from scipy.stats import chisquare

from pasteur.metric import Summaries
from pasteur.utils import LazyDataset

from ...attribute import Attributes, CatValue, SeqValue, get_dtype
from ...metric import Metric, Summaries
from ...utils import LazyChunk, LazyFrame, data_to_tables
from ...utils.progress import process_in_parallel

if TYPE_CHECKING:
    from ...metadata import Metadata

KL_ZERO_FILL = 1e-24
FONT_SIZE = "13px"

logger = logging.getLogger(__name__)


[docs]def calc_marginal_1way( data: np.ndarray, domain: np.ndarray, x: list[int], zero_fill: float | None = None, ): """Calculates the 1 way marginal of x, returned as a 1D array.""" x_dom = reduce(lambda a, b: a * b, domain[x], 1) dtype = get_dtype(x_dom) idx = np.zeros((len(data)), dtype=dtype) tmp = np.empty((len(data)), dtype=dtype) mul = 1 for col in reversed(x): # idx += mul*data[:, col] np.add(idx, np.multiply(mul, data[:, col], out=tmp), out=idx) mul *= domain[col] counts = np.bincount(idx, minlength=x_dom) assert ( len(counts) == x_dom ), f"Overflow error, domain for columns `{x}` is wrong or there is a mistake in encoding." margin = counts.astype("float") margin /= margin.sum() if zero_fill is not None: # Mutual info turns into NaN without this margin += zero_fill margin /= margin.sum() return margin.reshape(-1)
def _visualise_cs( table: str, domain: dict[str, int], data: dict[str, Summaries[dict[str, np.ndarray]]], ): import mlflow from ...utils.mlflow import color_dataframe, gen_html_table results = {} # Add ref split first zfill = lambda x: (x + 1) / np.sum(x + 1) name = "ref" res = [] split = next(iter(data.values())) for col in domain: wrk, syn = split.wrk, split.ref assert syn is not None chi, p = chisquare(zfill(wrk[col]), zfill(syn[col])) res.append([col, chi, p]) results[name] = pd.DataFrame(res, columns=["col", "X^2", "p"]) for name, split in data.items(): res = [] for col in domain: wrk, syn = split.wrk, split.syn assert syn is not None chi, p = chisquare(zfill(wrk[col]), zfill(syn[col])) res.append([col, chi, p]) results[name] = pd.DataFrame(res, columns=["col", "X^2", "p"]) cs_formatters = { "X^2": {"precision": 3}, "p": {"formatter": lambda x: f"{100*x:.1f}"}, } style = color_dataframe( results, idx=["col"], cols=[], vals=["X^2", "p"], formatters=cs_formatters, split_ref="ref", ) fn = f"distr/cs.html" if table == "table" else f"distr/{table}_cs.html" mlflow.log_text(gen_html_table(style, FONT_SIZE), fn) def _visualise_kl( table: str, data: dict[str, Summaries[dict[tuple[str, str], np.ndarray]]] ): import mlflow from ...utils.mlflow import color_dataframe, gen_html_table results = {} ref_split = next(iter(data.values())) ref_split = Summaries(ref_split.wrk, ref_split.ref, ref_split.ref) for name, split in { "ref": ref_split, **data, }.items(): wrk, syn = split.wrk, split.syn assert syn res = [] for key in syn: col_i, col_j = key zfill = lambda x: (x + KL_ZERO_FILL) / np.sum(x + KL_ZERO_FILL) k = zfill(wrk[key]) j = zfill(syn[key]) kl = rel_entr(k, j).sum() kl_norm = 1 / (1 + kl) res.append([col_i, col_j, kl, kl_norm, len(k)]) results[name] = pd.DataFrame( res, columns=[ "col_i", "col_j", "kl", "kl_norm", "mlen", ], ) logger.info(f"Split {name} mean norm KL={results[name]['kl_norm'].mean():.5f}.") mlflow.log_metric(f"kl_norm.{name}", results[name]["kl_norm"].mean()) kl_formatters = {"kl_norm": {"precision": 3}} style = color_dataframe( results, idx=["col_j"], cols=["col_i"], vals=["kl_norm"], formatters=kl_formatters, split_ref="ref", ) fn = f"distr/kl.html" if table == "table" else f"distr/{table}_kl.html" mlflow.log_text(gen_html_table(style, FONT_SIZE), fn) def _process_marginals_chunk( name: str, expand_parents: bool, domain: dict[str, dict[str, int]], ids: dict[str, LazyChunk], tables: dict[str, LazyChunk], ): assert not expand_parents, "Expanding parents not supported yet" table = tables[name]()[list(domain[name])].to_numpy(dtype="uint16") table_domain = domain[name] domain_arr = np.array(list(table_domain.values())) # One way for CS one_way: dict[str, ndarray] = {} for i, name in enumerate(table_domain): one_way[name] = calc_marginal_1way(table, domain_arr, [i], 0) # Two way for KL two_way: dict[tuple[str, str], ndarray] = {} for i, col_i in enumerate(table_domain): for j, col_j in enumerate(table_domain): two_way[(col_i, col_j)] = calc_marginal_1way(table, domain_arr, [i, j], 0) return one_way, two_way
[docs]class DistributionMetric( Metric[ Summaries[dict[str, tuple[dict[str, ndarray], dict[tuple[str, str], ndarray]]]], Summaries[dict[str, tuple[dict[str, ndarray], dict[tuple[str, str], ndarray]]]], ] ): name = "dstr" encodings = "idx"
[docs] def fit( self, meta: dict[str, Attributes], data: dict[str, LazyFrame], ): self.domain = defaultdict(dict) for table, attrs in meta.items(): for attr in attrs.values(): for name, val in attr.vals.items(): if isinstance(val, SeqValue): continue assert isinstance(val, CatValue) self.domain[table][name] = val.domain
[docs] def preprocess( self, wrk: dict[str, LazyDataset], ref: dict[str, LazyDataset], ) -> Summaries[ dict[str, tuple[dict[str, ndarray], dict[tuple[str, str], ndarray]]] ]: per_call = [] per_call_meta = [] base_args = {"domain": self.domain} for cwrk, cref in LazyDataset.zip_values([wrk, ref]): for split, split_data in [("wrk", cwrk), ("ref", cref)]: ids, tables = data_to_tables(split_data) for table in self.domain: per_call.append( { "name": table, "expand_parents": False, "ids": ids, "tables": tables, } ) per_call_meta.append({"split": split, "table": table}) # Process marginals out = process_in_parallel( _process_marginals_chunk, per_call, base_args=base_args, desc="Preprocessing distribution metrics", ) # Intertwine results res = defaultdict(lambda: defaultdict(list)) for meta, hist in zip(per_call_meta, out): res[meta["split"]][meta["table"]].append(hist) ret = defaultdict(dict) for split, split_hists in res.items(): for table, table_hists in split_hists.items(): one_way = {} for key in table_hists[0][0].keys(): one_way[key] = np.sum( [table_hists[i][0][key] for i in range(len(table_hists))], axis=0, ) two_way = {} for key in table_hists[0][1].keys(): two_way[key] = np.sum( [table_hists[i][1][key] for i in range(len(table_hists))], axis=0, ) ret[split][table] = one_way, two_way return Summaries(wrk=ret["wrk"], ref=ret["ref"])
[docs] def process( self, wrk: dict[str, LazyDataset], ref: dict[str, LazyDataset], syn: dict[str, LazyDataset], pre: Summaries[ dict[str, tuple[dict[str, ndarray], dict[tuple[str, str], ndarray]]] ], ) -> Summaries[ dict[str, tuple[dict[str, ndarray], dict[tuple[str, str], ndarray]]] ]: per_call = [] per_call_meta = [] base_args = {"domain": self.domain} for csyn in LazyDataset.zip_values(syn): ids, tables = data_to_tables(csyn) for table in self.domain: per_call.append( { "name": table, "expand_parents": False, "ids": ids, "tables": tables, } ) per_call_meta.append({"table": table}) # Process marginals out = process_in_parallel( _process_marginals_chunk, per_call, base_args=base_args, desc="Processing distribution metrics", ) # Intertwine results res = defaultdict(list) for meta, hist in zip(per_call_meta, out): res[meta["table"]].append(hist) ret = {} for table, table_hists in res.items(): one_way = {} for key in table_hists[0][0].keys(): one_way[key] = np.sum( [table_hists[i][0][key] for i in range(len(table_hists))], axis=0, ) two_way = {} for key in table_hists[0][1].keys(): two_way[key] = np.sum( [table_hists[i][1][key] for i in range(len(table_hists))], axis=0, ) ret[table] = one_way, two_way return pre.replace(syn=ret)
[docs] def visualise( self, data: dict[ str, Summaries[ dict[str, tuple[dict[str, ndarray], dict[tuple[str, str], ndarray]]] ], ], ): for name in self.domain: _visualise_cs( name, self.domain[name], { k: Summaries( wrk=v.wrk[name][0], ref=v.ref[name][0], syn=v.syn[name][0] if v.syn else None, ) for k, v in data.items() }, ) _visualise_kl( name, { k: Summaries( wrk=v.wrk[name][1], ref=v.ref[name][1], syn=v.syn[name][1] if v.syn else None, ) for k, v in data.items() }, )