from collections import defaultdict
import logging
from typing import (
Callable,
Generic,
Literal,
Mapping,
Protocol,
Sequence,
TypeGuard,
TypeVar,
cast,
overload,
)
import numpy as np
import pandas as pd
from ..utils.data import LazyDataset, LazyPartition
from ..attribute import (
Attribute,
Attributes,
CatValue,
DatasetAttributes,
SeqAttributes,
)
from ..utils import LazyChunk, LazyFrame
from ..utils.progress import PROGRESS_STEP_NS, piter, process_in_parallel
from .memory import load_from_memory, map_to_memory, merge_memory
from .numpy import (
AttrName,
AttrSelectors,
CalculationInfo,
ChildSelector,
CommonSelector,
TableSelector,
expand_table,
)
logger = logging.getLogger(__name__)
try:
from .native_py import calc_marginal
except Exception as e:
logger.error(
f"Failed importing native marginal implementation, using numpy instead (2-8x slower). Error:\n{e}"
)
from .numpy import calc_marginal
A = TypeVar("A", covariant=True)
MarginalRequest = Sequence[
tuple[TableSelector, AttrName, ChildSelector | CommonSelector]
| tuple[AttrName, ChildSelector | CommonSelector]
]
[docs]
def convert_reqs(
reqs: list[MarginalRequest],
) -> list[AttrSelectors]:
return [[y if len(y) == 3 else (None, *y) for y in x] for x in reqs]
[docs]
class PreprocessFun(Protocol):
def __call__(
self,
data: Mapping[str, LazyPartition],
) -> dict[TableSelector, pd.DataFrame]: ...
[docs]
class PostprocessFun(Protocol, Generic[A]):
def __call__(
self, req: AttrSelectors, mar: np.ndarray, info: CalculationInfo
) -> A: ...
[docs]
def counts_preprocess(
data: Mapping[str, LazyPartition]
) -> dict[TableSelector, pd.DataFrame]:
return {k: v() for k, v in data.items() if not k.startswith("ids_")}
def _tabular_load(
data: Mapping[str, LazyPartition],
) -> dict[TableSelector, pd.DataFrame]:
return {None: next(iter(v for d, v in data.items() if "ids_" not in d))()}
[docs]
def sequential_load(
attrs: Mapping[str | None, Attributes | SeqAttributes],
data: Mapping[str, LazyPartition],
preprocess: PreprocessFun,
):
out = preprocess(data)
cols, info = expand_table(attrs, out)
mem_arr, mem_info = map_to_memory(cols)
return mem_arr, mem_info, info
[docs]
def parallel_load(
attrs: Mapping[str | None, Attributes | SeqAttributes],
data: Mapping[str, LazyPartition],
preprocess: PreprocessFun,
):
base_args = {
"attrs": attrs,
"preprocess": preprocess,
}
per_call_args = [{"data": chunks} for chunks in LazyFrame.zip_values(data)]
out = process_in_parallel(
sequential_load, per_call_args, base_args, desc="Loading data"
)
info = out[0][-1]
mem_arr, mem_info = merge_memory(out)
return mem_arr, mem_info, info
[docs]
def get_info(
attrs: Mapping[str | None, Attributes | SeqAttributes],
data: Mapping[str, LazyPartition],
preprocess: PreprocessFun,
):
out = preprocess(
{k: v.sample if isinstance(v, LazyDataset) else v() for k, v in data.items()}
)
_, info = expand_table(attrs, out)
return info
def _marginal_initializer(base_args, per_call_args):
copy = base_args["copy"]
data = load_from_memory(base_args["mem_arr"], base_args["mem_info"], copy)
new_base_args = base_args.copy()
new_base_args["data"] = data
return new_base_args, per_call_args
def _marginal_worker(
data,
info,
req: AttrSelectors,
postprocess: PostprocessFun | None,
**_,
) -> np.ndarray:
res = calc_marginal(data, info, req)
if postprocess is not None:
return postprocess(req, res, info)
return res
def _marginal_batch_worker_inmem(
mem_arr,
mem_info,
info,
arange,
requests: list,
progress_lock,
progress_send,
) -> list[np.ndarray]:
from time import time_ns
data = load_from_memory(mem_arr, mem_info, range=arange, copy=True)
out = []
u = 0
last_updated = time_ns()
for x in requests:
out.append(calc_marginal(data, info, x))
u += 1
if (curr_time := time_ns()) - last_updated > PROGRESS_STEP_NS:
last_updated = curr_time
with progress_lock:
progress_send.send(u)
u = 0
if u > 0:
with progress_lock:
progress_send.send(u)
return out
def _marginal_batch_worker_load(
attrs: dict[str | None, Attributes | SeqAttributes],
data: dict[str, LazyPartition],
preprocess: PreprocessFun,
requests: list,
progress_lock,
progress_send,
) -> list[np.ndarray]:
from time import time_ns
out = preprocess(data)
cols, info = expand_table(attrs, out)
out = []
u = 0
last_updated = time_ns()
for x in requests:
out.append(calc_marginal(cols, info, x))
u += 1
if (curr_time := time_ns()) - last_updated > PROGRESS_STEP_NS:
last_updated = curr_time
with progress_lock:
progress_send.send(u)
u = 0
if u > 0:
with progress_lock:
progress_send.send(u)
return out
def _is_attributes(a) -> TypeGuard[Attributes]:
if not len(a):
return False
return isinstance(next(iter(a.values())), Attribute)
[docs]
class MarginalOracle:
MODES = Literal[
"out_of_core",
"inmemory",
"inmemory_shared",
"inmemory_copy",
"inmemory_batched",
]
def __init__(
self,
data: Mapping[str, LazyPartition],
attrs: DatasetAttributes | Attributes,
preprocess: PreprocessFun = _tabular_load,
mode: "MarginalOracle.MODES" = "out_of_core",
*,
min_chunk_size: int = 1,
max_worker_mult: int = 1,
repartitions: int | None = None,
log: bool = True,
) -> None:
if _is_attributes(attrs):
self.attrs: DatasetAttributes = {None: attrs}
else:
self.attrs = cast(DatasetAttributes, attrs)
self.data = data
self.preprocess = preprocess
if mode == "out_of_core" and not LazyFrame.are_partitioned(data):
logger.info("Data is not partitioned, switching to mode `inmemory_copy`.")
self.mode = "inmemory_copy"
elif mode == "inmemory":
# inmemory is an alias for inmemory_copy
self.mode = "inmemory_copy"
else:
self.mode = mode
data_partitions = 1
if LazyFrame.are_partitioned(data):
data_partitions = len(LazyFrame.zip_values(data))
self.repartitions = repartitions or data_partitions
if self.repartitions == 1 and mode == "inmemory_batched":
logger.info(
"Data is not partitioned and `repartitions` is not provided. Can't infer partition number, switching to mode `inmemory_copy`."
)
self.mode = "inmemory_copy"
self.min_chunk_size = min_chunk_size
self.max_worker_mult = max_worker_mult
self.counts = None
self.log = log
self._load_id = None
self.info = None
self.marginal_count = 0
[docs]
def load_data(self, preprocess: PreprocessFun):
if self._load_id:
if self._load_id == id(preprocess):
return
else:
self.unload_data()
if LazyFrame.are_partitioned(self.data):
# Load data in parallel
(self.mem_arr, self.mem_info, self.info) = parallel_load(
self.attrs, self.data, preprocess
)
else:
# Load data sequentially
(self.mem_arr, self.mem_info, self.info) = sequential_load(
self.attrs, self.data, preprocess
)
self._load_id = id(preprocess)
[docs]
def unload_data(self):
if not self._load_id:
return
self.mem_arr.close()
self.mem_arr.unlink()
self._load_id = None
def _process_inmemory(
self,
requests: list[AttrSelectors],
desc: str,
preprocess: PreprocessFun,
postprocess: PostprocessFun | None,
):
assert self.mode in ("inmemory_shared", "inmemory_copy")
if len(requests) == 0:
return []
self.load_data(preprocess)
base_args = {
"mem_arr": self.mem_arr,
"mem_info": self.mem_info,
"info": self.info,
"copy": self.mode == "inmemory_copy",
"postprocess": postprocess,
}
per_call_args = [{"req": req} for req in requests]
res = process_in_parallel(
_marginal_worker,
per_call_args,
base_args,
min_chunk_size=self.min_chunk_size,
max_worker_mult=self.max_worker_mult,
desc=desc,
initializer=_marginal_initializer,
)
return res
def _process_batched(
self,
requests: list[AttrSelectors],
desc: str,
preprocess: PreprocessFun,
postprocess: PostprocessFun | None,
):
assert self.mode in ("inmemory_batched", "out_of_core")
from multiprocessing import Pipe
from threading import Lock, Thread
from ..utils.progress import MULTIPROCESS_ENABLE, get_manager
if len(requests) == 0:
return []
progress_recv, progress_send = Pipe(duplex=False)
if MULTIPROCESS_ENABLE:
progress_lock = get_manager().Lock()
else:
# Use a thread lock to prevent launching a pool with multiprocess
# disabled
progress_lock = Lock()
base_args = {
"progress_send": progress_send,
"progress_lock": progress_lock,
"requests": requests,
}
if self.mode == "out_of_core":
base_args.update({"attrs": self.attrs})
per_call_args = [
{"data": chunks} for chunks in LazyFrame.zip_values(self.data)
]
l = len(requests) * len(LazyFrame.zip_values(self.data))
base_args.update({"preprocess": preprocess})
fun = _marginal_batch_worker_load
else:
self.load_data(preprocess)
base_args.update(
{"mem_arr": self.mem_arr, "mem_info": self.mem_info, "info": self.info}
)
n = next(iter(self.mem_info.values()))[0].shape[0]
chunk_n_suggestion = min(n, self.repartitions)
chunk_len = (n - 1) // chunk_n_suggestion + 1
chunk_n = (n - 1) // chunk_len + 1
chunk_ranges = [
(chunk_len * j, min(chunk_len * (j + 1), n)) for j in range(chunk_n)
]
per_call_args = [{"arange": chunk_range} for chunk_range in chunk_ranges]
l = len(requests) * chunk_n
fun = _marginal_batch_worker_inmem
def track_progress():
pbar = None
n = 0
while n < l and (u := progress_recv.recv()) is not None:
if pbar is None:
# Start pbar after the partition pbar has started
pbar = piter(desc="Calculating submarginals", total=l, leave=False)
n += u
pbar.update(u)
t = Thread(target=track_progress)
try:
t.start()
res = process_in_parallel(
fun,
per_call_args,
base_args,
desc=desc,
max_worker_mult=self.max_worker_mult,
)
finally:
progress_send.send(None)
progress_send.close()
progress_recv.close()
t.join()
if len(res) == 0:
return []
out = []
for i in piter(
range(len(requests)),
desc="Postprocessing partitioned marginals",
leave=False,
):
mar = np.sum([batch[i] for batch in res], axis=0)
if postprocess is not None:
if not self.info:
self.info = get_info(self.attrs, self.data, preprocess)
out.append(postprocess(requests[i], mar, self.info))
else:
out.append(mar)
return out
@overload
def process(
self,
requests: list[MarginalRequest],
desc: str = ...,
preprocess: PreprocessFun | None = ...,
) -> list[np.ndarray]: ...
@overload
def process(
self,
requests: list[MarginalRequest],
desc: str = ...,
preprocess: PreprocessFun | None = ...,
postprocess: PostprocessFun[A] = ..., # type: ignore
) -> list[A]: ...
[docs]
def process(
self,
requests: list[MarginalRequest],
desc: str = "Processing partition",
preprocess: PreprocessFun | None = None,
postprocess: PostprocessFun[A] | None = None,
) -> list[np.ndarray] | list[A]:
self.marginal_count += len(requests)
if not preprocess:
preprocess = self.preprocess
if self.mode in ("inmemory_batched", "out_of_core"):
logger.debug(
f"Processing {len(requests)} marginals by loading partitions in parallel."
)
return self._process_batched(
convert_reqs(requests), desc, preprocess, postprocess
)
else:
logger.debug(
f"Processing {len(requests)} marginals by loading dataset in memory."
)
return self._process_inmemory(
convert_reqs(requests), desc, preprocess, postprocess
)
[docs]
def get_counts(
self, desc: str = "Calculating counts"
) -> dict[str | None, dict[str, np.ndarray]]:
if self.counts:
return self.counts
cols = []
reqs: list[AttrSelectors] = []
for table, table_attrs in self.attrs.items():
attrs_dict: dict[TableSelector, Attributes]
if isinstance(table_attrs, SeqAttributes):
assert table is not None
attrs_dict = {(table, k): v for k, v in table_attrs.hist.items()}
if table_attrs.attrs is not None:
attrs_dict[table] = table_attrs.attrs
else:
attrs_dict = {table: table_attrs}
for table_name, attrs in attrs_dict.items():
for attr in attrs.values():
if attr.common:
reqs.append([(table_name, attr.name, 0)])
cols.append((table_name, attr.common.name))
for val_name, val in attr.vals.items():
if isinstance(val, CatValue):
reqs.append([(table_name, attr.name, {val_name: 0})])
cols.append((table_name, val_name))
count_arr = self.process(reqs, desc=desc) # type: ignore
self.counts = defaultdict(dict)
for (table, name), count in zip(cols, count_arr):
self.counts[table][name] = count
return dict(self.counts)
[docs]
def close(self):
if self.log:
logger.info(f"Processed {self.marginal_count} marginals.")
self.unload_data()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()