Source code for pasteur.extras.synth.pgm.common
import logging
from typing import cast
from mbi import Dataset, Domain
from ....attribute import IdxValue
from ....marginal import AttrSelector, MarginalOracle
logger = logging.getLogger(__name__)
[docs]class OracleDataset(Dataset):
def __init__(
self,
o: MarginalOracle,
domain: "Domain | None" = None,
force_cache: bool = True,
cache: dict = {},
):
self.o = o
self.attrs = o.attrs
self.cache = cache
self.force_cache = force_cache
if domain is not None:
self.domain = domain
else:
names = []
domains = []
for attr in self.attrs.values():
for val in attr.vals.values():
names.append(val.name)
domains.append(cast(IdxValue, val).get_domain(0))
self.domain = Domain(names, domains)
[docs] def project(self, cols):
"""project dataset onto a subset of columns"""
if type(cols) in [str, int]:
cols = [cols]
domain = self.domain.project(cols)
return OracleDataset(
self.o, domain, force_cache=self.force_cache, cache=self.cache
)
[docs] def drop(self, cols):
proj = [c for c in self.domain if c not in cols]
return self.project(proj)
@property
def records(self):
return self.o.get_shape()[0]
[docs] def datavector(self, flatten=True):
"""return the database in vector-of-counts form"""
assert flatten
cols = tuple(self.domain.attrs)
if cols in self.cache:
return self.cache[cols]
else:
assert not self.force_cache, "You set to force use cache and marginal is not in cache."
req = [{col: AttrSelector(col, 0, {col: 0}) for col in self.domain.attrs}]
return self.o.process(req, "", normalize=False)[0]
[docs] def cache_marginals(self, requests: list[tuple[str]]):
non_cached_req = [req for req in requests if req not in self.cache]
marginals = self.o.process(
[
{col: AttrSelector(col, 0, {col: 0}) for col in cols}
for cols in non_cached_req
],
normalize=False,
)
for req, mar in zip(non_cached_req, marginals):
self.cache[req] = mar