"""
This module contains ``CachedDataset``, a dataset wrapper which caches in memory the data saved,
so that the user avoids io operations with slow storage media
"""
from __future__ import annotations
import logging
import warnings
from typing import Any
from kedro import KedroDeprecationWarning
from kedro.io.core import VERSIONED_FLAG_KEY, AbstractDataset, Version
from kedro.io.memory_dataset import MemoryDataset
# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901
CachedDataSet: type[CachedDataset]
[docs]class CachedDataset(AbstractDataset):
"""``CachedDataset`` is a dataset wrapper which caches in memory the data saved,
so that the user avoids io operations with slow storage media.
You can also specify a ``CachedDataset`` in catalog.yml:
::
>>> test_ds:
>>> type: CachedDataset
>>> versioned: true
>>> dataset:
>>> type: pandas.CSVDataset
>>> filepath: example.csv
Please note that if your dataset is versioned, this should be indicated in the wrapper
class as shown above.
"""
# this dataset cannot be used with ``ParallelRunner``,
# therefore it has the attribute ``_SINGLE_PROCESS = True``
# for parallelism please consider ``ThreadRunner`` instead
_SINGLE_PROCESS = True
[docs] def __init__(
self,
dataset: AbstractDataset | dict,
version: Version = None,
copy_mode: str = None,
metadata: dict[str, Any] = None,
):
"""Creates a new instance of ``CachedDataset`` pointing to the
provided Python object.
Args:
dataset: A Kedro Dataset object or a dictionary to cache.
version: If specified, should be an instance of
``kedro.io.core.Version``. If its ``load`` attribute is
None, the latest version will be loaded. If its ``save``
attribute is None, save version will be autogenerated.
copy_mode: The copy mode used to copy the data. Possible
values are: "deepcopy", "copy" and "assign". If not
provided, it is inferred based on the data type.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
Raises:
ValueError: If the provided dataset is not a valid dict/YAML
representation of a dataset or an actual dataset.
"""
if isinstance(dataset, dict):
self._dataset = self._from_config(dataset, version)
elif isinstance(dataset, AbstractDataset):
self._dataset = dataset
else:
raise ValueError(
"The argument type of 'dataset' should be either a dict/YAML "
"representation of the dataset, or the actual dataset object."
)
self._cache = MemoryDataset(copy_mode=copy_mode)
self.metadata = metadata
def _release(self) -> None:
self._cache.release()
self._dataset.release()
@staticmethod
def _from_config(config, version):
if VERSIONED_FLAG_KEY in config:
raise ValueError(
"Cached datasets should specify that they are versioned in the "
"'CachedDataset', not in the wrapped dataset."
)
if version:
config[VERSIONED_FLAG_KEY] = True
return AbstractDataset.from_config(
"_cached", config, version.load, version.save
)
return AbstractDataset.from_config("_cached", config)
def _describe(self) -> dict[str, Any]:
return {
"dataset": self._dataset._describe(), # noqa: protected-access
"cache": self._cache._describe(), # noqa: protected-access
}
def _load(self):
data = self._cache.load() if self._cache.exists() else self._dataset.load()
if not self._cache.exists():
self._cache.save(data)
return data
def _save(self, data: Any) -> None:
self._dataset.save(data)
self._cache.save(data)
def _exists(self) -> bool:
return self._cache.exists() or self._dataset.exists()
def __getstate__(self):
# clearing the cache can be prevented by modifying
# how parallel runner handles datasets (not trivial!)
logging.getLogger(__name__).warning("%s: clearing cache to pickle.", str(self))
self._cache.release()
return self.__dict__
def __getattr__(name):
if name == "CachedDataSet":
alias = CachedDataset
warnings.warn(
f"{repr(name)} has been renamed to {repr(alias.__name__)}, "
f"and the alias will be removed in Kedro 0.19.0",
KedroDeprecationWarning,
stacklevel=2,
)
return alias
raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}")