Source code for ads.opctl.config.merger

#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import os
from string import Template
from typing import Dict
import json

import yaml

from ads.common.auth import AuthType, ResourcePrincipal
from ads.opctl import logger
from ads.opctl.config.base import ConfigProcessor
from ads.opctl.config.utils import read_from_ini, _DefaultNoneDict
from ads.opctl.utils import is_in_notebook_session, get_service_pack_prefix
from ads.opctl.constants import (
    DEFAULT_PROFILE,
    DEFAULT_OCI_CONFIG_FILE,
    DEFAULT_CONDA_PACK_FOLDER,
    DEFAULT_ADS_CONFIG_FOLDER,
    ADS_JOBS_CONFIG_FILE_NAME,
    ADS_CONFIG_FILE_NAME,
    ADS_ML_PIPELINE_CONFIG_FILE_NAME,
    ADS_DATAFLOW_CONFIG_FILE_NAME,
    ADS_LOCAL_BACKEND_CONFIG_FILE_NAME,
    ADS_MODEL_DEPLOYMENT_CONFIG_FILE_NAME,
    DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
    BACKEND_NAME,
)


[docs] class ConfigMerger(ConfigProcessor): """Merge configurations from command line args, YAML, ini and default configs. The order of precedence is command line args > YAML + conf.ini > .ads_ops/config.ini + .ads_ops/ml_job_config.ini or .ads_ops/dataflow_config.ini Detailed examples can be found at the last section on this page: https://bitbucket.oci.oraclecorp.com/projects/ODSC/repos/advanced-ds/browse/ads/opctl/README.md?at=refs%2Fheads%2Fads-ops-draft """
[docs] def process(self, **kwargs) -> None: config_string = Template(json.dumps(self.config)).safe_substitute(os.environ) self.config = json.loads(config_string) if "runtime" not in self.config: self.config["runtime"] = {} # 1. merge and overwrite values from command line args self._merge_config_with_cmd_args(kwargs) # 1.5 merge environment variables # TODO # 2. fill in values from conf file self._fill_config_from_conf() ads_config_path = os.path.abspath( os.path.expanduser( self.config["execution"].pop("ads_config", DEFAULT_ADS_CONFIG_FOLDER) ) ) # 3. fill in values from default files under ~/.ads_ops self._fill_config_with_defaults(ads_config_path) self._config_flex_shape_details() return self
def _merge_config_with_cmd_args(self, cmd_args: Dict) -> None: # overwrite config with command line args # if a command line arg value is None or empty collection, then it is ignored def _overwrite(cfg, args): for k, v in cfg.items(): if isinstance(v, dict): _overwrite(v, args) elif k in args: if ( isinstance(args[k], bool) or (isinstance(args[k], list) and len(args[k]) > 0) or args[k] ): cfg[k] = args.pop(k) else: args.pop(k) _overwrite(self.config, cmd_args) # save everything else from command line in "execution" section if "execution" not in self.config: self.config["execution"] = {} for k, v in cmd_args.items(): if isinstance(v, bool) or (isinstance(v, list) and len(v) > 0) or v: self.config["execution"][k] = v def _fill_config_from_conf(self) -> None: if self.config["execution"].get("conf_file"): conf_file = self.config["execution"]["conf_file"] conf_profile = self.config["execution"].get("conf_profile", DEFAULT_PROFILE) logger.info(f"Reading from {conf_profile} using profile {conf_profile}") parser = read_from_ini(conf_file) extra_configs = dict(os.environ) extra_configs.update(parser[conf_profile]) config_string = yaml.dump(self.config) # _DefaultNoneDict is used so that if $variable is not found in a section in .ini, None value is filled in. self.config = yaml.safe_load( Template(config_string).substitute(_DefaultNoneDict(**extra_configs)) ) def _fill_config_with_defaults(self, ads_config_path: str) -> None: exec_config = self._get_default_config() exec_config_from_conf = self._get_config_from_config_ini(ads_config_path) exec_config.update( {k: v for k, v in exec_config_from_conf.items() if v is not None} ) # set default auth if not self.config["execution"].get("auth", None): if ResourcePrincipal.supported(): self.config["execution"]["auth"] = ( exec_config.get("auth") or AuthType.RESOURCE_PRINCIPAL ) else: self.config["execution"]["auth"] = ( exec_config.get("auth") or AuthType.API_KEY ) # determine profile if self.config["execution"]["auth"] in ( AuthType.RESOURCE_PRINCIPAL, AuthType.INSTANCE_PRINCIPAL, ): profile = self.config["execution"]["auth"].upper() exec_config.pop("oci_profile", None) self.config["execution"]["oci_profile"] = None else: profile = self.config["execution"].get("oci_profile") or exec_config.get( "oci_profile" ) self.config["execution"]["oci_profile"] = profile # loading config for corresponding profile infra_config = self._get_service_config(profile, ads_config_path) if infra_config.get( "conda_pack_os_prefix" ): # this is a field that appeared both in config.ini and ml_job_config.ini exec_config["conda_pack_os_prefix"] = infra_config.pop( "conda_pack_os_prefix" ) for k, v in exec_config.items(): if v and not self.config["execution"].get(k): self.config["execution"][k] = v if not self.config.get("infrastructure"): self.config["infrastructure"] = {} for k, v in infra_config.items(): if v and not self.config["infrastructure"].get(k): self.config["infrastructure"][k] = v def _get_default_config(self) -> Dict: if is_in_notebook_session(): conda_pack_os_prefix = get_service_pack_prefix() return { "oci_config": DEFAULT_OCI_CONFIG_FILE, "oci_profile": DEFAULT_PROFILE, "conda_pack_folder": DEFAULT_NOTEBOOK_SESSION_CONDA_DIR, "conda_pack_os_prefix": conda_pack_os_prefix, } else: return { "oci_config": DEFAULT_OCI_CONFIG_FILE, "oci_profile": DEFAULT_PROFILE, "conda_pack_folder": DEFAULT_CONDA_PACK_FOLDER, } @staticmethod def _get_config_from_config_ini(ads_config_folder: str) -> Dict: if os.path.exists(os.path.join(ads_config_folder, ADS_CONFIG_FILE_NAME)): parser = read_from_ini( os.path.join(ads_config_folder, ADS_CONFIG_FILE_NAME) ) return { "oci_config": parser["OCI"].get("oci_config"), "oci_profile": parser["OCI"].get("oci_profile"), "auth": parser["OCI"].get("auth"), "conda_pack_folder": parser["CONDA"].get("conda_pack_folder"), "conda_pack_os_prefix": parser["CONDA"].get("conda_pack_os_prefix"), } else: logger.debug( f"{os.path.join(ads_config_folder, 'config.ini')} does not exist. No config loaded." ) return {} def _get_service_config(self, oci_profile: str, ads_config_folder: str) -> Dict: backend = self.config["execution"].get("backend", None) backend_config = { BACKEND_NAME.JOB.value: ADS_JOBS_CONFIG_FILE_NAME, BACKEND_NAME.DATAFLOW.value: ADS_DATAFLOW_CONFIG_FILE_NAME, BACKEND_NAME.PIPELINE.value: ADS_ML_PIPELINE_CONFIG_FILE_NAME, BACKEND_NAME.LOCAL.value: ADS_LOCAL_BACKEND_CONFIG_FILE_NAME, BACKEND_NAME.MODEL_DEPLOYMENT.value: ADS_MODEL_DEPLOYMENT_CONFIG_FILE_NAME, } config_file = backend_config.get(backend, ADS_JOBS_CONFIG_FILE_NAME) if os.path.exists(os.path.join(ads_config_folder, config_file)): parser = read_from_ini(os.path.join(ads_config_folder, config_file)) if oci_profile in parser: return parser[oci_profile] if DEFAULT_PROFILE in parser: return parser[DEFAULT_PROFILE] else: logger.debug( f"{os.path.join(ads_config_folder, config_file)} does not exist. No config loaded." ) return {} def _config_flex_shape_details(self): infrastructure = self.config["infrastructure"] backend = self.config["execution"].get("backend", None) if ( backend == BACKEND_NAME.JOB.value or backend == BACKEND_NAME.MODEL_DEPLOYMENT.value ): shape_name = infrastructure.get("shape_name", "") if shape_name.endswith(".Flex"): if ( "ocpus" not in infrastructure or "memory_in_gbs" not in infrastructure ): raise ValueError( "Parameters `ocpus` and `memory_in_gbs` must be provided for using flex shape. " "Call `ads opctl configure` to specify." ) infrastructure["shape_config_details"] = { "ocpus": infrastructure.pop("ocpus"), "memory_in_gbs": infrastructure.pop("memory_in_gbs"), } elif backend == BACKEND_NAME.DATAFLOW.value: executor_shape = infrastructure.get("executor_shape", "") driver_shape = infrastructure.get("driver_shape", "") data_flow_shape_config_details = [ "driver_shape_memory_in_gbs", "driver_shape_ocpus", "executor_shape_memory_in_gbs", "executor_shape_ocpus", ] # executor_shape and driver_shape must be the same shape family if executor_shape.endswith(".Flex") or driver_shape.endswith(".Flex"): for parameter in data_flow_shape_config_details: if parameter not in infrastructure: raise ValueError( f"Parameters {parameter} must be provided for using flex shape. " "Call `ads opctl configure` to specify." ) infrastructure["driver_shape_config"] = { "ocpus": infrastructure.pop("driver_shape_ocpus"), "memory_in_gbs": infrastructure.pop("driver_shape_memory_in_gbs"), } infrastructure["executor_shape_config"] = { "ocpus": infrastructure.pop("executor_shape_ocpus"), "memory_in_gbs": infrastructure.pop("executor_shape_memory_in_gbs"), }