kedro_datasets_experimental.darts.DartsTorchModelDataset

class kedro_datasets_experimental.darts.DartsTorchModelDataset(*, filepath, model_class, load_args=None, save_args=None, version=None, credentials=None, fs_args=None, metadata=None)[source]

DartsTorchModelDataset loads and saves Darts TorchForecastingModel instances. The underlying functionality is supported by, and passes arguments through to, the Darts library’s model load and save methods.

Example usage for the YAML API <https://kedro.readthedocs.io/en/stable/data/ data_catalog_yaml.html>_:

darts_model:
  type: path.to.DartsTorchModelDataset
  filepath: data/06_models/darts_model.pt
  model_class: RNNModel
  load_args:
    load_method: load
  save_args:
    save_model: true
  versioned: true

Example usage for the Python API <https://kedro.readthedocs.io/en/stable/data/ data_catalog_api.html>_:

from path.to.your.module import DartsTorchModelDataset
from darts.models import RNNModel
from kedro.io.core import Version

# Initialize the dataset
dataset = DartsTorchModelDataset(
    filepath="data/06_models/darts_model.pt",
    model_class=RNNModel
)

# Assuming model is an instance of RNNModel
model = RNNModel(input_chunk_length=12, output_chunk_length=6)

# Save the model
dataset.save(model)

# Load the model
loaded_model = dataset.load()

Attributes

DEFAULT_LOAD_ARGS

DEFAULT_SAVE_ARGS

Methods

exists()

Checks whether a dataset's output already exists by calling the provided _exists() method.

from_config(name, config[, load_version, ...])

Create a dataset instance using the configuration provided.

load()

Loads a TorchForecastingModel using the specified load method.

release()

Release any cached data.

resolve_load_version()

Compute the version the dataset should be loaded with.

resolve_save_version()

Compute the version the dataset should be saved with.

save(data)

Saves a TorchForecastingModel using the specified save method.

to_config()

Converts the dataset instance into a dictionary-based configuration for serialization.

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}
__init__(*, filepath, model_class, load_args=None, save_args=None, version=None, credentials=None, fs_args=None, metadata=None)[source]

Creates a new instance of DartsTorchModelDataset.

Parameters:
  • filepath (str) – Filepath in POSIX format to a model file or directory prefixed with a protocol like s3://. If prefix is not provided, the file protocol (local filesystem) will be used. The prefix should be any protocol supported by fsspec. Note: http(s) doesn’t support versioning.

  • model_class (str | type[TorchForecastingModel]) – The class of the model to load/save. Can be a string (name of the class in darts.models) or the class itself.

  • load_args (Optional[dict[str, Any]]) – Darts options for loading models. Available arguments depend on the load_method specified. All defaults are preserved.

  • save_args (Optional[dict[str, Any]]) – Darts options for saving models. Available arguments depend on the save method. All defaults are preserved.

  • version (Optional[Version]) – If specified, should be an instance of kedro.io.core.Version. If its load attribute is None, the latest version will be loaded. If its save attribute is None, save version will be autogenerated.

  • credentials (Optional[dict[str, Any]]) – Credentials required to access the underlying filesystem.

  • fs_args (Optional[dict[str, Any]]) – Extra arguments to pass into the underlying filesystem class constructor (e.g., {“project”: “my-project”} for GCSFileSystem).

  • metadata (Optional[dict[str, Any]]) – Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins.

exists()[source]

Checks whether a dataset’s output already exists by calling the provided _exists() method.

Return type:

bool

Returns:

Flag indicating whether the output already exists.

Raises:

DatasetError – when underlying exists method raises error.

classmethod from_config(name, config, load_version=None, save_version=None)[source]

Create a dataset instance using the configuration provided.

Parameters:
  • name (str) – Data set name.

  • config (dict[str, Any]) – Data set config dictionary.

  • load_version (Optional[str]) – Version string to be used for load operation if the dataset is versioned. Has no effect on the dataset if versioning was not enabled.

  • save_version (Optional[str]) – Version string to be used for save operation if the dataset is versioned. Has no effect on the dataset if versioning was not enabled.

Return type:

AbstractDataset

Returns:

An instance of an AbstractDataset subclass.

Raises:

DatasetError – When the function fails to create the dataset from its config.

load()[source]

Loads a TorchForecastingModel using the specified load method.

Return type:

TorchForecastingModel

Returns:

An instance of TorchForecastingModel.

release()[source]

Release any cached data.

Raises:

DatasetError – when underlying release method raises error.

Return type:

None

resolve_load_version()[source]

Compute the version the dataset should be loaded with.

Return type:

Optional[str]

resolve_save_version()[source]

Compute the version the dataset should be saved with.

Return type:

Optional[str]

save(data)[source]

Saves a TorchForecastingModel using the specified save method.

Parameters:

data (TorchForecastingModel) – The TorchForecastingModel instance to save.

Return type:

None

to_config()[source]

Converts the dataset instance into a dictionary-based configuration for serialization. Ensures that any subclass-specific details are handled, with additional logic for versioning and caching implemented for CachedDataset.

Adds a key for the dataset’s type using its module and class name and includes the initialization arguments.

For CachedDataset it extracts the underlying dataset’s configuration, handles the versioned flag and removes unnecessary metadata. It also ensures the embedded dataset’s configuration is appropriately flattened or transformed.

If the dataset has a version key, it sets the versioned flag in the configuration.

Removes the metadata key from the configuration if present.

Return type:

dict[str, Any]

Returns:

A dictionary containing the dataset’s type and initialization arguments.