Source code for pasteur.kedro.mlflow.config

import os
from logging import getLogger
from pathlib import Path, PurePath
from typing import List, Optional
from urllib.parse import urlparse

import mlflow
from kedro.framework.context import KedroContext
from mlflow.entities import Experiment
from mlflow.tracking.client import MlflowClient
from pydantic import BaseModel, PrivateAttr, StrictBool
from typing_extensions import Literal

LOGGER = getLogger(__name__)


[docs] class MlflowServerOptions(BaseModel): # mutable default is ok for pydantic : https://stackoverflow.com/questions/63793662/how-to-give-a-pydantic-list-field-a-default-value mlflow_tracking_uri: Optional[str] = None credentials: Optional[str] = None _mlflow_client: MlflowClient = PrivateAttr()
[docs] class Config: extra = "forbid"
[docs] class DisableTrackingOptions(BaseModel): # mutable default is ok for pydantic : https://stackoverflow.com/questions/63793662/how-to-give-a-pydantic-list-field-a-default-value pipelines: List[str] = []
[docs] class Config: extra = "forbid"
[docs] class ExperimentOptions(BaseModel): name: str = "Default" restore_if_deleted: StrictBool = True _experiment: Experiment = PrivateAttr() # do not create _experiment immediately to avoid creating # a database connection when creating the object # it will be instantiated on setup() call
[docs] class Config: extra = "forbid"
[docs] class DictParamsOptions(BaseModel): flatten: StrictBool = False recursive: StrictBool = True sep: str = "."
[docs] class Config: extra = "forbid"
[docs] class MlflowParamsOptions(BaseModel): dict_params: DictParamsOptions = DictParamsOptions() long_params_strategy: Literal["fail", "truncate", "tag"] = "fail"
[docs] class Config: extra = "forbid"
[docs] class MlflowTrackingOptions(BaseModel): # mutable default is ok for pydantic : https://stackoverflow.com/questions/63793662/how-to-give-a-pydantic-list-field-a-default-value disable_tracking: DisableTrackingOptions = DisableTrackingOptions() experiment: ExperimentOptions = ExperimentOptions() params: MlflowParamsOptions = MlflowParamsOptions()
[docs] class Config: extra = "forbid"
[docs] class UiOptions(BaseModel): port: str = "5000" host: str = "127.0.0.1"
[docs] class Config: extra = "forbid"
[docs] class KedroMlflowConfig(BaseModel): server: MlflowServerOptions = MlflowServerOptions() tracking: MlflowTrackingOptions = MlflowTrackingOptions() ui: UiOptions = UiOptions()
[docs] class Config: # force triggering type control when setting value instead of init validate_assignment = True # raise an error if an unknown key is passed to the constructor extra = "forbid"
[docs] def setup(self, context): """Setup all the mlflow configuration""" self.server.mlflow_tracking_uri = _validate_mlflow_tracking_uri( project_path=context.project_path, uri=self.server.mlflow_tracking_uri ) # init after validating the uri, else mlflow creates a mlruns folder at the root self.server._mlflow_client = MlflowClient( tracking_uri=self.server.mlflow_tracking_uri ) self._export_credentials(context) # we set the configuration now: it takes priority # if it has already be set in export_credentials mlflow.set_tracking_uri(self.server.mlflow_tracking_uri)
def _export_credentials(self, context: KedroContext): conf_creds = context._get_config_credentials() mlflow_creds = conf_creds.get(self.server.credentials, {}) for key, value in mlflow_creds.items(): os.environ[key] = value
[docs] def set_experiment(self): """Best effort to get the experiment associated to the configuration Returns: mlflow.entities.Experiment -- [description] """ # we retrieve the experiment manually to check if it exsits mlflow_experiment = self.server._mlflow_client.get_experiment_by_name( name=self.tracking.experiment.name ) # Deal with two side case when retrieving the experiment if mlflow_experiment is not None: if ( self.tracking.experiment.restore_if_deleted and mlflow_experiment.lifecycle_stage == "deleted" ): # the experiment was created, then deleted : we have to restore it manually before setting it as the active one self.server._mlflow_client.restore_experiment( mlflow_experiment.experiment_id ) # this creates the experiment if it does not exists # and creates a global variable with the experiment # but returns nothing mlflow.set_experiment(experiment_name=self.tracking.experiment.name) # we do not use "experiment" variable directly but we fetch again from the database # because if it did not exists at all, it was created by previous command self.tracking.experiment._experiment = ( self.server._mlflow_client.get_experiment_by_name( name=self.tracking.experiment.name ) )
def _validate_mlflow_tracking_uri(project_path: str, uri: Optional[str]) -> str: """Format the uri provided to match mlflow expectations. Arguments: uri {Union[None, str]} -- A valid filepath for mlflow uri Returns: str -- A valid mlflow_tracking_uri """ # this is a special reserved keyword for mlflow which should not be converted to a path # se: https://mlflow.org/docs/latest/tracking.html#where-runs-are-recorded if uri is None: # do not use mlflow.get_tracking_uri() because if there is no env var, # it resolves to 'Path.cwd() / "mlruns"' # but we want 'project_path / "mlruns"' uri = os.environ.get("MLFLOW_TRACKING_URI", "mlruns") if uri == "databricks": return uri # if no tracking uri is provided, we register the runs locally at the root of the project pathlib_uri = PurePath(uri) if pathlib_uri.is_absolute(): valid_uri = pathlib_uri.as_uri() else: parsed = urlparse(uri) if parsed.scheme == "": # if it is a local relative path, make it absolute # .resolve() does not work well on windows # .absolute is undocumented and have known bugs # Path.cwd() / uri is the recommend way by core developpers. # See : https://discuss.python.org/t/pathlib-absolute-vs-resolve/2573/6 valid_uri = (Path(project_path) / uri).as_uri() LOGGER.info( f"The 'mlflow_tracking_uri' key in mlflow.yml is relative ('server.mlflow_tracking_uri = {uri}'). It is converted to a valid uri: '{valid_uri}'" ) else: # else assume it is an uri valid_uri = uri return valid_uri