"""kedro is a CLI for managing Kedro projects.
This module implements commands available from the kedro CLI for creating
projects.
"""
from __future__ import annotations
import os
import re
import shutil
import stat
import tempfile
import warnings
from collections import OrderedDict
from itertools import groupby
from pathlib import Path
from typing import Any, Callable
import click
import yaml
from attrs import define, field
import kedro
from kedro import KedroDeprecationWarning
from kedro import __version__ as version
from kedro.framework.cli.utils import (
CONTEXT_SETTINGS,
KedroCliError,
_clean_pycache,
_get_entry_points,
_safe_load_entry_point,
command_with_verbosity,
)
KEDRO_PATH = Path(kedro.__file__).parent
TEMPLATE_PATH = KEDRO_PATH / "templates" / "project"
_STARTERS_REPO = "git+https://github.com/kedro-org/kedro-starters.git"
_DEPRECATED_STARTERS = [
"pandas-iris",
"pyspark-iris",
"pyspark",
"standalone-datacatalog",
]
[docs]@define(order=True)
class KedroStarterSpec: # noqa: too-few-public-methods
"""Specification of custom kedro starter template
Args:
alias: alias of the starter which shows up on `kedro starter list` and is used
by the starter argument of `kedro new`
template_path: path to a directory or a URL to a remote VCS repository supported
by `cookiecutter`
directory: optional directory inside the repository where the starter resides.
origin: reserved field used by kedro internally to determine where the starter
comes from, users do not need to provide this field.
"""
alias: str
template_path: str
directory: str | None = None
origin: str | None = field(init=False)
_OFFICIAL_STARTER_SPECS = [
KedroStarterSpec("astro-airflow-iris", _STARTERS_REPO, "astro-airflow-iris"),
# The `astro-iris` was renamed to `astro-airflow-iris`, but old (external)
# documentation and tutorials still refer to `astro-iris`. We create an alias to
# check if a user has entered old `astro-iris` as the starter name and changes it
# to `astro-airflow-iris`.
KedroStarterSpec("astro-iris", _STARTERS_REPO, "astro-airflow-iris"),
KedroStarterSpec(
"standalone-datacatalog", _STARTERS_REPO, "standalone-datacatalog"
),
KedroStarterSpec("pandas-iris", _STARTERS_REPO, "pandas-iris"),
KedroStarterSpec("pyspark", _STARTERS_REPO, "pyspark"),
KedroStarterSpec("pyspark-iris", _STARTERS_REPO, "pyspark-iris"),
KedroStarterSpec("spaceflights", _STARTERS_REPO, "spaceflights"),
KedroStarterSpec("databricks-iris", _STARTERS_REPO, "databricks-iris"),
]
# Set the origin for official starters
for starter_spec in _OFFICIAL_STARTER_SPECS:
starter_spec.origin = "kedro"
_OFFICIAL_STARTER_SPECS = {spec.alias: spec for spec in _OFFICIAL_STARTER_SPECS}
CONFIG_ARG_HELP = """Non-interactive mode, using a configuration yaml file. This file
must supply the keys required by the template's prompts.yml. When not using a starter,
these are `project_name`, `repo_name` and `python_package`."""
STARTER_ARG_HELP = """Specify the starter template to use when creating the project.
This can be the path to a local directory, a URL to a remote VCS repository supported
by `cookiecutter` or one of the aliases listed in ``kedro starter list``.
"""
CHECKOUT_ARG_HELP = (
"An optional tag, branch or commit to checkout in the starter repository."
)
DIRECTORY_ARG_HELP = (
"An optional directory inside the repository where the starter resides."
)
# noqa: unused-argument
def _remove_readonly(func: Callable, path: Path, excinfo: tuple): # pragma: no cover
"""Remove readonly files on Windows
See: https://docs.python.org/3/library/shutil.html?highlight=shutil#rmtree-example
"""
os.chmod(path, stat.S_IWRITE)
func(path)
def _get_starters_dict() -> dict[str, KedroStarterSpec]:
"""This function lists all the starter aliases declared in
the core repo and in plugins entry points.
For example, the output for official kedro starters looks like:
{"astro-airflow-iris":
KedroStarterSpec(
name="astro-airflow-iris",
template_path="git+https://github.com/kedro-org/kedro-starters.git",
directory="astro-airflow-iris",
origin="kedro"
),
"astro-iris":
KedroStarterSpec(
name="astro-iris",
template_path="git+https://github.com/kedro-org/kedro-starters.git",
directory="astro-airflow-iris",
origin="kedro"
),
}
"""
starter_specs = _OFFICIAL_STARTER_SPECS
for starter_entry_point in _get_entry_points(name="starters"):
origin = starter_entry_point.module.split(".")[0]
specs = _safe_load_entry_point(starter_entry_point) or []
for spec in specs:
if not isinstance(spec, KedroStarterSpec):
click.secho(
f"The starter configuration loaded from module {origin}"
f"should be a 'KedroStarterSpec', got '{type(spec)}' instead",
fg="red",
)
elif spec.alias in starter_specs:
click.secho(
f"Starter alias `{spec.alias}` from `{origin}` "
f"has been ignored as it is already defined by"
f"`{starter_specs[spec.alias].origin}`",
fg="red",
)
else:
spec.origin = origin
starter_specs[spec.alias] = spec
return starter_specs
def _starter_spec_to_dict(
starter_specs: dict[str, KedroStarterSpec]
) -> dict[str, dict[str, str]]:
"""Convert a dictionary of starters spec to a nicely formatted dictionary"""
format_dict: dict[str, dict[str, str]] = {}
for alias, spec in starter_specs.items():
if alias in _DEPRECATED_STARTERS:
key = alias + " (deprecated)"
else:
key = alias
format_dict[key] = {} # Each dictionary represent 1 starter
format_dict[key]["template_path"] = spec.template_path
if spec.directory:
format_dict[key]["directory"] = spec.directory
return format_dict
# noqa: missing-function-docstring
@click.group(context_settings=CONTEXT_SETTINGS, name="Kedro")
def create_cli(): # pragma: no cover
pass
@command_with_verbosity(create_cli, short_help="Create a new kedro project.")
@click.option(
"--config",
"-c",
"config_path",
type=click.Path(exists=True),
help=CONFIG_ARG_HELP,
)
@click.option("--starter", "-s", "starter_alias", help=STARTER_ARG_HELP)
@click.option("--checkout", help=CHECKOUT_ARG_HELP)
@click.option("--directory", help=DIRECTORY_ARG_HELP)
def new(config_path, starter_alias, checkout, directory, **kwargs):
"""Create a new kedro project."""
if starter_alias in _DEPRECATED_STARTERS:
warnings.warn(
f"The starter '{starter_alias}' has been deprecated and will be archived from Kedro 0.19.0.",
KedroDeprecationWarning,
)
click.secho(
"From Kedro 0.19.0, the command `kedro new` will come with the option of interactively selecting add-ons "
"for your project such as linting, testing, custom logging, and more. The selected add-ons will add the "
"basic setup for the utilities selected to your projects.",
fg="green",
)
if checkout and not starter_alias:
raise KedroCliError("Cannot use the --checkout flag without a --starter value.")
if directory and not starter_alias:
raise KedroCliError(
"Cannot use the --directory flag without a --starter value."
)
starters_dict = _get_starters_dict()
if starter_alias in starters_dict:
if directory:
raise KedroCliError(
"Cannot use the --directory flag with a --starter alias."
)
spec = starters_dict[starter_alias]
template_path = spec.template_path
# "directory" is an optional key for starters from plugins, so if the key is
# not present we will use "None".
directory = spec.directory
checkout = checkout or version
elif starter_alias is not None:
template_path = starter_alias
checkout = checkout or version
else:
template_path = str(TEMPLATE_PATH)
# Get prompts.yml to find what information the user needs to supply as config.
tmpdir = tempfile.mkdtemp()
cookiecutter_dir = _get_cookiecutter_dir(template_path, checkout, directory, tmpdir)
prompts_required = _get_prompts_required(cookiecutter_dir)
# We only need to make cookiecutter_context if interactive prompts are needed.
if not config_path:
cookiecutter_context = _make_cookiecutter_context_for_prompts(cookiecutter_dir)
# Cleanup the tmpdir after it's no longer required.
# Ideally we would want to be able to use tempfile.TemporaryDirectory() context manager
# but it causes an issue with readonly files on windows
# see: https://bugs.python.org/issue26660.
# So on error, we will attempt to clear the readonly bits and re-attempt the cleanup
shutil.rmtree(tmpdir, onerror=_remove_readonly)
# Obtain config, either from a file or from interactive user prompts.
if not prompts_required:
config = {}
if config_path:
config = _fetch_config_from_file(config_path)
elif config_path:
config = _fetch_config_from_file(config_path)
_validate_config_file(config, prompts_required)
else:
config = _fetch_config_from_user_prompts(prompts_required, cookiecutter_context)
cookiecutter_args = _make_cookiecutter_args(config, checkout, directory)
_create_project(template_path, cookiecutter_args)
@create_cli.group()
def starter():
"""Commands for working with project starters."""
@starter.command("list")
def list_starters():
"""List all official project starters available."""
starters_dict = _get_starters_dict()
# Group all specs by origin as nested dict and sort it.
sorted_starters_dict: dict[str, dict[str, KedroStarterSpec]] = {
origin: dict(sorted(starters_dict_by_origin))
for origin, starters_dict_by_origin in groupby(
starters_dict.items(), lambda item: item[1].origin
)
}
# ensure kedro starters are listed first
sorted_starters_dict = dict(
sorted(sorted_starters_dict.items(), key=lambda x: x == "kedro")
)
warnings.warn(
f"The starters {_DEPRECATED_STARTERS} are deprecated and will be archived in Kedro 0.19.0."
)
for origin, starters_spec in sorted_starters_dict.items():
click.secho(f"\nStarters from {origin}\n", fg="yellow")
click.echo(
yaml.safe_dump(_starter_spec_to_dict(starters_spec), sort_keys=False)
)
def _fetch_config_from_file(config_path: str) -> dict[str, str]:
"""Obtains configuration for a new kedro project non-interactively from a file.
Args:
config_path: The path of the config.yml which should contain the data required
by ``prompts.yml``.
Returns:
Configuration for starting a new project. This is passed as ``extra_context``
to cookiecutter and will overwrite the cookiecutter.json defaults.
Raises:
KedroCliError: If the file cannot be parsed.
"""
try:
with open(config_path, encoding="utf-8") as config_file:
config = yaml.safe_load(config_file)
if KedroCliError.VERBOSE_ERROR:
click.echo(config_path + ":")
click.echo(yaml.dump(config, default_flow_style=False))
except Exception as exc:
raise KedroCliError(
f"Failed to generate project: could not load config at {config_path}."
) from exc
return config
def _make_cookiecutter_args(
config: dict[str, str],
checkout: str,
directory: str,
) -> dict[str, Any]:
"""Creates a dictionary of arguments to pass to cookiecutter.
Args:
config: Configuration for starting a new project. This is passed as
``extra_context`` to cookiecutter and will overwrite the cookiecutter.json
defaults.
checkout: The tag, branch or commit in the starter repository to checkout.
Maps directly to cookiecutter's ``checkout`` argument. Relevant only when
using a starter.
directory: The directory of a specific starter inside a repository containing
multiple starters. Maps directly to cookiecutter's ``directory`` argument.
Relevant only when using a starter.
https://cookiecutter.readthedocs.io/en/1.7.2/advanced/directories.html
Returns:
Arguments to pass to cookiecutter.
"""
config.setdefault("kedro_version", version)
cookiecutter_args = {
"output_dir": config.get("output_dir", str(Path.cwd().resolve())),
"no_input": True,
"extra_context": config,
}
if checkout:
cookiecutter_args["checkout"] = checkout
if directory:
cookiecutter_args["directory"] = directory
return cookiecutter_args
def _create_project(template_path: str, cookiecutter_args: dict[str, Any]):
"""Creates a new kedro project using cookiecutter.
Args:
template_path: The path to the cookiecutter template to create the project.
It could either be a local directory or a remote VCS repository
supported by cookiecutter. For more details, please see:
https://cookiecutter.readthedocs.io/en/latest/usage.html#generate-your-project
cookiecutter_args: Arguments to pass to cookiecutter.
Raises:
KedroCliError: If it fails to generate a project.
"""
# noqa: import-outside-toplevel
from cookiecutter.main import cookiecutter # for performance reasons
try:
result_path = cookiecutter(template=template_path, **cookiecutter_args)
except Exception as exc:
raise KedroCliError(
"Failed to generate project when running cookiecutter."
) from exc
_clean_pycache(Path(result_path))
extra_context = cookiecutter_args["extra_context"]
project_name = extra_context.get("project_name", "New Kedro Project")
python_package = extra_context.get(
"python_package", project_name.lower().replace(" ", "_").replace("-", "_")
)
click.secho(
f"\nThe project name '{project_name}' has been applied to: "
f"\n- The project title in {result_path}/README.md "
f"\n- The folder created for your project in {result_path} "
f"\n- The project's python package in {result_path}/src/{python_package}"
)
click.secho(
"\nA best-practice setup includes initialising git and creating "
"a virtual environment before running 'pip install -r src/requirements.txt' to install "
"project-specific dependencies. Refer to the Kedro documentation: "
"https://kedro.readthedocs.io/"
)
click.secho(
f"\nChange directory to the project generated in {result_path} by "
f"entering 'cd {result_path}'",
fg="green",
)
def _get_cookiecutter_dir(
template_path: str, checkout: str, directory: str, tmpdir: str
) -> Path:
"""Gives a path to the cookiecutter directory. If template_path is a repo then
clones it to ``tmpdir``; if template_path is a file path then directly uses that
path without copying anything.
"""
# noqa: import-outside-toplevel
from cookiecutter.exceptions import RepositoryCloneFailed, RepositoryNotFound
from cookiecutter.repository import determine_repo_dir # for performance reasons
try:
cookiecutter_dir, _ = determine_repo_dir(
template=template_path,
abbreviations={},
clone_to_dir=Path(tmpdir).resolve(),
checkout=checkout,
no_input=True,
directory=directory,
)
except (RepositoryNotFound, RepositoryCloneFailed) as exc:
error_message = f"Kedro project template not found at {template_path}."
if checkout:
error_message += (
f" Specified tag {checkout}. The following tags are available: "
+ ", ".join(_get_available_tags(template_path))
)
official_starters = sorted(_OFFICIAL_STARTER_SPECS)
raise KedroCliError(
f"{error_message}. The aliases for the official Kedro starters are: \n"
f"{yaml.safe_dump(official_starters, sort_keys=False)}"
) from exc
return Path(cookiecutter_dir)
def _get_prompts_required(cookiecutter_dir: Path) -> dict[str, Any] | None:
"""Finds the information a user must supply according to prompts.yml."""
prompts_yml = cookiecutter_dir / "prompts.yml"
if not prompts_yml.is_file():
return None
try:
with prompts_yml.open("r") as prompts_file:
return yaml.safe_load(prompts_file)
except Exception as exc:
raise KedroCliError(
"Failed to generate project: could not load prompts.yml."
) from exc
def _fetch_config_from_user_prompts(
prompts: dict[str, Any], cookiecutter_context: OrderedDict
) -> dict[str, str]:
"""Interactively obtains information from user prompts.
Args:
prompts: Prompts from prompts.yml.
cookiecutter_context: Cookiecutter context generated from cookiecutter.json.
Returns:
Configuration for starting a new project. This is passed as ``extra_context``
to cookiecutter and will overwrite the cookiecutter.json defaults.
"""
# noqa: import-outside-toplevel
from cookiecutter.environment import StrictEnvironment
from cookiecutter.prompt import read_user_variable, render_variable
config: dict[str, str] = {}
for variable_name, prompt_dict in prompts.items():
prompt = _Prompt(**prompt_dict)
# render the variable on the command line
cookiecutter_variable = render_variable(
env=StrictEnvironment(context=cookiecutter_context),
raw=cookiecutter_context.get(variable_name),
cookiecutter_dict=config,
)
# read the user's input for the variable
user_input = read_user_variable(str(prompt), cookiecutter_variable)
if user_input:
prompt.validate(user_input)
config[variable_name] = user_input
return config
def _make_cookiecutter_context_for_prompts(cookiecutter_dir: Path):
# noqa: import-outside-toplevel
from cookiecutter.generate import generate_context
cookiecutter_context = generate_context(cookiecutter_dir / "cookiecutter.json")
return cookiecutter_context.get("cookiecutter", {})
class _Prompt:
"""Represent a single CLI prompt for `kedro new`"""
def __init__(self, *args, **kwargs) -> None: # noqa: unused-argument
try:
self.title = kwargs["title"]
except KeyError as exc:
raise KedroCliError(
"Each prompt must have a title field to be valid."
) from exc
self.text = kwargs.get("text", "")
self.regexp = kwargs.get("regex_validator", None)
self.error_message = kwargs.get("error_message", "")
def __str__(self) -> str:
title = self.title.strip().title()
title = click.style(title + "\n" + "=" * len(title), bold=True)
prompt_lines = [title] + [self.text]
prompt_text = "\n".join(str(line).strip() for line in prompt_lines)
return f"\n{prompt_text}\n"
def validate(self, user_input: str) -> None:
"""Validate a given prompt value against the regex validator"""
if self.regexp and not re.match(self.regexp, user_input):
message = f"'{user_input}' is an invalid value for {self.title}."
click.secho(message, fg="red", err=True)
click.secho(self.error_message, fg="red", err=True)
raise ValueError(message, self.error_message)
def _get_available_tags(template_path: str) -> list:
# Not at top level so that kedro CLI works without a working git executable.
# noqa: import-outside-toplevel
import git
try:
tags = git.cmd.Git().ls_remote("--tags", template_path.replace("git+", ""))
unique_tags = {
tag.split("/")[-1].replace("^{}", "") for tag in tags.split("\n")
}
# Remove git ref "^{}" and duplicates. For example,
# tags: ['/tags/version', '/tags/version^{}']
# unique_tags: {'version'}
except git.GitCommandError:
return []
return sorted(unique_tags)
def _validate_config_file(config: dict[str, str], prompts: dict[str, Any]):
"""Checks that the configuration file contains all needed variables.
Args:
config: The config as a dictionary.
prompts: Prompts from prompts.yml.
Raises:
KedroCliError: If the config file is empty or does not contain all the keys
required in prompts, or if the output_dir specified does not exist.
"""
if config is None:
raise KedroCliError("Config file is empty.")
missing_keys = set(prompts) - set(config)
if missing_keys:
click.echo(yaml.dump(config, default_flow_style=False))
raise KedroCliError(f"{', '.join(missing_keys)} not found in config file.")
if "output_dir" in config and not Path(config["output_dir"]).exists():
raise KedroCliError(
f"'{config['output_dir']}' is not a valid output directory. "
"It must be a relative or absolute path to an existing directory."
)