kedro_datasets_experimental.darts.DartsTorchModelDataset¶
- class kedro_datasets_experimental.darts.DartsTorchModelDataset(*, filepath, model_class, load_args=None, save_args=None, version=None, credentials=None, fs_args=None, metadata=None)[source]¶
DartsTorchModelDataset loads and saves Darts TorchForecastingModel instances. The underlying functionality is supported by, and passes arguments through to, the Darts library’s model load and save methods.
Example usage for the YAML API <https://kedro.readthedocs.io/en/stable/data/ data_catalog_yaml.html>_:
darts_model: type: path.to.DartsTorchModelDataset filepath: data/06_models/darts_model.pt model_class: RNNModel load_args: load_method: load save_args: save_model: true versioned: true
Example usage for the Python API <https://kedro.readthedocs.io/en/stable/data/ data_catalog_api.html>_:
from path.to.your.module import DartsTorchModelDataset from darts.models import RNNModel from kedro.io.core import Version # Initialize the dataset dataset = DartsTorchModelDataset( filepath="data/06_models/darts_model.pt", model_class=RNNModel ) # Assuming model is an instance of RNNModel model = RNNModel(input_chunk_length=12, output_chunk_length=6) # Save the model dataset.save(model) # Load the model loaded_model = dataset.load()
Attributes
Methods
exists
()Checks whether a dataset's output already exists by calling the provided _exists() method.
from_config
(name, config[, load_version, ...])Create a dataset instance using the configuration provided.
load
()Loads a TorchForecastingModel using the specified load method.
release
()Release any cached data.
Compute the version the dataset should be loaded with.
Compute the version the dataset should be saved with.
save
(data)Saves a TorchForecastingModel using the specified save method.
Converts the dataset instance into a dictionary-based configuration for serialization.
- __init__(*, filepath, model_class, load_args=None, save_args=None, version=None, credentials=None, fs_args=None, metadata=None)[source]¶
Creates a new instance of DartsTorchModelDataset.
- Parameters:
filepath (
str
) – Filepath in POSIX format to a model file or directory prefixed with a protocol like s3://. If prefix is not provided, the file protocol (local filesystem) will be used. The prefix should be any protocol supported by fsspec. Note: http(s) doesn’t support versioning.model_class (
str
|type
[TorchForecastingModel
]) – The class of the model to load/save. Can be a string (name of the class in darts.models) or the class itself.load_args (
Optional
[dict
[str
,Any
]]) – Darts options for loading models. Available arguments depend on the load_method specified. All defaults are preserved.save_args (
Optional
[dict
[str
,Any
]]) – Darts options for saving models. Available arguments depend on the save method. All defaults are preserved.version (
Optional
[Version
]) – If specified, should be an instance of kedro.io.core.Version. If its load attribute is None, the latest version will be loaded. If its save attribute is None, save version will be autogenerated.credentials (
Optional
[dict
[str
,Any
]]) – Credentials required to access the underlying filesystem.fs_args (
Optional
[dict
[str
,Any
]]) – Extra arguments to pass into the underlying filesystem class constructor (e.g., {“project”: “my-project”} for GCSFileSystem).metadata (
Optional
[dict
[str
,Any
]]) – Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins.
- exists()[source]¶
Checks whether a dataset’s output already exists by calling the provided _exists() method.
- Return type:
- Returns:
Flag indicating whether the output already exists.
- Raises:
DatasetError – when underlying exists method raises error.
- classmethod from_config(name, config, load_version=None, save_version=None)[source]¶
Create a dataset instance using the configuration provided.
- Parameters:
name (
str
) – Data set name.load_version (
Optional
[str
]) – Version string to be used forload
operation if the dataset is versioned. Has no effect on the dataset if versioning was not enabled.save_version (
Optional
[str
]) – Version string to be used forsave
operation if the dataset is versioned. Has no effect on the dataset if versioning was not enabled.
- Return type:
- Returns:
An instance of an
AbstractDataset
subclass.- Raises:
DatasetError – When the function fails to create the dataset from its config.
- load()[source]¶
Loads a TorchForecastingModel using the specified load method.
- Return type:
TorchForecastingModel
- Returns:
An instance of TorchForecastingModel.
- release()[source]¶
Release any cached data.
- Raises:
DatasetError – when underlying release method raises error.
- Return type:
- save(data)[source]¶
Saves a TorchForecastingModel using the specified save method.
- Parameters:
data (
TorchForecastingModel
) – The TorchForecastingModel instance to save.- Return type:
- to_config()[source]¶
Converts the dataset instance into a dictionary-based configuration for serialization. Ensures that any subclass-specific details are handled, with additional logic for versioning and caching implemented for CachedDataset.
Adds a key for the dataset’s type using its module and class name and includes the initialization arguments.
For CachedDataset it extracts the underlying dataset’s configuration, handles the versioned flag and removes unnecessary metadata. It also ensures the embedded dataset’s configuration is appropriately flattened or transformed.
If the dataset has a version key, it sets the versioned flag in the configuration.
Removes the metadata key from the configuration if present.