TensorFlowModelDataset¶
TensorFlowModelDataset loads and saves TensorFlow models.
kedro_datasets.tensorflow.TensorFlowModelDataset ¶
TensorFlowModelDataset(
*,
filepath,
load_args=None,
save_args=None,
version=None,
credentials=None,
fs_args=None,
metadata=None
)
Bases: AbstractVersionedDataset[Model, Model]
TensorFlowModelDataset loads and saves TensorFlow models.
The underlying functionality is supported by, and passes input arguments through to,
TensorFlow 2.X load_model and save_model methods.
TensorFlow does not currently support Python 3.14.
Examples:
Using the YAML API:
tensorflow_model:
type: tensorflow.TensorFlowModelDataset
filepath: data/06_models/tensorflow_model.h5
load_args:
compile: False
save_args:
overwrite: True
include_optimizer: False
credentials: tf_creds
Using the Python API:
>>> import numpy as np
>>> import tensorflow as tf
>>> from kedro_datasets.tensorflow import TensorFlowModelDataset
>>>
>>> model = tf.keras.Sequential(
... [tf.keras.layers.Dense(5, input_shape=(3,)), tf.keras.layers.Softmax()]
... )
>>> # x = tf.random.uniform((10, 3))
>>> # predictions = model.predict(x)
>>>
>>> dataset = TensorFlowModelDataset(
... filepath=tmp_path / "data/06_models/tensorflow_model.h5"
... )
>>> dataset.save(model)
>>> loaded_model = dataset.load()
Parameters:
-
filepath(str) –Filepath in POSIX format to a TensorFlow model directory prefixed with a protocol like
s3://. If prefix is not providedfileprotocol (local filesystem) will be used. The prefix should be any protocol supported byfsspec. Note:http(s)doesn't support versioning. -
load_args(dict[str, Any] | None, default:None) –TensorFlow options for loading models. Here you can find all available arguments: https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model All defaults are preserved.
-
save_args(dict[str, Any] | None, default:None) –TensorFlow options for saving models. Here you can find all available arguments: https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model All defaults are preserved, except for "save_format", which is set to "tf".
-
version(Version | None, default:None) –If specified, should be an instance of
kedro.io.core.Version. If itsloadattribute is None, the latest version will be loaded. If itssaveattribute is None, save version will be autogenerated. -
credentials(dict[str, Any] | None, default:None) –Credentials required to get access to the underlying filesystem. E.g. for
GCSFileSystemit should look like{'token': None}. -
fs_args(dict[str, Any] | None, default:None) –Extra arguments to pass into underlying filesystem class constructor (e.g.
{"project": "my-project"}forGCSFileSystem). -
metadata(dict[str, Any] | None, default:None) –Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins.
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | |
_describe ¶
_describe()
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
177 178 179 180 181 182 183 184 | |
_exists ¶
_exists()
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
170 171 172 173 174 175 | |
_invalidate_cache ¶
_invalidate_cache()
Invalidate underlying filesystem caches.
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
190 191 192 193 | |
_release ¶
_release()
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
186 187 188 | |
load ¶
load()
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | |
save ¶
save(data)
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | |