Source code for pasteur.kedro.mlflow.base

from typing import Any

import mlflow
from mlflow.entities import Run, RunStatus

from ...utils.parser import dict_to_flat_params


[docs] def flatten_dict(d: dict, recursive: bool = True, sep: str = ".") -> dict: def expand(key, value): if isinstance(value, dict): new_value = ( flatten_dict(value, recursive=recursive, sep=sep) if recursive else value ) return [(f"{key}{sep}{k}", v) for k, v in new_value.items()] else: return [(f"{key}", value)] items = [item for k, v in d.items() for item in expand(k, v)] return dict(items)
_git_id = None
[docs] def get_git_suffix(): # FIXME: Dirty global var global _git_id if _git_id is not None: return _git_id try: import git repo = git.Repo(search_parent_directories=True) sha = repo.head.object.hexsha return sha[:8] except Exception: return ""
[docs] def get_run_name(pipeline: str, params: dict[str, Any]): run_name = pipeline for param, val in dict_to_flat_params(params).items(): if param.startswith("_"): continue run_name += f" {param}={val}" return run_name
[docs] def get_parent_name( pipeline: str, algs: list[str], hyperparams: list[str], iterators: list[str], params: list[str], ): algs_str = "" if algs: algs_str = " -a [" + ", ".join(algs) + "]" hyper_str = "".join(map(lambda x: f" -h {x}", hyperparams)) iter_str = "".join(map(lambda x: f" -i {x}", iterators)) param_str = "".join(map(lambda x: f" {x}", params)) return f"{pipeline}{algs_str}{hyper_str}{iter_str}{param_str}"
[docs] def sanitize_name(name: str): # todo: properly escape return name.replace("'", "\\'")
[docs] def get_run_id(name: str, parent: str | None, git: str | None, finished: bool = True): filter_string = f"tags.pasteur_id = '{sanitize_name(name)}'" if parent: filter_string += f" and tags.pasteur_pid = '{sanitize_name(parent)}'" if git: filter_string += f" and tags.pasteur_git = '{git}'" if finished: filter_string += ( f" and attribute.status = '{RunStatus.to_string(RunStatus.FINISHED)}'" ) tmp = mlflow.search_runs( experiment_ids=[exp.experiment_id for exp in mlflow.search_experiments()], filter_string=filter_string, ) if len(tmp): return tmp["run_id"][0] return None
[docs] def check_run_done(name: str, parent: str | None, git: str | None): return bool(get_run_id(name, parent, git))
[docs] def get_run(name: str, parent: str | None, git: str | None) -> Run: return mlflow.get_run(get_run_id(name, parent, git))
[docs] def remove_runs(parent: str, delete_parent: bool = False): """Removes runs with provided parent""" # Delete children runs = mlflow.search_runs( search_all_experiments=True, filter_string=f"tags.pasteur_pid = '{sanitize_name(parent)}'", ) for id in runs["run_id"]: mlflow.delete_run(id) # Delete parent if not delete_parent: return git = get_git_suffix() runs = mlflow.search_runs( search_all_experiments=True, filter_string=f"tags.pasteur_id = '{sanitize_name(parent)}' and tags.pasteur_parent = '1' and tags.pasteur_git = '{git}'", ) for id in runs["run_id"]: mlflow.delete_run(id)