"""``DataCatalog`` stores instances of ``AbstractDataset`` implementations to
provide ``load`` and ``save`` capabilities from anywhere in the program. To
use a ``DataCatalog``, you need to instantiate it with a dictionary of data
sets. Then it will act as a single point of reference for your calls,
relaying load and save functions to the underlying data sets.
"""
from __future__ import annotations
import copy
import difflib
import logging
import re
from typing import Any, Dict
from parse import parse
from kedro.io.core import (
AbstractDataset,
AbstractVersionedDataset,
DatasetAlreadyExistsError,
DatasetError,
DatasetNotFoundError,
Version,
generate_timestamp,
)
from kedro.io.memory_dataset import MemoryDataset
Patterns = Dict[str, Dict[str, Any]]
CATALOG_KEY = "catalog"
CREDENTIALS_KEY = "credentials"
WORDS_REGEX_PATTERN = re.compile(r"\W+")
def _get_credentials(credentials_name: str, credentials: dict[str, Any]) -> Any:
"""Return a set of credentials from the provided credentials dict.
Args:
credentials_name: Credentials name.
credentials: A dictionary with all credentials.
Returns:
The set of requested credentials.
Raises:
KeyError: When a data set with the given name has not yet been
registered.
"""
try:
return credentials[credentials_name]
except KeyError as exc:
raise KeyError(
f"Unable to find credentials '{credentials_name}': check your data "
"catalog and credentials configuration. See "
"https://kedro.readthedocs.io/en/stable/kedro.io.DataCatalog.html "
"for an example."
) from exc
def _resolve_credentials(
config: dict[str, Any], credentials: dict[str, Any]
) -> dict[str, Any]:
"""Return the dataset configuration where credentials are resolved using
credentials dictionary provided.
Args:
config: Original dataset config, which may contain unresolved credentials.
credentials: A dictionary with all credentials.
Returns:
The dataset config, where all the credentials are successfully resolved.
"""
config = copy.deepcopy(config)
def _map_value(key: str, value: Any) -> Any:
if key == CREDENTIALS_KEY and isinstance(value, str):
return _get_credentials(value, credentials)
if isinstance(value, dict):
return {k: _map_value(k, v) for k, v in value.items()}
return value
return {k: _map_value(k, v) for k, v in config.items()}
def _sub_nonword_chars(dataset_name: str) -> str:
"""Replace non-word characters in data set names since Kedro 0.16.2.
Args:
dataset_name: The data set name registered in the data catalog.
Returns:
The name used in `DataCatalog.datasets`.
"""
return re.sub(WORDS_REGEX_PATTERN, "__", dataset_name)
class _FrozenDatasets:
"""Helper class to access underlying loaded datasets."""
def __init__(
self,
*datasets_collections: _FrozenDatasets | dict[str, AbstractDataset],
):
"""Return a _FrozenDatasets instance from some datasets collections.
Each collection could either be another _FrozenDatasets or a dictionary.
"""
for collection in datasets_collections:
if isinstance(collection, _FrozenDatasets):
self.__dict__.update(collection.__dict__)
else:
# Non-word characters in dataset names are replaced with `__`
# for easy access to transcoded/prefixed datasets.
self.__dict__.update(
{
_sub_nonword_chars(dataset_name): dataset
for dataset_name, dataset in collection.items()
}
)
# Don't allow users to add/change attributes on the fly
def __setattr__(self, key: str, value: Any) -> None:
msg = "Operation not allowed! "
if key in self.__dict__:
msg += "Please change datasets through configuration."
else:
msg += "Please use DataCatalog.add() instead."
raise AttributeError(msg)
[docs]
class DataCatalog:
"""``DataCatalog`` stores instances of ``AbstractDataset`` implementations
to provide ``load`` and ``save`` capabilities from anywhere in the
program. To use a ``DataCatalog``, you need to instantiate it with
a dictionary of data sets. Then it will act as a single point of reference
for your calls, relaying load and save functions
to the underlying data sets.
"""
[docs]
def __init__( # noqa: PLR0913
self,
datasets: dict[str, AbstractDataset] | None = None,
feed_dict: dict[str, Any] | None = None,
dataset_patterns: Patterns | None = None,
load_versions: dict[str, str] | None = None,
save_version: str | None = None,
) -> None:
"""``DataCatalog`` stores instances of ``AbstractDataset``
implementations to provide ``load`` and ``save`` capabilities from
anywhere in the program. To use a ``DataCatalog``, you need to
instantiate it with a dictionary of data sets. Then it will act as a
single point of reference for your calls, relaying load and save
functions to the underlying data sets.
Args:
datasets: A dictionary of data set names and data set instances.
feed_dict: A feed dict with data to be added in memory.
dataset_patterns: A dictionary of data set factory patterns
and corresponding data set configuration. When fetched from catalog configuration
these patterns will be sorted by:
1. Decreasing specificity (number of characters outside the curly brackets)
2. Decreasing number of placeholders (number of curly bracket pairs)
3. Alphabetically
A pattern of specificity 0 is a catch-all pattern and will overwrite the default
pattern provided through the runners if it comes before "default" in the alphabet.
Such an overwriting pattern will emit a warning. The `"{default}"` name will
not emit a warning.
load_versions: A mapping between data set names and versions
to load. Has no effect on data sets without enabled versioning.
save_version: Version string to be used for ``save`` operations
by all data sets with enabled versioning. It must: a) be a
case-insensitive string that conforms with operating system
filename limitations, b) always return the latest version when
sorted in lexicographical order.
Example:
::
>>> from kedro_datasets.pandas import CSVDataset
>>>
>>> cars = CSVDataset(filepath="cars.csv",
>>> load_args=None,
>>> save_args={"index": False})
>>> io = DataCatalog(datasets={'cars': cars})
"""
self._datasets = dict(datasets or {})
self.datasets = _FrozenDatasets(self._datasets)
# Keep a record of all patterns in the catalog.
# {dataset pattern name : dataset pattern body}
self._dataset_patterns = dataset_patterns or {}
self._load_versions = load_versions or {}
self._save_version = save_version
if feed_dict:
self.add_feed_dict(feed_dict)
@property
def _logger(self) -> logging.Logger:
return logging.getLogger(__name__)
[docs]
@classmethod
def from_config(
cls,
catalog: dict[str, dict[str, Any]] | None,
credentials: dict[str, dict[str, Any]] | None = None,
load_versions: dict[str, str] | None = None,
save_version: str | None = None,
) -> DataCatalog:
"""Create a ``DataCatalog`` instance from configuration. This is a
factory method used to provide developers with a way to instantiate
``DataCatalog`` with configuration parsed from configuration files.
Args:
catalog: A dictionary whose keys are the data set names and
the values are dictionaries with the constructor arguments
for classes implementing ``AbstractDataset``. The data set
class to be loaded is specified with the key ``type`` and their
fully qualified class name. All ``kedro.io`` data set can be
specified by their class name only, i.e. their module name
can be omitted.
credentials: A dictionary containing credentials for different
data sets. Use the ``credentials`` key in a ``AbstractDataset``
to refer to the appropriate credentials as shown in the example
below.
load_versions: A mapping between dataset names and versions
to load. Has no effect on data sets without enabled versioning.
save_version: Version string to be used for ``save`` operations
by all data sets with enabled versioning. It must: a) be a
case-insensitive string that conforms with operating system
filename limitations, b) always return the latest version when
sorted in lexicographical order.
Returns:
An instantiated ``DataCatalog`` containing all specified
data sets, created and ready to use.
Raises:
DatasetError: When the method fails to create any of the data
sets from their config.
DatasetNotFoundError: When `load_versions` refers to a dataset that doesn't
exist in the catalog.
Example:
::
>>> config = {
>>> "cars": {
>>> "type": "pandas.CSVDataset",
>>> "filepath": "cars.csv",
>>> "save_args": {
>>> "index": False
>>> }
>>> },
>>> "boats": {
>>> "type": "pandas.CSVDataset",
>>> "filepath": "s3://aws-bucket-name/boats.csv",
>>> "credentials": "boats_credentials",
>>> "save_args": {
>>> "index": False
>>> }
>>> }
>>> }
>>>
>>> credentials = {
>>> "boats_credentials": {
>>> "client_kwargs": {
>>> "aws_access_key_id": "<your key id>",
>>> "aws_secret_access_key": "<your secret>"
>>> }
>>> }
>>> }
>>>
>>> catalog = DataCatalog.from_config(config, credentials)
>>>
>>> df = catalog.load("cars")
>>> catalog.save("boats", df)
"""
datasets = {}
dataset_patterns = {}
catalog = copy.deepcopy(catalog) or {}
credentials = copy.deepcopy(credentials) or {}
save_version = save_version or generate_timestamp()
load_versions = copy.deepcopy(load_versions) or {}
for ds_name, ds_config in catalog.items():
ds_config = _resolve_credentials( # noqa: PLW2901
ds_config, credentials
)
if cls._is_pattern(ds_name):
# Add each factory to the dataset_patterns dict.
dataset_patterns[ds_name] = ds_config
else:
datasets[ds_name] = AbstractDataset.from_config(
ds_name, ds_config, load_versions.get(ds_name), save_version
)
sorted_patterns = cls._sort_patterns(dataset_patterns)
missing_keys = [
key
for key in load_versions.keys()
if not (key in catalog or cls._match_pattern(sorted_patterns, key))
]
if missing_keys:
raise DatasetNotFoundError(
f"'load_versions' keys [{', '.join(sorted(missing_keys))}] "
f"are not found in the catalog."
)
return cls(
datasets=datasets,
dataset_patterns=sorted_patterns,
load_versions=load_versions,
save_version=save_version,
)
@staticmethod
def _is_pattern(pattern: str) -> bool:
"""Check if a given string is a pattern. Assume that any name with '{' is a pattern."""
return "{" in pattern
@staticmethod
def _match_pattern(dataset_patterns: Patterns, dataset_name: str) -> str | None:
"""Match a dataset name against patterns in a dictionary."""
matches = (
pattern
for pattern in dataset_patterns.keys()
if parse(pattern, dataset_name)
)
return next(matches, None)
@classmethod
def _sort_patterns(cls, dataset_patterns: Patterns) -> dict[str, dict[str, Any]]:
"""Sort a dictionary of dataset patterns according to parsing rules.
In order:
1. Decreasing specificity (number of characters outside the curly brackets)
2. Decreasing number of placeholders (number of curly bracket pairs)
3. Alphabetically
"""
sorted_keys = sorted(
dataset_patterns,
key=lambda pattern: (
-(cls._specificity(pattern)),
-pattern.count("{"),
pattern,
),
)
return {key: dataset_patterns[key] for key in sorted_keys}
@staticmethod
def _specificity(pattern: str) -> int:
"""Helper function to check the length of exactly matched characters not inside brackets.
Example:
::
>>> specificity("{namespace}.companies") = 10
>>> specificity("{namespace}.{dataset}") = 1
>>> specificity("france.companies") = 16
"""
# Remove all the placeholders from the pattern and count the number of remaining chars
result = re.sub(r"\{.*?\}", "", pattern)
return len(result)
def _get_dataset(
self,
dataset_name: str,
version: Version | None = None,
suggest: bool = True,
) -> AbstractDataset:
matched_pattern = self._match_pattern(self._dataset_patterns, dataset_name)
if dataset_name not in self._datasets and matched_pattern:
# If the dataset is a patterned dataset, materialise it and add it to
# the catalog
config_copy = copy.deepcopy(self._dataset_patterns[matched_pattern])
dataset_config = self._resolve_config(
dataset_name, matched_pattern, config_copy
)
dataset = AbstractDataset.from_config(
dataset_name,
dataset_config,
self._load_versions.get(dataset_name),
self._save_version,
)
if (
self._specificity(matched_pattern) == 0
and matched_pattern != "{default}"
):
self._logger.warning(
"Config from the dataset factory pattern '%s' in the catalog will be used to "
"override the default dataset creation for '%s'",
matched_pattern,
dataset_name,
)
self.add(dataset_name, dataset)
if dataset_name not in self._datasets:
error_msg = f"Dataset '{dataset_name}' not found in the catalog"
# Flag to turn on/off fuzzy-matching which can be time consuming and
# slow down plugins like `kedro-viz`
if suggest:
matches = difflib.get_close_matches(dataset_name, self._datasets.keys())
if matches:
suggestions = ", ".join(matches)
error_msg += f" - did you mean one of these instead: {suggestions}"
raise DatasetNotFoundError(error_msg)
dataset = self._datasets[dataset_name]
if version and isinstance(dataset, AbstractVersionedDataset):
# we only want to return a similar-looking dataset,
# not modify the one stored in the current catalog
dataset = dataset._copy(_version=version)
return dataset
def __contains__(self, dataset_name: str) -> bool:
"""Check if an item is in the catalog as a materialised dataset or pattern"""
matched_pattern = self._match_pattern(self._dataset_patterns, dataset_name)
if dataset_name in self._datasets or matched_pattern:
return True
return False
@classmethod
def _resolve_config(
cls,
dataset_name: str,
matched_pattern: str,
config: dict,
) -> dict[str, Any]:
"""Get resolved AbstractDataset from a factory config"""
result = parse(matched_pattern, dataset_name)
# Resolve the factory config for the dataset
if isinstance(config, dict):
for key, value in config.items():
config[key] = cls._resolve_config(dataset_name, matched_pattern, value)
elif isinstance(config, (list, tuple)):
config = [
cls._resolve_config(dataset_name, matched_pattern, value)
for value in config
]
elif isinstance(config, str) and "}" in config:
try:
config = str(config).format_map(result.named)
except KeyError as exc:
raise DatasetError(
f"Unable to resolve '{config}' from the pattern '{matched_pattern}'. Keys used in the configuration "
f"should be present in the dataset factory pattern."
) from exc
return config
[docs]
def load(self, name: str, version: str | None = None) -> Any:
"""Loads a registered data set.
Args:
name: A data set to be loaded.
version: Optional argument for concrete data version to be loaded.
Works only with versioned datasets.
Returns:
The loaded data as configured.
Raises:
DatasetNotFoundError: When a data set with the given name
has not yet been registered.
Example:
::
>>> from kedro.io import DataCatalog
>>> from kedro_datasets.pandas import CSVDataset
>>>
>>> cars = CSVDataset(filepath="cars.csv",
>>> load_args=None,
>>> save_args={"index": False})
>>> io = DataCatalog(datasets={'cars': cars})
>>>
>>> df = io.load("cars")
"""
load_version = Version(version, None) if version else None
dataset = self._get_dataset(name, version=load_version)
self._logger.info(
"Loading data from [dark_orange]%s[/dark_orange] (%s)...",
name,
type(dataset).__name__,
extra={"markup": True},
)
result = dataset.load()
return result
[docs]
def save(self, name: str, data: Any) -> None:
"""Save data to a registered data set.
Args:
name: A data set to be saved to.
data: A data object to be saved as configured in the registered
data set.
Raises:
DatasetNotFoundError: When a data set with the given name
has not yet been registered.
Example:
::
>>> import pandas as pd
>>>
>>> from kedro_datasets.pandas import CSVDataset
>>>
>>> cars = CSVDataset(filepath="cars.csv",
>>> load_args=None,
>>> save_args={"index": False})
>>> io = DataCatalog(datasets={'cars': cars})
>>>
>>> df = pd.DataFrame({'col1': [1, 2],
>>> 'col2': [4, 5],
>>> 'col3': [5, 6]})
>>> io.save("cars", df)
"""
dataset = self._get_dataset(name)
self._logger.info(
"Saving data to [dark_orange]%s[/dark_orange] (%s)...",
name,
type(dataset).__name__,
extra={"markup": True},
)
dataset.save(data)
[docs]
def exists(self, name: str) -> bool:
"""Checks whether registered data set exists by calling its `exists()`
method. Raises a warning and returns False if `exists()` is not
implemented.
Args:
name: A data set to be checked.
Returns:
Whether the data set output exists.
"""
try:
dataset = self._get_dataset(name)
except DatasetNotFoundError:
return False
return dataset.exists()
[docs]
def release(self, name: str) -> None:
"""Release any cached data associated with a data set
Args:
name: A data set to be checked.
Raises:
DatasetNotFoundError: When a data set with the given name
has not yet been registered.
"""
dataset = self._get_dataset(name)
dataset.release()
[docs]
def add(
self, dataset_name: str, dataset: AbstractDataset, replace: bool = False
) -> None:
"""Adds a new ``AbstractDataset`` object to the ``DataCatalog``.
Args:
dataset_name: A unique data set name which has not been
registered yet.
dataset: A data set object to be associated with the given data
set name.
replace: Specifies whether to replace an existing dataset
with the same name is allowed.
Raises:
DatasetAlreadyExistsError: When a data set with the same name
has already been registered.
Example:
::
>>> from kedro_datasets.pandas import CSVDataset
>>>
>>> io = DataCatalog(datasets={
>>> 'cars': CSVDataset(filepath="cars.csv")
>>> })
>>>
>>> io.add("boats", CSVDataset(filepath="boats.csv"))
"""
if dataset_name in self._datasets:
if replace:
self._logger.warning("Replacing dataset '%s'", dataset_name)
else:
raise DatasetAlreadyExistsError(
f"Dataset '{dataset_name}' has already been registered"
)
self._datasets[dataset_name] = dataset
self.datasets = _FrozenDatasets(self.datasets, {dataset_name: dataset})
[docs]
def add_all(
self, datasets: dict[str, AbstractDataset], replace: bool = False
) -> None:
"""Adds a group of new data sets to the ``DataCatalog``.
Args:
datasets: A dictionary of dataset names and dataset
instances.
replace: Specifies whether to replace an existing dataset
with the same name is allowed.
Raises:
DatasetAlreadyExistsError: When a data set with the same name
has already been registered.
Example:
::
>>> from kedro_datasets.pandas import CSVDataset, ParquetDataset
>>>
>>> io = DataCatalog(datasets={
>>> "cars": CSVDataset(filepath="cars.csv")
>>> })
>>> additional = {
>>> "planes": ParquetDataset("planes.parq"),
>>> "boats": CSVDataset(filepath="boats.csv")
>>> }
>>>
>>> io.add_all(additional)
>>>
>>> assert io.list() == ["cars", "planes", "boats"]
"""
for name, dataset in datasets.items():
self.add(name, dataset, replace)
[docs]
def add_feed_dict(self, feed_dict: dict[str, Any], replace: bool = False) -> None:
"""Adds instances of ``MemoryDataset``, containing the data provided
through feed_dict.
Args:
feed_dict: A feed dict with data to be added in memory.
replace: Specifies whether to replace an existing dataset
with the same name is allowed.
Example:
::
>>> import pandas as pd
>>>
>>> df = pd.DataFrame({'col1': [1, 2],
>>> 'col2': [4, 5],
>>> 'col3': [5, 6]})
>>>
>>> io = DataCatalog()
>>> io.add_feed_dict({
>>> 'data': df
>>> }, replace=True)
>>>
>>> assert io.load("data").equals(df)
"""
for dataset_name in feed_dict:
if isinstance(feed_dict[dataset_name], AbstractDataset):
dataset = feed_dict[dataset_name]
else:
dataset = MemoryDataset(data=feed_dict[dataset_name])
self.add(dataset_name, dataset, replace)
[docs]
def list(self, regex_search: str | None = None) -> list[str]:
"""
List of all dataset names registered in the catalog.
This can be filtered by providing an optional regular expression
which will only return matching keys.
Args:
regex_search: An optional regular expression which can be provided
to limit the data sets returned by a particular pattern.
Returns:
A list of dataset names available which match the
`regex_search` criteria (if provided). All data set names are returned
by default.
Raises:
SyntaxError: When an invalid regex filter is provided.
Example:
::
>>> io = DataCatalog()
>>> # get data sets where the substring 'raw' is present
>>> raw_data = io.list(regex_search='raw')
>>> # get data sets which start with 'prm' or 'feat'
>>> feat_eng_data = io.list(regex_search='^(prm|feat)')
>>> # get data sets which end with 'time_series'
>>> models = io.list(regex_search='.+time_series$')
"""
if regex_search is None:
return list(self._datasets.keys())
if not regex_search.strip():
self._logger.warning("The empty string will not match any data sets")
return []
try:
pattern = re.compile(regex_search, flags=re.IGNORECASE)
except re.error as exc:
raise SyntaxError(
f"Invalid regular expression provided: '{regex_search}'"
) from exc
return [dset_name for dset_name in self._datasets if pattern.search(dset_name)]
[docs]
def shallow_copy(
self, extra_dataset_patterns: Patterns | None = None
) -> DataCatalog:
"""Returns a shallow copy of the current object.
Returns:
Copy of the current object.
"""
if extra_dataset_patterns:
unsorted_dataset_patterns = {
**self._dataset_patterns,
**extra_dataset_patterns,
}
dataset_patterns = self._sort_patterns(unsorted_dataset_patterns)
else:
dataset_patterns = self._dataset_patterns
return DataCatalog(
datasets=self._datasets,
dataset_patterns=dataset_patterns,
load_versions=self._load_versions,
save_version=self._save_version,
)
def __eq__(self, other) -> bool: # type: ignore[no-untyped-def]
return (self._datasets, self._dataset_patterns) == (
other._datasets,
other._dataset_patterns,
)
[docs]
def confirm(self, name: str) -> None:
"""Confirm a dataset by its name.
Args:
name: Name of the dataset.
Raises:
DatasetError: When the dataset does not have `confirm` method.
"""
self._logger.info("Confirming dataset '%s'", name)
dataset = self._get_dataset(name)
if hasattr(dataset, "confirm"):
dataset.confirm()
else:
raise DatasetError(f"Dataset '{name}' does not have 'confirm' method")