import logging
from copy import deepcopy
from typing import Any
import mlflow
from kedro.config import MissingConfigException
from kedro.framework.context import KedroContext
from kedro.framework.hooks import hook_impl
from kedro.pipeline.node import Node
from mlflow.entities import RunStatus
from mlflow.utils.validation import MAX_PARAM_VAL_LENGTH
from ...utils.logging import MlflowHandler
from ...utils.parser import merge_dicts
from ...utils.perf import PerformanceTracker
from .base import flatten_dict, get_run_id, get_run_name, sanitize_name, get_git_suffix
from .config import KedroMlflowConfig
logger = logging.getLogger(__name__)
[docs]
class MlflowTrackingHook:
def __init__(self):
self.recursive = True
self.sep = "."
self.long_parameters_strategy = "fail"
self._is_mlflow_enabled = True
@property
def _logger(self) -> logging.Logger:
return logging.getLogger(__name__)
def _log_param(self, name: str, value: dict | int | bool | str) -> None:
str_value = str(value)
str_value_length = len(str_value)
if str_value_length <= MAX_PARAM_VAL_LENGTH:
return mlflow.log_param(name, value)
else:
if self.long_params_strategy == "fail":
raise ValueError(
f"Parameter '{name}' length is {str_value_length}, "
f"while mlflow forces it to be lower than '{MAX_PARAM_VAL_LENGTH}'. "
"If you want to bypass it, try to change 'long_params_strategy' to"
" 'tag' or 'truncate' in the 'mlflow.yml'configuration file."
)
elif self.long_params_strategy == "tag":
self._logger.warning(
f"Parameter '{name}' (value length {str_value_length}) is set as a tag."
)
mlflow.set_tag(name, value)
elif self.long_params_strategy == "truncate":
self._logger.warning(
f"Parameter '{name}' (value length {str_value_length}) is truncated to its {MAX_PARAM_VAL_LENGTH} first characters."
)
mlflow.log_param(name, str_value[0:MAX_PARAM_VAL_LENGTH])
[docs]
@hook_impl
def after_context_created(
self,
context: KedroContext,
) -> None:
try:
patterns = getattr(context.config_loader, "config_patterns", {})
if "mlflow" not in patterns:
patterns["mlflow"] = ["mlflow*", "mlflow*/**"]
conf_mlflow_yml = context.config_loader.get("mlflow")
except MissingConfigException:
logger.warning(
"No 'mlflow.yml' config file found in environment. Default configuration will be used. Use ``kedro mlflow init`` command in CLI to customize the configuration."
)
# we create an empty dict to have the same behaviour when the mlflow.yml
# is commented out. In this situation there is no MissingConfigException
# but we got an empty dict
conf_mlflow_yml = {}
mlflow_config = KedroMlflowConfig.parse_obj(conf_mlflow_yml)
self.mlflow_config = mlflow_config # store for further reuse
self.mlflow_config.setup(context)
self.context = context
setattr(context, "mlflow", self)
[docs]
def get_experiment_id(self, view: str | None = None):
if view and self.mlflow_config.tracking.experiment.name == "Default":
self.mlflow_config.tracking.experiment.name = view
self.mlflow_config.set_experiment()
return self.mlflow_config.tracking.experiment._experiment.experiment_id
[docs]
@hook_impl
def before_pipeline_run(self, run_params: dict[str, Any]) -> None:
self.params = self.context.params.copy()
self.parent_name = self.params.pop("_mlflow_parent_name", "")
# Disable tracking for pipelines that don't meet criteria
pipeline_name = run_params["pipeline_name"]
disabled_pipelines = self.mlflow_config.tracking.disable_tracking.pipelines
self._is_mlflow_enabled = True
if pipeline_name in disabled_pipelines:
self._is_mlflow_enabled = False
logger.info(
f"Disabled mlflow logging for blacklisted pipeline {pipeline_name}."
)
return
if "ingest" in pipeline_name:
self._is_mlflow_enabled = False
logger.info(f"Disabled mlflow logging for ingest pipeline {pipeline_name}.")
pipe_seg = run_params["pipeline_name"].split(".")
if len(pipe_seg) < 2:
self._is_mlflow_enabled = False
self._logger.warning(
"Running ingest dataset/view pipeline, disabling mlflow"
# f"Pipeline name {pipeline_name} is not compatible with mlflow hook (<view>.<alg>.<misc>), disabling logging."
)
return
else:
current_view = pipe_seg[0]
alg = pipe_seg[1]
# Exit if mlflow bit was set to false
if not self._is_mlflow_enabled:
return
# Setup global mlflow configuration with view as experiment name
if self.mlflow_config.tracking.experiment.name == "Default":
self.mlflow_config.tracking.experiment.name = current_view
self.mlflow_config.set_experiment()
# params for further for node logging
self.flatten = self.mlflow_config.tracking.params.dict_params.flatten
self.recursive = self.mlflow_config.tracking.params.dict_params.recursive
self.sep = self.mlflow_config.tracking.params.dict_params.sep
self.long_params_strategy = (
self.mlflow_config.tracking.params.long_params_strategy
)
run_name = get_run_name(
run_params["pipeline_name"], run_params["runtime_params"]
)
git = get_git_suffix()
if self.parent_name:
query = f"tags.pasteur_id = '{sanitize_name(self.parent_name)}' and tags.pasteur_parent = '1' and tags.pasteur_git = '{git}'"
parent_runs = mlflow.search_runs(
experiment_ids=[
self.mlflow_config.tracking.experiment._experiment.experiment_id
],
filter_string=query,
)
if len(parent_runs):
parent_run_id = parent_runs["run_id"][0] # type: ignore
logger.info(f"Nesting mlflow run under:\n{self.parent_name}")
mlflow.start_run(
parent_run_id,
)
else:
logger.info(f"Creating mlflow parent run:\n{self.parent_name}")
mlflow.start_run(
run_name=self.parent_name,
experiment_id=self.mlflow_config.tracking.experiment._experiment.experiment_id,
)
mlflow.set_tag("pasteur_id", self.parent_name)
mlflow.set_tag("pasteur_parent", "1")
mlflow.set_tag("pasteur_git", git)
run_id = get_run_id(run_name, self.parent_name, git, finished=False)
if run_id:
# logger.info("Resuming unfinished mlflow run.")
# mlflow.start_run(
# run_id=run_id,
# nested=bool(self.parent_name),
# )
logger.info("Removing existing mlflow run.")
mlflow.delete_run(run_id)
mlflow.start_run(
experiment_id=self.mlflow_config.tracking.experiment._experiment.experiment_id,
run_name=run_name,
nested=bool(self.parent_name),
)
mlflow.set_tag("pasteur_id", run_name)
mlflow.set_tag("pasteur_git", git)
if self.parent_name:
mlflow.set_tag("pasteur_pid", self.parent_name)
# Set tags only for run parameters that have values.
mlflow.set_tags({k: v for k, v in run_params.items() if v})
# add manually git sha for consistency with the journal
# TODO : this does not take into account not committed files, so it
# does not ensure reproducibility. Define what to do.
self.flatten = self.mlflow_config.tracking.params.dict_params.flatten
self.recursive = self.mlflow_config.tracking.params.dict_params.recursive
self.sep = self.mlflow_config.tracking.params.dict_params.sep
self.long_params_strategy = (
self.mlflow_config.tracking.params.long_params_strategy
)
# We use 3 namespaces:
# the unbounded namespace with highest priority, which is used for overrides
# the `<view>` namespace that sets the parameters for the specific view
# the `default` namespace that sets a baseline of parameters
override_params = self.params.copy()
if "_views" in override_params:
for view in override_params.pop("_views"):
override_params.pop(view, None)
else:
logger.warning(
'"_views" key not found in params, view parameters won\'t be stripped from mlflow params.'
)
# Get default and view params
default_params = deepcopy(self.params.get("default", {}))
view_params = deepcopy(self.params.get(current_view, {}))
# Create params that contain the alg data to log as parameters.
run_params = merge_dicts(view_params, default_params, override_params)
run_params.pop("default", {})
params = deepcopy(run_params)
params.pop("tables", {})
ratios = params.pop("ratios", {})
algs = params.pop("algs", {})
alg_overrides = params.pop("alg", {})
params["alg._name"] = alg
params["alg"] = merge_dicts(algs.get(alg, {}), alg_overrides)
# filter dir, venv
params["alg"] = {
k: v for k, v in params["alg"].items() if k not in ("venv", "dir")
}
params["view"] = current_view
# The rest of the parameters get flattened
flattened_params = flatten_dict(
d=params, recursive=self.recursive, sep=self.sep
)
# logging parameters based on defined strategy
for k, v in flattened_params.items():
self._log_param(k, v)
if ratios:
self._log_param("ratios", dict(sorted(ratios.items())))
# And for good measure, store all parameters as a yml file
mlflow.log_dict(self.params, f"_raw/params_all.yml")
mlflow.log_dict(run_params, f"_raw/params_run.yml")
[docs]
@hook_impl
def on_pipeline_error(self):
if not self._is_mlflow_enabled:
return
MlflowHandler.reset_all()
PerformanceTracker.log(on_fail=True)
while mlflow.active_run():
mlflow.end_run(RunStatus.to_string(RunStatus.FAILED))
[docs]
@hook_impl
def after_pipeline_run(self) -> None:
if not self._is_mlflow_enabled:
return
MlflowHandler.reset_all()
PerformanceTracker.log()
while mlflow.active_run():
mlflow.end_run()