Skip to content

MatplotlibDataset

MatplotlibDataset is used to load and save matplotlib figures.

kedro_datasets.matplotlib.MatplotlibDataset

MatplotlibDataset(
    *,
    filepath,
    fs_args=None,
    credentials=None,
    save_args=None,
    version=None,
    overwrite=False,
    metadata=None
)

Bases: AbstractVersionedDataset[Figure | list[Figure] | dict[str, Figure], NoReturn]

MatplotlibDataset saves one or more Matplotlib objects as image files to an underlying filesystem (e.g. local, S3, GCS).

Examples:

Using the YAML API:

output_plot:
  type: matplotlib.MatplotlibDataset
  filepath: data/08_reporting/output_plot.png
  save_args:
    format: png

Using the Python API:

>>> import matplotlib.pyplot as plt
>>> from kedro_datasets.matplotlib import MatplotlibDataset
>>>
>>> fig = plt.figure()
>>> plt.plot([1, 2, 3])
[<matplotlib.lines.Line2D object at 0x...>]
>>> plot_dataset = MatplotlibDataset(filepath=tmp_path / "data/08_reporting/output_plot.png")
>>> plt.close()
>>> plot_dataset.save(fig)

Saving a plot as a PDF file:

>>> import matplotlib.pyplot as plt
>>> from kedro_datasets.matplotlib import MatplotlibDataset
>>>
>>> fig = plt.figure()
>>> plt.plot([1, 2, 3])
[<matplotlib.lines.Line2D object at 0x...>]
>>> pdf_plot_dataset = MatplotlibDataset(
...     filepath=tmp_path / "data/08_reporting/output_plot.pdf", save_args={"format": "pdf"}
... )
>>> plt.close()
>>> pdf_plot_dataset.save(fig)

Saving multiple plots in a folder, using a dictionary:

>>> import matplotlib.pyplot as plt
>>> from kedro_datasets.matplotlib import MatplotlibDataset
>>>
>>> plots_dict = {}
>>> for colour in ["blue", "green", "red"]:
...     plots_dict[f"{colour}.png"] = plt.figure()
...     plt.plot([1, 2, 3], color=colour)
...
[<matplotlib.lines.Line2D object at 0x...>]
[<matplotlib.lines.Line2D object at 0x...>]
[<matplotlib.lines.Line2D object at 0x...>]
>>> plt.close("all")
>>> dict_plot_dataset = MatplotlibDataset(filepath=tmp_path / "data/08_reporting/plots")
>>> dict_plot_dataset.save(plots_dict)

Parameters:

  • filepath (str | PathLike) –

    Filepath in POSIX format to save Matplotlib objects to, prefixed with a protocol like s3://. If prefix is not provided, file protocol (local filesystem) will be used. The prefix should be any protocol supported by fsspec.

  • fs_args (dict[str, Any] | None, default: None ) –

    Extra arguments to pass into underlying filesystem class constructor (e.g. {"project": "my-project"} for GCSFileSystem), as well as to pass to the filesystem's open method through nested key open_args_save. Here you can find all available arguments for open: https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open All defaults are preserved, except mode, which is set to wb when saving.

  • credentials (dict[str, Any] | None, default: None ) –

    Credentials required to get access to the underlying filesystem. E.g. for S3FileSystem it should look like: {'key': '<id>', 'secret': '<key>'}}

  • save_args (dict[str, Any] | None, default: None ) –

    Save args passed to plt.savefig. See https://matplotlib.org/api/_as_gen/matplotlib.pyplot.savefig.html

  • version (Version | None, default: None ) –

    If specified, should be an instance of kedro.io.core.Version. If its load attribute is None, the latest version will be loaded. If its save attribute is None, save version will be autogenerated.

  • overwrite (bool, default: False ) –

    If True, any existing image files will be removed. Only relevant when saving multiple Matplotlib objects at once.

  • metadata (dict[str, Any] | None, default: None ) –

    Any arbitrary Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins.

Source code in kedro_datasets/matplotlib/matplotlib_dataset.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def __init__(  # noqa: PLR0913
    self,
    *,
    filepath: str | os.PathLike,
    fs_args: dict[str, Any] | None = None,
    credentials: dict[str, Any] | None = None,
    save_args: dict[str, Any] | None = None,
    version: Version | None = None,
    overwrite: bool = False,
    metadata: dict[str, Any] | None = None,
) -> None:
    """Creates a new instance of ``MatplotlibDataset``.

    Args:
        filepath: Filepath in POSIX format to save Matplotlib objects to, prefixed with a
            protocol like `s3://`. If prefix is not provided, `file` protocol (local filesystem)
            will be used. The prefix should be any protocol supported by ``fsspec``.
        fs_args: Extra arguments to pass into underlying filesystem class constructor
            (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as
            to pass to the filesystem's `open` method through nested key `open_args_save`.
            Here you can find all available arguments for `open`:
            https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
            All defaults are preserved, except `mode`, which is set to `wb` when saving.
        credentials: Credentials required to get access to the underlying filesystem.
            E.g. for ``S3FileSystem`` it should look like:
            `{'key': '<id>', 'secret': '<key>'}}`
        save_args: Save args passed to `plt.savefig`. See
            https://matplotlib.org/api/_as_gen/matplotlib.pyplot.savefig.html
        version: If specified, should be an instance of
            ``kedro.io.core.Version``. If its ``load`` attribute is
            None, the latest version will be loaded. If its ``save``
            attribute is None, save version will be autogenerated.
        overwrite: If True, any existing image files will be removed.
            Only relevant when saving multiple Matplotlib objects at
            once.
        metadata: Any arbitrary Any arbitrary metadata.
            This is ignored by Kedro, but may be consumed by users or external plugins.
    """
    _credentials = deepcopy(credentials) or {}
    _fs_args = deepcopy(fs_args) or {}
    _fs_open_args_save = _fs_args.pop("open_args_save", {})
    _fs_open_args_save.setdefault("mode", "wb")

    protocol, path = get_protocol_and_path(filepath, version)
    if protocol == "file":
        _fs_args.setdefault("auto_mkdir", True)

    self._protocol = protocol
    self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args)

    self.metadata = metadata

    super().__init__(
        filepath=PurePosixPath(path),
        version=version,
        exists_function=self._fs.exists,
        glob_function=self._fs.glob,
    )

    self._fs_open_args_save = _fs_open_args_save

    # Handle default save arguments
    self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}

    if overwrite and version is not None:
        warn(
            "Setting 'overwrite=True' is ineffective if versioning "
            "is enabled, since the versioned path must not already "
            "exist; overriding flag with 'overwrite=False' instead."
        )
        overwrite = False
    self._overwrite = overwrite

DEFAULT_SAVE_ARGS class-attribute instance-attribute

DEFAULT_SAVE_ARGS = {}

_fs instance-attribute

_fs = filesystem(_protocol, **_credentials, **_fs_args)

_fs_open_args_save instance-attribute

_fs_open_args_save = _fs_open_args_save

_overwrite instance-attribute

_overwrite = overwrite

_protocol instance-attribute

_protocol = protocol

_save_args instance-attribute

_save_args = {
    None: DEFAULT_SAVE_ARGS,
    None: save_args or {},
}

metadata instance-attribute

metadata = metadata

_describe

_describe()
Source code in kedro_datasets/matplotlib/matplotlib_dataset.py
165
166
167
168
169
170
171
def _describe(self) -> dict[str, Any]:
    return {
        "filepath": self._filepath,
        "protocol": self._protocol,
        "save_args": self._save_args,
        "version": self._version,
    }

_exists

_exists()
Source code in kedro_datasets/matplotlib/matplotlib_dataset.py
217
218
219
def _exists(self) -> bool:
    load_path = get_filepath_str(self._get_load_path(), self._protocol)
    return self._fs.exists(load_path)

_invalidate_cache

_invalidate_cache()

Invalidate underlying filesystem caches.

Source code in kedro_datasets/matplotlib/matplotlib_dataset.py
225
226
227
228
def _invalidate_cache(self) -> None:
    """Invalidate underlying filesystem caches."""
    filepath = get_filepath_str(self._filepath, self._protocol)
    self._fs.invalidate_cache(filepath)

_release

_release()
Source code in kedro_datasets/matplotlib/matplotlib_dataset.py
221
222
223
def _release(self) -> None:
    super()._release()
    self._invalidate_cache()

_save_to_fs

_save_to_fs(full_key_path, plot)
Source code in kedro_datasets/matplotlib/matplotlib_dataset.py
210
211
212
213
214
215
def _save_to_fs(self, full_key_path: str, plot: Figure):
    bytes_buffer = io.BytesIO()
    plot.savefig(bytes_buffer, **self._save_args)

    with self._fs.open(full_key_path, **self._fs_open_args_save) as fs_file:
        fs_file.write(bytes_buffer.getvalue())

load

load()

Loading is not supported for MatplotlibDataset.

Raises:

Returns:

  • NoReturn

    Never returns as it always raises an exception.

Source code in kedro_datasets/matplotlib/matplotlib_dataset.py
173
174
175
176
177
178
179
180
181
182
183
def load(self) -> NoReturn:
    """
    Loading is not supported for MatplotlibDataset.

    Raises:
        DatasetError: When called with any arguments.

    Returns:
        Never returns as it always raises an exception.
    """
    raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'")

preview

preview()

Generates a preview of the matplotlib dataset as a base64 encoded image.

Returns:

  • str

    A base64 encoded string representing the matplotlib plot image.

Source code in kedro_datasets/matplotlib/matplotlib_dataset.py
230
231
232
233
234
235
236
237
238
239
240
def preview(self) -> ImagePreview:
    """
    Generates a preview of the matplotlib dataset as a base64 encoded image.

    Returns:
        str: A base64 encoded string representing the matplotlib plot image.
    """
    load_path = get_filepath_str(self._get_load_path(), self._protocol)
    with self._fs.open(load_path, mode="rb") as img_file:
        base64_bytes = base64.b64encode(img_file.read())
    return ImagePreview(base64_bytes.decode("utf-8"))

save

save(data)
Source code in kedro_datasets/matplotlib/matplotlib_dataset.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def save(self, data: Figure | (list[Figure] | dict[str, Figure])) -> None:
    save_path = self._get_save_path()

    if isinstance(data, list | dict) and self._overwrite and self._exists():
        self._fs.rm(get_filepath_str(save_path, self._protocol), recursive=True)

    if isinstance(data, list):
        for index, plot in enumerate(data):
            full_key_path = get_filepath_str(
                save_path / f"{index}.png", self._protocol
            )
            self._save_to_fs(full_key_path=full_key_path, plot=plot)
    elif isinstance(data, dict):
        for plot_name, plot in data.items():
            validate_sub_path(plot_name, str(save_path))
            full_key_path = get_filepath_str(save_path / plot_name, self._protocol)
            self._save_to_fs(full_key_path=full_key_path, plot=plot)
    else:
        full_key_path = get_filepath_str(save_path, self._protocol)
        self._save_to_fs(full_key_path=full_key_path, plot=data)

    plt.close("all")

    self._invalidate_cache()