Source code for ads.opctl.backend.ads_dataflow

#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


import json
import os
import shlex
import tempfile
import time
from typing import Dict, Union

from ads.common.auth import AuthContext, create_signer, AuthType
from ads.common.oci_client import OCIClientFactory
from ads.jobs import (
    DataFlow,
    DataFlowNotebookRuntime,
    DataFlowRun,
    DataFlowRuntime,
    Job,
)
from ads.opctl import logger
from ads.opctl.backend.base import Backend, RuntimeFactory
from ads.opctl.constants import OPERATOR_MODULE_PATH
from ads.opctl.decorator.common import print_watch_command
from ads.opctl.operator.common.const import ENV_OPERATOR_ARGS
from ads.opctl.operator.common.operator_loader import OperatorInfo, OperatorLoader

REQUIRED_FIELDS = [
    "compartment_id",
    "driver_shape",
    "executor_shape",
    "logs_bucket_uri",
    "script_bucket",
]


[docs] class DataFlowBackend(Backend): def __init__(self, config: Dict) -> None: """ Initialize a MLJobBackend object given config dictionary. Parameters ---------- config: dict dictionary of configurations """ self.config = config self.oci_auth = create_signer( config["execution"].get("auth"), config["execution"].get("oci_config", None), config["execution"].get("oci_profile", None), ) self.auth_type = config["execution"].get("auth") self.profile = config["execution"].get("oci_profile", None) self.client = OCIClientFactory(**self.oci_auth).dataflow
[docs] def init( self, uri: Union[str, None] = None, overwrite: bool = False, runtime_type: Union[str, None] = None, **kwargs: Dict, ) -> Union[str, None]: """Generates a starter YAML specification for a Data Flow Application. Parameters ---------- overwrite: (bool, optional). Defaults to False. Overwrites the result specification YAML if exists. uri: (str, optional), Defaults to None. The filename to save the resulting specification template YAML. runtime_type: (str, optional). Defaults to None. The resource runtime type. **kwargs: Dict The optional arguments. Returns ------- Union[str, None] The YAML specification for the given resource if `uri` was not provided. `None` otherwise. """ conda_slug = kwargs.get( "conda_slug", self.config["execution"].get("conda_slug", "conda_slug") ).lower() # if conda slug contains '/' then the assumption is that it is a custom conda pack # the conda prefix needs to be added if "/" in conda_slug: conda_slug = os.path.join( self.config["execution"].get( "conda_pack_os_prefix", "oci://bucket@namespace/conda_environments" ), conda_slug, ) RUNTIME_KWARGS_MAP = { DataFlowRuntime().type: { "conda_slug": conda_slug, "script_bucket": f"{self.config['infrastructure'].get('script_bucket','').rstrip('/')}", }, } with AuthContext(auth=self.auth_type, profile=self.profile): # define an job job = ( Job() .with_name("{Job name. For MLflow and Operator will be auto generated}") .with_infrastructure( DataFlow(**(self.config.get("infrastructure", {}) or {})).init() ) .with_runtime( DataFlowRuntimeFactory.get_runtime( key=runtime_type or DataFlowRuntime().type ).init(**{**kwargs, **RUNTIME_KWARGS_MAP[runtime_type]}) ) ) note = ( "# This YAML specification was auto generated by the `ads opctl init` command.\n" "# The more details about the jobs YAML specification can be found in the ADS documentation:\n" "# https://accelerated-data-science.readthedocs.io/en/latest/user_guide/apachespark/dataflow.html \n\n" ) return job.to_yaml( uri=uri, overwrite=overwrite, note=note, filter_by_attribute_map=False, **kwargs, )
[docs] def apply(self) -> Dict: """ Create DataFlow and DataFlow Run from YAML. """ # TODO add the logic for build dataflow and dataflow run from YAML. raise NotImplementedError(f"`apply` hasn't been supported for data flow yet.")
[docs] @print_watch_command def run(self) -> Dict: """ Create DataFlow and DataFlow Run from OCID or cli parameters. """ with AuthContext(auth=self.auth_type, profile=self.profile): if self.config["execution"].get("ocid", None): job_id = self.config["execution"]["ocid"] run_id = Job.from_dataflow_job(job_id).run().id else: infra = self.config.get("infrastructure", {}) if any(k not in infra for k in REQUIRED_FIELDS): missing = [k for k in REQUIRED_FIELDS if k not in infra] raise ValueError( f"Following fields are missing but are required for OCI DataFlow Jobs: {missing}. Please run `ads opctl configure`." ) rt_spec = {} rt_spec["scriptPathURI"] = os.path.join( self.config["execution"]["source_folder"], self.config["execution"]["entrypoint"], ) if "script_bucket" in infra: rt_spec["scriptBucket"] = infra.pop("script_bucket") if self.config["execution"].get("command"): rt_spec["args"] = shlex.split(self.config["execution"]["command"]) if self.config["execution"].get("archive"): rt_spec["archiveUri"] = self.config["execution"]["archive"] rt_spec["archiveBucket"] = infra.pop("archive_bucket", None) rt = DataFlowRuntime(rt_spec) if "configuration" in infra: infra["configuration"] = json.loads(infra["configuration"]) df = Job(infrastructure=DataFlow(spec=infra), runtime=rt) df.create(overwrite=self.config["execution"].get("overwrite", False)) job_id = df.id run_id = df.run().id print("DataFlow App ID:", job_id) print("DataFlow Run ID:", run_id) return {"job_id": job_id, "run_id": run_id}
[docs] def cancel(self): """ Cancel DataFlow Run from OCID. """ if not self.config["execution"].get("run_id"): raise ValueError("Can only cancel a DataFlow run.") run_id = self.config["execution"]["run_id"] with AuthContext(auth=self.auth_type, profile=self.profile): DataFlowRun.from_ocid(run_id).delete()
[docs] def delete(self): """ Delete DataFlow or DataFlow Run from OCID. """ if self.config["execution"].get("id"): data_flow_id = self.config["execution"]["id"] with AuthContext(auth=self.auth_type, profile=self.profile): Job.from_dataflow_job(data_flow_id).delete() elif self.config["execution"].get("run_id"): run_id = self.config["execution"]["run_id"] with AuthContext(auth=self.auth_type, profile=self.profile): DataFlowRun.from_ocid(run_id).delete()
[docs] def watch(self): """ Watch DataFlow Run from OCID. """ run_id = self.config["execution"]["run_id"] interval = self.config["execution"].get("interval") with AuthContext(auth=self.auth_type, profile=self.profile): run = DataFlowRun.from_ocid(run_id) run.watch(interval=interval)
[docs] class DataFlowOperatorBackend(DataFlowBackend): """ Backend class to run operator on Data Flow Application. Attributes ---------- runtime_config: (Dict) The runtime config for the operator. operator_config: (Dict) The operator specification config. operator_type: str The type of the operator. operator_version: str The version of the operator. job: Job The Data Science Job. """ def __init__(self, config: Dict, operator_info: OperatorInfo = None) -> None: """ Instantiates the operator backend. Parameters ---------- config: (Dict) The configuration file containing operator's specification details and execution section. operator_info: (OperatorInfo, optional) The operator's detailed information extracted from the operator.__init__ file. Will be extracted from the operator type in case if not provided. """ super().__init__(config=config or {}) self.job = None self.runtime_config = self.config.get("runtime", {}) self.operator_config = { **{ key: value for key, value in self.config.items() if key not in ("runtime", "infrastructure", "execution") } } self.operator_type = self.operator_config.get("type", "unknown") self.operator_version = self.operator_config.get("version", "unknown") self.operator_info = operator_info def _adjust_common_information(self): """Adjusts common information of the application.""" if self.job.name.lower().startswith("{job"): self.job.with_name( f"job_{self.operator_info.type.lower()}" f"_{self.operator_version.lower()}" ) self.job.runtime.with_maximum_runtime_in_minutes( self.config["execution"].get("max_wait_time", 1200) ) temp_dir = tempfile.mkdtemp() # prepare run.py file to run the operator script_file = os.path.join( temp_dir, f"{self.operator_info.type}_{int(time.time())}_run.py" ) operator_module = f"{OPERATOR_MODULE_PATH}.{self.operator_type}" with open(script_file, "w") as fp: fp.writelines( "\n".join( [ "import runpy", f"runpy.run_module('{operator_module}', run_name='__main__')", ] ) ) self.job.runtime.with_script_uri(script_file) # propagate environment variables to the runtime config env_vars = { "OCI_IAM_TYPE": AuthType.RESOURCE_PRINCIPAL, "OCIFS_IAM_TYPE": AuthType.RESOURCE_PRINCIPAL, ENV_OPERATOR_ARGS: json.dumps(self.operator_config), **(self.job.runtime.envs or {}), } runtime_config = self.job.runtime.configuration or dict() existing_env_keys = { key.upper() .replace("SPARK.EXECUTORENV.", "") .replace("SPARK.DRIVERENV.", "") for key in runtime_config if "SPARK.EXECUTORENV" in key.upper() or "SPARK.DRIVERENV" in key.upper() } for env_key, env_value in (env_vars or {}).items(): if env_key.upper() not in existing_env_keys: runtime_config[f"spark.driverEnv.{env_key}"] = env_value self.job.runtime.with_configuration(runtime_config)
[docs] @print_watch_command def run(self, **kwargs: Dict) -> Union[Dict, None]: """ Runs the operator on the Data Flow service. """ if not self.operator_info: self.operator_info = OperatorLoader.from_uri(self.operator_type).load() self.job = Job.from_dict(self.runtime_config).build() # adjust job's common information self._adjust_common_information() # run the job if only it is not a dry run mode if not self.config["execution"].get("dry_run"): job = self.job.create() logger.info(f"{'*' * 50} Data Flow Application {'*' * 50}") logger.info(job) job_run = job.run() logger.info(f"{'*' * 50} DataFlow Application Run {'*' * 50}") logger.info(job_run) return {"job_id": job.id, "run_id": job_run.id} else: logger.info(f"{'*' * 50} DataFlow Application (Dry Run Mode) {'*' * 50}") logger.info(self.job)
[docs] class DataFlowRuntimeFactory(RuntimeFactory): """Data Flow runtime factory.""" _MAP = { DataFlowRuntime().type: DataFlowRuntime, DataFlowNotebookRuntime().type: DataFlowNotebookRuntime, }