Source code for pasteur.extras.metrics.distr

from __future__ import annotations

import logging
from collections import defaultdict
from functools import reduce
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd
from numpy import ndarray
from scipy.special import rel_entr
from scipy.stats import chisquare
from scipy.stats.contingency import association

from pasteur.metric import Summaries
from pasteur.utils import LazyDataset

from ...attribute import Attributes, CatValue, SeqValue, get_dtype
from ...metric import Metric, Summaries
from ...utils import LazyChunk, LazyFrame, data_to_tables
from ...utils.progress import process_in_parallel

if TYPE_CHECKING:
    from ...metadata import Metadata

KL_ZERO_FILL = 1e-24
FONT_SIZE = "13px"

logger = logging.getLogger(__name__)

OneWaySummary = dict[str, ndarray]
TwoWaySummary = dict[tuple[str, str] | tuple[str, str | int, str], ndarray]
DistrSummary = Summaries[dict[str, tuple[OneWaySummary, TwoWaySummary]]]


[docs] def calc_marginal_1way( data: np.ndarray, domain: np.ndarray, x: list[int], zero_fill: float | None = None, ): """Calculates the 1 way marginal of x, returned as a 1D array.""" x_dom = reduce(lambda a, b: a * b, domain[x], 1) dtype = get_dtype(x_dom) idx = np.zeros((len(data)), dtype=dtype) tmp = np.empty((len(data)), dtype=dtype) mul = 1 for col in reversed(x): # idx += mul*data[:, col] np.add( idx, np.multiply(mul, data[:, col], out=tmp, casting="unsafe"), out=idx, ) mul *= domain[col] counts = np.bincount(idx, minlength=x_dom) assert ( len(counts) == x_dom ), f"Overflow error, domain for columns `{x}` is wrong or there is a mistake in encoding." return counts
def _visualise_cs( table: str, domain: dict[str, int], data: dict[str, Summaries[dict[str, np.ndarray]]], ): import mlflow from ...utils.mlflow import color_dataframe, gen_html_table results = {} # Add ref split first zfill = lambda x: (x + 1) / np.sum(x + 1) name = "ref" res = [] split = next(iter(data.values())) for col in domain: wrk, syn = split.wrk, split.ref assert syn is not None chi, p = chisquare(zfill(wrk[col]), zfill(syn[col])) res.append([col, chi, p]) results[name] = pd.DataFrame(res, columns=["col", "X^2", "p"]) for name, split in data.items(): res = [] for col in domain: wrk, syn = split.wrk, split.syn assert syn is not None chi, p = chisquare(zfill(wrk[col]), zfill(syn[col])) res.append([col, chi, p]) results[name] = pd.DataFrame(res, columns=["col", "X^2", "p"]) cs_formatters = { "X^2": {"precision": 3}, "p": {"formatter": lambda x: f"{100*x:.1f}"}, } style = color_dataframe( results, idx=["col"], cols=[], vals=["X^2", "p"], formatters=cs_formatters, split_ref="ref", ) fn = f"distr/cs.html" if table == "table" else f"distr/cs/{table}.html" mlflow.log_text(gen_html_table(style, FONT_SIZE), fn) def _get_histdata(val): if len(val) == 2 and val[0] == "None": v = _get_histdata(val[1]) if v is None: return None return [float("NaN"), *v] out = [] for v in val: if not isinstance(v, str): return None try: out.append(float(v)) continue except ValueError: pass if len(v) < 4: return None if v[0] not in "([": return None if v[-1] not in ")]": return None try: l, r = v[1:-1].split(", ", maxsplit=1) out.append((float(l) + float(r)) / 2) except ValueError: return None return out def _visualise_basetable( table: str, attrs: Attributes, data: dict[str, Summaries[dict[str, np.ndarray]]], ): import re from pasteur.hierarchy import RebalancedValue from ...utils.mlflow import gen_html_table, color_dataframe # Unroll splits ref_split = next(iter(data.values())) splits = { "wrk": ref_split.wrk, "ref": ref_split.ref, } for split, split_data in data.items(): splits[split] = split_data.syn # Handle them individually out_num = [] out_cat = [] CAT_VALS = 5 CAT_MIN_VAL = 0.001 TRE = re.compile(r"\d{2}:\d{2}") # 12:34 MRE = re.compile(r"\+?\d{2}:\d{2}") # +12:34 hvals_prev = {} for attr in attrs.values(): for name, col in attr.vals.items(): if not hasattr(col, "head"): continue for sname, split in splits.items(): bins = _get_histdata(getattr(col, "head")) if bins is None: break bins = np.array(bins) mean = np.nansum(bins * split[name]) / np.nansum(split[name]) std = np.sqrt( np.nansum((bins - mean) ** 2 * split[name]) / (np.nansum(split[name]) - 1) ) out_num.append( { "name": name, "split": sname, "mean": float(mean), "std": float(std), } ) counts = splits["wrk"][name] try: hval = RebalancedValue(counts, col) # type: ignore except Exception: logger.exception(f"Failed to get human values for {name}") hvals_prev[name] = (None, 0) continue height = 0 for h in range(hval.height): height = h dom = hval.get_domain(h) tmp = [0 for _ in range(dom)] for i, j in enumerate(hval.get_mapping(h)): tmp[j] += counts[i] # Missing values merge last, which can make # unecessary merges. Therefore, check the second largest # min is above min val vmins = sorted(tmp)[:2] if dom <= CAT_VALS and vmins[1] > CAT_MIN_VAL: break hvals_prev[name] = (hval, height) vnames = [[] for _ in range(hval.get_domain(height))] hnames = getattr(hval.original, "head").get_human_values() for i, v in enumerate(hval.get_mapping(height)): vnames[v].append(hnames[i]) def process_names(l): if not l: return "[Empty]" if len(l) == 1: return l[0] # Handle intervals if l[0] and l[0][0] in "[(" and l[-1] and l[-1][-1] in ")]": return f"{l[0].split(',')[0]}, {l[-1].split(', ')[-1]}" # Handle times if re.match(TRE, l[0]) and re.match(TRE, l[-1]): return f"{l[0]}-{l[-1]}" # Handle intervals (skip first +) if re.match(MRE, l[0]) and re.match(MRE, l[-1]): return f"{l[0]}-{l[-1][1:]}" # Handle numbers if all(v.isnumeric() for v in l): return f"[{min(l)}, {max(l)}]" return ", ".join([v for v in l if v])[:35] vnames = [process_names(v) for v in vnames] for i, vname in enumerate(vnames): for sname, split in splits.items(): nsum = np.sum(split[name]) mval = 0 for j, v in enumerate(hval.get_mapping(height)): if v == i: mval += split[name][j] rate = mval / nsum out_cat.append( { "name": name, "split": sname, "value": "[missing]" if vname == "None" else vname, "rate": 100 * float(rate), } ) import mlflow stylers = {} if out_num: stylers["Numerical"] = color_dataframe( out_num, idx=["name"], cols=[], vals=["mean", "std"], split_ref="wrk", split_col="split", formatters={"mean": {"precision": 3}, "std": {"precision": 3}}, ) if out_cat: stylers["Categorical"] = color_dataframe( out_cat, idx=["name", "value"], cols=[], vals=["rate"], split_ref="wrk", split_col="split", formatters={"rate": {"precision": 1}}, ) if stylers: fn = ( f"distr/basetable.html" if table == "table" else f"distr/basetable/{table}.html" ) mlflow.log_text(gen_html_table(stylers, FONT_SIZE), fn) def _visualise_kl( table: str, data: dict[str, Summaries[TwoWaySummary]], ): return _visualise_2way(table, data, "kl") ASSOC_METRICS = ["cramer", "tschuprow", "pearson"] METRICS = ["kl", *ASSOC_METRICS] def _visualise_2way( table: str, data: dict[str, Summaries[TwoWaySummary]], metr: str = "kl", domain=None ): import mlflow from ...utils.mlflow import color_dataframe, gen_html_table results = {} presults = {} ref_split = next(iter(data.values())) ref_split = Summaries(ref_split.wrk, ref_split.ref, ref_split.ref) for name, split in { "ref": ref_split, **data, }.items(): wrk, syn = split.wrk, split.syn assert syn res = [] pres = {} for key in syn: if len(key) == 3: col_i, p, col_j = key else: col_i, col_j = key p = None if metr == "kl": zfill = lambda x: (x + KL_ZERO_FILL) / np.sum(x + KL_ZERO_FILL) k = zfill(wrk[key]) j = zfill(syn[key]) kl = rel_entr(k / k.sum(), j).sum() kl_norm = 1 / (1 + kl) out = [col_i, col_j, kl, kl_norm, len(k)] elif metr in ASSOC_METRICS: assert domain if col_i == col_j and not p: continue k = wrk[key] + 1 j = syn[key] + 1 dom_i = domain[table][col_i] m_wrk = association(k.reshape((dom_i, -1)), method=metr) m_syn = association(j.reshape((dom_i, -1)), method=metr) m_res = np.abs(m_wrk - m_syn) out = [col_i, col_j, m_res, m_syn, len(k)] else: assert False, f"Metric {metr} not supported." if p: if p not in pres: pres[p] = [] pres[p].append(out) else: res.append(out) results[name] = pd.DataFrame( res, columns=[ "col_i", "col_j", "metr", "metr_norm", "mlen", ], ) sname = name.replace(" ", "_").replace("=", "_") # mlflow.log_metric(f"{sname}.kl_norm.{table}", results[name]["kl_norm"].mean()) if pres: presults[name] = { k: pd.DataFrame( v, columns=[ "col_i", "col_j", "metr", "metr_norm", "mlen", ], ) for k, v in pres.items() } for k, v in presults[name].items(): corrected = k.replace("-", "o") if k.startswith("-") else k mlflow.log_metric( f"{sname}.metr_norm.{table}.{corrected}", v["metr_norm"].mean(), ) kl_formatters = {"metr_norm": {"precision": 3}} kl_formatters_overall = {"mean_metr_norm": {"precision": 3}} res = {} for split in results: if split not in res: res[split] = [] res[split].append( { "table": "!", "split": split, "mean_metr_norm": results[split]["metr_norm"].mean(), } ) if presults: for p in presults[split]: res[split].append( { "table": p, "split": split, "mean_metr_norm": presults[split][p]["metr_norm"].mean(), } ) # Print results as a table outs = f"{metr.upper():>5s} Table '{table:15s}' results:\n" ores = [] for v in res.values(): ores.extend(v) outs += ( pd.DataFrame(ores) .pivot(index=["table"], columns=["split"], values=["mean_metr_norm"]) .xs("mean_metr_norm", axis=1) .sort_index() .to_markdown() ) outs += "\n" logger.info(outs) for v in results.values(): if v.empty: return res base = color_dataframe( results, idx=["col_j"], cols=["col_i"], vals=["metr_norm"], formatters=kl_formatters, split_ref="ref", ) overall = color_dataframe( {k: pd.DataFrame(v) for k, v in res.items()}, idx=["table"], cols=[], vals=["mean_metr_norm"], formatters=kl_formatters_overall, split_ref="ref", ) dfs = {"overall": overall, "same table": base} if presults: for p in next(iter(presults.values())): dfs[p] = color_dataframe( {k: v[p] for k, v in presults.items()}, idx=["col_i"], cols=["col_j"], vals=["metr_norm"], formatters=kl_formatters, split_ref="ref", ) pref = "" if metr in ASSOC_METRICS: pref = "assoc/" fn = ( f"distr/{pref}{metr}.html" if table == "table" else f"distr/{pref}{metr}/{table}.html" ) mlflow.log_text(gen_html_table(dfs, FONT_SIZE), fn) return res def _process_marginals_chunk( name: str, domain: dict[str, dict[str, int]], parents: dict[str, list[str]], seq: dict[str, SeqValue], ids: dict[str, LazyChunk], tables: dict[str, LazyChunk], ): tids = ids[name]() raw_table = tables[name]() table = raw_table[list(domain[name])].to_numpy(dtype="uint16") table_domain = domain[name] domain_arr = np.array(list(table_domain.values())) ofs = table.shape[1] # One way for CS one_way: dict[str, ndarray] = {} for i, cname in enumerate(table_domain): one_way[cname] = calc_marginal_1way(table, domain_arr, [i], 0) # Two way for KL two_way: dict[tuple[str, str] | tuple[str, str | int, str], ndarray] = {} for i, col_i in enumerate(table_domain): for j, col_j in enumerate(table_domain): two_way[(col_i, col_j)] = calc_marginal_1way(table, domain_arr, [i, j], 0) # Two way accross parents for p in parents[name]: p_table = ( tids[[p]] .join(tables[p](), on=p) .drop(columns=[p])[list(domain[p])] .to_numpy(dtype="uint16") ) p_domain = np.array(list(domain[p].values())) combined = np.concatenate((table, p_table), axis=1) combined_dom = np.concatenate((domain_arr, p_domain)) for i, col_i in enumerate(table_domain): for j, col_j in enumerate(domain[p]): two_way[(col_i, p, col_j)] = calc_marginal_1way( combined, combined_dom, [i, ofs + j], 0 ) _JOIN_NAME = "_id_zdjwk" _IDX_NAME = "_id_lkjijk" if name in seq: sval = seq[name] if sval.order: tseq = raw_table[sval.name] ids_seq = tids.join(tseq, how="right").reset_index(names=_IDX_NAME) for o in range(sval.order): ids_seq_prev = tids.join(tseq + o + 1, how="right").reset_index( names=_JOIN_NAME ) join_ids = ids_seq.merge( ids_seq_prev, on=[*tids.columns, sval.name], how="inner" ).set_index(_IDX_NAME)[[_JOIN_NAME]] ref_df = join_ids.join(raw_table, on=_JOIN_NAME)[ list(domain[name]) ].to_numpy(dtype="uint16") fkey = ( ~pd.isna( ids_seq.set_index(_IDX_NAME)[[]].join(join_ids, how="left") ).to_numpy() ).reshape(-1) combined = np.concatenate((table[fkey], ref_df), axis=1) combined_dom = np.concatenate([domain_arr, domain_arr]) for i, col_i in enumerate(table_domain): for j, col_j in enumerate(table_domain): two_way[(col_i, f"{-o-1}", col_j)] = calc_marginal_1way( combined, combined_dom, [i, ofs + j], 0 ) pass return one_way, two_way
[docs] class DistributionMetric(Metric[DistrSummary, DistrSummary]): name = "distr" encodings = "idx"
[docs] def fit( self, meta: dict[str, Attributes], data: dict[str, LazyFrame], ): self.domain = defaultdict(dict) self.attrs = meta self.parents = { k[:-4]: list(v.sample().columns) for k, v in data.items() if k.endswith("_ids") } self.seq = {} for table, attrs in meta.items(): for attr in attrs.values(): for name, val in attr.vals.items(): if isinstance(val, SeqValue): self.seq[table] = val else: assert isinstance(val, CatValue) self.domain[table][name] = val.domain
[docs] def preprocess( self, wrk: dict[str, LazyDataset], ref: dict[str, LazyDataset], ) -> Summaries[ dict[str, tuple[dict[str, ndarray], dict[tuple[str, str], ndarray]]] ]: per_call = [] per_call_meta = [] base_args = {"domain": self.domain, "parents": self.parents, "seq": self.seq} for cwrk, cref in LazyDataset.zip_values([wrk, ref]): for split, split_data in [("wrk", cwrk), ("ref", cref)]: ids, tables = data_to_tables(split_data) for table in self.domain: per_call.append( { "name": table, "ids": ids, "tables": tables, } ) per_call_meta.append({"split": split, "table": table}) # Process marginals out = process_in_parallel( _process_marginals_chunk, per_call, base_args=base_args, desc="Preprocessing distribution metrics", ) # Intertwine results res = defaultdict(lambda: defaultdict(list)) for meta, hist in zip(per_call_meta, out): res[meta["split"]][meta["table"]].append(hist) ret = defaultdict(dict) for split, split_hists in res.items(): for table, table_hists in split_hists.items(): one_way = {} for key in table_hists[0][0].keys(): one_way[key] = np.sum( [table_hists[i][0][key] for i in range(len(table_hists))], axis=0, ) two_way = {} for key in table_hists[0][1].keys(): two_way[key] = np.sum( [table_hists[i][1][key] for i in range(len(table_hists))], axis=0, ) ret[split][table] = one_way, two_way return Summaries(wrk=ret["wrk"], ref=ret["ref"])
[docs] def process( self, wrk: dict[str, LazyDataset], ref: dict[str, LazyDataset], syn: dict[str, LazyDataset], pre: DistrSummary, ) -> DistrSummary: per_call = [] per_call_meta = [] base_args = {"domain": self.domain, "parents": self.parents, "seq": self.seq} for csyn in LazyDataset.zip_values(syn): ids, tables = data_to_tables(csyn) for table in self.domain: per_call.append( { "name": table, "ids": ids, "tables": tables, } ) per_call_meta.append({"table": table}) # Process marginals out = process_in_parallel( _process_marginals_chunk, per_call, base_args=base_args, desc="Processing distribution metrics", ) # Intertwine results res = defaultdict(list) for meta, hist in zip(per_call_meta, out): res[meta["table"]].append(hist) ret = {} for table, table_hists in res.items(): one_way = {} for key in table_hists[0][0].keys(): one_way[key] = np.sum( [table_hists[i][0][key] for i in range(len(table_hists))], axis=0, ) two_way = {} for key in table_hists[0][1].keys(): two_way[key] = np.sum( [table_hists[i][1][key] for i in range(len(table_hists))], axis=0, ) ret[table] = one_way, two_way return pre.replace(syn=ret)
[docs] def visualise( self, data: dict[ str, DistrSummary, ], ): # import time overall_metr = {} for name in self.domain: # start = time.perf_counter() _visualise_cs( name, self.domain[name], { k: Summaries( wrk=v.wrk[name][0], ref=v.ref[name][0], syn=v.syn[name][0] if v.syn else None, ) for k, v in data.items() }, ) # logger.info(f"cs {name} {time.perf_counter()-start:.2f}s") # start = time.perf_counter() _visualise_basetable( name, self.attrs[name], { k: Summaries( wrk=v.wrk[name][0], ref=v.ref[name][0], syn=v.syn[name][0] if v.syn else None, ) for k, v in data.items() }, ) # logger.info(f"bs {name} {time.perf_counter()-start:.2f}s") for metric in METRICS: if metric not in overall_metr: overall_metr[metric] = {} # start = time.perf_counter() overall_metr[metric][name] = _visualise_2way( name, { k: Summaries( wrk=v.wrk[name][1], ref=v.ref[name][1], syn=v.syn[name][1] if v.syn else None, ) for k, v in data.items() }, metric, domain=self.domain, ) # logger.info(f"2w {metric} {name} {time.perf_counter()-start:.2f}s") from pasteur.utils.styles import use_style import matplotlib.pyplot as plt import mlflow use_style("mlflow") for metr in METRICS: scores = {} scores_per_table = {} for table, table_res in overall_metr[metr].items(): scores_per_table[table] = {} for split, split_res in table_res.items(): if split not in scores: scores[split] = { "intra": [], "seq": [], "hist": [], } if split not in scores_per_table[table]: scores_per_table[table][split] = { "intra": [], "seq": [], "hist": [], } for res in split_res: if res["table"] == "!": scores[split]["intra"].append(res["mean_metr_norm"]) scores_per_table[table][split]["intra"].append( res["mean_metr_norm"] ) elif res["table"].startswith("-"): scores[split]["seq"].append(res["mean_metr_norm"]) scores_per_table[table][split]["seq"].append( res["mean_metr_norm"] ) else: scores[split]["hist"].append(res["mean_metr_norm"]) scores_per_table[table][split]["hist"].append( res["mean_metr_norm"] ) fancy_names = { "intra": "Intra-table", "seq": "Sequential", "hist": "Inter-table", } lines = {} mlflow.log_dict(scores, f"_raw/metrics/distr/{metr}_overall.json") mlflow.log_dict( scores_per_table, f"_raw/metrics/distr/{metr}_overall_per_table.json" ) for table, split_scores_per_table in [ ("_overall_single", scores), ("_overall", scores), *scores_per_table.items(), ]: combined = "_single" in table fig, ax = plt.subplots() bar_width = 0.3 for split, split_scores in split_scores_per_table.items(): for stype, type_scores in split_scores.items(): if stype not in lines: lines[stype] = {} lines[stype][split] = np.mean(type_scores) if type_scores else 0 l_res = 0 split_scores = {} if combined: l_res = len(split_scores_per_table) split_scores = split_scores_per_table for x, y in enumerate(split_scores_per_table.values()): ax.bar( x, np.mean([np.mean(v) for v in y.values()]), ) else: for i, (stype, split_scores) in enumerate(lines.items()): l_res = len(split_scores) x = np.arange(l_res) ax.bar( x + i * bar_width, split_scores.values(), bar_width, label=fancy_names[stype], ) ax.set_xlabel("Experiment") ax.set_ylabel(f"Mean Norm {metr.upper()}") ax.set_title(f"Overall Mean Norm {metr.upper()}") max_len = 0 labels = [k.split(" ") for k in split_scores.keys()] for params in labels: for param in params: max_len = max(max_len, len(param)) ax.set_xticks(np.arange(l_res) + (0 if combined else 0.3)) if max_len > 15 or l_res > 7: tick_labels = [" ".join(l) for l in labels] rot = min(3 * l_res, 90) ax.set_xticklabels(tick_labels) plt.setp( ax.get_xticklabels(), rotation=rot, horizontalalignment="right" ) else: tick_labels = ["\n".join(l) for l in labels] ax.set_xticklabels(tick_labels) if combined: # Dont use legend on combined graph pass elif metr == "kl": # ax.set_ylim([0.55, 1.03]) ax.legend(loc="lower right") elif metr in ASSOC_METRICS: ax.legend(loc="upper right") else: ax.legend(loc="lower right") # elif metr == "chi2": # ax.set_ylim([0.5, 1.03]) plt.tight_layout() pref = "" if metr in ASSOC_METRICS: pref = "assoc/" mlflow.log_figure(fig, f"distr/{pref}{metr}_overall/{table}.png") plt.close()