Source code for pasteur.kedro.dataset.multi
import warnings
from copy import deepcopy
from typing import Any, Callable
from kedro.io.core import (
VERSION_KEY,
VERSIONED_FLAG_KEY,
AbstractDataset,
DatasetError,
parse_dataset_definition,
)
from kedro.io.partitioned_dataset import S3_PROTOCOLS
from urllib.parse import urlparse
[docs]class Multiset(AbstractDataset):
# noqa: too-many-instance-attributes,protected-access
"""Simplified version of the partitioned dataset. Is not lazy."""
def __init__( # noqa: too-many-arguments
self,
path: str,
dataset: str | type[AbstractDataset] | dict[str, Any],
filepath_arg: str = "filepath",
filename_suffix: str = "",
credentials: dict[str, Any] | None = None,
load_args: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
):
# noqa: import-outside-toplevel
from fsspec.utils import infer_storage_options # for performance reasons
super().__init__()
self._path = path
self._filename_suffix = filename_suffix
self._protocol = infer_storage_options(self._path)["protocol"]
self.metadata = metadata
dataset = dataset if isinstance(dataset, dict) else {"type": dataset}
self._dataset_type, self._dataset_config = parse_dataset_definition(dataset)
if VERSION_KEY in self._dataset_config:
raise DatasetError(
f"'{self.__class__.__name__}' does not support versioning of the "
f"underlying dataset. Please remove '{VERSIONED_FLAG_KEY}' flag from "
f"the dataset definition."
)
self._credentials = deepcopy(credentials) or {}
self._filepath_arg = filepath_arg
if self._filepath_arg in self._dataset_config:
warnings.warn(
f"'{self._filepath_arg}' key must not be specified in the dataset "
f"definition as it will be overwritten by partition path"
)
self._load_args = deepcopy(load_args) or {}
self._sep = self._filesystem.sep
# since some filesystem implementations may implement a global cache
self._invalidate_caches()
@property
def _filesystem(self):
# for performance reasons
import fsspec # noqa: import-outside-toplevel
protocol = "s3" if self._protocol in S3_PROTOCOLS else self._protocol
return fsspec.filesystem(protocol, **self._credentials)
@property
def _normalized_path(self) -> str:
if self._protocol in S3_PROTOCOLS:
return urlparse(self._path)._replace(scheme="s3").geturl()
return self._path
def _list_partitions(self) -> list[str]:
if not self._filesystem.isdir(self._normalized_path, **self._load_args):
# If the path does not exist, ie no datasets were saved before
# return no partitions instead of crashing
return []
return [
path["name"]
for path in self._filesystem.listdir(
self._normalized_path, **self._load_args
)
if path["name"].endswith(self._filename_suffix)
]
def _join_protocol(self, path: str) -> str:
protocol_prefix = f"{self._protocol}://"
if self._path.startswith(protocol_prefix) and not path.startswith(
protocol_prefix
):
return f"{protocol_prefix}{path}"
return path
def _partition_to_path(self, path: str):
dir_path = self._path.rstrip(self._sep)
path = path.lstrip(self._sep)
full_path = self._sep.join([dir_path, path]) + self._filename_suffix
return full_path
def _path_to_partition(self, path: str) -> str:
dir_path = self._filesystem._strip_protocol(self._normalized_path)
path = path.split(dir_path, 1).pop().lstrip(self._sep)
if self._filename_suffix and path.endswith(self._filename_suffix):
path = path[: -len(self._filename_suffix)]
return path
def _load(self) -> dict[str, Callable[[], Any]]:
partitions = {}
for partition in self._list_partitions():
kwargs = deepcopy(self._dataset_config)
# join the protocol back since PySpark may rely on it
kwargs[self._filepath_arg] = self._join_protocol(partition)
dataset = self._dataset_type(**kwargs) # type: ignore
partition_id = self._path_to_partition(partition)
partitions[partition_id] = dataset.load()
return partitions
def _save(self, data: dict[str, Any]) -> None:
for partition_id, partition_data in sorted(data.items()):
kwargs = deepcopy(self._dataset_config)
partition = self._partition_to_path(partition_id)
# join the protocol back since tools like PySpark may rely on it
kwargs[self._filepath_arg] = self._join_protocol(partition)
dataset = self._dataset_type(**kwargs) # type: ignore
if callable(partition_data):
partition_data = partition_data() # noqa: redefined-loop-name
dataset.save(partition_data)
self._invalidate_caches()
def _describe(self) -> dict[str, Any]:
clean_dataset_config = (
{k: v for k, v in self._dataset_config.items()}
if isinstance(self._dataset_config, dict)
else self._dataset_config
)
return {
"path": self._path,
"dataset_type": self._dataset_type.__name__,
"dataset_config": clean_dataset_config,
}
def _invalidate_caches(self):
self._filesystem.invalidate_cache(self._normalized_path)
[docs] def reset(self):
"""Removes the dataset from disk so that there are no stray partitions in subsequent runs."""
if self._filesystem.exists(self._normalized_path):
self._filesystem.rm(self._normalized_path, recursive=True, maxdepth=1)
def _release(self) -> None:
super()._release()
self._invalidate_caches()