StudyDataset(
*,
backend,
database,
study_name,
load_args=None,
version=None,
credentials=None,
metadata=None
)
Bases: AbstractVersionedDataset[Study, Study]
StudyDataset loads/saves data from/to an optuna Study.
Examples:
Using the YAML API:
review_prediction_study:
type: kedro_datasets_experimental.optuna.StudyDataset
backend: sqlite
database: data/05_model_input/review_prediction_study.db
load_args:
sampler:
class: TPESampler
n_startup_trials: 10
n_ei_candidates: 5
pruner:
class: NopPruner
versioned: true
price_prediction_study:
type: kedro_datasets_experimental.optuna.StudyDataset
backend: postgresql
database: optuna_db
credentials: dev_optuna_postgresql
Using the Python API:
>>> from kedro_datasets_experimental.optuna import StudyDataset
>>> from optuna.distributions import FloatDistribution
>>> import optuna
>>>
>>> study = optuna.create_study()
>>> trial = optuna.trial.create_trial(
... params={"x": 2.0},
... distributions={"x": FloatDistribution(0, 10)},
... value=4.0,
... )
>>> study.add_trial(trial)
>>>
>>> dataset = StudyDataset(backend="sqlite", database="optuna.db")
>>> dataset.save(study)
>>> reloaded = dataset.load()
>>> assert len(reloaded.trials) == 1
>>> assert reloaded.trials[0].params["x"] == 2.0
Parameters:
-
backend
(str)
–
Name of the database backend. This name should correspond to a module
in SQLAlchemy.
-
database
(str)
–
-
study_name
(str)
–
Name of the optuna Study.
-
load_args
(dict[str, Any] | None, default:
None
)
–
Optuna options for loading studies. Accepts a sampler and a
pruner. If either are provided, a class matching any Optuna sampler,
respecitively pruner class name should be provided, optionally with
their argyments. Here you can find all available samplers and pruners
and their arguments:
- https://optuna.readthedocs.io/en/stable/reference/samplers/index.html
- https://optuna.readthedocs.io/en/stable/reference/pruners.html
All defaults are preserved.
-
version
(Version, default:
None
)
–
If specified, should be an instance of
kedro.io.core.Version. If its load attribute is
None, the latest version will be loaded. If its save
attribute is None, save version will be autogenerated.
-
credentials
(dict[str, Any] | None, default:
None
)
–
Credentials required to get access to the underlying RDB.
They can include username, password, host, and port.
-
metadata
(dict[str, Any] | None, default:
None
)
–
Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
Source code in kedro_datasets_experimental/optuna/study_dataset.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137 | def __init__( # noqa: PLR0913
self,
*,
backend: str,
database: str,
study_name: str,
load_args: dict[str, Any] | None = None,
version: Version = None,
credentials: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Creates a new instance of ``StudyDataset`` pointing to a concrete optuna
Study on a specific relational database.
Args:
backend: Name of the database backend. This name should correspond to a module
in ``SQLAlchemy``.
database: Name of the database.
study_name: Name of the optuna Study.
load_args: Optuna options for loading studies. Accepts a `sampler` and a
`pruner`. If either are provided, a `class` matching any Optuna `sampler`,
respecitively `pruner` class name should be provided, optionally with
their argyments. Here you can find all available samplers and pruners
and their arguments:
- https://optuna.readthedocs.io/en/stable/reference/samplers/index.html
- https://optuna.readthedocs.io/en/stable/reference/pruners.html
All defaults are preserved.
version: If specified, should be an instance of
``kedro.io.core.Version``. If its ``load`` attribute is
None, the latest version will be loaded. If its ``save``
attribute is None, save version will be autogenerated.
credentials: Credentials required to get access to the underlying RDB.
They can include `username`, `password`, `host`, and `port`.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
self._backend = self._validate_backend(backend=backend)
self._database = self._validate_database(backend=backend, database=database)
self._study_name = self._validate_study_name(study_name=study_name)
credentials = self._validate_credentials(backend=backend, credentials=credentials)
storage_url = URL.create(
drivername=backend,
database=database,
**credentials,
)
self._storage_url = storage_url
self.metadata = metadata
filepath = None
if backend == "sqlite":
filepath = PurePosixPath(os.path.realpath(database))
super().__init__(
filepath=filepath,
version=version,
exists_function=self._study_name_exists,
glob_function=self._study_name_glob,
)
# Handle default load and save and fs arguments
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
|
DEFAULT_LOAD_ARGS
class-attribute
instance-attribute
DEFAULT_LOAD_ARGS = {'sampler': None, 'pruner': None}
_backend
instance-attribute
_backend = _validate_backend(backend=backend)
_database
instance-attribute
_database = _validate_database(
backend=backend, database=database
)
_load_args
instance-attribute
_load_args = {
None: DEFAULT_LOAD_ARGS,
None: load_args or {},
}
_storage_url
instance-attribute
_storage_url = storage_url
_study_name
instance-attribute
_study_name = _validate_study_name(study_name=study_name)
_describe
Source code in kedro_datasets_experimental/optuna/study_dataset.py
224
225
226
227
228
229
230
231 | def _describe(self) -> dict[str, Any]:
return {
"backend": self._backend,
"database": self._database,
"study_name": self._study_name,
"load_args": self._load_args,
"version": self._version,
}
|
_exists
Source code in kedro_datasets_experimental/optuna/study_dataset.py
342
343
344
345
346
347
348 | def _exists(self) -> bool:
try:
load_study_name = self._get_load_study_name()
except DatasetError:
return False
return self._study_name_exists(load_study_name)
|
_get_load_path
Source code in kedro_datasets_experimental/optuna/study_dataset.py
| def _get_load_path(self) -> PurePosixPath:
# Path is not affected by versioning
return self._filepath
|
_get_load_study_name
Source code in kedro_datasets_experimental/optuna/study_dataset.py
196
197
198
199
200
201
202 | def _get_load_study_name(self) -> str:
if not self._version:
# When versioning is disabled, load from original study name
return self._study_name
load_version = self.resolve_load_version()
return str(self._get_versioned_path(load_version))
|
_get_pruner
_get_pruner(pruner_config)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276 | def _get_pruner(self, pruner_config):
if pruner_config is None:
return None
if "class" not in pruner_config:
raise ValueError(
"Optuna `pruner` 'class' should be specified when trying to load study "
f"named '{self._study_name}' with a `pruner`."
)
pruner_class_name = pruner_config.pop("class")
if pruner_class_name == "PatientPruner":
pruner_config["wrapped_pruner"] = self._get_pruner(
pruner_config.pop("wrapped_pruner")
)
pruner_class = getattr(optuna.pruners, pruner_class_name)
return pruner_class(**pruner_config)
|
_get_sampler
_get_sampler(sampler_config)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256 | def _get_sampler(self, sampler_config):
if sampler_config is None:
return None
if "class" not in sampler_config:
raise ValueError(
"Optuna `sampler` 'class' should be specified when trying to load study "
f"named '{self._study_name}' with a `sampler`."
)
sampler_class_name = sampler_config.pop("class")
if sampler_class_name in ["QMCSampler", "CmaEsSampler", "GPSampler"]:
sampler_config["independent_sampler"] = self._get_sampler(
sampler_config.pop("independent_sampler")
)
if sampler_class_name == "PartialFixedSampler":
sampler_config["base_sampler"] = self._get_sampler(
sampler_config.pop("base_sampler")
)
sampler_class = getattr(optuna.samplers, sampler_class_name)
return sampler_class(**sampler_config)
|
_get_save_path
Source code in kedro_datasets_experimental/optuna/study_dataset.py
| def _get_save_path(self) -> PurePosixPath:
# Path is not affected by versioning
return self._filepath
|
_get_save_study_name
Source code in kedro_datasets_experimental/optuna/study_dataset.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222 | def _get_save_study_name(self) -> str:
if not self._version:
# When versioning is disabled, return original study name
return self._study_name
save_version = self.resolve_save_version()
versioned_study_name = self._get_versioned_path(save_version)
if self._exists_function(str(versioned_study_name)):
raise DatasetError(
f"Study name '{versioned_study_name}' for {self!s} must not exist if "
f"versioning is enabled."
)
return str(versioned_study_name)
|
_get_versioned_path
_get_versioned_path(version)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
| def _get_versioned_path(self, version: str) -> PurePosixPath:
study_name_posix = PurePosixPath(self._study_name)
return study_name_posix / version / study_name_posix
|
_study_name_exists
_study_name_exists(study_name)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
325
326
327
328
329
330
331
332 | def _study_name_exists(self, study_name) -> bool:
if self._backend == "sqlite" and not os.path.isfile(self._database):
return False
storage_url_str = self._storage_url.render_as_string(hide_password=False)
storage = optuna.storages.RDBStorage(url=storage_url_str)
study_names = optuna.study.get_all_study_names(storage=storage)
return study_name in study_names
|
_study_name_glob
_study_name_glob(pattern)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
334
335
336
337
338
339
340 | def _study_name_glob(self, pattern):
storage_url_str = self._storage_url.render_as_string(hide_password=False)
storage = optuna.storages.RDBStorage(url=storage_url_str)
study_names = optuna.study.get_all_study_names(storage=storage)
for study_name in study_names:
if fnmatch.fnmatch(study_name, pattern):
yield study_name
|
_validate_backend
_validate_backend(backend)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
139
140
141
142
143
144
145 | def _validate_backend(self, backend):
valid_backends = list(registry.impls.keys()) + ["mssql", "mysql", "oracle", "postgresql", "sqlite"]
if backend not in valid_backends:
raise ValueError(
f"Requested `backend` '{backend}' is not registered as an SQLAlchemy dialect."
)
return backend
|
_validate_credentials
_validate_credentials(backend, credentials)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
167
168
169
170
171
172
173
174
175
176
177
178 | def _validate_credentials(self, backend, credentials):
if backend == "sqlite" or credentials is None:
return {}
if not set(credentials.keys()) <= {"username", "password", "host", "port"}:
raise ValueError(
"Incorrect `credentials`. Provided `credentials` should contain "
"`'username'`, `'password'`, `'host'`, and/or `'port'`. It contains "
f"{set(credentials.keys())}."
)
return deepcopy(credentials)
|
_validate_database
_validate_database(backend, database)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160 | def _validate_database(self, backend, database):
if not isinstance(database, str):
raise ValueError(f"`database` '{database}' is not a string.")
if backend == "sqlite":
if database == ":memory:":
return database
# Check if the file has an extension
_, extension = os.path.splitext(database)
if not extension:
raise ValueError(f"The sqlite file `database` '{database}' does not have an extension.")
return database
|
_validate_study_name
_validate_study_name(study_name)
Source code in kedro_datasets_experimental/optuna/study_dataset.py
| def _validate_study_name(self, study_name):
if not isinstance(study_name, str):
raise ValueError(f"`study_name` '{study_name}' is not a string.")
return study_name
|
load
Source code in kedro_datasets_experimental/optuna/study_dataset.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295 | def load(self) -> optuna.Study:
load_args = deepcopy(self._load_args)
sampler_config = load_args.pop("sampler")
sampler = self._get_sampler(sampler_config)
pruner_config = load_args.pop("pruner")
pruner = self._get_pruner(pruner_config)
storage_url_str = self._storage_url.render_as_string(hide_password=False)
storage = optuna.storages.RDBStorage(url=storage_url_str)
study = optuna.load_study(
storage=storage,
study_name=self._get_load_study_name(),
sampler=sampler,
pruner=pruner,
)
return study
|
resolve_load_version
Compute the version the dataset should be loaded with.
Source code in kedro_datasets_experimental/optuna/study_dataset.py
184
185
186
187
188
189
190 | def resolve_load_version(self) -> str | None:
"""Compute the version the dataset should be loaded with."""
if not self._version:
return None
if self._version.load:
return self._version.load
return self._fetch_latest_load_version()
|
save
Source code in kedro_datasets_experimental/optuna/study_dataset.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323 | def save(self, study: optuna.Study) -> None:
save_study_name = self._get_save_study_name()
storage_url_str = self._storage_url.render_as_string(hide_password=False)
if self._backend == "sqlite":
os.makedirs(os.path.dirname(self._filepath), exist_ok=True)
if not os.path.isfile(self._filepath):
optuna.create_study(
storage=storage_url_str,
)
storage = optuna.storages.RDBStorage(url=storage_url_str)
# To overwrite an existing study, we need to first delete it if it exists
if self._study_name_exists(save_study_name):
optuna.delete_study(
storage=storage,
study_name=save_study_name,
)
optuna.copy_study(
from_study_name=study.study_name,
from_storage=study._storage,
to_storage=storage,
to_study_name=save_study_name,
)
|