""" Parsing related utility functions. """
import logging
from typing import Any, TypeVar
from ..module import Module
logger = logging.getLogger(__name__)
def _try_convert_to_numeric(value: str):
"""Taken from kedro.framework.cli.utils"""
try:
value = float(value) # type: ignore
except ValueError:
return value
return int(value) if value.is_integer() else value # type: ignore
def _try_convert_primitive(value: Any):
"""Converts string value to integer/float/bool/None."""
if value == "True":
return True
if value == "False":
return False
if value == "None":
return None
return _try_convert_to_numeric(value)
def _try_convert_eval(value: str, locals: dict[str, object]):
try:
return eval(value, {}, locals)
except Exception as e:
logger.error(f"Failed to evaluate '{value}' with locals {locals}")
raise e
def _update_value_nested_dict(
nested_dict: dict[str, Any], value: Any, walking_path: list[str]
) -> dict:
"""Taken from kedro.framework.cli.utils"""
key = walking_path.pop(0)
if not walking_path:
nested_dict[key] = value
return nested_dict
nested_dict[key] = _update_value_nested_dict(
nested_dict.get(key, {}), value, walking_path
)
return nested_dict
[docs]
def str_params_to_dict(params: list[str], locals: dict[str, Any] = {}):
"""Converts a list of format ["a.b.c=5", "c=b"] to {a: {b: {c:5}}, c: 'b'}.
Note the number conversion."""
param_dict = {}
for item in params:
item = item.split("=", 1)
if len(item) != 2:
assert False
key = item[0].strip()
if not key:
assert False
value = item[1].strip()
param_dict = _update_value_nested_dict(
param_dict, _try_convert_eval(value, locals), key.split(".")
)
return param_dict
[docs]
def eval_params(params: list[str], locals: dict[str, Any] = {}):
return {
name: _try_convert_eval(value, locals)
for name, value in map(lambda x: x.split("=", 1), params)
}
[docs]
def merge_params(params: dict[str, Any]):
param_dict = {}
for key, val in params.items():
param_dict = _update_value_nested_dict(param_dict, val, key.split("."))
return param_dict
[docs]
def flat_params_to_dict(params: dict[str, Any]):
"""Converts a list of format {a.b.c: 5, c: b} to {a: {b: {c:5}}, c: 'b'}.
Note the number conversion."""
param_dict = {}
for key, value in params.items():
if not key:
raise
param_dict = _update_value_nested_dict(
param_dict, _try_convert_to_numeric(value), key.split(".")
)
return param_dict
[docs]
def dict_to_flat_params(params: dict[str, Any]) -> dict[str, str]:
out = {}
for param, val in params.items():
if not isinstance(val, dict):
out[param] = val
else:
exp = dict_to_flat_params(val)
for nest_param, nest_val in exp.items():
out[f"{param}.{nest_param}"] = nest_val
return out
CLS = TypeVar("CLS", bound=Module)
def _find_subclasses(cls: type[CLS]) -> dict[str, type[CLS]]:
"""Returns all the subclasses of a given class."""
sub_cls = {}
for c in cls.__subclasses__():
sub_cls[c.name] = c
sub_cls.update(_find_subclasses(c))
sub_cls.pop(None, None)
return sub_cls
[docs]
def merge_two_dicts(a: dict, b: dict):
"""Recursively merges dictionaries a, b by prioritizing b."""
ak = set(a.keys())
bk = set(b.keys())
out = {}
for k in ak - bk:
out[k] = a[k]
for k in bk - ak:
out[k] = b[k]
for k in ak.intersection(bk):
if isinstance(a[k], dict) and isinstance(b[k], dict):
out[k] = merge_two_dicts(a[k], b[k])
else:
out[k] = b[k]
return out
[docs]
def merge_dicts(*ds: dict):
out = {}
for d in ds:
out = merge_two_dicts(out, d)
return out
[docs]
def get_params_for_pipe(name: str, params: dict):
"""Returns the parameters for the provided pipeline by merging
the nodes `default`, `<view>` and the top level one in one dictionary.
This allows the user to set default values for all views in the `default`
namespace, view specific overriding params in the `<view>` namespace and
override Any of them using the `--params` argument without having to use
the parameter namespace"""
view = name.split(".")[0]
return merge_dicts(params.get("default", {}), params.get(view, {}), params)