from functools import reduce
from typing import (
TYPE_CHECKING,
Any,
Generic,
NamedTuple,
Protocol,
Sequence,
TypeVar,
)
import numpy as np
from ..attribute import DatasetAttributes, get_dtype
from .hugin import CliqueMeta, get_attrs
A = TypeVar("A")
B = TypeVar("B", covariant=True)
[docs]
def deduplicate(arr1, arr2):
check = set()
out = []
for t in zip(arr1, arr2):
if t not in check:
out.append(t)
check.add(t)
return np.stack(out).T
[docs]
class IndexArg(NamedTuple):
a_map: tuple[int, ...]
b_map: tuple[int, ...]
a_dom: int
b_dom: int
[docs]
class Message:
# This message is sent from clique a -> b
# Cliques use an alphabetical canonical order for their attributes, so
# common attributes of a, b are in the same order.
a: CliqueMeta
b: CliqueMeta
# This message can be either part of the forward pass (a -> b before b -> a)
# or part of the backward pass.
# If it's part of the backward pass, the forward message (b -> a)
# which has been sent needs to be subtracted from it, so a version of the
# forward message that is broadcastable has to be stored during the forward pass.
forward: bool
# The broadcastable version will contain a subset of dims, the others
# should be summed out.
sum_dims: tuple[int, ...]
# The attribute domains between cliques are expected to be different.
# In this case, we use indexing operations to move between them.
# After the summing operations, only the message attributes are left
# In the order they appear in, index_args provides either an op or None for no-op.
index_args: tuple[IndexArg | None, ...]
# In the end, we have to add singleton dims so that the resulting message
# is broadcastable to clique b.
# Reshape dims specifies where singleton dimensions should be added.
# False: add singleton dim with len 1, True: add original dimension
# `len(reshape_dims) == len(b.shape)`, `sum(reshape_dims) == len(msg)`
reshape_dims: tuple[bool, ...]
def __init__(self, a, b, forward, sum_dims, index_args, reshape_dims):
self.a = a
self.b = b
self.forward = forward
self.sum_dims = sum_dims
self.index_args = index_args
self.reshape_dims = reshape_dims
[docs]
def is_same(c, d):
return c.table == d.table and c.order == d.order and c.attr == d.attr
[docs]
def has_attr(cl, a):
for attr in cl:
if is_same(attr, a):
return True
return False
[docs]
def convert_sel(sel):
if isinstance(sel, int):
return sel
else:
return dict(sel)
[docs]
def create_messages(
generations: Sequence[Sequence[tuple[CliqueMeta, CliqueMeta]]],
attrs: DatasetAttributes,
) -> Sequence[Message]:
done = set()
messages = []
for generation in generations:
for a, b in generation:
sum_dims = []
args = []
for i in range(len(a)):
if has_attr(b, a[i]):
j = 0
while not is_same(a[i], b[j]):
j += 1
if a[i].sel != b[j].sel:
attr = get_attrs(attrs, a[i].table, a[i].order)[a[i].attr]
a_map = attr.get_mapping(convert_sel(a[i].sel))
a_dom = attr.get_domain(convert_sel(a[i].sel))
b_map = attr.get_mapping(convert_sel(b[j].sel))
b_dom = attr.get_domain(convert_sel(b[j].sel))
assert len(a_map) == len(b_map)
assert np.max(a_map) < a_dom and np.max(b_map) < b_dom
args.append(IndexArg(tuple(a_map), tuple(b_map), a_dom, b_dom))
else:
args.append(None)
else:
sum_dims.append(i)
unsqueeze_dims = []
for i in range(len(b)):
if has_attr(a, b[i]):
unsqueeze_dims.append(True)
else:
unsqueeze_dims.append(False)
msg = Message(
a,
b,
(b, a) not in done,
tuple(sum_dims),
tuple(args),
tuple(unsqueeze_dims),
)
messages.append(msg)
done.add((a, b))
return messages
[docs]
def get_clique_shapes(cliques: Sequence[CliqueMeta], attrs: DatasetAttributes):
shapes = []
for cl in cliques:
shape = []
for meta in cl:
shape.append(
get_attrs(attrs, meta.table, meta.order)[meta.attr].get_domain(
convert_sel(meta.sel)
)
)
shapes.append(shape)
return shapes
[docs]
def get_clique_weights(cliques: Sequence[CliqueMeta], attrs: DatasetAttributes):
weights = []
for cl in cliques:
weight = []
for meta in cl:
attr = get_attrs(attrs, meta.table, meta.order)[meta.attr]
mapping = attr.get_mapping(convert_sel(meta.sel))
dom = attr.get_domain(convert_sel(meta.sel))
w = np.zeros(dom)
np.add.at(w, mapping, 1)
w /= len(mapping)
weight.append(w)
weights.append(weight)
return weights
[docs]
def numpy_create_cliques(cliques: Sequence[CliqueMeta], attrs: DatasetAttributes):
return [np.zeros(shape) for shape in get_clique_shapes(cliques, attrs)]
[docs]
def numpy_gen_multi_index(messages: Sequence[Message]):
index_args = []
for m in messages:
index_args.append(
tuple(
# np.unique(np.stack([idx.a_map, idx.b_map]), axis=1)
deduplicate(idx.a_map, idx.b_map)
if idx is not None
else None
for idx in m.index_args
)
)
return tuple(index_args)
[docs]
class NumpyIndexArgs(NamedTuple):
transpose: tuple[int, ...]
transpose_undo: tuple[int, ...]
idx_a: np.ndarray
idx_b: np.ndarray
b_doms: tuple[int, ...]
[docs]
def numpy_gen_args(messages: Sequence[Message]):
index_args = []
for m in messages:
transpose_front = []
transpose_back = []
idx_a_dims = []
idx_a_doms = []
idx_b_dims = []
idx_b_doms = []
for i, arg in enumerate(m.index_args):
if arg:
# uniques = np.unique(np.stack([arg.a_map, arg.b_map]), axis=1)
uniques = deduplicate(arg.a_map, arg.b_map)
idx_a_dims.append(uniques[0, :])
idx_a_doms.append(arg.a_dom)
idx_b_dims.append(uniques[1, :])
idx_b_doms.append(arg.b_dom)
transpose_front.append(i)
else:
transpose_back.append(i)
# When reshaping is not needed, the `if arg` block
# of the loop won't run so arrays will be empty. SKip.
if not idx_a_dims:
index_args.append(None)
continue
dtype_a = get_dtype(reduce(lambda a, b: a * b, idx_a_doms))
dtype_b = get_dtype(reduce(lambda a, b: a * b, idx_b_doms))
mesh_a = np.meshgrid(*idx_a_dims, indexing="ij")
mesh_b = np.meshgrid(*idx_b_dims, indexing="ij")
idx_a = np.zeros_like(mesh_a[0], dtype=dtype_a)
idx_b = np.zeros_like(mesh_a[0], dtype=dtype_b)
a_dom = 1
b_dom = 1
for i in reversed(list(range(len(idx_a_doms)))):
idx_a += a_dom * mesh_a[i].astype(dtype_a)
idx_b += b_dom * mesh_b[i].astype(dtype_b)
a_dom *= idx_a_doms[i]
b_dom *= idx_b_doms[i]
idx_a = idx_a.reshape(-1)
idx_b = idx_b.reshape(-1)
b_doms = tuple(idx_b_doms)
transpose = tuple(transpose_front + transpose_back)
transpose_undo = tuple(transpose.index(i) for i in range(len(transpose)))
index_args.append(
NumpyIndexArgs(transpose, transpose_undo, idx_a, idx_b, b_doms)
)
return tuple(index_args)