Source code for pasteur.kedro.dataset.auto

import logging
import os
from copy import deepcopy
from functools import wraps
from io import BytesIO
from pathlib import PurePosixPath
from typing import Any, Callable

import fsspec
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from kedro.io.core import (
    PROTOCOL_DELIMITER,
    AbstractVersionedDataset,
    DatasetError,
    Version,
    get_filepath_str,
    get_protocol_and_path,
)

from ...utils import LazyDataset, LazyFrame, LazyPartition
from ...utils.progress import get_node_name, process, process_in_parallel, DEBUG

logger = logging.getLogger(__name__)


def _wrap_retry(f):
    if DEBUG:
        # Skip wrapping if DEBUG is on so exceptions break correctly
        return f

    @wraps(f)
    def _wrap(*args, **kwargs):
        ex = None
        for i in range(3):
            try:
                return f(*args, **kwargs)
            except Exception as e:
                import time

                # Prevents pipeline crashing on unreliable network shares
                logger.warn(
                    f"Failed fs function '{f.__name__}(path={kwargs.get('path', str(args[0]) if args else 'None')})' (attempt {i + 1}/3). Waiting 1 second and retrying..."
                )
                time.sleep(1)
                ex = e
        if ex:
            raise ex

    return _wrap


@_wrap_retry
def _save_worker(
    pid: str | None,
    path: str,
    chunk: Callable[..., pd.DataFrame] | pd.DataFrame,
    protocol,
    fs,
    save_args,
):
    if pid:
        logging.debug(f"Saving chunk {pid}...")

    if callable(chunk):
        chunk = chunk()
        if callable(chunk):
            logger.error(
                f"Callable `chunk()` got double wrapped (`to_chunked()` bug).\n{str(chunk)[:50]}"
            )
            chunk = chunk()

    from inspect import isgenerator

    # Handle data partitioning using generators, to avoid storing the whole partition in ram
    # or having to use pd.concat()
    if isgenerator(chunk):
        # Grab first chunk with content
        p0 = None
        try:
            while p0 is None or len(p0) == 0:
                p0 = next(chunk)
        except:
            logger.error(f"Generator {chunk} returned no data.")
            return

        old_schema = pa.Schema.from_pandas(p0, preserve_index=True)

        # FIXME: Schema inference for pyarrow
        # null columns will lead to invalid schema
        # int8 dictionaries in first chunk which become int16 will lead to invalid schema
        # try to fix both
        fields = []
        dtypes = p0.dtypes
        for field in old_schema:
            if (
                isinstance(field.type, pa.DictionaryType)
                and field.type.index_type.bit_width == 8
            ):
                # Expand uint8 dictionaries to uint16
                fields.append(
                    pa.field(
                        field.name,
                        pa.dictionary(pa.int16(), field.type.value_type),
                        field.nullable,
                        field.metadata,
                    )
                )
            elif field.name in dtypes and field.type == pa.null():
                # Fix missing types based on pandas dtype
                # might produce larger than required types, but better than failing.
                # If field is not in dtype, assume it's related to pyarrow and skip.
                match (dtypes[field.name].name):
                    case "int64":
                        pa_type = pa.int64()
                    case other:
                        logger.warning(
                            f"Could not infer type for empty column `{field.name}`"
                            + f" with pandas type `{other}` to generate parquet"
                            + "schema. If there's a chunk who's column contains"
                            + "values, saving will crash. Fill in the code for your type."
                        )
                        pa_type = pa.null()

                fields.append(
                    pa.field(
                        field.name,
                        pa_type,
                        field.nullable,
                        field.metadata,
                    )
                )
            else:
                fields.append(field)
        schema = pa.schema(fields, old_schema.metadata)

        # Use parquet writer to write chunks
        with pq.ParquetWriter(path, schema, filesystem=fs) as w:
            w.write(pa.Table.from_pandas(p0, schema=schema))
            del p0

            for p in chunk:  # type: ignore
                try:
                    w.write(pa.Table.from_pandas(p, schema=schema, preserve_index=True))
                except Exception as e:
                    logger.error(f"Error writing chunk:\n{e}")
    else:
        if protocol == "file":
            if fs.isdir(path):
                fs.rm(path, recursive=True, maxdepth=1)

            with fs.open(path, mode="wb") as fs_file:
                chunk.to_parquet(fs_file, **save_args)
        else:
            bytes_buffer = BytesIO()
            chunk.to_parquet(bytes_buffer, **save_args)

            with fs.open(path, mode="wb") as fs_file:
                fs_file.write(bytes_buffer.getvalue())


@_wrap_retry
def _load_worker(
    path: str,
    protocol: str,
    storage_options,
    load_args: dict,
    columns: list[str] | None = None,
):
    if protocol == "file":
        # file:// protocol seems to misbehave on Windows
        # (<urlopen error file not on local host>),
        # so we don't join that back to the filepath;
        # storage_options also don't work with local paths
        return pd.read_parquet(path, **load_args)

    load_path = f"{protocol}{PROTOCOL_DELIMITER}{path}"

    if columns is not None:
        load_args = load_args.copy()
        load_args["columns"] = columns

    return pd.read_parquet(load_path, storage_options=storage_options, **load_args)


@_wrap_retry
def _load_merged_worker(
    load_path: str, filesystem, load_args, columns: list[str] | None = None
):
    if columns is not None:
        load_args = load_args.copy()
        load_args["columns"] = columns

    data = pq.ParquetDataset(load_path, filesystem=filesystem, use_legacy_dataset=False)
    table = data.read_pandas(**load_args)

    # Grab categorical columns from metadata
    # null columns that are specified as categorical in pandas metadata
    # will become objects after loading, ballooning dataset size
    # the following code will remake the column as categorical
    try:
        import json

        categorical = []
        for field in json.loads(table.schema.metadata[b"pandas"])["columns"]:
            if (field["pandas_type"]) == "categorical":
                categorical.append(field["name"])

        dtypes = {name: "category" for name in categorical}
    except:
        dtypes = None

    # Try to avoid double allocation
    out = table.to_pandas(split_blocks=True, self_destruct=True)
    del table

    # restore categorical dtypes
    if dtypes:
        return out.astype(dtypes)
    return out


@_wrap_retry
def _load_shape_worker(load_path: str, filesystem, *_, **__):
    # TODO: verify this returns correct numbers (esp. columns)
    data = pq.ParquetDataset(load_path, filesystem=filesystem, use_legacy_dataset=False)
    rows = 0
    for frag in data.fragments:
        rows += frag.count_rows()

    pm = data.schema.pandas_metadata  # type: ignore
    cols = len(pm["columns"]) - len(
        [c for c in pm["index_columns"] if isinstance(c, str)]
    )

    return (rows, cols)


[docs]class AutoDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): """Modified kedro parquet dataset that acts similarly to a partitioned dataset and implements lazy loading. In the future, this dataset will automatically handle pickling, pyarrow Tables, DataFrames, and Tensors automatically based on what is saved. `save()` data can be a table, a callable, or a dictionary combination of both. If its a table or a callable, this class acts exactly as ParquetDataset. If its a dictionary, each callable function is called and saved in parallel in a different parquet file, making the provided path a directory. Parallelism is achieved by using Pasteur's common process pool. `load()` returns a dictionary with parquet file names and callables that will load each one. In addition, `load()` will include an entry `_all` that will load and concatenate all partitions, with memory optimisations. If `save()` was called with a single dataframe/callable, then `load()` will return a callable instead. All callables can receive as input the columns they want to be loaded from the dataframe.""" DEFAULT_LOAD_ARGS: dict[str, Any] = {} DEFAULT_SAVE_ARGS: dict[str, Any] = {} def __init__( self, filepath: str, load_args: dict[str, Any] | None = None, save_args: dict[str, Any] | None = None, version: Version | None = None, credentials: dict[str, Any] | None = None, fs_args: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, ) -> None: _fs_args = deepcopy(fs_args) or {} _credentials = deepcopy(credentials) or {} protocol, path = get_protocol_and_path(filepath, version) # type: ignore if protocol == "file": _fs_args.setdefault("auto_mkdir", True) self._protocol = protocol self._storage_options = {**_credentials, **_fs_args} self._fs = fsspec.filesystem(self._protocol, **self._storage_options) self.metadata = metadata super().__init__( filepath=PurePosixPath(path), # type: ignore version=version, exists_function=self._fs.exists, glob_function=self._fs.glob, ) # Handle default load and save arguments self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) if load_args is not None: self._load_args.update(load_args) self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) if save_args is not None: self._save_args.update(save_args) if "storage_options" in self._save_args or "storage_options" in self._load_args: logger.warning( "Dropping 'storage_options' for %s, " "please specify them under 'fs_args' or 'credentials'.", self._filepath, ) self._save_args.pop("storage_options", None) self._load_args.pop("storage_options", None) def _describe(self) -> dict[str, Any]: return { "filepath": self._filepath, "protocol": self._protocol, "load_args": self._load_args, "save_args": self._save_args, "version": self._version, } def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) except DatasetError: return False return self._fs.exists(load_path) def _release(self) -> None: super()._release() self._invalidate_cache() def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) def _load(self) -> LazyFrame: load_path = get_filepath_str(self._get_load_path(), self._protocol) if not self._fs.isdir(load_path): return LazyDataset( LazyPartition( _load_merged_worker, _load_shape_worker, load_path, self._fs, self._load_args, ) ) partitions = {} for fn in self._fs.listdir(load_path): partition_id = fn["name"].split("/")[-1].split("\\")[-1].replace(".pq", "") partition_data = LazyPartition( _load_merged_worker, _load_shape_worker, fn["name"], self._fs, self._load_args, ) partitions[partition_id] = partition_data merged_partition = LazyPartition( _load_merged_worker, _load_shape_worker, load_path, self._fs, self._load_args, ) return LazyDataset(merged_partition, partitions) def _get_save_path(self): if not self._version: # When versioning is disabled, return original filepath return self._filepath save_version = self.resolve_save_version() versioned_path = self._get_versioned_path(save_version) # type: ignore # TODO; Redo check that respects partitioning # if self._exists_function(str(versioned_path)): # raise DatasetError( # f"Save path '{versioned_path}' for {str(self)} must not exist if " # f"versioning is enabled." # ) return versioned_path def _save(self, data: pd.DataFrame) -> None: save_path = get_filepath_str(self._get_save_path(), self._protocol) if (not isinstance(data, dict) and not isinstance(data, LazyDataset)) or ( isinstance(data, LazyDataset) and not data.partitioned ): process( _save_worker, protocol=self._protocol, fs=self._fs, save_args=self._save_args, pid=None, path=save_path, chunk=data, ) return base_args = { "protocol": self._protocol, "fs": self._fs, "save_args": self._save_args, } jobs = [] for pid, partition_data in sorted(data.items()): chunk_save_path = os.path.join( save_path, pid if pid.endswith(".pq") else pid + ".pq" # type: ignore ) jobs.append({"pid": pid, "path": chunk_save_path, "chunk": partition_data}) if not jobs: return self._fs.mkdirs(save_path, exist_ok=True) process_in_parallel( _save_worker, jobs, base_args, 1, f"Processing chunks ({get_node_name():>25s})", ) self._invalidate_cache()
[docs] def reset(self): save_path = get_filepath_str(self._get_save_path(), self._protocol) if self._fs.exists(save_path): self._fs.rm(save_path, recursive=True, maxdepth=1)