"""Contains the base definition for Synth(esizer modules).
In addition, a test Synthesizer (IdentSynth) is provided, which returns
the data it was provided as is."""
from __future__ import annotations
from functools import partial, wraps
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from pasteur.utils import LazyDataset
from .encode import ViewEncoder
from .metadata import Metadata
from .module import ModuleClass, ModuleFactory
from .utils import LazyDataset, LazyFrame
META = TypeVar("META")
import logging
logger = logging.getLogger(__name__)
[docs]
def make_deterministic(obj_func, /, *, noise_kw: str | None = None):
"""Takes an object function (with self), and if the object has a seed attribute
it fixes the np.random.seed attribute to it and prints a random number at the end.
If the algorithm sampled the same amount of numbers at the same order, then the
numbers should be the same."""
if isinstance(obj_func, str):
return partial(make_deterministic, noise_kw=obj_func)
import random
import numpy as np
@wraps(obj_func)
def wrapped(self, *args, **kwargs):
if hasattr(self, "seed") and getattr(self, "seed") is not None:
seed = getattr(self, "seed")
if noise_kw is not None:
seed += kwargs[noise_kw]
np.random.seed(seed)
random.seed(seed)
a = obj_func(self, *args, **kwargs)
if hasattr(self, "seed") and getattr(self, "seed") is not None:
noise_info = f" ('{noise_kw}': {kwargs[noise_kw]:3d})" if noise_kw else ""
logger.info(
f"Deterministic check: random number after "
+ f"{f'{type(self).__name__}.{obj_func.__name__}':>22s}(): "
+ f"<np.random> {np.random.random():7.5f} <random> {random.random():7.5f}"
+ noise_info
)
return a
return wrapped
[docs]
class SynthFactory(ModuleFactory["Synth"]):
def __init__(self, cls: type[Synth], *args, name: str | None = None, **_) -> None:
super().__init__(cls, *args, name=name, **_)
self.type = cls.type
self.in_types = cls.in_types
self.in_sample = cls.in_sample
[docs]
class Synth(ModuleClass, Generic[META]):
# If in_types is provided, it must include type, and the data provided to
# preprocess, bake, fit will be dict[str, dict[str, LazyDataset]].
# Otherwise, it will be dict[str, LazyDataset].
in_types: list[str] | None = None
type = "idx"
# Include input data in sample()
in_sample: bool = False
_factory = SynthFactory
# Fill in for `sample` function to work
_n: int | None = None
# Fill in for `sample` function to work
_partitions: int | None = None
[docs]
def preprocess(
self,
meta: META,
data: dict[str, LazyDataset] | dict[str, dict[str, LazyDataset]],
):
"""Runs any preprocessing required, such as domain reduction."""
raise NotImplementedError()
[docs]
def bake(
self,
data: dict[str, LazyDataset] | dict[str, dict[str, LazyDataset]],
):
"""Bakes the model based on the data provided (such as creating and
modeling a bayesian network on the data).
Attributes provide context about the data columns, including hierarchical
relationships, na vals, etc."""
raise NotImplementedError()
[docs]
def fit(
self,
data: dict[str, LazyDataset] | dict[str, dict[str, LazyDataset]],
):
"""Fits the model based on the provided data.
Data and Ids are dictionaries containing the dataframes with the data."""
raise NotImplementedError()
[docs]
def sample_partition(self, *, n: int, i: int = 0) -> dict[str, Any]:
"""Returns synthetic data in the same format they were provided.
`n` sets how many rows should be sampled. Otherwise,
Warning: not setting `n` technically violates DP for DP-aware algorithms.
`i` is the partition number that can be used for modifying the random state
sampling, since deterministic sampling will always return the same data.
"""
raise NotImplementedError()
[docs]
def sample(
self,
*,
n: int | None = None,
partitions: int | None = None,
data: dict[str, LazyDataset] | dict[str, dict[str, LazyDataset]] | None = None,
):
"""Samples `n` samples across `partitions` partitions.
The return value should be finalized to `dict[str, Any]`, which
matches the format of `data` provided to the fitting function.
Since this
A default implementation is provided, that packages `sample_partition()`
in such a way that pasteur can sample and save partitions in parallel."""
n = n or self._n
partitions = partitions or self._partitions
assert (
n and partitions
), "Either `n` or `partitions` was not provided.\nFill in `_n` and `_partitions` based on `fit` data."
n_chunk = n // partitions
return {
partial(self.sample_partition, i=i, n=n_chunk) for i in range(partitions)
}
[docs]
def synth_fit(
factory: SynthFactory,
metadata: Metadata,
encoder: ViewEncoder | dict[str, ViewEncoder],
data: dict[str, LazyDataset],
):
from .utils.perf import PerformanceTracker
tracker = PerformanceTracker.get("synth")
tracker.ensemble("total", "preprocess", "bake", "fit")
meta = (
{t: e.get_metadata() for t, e in encoder.items()}
if isinstance(encoder, dict)
else encoder.get_metadata()
)
args = {**metadata.algs.get(factory.name, {}), **metadata.alg_override}
model = factory.build(**args, seed=metadata.seed)
# if factory.gpu:
# tracker.use_gpu()
tracker.start("preprocess")
model.preprocess(meta, data)
tracker.stop("preprocess")
tracker.start("bake")
model.bake(data)
tracker.stop("bake")
tracker.start("fit")
model.fit(data)
tracker.stop("fit")
return model
[docs]
def synth_sample(s: Synth, data=None):
if data is not None:
return s.sample(data=data)
else:
return s.sample()
[docs]
class IdentSynth(Synth):
"""Samples the data it was provided."""
name = "ident_idx"
type = "idx"
partitions = 1
[docs]
def preprocess(self, meta: Any, data: dict[str, LazyDataset]):
pass
[docs]
def bake(self, data: dict[str, LazyDataset]):
pass
[docs]
def fit(self, data: dict[str, LazyDataset]):
self.data = data
[docs]
def sample(self, n: int | None = None):
return self.data
[docs]
class IdentSynthJson(Synth):
"""Samples the data it was provided."""
in_types = ["json", "idx"]
name = "ident_json"
type = "json"
partitions = 1
[docs]
def preprocess(self, meta: Any, data: dict[str, dict[str, LazyDataset]]):
pass
[docs]
def bake(self, data: dict[str, dict[str, LazyDataset]]):
pass
[docs]
def fit(self, data: dict[str, dict[str, LazyDataset]]):
self.data = data["json"]
[docs]
def sample(self, n: int | None = None):
return self.data
__all__ = [
"Synth",
"SynthFactory",
"IdentSynth",
"IdentSynthJson",
"make_deterministic",
]