Skip to content

safetensors.SafetensorsDataset

kedro_datasets_experimental.safetensors.SafetensorsDataset

SafetensorsDataset(
    *,
    filepath,
    backend="numpy",
    version=None,
    credentials=None,
    fs_args=None,
    metadata=None
)

Bases: AbstractVersionedDataset[Any, Any]

SafetensorsDataset loads/saves data from/to a Safetensors file using an underlying filesystem (e.g., local, S3, GCS). The underlying functionality is supported by the specified backend library (defaults to the numpy library), so it supports all allowed options for loading and saving Safetensors files.

Example usage for the YAML API
test_model:
    type: safetensors.SafetensorsDataset
    filepath: data/07_model_output/test_model.safetensors
Example usage for the Python API
from kedro_datasets_experimental.safetensors import SafetensorsDataset
import numpy as np

data = {
    "embedding": np.zeros((512, 1024)),
    "attention": np.zeros((256, 256))
}
dataset = SafetensorsDataset(
    filepath="test.safetensors",
)
dataset.save(data)
reloaded = dataset.load()
assert all(np.array_equal(data[key], reloaded[key]) for key in data)

serialise/deserialise objects.

The following backends are supported
  • numpy
  • torch
  • tensorflow
  • paddle
  • flax

Parameters:

  • filepath (str) –

    Filepath in POSIX format to a Safetensors file prefixed with a protocol like s3://. If prefix is not provided, file protocol (local filesystem) will be used. The prefix should be any protocol supported by fsspec. Note: http(s) doesn't support versioning.

  • backend (str, default: 'numpy' ) –

    The backend library to use for serialising/deserialising objects. The default backend is 'numpy'.

  • version (Version | None, default: None ) –

    If specified, should be an instance of kedro.io.core.Version. If its load attribute is None, the latest version will be loaded. If its save attribute is None, save version will be autogenerated.

  • credentials (dict[str, Any] | None, default: None ) –

    Credentials required to get access to the underlying filesystem. E.g. for GCSFileSystem it should look like {"token": None}.

  • fs_args (dict[str, Any] | None, default: None ) –

    Extra arguments to pass into underlying filesystem class constructor (e.g. {"project": "my-project"} for GCSFileSystem), as well as to pass to the filesystem's open method through nested keys open_args_load and open_args_save. Here you can find all available arguments for open: https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open All defaults are preserved, except mode, which is set to wb when saving.

  • metadata (dict[str, Any] | None, default: None ) –

    Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins.

Raises:

  • ImportError

    If the backend module could not be imported.

Source code in kedro_datasets_experimental/safetensors/safetensors_dataset.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def __init__(  # noqa: PLR0913
    self,
    *,
    filepath: str,
    backend: str = "numpy",
    version: Version | None = None,
    credentials: dict[str, Any] | None = None,
    fs_args: dict[str, Any] | None = None,
    metadata: dict[str, Any] | None = None,
) -> None:
    """Creates a new instance of ``SafetensorsDataset`` pointing to a concrete Safetensors
    file on a specific filesystem. ``SafetensorsDataset`` supports custom backends to
    serialise/deserialise objects.

    The following backends are supported:
        * `numpy`
        * `torch`
        * `tensorflow`
        * `paddle`
        * `flax`

    Args:
        filepath: Filepath in POSIX format to a Safetensors file prefixed with a protocol like
            `s3://`. If prefix is not provided, `file` protocol (local filesystem) will be used.
            The prefix should be any protocol supported by ``fsspec``.
            Note: `http(s)` doesn't support versioning.
        backend: The backend library to use for serialising/deserialising objects.
            The default backend is 'numpy'.
        version: If specified, should be an instance of
            ``kedro.io.core.Version``. If its ``load`` attribute is
            None, the latest version will be loaded. If its ``save``
            attribute is None, save version will be autogenerated.
        credentials: Credentials required to get access to the underlying filesystem.
            E.g. for ``GCSFileSystem`` it should look like `{"token": None}`.
        fs_args: Extra arguments to pass into underlying filesystem class constructor
            (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as
            to pass to the filesystem's `open` method through nested keys
            `open_args_load` and `open_args_save`.
            Here you can find all available arguments for `open`:
            https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
            All defaults are preserved, except `mode`, which is set to `wb` when saving.
        metadata: Any arbitrary metadata.
            This is ignored by Kedro, but may be consumed by users or external plugins.

    Raises:
        ImportError: If the ``backend`` module could not be imported.
    """
    try:
        importlib.import_module(f"safetensors.{backend}")
    except ImportError as exc:
        raise ImportError(
            f"Selected backend '{backend}' could not be imported. "
            "Make sure it is installed and importable."
        ) from exc

    _fs_args = deepcopy(fs_args) or {}
    _fs_open_args_load = _fs_args.pop("open_args_load", {})
    _fs_open_args_save = _fs_args.pop("open_args_save", {})
    _credentials = deepcopy(credentials) or {}

    protocol, path = get_protocol_and_path(filepath, version)
    if protocol == "file":
        _fs_args.setdefault("auto_mkdir", True)

    self._protocol = protocol
    self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args)

    self.metadata = metadata

    super().__init__(
        filepath=PurePosixPath(path),
        version=version,
        exists_function=self._fs.exists,
        glob_function=self._fs.glob,
    )

    self._backend = backend

    self._fs_open_args_load = {
        **self.DEFAULT_FS_ARGS.get("open_args_load", {}),
        **(_fs_open_args_load or {}),
    }
    self._fs_open_args_save = {
        **self.DEFAULT_FS_ARGS.get("open_args_save", {}),
        **(_fs_open_args_save or {}),
    }

DEFAULT_FS_ARGS class-attribute instance-attribute

DEFAULT_FS_ARGS = {'open_args_save': {'mode': 'wb'}}

DEFAULT_LOAD_ARGS class-attribute instance-attribute

DEFAULT_LOAD_ARGS = {}

DEFAULT_SAVE_ARGS class-attribute instance-attribute

DEFAULT_SAVE_ARGS = {}

_backend instance-attribute

_backend = backend

_fs instance-attribute

_fs = filesystem(_protocol, **_credentials, **_fs_args)

_fs_open_args_load instance-attribute

_fs_open_args_load = {
    None: get("open_args_load", {}),
    None: _fs_open_args_load or {},
}

_fs_open_args_save instance-attribute

_fs_open_args_save = {
    None: get("open_args_save", {}),
    None: _fs_open_args_save or {},
}

_protocol instance-attribute

_protocol = protocol

metadata instance-attribute

metadata = metadata

_describe

_describe()
Source code in kedro_datasets_experimental/safetensors/safetensors_dataset.py
162
163
164
165
166
167
168
def _describe(self) -> dict[str, Any]:
    return {
        "filepath": self._filepath,
        "backend": self._backend,
        "protocol": self._protocol,
        "version": self._version,
    }

_exists

_exists()
Source code in kedro_datasets_experimental/safetensors/safetensors_dataset.py
170
171
172
173
174
175
176
def _exists(self) -> bool:
    try:
        load_path = get_filepath_str(self._get_load_path(), self._protocol)
    except DatasetError:
        return False

    return self._fs.exists(load_path)

_invalidate_cache

_invalidate_cache()

Invalidate underlying filesystem caches.

Source code in kedro_datasets_experimental/safetensors/safetensors_dataset.py
182
183
184
185
def _invalidate_cache(self) -> None:
    """Invalidate underlying filesystem caches."""
    filepath = get_filepath_str(self._filepath, self._protocol)
    self._fs.invalidate_cache(filepath)

_release

_release()
Source code in kedro_datasets_experimental/safetensors/safetensors_dataset.py
178
179
180
def _release(self) -> None:
    super()._release()
    self._invalidate_cache()

load

load()
Source code in kedro_datasets_experimental/safetensors/safetensors_dataset.py
141
142
143
144
145
146
def load(self) -> Any:
    load_path = get_filepath_str(self._get_load_path(), self._protocol)

    with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
        imported_backend = importlib.import_module(f"safetensors.{self._backend}")
        return imported_backend.load(fs_file.read())

save

save(data)
Source code in kedro_datasets_experimental/safetensors/safetensors_dataset.py
148
149
150
151
152
153
154
155
156
157
158
159
160
def save(self, data: Any) -> None:
    save_path = get_filepath_str(self._get_save_path(), self._protocol)

    with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
        try:
            imported_backend = importlib.import_module(f"safetensors.{self._backend}")
            imported_backend.save_file(data, fs_file.name)
        except Exception as exc:
            raise DatasetError(
                f"{data.__class__} was not serialised due to: {exc}"
            ) from exc

    self._invalidate_cache()