import logging
from typing import Any, Callable, Literal, NamedTuple, overload
import numpy as np
from ..attribute import Attributes
from ..utils import LazyChunk, LazyFrame
from ..utils.progress import PROGRESS_STEP_NS, piter, process_in_parallel
from .numpy import ZERO_FILL, AttrSelector, AttrSelectors, expand_table, get_domains
from .numpy import normalize as normalize_2way
from .numpy import normalize_1way
from .memory import load_from_memory, map_to_memory, allocate_memory
logger = logging.getLogger(__name__)
from .numpy import calc_marginal as calc_marginal_np
try:
from .native_py import calc_marginal, calc_marginal_1way
except Exception as e:
logger.error(
f"Failed importing native marginal implementation, using numpy instead (2-8x slower). Error:\n{e}"
)
from .numpy import calc_marginal, calc_marginal_1way
[docs]class MarginalRequest(NamedTuple):
x: AttrSelector
p: AttrSelectors
partial: bool
def _find_columns(reqs: list[MarginalRequest]) -> list[str] | None:
cols = set()
for req in reqs:
if req.x is not None:
cols.update(req.x.cols.keys())
for p in req.p.values():
cols.update(p.cols.keys())
return sorted(list(cols)) or None
def _find_columns_1way(X: list[AttrSelectors]) -> list[str] | None:
cols = set()
for x in X:
for sel in x.values():
cols.update(sel.cols.keys())
return sorted(list(cols)) or None
[docs]def sequential_load(
chunk: LazyChunk, attrs: Attributes, columns: list[str] | None = None
):
df = chunk(columns=columns)
cols, cols_noncommon, domains = expand_table(attrs, df)
mem_cols, info_cols = map_to_memory(cols)
del cols
mem_noncommon, info_noncommon = map_to_memory(cols_noncommon)
return mem_cols, info_cols, mem_noncommon, info_noncommon, domains
def _parallel_load_worker(
mem_cols,
info_cols,
mem_noncommon,
info_noncommon,
chunk: LazyChunk,
chunk_range: tuple[int, int],
attrs: Attributes,
columns: list[str] | None = None,
):
cols = load_from_memory(mem_cols, info_cols, range=chunk_range)
cols_noncommon = load_from_memory(mem_noncommon, info_noncommon, range=chunk_range)
df = chunk(columns=columns)
_, _, domains = expand_table(attrs, df, out_cols=cols, out_noncommon=cols_noncommon)
return domains
[docs]def parallel_load(data: LazyFrame, attrs: Attributes, columns: list[str] | None = None):
# Both memory allocations share the range
mem_cols, info_cols, _ = allocate_memory(data, attrs, common=False)
mem_noncommon, info_noncommon, ranges = allocate_memory(data, attrs, common=True)
base_args = {
"mem_cols": mem_cols,
"info_cols": info_cols,
"mem_noncommon": mem_noncommon,
"info_noncommon": info_noncommon,
"attrs": attrs,
"columns": columns,
}
per_call_args = [
{"chunk": data[name], "chunk_range": ranges[name]} for name in data.keys()
]
out = process_in_parallel(
_parallel_load_worker, per_call_args, base_args, desc="Loading data"
)
domains = out[0]
return mem_cols, info_cols, mem_noncommon, info_noncommon, domains
def _marginal_initializer(base_args, per_call_args):
copy = base_args["copy"]
cols = load_from_memory(base_args["mem_cols"], base_args["info_cols"], copy)
cols_noncommon = load_from_memory(
base_args["mem_noncommon"], base_args["info_noncommon"], copy
)
new_base_args = base_args.copy()
new_base_args["cols"] = cols
new_base_args["cols_noncommon"] = cols_noncommon
return new_base_args, per_call_args
def _marginal_worker(
cols,
cols_noncommon,
domains,
req: MarginalRequest,
normalize: bool,
zero_fill: float,
postprocess: Callable | None,
**_,
):
x, p, partial = req
if partial:
# Native implementation is unfinished
res = calc_marginal_np(cols, cols_noncommon, domains, x, p, partial)
else:
res = calc_marginal(cols, cols_noncommon, domains, x, p, partial)
if normalize:
if postprocess is not None:
return postprocess(*normalize_2way(res, zero_fill))
else:
return normalize_2way(res, zero_fill)
if postprocess is not None:
return postprocess(res)
return res
def _marginal_worker_1way(
cols,
cols_noncommon,
domains,
x: AttrSelectors,
normalize: bool,
zero_fill: float,
postprocess: Callable | None,
**_,
) -> np.ndarray:
res = calc_marginal_1way(cols, cols_noncommon, domains, x)
if normalize:
if postprocess is not None:
return postprocess(normalize_1way(res, zero_fill))
else:
return normalize_1way(res, zero_fill)
if postprocess is not None:
return postprocess(res)
return res
def _marginal_batch_worker(
requests: list[MarginalRequest],
attrs: Attributes,
chunk: LazyChunk,
shared,
progress_lock,
progress_send,
) -> list[np.ndarray]:
from time import time_ns
if chunk is not None:
cols, cols_noncommon, domains = expand_table(
attrs, chunk(columns=_find_columns(requests))
)
else:
(
mem_cols,
info_cols,
mem_noncommon,
info_noncommon,
domains,
chunk_range,
) = shared
cols = load_from_memory(mem_cols, info_cols, range=chunk_range, copy=True)
cols_noncommon = load_from_memory(
mem_noncommon, info_noncommon, range=chunk_range, copy=True
)
out = []
u = 0
last_updated = time_ns()
for x, p, partial in requests:
if partial:
# Native implementation is unfinished
out.append(calc_marginal_np(cols, cols_noncommon, domains, x, p, partial))
else:
out.append(calc_marginal(cols, cols_noncommon, domains, x, p, partial))
u += 1
if (curr_time := time_ns()) - last_updated > PROGRESS_STEP_NS:
last_updated = curr_time
with progress_lock:
progress_send.send(u)
u = 0
if u > 0:
with progress_lock:
progress_send.send(u)
return out
def _marginal_batch_worker_1way(
X: list[AttrSelectors],
attrs: Attributes,
chunk: LazyChunk | None,
shared,
progress_lock,
progress_send,
) -> list[np.ndarray]:
from time import time_ns
if chunk is not None:
cols, cols_noncommon, domains = expand_table(
attrs, chunk(columns=_find_columns_1way(X))
)
else:
(
mem_cols,
info_cols,
mem_noncommon,
info_noncommon,
domains,
chunk_range,
) = shared
cols = load_from_memory(mem_cols, info_cols, range=chunk_range, copy=True)
cols_noncommon = load_from_memory(
mem_noncommon, info_noncommon, range=chunk_range, copy=True
)
out = []
u = 0
last_updated = time_ns()
for x in X:
out.append(calc_marginal_1way(cols, cols_noncommon, domains, x))
u += 1
if (curr_time := time_ns()) - last_updated > PROGRESS_STEP_NS:
last_updated = curr_time
with progress_lock:
progress_send.send(u)
u = 0
if u > 0:
with progress_lock:
progress_send.send(u)
return out
[docs]class MarginalOracle:
MODES = Literal[
"out_of_core",
"inmemory",
"inmemory_shared",
"inmemory_copy",
"inmemory_batched",
]
def __init__(
self,
attrs: Attributes,
data: LazyFrame,
mode: "MarginalOracle.MODES" = "out_of_core",
*,
min_chunk_size: int = 1,
max_worker_mult: int = 1,
repartitions: int | None = None,
log: bool = True,
) -> None:
self.attrs = attrs
self.data = data
if mode == "out_of_core" and not data.partitioned:
logger.info("Data is not partitioned, switching to mode `inmemory_copy`.")
self.mode = "inmemory_copy"
elif mode == "inmemory":
# inmemory is an alias for inmemory_copy
self.mode = "inmemory_copy"
else:
self.mode = mode
self.repartitions = repartitions or len(data)
if self.repartitions == 1 and mode == "inmemory_batched":
logger.info(
"Data is not partitioned and `repartitions` is not provided. Can't infer partition number, switching to mode `inmemory_copy`."
)
self.mode = "inmemory_copy"
self.min_chunk_size = min_chunk_size
self.max_worker_mult = max_worker_mult
self.counts = None
self.log = log
self._loaded = False
self.marginal_count = 0
[docs] def get_domains(self):
return get_domains(self.attrs)
[docs] def get_shape(self):
return self.data.shape
[docs] def load_data(self, columns: list[str] | None = None):
if self._loaded:
return
if self.data.partitioned:
# Load data in parallel
(
self.mem_cols,
self.info_cols,
self.mem_noncommon,
self.info_noncommon,
self.domains,
) = parallel_load(self.data, self.attrs, columns)
else:
# Load data sequentially
(
self.mem_cols,
self.info_cols,
self.mem_noncommon,
self.info_noncommon,
self.domains,
) = sequential_load(self.data, self.attrs, columns)
self._loaded = True
[docs] def unload_data(self):
if not self._loaded:
return
self.mem_cols.close()
self.mem_cols.unlink()
self.mem_noncommon.close()
self.mem_noncommon.unlink()
self._loaded = False
def _process_inmemory(
self,
requests: list[MarginalRequest] | list[AttrSelectors],
desc: str,
normalize: bool,
zero_fill: float | None,
postprocess: Callable | None,
):
assert self.mode in ("inmemory_shared", "inmemory_copy")
if len(requests) == 0:
return []
is_1way = not isinstance(requests[0], MarginalRequest)
self.load_data()
base_args = {
"mem_cols": self.mem_cols,
"info_cols": self.info_cols,
"mem_noncommon": self.mem_noncommon,
"info_noncommon": self.info_noncommon,
"copy": self.mode == "inmemory_copy",
"domains": self.domains,
"normalize": normalize,
"zero_fill": zero_fill,
"postprocess": postprocess,
}
if is_1way:
per_call_args = [{"x": x} for x in requests]
else:
per_call_args = [{"req": req} for req in requests]
res = process_in_parallel(
_marginal_worker_1way if is_1way else _marginal_worker,
per_call_args,
base_args,
min_chunk_size=self.min_chunk_size,
max_worker_mult=self.max_worker_mult,
desc=desc,
initializer=_marginal_initializer,
)
return res
def _process_batched(
self,
requests: list[MarginalRequest] | list[AttrSelectors],
desc: str,
normalize: bool,
zero_fill: float | None,
postprocess: Callable | None,
):
assert self.mode in ("inmemory_batched", "out_of_core")
from multiprocessing import Pipe
from threading import Thread, Lock
from pasteur.utils.progress import get_manager, MULTIPROCESS_ENABLE
if len(requests) == 0:
return []
is_1way = not isinstance(requests[0], MarginalRequest)
progress_recv, progress_send = Pipe(duplex=False)
if MULTIPROCESS_ENABLE:
progress_lock = get_manager().Lock()
else:
# Use a thread lock to prevent launching a pool with multiprocess
# disabled
progress_lock = Lock()
base_args = {
"attrs": self.attrs,
"progress_send": progress_send,
"progress_lock": progress_lock,
}
if is_1way:
base_args["X"] = requests
else:
base_args["requests"] = requests
if self.mode == "out_of_core":
base_args["shared"] = None
per_call_args = [{"chunk": chunk} for chunk in self.data.values()]
l = len(requests) * len(self.data)
else:
self.load_data()
base_args["chunk"] = None
shared_base = (
self.mem_cols,
self.info_cols,
self.mem_noncommon,
self.info_noncommon,
self.domains,
)
n = self.data.shape[0]
chunk_n_suggestion = min(n, self.repartitions)
chunk_len = (n - 1) // chunk_n_suggestion + 1
chunk_n = (n - 1) // chunk_len + 1
chunk_ranges = [
(chunk_len * j, min(chunk_len * (j + 1), n)) for j in range(chunk_n)
]
per_call_args = [
{"shared": (*shared_base, chunk_range)}
for chunk_range in chunk_ranges
]
l = len(requests) * chunk_n
def track_progress():
pbar = None
n = 0
while n < l and (u := progress_recv.recv()) is not None:
if pbar is None:
# Start pbar after the partition pbar has started
pbar = piter(desc="Calculating submarginals", total=l, leave=False)
n += u
pbar.update(u)
t = Thread(target=track_progress)
try:
t.start()
res = process_in_parallel(
_marginal_batch_worker_1way if is_1way else _marginal_batch_worker,
per_call_args,
base_args,
desc=desc,
max_worker_mult=self.max_worker_mult,
)
finally:
progress_send.send(None)
progress_send.close()
progress_recv.close()
t.join()
if len(res) == 0:
return []
out = []
for i in piter(
range(len(requests)),
desc="Postprocessing partitioned marginals",
leave=False,
):
mar = np.sum([batch[i] for batch in res], axis=0)
if postprocess is not None:
if not normalize:
out.append(postprocess(mar))
elif is_1way:
out.append(postprocess(normalize_1way(mar, zero_fill)))
else:
out.append(postprocess(*normalize_2way(mar, zero_fill)))
else:
if not normalize:
out.append(mar)
elif is_1way:
out.append(normalize_1way(mar, zero_fill))
else:
out.append(normalize_2way(mar, zero_fill))
return out
@overload
def process(
self,
requests: list[MarginalRequest],
desc: str = ...,
normalize: Literal[False] = ...,
zero_fill: float | None = ...,
postprocess: None = ...,
) -> list[np.ndarray]:
...
@overload
def process(
self,
requests: list[MarginalRequest],
desc: str = ...,
normalize: Literal[True] = ...,
zero_fill: float | None = ...,
postprocess: None = ...,
) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]:
...
@overload
def process(
self,
requests: list[AttrSelectors],
desc: str = ...,
normalize: bool = ...,
zero_fill: float | None = ...,
postprocess: None = ...,
) -> list[np.ndarray]:
...
@overload
def process(
self,
requests: list[AttrSelectors],
desc: str = ...,
normalize: bool = ...,
zero_fill: float | None = ...,
postprocess: Callable = ...,
) -> list[Any]:
...
[docs] def process(
self,
requests: list[MarginalRequest] | list[AttrSelectors],
desc: str = "Processing partition",
normalize: bool = True,
zero_fill: float | None = ZERO_FILL,
postprocess: Callable | None = None,
) -> list[np.ndarray] | list[tuple[np.ndarray, np.ndarray, np.ndarray]] | list[Any]:
self.marginal_count += len(requests)
if self.mode in ("inmemory_batched", "out_of_core"):
logger.debug(
f"Processing {len(requests)} marginals by loading partitions in parallel."
)
return self._process_batched(
requests, desc, normalize, zero_fill, postprocess
)
else:
logger.debug(
f"Processing {len(requests)} marginals by loading dataset in memory."
)
return self._process_inmemory(requests, desc, normalize, zero_fill, postprocess) # type: ignore
[docs] def get_counts(self, desc: str = "Calculating counts"):
if self.counts:
return self.counts
cols = []
reqs = []
for name, attr in self.attrs.items():
for val in attr.vals:
cols.append(val)
reqs.append({name: AttrSelector(name, attr.common, {val: 0})})
count_arr = self.process(reqs, desc=desc)
self.counts = {name: count for name, count in zip(cols, count_arr)}
return self.counts
[docs] def close(self):
if self.log:
logger.info(f"Processed {self.marginal_count} marginals.")
self.unload_data()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()