Source code for pasteur.kedro.mlflow.parent

import json
import logging
import os
import pickle
from typing import Any

import mlflow
from mlflow.entities import Run

from ...utils.mlflow import ARTIFACT_DIR, mlflow_log_perf
from .base import get_run, sanitize_name

logger = logging.getLogger(__name__)


[docs]def get_run_artifacts(run: Run): artifact_dir = mlflow.artifacts.download_artifacts( run_id=run.info.run_id, artifact_path=ARTIFACT_DIR ) artifacts = {} # Load all artifacts by walking for root, _, files in os.walk(artifact_dir): if not files: continue # if dir is <p1>/<p2>/<p3> place artifacts in {p1: {p2: {p3: artifacts}}} sub_dict = artifacts for sub in root.replace(artifact_dir, "").split("/"): if not sub: continue tmp = sub_dict.get(sub, {}) sub_dict[sub] = tmp sub_dict = tmp # load all files, support pickle and json for now for name in files: fn = os.path.join(root, name) with open(fn, "rb") as f: if fn.endswith(".json"): art = json.load(f) elif fn.endswith(".pkl"): art = pickle.load(f) else: continue try: no_ext = name[: name.rindex(".")] except Exception: no_ext = name sub_dict[no_ext] = art return artifacts
[docs]def get_artifacts(runs: dict[str, Run]): return {name: get_run_artifacts(run) for name, run in runs.items()}
[docs]def prettify_run_names(run_params: dict[str, dict[str, Any]]): """Generates a run name based on parameters that are short for use in graphs. Parameters of each run are lined up with each other and left-justified. The resulting name is stripped to the right, to remove extra space at the end if possible. Left spaces remain to maintain structure if the final name is left-justified. Parameters that start with `_`, get priority and only have their value printed. Ex. `{"_alg": "privbayes", "e1": "abc"}` becomes `privbayes e_1: abc`. Parameters composed of letters and then numbers have their number become an indicator: `e1` becomes `e_1`, where `_` indicates subscript. TODO Parameters with boolean are only printed when true.""" ref_run = next(iter(run_params.values())) value_params = {k for k in ref_run if k.startswith("_")} bool_params = {k for k, v in ref_run.items() if isinstance(v, bool)} str_params = {name: [] for name in run_params} for param in ref_run: # Calculate str length for str_params length = max( map(lambda x: len(str(x)), [run[param] for run in run_params.values()]) ) for name in run_params: try: param_str = param[param.rindex(".") + 1 :] except Exception: param_str = param if param in bool_params: s = param_str if run_params[name][param] else (" " * len(param_str)) else: val_str = str(run_params[name][param]) buffer = " " * (length - len(val_str)) if param in value_params: s = f"{val_str}{buffer}" else: s = f"{param_str}={val_str}{buffer}" str_params[name].append(s) return {name: " ".join(params).rstrip() for name, params in str_params.items()}
[docs]def log_parent_run(parent: str, run_params: dict[str, dict[str, Any]]): query = f"tags.pasteur_id = '{sanitize_name(parent)}' and tags.pasteur_parent = '1'" parent_runs = mlflow.search_runs(filter_string=query, search_all_experiments=True) if len(parent_runs): parent_run_id = parent_runs["run_id"][0] # type: ignore logger.info(f"Relaunching parent run for logging:\n{parent}") mlflow.start_run( parent_run_id, ) else: # TODO: Perhaps this should not be true assert False, f"Parent run {parent} should exist to create combined report." runs = {name: get_run(name, parent) for name in run_params} artifacts = get_artifacts(runs) pretty = prettify_run_names(run_params) assert len(runs) ref_params = next(iter(runs.values())).data.params for name, val in ref_params.items(): for run in runs.values(): params = run.data.params if not name in params or params[name] != val: break else: # if we iterate over the whole loop else runs # log param if it exists and its the same in all runs mlflow.log_param(name, val) ref_artifacts = next(iter(artifacts.values())) # meta = ref_artifacts["meta"] perfs = {pretty[n]: a["perf"] for n, a in artifacts.items()} mlflow_log_perf(**perfs) for name, folder in ref_artifacts["metrics"].items(): metric = folder["metric"] # FIXME: remove wrk, ref hardcoding splits = {"wrk": folder["wrk"], "ref": folder["ref"]} for alg_name, artifact in artifacts.items(): splits[pretty[alg_name]] = artifact["metrics"][name]["syn"] metric.visualise(data=splits, comparison=True, wrk_set="wrk", ref_set="ref") metric.summarize(data=splits, comparison=True, wrk_set="wrk", ref_set="ref")