"""This module provides user-friendly functions for creating nodes as parts
of Kedro pipelines.
"""
from __future__ import annotations
import copy
import inspect
import logging
import re
from collections import Counter
from typing import Any, Callable, Iterable
from warnings import warn
from more_itertools import spy, unzip
from .transcoding import _strip_transcoding
[docs]
class Node:
"""``Node`` is an auxiliary class facilitating the operations required to
run user-provided functions as part of Kedro pipelines.
"""
[docs]
def __init__( # noqa: PLR0913
self,
func: Callable,
inputs: str | list[str] | dict[str, str] | None,
outputs: str | list[str] | dict[str, str] | None,
*,
name: str | None = None,
tags: str | Iterable[str] | None = None,
confirms: str | list[str] | None = None,
namespace: str | None = None,
):
"""Create a node in the pipeline by providing a function to be called
along with variable names for inputs and/or outputs.
Args:
func: A function that corresponds to the node logic.
The function should have at least one input or output.
inputs: The name or the list of the names of variables used as
inputs to the function. The number of names should match
the number of arguments in the definition of the provided
function. When dict[str, str] is provided, variable names
will be mapped to function argument names.
outputs: The name or the list of the names of variables used
as outputs of the function. The number of names should match
the number of outputs returned by the provided function.
When dict[str, str] is provided, variable names will be mapped
to the named outputs the function returns.
name: Optional node name to be used when displaying the node in
logs or any other visualisations. Valid node name must contain
only letters, digits, hyphens, underscores and/or fullstops.
tags: Optional set of tags to be applied to the node. Valid node tag must
contain only letters, digits, hyphens, underscores and/or fullstops.
confirms: Optional name or the list of the names of the datasets
that should be confirmed. This will result in calling
``confirm()`` method of the corresponding data set instance.
Specified dataset names do not necessarily need to be present
in the node ``inputs`` or ``outputs``.
namespace: Optional node namespace.
Raises:
ValueError: Raised in the following cases:
a) When the provided arguments do not conform to
the format suggested by the type hint of the argument.
b) When the node produces multiple outputs with the same name.
c) When an input has the same name as an output.
d) When the given node name violates the requirements:
it must contain only letters, digits, hyphens, underscores
and/or fullstops.
"""
if not callable(func):
raise ValueError(
_node_error_message(
f"first argument must be a function, not '{type(func).__name__}'."
)
)
if inputs and not isinstance(inputs, (list, dict, str)):
raise ValueError(
_node_error_message(
f"'inputs' type must be one of [String, List, Dict, None], "
f"not '{type(inputs).__name__}'."
)
)
for _input in _to_list(inputs):
if not isinstance(_input, str):
raise ValueError(
_node_error_message(
f"names of variables used as inputs to the function "
f"must be of 'String' type, but {_input} from {inputs} "
f"is '{type(_input)}'."
)
)
if outputs and not isinstance(outputs, (list, dict, str)):
raise ValueError(
_node_error_message(
f"'outputs' type must be one of [String, List, Dict, None], "
f"not '{type(outputs).__name__}'."
)
)
for _output in _to_list(outputs):
if not isinstance(_output, str):
raise ValueError(
_node_error_message(
f"names of variables used as outputs of the function "
f"must be of 'String' type, but {_output} from {outputs} "
f"is '{type(_output)}'."
)
)
if not inputs and not outputs:
raise ValueError(
_node_error_message("it must have some 'inputs' or 'outputs'.")
)
self._validate_inputs(func, inputs)
self._func = func
self._inputs = inputs
# The type of _outputs is picked up as possibly being None, however the checks above prevent that
# ever being the case. Mypy doesn't get that though, so it complains about the assignment of outputs to
# _outputs with different types.
self._outputs: str | list[str] | dict[str, str] = outputs # type: ignore[assignment]
if name and not re.match(r"[\w\.-]+$", name):
raise ValueError(
f"'{name}' is not a valid node name. It must contain only "
f"letters, digits, hyphens, underscores and/or fullstops."
)
self._name = name
self._namespace = namespace
self._tags = set(_to_list(tags))
for tag in self._tags:
if not re.match(r"[\w\.-]+$", tag):
raise ValueError(
f"'{tag}' is not a valid node tag. It must contain only "
f"letters, digits, hyphens, underscores and/or fullstops."
)
self._validate_unique_outputs()
self._validate_inputs_dif_than_outputs()
self._confirms = confirms
def _copy(self, **overwrite_params: Any) -> Node:
"""
Helper function to copy the node, replacing some values.
"""
params = {
"func": self._func,
"inputs": self._inputs,
"outputs": self._outputs,
"name": self._name,
"namespace": self._namespace,
"tags": self._tags,
"confirms": self._confirms,
}
params.update(overwrite_params)
return Node(**params) # type: ignore[arg-type]
@property
def _logger(self) -> logging.Logger:
return logging.getLogger(__name__)
@property
def _unique_key(self) -> tuple[Any, Any] | Any | tuple:
def hashable(value: Any) -> tuple[Any, Any] | Any | tuple:
if isinstance(value, dict):
# we sort it because a node with inputs/outputs
# {"arg1": "a", "arg2": "b"} is equivalent to
# a node with inputs/outputs {"arg2": "b", "arg1": "a"}
return tuple(sorted(value.items()))
if isinstance(value, list):
return tuple(value)
return value
return self.name, hashable(self._inputs), hashable(self._outputs)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Node):
return NotImplemented
return self._unique_key == other._unique_key
def __lt__(self, other: Any) -> bool:
if not isinstance(other, Node):
return NotImplemented
return self._unique_key < other._unique_key
def __hash__(self) -> int:
return hash(self._unique_key)
def __str__(self) -> str:
def _set_to_str(xset: set | list[str]) -> str:
return f"[{';'.join(xset)}]"
out_str = _set_to_str(self.outputs) if self._outputs else "None"
in_str = _set_to_str(self.inputs) if self._inputs else "None"
prefix = self._name + ": " if self._name else ""
return prefix + f"{self._func_name}({in_str}) -> {out_str}"
def __repr__(self) -> str: # pragma: no cover
return (
f"Node({self._func_name}, {self._inputs!r}, {self._outputs!r}, "
f"{self._name!r})"
)
def __call__(self, **kwargs: Any) -> dict[str, Any]:
return self.run(inputs=kwargs)
@property
def _func_name(self) -> str:
name = _get_readable_func_name(self._func)
if name == "<partial>":
warn(
f"The node producing outputs '{self.outputs}' is made from a 'partial' function. "
f"Partial functions do not have a '__name__' attribute: consider using "
f"'functools.update_wrapper' for better log messages."
)
return name
@property
def func(self) -> Callable:
"""Exposes the underlying function of the node.
Returns:
Return the underlying function of the node.
"""
return self._func
@func.setter
def func(self, func: Callable) -> None:
"""Sets the underlying function of the node.
Useful if user wants to decorate the function in a node's Hook implementation.
Args:
func: The new function for node's execution.
"""
self._func = func
@property
def tags(self) -> set[str]:
"""Return the tags assigned to the node.
Returns:
Return the set of all assigned tags to the node.
"""
return set(self._tags)
[docs]
def tag(self, tags: str | Iterable[str]) -> Node:
"""Create a new ``Node`` which is an exact copy of the current one,
but with more tags added to it.
Args:
tags: The tags to be added to the new node.
Returns:
A copy of the current ``Node`` object with the tags added.
"""
return self._copy(tags=self.tags | set(_to_list(tags)))
@property
def name(self) -> str:
"""Node's name.
Returns:
Node's name if provided or the name of its function.
"""
node_name = self._name or str(self)
if self.namespace:
return f"{self.namespace}.{node_name}"
return node_name
@property
def short_name(self) -> str:
"""Node's name.
Returns:
Returns a short, user-friendly name that is not guaranteed to be unique.
The namespace is stripped out of the node name.
"""
if self._name:
return self._name
return self._func_name.replace("_", " ").title()
@property
def namespace(self) -> str | None:
"""Node's namespace.
Returns:
String representing node's namespace, typically from outer to inner scopes.
"""
return self._namespace
@property
def inputs(self) -> list[str]:
"""Return node inputs as a list, in the order required to bind them properly to
the node's function.
Returns:
Node input names as a list.
"""
if isinstance(self._inputs, dict):
return _dict_inputs_to_list(self._func, self._inputs)
return _to_list(self._inputs)
@property
def outputs(self) -> list[str]:
"""Return node outputs as a list preserving the original order
if possible.
Returns:
Node output names as a list.
"""
return _to_list(self._outputs)
@property
def confirms(self) -> list[str]:
"""Return dataset names to confirm as a list.
Returns:
Dataset names to confirm as a list.
"""
return _to_list(self._confirms)
[docs]
def run(self, inputs: dict[str, Any] | None = None) -> dict[str, Any]:
"""Run this node using the provided inputs and return its results
in a dictionary.
Args:
inputs: Dictionary of inputs as specified at the creation of
the node.
Raises:
ValueError: In the following cases:
a) The node function inputs are incompatible with the node
input definition.
Example 1: node definition input is a list of 2
DataFrames, whereas only 1 was provided or 2 different ones
were provided.
b) The node function outputs are incompatible with the node
output definition.
Example 1: node function definition is a dictionary,
whereas function returns a list.
Example 2: node definition output is a list of 5
strings, whereas the function returns a list of 4 objects.
Exception: Any exception thrown during execution of the node.
Returns:
All produced node outputs are returned in a dictionary, where the
keys are defined by the node outputs.
"""
self._logger.info("Running node: %s", str(self))
outputs = None
if not (inputs is None or isinstance(inputs, dict)):
raise ValueError(
f"Node.run() expects a dictionary or None, "
f"but got {type(inputs)} instead"
)
try:
inputs = {} if inputs is None else inputs
if not self._inputs:
outputs = self._run_with_no_inputs(inputs)
elif isinstance(self._inputs, str):
outputs = self._run_with_one_input(inputs, self._inputs)
elif isinstance(self._inputs, list):
outputs = self._run_with_list(inputs, self._inputs)
elif isinstance(self._inputs, dict):
outputs = self._run_with_dict(inputs, self._inputs)
return self._outputs_to_dictionary(outputs)
# purposely catch all exceptions
except Exception as exc:
self._logger.error(
"Node %s failed with error: \n%s",
str(self),
str(exc),
extra={"markup": True},
)
raise exc
def _run_with_no_inputs(self, inputs: dict[str, Any]) -> Any:
if inputs:
raise ValueError(
f"Node {self!s} expected no inputs, "
f"but got the following {len(inputs)} input(s) instead: "
f"{sorted(inputs.keys())}."
)
return self._func()
def _run_with_one_input(self, inputs: dict[str, Any], node_input: str) -> Any:
if len(inputs) != 1 or node_input not in inputs:
raise ValueError(
f"Node {self!s} expected one input named '{node_input}', "
f"but got the following {len(inputs)} input(s) instead: "
f"{sorted(inputs.keys())}."
)
return self._func(inputs[node_input])
def _run_with_list(self, inputs: dict[str, Any], node_inputs: list[str]) -> Any:
# Node inputs and provided run inputs should completely overlap
if set(node_inputs) != set(inputs.keys()):
raise ValueError(
f"Node {self!s} expected {len(node_inputs)} input(s) {node_inputs}, "
f"but got the following {len(inputs)} input(s) instead: "
f"{sorted(inputs.keys())}."
)
# Ensure the function gets the inputs in the correct order
return self._func(*(inputs[item] for item in node_inputs))
def _run_with_dict(
self, inputs: dict[str, Any], node_inputs: dict[str, str]
) -> Any:
# Node inputs and provided run inputs should completely overlap
if set(node_inputs.values()) != set(inputs.keys()):
raise ValueError(
f"Node {self!s} expected {len(set(node_inputs.values()))} input(s) "
f"{sorted(set(node_inputs.values()))}, "
f"but got the following {len(inputs)} input(s) instead: "
f"{sorted(inputs.keys())}."
)
kwargs = {arg: inputs[alias] for arg, alias in node_inputs.items()}
return self._func(**kwargs)
def _outputs_to_dictionary(self, outputs: Any) -> dict[str, Any]:
def _from_dict() -> dict[str, Any]:
result, iterator = outputs, None
# generator functions are lazy and we need a peek into their first output
if inspect.isgenerator(outputs):
(result,), iterator = spy(outputs)
# The type of _outputs is picked up as possibly not being a dict, but _from_dict is only called when
# it is a dictionary and so the calls to .keys and .values will work even though Mypy doesn't pick that up.
keys = list(self._outputs.keys()) # type: ignore[union-attr]
names = list(self._outputs.values()) # type: ignore[union-attr]
if not isinstance(result, dict):
raise ValueError(
f"Failed to save outputs of node {self}.\n"
f"The node output is a dictionary, whereas the "
f"function output is {type(result)}."
)
if set(keys) != set(result.keys()):
raise ValueError(
f"Failed to save outputs of node {self!s}.\n"
f"The node's output keys {set(result.keys())} "
f"do not match with the returned output's keys {set(keys)}."
)
if iterator:
exploded = map(lambda x: tuple(x[k] for k in keys), iterator)
result = unzip(exploded)
else:
# evaluate this eagerly so we can reuse variable name
result = tuple(result[k] for k in keys)
return dict(zip(names, result))
def _from_list() -> dict:
result, iterator = outputs, None
# generator functions are lazy and we need a peek into their first output
if inspect.isgenerator(outputs):
(result,), iterator = spy(outputs)
if not isinstance(result, (list, tuple)):
raise ValueError(
f"Failed to save outputs of node {self!s}.\n"
f"The node definition contains a list of "
f"outputs {self._outputs}, whereas the node function "
f"returned a '{type(result).__name__}'."
)
if len(result) != len(self._outputs):
raise ValueError(
f"Failed to save outputs of node {self!s}.\n"
f"The node function returned {len(result)} output(s), "
f"whereas the node definition contains {len(self._outputs)} "
f"output(s)."
)
if iterator:
result = unzip(iterator)
return dict(zip(self._outputs, result))
if self._outputs is None:
return {}
if isinstance(self._outputs, str):
return {self._outputs: outputs}
if isinstance(self._outputs, dict):
return _from_dict()
return _from_list()
def _validate_inputs(
self, func: Callable, inputs: None | str | list[str] | dict[str, str]
) -> None:
# inspect does not support built-in Python functions written in C.
# Thus we only validate func if it is not built-in.
if not inspect.isbuiltin(func):
args, kwargs = self._process_inputs_for_bind(inputs)
try:
inspect.signature(func, follow_wrapped=False).bind(*args, **kwargs)
except Exception as exc:
func_args = inspect.signature(
func, follow_wrapped=False
).parameters.keys()
func_name = _get_readable_func_name(func)
raise TypeError(
f"Inputs of '{func_name}' function expected {list(func_args)}, "
f"but got {inputs}"
) from exc
def _validate_unique_outputs(self) -> None:
cnt = Counter(self.outputs)
diff = {k for k in cnt if cnt[k] > 1}
if diff:
raise ValueError(
f"Failed to create node {self} due to duplicate "
f"output(s) {diff}.\nNode outputs must be unique."
)
def _validate_inputs_dif_than_outputs(self) -> None:
common_in_out = set(map(_strip_transcoding, self.inputs)).intersection(
set(map(_strip_transcoding, self.outputs))
)
if common_in_out:
raise ValueError(
f"Failed to create node {self}.\n"
f"A node cannot have the same inputs and outputs even if they are transcoded: "
f"{common_in_out}"
)
@staticmethod
def _process_inputs_for_bind(
inputs: str | list[str] | dict[str, str] | None,
) -> tuple[list[str], dict[str, str]]:
# Safeguard that we do not mutate list inputs
inputs = copy.copy(inputs)
args: list[str] = []
kwargs: dict[str, str] = {}
if isinstance(inputs, str):
args = [inputs]
elif isinstance(inputs, list):
args = inputs
elif isinstance(inputs, dict):
kwargs = inputs
return args, kwargs
def _node_error_message(msg: str) -> str:
return (
f"Invalid Node definition: {msg}\n"
f"Format should be: node(function, inputs, outputs)"
)
[docs]
def node( # noqa: PLR0913
func: Callable,
inputs: str | list[str] | dict[str, str] | None,
outputs: str | list[str] | dict[str, str] | None,
*,
name: str | None = None,
tags: str | Iterable[str] | None = None,
confirms: str | list[str] | None = None,
namespace: str | None = None,
) -> Node:
"""Create a node in the pipeline by providing a function to be called
along with variable names for inputs and/or outputs.
Args:
func: A function that corresponds to the node logic. The function
should have at least one input or output.
inputs: The name or the list of the names of variables used as inputs
to the function. The number of names should match the number of
arguments in the definition of the provided function. When
dict[str, str] is provided, variable names will be mapped to
function argument names.
outputs: The name or the list of the names of variables used as outputs
to the function. The number of names should match the number of
outputs returned by the provided function. When dict[str, str]
is provided, variable names will be mapped to the named outputs the
function returns.
name: Optional node name to be used when displaying the node in logs or
any other visualisations.
tags: Optional set of tags to be applied to the node.
confirms: Optional name or the list of the names of the datasets
that should be confirmed. This will result in calling ``confirm()``
method of the corresponding data set instance. Specified dataset
names do not necessarily need to be present in the node ``inputs``
or ``outputs``.
namespace: Optional node namespace.
Returns:
A Node object with mapped inputs, outputs and function.
Example:
::
>>> import pandas as pd
>>> import numpy as np
>>>
>>> def clean_data(cars: pd.DataFrame,
>>> boats: pd.DataFrame) -> dict[str, pd.DataFrame]:
>>> return dict(cars_df=cars.dropna(), boats_df=boats.dropna())
>>>
>>> def halve_dataframe(data: pd.DataFrame) -> List[pd.DataFrame]:
>>> return np.array_split(data, 2)
>>>
>>> nodes = [
>>> node(clean_data,
>>> inputs=['cars2017', 'boats2017'],
>>> outputs=dict(cars_df='clean_cars2017',
>>> boats_df='clean_boats2017')),
>>> node(halve_dataframe,
>>> 'clean_cars2017',
>>> ['train_cars2017', 'test_cars2017']),
>>> node(halve_dataframe,
>>> dict(data='clean_boats2017'),
>>> ['train_boats2017', 'test_boats2017'])
>>> ]
"""
return Node(
func,
inputs,
outputs,
name=name,
tags=tags,
confirms=confirms,
namespace=namespace,
)
def _dict_inputs_to_list(
func: Callable[[Any], Any], inputs: dict[str, str]
) -> list[str]:
"""Convert a dict representation of the node inputs to a list, ensuring
the appropriate order for binding them to the node's function.
"""
sig = inspect.signature(func, follow_wrapped=False).bind(**inputs)
return [*sig.args, *sig.kwargs.values()]
def _to_list(element: str | Iterable[str] | dict[str, str] | None) -> list[str]:
"""Make a list out of node inputs/outputs.
Returns:
list[str]: Node input/output names as a list to standardise.
"""
if element is None:
return []
if isinstance(element, str):
return [element]
if isinstance(element, dict):
return list(element.values())
return list(element)
def _get_readable_func_name(func: Callable) -> str:
"""Get a user-friendly readable name of the function provided.
Returns:
str: readable name of the provided callable func.
"""
if hasattr(func, "__name__"):
return func.__name__
name = repr(func)
if "functools.partial" in name:
name = "<partial>"
return name