Source code for pasteur.extras.metrics.visual

import logging
from typing import TYPE_CHECKING, Any, NamedTuple, Sequence, TypeVar, cast

import numpy as np
import pandas as pd
from numpy import ndarray
from pandas.core.frame import DataFrame
from pandas.core.series import Series

from pasteur.attribute import SeqValue
from pasteur.metric import AbstractColumnMetric, SeqColumnData, Summaries

from ...metric import (
    AbstractColumnMetric,
    ColumnMetric,
    RefColumnData,
    RefColumnMetric,
    SeqColumnMetric,
    Summaries,
    name_style_fn,
    name_style_title,
)
from ...utils import list_unique
from ...utils.mlflow import load_matplotlib_style, mlflow_log_hists

if TYPE_CHECKING:
    from matplotlib.figure import Figure

logger = logging.getLogger(__name__)

A = TypeVar("A")


def _percent_formatter(x, pos):
    return f"{100*x:.1f}%"


def _gen_hist(
    y_log: bool,
    title: str,
    bins: np.ndarray | Sequence[float],
    heights: dict[str, np.ndarray],
    xticks_x=None,
    xticks_label=None,
):
    import matplotlib.pyplot as plt

    try:
        fig, ax = plt.subplots()
        x = np.array(bins)[:-1]
        w = (x[1] - x[0]) / len(heights)

        for i, (name, h) in enumerate(heights.items()):
            ax.bar(x + w * i, h / h.sum(), width=w, label=name, log=y_log)

        ax.legend()
        ax.set_title(title)
        ax.yaxis.set_major_formatter(_percent_formatter)

        if xticks_x is not None:
            ax.set_xticks(xticks_x, xticks_label)

        plt.tight_layout()
        return fig
    except Exception as e:
        logger.error(f"Failed to generate histogram '{title}' with error:\n{e}")
        return None


def _gen_bar(y_log: bool, title: str, cols: list[str], counts: dict[str, np.ndarray]):
    import matplotlib.pyplot as plt

    fig, ax = plt.subplots()

    x = np.array(range(len(cols)))
    w = 0.9 / len(counts)

    for i, (name, c) in enumerate(counts.items()):
        h = c / c.sum() if c.sum() > 0 else c
        ax.bar(
            x - 0.45 + w * i,
            h,
            width=w,
            align="edge",
            label=name,
            log=y_log,
        )

    plt.xticks(x, cols)
    rot = min(3 * len(cols), 90)
    if rot > 10:
        plt.setp(ax.get_xticklabels(), rotation=rot, horizontalalignment="right")

    ax.legend()
    ax.set_title(title)
    ax.yaxis.set_major_formatter(_percent_formatter)

    plt.tight_layout()
    return fig


[docs] class NumericalHist(ColumnMetric[Summaries[np.ndarray], Summaries[np.ndarray]]): name = "numerical" def __init__( self, bins: Sequence[float] | int = 10, min: float | None = None, max: float | None = None, y_log: bool = False, _from_factory: bool = False, **_, ) -> None: super().__init__(_from_factory=_from_factory) self.y_log = y_log if isinstance(bins, Sequence): self.manual = True self.bins_arg = tuple(bins) self.bin_n = len(bins) - 1 self.min_arg = bins[0] self.max_arg = bins[-1] assert ( not min and not max ), "Either min,max or specific buckets can be provided." else: self.manual = False self.bins_arg = None self.bin_n = bins self.min_arg = min self.max_arg = max
[docs] def fit(self, table: str, col: str, data: pd.Series): self.table = table self.col = col # Get maximums self.min = self.min_arg if self.min_arg is not None else data.min() self.max = self.max_arg if self.max_arg is not None else data.max() # In the case the column is NA, x_min and x_max will be NA # Disable visualiser self.disabled = ( self.max is None or self.min is None or np.isnan(self.max) or np.isnan(self.min) ) if self.bins_arg is None: self.bins = np.linspace(self.min, self.max, self.bin_n + 1) else: self.bins = self.bins_arg if self.disabled: return
[docs] def reduce(self, other: "NumericalHist"): if self.manual: return self.min = min(self.min, other.min) self.max = max(self.max, other.max) self.bins = np.linspace(self.min, self.max, self.bin_n + 1)
def _process(self, data: pd.Series): if self.disabled: return np.array([]) return np.histogram(data.astype(np.float32), self.bins, density=True)[0]
[docs] def preprocess(self, wrk: Series, ref: Series): return Summaries(self._process(wrk), self._process(ref))
[docs] def process( self, wrk: Series, ref: Series, syn: Series, pre: Summaries[np.ndarray] ): return pre.replace(syn=self._process(syn))
[docs] def combine(self, summaries: list[Summaries[ndarray]]) -> Summaries[ndarray]: return Summaries( wrk=np.sum([s.wrk for s in summaries], axis=0), ref=np.sum([s.ref for s in summaries], axis=0), syn=np.sum([s.syn for s in summaries if s.syn is not None], axis=0), )
[docs] def visualise(self, data: dict[str, Summaries[np.ndarray]]): if self.disabled: return keys = list(data.keys()) splits = {"wrk": data[keys[0]].wrk, "ref": data[keys[0]].ref} for name, split in data.items(): assert split.syn is not None, f"Received null syn split for split {name}." splits[name] = split.syn load_matplotlib_style() v = _gen_hist( self.y_log, self.col.capitalize(), self.bins, splits, ) if v: mlflow_log_hists(self.table, self.col, v)
[docs] class CategoricalHist(ColumnMetric[Summaries[np.ndarray], Summaries[np.ndarray]]): name = "categorical" def __init__(self, y_log: bool = False, _from_factory: bool = False, **_) -> None: super().__init__(_from_factory=_from_factory) self.y_log = y_log
[docs] def fit(self, table: str, col: str, data: pd.Series): self.table = table self.col = col self.cols = list(data.value_counts().sort_values(ascending=False).index)
[docs] def reduce(self, other: "CategoricalHist"): self.cols = list_unique(self.cols, other.cols)
def _process(self, data: pd.Series): return data.value_counts().reindex(self.cols, fill_value=0).to_numpy() def _combine(self, summaries: list[np.ndarray]) -> np.ndarray: return np.sum(summaries, axis=0)
[docs] def preprocess(self, wrk: Series, ref: Series): return Summaries(self._process(wrk), self._process(ref))
[docs] def process( self, wrk: Series, ref: Series, syn: Series, pre: Summaries[np.ndarray] ): return pre.replace(syn=self._process(syn))
[docs] def combine(self, summaries: list[Summaries[ndarray]]) -> Summaries[ndarray]: return Summaries( wrk=self._combine([s.wrk for s in summaries]), ref=self._combine([s.ref for s in summaries]), syn=self._combine([s.syn for s in summaries if s.syn is not None]), )
[docs] def visualise(self, data: dict[str, Summaries[np.ndarray]]): keys = list(data.keys()) splits = {"wrk": data[keys[0]].wrk, "ref": data[keys[0]].ref} for name, split in data.items(): assert split.syn is not None, f"Received null syn split for split {name}." splits[name] = split.syn load_matplotlib_style() v = _gen_bar( self.y_log, self.col.capitalize(), self.cols, splits, ) mlflow_log_hists(self.table, self.col, v)
[docs] class OrdinalHist(CategoricalHist): name = "ordinal"
[docs] def fit(self, table: str, col: str, data: pd.Series): super().fit(table, col, data) try: self.cols = pd.Index(np.sort(data.unique())) except Exception as e: logger.error( f"Column '{table}:{col}' has non-sortable values:\n{data.unique()}" ) raise e
[docs] class FixedHist(ColumnMetric[Any, Any]): """Fixed values can not be visualised. Removes warning.""" name = "fixed"
[docs] def fit(self, table: str, col: str, data: pd.Series): ...
[docs] def reduce(self, other: AbstractColumnMetric): pass
[docs] def process( self, wrk: Series | DataFrame, ref: Series | DataFrame, syn: Series | DataFrame, pre: Any, ) -> Any: return []
[docs] def combine(self, summaries: list[Any]) -> Any: return []
[docs] class DateData(NamedTuple): span: np.ndarray | None = None weeks: np.ndarray | None = None days: np.ndarray | None = None na: np.ndarray | None = None
[docs] class DateHist(RefColumnMetric[Summaries[DateData], Summaries[DateData]]): name = "date" def __init__( self, span: str = "year", y_log: bool = False, nullable: bool = False, bins: int = 20, max_len: int | None = None, **_, ) -> None: self.span = span.split(".")[0] self.y_log = y_log self.nullable = nullable self.bin_n = bins self.max_len_arg = max_len
[docs] def fit(self, table: str, col: str | tuple[str, ...], data: RefColumnData): ddata = cast(pd.Series, data["data"]) dref = cast(pd.Series | None, data.get("ref", None)) self.table = table self.col = col self.weeks53 = self.span == "year53" if self.weeks53: self.span = "year" if dref is None: self.ref = ddata.min() else: self.ref = None # Find histogram bin edges if self.ref is None: assert dref is not None mask = ~pd.isna(ddata) & ~pd.isna(dref) ddata = ddata[mask] rf_dt = dref[mask].dt else: ddata = ddata[~pd.isna(ddata)] rf_dt = self.ref match self.span: case "year": segs = ddata.dt.year - rf_dt.year case "week": segs = ( (ddata.dt.normalize() - rf_dt.normalize()).dt.days + rf_dt.day_of_week ) // 7 case "day": segs = ( ddata.dt.normalize() - rf_dt.normalize() ).dt.days + rf_dt.day_of_week case _: assert False, f"Span {self.span} not supported by DateHist" segs = segs.astype("int16") if self.max_len_arg is None: self.max_len = float(np.percentile(segs, 90)) else: self.max_len = self.max_len_arg if self.max_len < self.bin_n: self.bin_n = int(self.max_len - 1) self.bins = np.linspace(0, self.max_len, self.bin_n + 1)
[docs] def reduce(self, other: "DateHist"): self.max_len = max(self.max_len, other.max_len) # type: ignore self.bin_n = max(self.bin_n, other.bin_n) if self.max_len < self.bin_n: self.bin_n = int(self.max_len - 1) self.bins = np.linspace(0, self.max_len, self.bin_n + 1)
def _process(self, data: pd.Series, ref: pd.Series | None = None) -> DateData: assert self.ref is not None or ref is not None # Based on date transformer if self.ref is None: assert ref is not None mask = ~pd.isna(data) & ~pd.isna(ref) data = data[mask] rf_dt = ref[mask].dt else: mask = ~pd.isna(data) data = data[mask] rf_dt = self.ref iso = data.dt.isocalendar() iso_rf = rf_dt.isocalendar() if self.ref is not None: rf_year = iso_rf.year rf_day = iso_rf.weekday else: rf_year = iso_rf["year"] rf_day = iso_rf["day"] weeks = iso["week"].astype("int16") - 1 days = iso["day"].astype("int16") - 1 # Push week 53 to next year if not self.weeks53: m = weeks == 52 weeks[m] = 0 match self.span: case "year": span = iso["year"] - rf_year if not self.weeks53: span[m] += 1 # type: ignore span = np.histogram(span, bins=self.bins, density=True)[0] case "week": span = ((data.dt.normalize() - rf_dt.normalize()).dt.days + rf_day) // 7 span = np.histogram(span, bins=self.bins, density=True)[0] case "day": span = (data.dt.normalize() - rf_dt.normalize()).dt.days + rf_day span = np.histogram(span, bins=self.bins, density=True)[0] case _: assert False, f"Span {self.span} not supported by DateHist" weeks = ( weeks.value_counts() .reindex(range(53 if self.weeks53 else 52), fill_value=0) .to_numpy() ) days = days.value_counts().reindex(range(7), fill_value=0).to_numpy() na = None if self.nullable: non_na_rate = np.sum(mask) / len(mask) # type: ignore na = np.array([non_na_rate, 1 - non_na_rate]) return DateData(span, weeks, days, na) def _combine(self, summaries: list[DateData]) -> DateData: return DateData( span=np.sum([s.span for s in summaries if s.span is not None], axis=0), weeks=np.sum([s.weeks for s in summaries if s.weeks is not None], axis=0), days=np.sum([s.days for s in summaries if s.days is not None], axis=0), na=np.sum([s.na for s in summaries if s.na is not None], axis=0), )
[docs] def preprocess( self, wrk: RefColumnData, ref: RefColumnData ) -> Summaries[DateData] | None: return Summaries( self._process(wrk["data"], wrk["ref"]), # type: ignore self._process(ref["data"], ref["ref"]), # type: ignore )
[docs] def process( self, wrk: RefColumnData, ref: RefColumnData, syn: RefColumnData, pre: Summaries[DateData], ) -> Summaries[DateData]: return pre.replace(syn=self._process(syn["data"], syn["ref"])) # type: ignore
[docs] def combine(self, summaries: list[Summaries[DateData]]) -> Summaries[DateData]: return Summaries( wrk=self._combine([s.wrk for s in summaries]), ref=self._combine([s.ref for s in summaries]), syn=self._combine([s.syn for s in summaries if s.syn is not None]), )
def _viz_days(self, data: dict[str, DateData]): return _gen_bar( y_log=self.y_log, title=name_style_title(self.col, "Weekday"), cols=[ "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday", ], counts={n: d.days for n, d in data.items() if d.days is not None}, ) def _viz_weeks(self, data: dict[str, DateData]): bins = np.array(range(54 if self.weeks53 else 53)) - 0.5 return _gen_hist( y_log=self.y_log, title=name_style_title(self.col, "Season"), bins=bins, heights={n: d.weeks for n, d in data.items() if d.weeks is not None}, xticks_x=[2, 15, 28, 41], xticks_label=["Winter", "Spring", "Summer", "Autumn"], ) def _viz_binned(self, data: dict[str, DateData]): return _gen_hist( self.y_log, name_style_title(self.col, f"{self.span.capitalize()}s"), self.bins, {n: d.span for n, d in data.items() if d.span is not None}, ) def _viz_na(self, data: dict[str, DateData]): return _gen_bar( self.y_log, name_style_title(self.col, "NA"), ["Val", "NA"], {n: d.na for n, d in data.items() if d.na is not None}, ) def _visualise(self, data: dict[str, Summaries[DateData]]) -> dict[str, "Figure"]: keys = list(data.keys()) splits = {"wrk": data[keys[0]].wrk, "ref": data[keys[0]].ref} for name, split in data.items(): assert split.syn is not None, f"Received null syn split for split {name}." splits[name] = split.syn s = self.span charts = { f"n{s}s": self._viz_binned(splits), "weeks": self._viz_weeks(splits), "days": self._viz_days(splits), } if self.nullable: charts["na"] = self._viz_na(splits) return charts
[docs] def visualise(self, data: dict[str, Summaries[DateData]]): load_matplotlib_style() v = self._visualise(data) mlflow_log_hists(self.table, name_style_fn(self.col), v)
[docs] class TimeHist(ColumnMetric[Summaries[np.ndarray], Summaries[np.ndarray]]): name = "time" def __init__( self, span: str = "halfhour", y_log: bool = False, _from_factory: bool = False, **_, ) -> None: self.span = span.split(".")[-1] self.y_log = y_log super().__init__(_from_factory=_from_factory)
[docs] def fit(self, table: str, col: str, data: pd.Series): self.table = table self.col = col
[docs] def reduce(self, other: AbstractColumnMetric): pass
def _process(self, data: pd.Series): data = data[~pd.isna(data)] hours = data.dt.hour if self.span == "hour": seg_len = 24 segments = hours else: seg_len = 48 half_hours = data.dt.minute > 29 segments = 2 * hours + half_hours return segments.value_counts().reindex(range(seg_len), fill_value=0).to_numpy() def _combine(self, summaries: list[np.ndarray]) -> np.ndarray: return np.sum(summaries, axis=0)
[docs] def preprocess(self, wrk: Series, ref: Series): return Summaries(self._process(wrk), self._process(ref))
[docs] def process( self, wrk: Series, ref: Series, syn: Series, pre: Summaries[np.ndarray] ): return pre.replace(syn=self._process(syn))
[docs] def combine(self, summaries: list[Summaries[ndarray]]) -> Summaries[ndarray]: return Summaries( wrk=np.sum([s.wrk for s in summaries], axis=0), ref=np.sum([s.ref for s in summaries], axis=0), syn=np.sum([s.syn for s in summaries if s.syn is not None], axis=0), )
def _visualise(self, data: dict[str, Summaries[np.ndarray]]) -> "Figure | None": keys = list(data.keys()) splits = {"wrk": data[keys[0]].wrk, "ref": data[keys[0]].ref} for name, split in data.items(): assert split.syn is not None, f"Received null syn split for split {name}." splits[name] = split.syn if self.span == "hour": seg_len = 24 mult = 1 else: seg_len = 48 mult = 2 bins = np.array(range(seg_len + 1)) - 0.5 hours = [0, 3, 6, 9, 12, 15, 18, 21, 24] tick_x = mult * np.array(hours) tick_label = [f"{hour:02d}:00" for hour in hours] col = self.col if not isinstance(col, str): col = " ".join(col) return _gen_hist( y_log=self.y_log, title=f"{col.capitalize()} Time", bins=bins, heights=splits, xticks_x=tick_x, xticks_label=tick_label, )
[docs] def visualise(self, data: dict[str, Summaries[np.ndarray]]): load_matplotlib_style() v = self._visualise(data) if v: mlflow_log_hists(self.table, self.col, v)
[docs] class DatetimeData(NamedTuple): date: DateData time: np.ndarray
[docs] class DatetimeHist( RefColumnMetric[ tuple[Summaries[DateData], Summaries[ndarray]], tuple[Summaries[DateData], Summaries[ndarray]], ] ): name = "datetime" def __init__(self, *args, _from_factory: bool = False, **kwargs) -> None: super().__init__(*args, _from_factory=_from_factory, **kwargs) self.date = DateHist(*args, _from_factory=_from_factory, **kwargs) self.time = TimeHist(*args, _from_factory=_from_factory, **kwargs)
[docs] def fit(self, table: str, col: str, data: RefColumnData): self.table = table self.col = col self.date.fit(table=table, col=col, data=data) self.time.fit(table=table, col=col, data=cast(pd.Series, data["data"]))
[docs] def preprocess( self, wrk: RefColumnData, ref: RefColumnData ) -> tuple[Summaries[DateData], Summaries[ndarray]] | None: return ( self.date.preprocess(wrk, ref), self.time.preprocess(wrk["data"], ref["data"]), # type: ignore )
[docs] def process( self, wrk: RefColumnData, ref: RefColumnData, syn: RefColumnData, pre: tuple[Summaries[DateData], Summaries[ndarray]], ) -> tuple[Summaries[DateData], Summaries[ndarray]]: return ( self.date.process(wrk, ref, syn, pre[0]), self.time.process(wrk["data"], ref["data"], syn["data"], pre[1]), # type: ignore )
[docs] def combine( self, summaries: list[tuple[Summaries[DateData], Summaries[ndarray]]] ) -> tuple[Summaries[DateData], Summaries[ndarray]]: return ( self.date.combine([s[0] for s in summaries]), self.time.combine([s[1] for s in summaries]), )
[docs] def visualise( self, data: dict[str, tuple[Summaries[DateData], Summaries[ndarray]]] ): load_matplotlib_style() date_fig = self.date._visualise({n: c[0] for n, c in data.items()}) time_fig = self.time._visualise({n: c[1] for n, c in data.items()}) figs = {**date_fig, "time": time_fig} mlflow_log_hists(self.table, self.col, {k: v for k, v in figs.items() if v})
[docs] class SeqHist( SeqColumnMetric[ Summaries[np.ndarray], Summaries[np.ndarray], ] ): name = "seq" def __init__( self, y_log=False, max_len: int | None = None, _from_factory: bool = False, **_ ) -> None: super().__init__(_from_factory=_from_factory, **_) self.y_log = y_log self.max_len_arg = max_len
[docs] def fit( self, table: str, col: str | tuple[str, ...], seq_val: SeqValue | None, data: SeqColumnData, ): self.max_len = self.max_len_arg or int(data["seq"].max() + 1) self.table = table self.col = col assert seq_val is not None self.parent = seq_val.table
[docs] def reduce(self, other: "SeqHist"): self.max_len = max(self.max_len, other.max_len)
def _process(self, data: SeqColumnData): return ( data["seq"] .groupby(data["ids"][self.parent]) .max() .value_counts() .reindex(range(self.max_len), fill_value=0) .to_numpy() )
[docs] def preprocess(self, wrk: SeqColumnData, ref: SeqColumnData) -> Summaries[ndarray]: return Summaries(self._process(wrk), self._process(ref))
[docs] def combine(self, summaries: list[Summaries[ndarray]]) -> Summaries[ndarray]: return Summaries( wrk=np.sum([s.wrk for s in summaries], axis=0), ref=np.sum([s.ref for s in summaries], axis=0), syn=np.sum([s.syn for s in summaries if s.syn is not None], axis=0), )
[docs] def process( self, wrk: SeqColumnData, ref: SeqColumnData, syn: SeqColumnData, pre: Summaries[ndarray], ) -> Summaries[ndarray]: return pre.replace(syn=self._process(syn))
[docs] def visualise(self, data: dict[str, Summaries[ndarray]]): load_matplotlib_style() keys = list(data.keys()) splits = {"wrk": data[keys[0]].wrk, "ref": data[keys[0]].ref} for name, split in data.items(): splits[name] = split.syn f = _gen_hist( self.y_log, f"N-1 with parent '{self.parent}'", np.arange(self.max_len + 1) - 0.5, splits, ) if f: mlflow_log_hists(self.table, f"_n_per_{self.parent}", f)