from collections import defaultdict
from copy import copy
from functools import partial
from typing import Mapping, NamedTuple, cast
import numpy as np
import pandas as pd
from ..attribute import (
Attribute,
Attributes,
CommonValue,
DatasetAttributes,
GenAttribute,
Grouping,
SeqAttribute,
SeqAttributes,
SeqValue,
StratifiedValue,
get_dtype,
)
from ..hierarchy import RebalancedValue
from ..marginal.numpy import TableSelector
from ..marginal.oracle import PreprocessFun
from ..utils import LazyChunk
from ..utils.data import LazyDataset, LazyPartition, data_to_tables
from .chains import TablePartition, TableVersion, calculate_table_chains
from .reduce import (
TableMeta,
TablePartition,
TableVersion,
_calculate_stripped_meta,
merge_versions,
merge_versions_heuristic,
)
def _unroll_sequence(
seq_name: str,
order: int,
ids: pd.DataFrame,
data: pd.DataFrame,
seq: pd.Series | None = None,
stable: str | None = None,
ptable: str | None = None,
):
_IDX_NAME = "_id_lkjijk"
_JOIN_NAME = "_id_zdjwk"
if seq is not None:
assert stable
ids_seq = ids.merge(seq, left_on=stable, right_index=True, how="right")
seq = ids_seq[seq_name]
ids_seq = ids_seq.reset_index(names=_IDX_NAME)
else:
seq = data[seq_name]
ids_seq = ids.join(seq, how="right").reset_index(names=_IDX_NAME)
out = {}
for i in range(order):
# Create join with previous seq
ids_seq_prev = ids.join(seq + i + 1, how="right").reset_index(names=_JOIN_NAME)
if ptable:
# Here, if we use a lower table, the merge below fails
cols = [ptable]
else:
cols = ids.columns
join_ids = ids_seq.merge(
ids_seq_prev, on=[*cols, seq_name], how="inner"
).set_index(_IDX_NAME)[[_JOIN_NAME]]
ref_df = join_ids.join(data, on=_JOIN_NAME).drop(columns=[_JOIN_NAME])
if seq_name in ref_df:
ref_df = ref_df.drop(columns=[seq_name])
# Rebase discrete columns to new stratified structure with offset
idx_cols = [
c for c, t in ref_df.dtypes.items() if pd.api.types.is_integer_dtype(t)
]
idx_dtypes = {c: get_dtype(ref_df[c].max() + 2 + i) for c in idx_cols}
idx_df = (ref_df[idx_cols].astype(idx_dtypes) + 1 + i).reindex(
ids.index, fill_value=0
)
# Fill value should be set depending on the history of each column
for j in range(i):
idx_df.loc[seq == j + 1] = i - j
# Re-index continuous cols with NaN
cnt_cols = [c for c in ref_df.columns if c not in idx_cols]
cnt_df = ref_df[cnt_cols].reindex(ids.index)
# Concat and fix
ref_df = pd.concat([idx_df, cnt_df], axis=1)
ref_df.index.name = data.index.name
out[i] = ref_df
return out
def _gen_history(
ver: TableVersion | TablePartition,
tables: dict[str, LazyChunk],
ids: dict[str, LazyChunk],
meta: dict[str, TableMeta],
_out: dict[str | tuple[str, int], pd.DataFrame] | None = None,
) -> dict[str | tuple[str, int], pd.DataFrame]:
if _out is None:
_out = {}
# Handle table partitions
if isinstance(ver, TablePartition):
if ver.table.name not in _out:
_gen_history(ver.table, tables, ids, meta, _out)
col = _out[ver.table.name][meta[ver.table.name].partition]
if len(ver.partitions) == 1:
# Faster than isin for single partition
mask = col == ver.partitions[0]
else:
mask = col.isin(ver.partitions)
_out[ver.table.name] = _out[ver.table.name][mask]
elif ver.name not in _out:
sequence = meta[ver.name].sequence
order = meta[ver.name].order
table = tables[ver.name]()
_out[ver.name] = table.copy()
if order and sequence:
seq_hist = _unroll_sequence(sequence, order, ids[ver.name](), table)
for o, data in seq_hist.items():
_out[(ver.name, o)] = data
_out[ver.name][sequence] = table[sequence].clip(upper=order)
for parent in ver.parents:
_gen_history(parent, tables, ids, meta, _out)
return _out
[docs]
def gen_history(
parents: tuple[TableVersion | TablePartition, ...],
tables: dict[str, LazyChunk],
ids: dict[str, LazyChunk],
meta: dict[str, TableMeta],
):
out = {}
for p in parents:
_gen_history(p, tables, ids, meta, out)
return out
def _recurse_unroll_groups(
unroll_ofs: tuple[int, ...],
cmn: Grouping,
groups: dict[str | tuple[str, ...], Grouping],
out,
cmn_ofs=0,
ofs=None,
):
if ofs is None:
ofs = defaultdict(lambda: 0)
if out is None:
out = {}
for i, v in enumerate(cmn):
if isinstance(v, str):
# If this common val is unrolled, update output
if cmn_ofs in unroll_ofs:
out[cmn_ofs] = {}
for name, g in groups.items():
if isinstance(g[i], str):
out[cmn_ofs][name] = ofs[name]
else:
out[cmn_ofs][name] = (
ofs[name],
StratifiedValue(
f"{name}.o{cmn_ofs:03d}", Grouping("cat", [None, g[i]])
),
)
# Always update ofsets
for name, g in groups.items():
gi = g[i]
if isinstance(gi, str):
ofs[name] += 1
else:
ofs[name] += gi.get_domain(0)
cmn_ofs += 1
else:
cmn_ofs = _recurse_unroll_groups(
unroll_ofs,
v,
{n: cast(Grouping, g[i]) for n, g in groups.items()},
out,
cmn_ofs,
ofs,
)
return cmn_ofs
[docs]
def recurse_unroll_attr(unrolls: tuple[int, ...], attrs: Attributes):
attr = None
for attr in attrs.values():
if attr.unroll:
break
assert attr and attr.unroll
if attr.common is not None:
cval = attr.common
else:
assert (
len(attr.vals) == 1
), f"If a common val is not provided, there should only be a single value."
cval = next(iter(attr.vals))
assert isinstance(
cval, StratifiedValue
), "Unrolling is supported only for stratified values for now."
cmn = cval.head
groups = {}
for along in [attr, *[attrs[n] for n in attr.along]]:
for name, val in along.vals.items():
if name == cval.name:
continue # skip common val if attr.common was none
assert isinstance(
val, StratifiedValue
), "For unrolling, all values should be StratifiedValues"
groups[name] = val.head
unrolled = {}
_recurse_unroll_groups(unrolls, cmn, groups, unrolled)
new_attrs = {}
cmn = {}
cols = defaultdict(dict)
ofs = defaultdict(dict)
base_name = (attr.name,) if isinstance(attr.name, str) else attr.name
for cmn_ofs, vals in unrolled.items():
new_name = (*base_name, cmn_ofs)
new_vals = []
for name, t in vals.items():
if isinstance(t, tuple):
ofs[cmn_ofs][name] = t[0]
cols[cmn_ofs][name] = t[1].name
new_vals.append(t[1])
else:
ofs[cmn_ofs][name] = t
cmn_name = f"{'.'.join(base_name)}.o{cmn_ofs:03d}"
cmn[cmn_ofs] = cmn_name
if new_vals:
new_attrs[new_name] = Attribute(
new_name, new_vals, CommonValue(cmn_name, na=True)
)
else:
new_attrs[new_name] = Attribute(new_name, [CommonValue(cmn_name, na=True)])
return new_attrs, cmn, cols, ofs
[docs]
def SeqCommonValue(name: str, order: int):
g = f"O{order}"
for ord in reversed(range(order)):
title = f"O{ord}"
g = Grouping("cat", [title, g], title=title)
return StratifiedValue(name, cast(Grouping, g))
[docs]
def convert_to_seq_val(s: StratifiedValue, order: int):
g = s.head
for ord in reversed(range(order + 1)):
g = Grouping("cat", [f"H{ord}", g])
return StratifiedValue(s.name, g)
[docs]
def convert_rebalanced_to_seq_val(s: RebalancedValue, order: int):
order += 1
reb = copy(s)
reb.counts = np.concatenate([[0] * order, s.counts])
reb.grouping = np.concatenate(
[
[[0 for _ in range(order)] for _ in range(s.grouping.shape[0])],
s.grouping.astype(np.uint16) + order,
],
axis=1,
)
# TODO: verify this works
reb.common_sizes = [[1 for _ in range(order)] + v for v in s.common_sizes]
reb.common_groups = [list(v) for v in reb.grouping]
reb.domains = [d + order for d in s.domains]
return reb
[docs]
def convert_to_seq_attr(attrs: Attributes, order: int) -> Attributes:
out = {}
for name, attr in attrs.items():
vals = []
for v in attr.vals.values():
if isinstance(v, SeqValue):
continue
if isinstance(v, RebalancedValue):
vals.append(convert_rebalanced_to_seq_val(v, order))
elif isinstance(v, StratifiedValue):
vals.append(convert_to_seq_val(v, order))
else:
assert (
False
), f"Attr. '{v.name}' is of type '{type(v)}', which is not supported."
if attr.common:
v = attr.common
if isinstance(v, RebalancedValue):
cmn = convert_rebalanced_to_seq_val(v, order)
elif isinstance(v, StratifiedValue):
cmn = convert_to_seq_val(v, order)
else:
assert (
False
), f"Attr. '{v.name}' is of type '{type(v)}', which is not supported."
else:
cmn = None
if len(vals):
out[name] = Attribute(attr.name, vals, cmn, attr.unroll, attr.along)
return out
[docs]
def strip_seq_vals(attrs: Attributes):
out = {}
for n, a in attrs.items():
vals = []
for v in a.vals.values():
if not isinstance(v, SeqValue):
vals.append(v)
if len(vals):
out[n] = Attribute(a.name, vals, a.common, a.unroll, a.along)
return out
def _gen_history_attributes(
ver: TableVersion | TablePartition,
attrs: dict[str, Attributes],
_out: dict[str, SeqAttributes | Attributes] | None = None,
):
if _out is None:
_out = {}
if isinstance(ver, TablePartition):
_gen_history_attributes(ver.table, attrs, _out)
elif ver.name not in _out:
seq = None
for attr in attrs[ver.name].values():
for val in attr.vals.values():
if isinstance(val, SeqValue):
seq = val
break
if seq is not None:
order = seq.order
new_attrs = strip_seq_vals(attrs[ver.name])
hist = {}
assert order is not None, f"Seq value's '{seq.name}' order is None"
for i in range(order):
hist[i] = convert_to_seq_attr(attrs[ver.name], i)
_out[ver.name] = SeqAttributes(
order, SeqCommonValue(seq.name, order), new_attrs, hist
)
else:
_out[ver.name] = attrs[ver.name]
for p in ver.parents:
_gen_history_attributes(p, attrs, _out)
return _out
[docs]
def get_parents(ver):
# Get parents for building sequences
# This is a bit dirty coding-wise
parent = None
gparent = None
ggparent = None
if ver.parents:
parent = ver.parents[0]
if hasattr(parent, "table"):
parent = getattr(parent, "table")
if parent.parents: # type: ignore
gparent = parent.parents[0] # type: ignore
if hasattr(gparent, "table"):
gparent = getattr(gparent, "table")
if gparent.parents: # type: ignore
ggparent = gparent.parents[0] # type: ignore
if hasattr(ggparent, "table"):
ggparent = getattr(ggparent, "table")
ggparent = getattr(ggparent, "name")
gparent = getattr(gparent, "name")
parent = getattr(parent, "name")
return parent, gparent, ggparent
[docs]
def gen_history_attributes(
parents: tuple[TableVersion | TablePartition, ...],
attrs: dict[str, Attributes],
):
out = {}
for p in parents:
_gen_history_attributes(p, attrs, out)
return out
[docs]
def generate_fit_attrs(
ver: TableVersion, attrs: dict[str, Attributes], ctx: bool, no_hist: bool = False
) -> DatasetAttributes | None:
# Don't generate context tables for top level tables
if not ver.parents and ctx:
return None
meta = _calculate_stripped_meta(attrs)
unroll = None
seq = None
seq_repeat = False
for attr in attrs[ver.name].values():
if attr.unroll:
unroll = attr.name
seq_repeat = attr.seq_repeat
for v in attr.vals.values():
if isinstance(v, SeqValue):
seq = v
assert not (unroll and seq), f"Both unrolling and sequence found on the same table."
if no_hist:
hist = {}
else:
hist = gen_history_attributes(ver.parents, attrs)
if unroll:
if ctx:
assert ver.unrolls
synth = recurse_unroll_attr(ver.unrolls, attrs[ver.name])
hist[None] = synth[0]
else:
unroll_attrs = {unroll: attrs[ver.name][unroll]}
other_attrs = {}
along = attrs[ver.name][unroll].along
for name, attr in attrs[ver.name].items():
if name in along or name == unroll:
unroll_attrs[name] = attr
else:
other_attrs[name] = attr
hist[ver.name] = unroll_attrs
hist[None] = other_attrs
elif ctx:
# Context tables for normal tables and sequence tables same
assert ver.children is not None
hist[None] = {
f"{ver.name}_n": GenAttribute(
f"{ver.name}_n", seq.max if seq and seq.max else ver.children
)
}
elif seq:
order = seq.order
ahist = {}
assert order is not None, f"Table '{ver.name}'s order is None"
for i in range(order):
ahist[i] = convert_to_seq_attr(attrs[ver.name], i)
assert ver.children is not None
hist[ver.name] = SeqAttributes(
order,
SeqCommonValue(seq.name, order),
{seq.name: SeqAttribute(seq.name, order=order, max=ver.children)},
ahist,
)
hist[None] = strip_seq_vals(attrs[ver.name])
else:
hist[None] = attrs[ver.name]
# Add sequentiality to context tables through parent
parent, gparent, ggparent = get_parents(ver)
target = gparent if seq_repeat else parent
ptarget = ggparent if seq_repeat else gparent
if ctx and target and ptarget and meta[target].sequence:
sequence = meta[target].sequence
norder = meta[target].order or 1
assert sequence
ahist = {}
for i in range(norder):
ahist[i] = convert_to_seq_attr(hist[None], i)
hist[ver.name] = SeqAttributes(
norder,
SeqCommonValue(sequence, norder),
{sequence: SeqAttribute(sequence, order=norder, max=ver.children)},
ahist,
)
return hist
[docs]
def generate_fit_tables(
data: Mapping[str, LazyPartition],
attrs: dict[str, Attributes],
ver: TableVersion,
ctx: bool,
) -> dict[TableSelector, pd.DataFrame]:
ids, tables = data_to_tables(data)
# Get history
meta = _calculate_stripped_meta(attrs)
hist = gen_history(ver.parents, tables, ids, meta)
# Prune ids that are not used in this model, ex. due to partitioning
fids = ids[ver.name]()
for name, table in hist.items():
if isinstance(name, tuple):
name = name[0]
fids = fids.join(table[[]], on=name, how="inner")
_tmp = tables[ver.name]()
try:
table = _tmp.loc[fids.index]
except Exception:
# FIXME: Resolve id situation
# If a context table is joined to a parent without child rows
# that parent is pruned.
table = _tmp.join(fids, how="inner")
# If no parents, assume normal table and return it
if not ver.parents:
tmeta = meta[ver.name]
assert not tmeta.unroll and not tmeta.sequence
return {None: table}
new_hist = {}
unroll = meta[ver.name].unroll
sequence = meta[ver.name].sequence
order = meta[ver.name].order
max_len = meta[ver.name].max_len
seq_repeat = meta[ver.name].seq_repeat
parent, gparent, ggparent = get_parents(ver)
# Create new id that is unique per sequence
SID_NAME = "nid_jsdi78"
if seq_repeat and parent:
# If seq_repeat, skip parent
fltr = [c for c in fids.columns if c != parent]
else:
fltr = list(fids.columns)
# Create index with just parents, add a column that acts as an
# index, join back to fids
sid = fids[fltr].groupby(fltr).first()[[]]
sid[SID_NAME] = range(len(sid))
sid = fids[fltr].join(sid, on=fltr).drop(columns=fltr)
if ctx:
# common operation for indexing to parent for context tables
fids = fids.join(sid.drop_duplicates(), how="inner").set_index(SID_NAME)
if unroll:
if ctx:
assert ver.unrolls
_, cmn, cols, ofs = recurse_unroll_attr(ver.unrolls, attrs[ver.name])
udfs = []
for u in cmn.keys():
udf = (
table.loc[table[unroll] == u, list(cols[u])]
.join(sid)
.drop_duplicates([SID_NAME])
.set_index(SID_NAME)
.convert_dtypes()
)
udf[cmn[u]] = pd.Series(1, dtype="UInt8", index=udf.index)
for c, o in ofs[u].items():
udf[c] -= o - 1
udfs.append(udf.rename(columns=cols[u]))
utab = pd.concat(udfs, axis=1).fillna(0)
synth = utab.astype(
dtype={k: str(v).lower() for k, v in utab.dtypes.to_dict().items()} # type: ignore
)
else:
unroll_cols = [unroll]
along = meta[ver.name].along
if along:
unroll_cols.extend(along)
new_hist[ver.name] = table[unroll_cols]
synth = table.drop(columns=unroll_cols)
elif sequence:
if ctx:
lens = sid.groupby(SID_NAME).size().rename(f"{ver.name}_n")
if max_len is not None:
lens = lens.clip(upper=max_len)
synth = pd.DataFrame(lens)
else:
assert order is not None
seq_hist = _unroll_sequence(sequence, order, fids, table)
for o, data in seq_hist.items():
new_hist[(ver.name, o)] = data
new_hist[ver.name] = pd.DataFrame(table[sequence].clip(upper=order))
synth = table
else:
if ctx:
synth = pd.DataFrame(sid.groupby(SID_NAME).size().rename(f"{ver.name}_n"))
else:
synth = table
# With the new ids, prune the history
for idx, table in hist.items():
if isinstance(idx, tuple):
name = idx[0]
else:
name = idx
new_hist[idx] = (
fids[[name]].join(table, on=name, how="inner").drop(columns=name)
)
# Apply sequentiality to context tables. If seq_repeat, to grand_parent. Else to parent
target = gparent if seq_repeat else parent
ptarget = ggparent if seq_repeat else gparent
if ctx and target and ptarget and meta[target].sequence:
seq = meta[target].sequence
norder = meta[target].order or 1
assert seq
seq_hist = _unroll_sequence(
seq,
norder,
fids,
synth,
seq=tables[target]()[seq],
stable=target,
ptable=ptarget,
)
for o, data in seq_hist.items():
new_hist[(ver.name, o)] = data
new_hist[ver.name] = new_hist[parent]
return {**new_hist, None: synth}
[docs]
class ModelVersion(NamedTuple):
ver: TableVersion
ctx: bool
[docs]
def calculate_model_versions(
attrs: dict[str, Attributes],
data: Mapping[str, LazyDataset],
max_vers: int,
no_hist: bool = False,
) -> dict[ModelVersion, tuple[DatasetAttributes, PreprocessFun]]:
ids, tables = data_to_tables(data)
chains = calculate_table_chains(attrs, ids, tables)
meta = _calculate_stripped_meta(attrs)
out: dict[ModelVersion, tuple[DatasetAttributes, PreprocessFun]] = {}
for name, vers in chains.items():
assert vers, f"Table {name} has 0 versions."
tmeta = meta[name]
if tmeta.unroll:
preproc_fn = lambda v: set(v.unrolls) if v.unrolls else set()
merge_fn = lambda a, b: a.union(b)
score_fn = lambda a, b: len(a.symmetric_difference(b))
# Unroll context models
new_vers = merge_versions_heuristic(
vers, max_vers, preproc_fn, merge_fn, score_fn
)
if not new_vers:
new_vers = vers
for ver in new_vers:
new_attrs = generate_fit_attrs(ver, attrs, True, no_hist=no_hist)
assert new_attrs is not None
load_fn = partial(generate_fit_tables, attrs=attrs, ver=ver, ctx=True)
out[ModelVersion(ver, True)] = new_attrs, load_fn
# Unroll series model
ver = merge_versions(vers)
new_attrs = generate_fit_attrs(ver, attrs, False, no_hist=no_hist)
assert new_attrs is not None
load_fn = partial(generate_fit_tables, attrs=attrs, ver=ver, ctx=False)
out[ModelVersion(ver, False)] = new_attrs, load_fn
else:
# Apart from unroll, create one ctx model and one series model
# for each table
ver = merge_versions(vers)
for ctx in (True, False):
new_attrs = generate_fit_attrs(ver, attrs, ctx, no_hist=no_hist)
if new_attrs is not None:
load_fn = partial(
generate_fit_tables, attrs=attrs, ver=ver, ctx=ctx
)
out[ModelVersion(ver, ctx)] = new_attrs, load_fn
return out