#!/usr/bin/env python
# -*- coding: utf-8; -*-
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
import json
import os
import shlex
import tempfile
import time
from typing import Dict, Union
from ads.common.auth import AuthContext, create_signer, AuthType
from ads.common.oci_client import OCIClientFactory
from ads.jobs import (
DataFlow,
DataFlowNotebookRuntime,
DataFlowRun,
DataFlowRuntime,
Job,
)
from ads.opctl import logger
from ads.opctl.backend.base import Backend, RuntimeFactory
from ads.opctl.constants import OPERATOR_MODULE_PATH
from ads.opctl.decorator.common import print_watch_command
from ads.opctl.operator.common.const import ENV_OPERATOR_ARGS
from ads.opctl.operator.common.operator_loader import OperatorInfo, OperatorLoader
REQUIRED_FIELDS = [
"compartment_id",
"driver_shape",
"executor_shape",
"logs_bucket_uri",
"script_bucket",
]
[docs]
class DataFlowBackend(Backend):
def __init__(self, config: Dict) -> None:
"""
Initialize a MLJobBackend object given config dictionary.
Parameters
----------
config: dict
dictionary of configurations
"""
self.config = config
self.oci_auth = create_signer(
config["execution"].get("auth"),
config["execution"].get("oci_config", None),
config["execution"].get("oci_profile", None),
)
self.auth_type = config["execution"].get("auth")
self.profile = config["execution"].get("oci_profile", None)
self.client = OCIClientFactory(**self.oci_auth).dataflow
[docs]
def init(
self,
uri: Union[str, None] = None,
overwrite: bool = False,
runtime_type: Union[str, None] = None,
**kwargs: Dict,
) -> Union[str, None]:
"""Generates a starter YAML specification for a Data Flow Application.
Parameters
----------
overwrite: (bool, optional). Defaults to False.
Overwrites the result specification YAML if exists.
uri: (str, optional), Defaults to None.
The filename to save the resulting specification template YAML.
runtime_type: (str, optional). Defaults to None.
The resource runtime type.
**kwargs: Dict
The optional arguments.
Returns
-------
Union[str, None]
The YAML specification for the given resource if `uri` was not provided.
`None` otherwise.
"""
conda_slug = kwargs.get(
"conda_slug", self.config["execution"].get("conda_slug", "conda_slug")
).lower()
# if conda slug contains '/' then the assumption is that it is a custom conda pack
# the conda prefix needs to be added
if "/" in conda_slug:
conda_slug = os.path.join(
self.config["execution"].get(
"conda_pack_os_prefix", "oci://bucket@namespace/conda_environments"
),
conda_slug,
)
RUNTIME_KWARGS_MAP = {
DataFlowRuntime().type: {
"conda_slug": conda_slug,
"script_bucket": f"{self.config['infrastructure'].get('script_bucket','').rstrip('/')}",
},
}
with AuthContext(auth=self.auth_type, profile=self.profile):
# define an job
job = (
Job()
.with_name("{Job name. For MLflow and Operator will be auto generated}")
.with_infrastructure(
DataFlow(**(self.config.get("infrastructure", {}) or {})).init()
)
.with_runtime(
DataFlowRuntimeFactory.get_runtime(
key=runtime_type or DataFlowRuntime().type
).init(**{**kwargs, **RUNTIME_KWARGS_MAP[runtime_type]})
)
)
note = (
"# This YAML specification was auto generated by the `ads opctl init` command.\n"
"# The more details about the jobs YAML specification can be found in the ADS documentation:\n"
"# https://accelerated-data-science.readthedocs.io/en/latest/user_guide/apachespark/dataflow.html \n\n"
)
return job.to_yaml(
uri=uri,
overwrite=overwrite,
note=note,
filter_by_attribute_map=False,
**kwargs,
)
[docs]
def apply(self) -> Dict:
"""
Create DataFlow and DataFlow Run from YAML.
"""
# TODO add the logic for build dataflow and dataflow run from YAML.
raise NotImplementedError(f"`apply` hasn't been supported for data flow yet.")
[docs]
@print_watch_command
def run(self) -> Dict:
"""
Create DataFlow and DataFlow Run from OCID or cli parameters.
"""
with AuthContext(auth=self.auth_type, profile=self.profile):
if self.config["execution"].get("ocid", None):
job_id = self.config["execution"]["ocid"]
run_id = Job.from_dataflow_job(job_id).run().id
else:
infra = self.config.get("infrastructure", {})
if any(k not in infra for k in REQUIRED_FIELDS):
missing = [k for k in REQUIRED_FIELDS if k not in infra]
raise ValueError(
f"Following fields are missing but are required for OCI DataFlow Jobs: {missing}. Please run `ads opctl configure`."
)
rt_spec = {}
rt_spec["scriptPathURI"] = os.path.join(
self.config["execution"]["source_folder"],
self.config["execution"]["entrypoint"],
)
if "script_bucket" in infra:
rt_spec["scriptBucket"] = infra.pop("script_bucket")
if self.config["execution"].get("command"):
rt_spec["args"] = shlex.split(self.config["execution"]["command"])
if self.config["execution"].get("archive"):
rt_spec["archiveUri"] = self.config["execution"]["archive"]
rt_spec["archiveBucket"] = infra.pop("archive_bucket", None)
rt = DataFlowRuntime(rt_spec)
if "configuration" in infra:
infra["configuration"] = json.loads(infra["configuration"])
df = Job(infrastructure=DataFlow(spec=infra), runtime=rt)
df.create(overwrite=self.config["execution"].get("overwrite", False))
job_id = df.id
run_id = df.run().id
print("DataFlow App ID:", job_id)
print("DataFlow Run ID:", run_id)
return {"job_id": job_id, "run_id": run_id}
[docs]
def cancel(self):
"""
Cancel DataFlow Run from OCID.
"""
if not self.config["execution"].get("run_id"):
raise ValueError("Can only cancel a DataFlow run.")
run_id = self.config["execution"]["run_id"]
with AuthContext(auth=self.auth_type, profile=self.profile):
DataFlowRun.from_ocid(run_id).delete()
[docs]
def delete(self):
"""
Delete DataFlow or DataFlow Run from OCID.
"""
if self.config["execution"].get("id"):
data_flow_id = self.config["execution"]["id"]
with AuthContext(auth=self.auth_type, profile=self.profile):
Job.from_dataflow_job(data_flow_id).delete()
elif self.config["execution"].get("run_id"):
run_id = self.config["execution"]["run_id"]
with AuthContext(auth=self.auth_type, profile=self.profile):
DataFlowRun.from_ocid(run_id).delete()
[docs]
def watch(self):
"""
Watch DataFlow Run from OCID.
"""
run_id = self.config["execution"]["run_id"]
interval = self.config["execution"].get("interval")
with AuthContext(auth=self.auth_type, profile=self.profile):
run = DataFlowRun.from_ocid(run_id)
run.watch(interval=interval)
[docs]
class DataFlowOperatorBackend(DataFlowBackend):
"""
Backend class to run operator on Data Flow Application.
Attributes
----------
runtime_config: (Dict)
The runtime config for the operator.
operator_config: (Dict)
The operator specification config.
operator_type: str
The type of the operator.
operator_version: str
The version of the operator.
job: Job
The Data Science Job.
"""
def __init__(self, config: Dict, operator_info: OperatorInfo = None) -> None:
"""
Instantiates the operator backend.
Parameters
----------
config: (Dict)
The configuration file containing operator's specification details and execution section.
operator_info: (OperatorInfo, optional)
The operator's detailed information extracted from the operator.__init__ file.
Will be extracted from the operator type in case if not provided.
"""
super().__init__(config=config or {})
self.job = None
self.runtime_config = self.config.get("runtime", {})
self.operator_config = {
**{
key: value
for key, value in self.config.items()
if key not in ("runtime", "infrastructure", "execution")
}
}
self.operator_type = self.operator_config.get("type", "unknown")
self.operator_version = self.operator_config.get("version", "unknown")
self.operator_info = operator_info
def _adjust_common_information(self):
"""Adjusts common information of the application."""
if self.job.name.lower().startswith("{job"):
self.job.with_name(
f"job_{self.operator_info.type.lower()}"
f"_{self.operator_version.lower()}"
)
self.job.runtime.with_maximum_runtime_in_minutes(
self.config["execution"].get("max_wait_time", 1200)
)
temp_dir = tempfile.mkdtemp()
# prepare run.py file to run the operator
script_file = os.path.join(
temp_dir, f"{self.operator_info.type}_{int(time.time())}_run.py"
)
operator_module = f"{OPERATOR_MODULE_PATH}.{self.operator_type}"
with open(script_file, "w") as fp:
fp.writelines(
"\n".join(
[
"import runpy",
f"runpy.run_module('{operator_module}', run_name='__main__')",
]
)
)
self.job.runtime.with_script_uri(script_file)
# propagate environment variables to the runtime config
env_vars = {
"OCI_IAM_TYPE": AuthType.RESOURCE_PRINCIPAL,
"OCIFS_IAM_TYPE": AuthType.RESOURCE_PRINCIPAL,
ENV_OPERATOR_ARGS: json.dumps(self.operator_config),
**(self.job.runtime.envs or {}),
}
runtime_config = self.job.runtime.configuration or dict()
existing_env_keys = {
key.upper()
.replace("SPARK.EXECUTORENV.", "")
.replace("SPARK.DRIVERENV.", "")
for key in runtime_config
if "SPARK.EXECUTORENV" in key.upper() or "SPARK.DRIVERENV" in key.upper()
}
for env_key, env_value in (env_vars or {}).items():
if env_key.upper() not in existing_env_keys:
runtime_config[f"spark.driverEnv.{env_key}"] = env_value
self.job.runtime.with_configuration(runtime_config)
[docs]
@print_watch_command
def run(self, **kwargs: Dict) -> Union[Dict, None]:
"""
Runs the operator on the Data Flow service.
"""
if not self.operator_info:
self.operator_info = OperatorLoader.from_uri(self.operator_type).load()
self.job = Job.from_dict(self.runtime_config).build()
# adjust job's common information
self._adjust_common_information()
# run the job if only it is not a dry run mode
if not self.config["execution"].get("dry_run"):
job = self.job.create()
logger.info(f"{'*' * 50} Data Flow Application {'*' * 50}")
logger.info(job)
job_run = job.run()
logger.info(f"{'*' * 50} DataFlow Application Run {'*' * 50}")
logger.info(job_run)
return {"job_id": job.id, "run_id": job_run.id}
else:
logger.info(f"{'*' * 50} DataFlow Application (Dry Run Mode) {'*' * 50}")
logger.info(self.job)
[docs]
class DataFlowRuntimeFactory(RuntimeFactory):
"""Data Flow runtime factory."""
_MAP = {
DataFlowRuntime().type: DataFlowRuntime,
DataFlowNotebookRuntime().type: DataFlowNotebookRuntime,
}