kedro_datasets_experimental.pytorch.PyTorchDataset

class kedro_datasets_experimental.pytorch.PyTorchDataset(*, filepath, load_args=None, save_args=None, version=None, credentials=None, fs_args=None, metadata=None)[source]

PyTorchDataset loads and saves PyTorch models’ state_dict using PyTorch’s recommended zipfile serialization protocol. To avoid security issues with Pickle.

model:
  type: pytorch.PyTorchDataset
  filepath: data/06_models/model.pt
from kedro_datasets_experimental.pytorch import PyTorchDataset
import torch

model: torch.nn.Module
model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU())
dataset = PyTorchDataset(filepath=tmp_path / "model.pt")
dataset.save(model)
reloaded = TheModelClass(*args, **kwargs)
reloaded.load_state_dict(dataset.load())

Attributes

DEFAULT_LOAD_ARGS

DEFAULT_SAVE_ARGS

Methods

exists()

Checks whether a dataset's output already exists by calling the provided _exists() method.

from_config(name, config[, load_version, ...])

Create a dataset instance using the configuration provided.

load()

Loads data by delegation to the provided load method.

release()

Release any cached data.

resolve_load_version()

Compute the version the dataset should be loaded with.

resolve_save_version()

Compute the version the dataset should be saved with.

save(data)

Saves data by delegation to the provided save method.

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}
exists()

Checks whether a dataset’s output already exists by calling the provided _exists() method.

Return type:

bool

Returns:

Flag indicating whether the output already exists.

Raises:

DatasetError – when underlying exists method raises error.

classmethod from_config(name, config, load_version=None, save_version=None)

Create a dataset instance using the configuration provided.

Parameters:
  • name (str) – Data set name.

  • config (dict[str, Any]) – Data set config dictionary.

  • load_version (Optional[str]) – Version string to be used for load operation if the dataset is versioned. Has no effect on the dataset if versioning was not enabled.

  • save_version (Optional[str]) – Version string to be used for save operation if the dataset is versioned. Has no effect on the dataset if versioning was not enabled.

Return type:

AbstractDataset

Returns:

An instance of an AbstractDataset subclass.

Raises:

DatasetError – When the function fails to create the dataset from its config.

load()[source]

Loads data by delegation to the provided load method.

Return type:

Any

Returns:

Data returned by the provided load method.

Raises:

DatasetError – When underlying load method raises error.

release()

Release any cached data.

Raises:

DatasetError – when underlying release method raises error.

Return type:

None

resolve_load_version()

Compute the version the dataset should be loaded with.

Return type:

Optional[str]

resolve_save_version()

Compute the version the dataset should be saved with.

Return type:

Optional[str]

save(data)[source]

Saves data by delegation to the provided save method.

Parameters:

data (Module) – the value to be saved by provided save method.

Raises:
  • DatasetError – when underlying save method raises error.

  • FileNotFoundError – when save method got file instead of dir, on Windows.

  • NotADirectoryError – when save method got file instead of dir, on Unix.

Return type:

None