Source code for kedro.config.omegaconf_config

"""This module provides ``kedro.config`` with the functionality to load one
or more configuration files of yaml or json type from specified paths through OmegaConf.
from __future__ import annotations

import io
import logging
import mimetypes
import typing
from pathlib import Path
from typing import Any, Callable, Iterable

import fsspec
from omegaconf import DictConfig, OmegaConf
from omegaconf.errors import InterpolationResolutionError, UnsupportedInterpolationType
from omegaconf.resolvers import oc
from yaml.parser import ParserError
from yaml.scanner import ScannerError

from kedro.config.abstract_config import AbstractConfigLoader, MissingConfigException

_config_logger = logging.getLogger(__name__)

_NO_VALUE = object()

[docs] class OmegaConfigLoader(AbstractConfigLoader): """Recursively scan directories (config paths) contained in ``conf_source`` for configuration files with a ``yaml``, ``yml`` or ``json`` extension, load and merge them through ``OmegaConf`` ( and return them in the form of a config dictionary. The first processed config path is the ``base`` directory inside ``conf_source``. The optional ``env`` argument can be used to specify a subdirectory of ``conf_source`` to process as a config path after ``base``. When the same top-level key appears in any two config files located in the same (sub)directory, a ``ValueError`` is raised. When the same key appears in any two config files located in different (sub)directories, the last processed config path takes precedence and overrides this key and any sub-keys. You can access the different configurations as follows: :: >>> import logging.config >>> from kedro.config import OmegaConfigLoader >>> from kedro.framework.project import settings >>> >>> conf_path = str(project_path / settings.CONF_SOURCE) >>> conf_loader = OmegaConfigLoader(conf_source=conf_path, env="local") >>> >>> conf_catalog = conf_loader["catalog"] >>> conf_params = conf_loader["parameters"] ``OmegaConf`` supports variable interpolation in configuration It is recommended to use this instead of yaml anchors with the ``OmegaConfigLoader``. This version of the ``OmegaConfigLoader`` does not support any of the built-in ``OmegaConf`` resolvers. Support for resolvers might be added in future versions. To use this class, change the setting for the `CONFIG_LOADER_CLASS` constant in ``. Example: :: >>> # in >>> from kedro.config import OmegaConfigLoader >>> >>> CONFIG_LOADER_CLASS = OmegaConfigLoader """
[docs] def __init__( # noqa: PLR0913 self, conf_source: str, env: str | None = None, runtime_params: dict[str, Any] | None = None, *, config_patterns: dict[str, list[str]] | None = None, base_env: str | None = None, default_run_env: str | None = None, custom_resolvers: dict[str, Callable] | None = None, merge_strategy: dict[str, str] | None = None, ): """Instantiates a ``OmegaConfigLoader``. Args: conf_source: Path to use as root directory for loading configuration. env: Environment that will take precedence over base. runtime_params: Extra parameters passed to a Kedro run. config_patterns: Regex patterns that specify the naming convention for configuration files so they can be loaded. Can be customised by supplying config_patterns as in `CONFIG_LOADER_ARGS` in ``. base_env: Name of the base environment. Defaults to `"base"`. This is used in the `conf_paths` property method to construct the configuration paths. default_run_env: Name of the default run environment. Defaults to `"local"`. Can be overridden by supplying the `env` argument. custom_resolvers: A dictionary of custom resolvers to be registered. For more information, see here: merge_strategy: A dictionary that specifies the merging strategy for each configuration type. The accepted merging strategies are `soft` and `destructive`. Defaults to `destructive`. """ self.base_env = base_env or "" self.default_run_env = default_run_env or "" self.merge_strategy = merge_strategy or {} self.config_patterns = { "catalog": ["catalog*", "catalog*/**", "**/catalog*"], "parameters": ["parameters*", "parameters*/**", "**/parameters*"], "credentials": ["credentials*", "credentials*/**", "**/credentials*"], "globals": ["globals.yml"], } self.config_patterns.update(config_patterns or {}) # Deactivate oc.env built-in resolver for OmegaConf OmegaConf.clear_resolver("oc.env") # Register user provided custom resolvers if custom_resolvers: self._register_new_resolvers(custom_resolvers) # Register globals resolver self._register_globals_resolver() file_mimetype, _ = mimetypes.guess_type(conf_source) if file_mimetype == "application/x-tar": self._protocol = "tar" elif file_mimetype in ( "application/zip", "application/x-zip-compressed", "application/zip-compressed", ): self._protocol = "zip" else: self._protocol = "file" self._fs = fsspec.filesystem(protocol=self._protocol, fo=conf_source) super().__init__( conf_source=conf_source, env=env, runtime_params=runtime_params, ) try: self._globals = self["globals"] except MissingConfigException: self._globals = {}
def __setitem__(self, key: str, value: Any) -> None: if key == "globals": # Update the cached value at self._globals since it is used by the globals resolver self._globals = value super().__setitem__(key, value) def __getitem__(self, key: str) -> dict[str, Any]: # noqa: PLR0912 """Get configuration files by key, load and merge them, and return them in the form of a config dictionary. Args: key: Key of the configuration type to fetch. Raises: KeyError: If key provided isn't present in the config_patterns of this ``OmegaConfigLoader`` instance. MissingConfigException: If no configuration files exist matching the patterns mapped to the provided key. Returns: Dict[str, Any]: A Python dictionary with the combined configuration from all configuration files. """ # Allow bypassing of loading config from patterns if a key and value have been set # explicitly on the ``OmegaConfigLoader`` instance. # Re-register runtime params resolver incase it was previously deactivated self._register_runtime_params_resolver() if key in self: return super().__getitem__(key) # type: ignore[no-any-return] if key not in self.config_patterns: raise KeyError( f"No config patterns were found for '{key}' in your config loader" ) patterns = [*self.config_patterns[key]] if key == "globals": # "runtime_params" resolver is not allowed in globals. OmegaConf.clear_resolver("runtime_params") read_environment_variables = key == "credentials" processed_files: set[Path] = set() # Load base env config if self._protocol == "file": base_path = str(Path(self.conf_source) / self.base_env) else: base_path = str(Path("", detail=False)[-1]) / self.base_env) try: base_config = self.load_and_merge_dir_config( # type: ignore[no-untyped-call] base_path, patterns, key, processed_files, read_environment_variables ) except UnsupportedInterpolationType as exc: if "runtime_params" in str(exc): raise UnsupportedInterpolationType( "The `runtime_params:` resolver is not supported for globals." ) else: raise exc config = base_config # Load chosen env config run_env = self.env or self.default_run_env if self._protocol == "file": env_path = str(Path(self.conf_source) / run_env) else: env_path = str(Path("", detail=False)[-1]) / run_env) try: env_config = self.load_and_merge_dir_config( # type: ignore[no-untyped-call] env_path, patterns, key, processed_files, read_environment_variables ) except UnsupportedInterpolationType as exc: if "runtime_params" in str(exc): raise UnsupportedInterpolationType( "The `runtime_params:` resolver is not supported for globals." ) else: raise exc merging_strategy = self.merge_strategy.get(key) if merging_strategy == "soft": resulting_config = self._soft_merge(config, env_config) elif merging_strategy == "destructive" or not merging_strategy: resulting_config = self._destructive_merge(config, env_config, env_path) else: raise ValueError( f"Merging strategy {merging_strategy} not supported. The accepted merging " f"strategies are `soft` and `destructive`." ) if not processed_files and key != "globals": raise MissingConfigException( f"No files of YAML or JSON format found in {base_path} or {env_path} matching" f" the glob pattern(s): {[*self.config_patterns[key]]}" ) return resulting_config # type: ignore[no-any-return] def __repr__(self) -> str: # pragma: no cover return ( f"OmegaConfigLoader(conf_source={self.conf_source}, env={self.env}, " f"config_patterns={self.config_patterns})" )
[docs] @typing.no_type_check def load_and_merge_dir_config( # noqa: PLR0913 self, conf_path: str, patterns: Iterable[str], key: str, processed_files: set, read_environment_variables: bool | None = False, ) -> dict[str, Any]: """Recursively load and merge all configuration files in a directory using OmegaConf, which satisfy a given list of glob patterns from a specific path. Args: conf_path: Path to configuration directory. patterns: List of glob patterns to match the filenames against. key: Key of the configuration type to fetch. processed_files: Set of files read for a given configuration type. read_environment_variables: Whether to resolve environment variables. Raises: MissingConfigException: If configuration path doesn't exist or isn't valid. ValueError: If two or more configuration files contain the same key(s). ParserError: If config file contains invalid YAML or JSON syntax. Returns: Resulting configuration dictionary. """ # noqa: too-many-locals if not self._fs.isdir(Path(conf_path).as_posix()): raise MissingConfigException( f"Given configuration path either does not exist " f"or is not a valid directory: {conf_path}" ) paths = [] for pattern in patterns: for each in self._fs.glob(Path(f"{str(conf_path)}/{pattern}").as_posix()): if not self._is_hidden(each): paths.append(Path(each)) deduplicated_paths = set(paths) config_files_filtered = [ path for path in deduplicated_paths if self._is_valid_config_path(path) ] config_per_file = {} for config_filepath in config_files_filtered: try: with as open_config: # As fsspec doesn't allow the file to be read as StringIO, # this is a workaround to read it as a binary file and decode it back to utf8. tmp_fo = io.StringIO("utf8")) config = OmegaConf.load(tmp_fo) processed_files.add(config_filepath) if read_environment_variables: self._resolve_environment_variables(config) config_per_file[config_filepath] = config except (ParserError, ScannerError) as exc: line = exc.problem_mark.line cursor = exc.problem_mark.column raise ParserError( f"Invalid YAML or JSON file {Path(config_filepath).as_posix()}," f" unable to read line {line}, position {cursor}." ) from exc seen_file_to_keys = { file: set(config.keys()) for file, config in config_per_file.items() } aggregate_config = config_per_file.values() self._check_duplicates(seen_file_to_keys) if not aggregate_config: return {} if key == "parameters": # Merge with runtime parameters only for "parameters" return OmegaConf.to_container( OmegaConf.merge(*aggregate_config, self.runtime_params), resolve=True ) return { k: v for k, v in OmegaConf.to_container( OmegaConf.merge(*aggregate_config), resolve=True ).items() if not k.startswith("_") }
def _is_valid_config_path(self, path: Path) -> bool: """Check if given path is a file path and file type is yaml or json.""" posix_path = path.as_posix() return self._fs.isfile(str(posix_path)) and path.suffix in [ ".yml", ".yaml", ".json", ] def _register_globals_resolver(self) -> None: """Register the globals resolver""" OmegaConf.register_new_resolver( "globals", self._get_globals_value, replace=True, ) def _register_runtime_params_resolver(self) -> None: OmegaConf.register_new_resolver( "runtime_params", self._get_runtime_value, replace=True, ) def _get_globals_value(self, variable: str, default_value: Any = _NO_VALUE) -> Any: """Return the globals values to the resolver""" if variable.startswith("_"): raise InterpolationResolutionError( "Keys starting with '_' are not supported for globals." ) globals_oc = OmegaConf.create(self._globals) interpolated_value = globals_oc, variable, default=default_value ) if interpolated_value != _NO_VALUE: return interpolated_value else: raise InterpolationResolutionError( f"Globals key '{variable}' not found and no default value provided." ) def _get_runtime_value(self, variable: str, default_value: Any = _NO_VALUE) -> Any: """Return the runtime params values to the resolver""" runtime_oc = OmegaConf.create(self.runtime_params) interpolated_value = runtime_oc, variable, default=default_value ) if interpolated_value != _NO_VALUE: return interpolated_value else: raise InterpolationResolutionError( f"Runtime parameter '{variable}' not found and no default value provided." ) @staticmethod def _register_new_resolvers(resolvers: dict[str, Callable]) -> None: """Register custom resolvers""" for name, resolver in resolvers.items(): if not OmegaConf.has_resolver(name): msg = f"Registering new custom resolver: {name}" _config_logger.debug(msg) OmegaConf.register_new_resolver(name=name, resolver=resolver) @staticmethod def _check_duplicates(seen_files_to_keys: dict[Path, set[Any]]) -> None: duplicates = [] filepaths = list(seen_files_to_keys.keys()) for i, filepath1 in enumerate(filepaths, 1): config1 = seen_files_to_keys[filepath1] for filepath2 in filepaths[i:]: config2 = seen_files_to_keys[filepath2] combined_keys = config1 & config2 overlapping_keys = { key for key in combined_keys if not key.startswith("_") } if overlapping_keys: sorted_keys = ", ".join(sorted(overlapping_keys)) if len(sorted_keys) > 100: # noqa: PLR2004 sorted_keys = sorted_keys[:100] + "..." duplicates.append( f"Duplicate keys found in {filepath1} and {filepath2}: {sorted_keys}" ) if duplicates: dup_str = "\n".join(duplicates) raise ValueError(f"{dup_str}") @staticmethod def _resolve_environment_variables(config: DictConfig) -> None: """Use the ``oc.env`` resolver to read environment variables and replace them in-place, clearing the resolver after the operation is complete if it was not registered beforehand. Arguments: config {Dict[str, Any]} -- The configuration dictionary to resolve. """ if not OmegaConf.has_resolver("oc.env"): OmegaConf.register_new_resolver("oc.env", oc.env) OmegaConf.resolve(config) OmegaConf.clear_resolver("oc.env") else: OmegaConf.resolve(config) @staticmethod def _destructive_merge( config: dict[str, Any], env_config: dict[str, Any], env_path: str ) -> dict[str, Any]: # Destructively merge the two env dirs. The chosen env will override base. common_keys = config.keys() & env_config.keys() if common_keys: sorted_keys = ", ".join(sorted(common_keys)) msg = ( "Config from path '%s' will override the following " "existing top-level config keys: %s" ) _config_logger.debug(msg, env_path, sorted_keys) config.update(env_config) return config @staticmethod def _soft_merge(config: dict[str, Any], env_config: dict[str, Any]) -> Any: # Soft merge the two env dirs. The chosen env will override base if keys clash. return OmegaConf.to_container(OmegaConf.merge(config, env_config)) def _is_hidden(self, path_str: str) -> bool: """Check if path contains any hidden directory or is a hidden file""" path = Path(path_str) conf_path = Path(self.conf_source).resolve().as_posix() if self._protocol == "file": path = path.resolve() posix_path = path.as_posix() if posix_path.startswith(conf_path): posix_path = posix_path.replace(conf_path, "") parts = posix_path.split(self._fs.sep) # filesystem specific separator HIDDEN = "." # Check if any component (folder or file) starts with a dot (.) return any(part.startswith(HIDDEN) for part in parts)