"""This module contains heuristics for simplifying the chain combinations of a
dataset."""
from collections import defaultdict
from functools import reduce
from heapq import heappop, heappush
from itertools import combinations, product
from typing import (
Generator,
Generic,
Protocol,
Sequence,
TypeVar,
cast,
)
from ..attribute import Attributes
from .chains import TableMeta, TablePartition, TableVersion, _calculate_stripped_meta
A = TypeVar("A", covariant=True)
B = TypeVar("B")
C = TypeVar("C", contravariant=True)
[docs]
class PreprocFn(Protocol, Generic[A]):
def __call__(self, v: TableVersion) -> A: ...
[docs]
class MergeFn(Protocol, Generic[B]):
def __call__(self, a: B, b: B) -> B: ...
[docs]
class ScoreFn(Protocol, Generic[C]):
def __call__(self, a: C, b: C) -> int: ...
[docs]
class Pair:
def __init__(self, name, a, b, score):
self.name = name
self.a = a
self.b = b
self.score = score
def __lt__(self, other: "Pair"):
return self.score < other.score
def __iter__(self):
return iter((self.name, self.a, self.b))
def _get_partitions_names(ver: TableVersion):
out = []
for p in ver.parents:
if isinstance(p, TablePartition):
out.append(p.table.name)
p = p.table
out.extend(_get_partitions_names(p))
return out
def _get_partitions(ver: TableVersion, out):
for p in ver.parents:
if isinstance(p, TablePartition):
out.append(p.partitions)
p = p.table
_get_partitions(p, out)
return out
[docs]
def get_combos(
vers: Sequence[TableVersion], tables: Sequence[str] | None = None
) -> tuple[Sequence[str], dict[tuple[tuple[int, ...], ...], TableVersion]]:
names = _get_partitions_names(vers[0])
unique_names = tables or sorted(set(names))
if not unique_names:
return [], {}
name_idx = [names.index(u) for u in unique_names]
candidates = {}
for ver in vers:
parts = []
_get_partitions(ver, parts)
candidates[tuple([parts[i] for i in name_idx])] = ver
return unique_names, candidates
[docs]
def merge_versions_heuristic(
vers: Sequence[TableVersion],
max_vers: int,
preproc_fn: PreprocFn[B],
merge_fn: MergeFn[B],
score_fn: ScoreFn[B],
tables: Sequence[str] | None = None,
):
names, combos = get_combos(vers, tables)
lookups = {}
for i, name in enumerate(names):
lookup = defaultdict(list)
for combo, val in combos.items():
lookup[combo[i]].append(val)
lookups[name] = {k: preproc_fn(merge_versions(v)) for k, v in lookup.items()}
heap = []
used = {k: set() for k in lookups}
for name, lookup in lookups.items():
for a, b in combinations(lookup, 2):
score = score_fn(lookups[name][a], lookups[name][b])
heappush(heap, Pair(name, a, b, score))
while True:
if not heap or len(combos) <= max_vers:
break
name, a, b = heappop(heap)
if a in used[name] or b in used[name]:
continue
c = tuple(sorted(set(a + b)))
lookup = lookups[name]
lookup[c] = merge_fn(lookup[a], lookup[b])
lookup.pop(a)
lookup.pop(b)
used[name].add(a)
used[name].add(b)
for other in lookup:
if other != c:
score = score_fn(lookups[name][c], lookups[name][other])
heappush(heap, Pair(name, c, other, score))
i = names.index(name)
merge_a = {}
merge_b = {}
for combo in combos:
parts = combo[i]
if a == parts:
key = tuple(v for j, v in enumerate(combo) if j != i)
merge_a[key] = combo
elif b == parts:
key = tuple(v for j, v in enumerate(combo) if j != i)
merge_b[key] = combo
for key in sorted(set(merge_a.keys()).union(merge_b.keys())):
if key in merge_a and key in merge_b:
template = merge_a[key]
a_val = combos.pop(merge_a[key])
b_val = combos.pop(merge_b[key])
c_val = merge_versions([a_val, b_val])
elif key in merge_a:
template = merge_a[key]
c_val = combos.pop(template)
else:
template = merge_b[key]
c_val = combos.pop(template)
new_c = tuple(k if j != i else c for j, k in enumerate(template))
combos[new_c] = c_val
return tuple(combos.values())
def _estimate_columns_for_chain_recurse(
ver: TableVersion, smeta: dict[str, TableMeta], meta: dict[str, Attributes]
):
base_cols = 0
for attr in meta[ver.name].values():
base_cols += attr.common is not None
base_cols += len(attr.vals)
if smeta[ver.name].order is not None:
mult = smeta[ver.name].order
if mult is not None:
base_cols = mult * (base_cols - 1)
along = smeta[ver.name].along
if along and ver.unrolls:
base_cols -= len(along) + 1
base_cols += len(along) * len(ver.unrolls)
out = {ver.name: base_cols}
for p in ver.parents:
if isinstance(p, TablePartition):
p = p.table
out.update(_estimate_columns_for_chain_recurse(p, smeta, meta))
return out
[docs]
def estimate_columns_for_chain(ver: TableVersion, meta: dict[str, Attributes]):
"""Estimates the number of columns for the provided chain based on metadata."""
return sum(
_estimate_columns_for_chain_recurse(
ver, _calculate_stripped_meta(meta), meta
).values()
)
[docs]
def calc_model_num(combos: dict[str, tuple[set[int]]]):
num = 1
for c in combos.values():
num *= len(c)
return num
[docs]
def get_parents(ver: TableVersion) -> Generator[TablePartition, None, None]:
for p in ver.parents:
if isinstance(p, TablePartition):
yield p
yield from get_parents(p.table)
else:
yield from get_parents(p)
[docs]
def tuple_unique(a, b):
if not a:
return b
if not b:
return a
return tuple(sorted(set(a + b)))
[docs]
def merge_versions(vers: Sequence[TableVersion]):
ref = vers[0]
new_parents = []
for i, p in enumerate(ref.parents):
if isinstance(p, TablePartition):
new_partitions = set()
par_versions = []
for ver in vers:
partitions, table = cast(TablePartition, ver.parents[i])
new_partitions.update(partitions)
par_versions.append(table)
new_parents.append(
TablePartition(
tuple(sorted(new_partitions)), merge_versions(par_versions)
)
)
else:
new_parents.append(
merge_versions([cast(TableVersion, ver.parents[i]) for ver in vers])
)
if ref.children is not None:
children = max(cast(int, v.children) for v in vers)
else:
children = None
if ref.partitions:
partitions = set()
for ver in vers:
partitions.update(ver.partitions) # type: ignore
partitions = tuple(sorted(partitions))
else:
partitions = None
if ref.unrolls:
unrolls = set()
for ver in vers:
unrolls.update(ver.unrolls) # type: ignore
unrolls = tuple(sorted(unrolls))
else:
unrolls = None
rows = 0
for ver in vers:
rows += ver.rows
lens = [v.max_len for v in vers if v.max_len is not None]
max_len = max(lens) if lens else None
return TableVersion(
name=ref.name,
rows=rows,
children=children,
max_len=max_len,
partitions=partitions,
unrolls=unrolls,
parents=tuple(new_parents),
seq_repeat=ref.seq_repeat,
)
[docs]
def calc_rows_cols(
combo: dict[str, tuple[set[int]]],
chains: tuple[TableVersion, ...],
rows: dict[TableVersion | TablePartition, int],
meta: dict[str, Attributes],
) -> list[tuple[TableVersion, int, int]]:
out = []
for partitions in product(*combo.values()):
partitions = {k: v for k, v in zip(combo, partitions)}
versions = []
for ver in chains:
reject = False
for p in get_parents(ver):
if p.partitions[0] not in partitions[p.table.name]:
reject = True
break
if not reject:
versions.append(ver)
if not versions:
continue
new_version = merge_versions(versions)
row_count = sum(rows[v] for v in versions)
col_count = estimate_columns_for_chain(new_version, meta)
out.append((new_version, row_count, col_count))
return out