Skip to content

TensorFlowModelDataset

TensorFlowModelDataset loads and saves TensorFlow models.

kedro_datasets.tensorflow.TensorFlowModelDataset

TensorFlowModelDataset(
    *,
    filepath,
    load_args=None,
    save_args=None,
    version=None,
    credentials=None,
    fs_args=None,
    metadata=None
)

Bases: AbstractVersionedDataset[Model, Model]

TensorFlowModelDataset loads and saves TensorFlow models. The underlying functionality is supported by, and passes input arguments through to, TensorFlow 2.X load_model and save_model methods.

TensorFlow does not currently support Python 3.14.

Examples:

Using the YAML API:

tensorflow_model:
  type: tensorflow.TensorFlowModelDataset
  filepath: data/06_models/tensorflow_model.h5
  load_args:
    compile: False
  save_args:
    overwrite: True
    include_optimizer: False
  credentials: tf_creds

Using the Python API:

>>> import numpy as np
>>> import tensorflow as tf
>>> from kedro_datasets.tensorflow import TensorFlowModelDataset
>>>
>>> model = tf.keras.Sequential(
...     [tf.keras.layers.Dense(5, input_shape=(3,)), tf.keras.layers.Softmax()]
... )
>>> # x = tf.random.uniform((10, 3))
>>> # predictions = model.predict(x)
>>>
>>> dataset = TensorFlowModelDataset(
...     filepath=tmp_path / "data/06_models/tensorflow_model.h5"
... )
>>> dataset.save(model)
>>> loaded_model = dataset.load()

Parameters:

  • filepath (str) –

    Filepath in POSIX format to a TensorFlow model directory prefixed with a protocol like s3://. If prefix is not provided file protocol (local filesystem) will be used. The prefix should be any protocol supported by fsspec. Note: http(s) doesn't support versioning.

  • load_args (dict[str, Any] | None, default: None ) –

    TensorFlow options for loading models. Here you can find all available arguments: https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model All defaults are preserved.

  • save_args (dict[str, Any] | None, default: None ) –

    TensorFlow options for saving models. Here you can find all available arguments: https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model All defaults are preserved, except for "save_format", which is set to "tf".

  • version (Version | None, default: None ) –

    If specified, should be an instance of kedro.io.core.Version. If its load attribute is None, the latest version will be loaded. If its save attribute is None, save version will be autogenerated.

  • credentials (dict[str, Any] | None, default: None ) –

    Credentials required to get access to the underlying filesystem. E.g. for GCSFileSystem it should look like {'token': None}.

  • fs_args (dict[str, Any] | None, default: None ) –

    Extra arguments to pass into underlying filesystem class constructor (e.g. {"project": "my-project"} for GCSFileSystem).

  • metadata (dict[str, Any] | None, default: None ) –

    Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins.

Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def __init__(  # noqa: PLR0913
    self,
    *,
    filepath: str,
    load_args: dict[str, Any] | None = None,
    save_args: dict[str, Any] | None = None,
    version: Version | None = None,
    credentials: dict[str, Any] | None = None,
    fs_args: dict[str, Any] | None = None,
    metadata: dict[str, Any] | None = None,
) -> None:
    """Creates a new instance of ``TensorFlowModelDataset``.

    Args:
        filepath: Filepath in POSIX format to a TensorFlow model directory prefixed with a
            protocol like `s3://`. If prefix is not provided `file` protocol (local filesystem)
            will be used. The prefix should be any protocol supported by ``fsspec``.
            Note: `http(s)` doesn't support versioning.
        load_args: TensorFlow options for loading models.
            Here you can find all available arguments:
            https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model
            All defaults are preserved.
        save_args: TensorFlow options for saving models.
            Here you can find all available arguments:
            https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model
            All defaults are preserved, except for "save_format", which is set to "tf".
        version: If specified, should be an instance of
            ``kedro.io.core.Version``. If its ``load`` attribute is
            None, the latest version will be loaded. If its ``save``
            attribute is None, save version will be autogenerated.
        credentials: Credentials required to get access to the underlying filesystem.
            E.g. for ``GCSFileSystem`` it should look like `{'token': None}`.
        fs_args: Extra arguments to pass into underlying filesystem class constructor
            (e.g. `{"project": "my-project"}` for ``GCSFileSystem``).
        metadata: Any arbitrary metadata.
            This is ignored by Kedro, but may be consumed by users or external plugins.
    """
    _fs_args = copy.deepcopy(fs_args) or {}
    _credentials = copy.deepcopy(credentials) or {}
    protocol, path = get_protocol_and_path(filepath, version)
    if protocol == "file":
        _fs_args.setdefault("auto_mkdir", True)

    self._protocol = protocol
    self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args)

    self.metadata = metadata

    super().__init__(
        filepath=PurePosixPath(path),
        version=version,
        exists_function=self._fs.exists,
        glob_function=self._fs.glob,
    )

    self._tmp_prefix = "kedro_tensorflow_tmp"  # temp prefix pattern

    # Handle default load and save arguments
    self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
    self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}

    self._is_h5 = self._save_args.get("save_format") == "h5"

DEFAULT_LOAD_ARGS class-attribute instance-attribute

DEFAULT_LOAD_ARGS = {}

DEFAULT_SAVE_ARGS class-attribute instance-attribute

DEFAULT_SAVE_ARGS = {}

_fs instance-attribute

_fs = filesystem(_protocol, **_credentials, **_fs_args)

_is_h5 instance-attribute

_is_h5 = get('save_format') == 'h5'

_load_args instance-attribute

_load_args = {
    None: DEFAULT_LOAD_ARGS,
    None: load_args or {},
}

_protocol instance-attribute

_protocol = protocol

_save_args instance-attribute

_save_args = {
    None: DEFAULT_SAVE_ARGS,
    None: save_args or {},
}

_tmp_prefix instance-attribute

_tmp_prefix = 'kedro_tensorflow_tmp'

metadata instance-attribute

metadata = metadata

_describe

_describe()
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
177
178
179
180
181
182
183
184
def _describe(self) -> dict[str, Any]:
    return {
        "filepath": self._filepath,
        "protocol": self._protocol,
        "load_args": self._load_args,
        "save_args": self._save_args,
        "version": self._version,
    }

_exists

_exists()
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
170
171
172
173
174
175
def _exists(self) -> bool:
    try:
        load_path = get_filepath_str(self._get_load_path(), self._protocol)
    except DatasetError:
        return False
    return self._fs.exists(load_path)

_invalidate_cache

_invalidate_cache()

Invalidate underlying filesystem caches.

Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
190
191
192
193
def _invalidate_cache(self) -> None:
    """Invalidate underlying filesystem caches."""
    filepath = get_filepath_str(self._filepath, self._protocol)
    self._fs.invalidate_cache(filepath)

_release

_release()
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
186
187
188
def _release(self) -> None:
    super()._release()
    self._invalidate_cache()

load

load()
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def load(self) -> tf.keras.Model:
    load_path = get_filepath_str(self._get_load_path(), self._protocol)

    with tempfile.TemporaryDirectory(prefix=self._tmp_prefix) as tempdir:
        if self._is_h5:
            path = str(PurePath(tempdir) / TEMPORARY_H5_FILE)  # noqa: PLW2901
        else:
            # We assume .keras
            path = str(PurePath(tempdir) / TEMPORARY_KERAS_FILE)  # noqa: PLW2901

        self._fs.get(load_path, path)

        # Pass the local temporary directory/file path to keras.load_model
        device_name = self._load_args.pop("tf_device", None)
        if device_name:
            with tf.device(device_name):
                model = tf.keras.models.load_model(path, **self._load_args)
        else:
            model = tf.keras.models.load_model(path, **self._load_args)
        return model

save

save(data)
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def save(self, data: tf.keras.Model) -> None:
    save_path = get_filepath_str(self._get_save_path(), self._protocol)

    with tempfile.TemporaryDirectory(prefix=self._tmp_prefix) as tempdir:
        if self._is_h5:
            path = str(PurePath(tempdir) / TEMPORARY_H5_FILE)  # noqa: PLW2901
        else:
            # We assume .keras
            path = str(PurePath(tempdir) / TEMPORARY_KERAS_FILE)  # noqa: PLW2901

        tf.keras.models.save_model(data, path, **self._save_args)

        # Use fsspec to take from local tempfile directory/file and
        # put in ArbitraryFileSystem
        self._fs.put(path, save_path)