TensorFlowModelDataset¶
TensorFlowModelDataset loads and saves TensorFlow models.
kedro_datasets.tensorflow.TensorFlowModelDataset ¶
TensorFlowModelDataset(
*,
filepath,
load_args=None,
save_args=None,
version=None,
credentials=None,
fs_args=None,
metadata=None
)
Bases: AbstractVersionedDataset[Model, Model]
TensorFlowModelDataset loads and saves TensorFlow models.
The underlying functionality is supported by, and passes input arguments through to,
TensorFlow 2.X load_model and save_model methods.
Examples:
Using the YAML API:
tensorflow_model:
type: tensorflow.TensorFlowModelDataset
filepath: data/06_models/tensorflow_model.h5
load_args:
compile: False
save_args:
overwrite: True
include_optimizer: False
credentials: tf_creds
Using the Python API:
>>> import numpy as np
>>> import tensorflow as tf
>>> from kedro_datasets.tensorflow import TensorFlowModelDataset
>>>
>>> model = tf.keras.Sequential(
... [tf.keras.layers.Dense(5, input_shape=(3,)), tf.keras.layers.Softmax()]
... )
>>> # x = tf.random.uniform((10, 3))
>>> # predictions = model.predict(x)
>>>
>>> dataset = TensorFlowModelDataset(
... filepath=tmp_path / "data/06_models/tensorflow_model.h5"
... )
>>> dataset.save(model)
>>> loaded_model = dataset.load()
Parameters:
-
filepath(str) –Filepath in POSIX format to a TensorFlow model directory prefixed with a protocol like
s3://. If prefix is not providedfileprotocol (local filesystem) will be used. The prefix should be any protocol supported byfsspec. Note:http(s)doesn't support versioning. -
load_args(dict[str, Any] | None, default:None) –TensorFlow options for loading models. Here you can find all available arguments: https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model All defaults are preserved.
-
save_args(dict[str, Any] | None, default:None) –TensorFlow options for saving models. Here you can find all available arguments: https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model All defaults are preserved, except for "save_format", which is set to "tf".
-
version(Version | None, default:None) –If specified, should be an instance of
kedro.io.core.Version. If itsloadattribute is None, the latest version will be loaded. If itssaveattribute is None, save version will be autogenerated. -
credentials(dict[str, Any] | None, default:None) –Credentials required to get access to the underlying filesystem. E.g. for
GCSFileSystemit should look like{'token': None}. -
fs_args(dict[str, Any] | None, default:None) –Extra arguments to pass into underlying filesystem class constructor (e.g.
{"project": "my-project"}forGCSFileSystem). -
metadata(dict[str, Any] | None, default:None) –Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins.
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | |
_describe ¶
_describe()
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
175 176 177 178 179 180 181 182 | |
_exists ¶
_exists()
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
168 169 170 171 172 173 | |
_invalidate_cache ¶
_invalidate_cache()
Invalidate underlying filesystem caches.
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
188 189 190 191 | |
_release ¶
_release()
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
184 185 186 | |
load ¶
load()
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | |
save ¶
save(data)
Source code in kedro_datasets/tensorflow/tensorflow_model_dataset.py
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | |