Source code for ads.opctl.distributed.common.cluster_config_helper

#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2022 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


OCI__RUNTIME_TYPE = "OCI__RUNTIME_TYPE"
OCI__RUNTIME_URI = "OCI__RUNTIME_URI"
OCI__RUNTIME_PYTHON_PATH = "OCI__RUNTIME_PYTHON_PATH"
OCI__RUNTIME_GIT_BRANCH = "OCI__RUNTIME_GIT_BRANCH"
OCI__RUNTIME_GIT_COMMIT = "OCI__RUNTIME_GIT_COMMIT"
OCI__RUNTIME_GIT_SECRET_ID = "OCI__RUNTIME_GIT_SECRET_ID"
OCI__CODE_DIR = "OCI__CODE_DIR"


[docs] class ClusterConfigToJobSpecConverter: def __init__(self, cluster_info): self.cluster_info = cluster_info
[docs] def job_def_info(self): job = {} job["infrastructure"] = self.cluster_info.infrastructure cluster = self.cluster_info.cluster job["image"] = cluster.image or cluster.main.image or cluster.worker.image job["name"] = cluster.name or self.cluster_info.infrastructure.get( "displayName" ) job["envVars"] = cluster.config.envVars job["envVars"]["OCI__WORK_DIR"] = cluster.work_dir job["envVars"]["OCI__EPHEMERAL"] = cluster.ephemeral job["envVars"]["OCI__CLUSTER_TYPE"] = cluster.type.upper() job["envVars"]["OCI__WORKER_COUNT"] = ( cluster.worker.replicas if cluster.worker is not None else 0 ) if cluster.ps is not None: job["envVars"]["OCI__PS_COUNT"] = cluster.ps.replicas job["envVars"]["OCI__START_ARGS"] = cluster.config.cmd_args.strip() job["envVars"]["OCI__ENTRY_SCRIPT"] = self.cluster_info.runtime.entry_point if self.cluster_info.runtime.type: optional_runtime_fields = [ {"attr": "type", "env": OCI__RUNTIME_TYPE}, {"attr": "uri", "env": OCI__RUNTIME_URI}, {"attr": "branch", "env": OCI__RUNTIME_GIT_BRANCH}, {"attr": "commit", "env": OCI__RUNTIME_GIT_COMMIT}, {"attr": "git_secret_id", "env": OCI__RUNTIME_GIT_SECRET_ID}, {"attr": "code_dir", "env": OCI__CODE_DIR}, {"attr": "python_path", "env": OCI__RUNTIME_PYTHON_PATH}, ] for field in optional_runtime_fields: val = getattr(self.cluster_info.runtime, field["attr"], None) if val: job["envVars"][field["env"]] = val runtime_args = self.cluster_info.runtime.args if isinstance(self.cluster_info.runtime.args, list): runtime_args = " ".join([str(v) for v in self.cluster_info.runtime.args]) if runtime_args: job["envVars"]["OCI__ENTRY_SCRIPT_ARGS"] = runtime_args if self.cluster_info.runtime.kwargs: job["envVars"][ "OCI__ENTRY_SCRIPT_KWARGS" ] = self.cluster_info.runtime.kwargs job["envVars"].update(self.cluster_info.runtime.envVars) job["envVars"] = {k: str(job["envVars"][k]) for k in job["envVars"]} if self.cluster_info.cluster.certificate: job["envVars"][ "OCI__CERTIFICATE_OCID" ] = self.cluster_info.cluster.certificate.cert_ocid job["envVars"][ "OCI__CERTIFICATE_KEY_DOWNLOAD_LOCATION" ] = self.cluster_info.cluster.certificate.key_download_location job["envVars"][ "OCI__CERTIFICATE_DOWNLOAD_LOCATION" ] = self.cluster_info.cluster.certificate.cert_download_location job["envVars"][ "OCI__CERTIFICATE_AUTHORITY_OCID" ] = self.cluster_info.cluster.certificate.ca_ocid job["envVars"][ "OCI__CA_DOWNLOAD_LOCATION" ] = self.cluster_info.cluster.certificate.ca_download_location return job
[docs] def job_run_info(self, jobType): jobrun = {} jobTypeConfig = getattr(self.cluster_info.cluster, jobType) if jobTypeConfig is not None: jobrun["name"] = jobTypeConfig.name or jobType jobrun["envVars"] = jobTypeConfig.config.envVars if jobTypeConfig.config.cmd_args: jobrun["envVars"][ "OCI__START_ARGS" ] = jobTypeConfig.config.cmd_args.strip() jobrun["envVars"]["OCI__MODE"] = jobType.upper() jobrun["envVars"] = { k: str(jobrun["envVars"][k]) for k in jobrun["envVars"] } return jobrun