""" This module contains functions for calculating the markov chain combinations
of different tables in a dataset.
"""
from collections import defaultdict
from itertools import product
from typing import Generic, NamedTuple, Sequence, TypeVar
import pandas as pd
from ..attribute import Attributes, SeqValue
from ..utils import LazyFrame, lazy_load_tables
[docs]
class TableVersion(NamedTuple):
"""Contains the parameters required to construct a markov chain for table
with name `name`, by recursing on its parents."""
name: str
rows: int
children: int | None
max_len: int | None
partitions: tuple[int, ...] | None
unrolls: tuple[int, ...] | None
parents: tuple["TableVersion | TablePartition", ...]
seq_repeat: bool
[docs]
class TablePartition(NamedTuple):
partitions: tuple[int, ...]
table: TableVersion
[docs]
def get_parents(
ids: dict[str, LazyFrame], table: str | None = None, _reverse: bool = False
):
"""Returns a graph of the parent relationships in the dataset, based on the id table.
I.e. if there is a dataset with tables A, B, C, D, where C -> B -> A and D -> B,
the function returns`{'D': {'B': {'A': {}}}, 'C': {'B': {'A': {}}}}`"""
# Get parents
parents = {name: set(id.sample().columns) for name, id in ids.items()}
tables = list(parents)
if _reverse:
children = {t: set() for t in tables}
for t, deps in parents.items():
for p in deps:
children[p].add(t)
parents = children
# Deduplicate parent relationships
updated = True
while updated:
updated = False
for k in tables:
for l in tables:
if k in parents[l] and parents[k] and parents[k].issubset(parents[l]):
parents[l].difference_update(parents[k])
updated = True
# Create parent tree
new_parents = {t: {} for t in tables}
has_child = set()
for t, dep in parents.items():
for p in dep:
new_parents[t][p] = new_parents[p]
has_child.add(p)
if table:
return new_parents[table]
# Keep only tables with no children in top level
return {t: d for t, d in new_parents.items() if t not in has_child}
[docs]
def get_children(ids: dict[str, LazyFrame]):
"""Returns a graph of the children relationships in the dataset, based on the id table.
I.e. if there is a dataset with tables A, B, C, D, where C -> B -> A and D -> B,
the function returns`{'A': {'B': {'C', 'D'}}}`"""
return get_parents(ids, _reverse=True)
def _calculate_parent_mask(
name: str,
parents: tuple[TableVersion | TablePartition, ...],
meta: dict[str, TableMeta],
get_table,
get_ids,
cache,
):
included_ids = {}
for parent in parents:
included_ids.update(
_calculate_included_ids(parent, meta, get_table, get_ids, cache)
)
mask = True
for parent, (pkey, incl_ids) in included_ids.items():
key = (name, pkey)
if key not in cache:
cache[key] = get_ids(name)[parent].isin(incl_ids)
mask &= cache[key]
if mask is True:
return None
return mask
def _calculate_included_ids(
ver: TableVersion | TablePartition,
meta: dict[str, TableMeta],
get_table,
get_ids,
cache,
):
"""Each table version excludes some rows based on partitioning and sequencing.
And the partitioning and sequencing of its parents.
This function returns a dictionary where for each parent table, the keys that
should be included are provided as a pd.Series.
`cache` is a dictionary that caches the ids for previous table versions."""
# Out is a dictionary of keys used to calculate the mask
out = {}
if isinstance(ver, TablePartition):
partitions = ver.partitions
ver = ver.table
else:
partitions = None
for parent in ver.parents:
out.update(_calculate_included_ids(parent, meta, get_table, get_ids, cache))
# Cache ids based on name, partition type, and sequence
key = (ver.name, partitions)
if key in cache:
out[ver.name] = (key, cache[key])
return out
# Since we only filter when there is a partition, return without ids
# If both don't exist
if not partitions:
return out
# Apply mask based on sequence and partitions
mask = True
part_col = get_table(ver.name)[meta[ver.name].partition]
if len(partitions) > 1:
mask &= part_col.isin(partitions)
else:
mask &= part_col == partitions[0]
# Calculate ids, update cache, and return them
included_ids = get_ids(ver.name).index[mask].unique()
cache[key] = included_ids
out[ver.name] = (key, included_ids)
return out
_calculate_stripped_meta = calculate_stripped_meta
def _calculate_chains_of_table(
name: str,
meta: dict[str, TableMeta],
ids: dict[str, LazyFrame],
tables: dict[str, LazyFrame],
parents: dict[str, tuple[TableVersion, ...]],
) -> tuple[TableVersion, ...]:
"""Calculates the possible markov chains for the table with `name`, by
iterating over its parent combinations.
The chains are returned as a tuple, alongside a dictionary of `chain -> row
count` mappings, which may be used as auxiliary info when deciding
which chains to keep."""
def _expand_versions(vers: tuple[TableVersion, ...]):
out = []
for v in vers:
if v.partitions:
for p in v.partitions:
out.append(TablePartition((p,), v))
else:
out.append(v)
return out
combos = list(
product(
*[
_expand_versions(v)
for _, v in sorted(parents.items(), key=lambda x: x[0])
]
)
) or [()]
partition = meta[name].partition
unroll = meta[name].unroll
max_len = meta[name].max_len
seq_repeat = meta[name].seq_repeat
# Unrolling and Partitioning have values that are extracted from data so have
# to run per partution
rows_per_combo = defaultdict(lambda: 0)
children_per_combo: dict[tuple, int | None] = defaultdict(lambda: None)
unrolls_per_combo = defaultdict(set)
partitions_per_combo = defaultdict(set)
for ptables, pids in LazyFrame.zip_values([tables, ids]):
get_table = lazy_load_tables(ptables)
get_ids = lazy_load_tables(pids)
included_ids_cache: dict[TablePartition | TableVersion, pd.Series] = {}
for combo in combos:
table = get_table(name)
pmask = _calculate_parent_mask(
name, combo, meta, get_table, get_ids, included_ids_cache
)
if pmask is not None:
table = table.loc[pmask]
rows_per_combo[combo] += len(table)
if parents:
SID_NAME = "nid_jsdi78"
fids = get_ids(name)
if pmask is not None:
fids = fids.loc[pmask]
sid = fids.join(
fids.drop_duplicates()
.reset_index(drop=True)
.reset_index(names=SID_NAME)
.set_index(list(fids.columns)),
on=list(fids.columns),
).drop(columns=list(fids.columns))
new_children = sid.groupby(SID_NAME).size().max()
children = children_per_combo[combo]
if children is None or children < new_children:
children_per_combo[combo] = new_children
if unroll:
unrolls_per_combo[combo].update(table[unroll].unique())
if partition:
counts = table[partition].value_counts().to_dict()
partitions_per_combo[combo].update(counts.keys())
versions = []
for combo in combos:
# Extract data from loading tables
if partition:
partitions = tuple(sorted(partitions_per_combo[combo]))
else:
partitions = None
if unroll:
unrolls = tuple(sorted(unrolls_per_combo[combo]))
else:
unrolls = None
ver = TableVersion(
name=name,
rows=rows_per_combo[combo],
children=children_per_combo[combo],
max_len=max_len,
partitions=partitions,
unrolls=unrolls,
parents=combo,
seq_repeat=seq_repeat,
)
versions.append(ver)
return tuple(versions)
[docs]
def calculate_table_chains(
attrs: dict[str, Attributes],
ids: dict[str, LazyFrame],
tables: dict[str, LazyFrame],
return_all_tables=True,
_parents=None,
_cache=None,
) -> dict[str, tuple[TableVersion, ...]]:
"""Returns a tuple of all possible chain combinations for the tables in the
provided view (as a dictionary of table -> chains) and a dictionary of chain
to row count mappings."""
if _parents is None:
_parents = get_parents(ids)
if _cache is None:
_cache = {}
meta = _calculate_stripped_meta(attrs)
out = {}
for name, parents in _parents.items():
if name in _cache:
out[name] = _cache[name]
else:
parent_versions = calculate_table_chains(
attrs, ids, tables, False, parents, _cache
)
versions = _calculate_chains_of_table(
name, meta, ids, tables, parent_versions
)
_cache[name] = versions
out[name] = versions
if return_all_tables:
return _cache
return out
__all__ = [
"get_parents",
"get_children",
"calculate_table_chains",
]