from __future__ import annotations
from itertools import chain
import logging
from math import ceil
from typing import Any, Sequence, cast
import pandas as pd
import numpy as np
from ....attribute import Attributes, CatValue, DatasetAttributes, SeqAttributes
from ....hierarchy import rebalance_attributes
from ....mare.synth import MareModel
from ....marginal import MarginalOracle, PostprocessFun, PreprocessFun
from ....marginal.numpy import TableSelector
from ....marginal.oracle import counts_preprocess
from ....synth import Synth, make_deterministic
from ....utils import LazyFrame, data_to_tables, tables_to_data
from .implementation import (
MAX_EPSILON,
Node,
calc_noisy_marginals,
greedy_bayes,
print_tree,
sample_rows,
)
logger = logging.getLogger(__name__)
[docs]
class PrivBayesMare(MareModel):
def __init__(
self,
*,
etotal: float | None = None,
ep: float | None = None,
e1: float = 0.3,
e2: float = 0.7,
theta: float = 4,
use_r: bool = True,
seed: float | None = None,
unbounded_dp: bool = False,
random_init: bool = False,
skip_zero_counts: bool = True,
minimum_cutoff: int | None = 3,
rake: bool = True,
**kwargs,
) -> None:
if etotal is None:
etotal = 1
self.ep = ep * etotal if ep is not None else None
self.e1 = e1 * etotal
self.e2 = e2 * etotal
self.theta = theta
self.use_r = use_r
self.seed = seed
self.random_init = random_init
self.unbounded_dp = unbounded_dp
self.skip_zero_counts = skip_zero_counts
self.rake = rake
self.kwargs = kwargs
self.minimum_cutoff = minimum_cutoff
[docs]
@make_deterministic
def fit(self, n: int, table: str, attrs: DatasetAttributes, oracle: MarginalOracle):
from .implementation import MAX_EPSILON, calc_noisy_marginals, greedy_bayes
# Fit network
nodes, t = greedy_bayes(
oracle,
attrs,
n,
self.e1,
self.e2,
self.theta,
self.use_r,
self.unbounded_dp,
self.random_init,
prefer_table=table,
rake=self.rake,
)
# Nodes are a tuple of a x attribute
self.t = t
self.nodes = nodes
self.attrs = attrs
logger.info(self)
d = 0
for attr in cast(Attributes, attrs[None]).values():
d += len(attr.vals)
noise = (1 if self.unbounded_dp else 2) * d / self.e2 / n
if self.e2 > MAX_EPSILON:
logger.warning(f"Considering e2={self.e2} unbounded, sampling without DP.")
noise = 0
self.marginals = calc_noisy_marginals(
oracle,
self.nodes,
noise,
self.skip_zero_counts,
minimum_cutoff=self.minimum_cutoff,
)
[docs]
def sample(
self, index: pd.Index, hist: dict[TableSelector, pd.DataFrame]
) -> pd.DataFrame:
from .implementation import sample_rows
return sample_rows(index, self.attrs, hist, self.nodes, self.marginals)
def __str__(self) -> str:
from .implementation import print_tree
return print_tree(
self.attrs,
self.nodes,
self.e1,
self.e2,
self.theta,
self.t,
minimum_cutoff=self.minimum_cutoff,
)
[docs]
class PrivBayesSynth(Synth):
name = "privbayes"
type = "idx"
tabular = True
multimodal = False
timeseries = False
parallel = True
def __init__(
self,
ep: float | None = None,
e1: float = 0.3,
e2: float = 0.7,
etotal: float | None = None,
theta: float = 4,
use_r: bool = True,
seed: float | None = None,
rebalance: bool = False,
unbounded_dp: bool = False,
random_init: bool = False,
marginal_mode: MarginalOracle.MODES = "out_of_core",
marginal_worker_mult: int = 1,
marginal_min_chunk: int = 100,
skip_zero_counts: bool = True,
minimum_cutoff: int | None = 3,
**kwargs,
) -> None:
if etotal is None:
etotal = 1
self.ep = ep * etotal if ep is not None else None
self.e1 = e1 * etotal
self.e2 = e2 * etotal
self.theta = theta
self.use_r = use_r
self.seed = seed
self.random_init = random_init
self.unbounded_dp = unbounded_dp
self.rebalance = rebalance
self.marginal_mode: MarginalOracle.MODES = marginal_mode
self.marginal_min_chunk = marginal_min_chunk
self.marginal_worker_mult = marginal_worker_mult
self.skip_zero_counts = skip_zero_counts
self.minimum_cutoff = minimum_cutoff
self.kwargs = kwargs
[docs]
@make_deterministic
def preprocess(
self, meta: dict[str | None, Attributes], data: dict[str, LazyFrame]
):
attrs = meta
_, tables = data_to_tables(data)
table_name = next(iter(tables.keys()))
table = tables[table_name]
self._n = table.shape[0]
self._partitions = len(table)
self.original_attrs = attrs
self.table_name = table_name
if self.rebalance:
with MarginalOracle(
data, # type: ignore
attrs,
mode=self.marginal_mode,
min_chunk_size=self.marginal_min_chunk,
max_worker_mult=self.marginal_worker_mult,
preprocess=counts_preprocess,
) as o:
counts = o.get_counts(desc="Calculating counts for column rebalancing")
# TODO: Add noise and remove save support
self.counts = counts
self.attrs = {
k: rebalance_attributes(
counts[k],
v,
unbounded_dp=self.unbounded_dp,
**self.kwargs,
)
for k, v in attrs.items()
}
else:
self.attrs = attrs
self.table_attrs: DatasetAttributes = {None: self.attrs[table_name]}
[docs]
@make_deterministic
def bake(self, data: dict[str, LazyFrame]):
_, tables = data_to_tables(data)
assert len(tables) == 1, "Only tabular data supported for now"
table_name = next(iter(tables.keys()))
table = tables[table_name]
with MarginalOracle(
data, # type: ignore
self.table_attrs,
mode=self.marginal_mode,
min_chunk_size=self.marginal_min_chunk,
max_worker_mult=self.marginal_worker_mult,
) as oracle:
self.n, self.d = table.shape
# Fit network
nodes, t = greedy_bayes(
oracle,
self.table_attrs,
table.shape[0],
self.e1,
self.e2,
self.theta,
self.use_r,
self.unbounded_dp,
self.random_init,
)
# Nodes are a tuple of a x attribute
self.table_name = table_name
self.t = t
self.nodes = nodes
logger.info(self)
[docs]
@make_deterministic
def fit(self, data: dict[str, LazyFrame]):
_, tables = data_to_tables(data)
table = tables[self.table_name]
self.partitions = len(table)
self.n = ceil(table.shape[0] / self.partitions)
noise = (1 if self.unbounded_dp else 2) * self.d / self.e2 / self.n
if self.e2 > MAX_EPSILON:
logger.warning(f"Considering e2={self.e2} unbounded, sampling without DP.")
noise = 0
with MarginalOracle(
data, # type: ignore
self.table_attrs,
mode=self.marginal_mode,
min_chunk_size=self.marginal_min_chunk,
max_worker_mult=self.marginal_worker_mult,
) as o:
self.marginals = calc_noisy_marginals(
o,
self.nodes,
noise,
self.skip_zero_counts,
minimum_cutoff=self.minimum_cutoff,
)
[docs]
@make_deterministic("i")
def sample_partition(self, *, n: int, i: int = 0) -> dict[str, Any]:
import pandas as pd
if n is None:
n = self.n
tables = {
self.table_name: sample_rows(
pd.RangeIndex(n),
{None: self.attrs[self.table_name]},
{},
self.nodes,
self.marginals,
)
}
ids = {self.table_name: pd.DataFrame()}
return tables_to_data(ids, tables)
def __str__(self) -> str:
return print_tree(
{None: self.attrs[self.table_name]},
self.nodes,
self.e1,
self.e2,
self.theta,
self.t,
minimum_cutoff=self.minimum_cutoff
)
[docs]
def derive_graph_from_nodes(
nodes: Sequence[Node], attrs: DatasetAttributes, prune: bool = True
):
import networkx as nx
def get_name(table, order, attr, val, height):
out = ""
if table:
out += table
if order is not None:
out += f"[{order}]"
out += "_"
out += f"{attr}.{val}[{height}]"
return out
g = nx.DiGraph()
commons = {}
max_heights = {}
for table, tattrs in attrs.items():
if isinstance(tattrs, SeqAttributes):
attr_sets = {**tattrs.hist, None: tattrs.attrs}
else:
attr_sets = {None: tattrs}
for order, attr_set in attr_sets.items():
for name, attr in attr_set.items():
cmn = attr.common
if cmn:
commons[(table, order, name)] = cmn.name
for h in range(cmn.height):
g.add_node(
get_name(table, order, name, cmn.name, h),
table=table,
order=order,
attr=name,
value=cmn.name,
height=h,
)
if h:
g.add_edge(
get_name(table, order, name, cmn.name, h),
get_name(table, order, name, cmn.name, h - 1),
)
for v in attr.vals.values():
if not isinstance(v, CatValue):
continue
h_range = v.height if cmn is None else v.height - 1
max_heights[(table, order, name, v.name)] = h_range - 1
for h in range(h_range):
g.add_node(
get_name(table, order, name, v.name, h),
table=table,
order=order,
attr=name,
value=v.name,
height=h,
)
if h:
g.add_edge(
get_name(table, order, name, v.name, h),
get_name(table, order, name, v.name, h - 1),
)
if cmn:
g.add_edge(
get_name(table, order, name, cmn.name, 0),
get_name(table, order, name, v.name, v.height - 2),
)
for node in nodes:
for parent in node.p:
node_name = get_name(None, None, node.attr, node.value, 0)
order = None
if len(parent) == 3:
table, aname, sel = parent
if isinstance(table, tuple):
order = table[1]
table = table[0]
else:
table = None
aname, sel = parent
if isinstance(sel, int):
if table and order is not None:
cmn = cast(SeqAttributes, attrs[table]).hist[order][aname].common
else:
tattrs = attrs[table]
if isinstance(tattrs, SeqAttributes):
assert tattrs.attrs
cmn = tattrs.attrs[aname].common
else:
cmn = tattrs[aname].common
assert cmn
cmn = cmn.name
cmn_name = get_name(table, order, aname, cmn, sel)
g.add_edge(cmn_name, node_name)
else:
for k, v in sel.items():
other_name = get_name(table, order, aname, k, v)
g.add_edge(other_name, node_name)
if prune:
for node, d in list(g.nodes(data=True)):
if not d["height"]:
continue # keep all height 0 nodes
next_neighbor = None
prev_neighbor = None
prune_node = True
for neighbor in chain(g.successors(node), g.predecessors(node)):
nd = g.nodes[neighbor]
# Prune all nodes where their only neighbor is a different height
# of their value
if (
d["table"] != nd["table"]
or d["order"] != nd["order"]
or d["attr"] != nd["attr"]
):
prune_node = False
elif d["value"] != nd["value"]:
if (
commons.get((d["table"], d["order"], d["attr"]), None)
== nd["value"]
and d["height"]
== max_heights[(d["table"], d["order"], d["attr"], d["value"])]
):
prev_neighbor = neighbor
else:
prune_node = False
elif d["height"] < nd["height"]:
prev_neighbor = neighbor
else:
next_neighbor = neighbor
if prune_node:
g.remove_node(node)
if next_neighbor is not None and prev_neighbor is not None:
g.add_edge(prev_neighbor, next_neighbor)
else:
pass
return g
[docs]
def derive_obs_from_model(
nodes: Sequence[Node], attrs: DatasetAttributes, marginals: Sequence[np.ndarray]
):
from ....graph.hugin import AttrMeta, get_attrs
from ....graph.loss import LinearObservation
lin_obs = []
for node, obs in zip(nodes, marginals):
# Create Attr Meta
out = []
used_parent = False
orig = []
for s in node.p:
if len(s) == 3:
table_sel, attr_name, sel = s
else:
table_sel = None
attr_name, sel = s
if isinstance(table_sel, tuple):
table = table_sel[0]
order = table_sel[1]
else:
table = table_sel
order = None
attr = get_attrs(attrs, table, order)[attr_name]
if isinstance(sel, int):
new_sel = sel
orig.append((table, order, attr_name, None))
else:
cmn = attr.common.name if attr.common else None
new_sel = []
for val, h in sel.items():
if val == cmn:
continue # skip common
new_sel.append((val, h))
orig.append((table, order, attr_name, val))
if node.attr == attr_name:
new_sel.append((node.value, 0))
used_parent = True
new_sel = tuple(sorted(new_sel))
out.append(AttrMeta(table, order, attr_name, new_sel))
if not used_parent:
out.append(AttrMeta(None, None, node.attr, ((node.value, 0),)))
orig.append((None, None, node.attr, node.value))
# Transpose observation
source = tuple(sorted(out, key=lambda x: x[:-1]))
vals = list(
chain.from_iterable(
(
[(a.table, a.order, a.attr, None)]
if isinstance(a.sel, int)
else [(a.table, a.order, a.attr, v[0]) for v in a.sel]
)
for a in source
)
)
# Find new domain and transpose dimensions to be alphabetical
new_obs = obs.astype("float32").transpose([orig.index(v) for v in vals])
new_dom = []
i = 0
for a in source:
if isinstance(a.sel, int):
l = 1
else:
l = len(a.sel)
nd = 1
for d in new_obs.shape[i : i + l]:
nd *= d
new_dom.append(nd)
i += l
new_obs = new_obs.reshape(new_dom)
# Align naive and compressed representations
for i, a in enumerate(source):
if isinstance(a.sel, int): # or len(a.sel) == 1:
# Skip single dimension
continue
attr = get_attrs(attrs, a.table, a.order)[a.attr]
i_map = tuple(
attr.get_naive_mapping(dict(a.sel)) if j == i else slice(None)
for j in range(len(new_obs.shape))
)
o_map = tuple(
attr.get_mapping(dict(a.sel)) if j == i else slice(None)
for j in range(len(new_obs.shape))
)
tmp = np.zeros(
[
attr.get_domain(dict(a.sel)) if j == i else d
for j, d in enumerate(new_obs.shape)
]
)
np.add.at(tmp, o_map, new_obs[i_map]) # type: ignore
new_obs = tmp
lo = LinearObservation(
source,
None,
new_obs,
1,
)
lin_obs.append(lo)
return lin_obs