import json
import logging
import os
import pickle
from typing import Any
import mlflow
from mlflow.entities import Run
from ...utils.mlflow import ARTIFACT_DIR, mlflow_log_perf
from .base import get_run, sanitize_name, get_git_suffix
logger = logging.getLogger(__name__)
[docs]
def get_run_artifacts(run: Run):
artifact_dir = mlflow.artifacts.download_artifacts(
run_id=run.info.run_id, artifact_path=ARTIFACT_DIR
)
artifacts = {}
# Load all artifacts by walking
for root, _, files in os.walk(artifact_dir):
if not files:
continue
# if dir is <p1>/<p2>/<p3> place artifacts in {p1: {p2: {p3: artifacts}}}
sub_dict = artifacts
for sub in root.replace(artifact_dir, "").split("/"):
if not sub:
continue
tmp = sub_dict.get(sub, {})
sub_dict[sub] = tmp
sub_dict = tmp
# load all files, support pickle and json for now
for name in files:
fn = os.path.join(root, name)
with open(fn, "rb") as f:
if fn.endswith(".json"):
art = json.load(f)
elif fn.endswith(".pkl"):
art = pickle.load(f)
else:
continue
try:
no_ext = name[: name.rindex(".")]
except Exception:
no_ext = name
sub_dict[no_ext] = art
return artifacts
[docs]
def get_artifacts(runs: dict[str, Run]):
return {name: get_run_artifacts(run) for name, run in runs.items()}
[docs]
def prettify_run_names(run_params: dict[str, dict[str, Any]]):
"""Generates a run name based on parameters that are short for use in graphs.
Parameters of each run are lined up with each other and left-justified.
The resulting name is stripped to the right, to remove extra space at the end
if possible. Left spaces remain to maintain structure if the final name is
left-justified.
Parameters that start with `_`, get priority and only have their value printed.
Ex. `{"_alg": "privbayes", "e1": "abc"}` becomes `privbayes e_1: abc`.
Parameters composed of letters and then numbers have their number become an indicator:
`e1` becomes `e_1`, where `_` indicates subscript. TODO
Parameters with boolean are only printed when true."""
ref_run = next(iter(run_params.values()))
value_params = {k for k in ref_run if k.startswith("_")}
bool_params = {k for k, v in ref_run.items() if isinstance(v, bool)}
# Skip params shared by all runs
skip_params = {
k
for k, v in ref_run.items()
if all(k in run and run[k] == v for run in run_params.values())
}
str_params = {name: [] for name in run_params}
pretty_provided = {}
for param in ref_run:
if param in skip_params:
continue
# Calculate str length for str_params
# length = max(
# map(lambda x: len(str(x)), [run[param] for run in run_params.values()])
# )
for name in run_params:
try:
param_str = param[param.rindex(".") + 1 :]
except Exception:
param_str = param
if param in bool_params:
# s = param_str if run_params[name][param] else (" " * len(param_str))
s = param_str if run_params[name][param] else ""
elif param == "_alg":
# FIXME: dirty hack to add algorithm name
s = str(name.split(".", 1)[-1].split(" ", 1)[0])
elif param == "_pretty":
pretty_provided[name] = str(run_params[name][param])
continue
else:
val_str = str(run_params[name][param])
# buffer = " " * (length - len(val_str))
buffer = ""
if param in value_params:
s = f"{val_str}{buffer}"
else:
s = f"{param_str}={val_str}{buffer}"
str_params[name].append(s)
return {
name: " ".join(params).strip() if params and any(params) else "base"
for name, params in str_params.items()
} | pretty_provided
[docs]
def log_parent_run(
parent: str,
run_params: dict[str, dict[str, Any]],
skip_parent: bool = False,
experiment_id: str | None = None,
):
git = get_git_suffix()
query = f"tags.pasteur_id = '{sanitize_name(parent)}' and tags.pasteur_parent = '1' and tags.pasteur_git = '{git}'"
parent_runs = mlflow.search_runs(filter_string=query, search_all_experiments=True)
if not len(parent_runs):
logger.info(f"Creating empty mlflow parent run:\n{parent}")
ctx_mgr = mlflow.start_run(run_name=parent, experiment_id=experiment_id)
mlflow.set_tag("pasteur_id", parent)
mlflow.set_tag("pasteur_parent", "1")
mlflow.set_tag("pasteur_git", git)
else:
parent_run_id = parent_runs["run_id"][0] # type: ignore
logger.info(f"Relaunching parent run for logging:\n{parent}")
ctx_mgr = mlflow.start_run(parent_run_id)
with ctx_mgr:
runs = {
name: get_run(
name,
parent if not skip_parent else None,
git if not skip_parent else None,
)
for name in run_params
}
artifacts = get_artifacts(runs)
pretty = prettify_run_names(run_params)
assert len(runs)
ref_params = next(iter(runs.values())).data.params
for name, val in ref_params.items():
for run in runs.values():
params = run.data.params
if not name in params or params[name] != val:
break
else:
# if we iterate over the whole loop else runs
# log param if it exists and its the same in all runs
mlflow.log_param(name, val)
ref_artifacts = next(iter(artifacts.values()))
# meta = ref_artifacts["meta"]
perfs = {pretty[n]: a["perf"] for n, a in artifacts.items() if "perf" in a}
try:
mlflow_log_perf(**perfs)
except Exception as e:
logger.error(f"Error logging performance:\n{e}")
for name, folder in ref_artifacts["metrics"].items():
if not "metric" in folder:
logger.error(
f"Metric '{name}' does not have a 'metric' executable, skipping..."
)
metric = folder["metric"]
splits = {}
for alg_name, artifact in artifacts.items():
try:
splits[pretty[alg_name]] = artifact["metrics"][name]["data"]
except Exception as e:
logger.error(
f"Split '{pretty[alg_name]}' metric '{name}' is broken."
)
metric.visualise(data=splits)
metric.summarize(data=splits)