"""``ThreadRunner`` is an ``AbstractRunner`` implementation. It can
be used to run the ``Pipeline`` in parallel groups formed by toposort
using threads.
"""
from __future__ import annotations
import warnings
from collections import Counter
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from itertools import chain
from typing import Any
from pluggy import PluginManager
from kedro.io import DataCatalog
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node
from kedro.runner.runner import AbstractRunner, run_node
[docs]
class ThreadRunner(AbstractRunner):
"""``ThreadRunner`` is an ``AbstractRunner`` implementation. It can
be used to run the ``Pipeline`` in parallel groups formed by toposort
using threads.
"""
[docs]
def __init__(
self,
max_workers: int | None = None,
is_async: bool = False,
extra_dataset_patterns: dict[str, dict[str, Any]] | None = None,
):
"""
Instantiates the runner.
Args:
max_workers: Number of worker processes to spawn. If not set,
calculated automatically based on the pipeline configuration
and CPU core count.
is_async: If True, set to False, because `ThreadRunner`
doesn't support loading and saving the node inputs and
outputs asynchronously with threads. Defaults to False.
extra_dataset_patterns: Extra dataset factory patterns to be added to the DataCatalog
during the run. This is used to set the default datasets to MemoryDataset
for `ThreadRunner`.
Raises:
ValueError: bad parameters passed
"""
if is_async:
warnings.warn(
"'ThreadRunner' doesn't support loading and saving the "
"node inputs and outputs asynchronously with threads. "
"Setting 'is_async' to False."
)
default_dataset_pattern = {"{default}": {"type": "MemoryDataset"}}
self._extra_dataset_patterns = extra_dataset_patterns or default_dataset_pattern
super().__init__(
is_async=False, extra_dataset_patterns=self._extra_dataset_patterns
)
if max_workers is not None and max_workers <= 0:
raise ValueError("max_workers should be positive")
self._max_workers = max_workers
def _get_required_workers_count(self, pipeline: Pipeline) -> int:
"""
Calculate the max number of processes required for the pipeline
"""
# Number of nodes is a safe upper-bound estimate.
# It's also safe to reduce it by the number of layers minus one,
# because each layer means some nodes depend on other nodes
# and they can not run in parallel.
# It might be not a perfect solution, but good enough and simple.
required_threads = len(pipeline.nodes) - len(pipeline.grouped_nodes) + 1
return (
min(required_threads, self._max_workers)
if self._max_workers
else required_threads
)
def _run(
self,
pipeline: Pipeline,
catalog: DataCatalog,
hook_manager: PluginManager,
session_id: str | None = None,
) -> None:
"""The abstract interface for running pipelines.
Args:
pipeline: The ``Pipeline`` to run.
catalog: The ``DataCatalog`` from which to fetch data.
hook_manager: The ``PluginManager`` to activate hooks.
session_id: The id of the session.
Raises:
Exception: in case of any downstream node failure.
"""
nodes = pipeline.nodes
load_counts = Counter(chain.from_iterable(n.inputs for n in nodes))
node_dependencies = pipeline.node_dependencies
todo_nodes = set(node_dependencies.keys())
done_nodes: set[Node] = set()
futures = set()
done = None
max_workers = self._get_required_workers_count(pipeline)
with ThreadPoolExecutor(max_workers=max_workers) as pool:
while True:
ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes}
todo_nodes -= ready
for node in ready:
futures.add(
pool.submit(
run_node,
node,
catalog,
hook_manager,
self._is_async,
session_id,
)
)
if not futures:
assert not todo_nodes, (todo_nodes, done_nodes, ready, done) # noqa: S101
break
done, futures = wait(futures, return_when=FIRST_COMPLETED)
for future in done:
try:
node = future.result()
except Exception:
self._suggest_resume_scenario(pipeline, done_nodes, catalog)
raise
done_nodes.add(node)
self._logger.info("Completed node: %s", node.name)
self._logger.info(
"Completed %d out of %d tasks", len(done_nodes), len(nodes)
)
# Decrement load counts, and release any datasets we
# have finished with.
for dataset in node.inputs:
load_counts[dataset] -= 1
if (
load_counts[dataset] < 1
and dataset not in pipeline.inputs()
):
catalog.release(dataset)
for dataset in node.outputs:
if (
load_counts[dataset] < 1
and dataset not in pipeline.outputs()
):
catalog.release(dataset)