Skip to content

mlrun.MLRunModel

kedro_datasets_experimental.mlrun.MLRunModel

MLRunModel(
    key=None,
    framework="sklearn",
    model_format="pkl",
    load_args=None,
    save_args=None,
)

Bases: MLRunAbstractDataset

Dataset for saving/loading models via MLRun.

Uses MLRun's log_model and get_artifact.

load_args and save_args accept any arguments supported by the corresponding MLRun API for your MLRun version; see the MLRun documentation.

Examples

Using the YAML API:

trained_model:
  type: kedro_datasets_experimental.mlrun.MLRunModel
  key: my_sklearn_model
  framework: sklearn
  model_format: pkl

Using the Python API:

from kedro_datasets_experimental.mlrun import MLRunModel

dataset = MLRunModel(
    key="my_sklearn_model",
    framework="sklearn",
    model_format="pkl",
)
dataset.save(trained_model)
loaded_model = dataset.load()

Parameters:

  • key (str | None, default: None ) –

    Artifact key for MLRun (defaults to catalog dataset name).

  • framework (str, default: 'sklearn' ) –

    ML framework name (e.g. "sklearn", "xgboost", "lightgbm").

  • model_format (str, default: 'pkl' ) –

    File format/extension for saving the model (e.g. "pkl").

  • load_args (dict[str, Any] | None, default: None ) –

    Passed to MLRun when loading; see MLRun docs for your version.

  • save_args (dict[str, Any] | None, default: None ) –

    Passed to log_model; see MLRun docs for your version.

Source code in kedro_datasets_experimental/mlrun/model.py
59
60
61
62
63
64
65
66
67
68
69
def __init__( # noqa: PLR0913
    self,
    key: str | None = None,
    framework: str = "sklearn",
    model_format: str = "pkl",
    load_args: dict[str, Any] | None = None,
    save_args: dict[str, Any] | None = None,
) -> None:
    super().__init__(key=key, save_args=save_args, load_args=load_args)
    self._framework = framework
    self._model_format = model_format.lower().lstrip(".")

_framework instance-attribute

_framework = framework

_model_format instance-attribute

_model_format = lstrip('.')

_describe

_describe()
Source code in kedro_datasets_experimental/mlrun/model.py
90
91
92
93
94
95
def _describe(self) -> dict[str, Any]:
    return {
        **super()._describe(),
        "framework": self._framework,
        "model_format": self._model_format,
    }

load

load()
Source code in kedro_datasets_experimental/mlrun/model.py
83
84
85
86
87
88
def load(self) -> Any:
    artifact = self._ctx_manager.project.get_artifact(self.key)
    target_path = artifact.get_target_path()
    model_file = artifact.model_file
    local_path = get_dataitem(target_path + model_file).local()
    return joblib.load(local_path)

save

save(data)
Source code in kedro_datasets_experimental/mlrun/model.py
71
72
73
74
75
76
77
78
79
80
81
def save(self, data: Any) -> None:
    with tempfile.TemporaryDirectory() as tmpdir:
        model_path = os.path.join(tmpdir, f"model.{self._model_format}")
        joblib.dump(data, model_path)

        self._ctx_manager.context.log_model(
            key=self.key,
            model_file=model_path,
            framework=self._framework,
            **self._save_args
        )