Source code for kedro.runner.runner

"""``AbstractRunner`` is the base class for all ``Pipeline`` runner
implementations.
"""

from __future__ import annotations

import inspect
import logging
import warnings
from abc import ABC, abstractmethod
from collections import deque
from typing import TYPE_CHECKING, Any

from kedro import KedroDeprecationWarning
from kedro.framework.hooks.manager import _NullPluginManager
from kedro.io import CatalogProtocol, MemoryDataset, SharedMemoryDataset
from kedro.pipeline import Pipeline
from kedro.runner.task import Task

if TYPE_CHECKING:
    from collections.abc import Collection, Iterable

    from pluggy import PluginManager

    from kedro.pipeline.node import Node


[docs] class AbstractRunner(ABC): """``AbstractRunner`` is the base class for all ``Pipeline`` runner implementations. """
[docs] def __init__( self, is_async: bool = False, extra_dataset_patterns: dict[str, dict[str, Any]] | None = None, ): """Instantiates the runner class. Args: is_async: If True, the node inputs and outputs are loaded and saved asynchronously with threads. Defaults to False. extra_dataset_patterns: Extra dataset factory patterns to be added to the catalog during the run. This is used to set the default datasets on the Runner instances. """ self._is_async = is_async self._extra_dataset_patterns = extra_dataset_patterns
@property def _logger(self) -> logging.Logger: return logging.getLogger(self.__module__)
[docs] def run( self, pipeline: Pipeline, catalog: CatalogProtocol, hook_manager: PluginManager | None = None, session_id: str | None = None, ) -> dict[str, Any]: """Run the ``Pipeline`` using the datasets provided by ``catalog`` and save results back to the same objects. Args: pipeline: The ``Pipeline`` to run. catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. session_id: The id of the session. Raises: ValueError: Raised when ``Pipeline`` inputs cannot be satisfied. Returns: Any node outputs that cannot be processed by the catalog. These are returned in a dictionary, where the keys are defined by the node outputs. """ # Check which datasets used in the pipeline are in the catalog or match # a pattern in the catalog, not including extra dataset patterns # Run a warm-up to materialize all datasets in the catalog before run warmed_up_ds = [] for ds in pipeline.datasets(): if ds in catalog: warmed_up_ds.append(ds) _ = catalog._get_dataset(ds) # Check if there are any input datasets that aren't in the catalog and # don't match a pattern in the catalog. unsatisfied = pipeline.inputs() - set(warmed_up_ds) if unsatisfied: raise ValueError( f"Pipeline input(s) {unsatisfied} not found in the {catalog.__class__.__name__}" ) # Register the default dataset pattern with the catalog catalog = catalog.shallow_copy( extra_dataset_patterns=self._extra_dataset_patterns ) hook_or_null_manager = hook_manager or _NullPluginManager() # Check which datasets used in the pipeline are in the catalog or match # a pattern in the catalog, including added extra_dataset_patterns registered_ds = [ds for ds in pipeline.datasets() if ds in catalog] if self._is_async: self._logger.info( "Asynchronous mode is enabled for loading and saving data" ) self._run(pipeline, catalog, hook_or_null_manager, session_id) # type: ignore[arg-type] self._logger.info("Pipeline execution completed successfully.") # Identify MemoryDataset in the catalog memory_datasets = { ds_name for ds_name, ds in catalog._datasets.items() if isinstance(ds, MemoryDataset) or isinstance(ds, SharedMemoryDataset) } # Check if there's any output datasets that aren't in the catalog and don't match a pattern # in the catalog and include MemoryDataset. free_outputs = pipeline.outputs() - (set(registered_ds) - memory_datasets) run_output = {ds_name: catalog.load(ds_name) for ds_name in free_outputs} return run_output
[docs] def run_only_missing( self, pipeline: Pipeline, catalog: CatalogProtocol, hook_manager: PluginManager ) -> dict[str, Any]: """Run only the missing outputs from the ``Pipeline`` using the datasets provided by ``catalog``, and save results back to the same objects. Args: pipeline: The ``Pipeline`` to run. catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. Raises: ValueError: Raised when ``Pipeline`` inputs cannot be satisfied. Returns: Any node outputs that cannot be processed by the catalog. These are returned in a dictionary, where the keys are defined by the node outputs. """ free_outputs = pipeline.outputs() - set(catalog.list()) missing = {ds for ds in catalog.list() if not catalog.exists(ds)} to_build = free_outputs | missing to_rerun = pipeline.only_nodes_with_outputs(*to_build) + pipeline.from_inputs( *to_build ) # We also need any missing datasets that are required to run the # `to_rerun` pipeline, including any chains of missing datasets. unregistered_ds = pipeline.datasets() - set(catalog.list()) output_to_unregistered = pipeline.only_nodes_with_outputs(*unregistered_ds) input_from_unregistered = to_rerun.inputs() & unregistered_ds to_rerun += output_to_unregistered.to_outputs(*input_from_unregistered) return self.run(to_rerun, catalog, hook_manager)
@abstractmethod # pragma: no cover def _run( self, pipeline: Pipeline, catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> None: """The abstract interface for running pipelines, assuming that the inputs have already been checked and normalized by run(). Args: pipeline: The ``Pipeline`` to run. catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. session_id: The id of the session. """ pass def _suggest_resume_scenario( self, pipeline: Pipeline, done_nodes: Iterable[Node], catalog: CatalogProtocol, ) -> None: """ Suggest a command to the user to resume a run after it fails. The run should be started from the point closest to the failure for which persisted input exists. Args: pipeline: the ``Pipeline`` of the run. done_nodes: the ``Node``s that executed successfully. catalog: an implemented instance of ``CatalogProtocol`` of the run. """ remaining_nodes = set(pipeline.nodes) - set(done_nodes) postfix = "" if done_nodes: start_node_names = _find_nodes_to_resume_from( pipeline=pipeline, unfinished_nodes=remaining_nodes, catalog=catalog, ) start_nodes_str = ",".join(sorted(start_node_names)) postfix += f' --from-nodes "{start_nodes_str}"' if not postfix: self._logger.warning( "No nodes ran. Repeat the previous command to attempt a new run." ) else: self._logger.warning( f"There are {len(remaining_nodes)} nodes that have not run.\n" "You can resume the pipeline run from the nearest nodes with " "persisted inputs by adding the following " f"argument to your previous command:\n{postfix}" ) @staticmethod def _release_datasets( node: Node, catalog: CatalogProtocol, load_counts: dict, pipeline: Pipeline ) -> None: """Decrement dataset load counts and release any datasets we've finished with""" for dataset in node.inputs: load_counts[dataset] -= 1 if load_counts[dataset] < 1 and dataset not in pipeline.inputs(): catalog.release(dataset) for dataset in node.outputs: if load_counts[dataset] < 1 and dataset not in pipeline.outputs(): catalog.release(dataset)
def _find_nodes_to_resume_from( pipeline: Pipeline, unfinished_nodes: Collection[Node], catalog: CatalogProtocol ) -> set[str]: """Given a collection of unfinished nodes in a pipeline using a certain catalog, find the node names to pass to pipeline.from_nodes() to cover all unfinished nodes, including any additional nodes that should be re-run if their outputs are not persisted. Args: pipeline: the ``Pipeline`` to find starting nodes for. unfinished_nodes: collection of ``Node``s that have not finished yet catalog: an implemented instance of ``CatalogProtocol`` of the run. Returns: Set of node names to pass to pipeline.from_nodes() to continue the run. """ nodes_to_be_run = _find_all_nodes_for_resumed_pipeline( pipeline, unfinished_nodes, catalog ) # Find which of the remaining nodes would need to run first (in topo sort) persistent_ancestors = _find_initial_node_group(pipeline, nodes_to_be_run) return {n.name for n in persistent_ancestors} def _find_all_nodes_for_resumed_pipeline( pipeline: Pipeline, unfinished_nodes: Iterable[Node], catalog: CatalogProtocol ) -> set[Node]: """Breadth-first search approach to finding the complete set of ``Node``s which need to run to cover all unfinished nodes, including any additional nodes that should be re-run if their outputs are not persisted. Args: pipeline: the ``Pipeline`` to analyze. unfinished_nodes: the iterable of ``Node``s which have not finished yet. catalog: an implemented instance of ``CatalogProtocol`` of the run. Returns: A set containing all input unfinished ``Node``s and all remaining ``Node``s that need to run in case their outputs are not persisted. """ nodes_to_run = set(unfinished_nodes) initial_nodes = _nodes_with_external_inputs(unfinished_nodes) queue, visited = deque(initial_nodes), set(initial_nodes) while queue: current_node = queue.popleft() nodes_to_run.add(current_node) # Look for parent nodes which produce non-persistent inputs (if those exist) non_persistent_inputs = _enumerate_non_persistent_inputs(current_node, catalog) for node in _enumerate_nodes_with_outputs(pipeline, non_persistent_inputs): if node in visited: continue visited.add(node) queue.append(node) # Make sure no downstream tasks are skipped nodes_to_run = set(pipeline.from_nodes(*(n.name for n in nodes_to_run)).nodes) return nodes_to_run def _nodes_with_external_inputs(nodes_of_interest: Iterable[Node]) -> set[Node]: """For given ``Node``s , find their subset which depends on external inputs of the ``Pipeline`` they constitute. External inputs are pipeline inputs not produced by other ``Node``s in the ``Pipeline``. Args: nodes_of_interest: the ``Node``s to analyze. Returns: A set of ``Node``s that depend on external inputs of nodes of interest. """ p_nodes_of_interest = Pipeline(nodes_of_interest) p_nodes_with_external_inputs = p_nodes_of_interest.only_nodes_with_inputs( *p_nodes_of_interest.inputs() ) return set(p_nodes_with_external_inputs.nodes) def _enumerate_non_persistent_inputs(node: Node, catalog: CatalogProtocol) -> set[str]: """Enumerate non-persistent input datasets of a ``Node``. Args: node: the ``Node`` to check the inputs of. catalog: an implemented instance of ``CatalogProtocol`` of the run. Returns: Set of names of non-persistent inputs of given ``Node``. """ # We use _datasets because they pertain parameter name format catalog_datasets = catalog._datasets non_persistent_inputs: set[str] = set() for node_input in node.inputs: if node_input.startswith("params:"): continue if ( node_input not in catalog_datasets or catalog_datasets[node_input]._EPHEMERAL ): non_persistent_inputs.add(node_input) return non_persistent_inputs def _enumerate_nodes_with_outputs( pipeline: Pipeline, outputs: Collection[str] ) -> list[Node]: """For given outputs, returns a list containing nodes that generate them in the given ``Pipeline``. Args: pipeline: the ``Pipeline`` to search for nodes in. outputs: the dataset names to find source nodes for. Returns: A list of all ``Node``s that are producing ``outputs``. """ parent_pipeline = pipeline.only_nodes_with_outputs(*outputs) return parent_pipeline.nodes def _find_initial_node_group(pipeline: Pipeline, nodes: Iterable[Node]) -> list[Node]: """Given a collection of ``Node``s in a ``Pipeline``, find the initial group of ``Node``s to be run (in topological order). This can be used to define a sub-pipeline with the smallest possible set of nodes to pass to --from-nodes. Args: pipeline: the ``Pipeline`` to search for initial ``Node``s in. nodes: the ``Node``s to find initial group for. Returns: A list of initial ``Node``s to run given inputs (in topological order). """ node_names = {n.name for n in nodes} if len(node_names) == 0: return [] sub_pipeline = pipeline.only_nodes(*node_names) initial_nodes = sub_pipeline.grouped_nodes[0] return initial_nodes
[docs] def run_node( node: Node, catalog: CatalogProtocol, hook_manager: PluginManager, is_async: bool = False, session_id: str | None = None, ) -> Node: """Run a single `Node` with inputs from and outputs to the `catalog`. Args: node: The ``Node`` to run. catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs. hook_manager: The ``PluginManager`` to activate hooks. is_async: If True, the node inputs and outputs are loaded and saved asynchronously with threads. Defaults to False. session_id: The session id of the pipeline run. Raises: ValueError: Raised if is_async is set to True for nodes wrapping generator functions. Returns: The node argument. """ warnings.warn( "`run_node()` has been deprecated and will be removed in Kedro 0.20.0", KedroDeprecationWarning, ) if is_async and inspect.isgeneratorfunction(node.func): raise ValueError( f"Async data loading and saving does not work with " f"nodes wrapping generator functions. Please make " f"sure you don't use `yield` anywhere " f"in node {node!s}." ) task = Task( node=node, catalog=catalog, hook_manager=hook_manager, is_async=is_async, session_id=session_id, ) node = task.execute() return node