Skip to content

StudyDataset

StudyDataset loads and saves data from/to an Optuna study.

kedro_datasets_experimental.optuna.StudyDataset

StudyDataset(
    *,
    backend,
    database,
    study_name,
    load_args=None,
    version=None,
    credentials=None,
    metadata=None
)

Bases: AbstractVersionedDataset[Study, Study]

StudyDataset loads/saves data from/to an optuna Study.

Examples:

Using the YAML API:

review_prediction_study:
  type: kedro_datasets_experimental.optuna.StudyDataset
  backend: sqlite
  database: data/05_model_input/review_prediction_study.db
  load_args:
    sampler:
      class: TPESampler
      n_startup_trials: 10
      n_ei_candidates: 5
    pruner:
      class: NopPruner
  versioned: true

price_prediction_study:
  type: kedro_datasets_experimental.optuna.StudyDataset
  backend: postgresql
  database: optuna_db
  credentials: dev_optuna_postgresql

Using the Python API:

>>> from kedro_datasets_experimental.optuna import StudyDataset
>>> from optuna.distributions import FloatDistribution
>>> import optuna
>>>
>>> study = optuna.create_study()
>>> trial = optuna.trial.create_trial(
...     params={"x": 2.0},
...     distributions={"x": FloatDistribution(0, 10)},
...     value=4.0,
... )
>>> study.add_trial(trial)
>>>
>>> dataset = StudyDataset(backend="sqlite", database="optuna.db")
>>> dataset.save(study)
>>> reloaded = dataset.load()
>>> assert len(reloaded.trials) == 1
>>> assert reloaded.trials[0].params["x"] == 2.0

Parameters:

  • backend (str) –

    Name of the database backend. This name should correspond to a module in SQLAlchemy.

  • database (str) –

    Name of the database.

  • study_name (str) –

    Name of the optuna Study.

  • load_args (dict[str, Any] | None, default: None ) –

    Optuna options for loading studies. Accepts a sampler and a pruner. If either are provided, a class matching any Optuna sampler, respecitively pruner class name should be provided, optionally with their argyments. Here you can find all available samplers and pruners and their arguments: - https://optuna.readthedocs.io/en/stable/reference/samplers/index.html - https://optuna.readthedocs.io/en/stable/reference/pruners.html All defaults are preserved.

  • version (Version, default: None ) –

    If specified, should be an instance of kedro.io.core.Version. If its load attribute is None, the latest version will be loaded. If its save attribute is None, save version will be autogenerated.

  • credentials (dict[str, Any] | None, default: None ) –

    Credentials required to get access to the underlying RDB. They can include username, password, host, and port.

  • metadata (dict[str, Any] | None, default: None ) –

    Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins.

Source code in kedro_datasets_experimental/optuna/study_dataset.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def __init__(  # noqa: PLR0913
    self,
    *,
    backend: str,
    database: str,
    study_name: str,
    load_args: dict[str, Any] | None = None,
    version: Version = None,
    credentials: dict[str, Any] | None = None,
    metadata: dict[str, Any] | None = None,
) -> None:
    """Creates a new instance of ``StudyDataset`` pointing to a concrete optuna
    Study on a specific relational database.

    Args:
        backend: Name of the database backend. This name should correspond to a module
            in ``SQLAlchemy``.
        database: Name of the database.
        study_name: Name of the optuna Study.
        load_args: Optuna options for loading studies. Accepts a `sampler` and a
            `pruner`. If either are provided, a `class` matching any Optuna `sampler`,
            respecitively `pruner` class name should be provided, optionally with
            their argyments. Here you can find all available samplers and pruners
            and their arguments:
            - https://optuna.readthedocs.io/en/stable/reference/samplers/index.html
            - https://optuna.readthedocs.io/en/stable/reference/pruners.html
            All defaults are preserved.
        version: If specified, should be an instance of
            ``kedro.io.core.Version``. If its ``load`` attribute is
            None, the latest version will be loaded. If its ``save``
            attribute is None, save version will be autogenerated.
        credentials: Credentials required to get access to the underlying RDB.
            They can include `username`, `password`, `host`, and `port`.
        metadata: Any arbitrary metadata.
            This is ignored by Kedro, but may be consumed by users or external plugins.
    """
    self._backend = self._validate_backend(backend=backend)
    self._database = self._validate_database(backend=backend, database=database)
    self._study_name = self._validate_study_name(study_name=study_name)

    credentials = self._validate_credentials(backend=backend, credentials=credentials)
    storage_url = URL.create(
        drivername=backend,
        database=database,
        **credentials,
    )

    self._storage_url = storage_url
    self.metadata = metadata

    filepath = None
    if backend == "sqlite":
        filepath = PurePosixPath(os.path.realpath(database))

    super().__init__(
        filepath=filepath,
        version=version,
        exists_function=self._study_name_exists,
        glob_function=self._study_name_glob,
    )

    # Handle default load and save and fs arguments
    self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}

DEFAULT_LOAD_ARGS class-attribute instance-attribute

DEFAULT_LOAD_ARGS = {'sampler': None, 'pruner': None}

_backend instance-attribute

_backend = _validate_backend(backend=backend)

_database instance-attribute

_database = _validate_database(
    backend=backend, database=database
)

_load_args instance-attribute

_load_args = {
    None: DEFAULT_LOAD_ARGS,
    None: load_args or {},
}

_storage_url instance-attribute

_storage_url = storage_url

_study_name instance-attribute

_study_name = _validate_study_name(study_name=study_name)

metadata instance-attribute

metadata = metadata

_describe

_describe()
Source code in kedro_datasets_experimental/optuna/study_dataset.py
224
225
226
227
228
229
230
231
def _describe(self) -> dict[str, Any]:
    return {
        "backend": self._backend,
        "database": self._database,
        "study_name": self._study_name,
        "load_args": self._load_args,
        "version": self._version,
    }

_exists

_exists()
Source code in kedro_datasets_experimental/optuna/study_dataset.py
342
343
344
345
346
347
348
def _exists(self) -> bool:
    try:
        load_study_name = self._get_load_study_name()
    except DatasetError:
        return False

    return self._study_name_exists(load_study_name)

_get_load_path

_get_load_path()
Source code in kedro_datasets_experimental/optuna/study_dataset.py
192
193
194
def _get_load_path(self) -> PurePosixPath:
    # Path is not affected by versioning
    return self._filepath

_get_load_study_name

_get_load_study_name()
Source code in kedro_datasets_experimental/optuna/study_dataset.py
196
197
198
199
200
201
202
def _get_load_study_name(self) -> str:
    if not self._version:
        # When versioning is disabled, load from original study name
        return self._study_name

    load_version = self.resolve_load_version()
    return str(self._get_versioned_path(load_version))

_get_pruner

_get_pruner(pruner_config)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def _get_pruner(self, pruner_config):
    if pruner_config is None:
        return None

    if "class" not in pruner_config:
        raise ValueError(
            "Optuna `pruner` 'class' should be specified when trying to load study "
            f"named '{self._study_name}' with a `pruner`."
        )

    pruner_class_name = pruner_config.pop("class")
    if pruner_class_name == "PatientPruner":
        pruner_config["wrapped_pruner"] = self._get_pruner(
            pruner_config.pop("wrapped_pruner")
        )

    pruner_class = getattr(optuna.pruners, pruner_class_name)

    return pruner_class(**pruner_config)

_get_sampler

_get_sampler(sampler_config)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def _get_sampler(self, sampler_config):
    if sampler_config is None:
        return None

    if "class" not in sampler_config:
        raise ValueError(
            "Optuna `sampler` 'class' should be specified when trying to load study "
            f"named '{self._study_name}' with a `sampler`."
        )

    sampler_class_name = sampler_config.pop("class")
    if sampler_class_name in ["QMCSampler", "CmaEsSampler", "GPSampler"]:
        sampler_config["independent_sampler"] = self._get_sampler(
            sampler_config.pop("independent_sampler")
        )

    if sampler_class_name == "PartialFixedSampler":
        sampler_config["base_sampler"] = self._get_sampler(
            sampler_config.pop("base_sampler")
        )

    sampler_class = getattr(optuna.samplers, sampler_class_name)

    return sampler_class(**sampler_config)

_get_save_path

_get_save_path()
Source code in kedro_datasets_experimental/optuna/study_dataset.py
204
205
206
def _get_save_path(self) -> PurePosixPath:
    # Path is not affected by versioning
    return self._filepath

_get_save_study_name

_get_save_study_name()
Source code in kedro_datasets_experimental/optuna/study_dataset.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def _get_save_study_name(self) -> str:
    if not self._version:
        # When versioning is disabled, return original study name
        return self._study_name

    save_version = self.resolve_save_version()
    versioned_study_name = self._get_versioned_path(save_version)

    if self._exists_function(str(versioned_study_name)):
        raise DatasetError(
            f"Study name '{versioned_study_name}' for {self!s} must not exist if "
            f"versioning is enabled."
        )

    return str(versioned_study_name)

_get_versioned_path

_get_versioned_path(version)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
180
181
182
def _get_versioned_path(self, version: str) -> PurePosixPath:
    study_name_posix = PurePosixPath(self._study_name)
    return study_name_posix / version / study_name_posix

_study_name_exists

_study_name_exists(study_name)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
325
326
327
328
329
330
331
332
def _study_name_exists(self, study_name) -> bool:
    if self._backend == "sqlite" and not os.path.isfile(self._database):
        return False

    storage_url_str = self._storage_url.render_as_string(hide_password=False)
    storage = optuna.storages.RDBStorage(url=storage_url_str)
    study_names = optuna.study.get_all_study_names(storage=storage)
    return study_name in study_names

_study_name_glob

_study_name_glob(pattern)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
334
335
336
337
338
339
340
def _study_name_glob(self, pattern):
    storage_url_str = self._storage_url.render_as_string(hide_password=False)
    storage = optuna.storages.RDBStorage(url=storage_url_str)
    study_names = optuna.study.get_all_study_names(storage=storage)
    for study_name in study_names:
        if fnmatch.fnmatch(study_name, pattern):
            yield study_name

_validate_backend

_validate_backend(backend)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
139
140
141
142
143
144
145
def _validate_backend(self, backend):
    valid_backends = list(registry.impls.keys()) + ["mssql", "mysql", "oracle", "postgresql", "sqlite"]
    if backend not in valid_backends:
        raise ValueError(
            f"Requested `backend` '{backend}' is not registered as an SQLAlchemy dialect."
        )
    return backend

_validate_credentials

_validate_credentials(backend, credentials)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
167
168
169
170
171
172
173
174
175
176
177
178
def _validate_credentials(self, backend, credentials):
    if backend == "sqlite" or credentials is None:
        return {}

    if not set(credentials.keys()) <= {"username", "password", "host", "port"}:
        raise ValueError(
            "Incorrect `credentials`. Provided `credentials` should contain "
            "`'username'`, `'password'`, `'host'`, and/or `'port'`. It contains "
            f"{set(credentials.keys())}."
        )

    return deepcopy(credentials)

_validate_database

_validate_database(backend, database)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def _validate_database(self, backend, database):
    if not isinstance(database, str):
        raise ValueError(f"`database` '{database}' is not a string.")

    if backend == "sqlite":
        if database == ":memory:":
            return database

        # Check if the file has an extension
        _, extension = os.path.splitext(database)
        if not extension:
            raise ValueError(f"The sqlite file `database` '{database}' does not have an extension.")

    return database

_validate_study_name

_validate_study_name(study_name)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
162
163
164
165
def _validate_study_name(self, study_name):
    if not isinstance(study_name, str):
        raise ValueError(f"`study_name` '{study_name}' is not a string.")
    return study_name

load

load()
Source code in kedro_datasets_experimental/optuna/study_dataset.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def load(self) -> optuna.Study:
    load_args = deepcopy(self._load_args)
    sampler_config = load_args.pop("sampler")
    sampler = self._get_sampler(sampler_config)

    pruner_config = load_args.pop("pruner")
    pruner = self._get_pruner(pruner_config)

    storage_url_str = self._storage_url.render_as_string(hide_password=False)
    storage = optuna.storages.RDBStorage(url=storage_url_str)
    study = optuna.load_study(
        storage=storage,
        study_name=self._get_load_study_name(),
        sampler=sampler,
        pruner=pruner,
    )

    return study

resolve_load_version

resolve_load_version()

Compute the version the dataset should be loaded with.

Source code in kedro_datasets_experimental/optuna/study_dataset.py
184
185
186
187
188
189
190
def resolve_load_version(self) -> str | None:
    """Compute the version the dataset should be loaded with."""
    if not self._version:
        return None
    if self._version.load:
        return self._version.load
    return self._fetch_latest_load_version()

save

save(study)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def save(self, study: optuna.Study) -> None:
    save_study_name = self._get_save_study_name()

    storage_url_str = self._storage_url.render_as_string(hide_password=False)
    if self._backend == "sqlite":
        os.makedirs(os.path.dirname(self._filepath), exist_ok=True)

        if not os.path.isfile(self._filepath):
            optuna.create_study(
                storage=storage_url_str,
            )

    storage = optuna.storages.RDBStorage(url=storage_url_str)

    # To overwrite an existing study, we need to first delete it if it exists
    if self._study_name_exists(save_study_name):
        optuna.delete_study(
            storage=storage,
            study_name=save_study_name,
        )

    optuna.copy_study(
        from_study_name=study.study_name,
        from_storage=study._storage,
        to_storage=storage,
        to_study_name=save_study_name,
    )