from copy import copy
from typing import Any, cast
import numpy as np
import pandas as pd
from pasteur.attribute import Attribute
from pasteur.metadata import Metadata
from pasteur.utils import LazyPartition
from ..attribute import (
Attribute,
CatValue,
NumValue,
_create_strat_value_ord,
get_dtype,
)
from ..encode import AttributeEncoder, PostprocessEncoder
[docs]class IdxEncoder(AttributeEncoder[Attribute]):
name = "idx"
[docs] def fit(self, attr: Attribute, data: pd.DataFrame):
self.transformers: dict[str, DiscretizationColumnTransformer] = {}
# FIXME: not out-of-core
cols = {}
for name, col_attr in attr.vals.items():
if isinstance(col_attr, NumValue):
t = DiscretizationColumnTransformer()
new_attr = t.fit(col_attr, data[name])
if isinstance(new_attr, dict):
cols.update(new_attr)
else:
cols[name] = new_attr
self.transformers[name] = t
else:
cols[name] = col_attr
self.attr = copy(attr)
self.attr.update_vals(cols)
[docs] def encode(self, data: pd.DataFrame) -> pd.DataFrame:
if len(self.attr.vals) == 0:
return pd.DataFrame(index=data.index)
out_cols = []
for name, col in self.attr.vals.items():
t = self.transformers.get(name, None)
if t:
out_cols.append(t.encode(data[name]))
else:
out_cols.append(data[name])
return pd.concat(out_cols, axis=1, copy=False, join="inner")
[docs] def decode(self, enc: pd.DataFrame) -> pd.DataFrame:
dec = pd.DataFrame(index=enc.index)
for n in self.attr.vals.keys():
t = self.transformers.get(n, None)
if t:
dec[n] = t.decode(enc)
else:
dec[n] = enc[n]
return dec
[docs]class NumEncoder(AttributeEncoder[Attribute]):
name = "num"
[docs] def fit(self, attr: Attribute, data: pd.DataFrame):
self.in_attr = attr
cols = {}
common = attr.common
skip_common = False
if len(attr.vals) == 1:
v = next(iter(attr.vals.values()))
if isinstance(v, CatValue) and v.is_ordinal:
skip_common = True
if not skip_common:
for i in range(common):
cols[f"{attr.name}_cmn_{i}"] = NumValue()
for name, col in attr.vals.items():
if isinstance(col, NumValue):
cols[name] = col
elif isinstance(col, CatValue):
if col.is_ordinal():
cols[name] = NumValue()
else:
assert col.common == common
for i in range(col.get_domain(0) - col.common):
cols[f"{name}_{i}"] = NumValue()
self.attr = copy(attr)
self.attr.update_vals(cols)
[docs] def encode(self, data: pd.DataFrame) -> pd.DataFrame:
a = self.in_attr
if len(a.vals) == 0:
return pd.DataFrame(index=data.index)
cols = []
only_has_na = a.common == 1 and a.na
# Handle common values
skip_common = False
if len(a.vals) == 1:
v = next(iter(a.vals.values()))
if isinstance(v, CatValue) and v.is_ordinal:
skip_common = True
for i in range(a.common) if not skip_common else []:
cmn_col = pd.Series(
False, index=data.index, name=f"{a.name}_cmn_{i}", dtype=np.float32
)
for name, col in a.vals.items():
if isinstance(col, CatValue):
cmn_col += data[name] == i
elif isinstance(col, NumValue) and only_has_na:
# Numerical values are expected to be NA for all common values
# so they are only used to set the common values when:
# `common == 1 and a.na`, meaning the only common value is NA.``
cmn_col += pd.isna(data[name])
cols.append(cmn_col.clip(0, 1, inplace=False))
# Add other columns
for name, col in a.vals.items():
if isinstance(col, NumValue):
cols.append(data[name])
elif isinstance(col, CatValue):
# TODO add proper encodings other than one hot
# Handle ordinal values
if col.is_ordinal():
cols.append(data[name])
else:
# One hot encode everything else
for i in range(col.get_domain(0) - col.common):
cols.append(
(data[name] == i + col.common).rename(f"{name}_{i}")
)
return pd.concat(cols, axis=1, copy=False, join="inner")
[docs] def decode(self, enc: pd.DataFrame) -> pd.DataFrame:
assert False, "Not Implemented"
[docs]class MareEncoder(IdxEncoder, PostprocessEncoder[Attribute]):
name = "mare"
[docs] def finalize(
self,
meta: dict[str, dict[tuple[str, ...] | str, Attribute]],
tables: dict[str, pd.DataFrame],
ids: dict[str, pd.DataFrame],
) -> dict[str, Any]:
return super().finalize(meta, tables, ids)
[docs] def undo(
self,
meta: dict[str, dict[tuple[str, ...] | str, Attribute]],
data: dict[str, LazyPartition],
) -> tuple[dict[str, pd.DataFrame], dict[str, pd.DataFrame]]:
return super().undo(meta, data)