Source code for pasteur.kedro.cli

"""In this module, all Pasteur related cli commands are defined.

You can access them through `pasteur <command>` or `kedro <command>`."""

import logging
from typing import Any, Iterable

import click
from kedro.framework.session import KedroSession

from pasteur.kedro.mlflow.base import get_git_suffix

from ..utils.parser import eval_params, merge_params, str_params_to_dict
from ..utils.progress import init_pool
from .runner import SimpleRunner

logger = logging.getLogger(__name__)


[docs] def create_session( cls, project_path: str | None = None, save_on_close: bool = True, env: str | None = None, runtime_params: dict[str, Any] | None = None, conf_source: str | None = None, session_id: str | None = None, ) -> KedroSession: # We have to stub this to change session_id import getpass import os from kedro.framework.project import validate_settings from kedro.framework.session.session import _describe_git, _jsonify_cli_context from kedro.io.core import generate_timestamp if session_id: logger.info(f'Reusing session id: "{session_id}"') validate_settings() session = cls( project_path=project_path, session_id=generate_timestamp() if session_id is None else session_id, save_on_close=save_on_close, conf_source=conf_source, ) # have to explicitly type session_data otherwise mypy will complain # possibly related to this: https://github.com/python/mypy/issues/1430 session_data: dict[str, Any] = { "project_path": session._project_path, "session_id": session.session_id, } ctx = click.get_current_context(silent=True) if ctx: session_data["cli"] = _jsonify_cli_context(ctx) env = env or os.getenv("KEDRO_ENV") if env: session_data["env"] = env if runtime_params: session_data["runtime_params"] = runtime_params try: session_data["username"] = getpass.getuser() except Exception as exc: logging.getLogger(__name__).debug( "Unable to get username. Full exception: %s", exc ) session_data.update(**_describe_git(session._project_path)) session._store.update(session_data) return session
@click.command @click.argument("pipeline", type=str, default=None) @click.argument( "params", nargs=-1, type=str, ) @click.option( "--all", is_flag=True, help="Also runs dataset ingestion, which is skipped by default.", ) @click.option("--pre", is_flag=True, help="Only runs split preprocessing.") @click.option( "--synth", is_flag=True, help="Skips running split preprocessing, only runs synthesis.", ) @click.option( "--metrics", is_flag=True, help="Useful for testing metrics, runs only metrics." ) @click.option( "-r", "--refresh-processes", type=int, default=1, help="Restarts processes after `n` tasks. Lower numbers help with memory leaks but slower. Set to 0 to disable. Check `pasteur.utils.leaks` to fix.", ) @click.option("-w", "--max-workers", type=int, default=None) @click.option( "-c", "--continue-from", type=str, default=None, help="Node name to continue from, all previous nodes are skipped (nodes in the same topological generation are also skipped).", ) @click.option( "-s", "--session-id", type=str, default=None, help="Session ID to use. Allows reusing artifacts from a previous run.", ) @click.pass_context def pipe( ctx, pipeline, params, all, pre, synth, metrics, max_workers, refresh_processes, continue_from, session_id, ): """pipe(line) is a modified version of run with minified logging and shorter syntax""" from .pipelines.meta import ( TAG_ALWAYS, TAG_CHANGES_HYPERPARAMETER, TAG_CHANGES_PER_ALGORITHM, TAG_METRICS, ) assert sum([all, pre, synth, metrics]) <= 1 param_dict = str_params_to_dict(params) cmd: str = ctx.info_name if cmd.startswith("i"): match cmd: case "iv" | "ingest_view": pipeline = f"ingest_view.{pipeline}" case "id" | "ingest_dataset": pipeline = f"ingest_dataset.{pipeline}" case "i" | "ingest": pipeline = f"{pipeline}.ingest" with create_session( KedroSession, runtime_params=param_dict, env="base", session_id=session_id ) as session: if "ingest" in pipeline: logger.debug("Skipping tags for ingest pipeline.") tags = [] elif all: logger.info("Nodes for ingesting the dataset will be run.") tags = [] elif pre: logger.info("Only nodes for preprocessing the view will be run.") tags = [TAG_ALWAYS, TAG_CHANGES_HYPERPARAMETER] elif synth: logger.warning( "Skipping ingest nodes which are affected by hyperparameters, results may be invalid." ) tags = [TAG_ALWAYS, TAG_CHANGES_PER_ALGORITHM] elif metrics: logger.warning("Only running metrics nodes.") tags = [TAG_METRICS] # Disable load versions for metrics, due to missing .e.g, models from pasteur.kedro.hooks import pasteur if pasteur: pasteur.load_any = True else: logger.debug( "Skipping dataset ingestion. In case of error, run the pipeline with the name of the dataset." ) tags = [TAG_ALWAYS, TAG_CHANGES_HYPERPARAMETER, TAG_CHANGES_PER_ALGORITHM] # TODO: Allow for using a config value if refresh_processes == 0: refresh_processes = None session.run( tags=tags, runner=SimpleRunner( pipeline, " ".join(params), max_workers=max_workers, refresh_processes=refresh_processes, resume_node=continue_from, ), # SequentialRunner(True), node_names="", from_nodes="", to_nodes="", from_inputs="", to_outputs="", load_versions={}, pipeline_name=pipeline, ) def _process_iterables(iterables: dict[str, Iterable]): sentinel = object() iterator_dict = {n: iter(v) for n, v in iterables.items()} value_dict = {n: next(v, None) for n, v in iterator_dict.items()} has_combs = True while has_combs: yield value_dict has_combs = False for name, it in iterator_dict.items(): val: Any = next(it, sentinel) if val is sentinel: new_it = iter(iterables[name]) iterator_dict[name] = new_it value_dict[name] = next(new_it, None) else: value_dict[name] = val has_combs = True break @click.command() @click.argument("pipeline", type=str, default=None) @click.option("--alg", "-a", multiple=True) @click.option("--iterator", "-i", multiple=True) @click.option("--hyperparameter", "-h", multiple=True) @click.option("--clear-cache", "-c", is_flag=True) @click.option("--skip-parent", "-p", is_flag=True) @click.argument( "params", nargs=-1, type=str, ) @click.option( "-r", "--refresh-processes", type=int, default=1, help="Restarts processes after `n` tasks. Lower numbers help with memory leaks but slower. Set to 0 to disable. Check `pasteur.utils.leaks` to fix.", ) @click.option("-w", "--max-workers", type=int, default=None) @click.pass_context def sweep( ctx, pipeline, alg, iterator, hyperparameter, skip_parent, params, clear_cache, max_workers, refresh_processes, ): """Similar to pipe, sweep allows in addition a hyperparameter sweep. By using `-i` an iterator can be defined (e.g., `-i i="range(5)"`), which will make the pipeline run for each value of i. Then i can be used in expressions with other variables that are passed as arguments (ex. `j="0.2*i"`). If an iterator is also a hyperparameter (e.g., `-h e1="[0.1,0.2,0.3]"`) then `-h` can be used, which will both sweep and pass the variable as an override at the same time (it is equal to `-i val=<iterable> val=val`). If `alg` is provided, `pipeline` is treated as the sweep view and the algorithms provided are sweeped. Example: ``` s tab_adult -a privbayes -a aim ``` runs the following for each parameter combination: ``` tab_adult.aim tab_adult.privbayes ``` Where the first algorithm runs for all nodes that are affected by hyperparameters, ie executing preprocessing. Ingest is ran for each parameter combination, so if a parameter override changes the view it is honored.""" from .mlflow import ( check_run_done, get_parent_name, get_run_name, log_parent_run, remove_runs, ) from .pipelines.meta import ( TAG_ALWAYS, TAG_CHANGES_HYPERPARAMETER, TAG_CHANGES_PER_ALGORITHM, ) # Create pipelines if alg: view = pipeline pipelines_tags = [ ( f"{view}.{alg[0]}", [TAG_ALWAYS, TAG_CHANGES_HYPERPARAMETER, TAG_CHANGES_PER_ALGORITHM], ), *[ (f"{view}.{a}", [TAG_ALWAYS, TAG_CHANGES_PER_ALGORITHM]) for a in alg[1:] ], ] else: pipelines_tags = [ ( pipeline, [TAG_ALWAYS, TAG_CHANGES_HYPERPARAMETER, TAG_CHANGES_PER_ALGORITHM], ) ] # Configure iterators iterable_dict = eval_params(iterator) hyperparam_dict = eval_params(hyperparameter) # Configure parent parent_name = get_parent_name(pipeline, alg, hyperparameter, iterator, params) mlflow_dict = { "_mlflow_parent_name": parent_name, } if clear_cache: with KedroSession.create(env="base") as session: session.load_context() logger.warning(f"Removing runs from mlflow with parent:\n{parent_name}") remove_runs(parent_name, delete_parent=False) # TODO: Allow for using a config value if refresh_processes == 0: refresh_processes = None runs = {} ingested = False runtime_params = {} for iters in _process_iterables(iterable_dict | hyperparam_dict): param_dict = eval_params(params, iters) hyper_dict = {n: iters[n] for n in hyperparam_dict} vals = param_dict | hyper_dict runtime_params = merge_params(vals | mlflow_dict) alg_only_hyper = all([n.startswith("alg") for n in vals]) for i, (pipeline, tags) in enumerate(pipelines_tags): tags = list(tags) params_skipped = False if (alg_only_hyper and ingested) or i: params_skipped = True if TAG_CHANGES_HYPERPARAMETER in tags: tags.remove(TAG_CHANGES_HYPERPARAMETER) with KedroSession.create( runtime_params=runtime_params, env="base" ) as session: session.load_context() run_name = get_run_name(pipeline, runtime_params) if alg: # if alg exists add its name runs[run_name] = {"_alg": alg, **vals} elif alg is not None: # ingest pipeline has None and should be skipped from cross-eval runs[run_name] = vals if check_run_done(run_name, None if skip_parent else parent_name, None if skip_parent else get_git_suffix()): logger.warning(f"Run '{run_name}' is complete, skipping...") continue if params_skipped: logger.warning( "Skipping ingestion since hyperparameters are the same" ) session.run( tags=tags, runner=SimpleRunner( pipeline, " ".join(f"{n}={v}" for n, v in vals.items()), max_workers=max_workers, refresh_processes=refresh_processes, ), node_names="", from_nodes="", to_nodes="", from_inputs="", to_outputs="", load_versions={}, pipeline_name=pipeline, ) ingested = True if len(runs) <= 1: logger.info("Only 1 run executed, skipping summary.") return with KedroSession.create(runtime_params=runtime_params, env="base") as session: ctx = session.load_context() experiment_id = getattr(ctx, "mlflow").get_experiment_id(pipeline.split(".")[0]) log_parent_run( parent_name, runs, skip_parent=skip_parent, experiment_id=experiment_id ) @click.command() @click.option("--user", "-u", type=str, default=None) @click.option( "--download-dir", "-d", type=str, default=None, help="Specify a different download dir. By default `raw_location` is used.", ) @click.argument( "datasets", nargs=-1, type=str, ) @click.option( "--accept", "-a", is_flag=True, help="By passing this option, you accept to the terms of the data that will be downloaded. Pasteur doesn't provide or license any of the datasets.", ) def download( user: str | None, download_dir: str | None, datasets: list[str], accept: bool = False, ): """Downloads all Pasteur datasets from their creators, provided the user agrees to their access requirements, and has credentials, if required. Uses `wget` and `boto3` to download files. Only downloads missing files, can be ran to verify dataset is downloaded correctly. """ from ..dataset import Dataset from ..extras.download import datasets as EXTRA_DATASETS from ..module import get_module_dict from ..utils.download import get_description, main # Setup logging and params with kedro with KedroSession.create(env="base") as session: ctx = session.load_context() dataset_modules = get_module_dict(Dataset, getattr(ctx, "pasteur").modules) all_datasets = dict(EXTRA_DATASETS) for name, ds in dataset_modules.items(): if isinstance(ds.raw_sources, dict): all_datasets.update(ds.raw_sources) elif ds.raw_sources is not None: all_datasets[name] = ds.raw_sources logger.info(get_description(all_datasets)) if not datasets: return sel_datasets = {} for ds in datasets: if ds not in all_datasets: logger.error(f"Raw sources for {ds} not found.") return sel_datasets[ds] = all_datasets[ds] download_dir = download_dir or getattr(ctx, "pasteur").raw_location assert download_dir, f"Download dir is empty" if not accept: logger.error( "You have to accept to the license of the data stores you're about to download from (--accept/-a)." ) else: main(download_dir, sel_datasets, user) @click.command() @click.argument( "datasets", nargs=-1, type=str, ) def bootstrap( datasets: tuple[str, ...], ): """Preprocesses downloaded datasets which require it so they can be loaded by kedro.""" from os import path from ..dataset import Dataset from ..kedro.pipelines.main import NAME_LOCATION from ..module import get_module_dict from ..utils.progress import logging_redirect_pbar # Setup logging and params with kedro with KedroSession.create(env="base") as session: ctx = session.load_context() dataset_modules = get_module_dict(Dataset, ctx.pasteur.modules) # type: ignore if not datasets: datasets = tuple(dataset_modules) for dataset in datasets: if dataset not in dataset_modules: logger.error(f"Module for dataset `{dataset}` not currently loaded.") return locations = ctx.config_loader.get("locations") raw_location = locations["raw"] base_location = locations["base"] bootstrap_location = locations.get( "bootstrap", path.join(base_location, "bootstrap") ) with logging_redirect_pbar(), init_pool(): for name in datasets: ds = dataset_modules[name] if not ds.bootstrap: continue assert ( ds.folder_name ), "Folder name for a dataset shouldn't be null when bootstrap is supplied." ds_raw_location = locations.get( NAME_LOCATION.format(ds.folder_name), path.join(raw_location, ds.folder_name), ) bootstrap_location_ds = path.join(bootstrap_location, ds.folder_name) logger.info( f'Initializing dataset "{name}" in:\n{bootstrap_location_ds}' ) ds.bootstrap(ds_raw_location, bootstrap_location_ds) @click.command() @click.argument( "dataset", type=str, ) @click.argument( "output", type=str, ) def export( dataset: str, output: str, ): """Exports a kedro dataset with name `dataset` into file `output`. Format is chosen based on filename.""" import pyarrow as pa import pyarrow.csv as csv # Setup logging and params with kedro with KedroSession.create(env="base") as session: ctx = session.load_context() ds = ctx.catalog.load(dataset) if callable(ds): ds = ds() logger.info(f"Starting export of '{dataset}' to '{output}'.") if output.endswith(".csv.gz"): table = pa.Table.from_pandas(ds) with pa.CompressedOutputStream(output, "gzip") as out: csv.write_csv(table, out) elif output.endswith(".csv"): table = pa.Table.from_pandas(ds) csv.write_csv(table, output) elif output.endswith(".pq") or output.endswith(".parquet"): ds.to_parquet(output) else: assert ( False ), f"Unsupported file format: '{output[output.index('.'):]}' of file '{output}'" logger.info("Finished export.") @click.group(name="Pasteur") def cli(): """Command line tools for manipulating a Kedro project.""" cli.add_command(bootstrap) cli.add_command(export) cli.add_command(download) cli.add_command(pipe) cli.add_command(sweep) cli.add_command(download, "dl") cli.add_command(pipe, "p") cli.add_command(sweep, "s") # TODO: fix styling in help menu cli.add_command(pipe, "ingest_dataset") cli.add_command(pipe, "id") cli.add_command(pipe, "ingest_view") cli.add_command(pipe, "iv") cli.add_command(pipe, "ingest") cli.add_command(pipe, "i")