Skip to content

pytorch.PyTorchDataset

kedro_datasets_experimental.pytorch.PyTorchDataset

PyTorchDataset(
    *,
    filepath,
    load_args=None,
    save_args=None,
    version=None,
    credentials=None,
    fs_args=None,
    metadata=None
)

Bases: AbstractVersionedDataset[Any, Any]

PyTorchDataset loads and saves PyTorch models' state_dict using PyTorch's recommended zipfile serialization protocol to avoid security issues with Pickle.

Example usage for the YAML API
model:
    type: pytorch.PyTorchDataset
    filepath: data/06_models/model.pt
Example usage for the Python API
from kedro_datasets_experimental.pytorch import PyTorchDataset
import torch

# Define your model
model: torch.nn.Module
model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU())

# Save model state dict
dataset = PyTorchDataset(filepath="data/06_models/model.pt")
dataset.save(model)

# Reload model state dict
reloaded = TheModelClass(*args, **kwargs)
reloaded.load_state_dict(dataset.load())
Source code in kedro_datasets_experimental/pytorch/pytorch_dataset.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def __init__(  # noqa: PLR0913
        self,
        *,
        filepath,
        load_args: dict[str, Any] = None,
        save_args: dict[str, Any] = None,
        version: Version | None = None,
        credentials: dict[str, Any] = None,
        fs_args: dict[str, Any] = None,
        metadata: dict[str, Any] = None,
):
    _fs_args = deepcopy(fs_args) or {}
    _fs_open_args_load = _fs_args.pop("open_args_load", {})
    _fs_open_args_save = _fs_args.pop("open_args_save", {})
    _credentials = deepcopy(credentials) or {}

    protocol, path = get_protocol_and_path(filepath, version)
    if protocol == "file":
        _fs_args.setdefault("auto_mkdir", True)

    self._protocol = protocol
    self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args)

    self.metadata = metadata

    super().__init__(
        filepath=PurePosixPath(path),
        version=version,
        exists_function=self._fs.exists,
        glob_function=self._fs.glob,
    )

    # Handle default load and save arguments
    self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
    if load_args is not None:
        self._load_args.update(load_args)
    self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
    if save_args is not None:
        self._save_args.update(save_args)

    self._fs_open_args_load = _fs_open_args_load
    self._fs_open_args_save = _fs_open_args_save

DEFAULT_LOAD_ARGS class-attribute instance-attribute

DEFAULT_LOAD_ARGS = {}

DEFAULT_SAVE_ARGS class-attribute instance-attribute

DEFAULT_SAVE_ARGS = {}

_fs instance-attribute

_fs = filesystem(_protocol, **_credentials, **_fs_args)

_fs_open_args_load instance-attribute

_fs_open_args_load = _fs_open_args_load

_fs_open_args_save instance-attribute

_fs_open_args_save = _fs_open_args_save

_load_args instance-attribute

_load_args = deepcopy(DEFAULT_LOAD_ARGS)

_protocol instance-attribute

_protocol = protocol

_save_args instance-attribute

_save_args = deepcopy(DEFAULT_SAVE_ARGS)

metadata instance-attribute

metadata = metadata

_describe

_describe()
Source code in kedro_datasets_experimental/pytorch/pytorch_dataset.py
 96
 97
 98
 99
100
101
102
103
def _describe(self) -> dict[str, Any]:
    return {
        "filepath": self._filepath,
        "protocol": self._protocol,
        "load_args": self._load_args,
        "save_args": self._save_args,
        "version": self._version,
    }

_exists

_exists()
Source code in kedro_datasets_experimental/pytorch/pytorch_dataset.py
115
116
117
118
119
120
121
def _exists(self):
    try:
        load_path = get_filepath_str(self._get_load_path(), self._protocol)
    except DatasetError:
        return False

    return self._fs.exists(load_path)

_invalidate_cache

_invalidate_cache()

Invalidate underlying filesystem caches.

Source code in kedro_datasets_experimental/pytorch/pytorch_dataset.py
127
128
129
130
def _invalidate_cache(self) -> None:
    """Invalidate underlying filesystem caches."""
    filepath = get_filepath_str(self._filepath, self._protocol)
    self._fs.invalidate_cache(filepath)

_release

_release()
Source code in kedro_datasets_experimental/pytorch/pytorch_dataset.py
123
124
125
def _release(self) -> None:
    super()._release()
    self._invalidate_cache()

load

load()
Source code in kedro_datasets_experimental/pytorch/pytorch_dataset.py
105
106
107
def load(self) -> Any:
    load_path = get_filepath_str(self._get_load_path(), self._protocol)
    return torch.load(load_path, **self._fs_open_args_load)  #nosec: B614

save

save(data)
Source code in kedro_datasets_experimental/pytorch/pytorch_dataset.py
109
110
111
112
113
def save(self, data: torch.nn.Module) -> None:
    save_path = get_filepath_str(self._get_save_path(), self._protocol)
    torch.save(data.state_dict(), save_path, **self._fs_open_args_save)  #nosec: B614

    self._invalidate_cache()