Source code for pasteur.extras.metrics.visual

from typing import Any, NamedTuple, TypeVar, cast

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.figure import Figure
from numpy import ndarray
from pandas.core.frame import DataFrame
from pandas.core.series import Series

from ...metadata import ColumnMeta, Metadata
from ...metric import (
    AbstractColumnMetric,
    ColumnMetric,
    RefColumnData,
    RefColumnMetric,
    Summaries,
)
from ...utils import list_unique
from ...utils.mlflow import load_matplotlib_style, mlflow_log_hists

A = TypeVar("A")


def _percent_formatter(x, pos):
    return f"{100*x:.1f}%"


def _gen_hist(
    meta: ColumnMeta,
    title: str,
    bins: np.ndarray,
    heights: dict[str, np.ndarray],
    xticks_x=None,
    xticks_label=None,
):
    fig, ax = plt.subplots()
    x = bins[:-1]
    w = (x[1] - x[0]) / len(heights)

    is_log = meta.metrics.y_log == True
    for i, (name, h) in enumerate(heights.items()):
        ax.bar(x + w * i, h / h.sum(), width=w, label=name, log=is_log)

    ax.legend()
    ax.set_title(title)
    ax.yaxis.set_major_formatter(_percent_formatter)

    if xticks_x is not None:
        ax.set_xticks(xticks_x, xticks_label)

    plt.tight_layout()
    return fig


def _gen_bar(
    meta: ColumnMeta, title: str, cols: list[str], counts: dict[str, np.ndarray]
):
    fig, ax = plt.subplots()

    x = np.array(range(len(cols)))
    w = 0.9 / len(counts)

    is_log = meta.metrics.y_log == True
    for i, (name, c) in enumerate(counts.items()):
        h = c / c.sum() if c.sum() > 0 else c
        ax.bar(
            x - 0.45 + w * i,
            h,
            width=w,
            align="edge",
            label=name,
            log=is_log,
        )

    plt.xticks(x, cols)
    rot = min(3 * len(cols), 90)
    if rot > 10:
        plt.setp(ax.get_xticklabels(), rotation=rot, horizontalalignment="right")

    ax.legend()
    ax.set_title(title)
    ax.yaxis.set_major_formatter(_percent_formatter)

    plt.tight_layout()
    return fig


[docs]class NumericalHist(ColumnMetric[Summaries[np.ndarray], Summaries[np.ndarray]]): name = "numerical"
[docs] def fit(self, table: str, col: str, meta: ColumnMeta, data: pd.Series): self.meta = meta self.table = table self.col = col args = meta.args metrics = meta.metrics # Get maximums if metrics.x_min is not None: x_min = metrics.x_min elif "min" in args: x_min = args["min"] else: x_min = data.min() if metrics.x_max is not None: x_max = metrics.x_max elif "max" in args: x_max = args["max"] else: x_max = data.max() # In the case the column is NA, x_min and x_max will be NA # Disable visualiser self.disabled = ( x_max is None or x_min is None or np.isnan(x_max) or np.isnan(x_min) ) if self.disabled: return main_param = args.get("main_param", None) if main_param and (isinstance(main_param, int)): self.bin_n = main_param else: self.bin_n = args.get("bins", 20) self.x_min = x_min self.x_max = x_max self.bins = np.histogram_bin_edges(data, bins=self.bin_n, range=(x_min, x_max))
[docs] def reduce(self, other: "NumericalHist"): self.x_min = min(self.x_min, other.x_min) self.x_max = max(self.x_max, other.x_max) self.bins = np.linspace(self.x_min, self.x_max, self.bin_n + 1)
def _process(self, data: pd.Series): if self.disabled: return np.array([]) return np.histogram(data.astype(np.float32), self.bins, density=True)[0]
[docs] def preprocess(self, wrk: Series, ref: Series): return Summaries(self._process(wrk), self._process(ref))
[docs] def process( self, wrk: Series, ref: Series, syn: Series, pre: Summaries[np.ndarray] ): return pre.replace(syn=self._process(syn))
[docs] def combine(self, summaries: list[Summaries[ndarray]]) -> Summaries[ndarray]: return Summaries( wrk=np.sum([s.wrk for s in summaries], axis=0), ref=np.sum([s.ref for s in summaries], axis=0), syn=np.sum([s.syn for s in summaries if s.syn is not None], axis=0), )
[docs] def visualise(self, data: dict[str, Summaries[np.ndarray]]): if self.disabled: return keys = list(data.keys()) splits = {"wrk": data[keys[0]].wrk, "ref": data[keys[0]].ref} for name, split in data.items(): assert split.syn is not None, f"Received null syn split for split {name}." splits[name] = split.syn load_matplotlib_style() v = _gen_hist( self.meta, self.col.capitalize(), self.bins, splits, ) mlflow_log_hists(self.table, self.col, v)
[docs]class CategoricalHist(ColumnMetric[Summaries[np.ndarray], Summaries[np.ndarray]]): name = "categorical"
[docs] def fit(self, table: str, col: str, meta: ColumnMeta, data: pd.Series): self.meta = meta self.table = table self.col = col self.cols = list(data.value_counts().sort_values(ascending=False).index)
[docs] def reduce(self, other: "CategoricalHist"): self.cols = list_unique(self.cols, other.cols)
def _process(self, data: pd.Series): return data.value_counts().reindex(self.cols, fill_value=0).to_numpy() def _combine(self, summaries: list[np.ndarray]) -> np.ndarray: return np.sum(summaries, axis=0)
[docs] def preprocess(self, wrk: Series, ref: Series): return Summaries(self._process(wrk), self._process(ref))
[docs] def process( self, wrk: Series, ref: Series, syn: Series, pre: Summaries[np.ndarray] ): return pre.replace(syn=self._process(syn))
[docs] def combine(self, summaries: list[Summaries[ndarray]]) -> Summaries[ndarray]: return Summaries( wrk=self._combine([s.wrk for s in summaries]), ref=self._combine([s.ref for s in summaries]), syn=self._combine([s.syn for s in summaries if s.syn is not None]), )
[docs] def visualise(self, data: dict[str, Summaries[np.ndarray]]): keys = list(data.keys()) splits = {"wrk": data[keys[0]].wrk, "ref": data[keys[0]].ref} for name, split in data.items(): assert split.syn is not None, f"Received null syn split for split {name}." splits[name] = split.syn load_matplotlib_style() v = _gen_bar( self.meta, self.col.capitalize(), self.cols, splits, ) mlflow_log_hists(self.table, self.col, v)
[docs]class OrdinalHist(CategoricalHist): name = "ordinal"
[docs] def fit(self, table: str, col: str, meta: ColumnMeta, data: pd.Series): super().fit(table, col, meta, data) self.cols = pd.Index(np.sort(data.unique()))
[docs]class FixedHist(ColumnMetric[Any, Any]): """Fixed values can not be visualised. Removes warning.""" name = "fixed"
[docs] def fit(self, table: str, col: str, meta: ColumnMeta, data: pd.Series): ...
[docs] def reduce(self, other: AbstractColumnMetric): pass
[docs] def process( self, wrk: Series | DataFrame, ref: Series | DataFrame, syn: Series | DataFrame, pre: Any, ) -> Any: return []
[docs] def combine(self, summaries: list[Any]) -> Any: return []
[docs]class DateData(NamedTuple): span: np.ndarray | None = None weeks: np.ndarray | None = None days: np.ndarray | None = None na: np.ndarray | None = None
[docs]class DateHist(RefColumnMetric[Summaries[DateData], Summaries[DateData]]): name = "date"
[docs] def fit(self, table: str, col: str | tuple[str], meta: ColumnMeta, data: RefColumnData): ref = data['ref'] data = data['data'] self.table = table self.col = col self.meta = meta if "main_param" in meta.args: self.span = meta.args["main_param"].split(".")[0] elif "span" in meta.args: self.span = meta.args["span"].split(".")[0] else: self.span = "year" self.weeks53 = self.span == "year53" if self.weeks53: self.span = "year" self.max_len = meta.args.get("max_len", None) self.bin_n = meta.args.get("bins", 20) if ref is None: self.ref = data.min() else: self.ref = None # Find histogram bin edges if self.ref is None: assert ref is not None mask = ~pd.isna(data) & ~pd.isna(ref) data = data[mask] rf_dt = ref[mask].dt else: data = data[~pd.isna(data)] rf_dt = self.ref match self.span: case "year": segs = data.dt.year - rf_dt.year case "week": segs = ( (data.dt.normalize() - rf_dt.normalize()).dt.days + rf_dt.day_of_week ) // 7 case "day": segs = ( data.dt.normalize() - rf_dt.normalize() ).dt.days + rf_dt.day_of_week case _: assert False, f"Span {self.span} not supported by DateHist" segs = segs.astype("int16") if self.max_len is None: self.max_len = float(np.percentile(segs, 90)) if self.max_len < self.bin_n: self.bin_n = int(self.max_len - 1) self.bins = np.histogram_bin_edges( segs, bins=self.bin_n, range=(0, self.max_len) ) self.nullable = meta.args.get("nullable", False)
[docs] def reduce(self, other: "DateHist"): self.max_len = max(self.max_len, other.max_len) # type: ignore self.bin_n = max(self.bin_n, other.bin_n) if self.max_len < self.bin_n: self.bin_n = int(self.max_len - 1) self.bins = np.linspace(0, self.max_len, self.bin_n + 1)
def _process(self, data: pd.Series, ref: pd.Series | None = None) -> DateData: assert self.ref is not None or ref is not None # Based on date transformer if self.ref is None: assert ref is not None mask = ~pd.isna(data) & ~pd.isna(ref) data = data[mask] rf_dt = ref[mask].dt else: mask = ~pd.isna(data) data = data[mask] rf_dt = self.ref iso = data.dt.isocalendar() iso_rf = rf_dt.isocalendar() if self.ref is not None: rf_year = iso_rf.year rf_day = iso_rf.weekday else: rf_year = iso_rf["year"] rf_day = iso_rf["day"] weeks = iso["week"].astype("int16") - 1 days = iso["day"].astype("int16") - 1 # Push week 53 to next year if not self.weeks53: m = weeks == 52 weeks[m] = 0 match self.span: case "year": span = iso["year"] - rf_year if not self.weeks53: span[m] += 1 # type: ignore span = np.histogram(span, bins=self.bins, density=True)[0] case "week": span = ((data.dt.normalize() - rf_dt.normalize()).dt.days + rf_day) // 7 span = np.histogram(span, bins=self.bins, density=True)[0] case "day": span = (data.dt.normalize() - rf_dt.normalize()).dt.days + rf_day span = np.histogram(span, bins=self.bins, density=True)[0] case _: assert False, f"Span {self.span} not supported by DateHist" weeks = ( weeks.value_counts() .reindex(range(53 if self.weeks53 else 52), fill_value=0) .to_numpy() ) days = days.value_counts().reindex(range(7), fill_value=0).to_numpy() na = None if self.nullable: non_na_rate = np.sum(mask) / len(mask) # type: ignore na = np.array([non_na_rate, 1 - non_na_rate]) return DateData(span, weeks, days, na) def _combine(self, summaries: list[DateData]) -> DateData: return DateData( span=np.sum([s.span for s in summaries if s.span is not None], axis=0), weeks=np.sum([s.weeks for s in summaries if s.weeks is not None], axis=0), days=np.sum([s.days for s in summaries if s.days is not None], axis=0), na=np.sum([s.na for s in summaries if s.na is not None], axis=0), )
[docs] def preprocess( self, wrk: RefColumnData, ref: RefColumnData ) -> Summaries[DateData] | None: return Summaries( self._process(wrk["data"], wrk["ref"]), # type: ignore self._process(ref["data"], ref["ref"]), # type: ignore )
[docs] def process( self, wrk: RefColumnData, ref: RefColumnData, syn: RefColumnData, pre: Summaries[DateData], ) -> Summaries[DateData]: return pre.replace(syn=self._process(syn["data"], syn["ref"])) # type: ignore
[docs] def combine(self, summaries: list[Summaries[DateData]]) -> Summaries[DateData]: return Summaries( wrk=self._combine([s.wrk for s in summaries]), ref=self._combine([s.ref for s in summaries]), syn=self._combine([s.syn for s in summaries if s.syn is not None]), )
def _viz_days(self, data: dict[str, DateData]): return _gen_bar( meta=self.meta, title=f"{self.col.capitalize()} Weekday", cols=[ "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday", ], counts={n: d.days for n, d in data.items() if d.days is not None}, ) def _viz_weeks(self, data: dict[str, DateData]): bins = np.array(range(54 if self.weeks53 else 53)) - 0.5 return _gen_hist( meta=self.meta, title=f"{self.col.capitalize()} Season", bins=bins, heights={n: d.weeks for n, d in data.items() if d.weeks is not None}, xticks_x=[2, 15, 28, 41], xticks_label=["Winter", "Spring", "Summer", "Autumn"], ) def _viz_binned(self, data: dict[str, DateData]): return _gen_hist( self.meta, f"{self.col.capitalize()} {self.span.capitalize()}s", self.bins, {n: d.span for n, d in data.items() if d.span is not None}, ) def _viz_na(self, data: dict[str, DateData]): return _gen_bar( self.meta, f"{self.col.capitalize()} NA", ["Val", "NA"], {n: d.na for n, d in data.items() if d.na is not None}, ) def _visualise(self, data: dict[str, Summaries[DateData]]) -> dict[str, Figure]: keys = list(data.keys()) splits = {"wrk": data[keys[0]].wrk, "ref": data[keys[0]].ref} for name, split in data.items(): assert split.syn is not None, f"Received null syn split for split {name}." splits[name] = split.syn s = self.span charts = { f"n{s}s": self._viz_binned(splits), "weeks": self._viz_weeks(splits), "days": self._viz_days(splits), } if self.nullable: charts["na"] = self._viz_na(splits) return charts
[docs] def visualise(self, data: dict[str, Summaries[DateData]]): load_matplotlib_style() v = self._visualise(data) mlflow_log_hists(self.table, self.col, v)
[docs]class TimeHist(ColumnMetric[Summaries[np.ndarray], Summaries[np.ndarray]]): name = "time"
[docs] def fit(self, table: str, col: str, meta: ColumnMeta, data: pd.Series): self.meta = meta self.table = table self.col = col if "main_param" in meta.args: self.span = meta.args["main_param"].split(".")[-1] elif "span" in meta.args: self.span = meta.args["span"].split(".")[-1] else: self.span = "halfhour"
[docs] def reduce(self, other: AbstractColumnMetric): pass
def _process(self, data: pd.Series): data = data[~pd.isna(data)] hours = data.dt.hour if self.span == "hour": seg_len = 24 segments = hours else: seg_len = 48 half_hours = data.dt.minute > 29 segments = 2 * hours + half_hours return segments.value_counts().reindex(range(seg_len), fill_value=0).to_numpy() def _combine(self, summaries: list[np.ndarray]) -> np.ndarray: return np.sum(summaries, axis=0)
[docs] def preprocess(self, wrk: Series, ref: Series): return Summaries(self._process(wrk), self._process(ref))
[docs] def process( self, wrk: Series, ref: Series, syn: Series, pre: Summaries[np.ndarray] ): return pre.replace(syn=self._process(syn))
[docs] def combine(self, summaries: list[Summaries[ndarray]]) -> Summaries[ndarray]: return Summaries( wrk=np.sum([s.wrk for s in summaries], axis=0), ref=np.sum([s.ref for s in summaries], axis=0), syn=np.sum([s.syn for s in summaries if s.syn is not None], axis=0), )
def _visualise(self, data: dict[str, Summaries[np.ndarray]]) -> Figure: keys = list(data.keys()) splits = {"wrk": data[keys[0]].wrk, "ref": data[keys[0]].ref} for name, split in data.items(): assert split.syn is not None, f"Received null syn split for split {name}." splits[name] = split.syn if self.span == "hour": seg_len = 24 mult = 1 else: seg_len = 48 mult = 2 bins = np.array(range(seg_len + 1)) - 0.5 hours = [0, 3, 6, 9, 12, 15, 18, 21, 24] tick_x = mult * np.array(hours) tick_label = [f"{hour:02d}:00" for hour in hours] return _gen_hist( meta=self.meta, title=f"{self.col.capitalize()} Time", bins=bins, heights=splits, xticks_x=tick_x, xticks_label=tick_label, )
[docs] def visualise(self, data: dict[str, Summaries[np.ndarray]]): load_matplotlib_style() v = self._visualise(data) mlflow_log_hists(self.table, self.col, v)
[docs]class DatetimeData(NamedTuple): date: DateData time: np.ndarray
[docs]class DatetimeHist( RefColumnMetric[ tuple[Summaries[DateData], Summaries[ndarray]], tuple[Summaries[DateData], Summaries[ndarray]], ] ): name = "datetime" def __init__(self, *args, _from_factory: bool = False, **kwargs) -> None: super().__init__(*args, _from_factory=_from_factory, **kwargs) self.date = DateHist(*args, _from_factory=_from_factory, **kwargs) self.time = TimeHist(*args, _from_factory=_from_factory, **kwargs)
[docs] def fit( self, table: str, col: str, meta: ColumnMeta, data: RefColumnData ): self.table = table self.col = col self.date.fit(table=table, col=col, meta=meta, data=data) self.time.fit(table=table, col=col, meta=meta, data=cast(pd.Series, data['data']))
[docs] def preprocess( self, wrk: RefColumnData, ref: RefColumnData ) -> tuple[Summaries[DateData], Summaries[ndarray]] | None: return ( self.date.preprocess(wrk, ref), self.time.preprocess(wrk["data"], ref["data"]), # type: ignore )
[docs] def process( self, wrk: RefColumnData, ref: RefColumnData, syn: RefColumnData, pre: tuple[Summaries[DateData], Summaries[ndarray]], ) -> tuple[Summaries[DateData], Summaries[ndarray]]: return ( self.date.process(wrk, ref, syn, pre[0]), self.time.process(wrk["data"], ref["data"], syn["data"], pre[1]), # type: ignore )
[docs] def combine( self, summaries: list[tuple[Summaries[DateData], Summaries[ndarray]]] ) -> tuple[Summaries[DateData], Summaries[ndarray]]: return ( self.date.combine([s[0] for s in summaries]), self.time.combine([s[1] for s in summaries]), )
[docs] def visualise(self, data: dict[str, tuple[Summaries[DateData], Summaries[ndarray]]]): load_matplotlib_style() date_fig = self.date._visualise({n: c[0] for n, c in data.items()}) time_fig = self.time._visualise({n: c[1] for n, c in data.items()}) mlflow_log_hists(self.table, self.col, {**date_fig, "time": time_fig})