Source code for ads.opctl.backend.ads_dataflow

#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


import os
import json
import shlex
from typing import Dict, Union

from ads.opctl.backend.base import Backend
from ads.common.auth import create_signer, AuthContext
from ads.common.oci_client import OCIClientFactory

from ads.opctl.backend.base import (
    Backend,
    RuntimeFactory,
)

from ads.jobs import (
    Job,
    DataFlow,
    DataFlowRuntime,
    DataFlowNotebookRuntime,
    DataFlowRun,
)

REQUIRED_FIELDS = [
    "compartment_id",
    "driver_shape",
    "executor_shape",
    "logs_bucket_uri",
    "script_bucket",
]


[docs]class DataFlowBackend(Backend): def __init__(self, config: Dict) -> None: """ Initialize a MLJobBackend object given config dictionary. Parameters ---------- config: dict dictionary of configurations """ self.config = config self.oci_auth = create_signer( config["execution"].get("auth"), config["execution"].get("oci_config", None), config["execution"].get("oci_profile", None), ) self.auth_type = config["execution"].get("auth") self.profile = config["execution"].get("oci_profile", None) self.client = OCIClientFactory(**self.oci_auth).dataflow
[docs] def init( self, uri: Union[str, None] = None, overwrite: bool = False, runtime_type: Union[str, None] = None, **kwargs: Dict, ) -> Union[str, None]: """Generates a starter YAML specification for a Data Flow Application. Parameters ---------- overwrite: (bool, optional). Defaults to False. Overwrites the result specification YAML if exists. uri: (str, optional), Defaults to None. The filename to save the resulting specification template YAML. runtime_type: (str, optional). Defaults to None. The resource runtime type. **kwargs: Dict The optional arguments. Returns ------- Union[str, None] The YAML specification for the given resource if `uri` was not provided. `None` otherwise. """ with AuthContext(auth=self.auth_type, profile=self.profile): # define an job job = ( Job() .with_name( "{Job name. For MLflow, it will be replaced with the Project name}" ) .with_infrastructure( DataFlow(**(self.config.get("infrastructure", {}) or {})).init() ) .with_runtime( DataFlowRuntimeFactory.get_runtime( key=runtime_type or DataFlowRuntime().type ).init() ) ) note = ( "# This YAML specification was auto generated by the `ads opctl init` command.\n" "# The more details about the jobs YAML specification can be found in the ADS documentation:\n" "# https://accelerated-data-science.readthedocs.io/en/latest/user_guide/apachespark/dataflow.html \n\n" ) return job.to_yaml( uri=uri, overwrite=overwrite, note=note, filter_by_attribute_map=True, **kwargs, )
[docs] def apply(self): """ Create DataFlow and DataFlow Run from YAML. """ # TODO add the logic for build dataflow and dataflow run from YAML. raise NotImplementedError(f"`apply` hasn't been supported for data flow yet.")
[docs] def run(self) -> None: """ Create DataFlow and DataFlow Run from OCID or cli parameters. """ with AuthContext(auth=self.auth_type, profile=self.profile): if self.config["execution"].get("ocid", None): data_flow_id = self.config["execution"]["ocid"] run_id = Job.from_dataflow_job(data_flow_id).run().id else: infra = self.config.get("infrastructure", {}) if any(k not in infra for k in REQUIRED_FIELDS): missing = [k for k in REQUIRED_FIELDS if k not in infra] raise ValueError( f"Following fields are missing but are required for OCI DataFlow Jobs: {missing}. Please run `ads opctl configure`." ) rt_spec = {} rt_spec["scriptPathURI"] = os.path.join( self.config["execution"]["source_folder"], self.config["execution"]["entrypoint"], ) if "script_bucket" in infra: rt_spec["scriptBucket"] = infra.pop("script_bucket") if self.config["execution"].get("command"): rt_spec["args"] = shlex.split(self.config["execution"]["command"]) if self.config["execution"].get("archive"): rt_spec["archiveUri"] = self.config["execution"]["archive"] rt_spec["archiveBucket"] = infra.pop("archive_bucket", None) rt = DataFlowRuntime(rt_spec) if "configuration" in infra: infra["configuration"] = json.loads(infra["configuration"]) df = Job(infrastructure=DataFlow(spec=infra), runtime=rt) df.create(overwrite=self.config["execution"].get("overwrite", False)) job_id = df.id run_id = df.run().id print("DataFlow App ID:", job_id) print("DataFlow Run ID:", run_id) return {"job_id": job_id, "run_id": run_id}
[docs] def cancel(self): """ Cancel DataFlow Run from OCID. """ if not self.config["execution"].get("run_id"): raise ValueError("Can only cancel a DataFlow run.") run_id = self.config["execution"]["run_id"] with AuthContext(auth=self.auth_type, profile=self.profile): DataFlowRun.from_ocid(run_id).delete()
[docs] def delete(self): """ Delete DataFlow or DataFlow Run from OCID. """ if self.config["execution"].get("id"): data_flow_id = self.config["execution"]["id"] with AuthContext(auth=self.auth_type, profile=self.profile): Job.from_dataflow_job(data_flow_id).delete() elif self.config["execution"].get("run_id"): run_id = self.config["execution"]["run_id"] with AuthContext(auth=self.auth_type, profile=self.profile): DataFlowRun.from_ocid(run_id).delete()
[docs] def watch(self): """ Watch DataFlow Run from OCID. """ run_id = self.config["execution"]["run_id"] with AuthContext(auth=self.auth_type, profile=self.profile): run = DataFlowRun.from_ocid(run_id) run.watch()
[docs]class DataFlowRuntimeFactory(RuntimeFactory): """Data Flow runtime factory.""" _MAP = { DataFlowRuntime().type: DataFlowRuntime, DataFlowNotebookRuntime().type: DataFlowNotebookRuntime, }