"""This module provides a set of classes which underpin the data loading and
saving functionality provided by ``kedro.io``.
"""
from __future__ import annotations
import abc
import copy
import logging
import os
import re
import warnings
from collections import namedtuple
from datetime import datetime, timezone
from functools import partial
from glob import iglob
from operator import attrgetter
from pathlib import Path, PurePath, PurePosixPath
from typing import Any, Callable, Generic, TypeVar
from urllib.parse import urlsplit
from cachetools import Cache, cachedmethod
from cachetools.keys import hashkey
from kedro.utils import load_obj
VERSION_FORMAT = "%Y-%m-%dT%H.%M.%S.%fZ"
VERSIONED_FLAG_KEY = "versioned"
VERSION_KEY = "version"
HTTP_PROTOCOLS = ("http", "https")
PROTOCOL_DELIMITER = "://"
CLOUD_PROTOCOLS = ("s3", "s3n", "s3a", "gcs", "gs", "adl", "abfs", "abfss", "gdrive")
[docs]
class DatasetError(Exception):
"""``DatasetError`` raised by ``AbstractDataset`` implementations
in case of failure of input/output methods.
``AbstractDataset`` implementations should provide instructive
information in case of failure.
"""
pass
[docs]
class DatasetNotFoundError(DatasetError):
"""``DatasetNotFoundError`` raised by ``DataCatalog`` class in case of
trying to use a non-existing data set.
"""
pass
[docs]
class DatasetAlreadyExistsError(DatasetError):
"""``DatasetAlreadyExistsError`` raised by ``DataCatalog`` class in case
of trying to add a data set which already exists in the ``DataCatalog``.
"""
pass
class VersionNotFoundError(DatasetError):
"""``VersionNotFoundError`` raised by ``AbstractVersionedDataset`` implementations
in case of no load versions available for the data set.
"""
pass
_DI = TypeVar("_DI")
_DO = TypeVar("_DO")
[docs]
class AbstractDataset(abc.ABC, Generic[_DI, _DO]):
"""``AbstractDataset`` is the base class for all data set implementations.
All data set implementations should extend this abstract class
and implement the methods marked as abstract.
If a specific dataset implementation cannot be used in conjunction with
the ``ParallelRunner``, such user-defined dataset should have the
attribute `_SINGLE_PROCESS = True`.
Example:
::
>>> from pathlib import Path, PurePosixPath
>>> import pandas as pd
>>> from kedro.io import AbstractDataset
>>>
>>>
>>> class MyOwnDataset(AbstractDataset[pd.DataFrame, pd.DataFrame]):
>>> def __init__(self, filepath, param1, param2=True):
>>> self._filepath = PurePosixPath(filepath)
>>> self._param1 = param1
>>> self._param2 = param2
>>>
>>> def _load(self) -> pd.DataFrame:
>>> return pd.read_csv(self._filepath)
>>>
>>> def _save(self, df: pd.DataFrame) -> None:
>>> df.to_csv(str(self._filepath))
>>>
>>> def _exists(self) -> bool:
>>> return Path(self._filepath.as_posix()).exists()
>>>
>>> def _describe(self):
>>> return dict(param1=self._param1, param2=self._param2)
Example catalog.yml specification:
::
my_dataset:
type: <path-to-my-own-dataset>.MyOwnDataset
filepath: data/01_raw/my_data.csv
param1: <param1-value> # param1 is a required argument
# param2 will be True by default
"""
"""
Datasets are persistent by default. User-defined datasets that
are not made to be persistent, such as instances of `MemoryDataset`,
need to change the `_EPHEMERAL` attribute to 'True'.
"""
_EPHEMERAL = False
[docs]
@classmethod
def from_config(
cls: type,
name: str,
config: dict[str, Any],
load_version: str | None = None,
save_version: str | None = None,
) -> AbstractDataset:
"""Create a data set instance using the configuration provided.
Args:
name: Data set name.
config: Data set config dictionary.
load_version: Version string to be used for ``load`` operation if
the data set is versioned. Has no effect on the data set
if versioning was not enabled.
save_version: Version string to be used for ``save`` operation if
the data set is versioned. Has no effect on the data set
if versioning was not enabled.
Returns:
An instance of an ``AbstractDataset`` subclass.
Raises:
DatasetError: When the function fails to create the data set
from its config.
"""
try:
class_obj, config = parse_dataset_definition(
config, load_version, save_version
)
except Exception as exc:
raise DatasetError(
f"An exception occurred when parsing config "
f"for dataset '{name}':\n{str(exc)}"
) from exc
try:
dataset = class_obj(**config)
except TypeError as err:
raise DatasetError(
f"\n{err}.\nDataset '{name}' must only contain arguments valid for the "
f"constructor of '{class_obj.__module__}.{class_obj.__qualname__}'."
) from err
except Exception as err:
raise DatasetError(
f"\n{err}.\nFailed to instantiate dataset '{name}' "
f"of type '{class_obj.__module__}.{class_obj.__qualname__}'."
) from err
return dataset
@property
def _logger(self) -> logging.Logger:
return logging.getLogger(__name__)
[docs]
def load(self) -> _DO:
"""Loads data by delegation to the provided load method.
Returns:
Data returned by the provided load method.
Raises:
DatasetError: When underlying load method raises error.
"""
self._logger.debug("Loading %s", str(self))
try:
return self._load()
except DatasetError:
raise
except Exception as exc:
# This exception handling is by design as the composed data sets
# can throw any type of exception.
message = (
f"Failed while loading data from data set {str(self)}.\n{str(exc)}"
)
raise DatasetError(message) from exc
[docs]
def save(self, data: _DI) -> None:
"""Saves data by delegation to the provided save method.
Args:
data: the value to be saved by provided save method.
Raises:
DatasetError: when underlying save method raises error.
FileNotFoundError: when save method got file instead of dir, on Windows.
NotADirectoryError: when save method got file instead of dir, on Unix.
"""
if data is None:
raise DatasetError("Saving 'None' to a 'Dataset' is not allowed")
try:
self._logger.debug("Saving %s", str(self))
self._save(data)
except DatasetError:
raise
except (FileNotFoundError, NotADirectoryError):
raise
except Exception as exc:
message = f"Failed while saving data to data set {str(self)}.\n{str(exc)}"
raise DatasetError(message) from exc
def __str__(self) -> str:
def _to_str(obj: Any, is_root: bool = False) -> str:
"""Returns a string representation where
1. The root level (i.e. the Dataset.__init__ arguments) are
formatted like Dataset(key=value).
2. Dictionaries have the keys alphabetically sorted recursively.
3. None values are not shown.
"""
fmt = "{}={}" if is_root else "'{}': {}" # 1
if isinstance(obj, dict):
sorted_dict = sorted(obj.items(), key=lambda pair: str(pair[0])) # 2
text = ", ".join(
fmt.format(key, _to_str(value)) # 2
for key, value in sorted_dict
if value is not None # 3
)
return text if is_root else "{" + text + "}" # 1
# not a dictionary
return str(obj)
return f"{type(self).__name__}({_to_str(self._describe(), True)})"
@abc.abstractmethod
def _load(self) -> _DO:
raise NotImplementedError(
f"'{self.__class__.__name__}' is a subclass of AbstractDataset and "
f"it must implement the '_load' method"
)
@abc.abstractmethod
def _save(self, data: _DI) -> None:
raise NotImplementedError(
f"'{self.__class__.__name__}' is a subclass of AbstractDataset and "
f"it must implement the '_save' method"
)
@abc.abstractmethod
def _describe(self) -> dict[str, Any]:
raise NotImplementedError(
f"'{self.__class__.__name__}' is a subclass of AbstractDataset and "
f"it must implement the '_describe' method"
)
[docs]
def exists(self) -> bool:
"""Checks whether a data set's output already exists by calling
the provided _exists() method.
Returns:
Flag indicating whether the output already exists.
Raises:
DatasetError: when underlying exists method raises error.
"""
try:
self._logger.debug("Checking whether target of %s exists", str(self))
return self._exists()
except Exception as exc:
message = (
f"Failed during exists check for data set {str(self)}.\n{str(exc)}"
)
raise DatasetError(message) from exc
def _exists(self) -> bool:
self._logger.warning(
"'exists()' not implemented for '%s'. Assuming output does not exist.",
self.__class__.__name__,
)
return False
[docs]
def release(self) -> None:
"""Release any cached data.
Raises:
DatasetError: when underlying release method raises error.
"""
try:
self._logger.debug("Releasing %s", str(self))
self._release()
except Exception as exc:
message = f"Failed during release for data set {str(self)}.\n{str(exc)}"
raise DatasetError(message) from exc
def _release(self) -> None:
pass
def _copy(self, **overwrite_params: Any) -> AbstractDataset:
dataset_copy = copy.deepcopy(self)
for name, value in overwrite_params.items():
setattr(dataset_copy, name, value)
return dataset_copy
def generate_timestamp() -> str:
"""Generate the timestamp to be used by versioning.
Returns:
String representation of the current timestamp.
"""
current_ts = datetime.now(tz=timezone.utc).strftime(VERSION_FORMAT)
return current_ts[:-4] + current_ts[-1:] # Don't keep microseconds
[docs]
class Version(namedtuple("Version", ["load", "save"])):
"""This namedtuple is used to provide load and save versions for versioned
data sets. If ``Version.load`` is None, then the latest available version
is loaded. If ``Version.save`` is None, then save version is formatted as
YYYY-MM-DDThh.mm.ss.sssZ of the current timestamp.
"""
__slots__ = ()
_CONSISTENCY_WARNING = (
"Save version '{}' did not match load version '{}' for {}. This is strongly "
"discouraged due to inconsistencies it may cause between 'save' and "
"'load' operations. Please refrain from setting exact load version for "
"intermediate data sets where possible to avoid this warning."
)
_DEFAULT_PACKAGES = ["kedro.io.", "kedro_datasets.", ""]
def parse_dataset_definition(
config: dict[str, Any],
load_version: str | None = None,
save_version: str | None = None,
) -> tuple[type[AbstractDataset], dict[str, Any]]:
"""Parse and instantiate a dataset class using the configuration provided.
Args:
config: Data set config dictionary. It *must* contain the `type` key
with fully qualified class name or the class object.
load_version: Version string to be used for ``load`` operation if
the data set is versioned. Has no effect on the data set
if versioning was not enabled.
save_version: Version string to be used for ``save`` operation if
the data set is versioned. Has no effect on the data set
if versioning was not enabled.
Raises:
DatasetError: If the function fails to parse the configuration provided.
Returns:
2-tuple: (Dataset class object, configuration dictionary)
"""
save_version = save_version or generate_timestamp()
config = copy.deepcopy(config)
if "type" not in config:
raise DatasetError("'type' is missing from dataset catalog configuration")
dataset_type = config.pop("type")
class_obj = None
if isinstance(dataset_type, str):
if len(dataset_type.strip(".")) != len(dataset_type):
raise DatasetError(
"'type' class path does not support relative "
"paths or paths ending with a dot."
)
class_paths = (prefix + dataset_type for prefix in _DEFAULT_PACKAGES)
for class_path in class_paths:
tmp = _load_obj(class_path)
if tmp is not None:
class_obj = tmp
break
else:
raise DatasetError(f"Class '{dataset_type}' not found, is this a typo?")
if not class_obj:
class_obj = dataset_type
if not issubclass(class_obj, AbstractDataset):
raise DatasetError(
f"Dataset type '{class_obj.__module__}.{class_obj.__qualname__}' "
f"is invalid: all data set types must extend 'AbstractDataset'."
)
if VERSION_KEY in config:
# remove "version" key so that it's not passed
# to the "unversioned" data set constructor
message = (
"'%s' attribute removed from data set configuration since it is a "
"reserved word and cannot be directly specified"
)
logging.getLogger(__name__).warning(message, VERSION_KEY)
del config[VERSION_KEY]
# dataset is either versioned explicitly by the user or versioned is set to true by default
# on the dataset
if config.pop(VERSIONED_FLAG_KEY, False) or getattr(
class_obj, VERSIONED_FLAG_KEY, False
):
config[VERSION_KEY] = Version(load_version, save_version)
return class_obj, config
def _load_obj(class_path: str) -> Any | None:
mod_path, _, class_name = class_path.rpartition(".")
# Check if the module exists
try:
available_classes = load_obj(f"{mod_path}.__all__")
# ModuleNotFoundError: When `load_obj` can't find `mod_path` (e.g `kedro.io.pandas`)
# this is because we try a combination of all prefixes.
# AttributeError: When `load_obj` manages to load `mod_path` but it doesn't have an
# `__all__` attribute -- either because it's a custom or a kedro.io dataset
except (ModuleNotFoundError, AttributeError, ValueError):
available_classes = None
try:
class_obj = load_obj(class_path)
except (ModuleNotFoundError, ValueError, AttributeError) as exc:
# If it's available, module exist but dependencies are missing
if available_classes and class_name in available_classes:
raise DatasetError(
f"{exc}. Please see the documentation on how to "
f"install relevant dependencies for {class_path}:\n"
f"https://docs.kedro.org/en/stable/kedro_project_setup/"
f"dependencies.html#install-dependencies-related-to-the-data-catalog"
) from exc
return None
return class_obj
def _local_exists(local_filepath: str) -> bool: # SKIP_IF_NO_SPARK
filepath = Path(local_filepath)
return filepath.exists() or any(par.is_file() for par in filepath.parents)
[docs]
class AbstractVersionedDataset(AbstractDataset[_DI, _DO], abc.ABC):
"""
``AbstractVersionedDataset`` is the base class for all versioned data set
implementations. All data sets that implement versioning should extend this
abstract class and implement the methods marked as abstract.
Example:
::
>>> from pathlib import Path, PurePosixPath
>>> import pandas as pd
>>> from kedro.io import AbstractVersionedDataset
>>>
>>>
>>> class MyOwnDataset(AbstractVersionedDataset):
>>> def __init__(self, filepath, version, param1, param2=True):
>>> super().__init__(PurePosixPath(filepath), version)
>>> self._param1 = param1
>>> self._param2 = param2
>>>
>>> def _load(self) -> pd.DataFrame:
>>> load_path = self._get_load_path()
>>> return pd.read_csv(load_path)
>>>
>>> def _save(self, df: pd.DataFrame) -> None:
>>> save_path = self._get_save_path()
>>> df.to_csv(str(save_path))
>>>
>>> def _exists(self) -> bool:
>>> path = self._get_load_path()
>>> return Path(path.as_posix()).exists()
>>>
>>> def _describe(self):
>>> return dict(version=self._version, param1=self._param1, param2=self._param2)
Example catalog.yml specification:
::
my_dataset:
type: <path-to-my-own-dataset>.MyOwnDataset
filepath: data/01_raw/my_data.csv
versioned: true
param1: <param1-value> # param1 is a required argument
# param2 will be True by default
"""
[docs]
def __init__(
self,
filepath: PurePosixPath,
version: Version | None,
exists_function: Callable[[str], bool] | None = None,
glob_function: Callable[[str], list[str]] | None = None,
):
"""Creates a new instance of ``AbstractVersionedDataset``.
Args:
filepath: Filepath in POSIX format to a file.
version: If specified, should be an instance of
``kedro.io.core.Version``. If its ``load`` attribute is
None, the latest version will be loaded. If its ``save``
attribute is None, save version will be autogenerated.
exists_function: Function that is used for determining whether
a path exists in a filesystem.
glob_function: Function that is used for finding all paths
in a filesystem, which match a given pattern.
"""
self._filepath = filepath
self._version = version
self._exists_function = exists_function or _local_exists
self._glob_function = glob_function or iglob
# 1 entry for load version, 1 for save version
self._version_cache = Cache(maxsize=2) # type: Cache
# 'key' is set to prevent cache key overlapping for load and save:
# https://cachetools.readthedocs.io/en/stable/#cachetools.cachedmethod
@cachedmethod(cache=attrgetter("_version_cache"), key=partial(hashkey, "load"))
def _fetch_latest_load_version(self) -> str:
# When load version is unpinned, fetch the most recent existing
# version from the given path.
pattern = str(self._get_versioned_path("*"))
try:
version_paths = sorted(self._glob_function(pattern), reverse=True)
except Exception as exc:
message = (
f"Did not find any versions for {self}. This could be "
f"due to insufficient permission. Exception: {exc}"
)
raise VersionNotFoundError(message) from exc
most_recent = next(
(path for path in version_paths if self._exists_function(path)), None
)
if not most_recent:
message = f"Did not find any versions for {self}"
raise VersionNotFoundError(message)
return PurePath(most_recent).parent.name
# 'key' is set to prevent cache key overlapping for load and save:
# https://cachetools.readthedocs.io/en/stable/#cachetools.cachedmethod
@cachedmethod(cache=attrgetter("_version_cache"), key=partial(hashkey, "save"))
def _fetch_latest_save_version(self) -> str:
"""Generate and cache the current save version"""
return generate_timestamp()
[docs]
def resolve_load_version(self) -> str | None:
"""Compute the version the dataset should be loaded with."""
if not self._version:
return None
if self._version.load:
return self._version.load # type: ignore[no-any-return]
return self._fetch_latest_load_version()
def _get_load_path(self) -> PurePosixPath:
if not self._version:
# When versioning is disabled, load from original filepath
return self._filepath
load_version = self.resolve_load_version()
return self._get_versioned_path(load_version) # type: ignore[arg-type]
[docs]
def resolve_save_version(self) -> str | None:
"""Compute the version the dataset should be saved with."""
if not self._version:
return None
if self._version.save:
return self._version.save # type: ignore[no-any-return]
return self._fetch_latest_save_version()
def _get_save_path(self) -> PurePosixPath:
if not self._version:
# When versioning is disabled, return original filepath
return self._filepath
save_version = self.resolve_save_version()
versioned_path = self._get_versioned_path(save_version) # type: ignore[arg-type]
if self._exists_function(str(versioned_path)):
raise DatasetError(
f"Save path '{versioned_path}' for {str(self)} must not exist if "
f"versioning is enabled."
)
return versioned_path
def _get_versioned_path(self, version: str) -> PurePosixPath:
return self._filepath / version / self._filepath.name
[docs]
def load(self) -> _DO:
return super().load()
[docs]
def save(self, data: _DI) -> None:
self._version_cache.clear()
save_version = self.resolve_save_version() # Make sure last save version is set
try:
super().save(data)
except (FileNotFoundError, NotADirectoryError) as err:
# FileNotFoundError raised in Win, NotADirectoryError raised in Unix
_default_version = "YYYY-MM-DDThh.mm.ss.sssZ"
raise DatasetError(
f"Cannot save versioned dataset '{self._filepath.name}' to "
f"'{self._filepath.parent.as_posix()}' because a file with the same "
f"name already exists in the directory. This is likely because "
f"versioning was enabled on a dataset already saved previously. Either "
f"remove '{self._filepath.name}' from the directory or manually "
f"convert it into a versioned dataset by placing it in a versioned "
f"directory (e.g. with default versioning format "
f"'{self._filepath.as_posix()}/{_default_version}/{self._filepath.name}"
f"')."
) from err
load_version = self.resolve_load_version()
if load_version != save_version:
warnings.warn(
_CONSISTENCY_WARNING.format(save_version, load_version, str(self))
)
[docs]
def exists(self) -> bool:
"""Checks whether a data set's output already exists by calling
the provided _exists() method.
Returns:
Flag indicating whether the output already exists.
Raises:
DatasetError: when underlying exists method raises error.
"""
self._logger.debug("Checking whether target of %s exists", str(self))
try:
return self._exists()
except VersionNotFoundError:
return False
except Exception as exc: # SKIP_IF_NO_SPARK
message = (
f"Failed during exists check for data set {str(self)}.\n{str(exc)}"
)
raise DatasetError(message) from exc
def _release(self) -> None:
super()._release()
self._version_cache.clear()
def _parse_filepath(filepath: str) -> dict[str, str]:
"""Split filepath on protocol and path. Based on `fsspec.utils.infer_storage_options`.
Args:
filepath: Either local absolute file path or URL (s3://bucket/file.csv)
Returns:
Parsed filepath.
"""
if (
re.match(r"^[a-zA-Z]:[\\/]", filepath)
or re.match(r"^[a-zA-Z0-9]+://", filepath) is None
):
return {"protocol": "file", "path": filepath}
parsed_path = urlsplit(filepath)
protocol = parsed_path.scheme or "file"
if protocol in HTTP_PROTOCOLS:
return {"protocol": protocol, "path": filepath}
path = parsed_path.path
if protocol == "file":
windows_path = re.match(r"^/([a-zA-Z])[:|]([\\/].*)$", path)
if windows_path:
path = ":".join(windows_path.groups())
options = {"protocol": protocol, "path": path}
if parsed_path.netloc and protocol in CLOUD_PROTOCOLS:
host_with_port = parsed_path.netloc.rsplit("@", 1)[-1]
host = host_with_port.rsplit(":", 1)[0]
options["path"] = host + options["path"]
# Azure Data Lake Storage Gen2 URIs can store the container name in the
# 'username' field of a URL (@ syntax), so we need to add it to the path
if protocol == "abfss" and parsed_path.username:
options["path"] = parsed_path.username + "@" + options["path"]
return options
def get_protocol_and_path(
filepath: str | os.PathLike, version: Version | None = None
) -> tuple[str, str]:
"""Parses filepath on protocol and path.
.. warning::
Versioning is not supported for HTTP protocols.
Args:
filepath: raw filepath e.g.: ``gcs://bucket/test.json``.
version: instance of ``kedro.io.core.Version`` or None.
Returns:
Protocol and path.
Raises:
DatasetError: when protocol is http(s) and version is not None.
"""
options_dict = _parse_filepath(str(filepath))
path = options_dict["path"]
protocol = options_dict["protocol"]
if protocol in HTTP_PROTOCOLS:
if version is not None:
raise DatasetError(
"Versioning is not supported for HTTP protocols. "
"Please remove the `versioned` flag from the dataset configuration."
)
path = path.split(PROTOCOL_DELIMITER, 1)[-1]
return protocol, path
def get_filepath_str(raw_path: PurePath, protocol: str) -> str:
"""Returns filepath. Returns full filepath (with protocol) if protocol is HTTP(s).
Args:
raw_path: filepath without protocol.
protocol: protocol.
Returns:
Filepath string.
"""
path = raw_path.as_posix()
if protocol in HTTP_PROTOCOLS:
path = "".join((protocol, PROTOCOL_DELIMITER, path))
return path
def validate_on_forbidden_chars(**kwargs: Any) -> None:
"""Validate that string values do not include white-spaces or ;"""
for key, value in kwargs.items():
if " " in value or ";" in value:
raise DatasetError(
f"Neither white-space nor semicolon are allowed in '{key}'."
)