Source code for ads.catalog.model

#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2020, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import warnings

warnings.warn(
    (
        "The `ads.catalog.model` is deprecated in `oracle-ads 2.6.9` and will be removed in `oracle-ads 3.0`. "
        "Use framework specific Model utility class for saving and deploying model. "
        "Check https://accelerated-data-science.readthedocs.io/en/latest/user_guide/model_registration/quick_start.html"
    ),
    DeprecationWarning,
    stacklevel=2,
)

import json
import os
import shutil
import tempfile
import time
import uuid
from typing import Dict, Optional, Union
from zipfile import ZipFile

import pandas as pd
import yaml
from ads.catalog.summary import SummaryList
from ads.common import auth, logger, oci_client, utils
from ads.common.decorator.deprecate import deprecated
from ads.common.decorator.runtime_dependency import (
    runtime_dependency,
    OptionalDependency,
)
from ads.common.model_artifact import ConflictStrategy, ModelArtifact
from ads.model.model_metadata import (
    METADATA_SIZE_LIMIT,
    MetadataSizeTooLarge,
    ModelCustomMetadata,
    ModelTaxonomyMetadata,
)
from ads.common.object_storage_details import ObjectStorageDetails
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
from ads.config import (
    NB_SESSION_COMPARTMENT_OCID,
    OCI_ODSC_SERVICE_ENDPOINT,
    OCI_REGION_METADATA,
    PROJECT_OCID,
)
from ads.dataset.progress import TqdmProgressBar
from ads.feature_engineering.schema import Schema
from ads.model.model_version_set import ModelVersionSet, _extract_model_version_set_id
from ads.model.deployment.model_deployer import ModelDeployer
from oci.data_science.data_science_client import DataScienceClient
from oci.data_science.models import (
    ArtifactExportDetailsObjectStorage,
    ArtifactImportDetailsObjectStorage,
    CreateModelDetails,
    ExportModelArtifactDetails,
    ImportModelArtifactDetails,
)
from oci.data_science.models import Model as OCIModel
from oci.data_science.models import ModelSummary, WorkRequest
from oci.data_science.models.model_provenance import ModelProvenance
from oci.data_science.models.update_model_details import UpdateModelDetails
from oci.exceptions import ServiceError
from oci.identity import IdentityClient

_UPDATE_MODEL_DETAILS_ATTRIBUTES = [
    "display_name",
    "description",
    "freeform_tags",
    "defined_tags",
    "model_version_set_id",
    "version_label",
]
_MODEL_PROVENANCE_ATTRIBUTES = ModelProvenance().swagger_types.keys()
_ETAG_KEY = "ETag"
_MAX_ARTIFACT_SIZE_IN_BYTES = 2147483648  # 2GB
_WORK_REQUEST_INTERVAL_IN_SEC = 3


[docs] class ModelWithActiveDeploymentError(Exception): # pragma: no cover pass
[docs] class ModelArtifactSizeError(Exception): # pragma: no cover def __init__(self, max_artifact_size: str): super().__init__( f"The model artifacts size is greater than `{max_artifact_size}`. " "The `bucket_uri` needs to be specified to " "copy artifacts to the object storage bucket. " "Example: `bucket_uri=oci://<bucket_name>@<namespace>/prefix/`" )
def _get_etag(response) -> str: """Gets etag from the response.""" if _ETAG_KEY in response.headers: return response.headers[_ETAG_KEY].split("--")[0] return None
[docs] class ModelSummaryList(SummaryList): """Model Summary List which represents a list of Model Object. Methods ------- sort_by(self, columns, reverse=False) Performs a multi-key sort on a particular set of columns and returns the sorted ModelSummaryList. Results are listed in a descending order by default. filter(self, selection, instance=None) Filters the model list according to a lambda filter function, or list comprehension. """ @deprecated( "2.6.6", details="Use framework specific Model utility class for saving and deploying model. Check https://accelerated-data-science.readthedocs.io/en/latest/user_guide/model_registration/quick_start.html", ) def __init__( self, model_catalog, model_list, response=None, datetime_format=utils.date_format, ): super(ModelSummaryList, self).__init__( model_list, datetime_format=datetime_format ) self.mc = model_catalog self.response = response def __add__(self, rhs): return ModelSummaryList( self.mc, list.__add__(self, rhs), datetime_format=self.datetime_format ) def __getitem__(self, item): return self.mc.get_model(super(ModelSummaryList, self).__getitem__(item).id)
[docs] def sort_by(self, columns, reverse=False): """ Performs a multi-key sort on a particular set of columns and returns the sorted ModelSummaryList. Results are listed in a descending order by default. Parameters ---------- columns: List of string A list of columns which are provided to sort on reverse: Boolean (defaults to false) If you'd like to reverse the results (for example, to get ascending instead of descending results) Returns ------- ModelSummaryList: A sorted ModelSummaryList """ return ModelSummaryList( self.mc, self._sort_by(columns, reverse=reverse), datetime_format=self.datetime_format, )
[docs] def filter(self, selection, instance=None): """ Filters the model list according to a lambda filter function, or list comprehension. Parameters ---------- selection: lambda function filtering model instances, or a list-comprehension function of list filtering projects instance: list, optional list to filter, optional, defaults to self Returns ------- ModelSummaryList: A filtered ModelSummaryList """ instance = instance if instance is not None else self if callable(selection): res = list(filter(selection, instance)) # lambda filtering if len(res) == 0: print("No models found") return return ModelSummaryList(self.mc, res, datetime_format=self.datetime_format) elif isinstance(selection, list): # list comprehension if len(selection) == 0: print("No models found") return return ModelSummaryList( self.mc, selection, datetime_format=self.datetime_format ) else: raise ValueError( "Filter selection must be a function or a ProjectSummaryList" )
[docs] class Model: """Class that represents the ADS implementation of model catalog item. Converts the metadata and schema from OCI implememtation to ADS implementation. Methods ------- to_dataframe Converts model to dataframe format. show_in_notebook Shows model in the notebook in dataframe or YAML representation. activate Activates model. deactivate Deactivates model. commit Commits the changes made to the model. rollback Rollbacks the changes made to the model. load_model Loads the model from the model catalog based on model ID. """ _FIELDS_TO_DECORATE = [ "schema_input", "schema_output", "metadata_custom", "metadata_taxonomy", ] _NEW_ATTRIBUTES = [ "input_schema", "output_schema", "custom_metadata_list", "defined_metadata_list", ] @deprecated( "2.6.6", details="Use framework specific Model utility class for saving and deploying model. Check https://accelerated-data-science.readthedocs.io/en/latest/user_guide/model_registration/quick_start.html", ) def __init__( self, model: OCIModel, model_etag: str, provenance_metadata: ModelProvenance, provenance_etag: str, ds_client: DataScienceClient, identity_client: IdentityClient, ) -> None: """Initializes the Model. Parameters ---------- model: OCIModel The OCI model object. model_etag: str The model ETag. provenance_metadata: ModelProvenance The model provenance metadata. provenance_etag: str The model provenance metadata ETag. ds_client: DataScienceClient The Oracle DataScience client. identity_client: IdentityClient The Orcale Identity Service Client. """ self.ds_client = ds_client self.identity_client = identity_client self.user_name = "" self._etag = model_etag self._provenance_metadata_etag = provenance_etag self.provenance_metadata = provenance_metadata self._extract_oci_model(model) self._extract_user_name(model) def _extract_oci_model(self, model: OCIModel) -> None: """Extracts the model information from OCI model.""" for key in model.swagger_types.keys(): if key not in self._NEW_ATTRIBUTES: val = getattr(model, key) setattr(self, key, val) self.schema_input = self._extract_schema("input_schema", model) self.schema_output = self._extract_schema("output_schema", model) self.metadata_custom = self._extract_metadata_custom(model) self.metadata_taxonomy = self._extract_metadata_taxonomy(model) self.swagger_types = model.swagger_types self.lifecycle_state = model.lifecycle_state def _validate_metadata(self): self.metadata_custom.validate() self.metadata_taxonomy.validate() total_size = self.metadata_custom.size() + self.metadata_taxonomy.size() if total_size > METADATA_SIZE_LIMIT: raise MetadataSizeTooLarge(total_size) return True def _extract_user_name(self, model: OCIModel) -> None: try: user = self.identity_client.get_user(model.created_by) self.user_name = user.data.name except: pass @staticmethod def _extract_schema(key, model): """Extracts the input and output schema.""" schema = Schema() if hasattr(model, key): try: schema = ( Schema.from_dict(json.loads(getattr(model, key))) if getattr(model, key) else Schema() ) except Exception as e: logger.warning(str(e)) return schema @staticmethod def _extract_metadata_taxonomy(model): """Extracts the taxonomy metadata.""" metadata_taxonomy = ModelTaxonomyMetadata() if hasattr(model, "defined_metadata_list"): try: metadata_taxonomy = ModelTaxonomyMetadata._from_oci_metadata( model.defined_metadata_list ) except Exception as e: logger.warning(str(e)) return metadata_taxonomy @staticmethod def _extract_metadata_custom(model): """Extracts the custom metadata.""" metadata_custom = ModelCustomMetadata() if hasattr(model, "custom_metadata_list"): try: metadata_custom = ModelCustomMetadata._from_oci_metadata( model.custom_metadata_list ) except Exception as e: logger.warning(str(e)) return metadata_custom def _to_dict(self): """Converts the model attributes to dictionary format.""" attributes = {} for key in _UPDATE_MODEL_DETAILS_ATTRIBUTES: if hasattr(self, key): attributes[key] = getattr(self, key) if self.provenance_metadata is not None: attributes.update( { key: getattr(self.provenance_metadata, key) for key in _MODEL_PROVENANCE_ATTRIBUTES } ) for field in self._FIELDS_TO_DECORATE: attributes[field] = getattr(self, field).to_dict() return attributes def _to_yaml(self): """Converts the model attributes to yaml format.""" attributes = self._to_dict() return yaml.safe_dump(attributes)
[docs] def to_dataframe(self) -> pd.DataFrame: """ Converts the model to dataframe format. Returns ------- panadas.DataFrame Pandas dataframe. """ attributes = self._to_dict() df = pd.DataFrame.from_dict(attributes, orient="index", columns=[""]).dropna() return df
[docs] @runtime_dependency(module="IPython", install_from=OptionalDependency.NOTEBOOK) def show_in_notebook(self, display_format: str = "dataframe") -> None: """Shows model in dataframe or yaml format. Supported formats: `dataframe` and `yaml`. Defaults to dataframe format. Returns ------- None Nothing. """ if display_format == "dataframe": from IPython.core.display import display display(self.to_dataframe()) elif display_format == "yaml": print(self._to_yaml()) else: NotImplementedError( "`display_format` is not supported. Choose 'dataframe' or 'yaml'" )
def _repr_html_(self): """Shows model in dataframe format.""" return ( self.to_dataframe().style.set_properties(**{"margin-left": "0px"}).render() ) def __repr__(self): """Shows model in dataframe format.""" return ( self.to_dataframe().style.set_properties(**{"margin-left": "0px"}).render() )
[docs] def activate(self) -> None: """Activates model. Returns ------- None Nothing. """ self.lifecycle_state = OCIModel.LIFECYCLE_STATE_ACTIVE
[docs] def deactivate(self) -> None: """Deactivates model. Returns ------- None Nothing. """ self.lifecycle_state = OCIModel.LIFECYCLE_STATE_INACTIVE
[docs] def commit(self, force: bool = True) -> None: """Commits model changes. Parameters ---------- force: bool If True, any remote changes on this model would be lost. Returns ------- None Nothing. """ self._validate_metadata() attributes = { key: getattr(self, key) for key in _UPDATE_MODEL_DETAILS_ATTRIBUTES } if hasattr(self, "metadata_custom"): attributes["custom_metadata_list"] = self.metadata_custom._to_oci_metadata() if hasattr(self, "metadata_taxonomy"): attributes[ "defined_metadata_list" ] = self.metadata_taxonomy._to_oci_metadata() update_model_details = UpdateModelDetails(**attributes) # freeform_tags=self._model.freeform_tags, defined_tags=self._model.defined_tags) # update model # https://docs.oracle.com/en-us/iaas/Content/API/Concepts/usingapi.htm#eleven # The API supports etags for the purposes of optimistic concurrency control. # The GET and POST calls return an etag response header with a value you should store. # When you later want to update or delete the resource, set the if-match header to the ETag # you received for the resource. The resource will then be updated or deleted # only if the ETag you provide matches the current value of that resource's ETag. kwargs = {} if not force: kwargs["if_match"] = self._etag self.ds_client.update_model( self.id, update_model_details=update_model_details, **kwargs ) # store the lifecycle status, as updating the model will delete info not included in "update_model_details" lifecycle_status = self.lifecycle_state self.__dict__.update(self._load_model().__dict__) self.lifecycle_state = lifecycle_status # update model state if not force: kwargs["if_match"] = self._etag if self.lifecycle_state == OCIModel.LIFECYCLE_STATE_ACTIVE: self.ds_client.activate_model(self.id, **kwargs) elif self.lifecycle_state == OCIModel.LIFECYCLE_STATE_INACTIVE: self.ds_client.deactivate_model(self.id, **kwargs) self.__dict__.update(self._load_model().__dict__) # update model provenance if self.provenance_metadata != ModelProvenance(): if not force: kwargs["if_match"] = self._provenance_metadata_etag response = self.ds_client.update_model_provenance( self.id, self.provenance_metadata, **kwargs ) # get model etag again, as updating model provenance changes it self.__dict__.update(self._load_model().__dict__)
@staticmethod def _get_provenance_metadata(ds_client: DataScienceClient, model_id: str): """Gets provenance information for specified model.""" try: provenance_response = ds_client.get_model_provenance(model_id) except ServiceError as e: if e.status == 404: try: provenance_response = ds_client.create_model_provenance( model_id, ModelProvenance() ) except ServiceError as e2: raise e2 elif e.status == 409: print("The model has been deleted.") raise e else: raise e return provenance_response
[docs] @classmethod def load_model( cls, ds_client: DataScienceClient, identity_client: IdentityClient, model_id: str, ) -> "Model": """Loads the model from the model catalog based on model ID. Parameters ---------- ds_client: DataScienceClient The Oracle DataScience client. identity_client: IdentityClient The Orcale Identity Service Client. model_id: str The model ID. Returns ------- Model The ADS model catalog item. Raises ------ ServiceError: If error occures while getting model from server. KeyError: If model not found. ValueError: If error occures while getting model provenance mettadata from server. """ try: model_response = ds_client.get_model(model_id) except ServiceError as e: if e.status == 404: raise KeyError(e.message) from e raise e try: provenance_response = cls._get_provenance_metadata(ds_client, model_id) except Exception as e: raise ValueError( f"Unable to fetch model provenance metadata for model {model_id}" ) return cls( model_response.data, _get_etag(model_response), provenance_response.data, _get_etag(provenance_response), ds_client, identity_client, )
def _load_model(self): """Loads the model from model catalog.""" return self.load_model(self.ds_client, self.identity_client, self.id)
[docs] def rollback(self) -> None: """Rollbacks the changes made to the model. Returns ------- None Nothing. """ self.__dict__.update(self._load_model().__dict__)
[docs] class ModelCatalog: """ Allows to list, load, update, download, upload and delete models from model catalog. Methods ------- get_model(self, model_id) Loads the model from the model catalog based on model_id. list_models(self, project_id=None, include_deleted=False, datetime_format=utils.date_format, **kwargs) Lists all models in a given compartment, or in the current project if project_id is specified. list_model_deployment(self, model_id, config=None, tenant_id=None, limit=500, page=None, **kwargs) Gets the list of model deployments by model Id across the compartments. update_model(self, model_id, update_model_details=None, **kwargs) Updates a model with given model_id, using the provided update data. delete_model(self, model, **kwargs) Deletes the model based on model_id. download_model(self, model_id, target_dir, force_overwrite=False, install_libs=False, conflict_strategy=ConflictStrategy.IGNORE) Downloads the model from model_dir to target_dir based on model_id. upload_model(self, model_artifact, provenance_metadata=None, project_id=None, display_name=None, description=None) Uploads the model artifact to cloud storage. """ @deprecated( "2.6.6", details="Use framework specific Model utility class for saving and deploying model. Check https://accelerated-data-science.readthedocs.io/en/latest/user_guide/model_registration/quick_start.html", ) def __init__( self, compartment_id: Optional[str] = None, ds_client_auth: Optional[dict] = None, identity_client_auth: Optional[dict] = None, timeout: Optional[int] = None, ds_client: Optional[DataScienceClient] = None, identity_client: Optional[IdentityClient] = None, ): """Initializes model catalog instance. Parameters ---------- compartment_id : (str, optional). Defaults to None. Model compartment OCID. If `None`, the `config.NB_SESSION_COMPARTMENT_OCID` would be used. ds_client_auth : (dict, optional). Defaults to None. The default authetication is set using `ads.set_auth` API. If you need to override the default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate authentication signer and kwargs required to instantiate DataScienceClient object. identity_client_auth : (dict, optional). Defaults to None. The default authetication is set using `ads.set_auth` API. If you need to override the default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate authentication signer and kwargs required to instantiate IdentityClient object. timeout: (int, optional). Defaults to 10 seconds. The connection timeout in seconds for the client. ds_client: DataScienceClient The Oracle DataScience client. identity_client: IdentityClient The Orcale Identity Service Client. Raises ------ ValueError If compartment ID not specified. TypeError If timeout not an integer. """ self.compartment_id = ( NB_SESSION_COMPARTMENT_OCID if compartment_id is None else compartment_id ) if self.compartment_id is None: raise ValueError("compartment_id needs to be specified.") if timeout and not isinstance(timeout, int): raise TypeError("Timeout must be an integer.") self.ds_client_auth = ds_client_auth self.identity_client_auth = identity_client_auth self.ds_client = ds_client self.identity_client = identity_client if not self.ds_client: self.ds_client_auth = ( ds_client_auth if ds_client_auth else auth.default_signer( {"service_endpoint": OCI_ODSC_SERVICE_ENDPOINT} ) ) if timeout: if not self.ds_client_auth.get("client_kwargs"): self.ds_client_auth["client_kwargs"] = {} self.ds_client_auth["client_kwargs"]["timeout"] = timeout self.ds_client = oci_client.OCIClientFactory( **self.ds_client_auth ).data_science if not self.identity_client: self.identity_client_auth = ( identity_client_auth if identity_client_auth else auth.default_signer( {"service_endpoint": OCI_ODSC_SERVICE_ENDPOINT} ) ) if timeout: if not self.identity_client_auth.get("client_kwargs"): self.identity_client_auth["client_kwargs"] = {} self.identity_client_auth["client_kwargs"]["timeout"] = timeout self.identity_client = oci_client.OCIClientFactory( **self.identity_client_auth ).identity self.short_id_index = {} def __getitem__(self, model_id): # pragma: no cover return self.get_model(model_id) def __contains__(self, model_id): # pragma: no cover try: return self.get_model(model_id) is not None except KeyError: return False except Exception: raise def __iter__(self): # pragma: no cover return self.list_models().__iter__() def __len__(self): # pragma: no cover return len(self.list_models())
[docs] def get_model(self, model_id): """ Loads the model from the model catalog based on model_id. Parameters ---------- model_id: str, required The model ID. Returns ------- ads.catalog.Model The ads.catalog.Model with the matching ID. """ if not model_id.startswith("ocid"): model_id = self.short_id_index[model_id] self.id = model_id return Model.load_model(self.ds_client, self.identity_client, model_id)
[docs] def list_models( self, project_id: str = None, include_deleted: bool = False, datetime_format: str = utils.date_format, **kwargs, ): """ Lists all models in a given compartment, or in the current project if project_id is specified. Parameters ---------- project_id: str The project_id of model. include_deleted: bool, optional, default=False Whether to include deleted models in the returned list. datetime_format: str, optional, default: '%Y-%m-%d %H:%M:%S' Change format for date time fields. Returns ------- ModelSummaryList A list of models. """ try: list_models_response = self.ds_client.list_models( self.compartment_id, project_id=project_id, **kwargs ) if list_models_response.data is None or len(list_models_response.data) == 0: print("No model found.") return except ServiceError as se: if se.status == 404: raise KeyError(se.message) from se else: raise model_list_filtered = [ Model( model=model, model_etag=None, provenance_metadata=None, provenance_etag=None, ds_client=self.ds_client, identity_client=self.identity_client, ) for model in list_models_response.data if include_deleted or model.lifecycle_state != ModelSummary.LIFECYCLE_STATE_DELETED ] # handle empty list if model_list_filtered is None or len(model_list_filtered) == 0: print("No model found.") return [] msl = ModelSummaryList( self, model_list_filtered, list_models_response, datetime_format=datetime_format, ) self.short_id_index.update(msl.short_id_index) return msl
[docs] def list_model_deployment( self, model_id: str, config: dict = None, tenant_id: str = None, limit: int = 500, page: str = None, **kwargs, ): """ Gets the list of model deployments by model Id across the compartments. Parameters ---------- model_id: str The model ID. config: dict (optional) Configuration keys and values as per SDK and Tool Configuration. The from_file() method can be used to load configuration from a file. Alternatively, a dict can be passed. You can validate_config the dict using validate_config(). Defaults to None. tenant_id: str (optional) The tenancy ID, which can be used to specify a different tenancy (for cross-tenancy authorization) when searching for resources in a different tenancy. Defaults to None. limit: int (optional) The maximum number of items to return. The value must be between 1 and 1000. Defaults to 500. page: str (optional) The page at which to start retrieving results. Returns ------- The list of model deployments. """ query = f"query datasciencemodeldeployment resources where ModelId='{model_id}'" return OCIResource.search( query, type=SEARCH_TYPE.STRUCTURED, config=config, tenant_id=tenant_id, limit=limit, page=page, **kwargs, )
[docs] def update_model(self, model_id, update_model_details=None, **kwargs) -> Model: """ Updates a model with given model_id, using the provided update data. Parameters ---------- model_id: str The model ID. update_model_details: UpdateModelDetails Contains the update model details data to apply. Mandatory unless kwargs are supplied. kwargs: dict, optional Update model details can be supplied instead as kwargs. Returns ------- Model The ads.catalog.Model with the matching ID. """ if not model_id.startswith("ocid"): model_id = self.short_id_index[model_id] if update_model_details is None: update_model_details = UpdateModelDetails( **{ k: v for k, v in kwargs.items() if k in _UPDATE_MODEL_DETAILS_ATTRIBUTES } ) update_model_details.compartment_id = self.compartment_id # filter kwargs removing used keys kwargs = { k: v for k, v in kwargs.items() if k not in _UPDATE_MODEL_DETAILS_ATTRIBUTES } update_model_response = self.ds_client.update_model( model_id, update_model_details, **kwargs ) provenance_response = Model._get_provenance_metadata(self.ds_client, model_id) return Model( model=update_model_response.data, model_etag=_get_etag(update_model_response), provenance_metadata=provenance_response.data, provenance_etag=_get_etag(provenance_response), ds_client=self.ds_client, identity_client=self.identity_client, )
[docs] def delete_model(self, model: Union[str, "ads.catalog.Model"], **kwargs) -> bool: """ Deletes the model from Model Catalog. Parameters ---------- model: Union[str, "ads.catalog.Model"] The OCID of the model to delete as a string, or a `ads.catalog.Model` instance. kwargs: delete_associated_model_deployment: (bool, optional). Defaults to `False`. Whether associated model deployments need to be deletet or not. Returns ------- bool `True` if the model was successfully deleted. Raises ------ ModelWithActiveDeploymentError If model has active model deployments ant inout attribute `delete_associated_model_deployment` set to `False`. """ model_id = ( model.id if isinstance(model, Model) else self.short_id_index[model] if not model.startswith("ocid") else model ) delete_associated_model_deployment = kwargs.pop( "delete_associated_model_deployment", None ) active_deployments = tuple( item for item in self.list_model_deployment(model_id) if item.lifecycle_state == "ACTIVE" ) if len(active_deployments) > 0: if not delete_associated_model_deployment: raise ModelWithActiveDeploymentError( f"The model `{model_id}` has active model deployments: " f"{[item.identifier for item in active_deployments]}. " "Delete associated model deployments before deleting the model or " "set the `delete_associated_model_deployment` attribute to `True`." ) logger.info( f"Deleting model deployments associated with the model `{model_id}`." ) for oci_model_deployment in active_deployments: ( ModelDeployer(config=self.ds_client_auth) .get_model_deployment(oci_model_deployment.identifier) .delete(wait_for_completion=True) ) logger.info(f"Deleting model `{model_id}`.") self.ds_client.delete_model(model_id, **kwargs) return True
def _download_artifact( self, model_id: str, target_dir: str, force_overwrite: Optional[bool] = False, bucket_uri: Optional[str] = None, remove_existing_artifact: Optional[bool] = True, ) -> None: """ Downloads the model artifacts from model catalog to target_dir based on `model_id`. Parameters ---------- model_id: str The OCID of the model to download. target_dir: str The target location of model artifacts. force_overwrite: (bool, optional). Defaults to `False`. Overwrite target directory if exists. bucket_uri: (str, optional). Defaults to None. The OCI Object Storage URI where model artifacts will be copied to. The `bucket_uri` is only necessary for downloading large artifacts with size is greater than 2GB. Example: `bucket_uri=oci://<bucket_name>@<namespace>/prefix/`. remove_existing_artifact: (bool, optional). Defaults to `True`. Whether artifacts uploaded to object storage bucket need to be removed or not. Raises ------ ValueError If targeted directory does not exist. KeyError If model id not found. Returns ------- None Nothing """ if os.path.exists(target_dir) and os.listdir(target_dir): if not force_overwrite: raise ValueError( f"The `{target_dir}` directory already exists. " "Set `force_overwrite` to `True` if you wish to overwrite." ) shutil.rmtree(target_dir) with utils.get_progress_bar(6) as progress: progress.update("Getting information about model artifacts") # If the size of artifacts greater than 2GB, then artifacts # need to be imported to OS bucket at first. try: model_artifact_info = self.ds_client.head_model_artifact( model_id=model_id ).headers except ServiceError as ex: if ex.status == 404: raise KeyError(f"The model `{model_id}` not found.") from ex raise artifact_size = int(model_artifact_info.get("content-length")) if bucket_uri or artifact_size > _MAX_ARTIFACT_SIZE_IN_BYTES: if not bucket_uri: raise ModelArtifactSizeError( utils.human_size(_MAX_ARTIFACT_SIZE_IN_BYTES) ) self._download_large_artifact( model_id=model_id, target_dir=target_dir, bucket_uri=os.path.join(bucket_uri, f"{model_id}.zip"), progress=progress, remove_existing_artifact=remove_existing_artifact, ) else: self._download_small_artifact( model_id=model_id, target_dir=target_dir, progress=progress ) progress.update() progress.update("Done") def _download_large_artifact( self, model_id: str, target_dir: str, bucket_uri: str, progress: TqdmProgressBar, remove_existing_artifact: Optional[bool] = True, ) -> None: """ Downloads the model artifacts from model catalog to target_dir based on `model_id`. This method is used for artifacts with size greater than 2GB. Parameters ---------- model_id: str The OCID of the model to download. target_dir: str The target location of model artifacts. bucket_uri: str The OCI Object Storage URI where model artifacts will be copied to. The `bucket_uri` is only necessary for downloading large artifacts with size is greater than 2GB. Example: `bucket_uri=oci://<bucket_name>@<namespace>/prefix/`. progress: TqdmProgressBar The progress bar. remove_existing_artifact: (bool, optional). Defaults to `True`. Whether artifacts uploaded to object storage bucket need to be removed or not. Returns ------- None Nothing. """ progress.update(f"Importing model artifacts from model catalog") self._import_model_artifact(model_id=model_id, bucket_uri=bucket_uri) progress.update("Copying model artifacts to the artifact directory") with tempfile.TemporaryDirectory() as temp_dir: zip_file_path = os.path.join(temp_dir, f"{str(uuid.uuid4())}.zip") zip_file_path = utils.copy_file( uri_src=bucket_uri, uri_dst=zip_file_path, progressbar_description="Copying model artifacts to the artifact directory", ) progress.update("Extracting model artifacts") with ZipFile(zip_file_path) as zip_file: zip_file.extractall(target_dir) if remove_existing_artifact: progress.update("Removing temporary artifacts") utils.remove_file(bucket_uri, self.ds_client_auth) else: progress.update() def _import_model_artifact( self, model_id: str, bucket_uri: str, ): """Imports model artifact from the model catalog to the object storage bucket. This method is used for the case when the artifact size is greater than 2GB. """ bucket_details = ObjectStorageDetails.from_path(bucket_uri) response = self.ds_client.import_model_artifact( model_id=model_id, import_model_artifact_details=ImportModelArtifactDetails( artifact_import_details=ArtifactImportDetailsObjectStorage( namespace=bucket_details.namespace, destination_bucket=bucket_details.bucket, destination_object_name=bucket_details.filepath, destination_region=self._region, ) ), ) self._wait_for_work_request( response=response, first_step_description="Preparing to import model artifacts from the model catalog", num_steps=3, ) def _download_small_artifact( self, model_id: str, target_dir: str, progress: TqdmProgressBar, ) -> None: """ Downloads the model artifacts from model catalog to target_dir based on `model_id`. This method is used for the artifacts with size less than 2GB. Parameters ---------- model_id: str The OCID of the model to download. target_dir: str The target location of model artifacts. progress: TqdmProgressBar The progress bar. Returns ------- None Nothing """ progress.update("Importing model artifacts from catalog") try: zip_contents = self.ds_client.get_model_artifact_content( model_id ).data.content except ServiceError as ex: if ex.status == 404: raise KeyError(ex.message) from ex raise with tempfile.TemporaryDirectory() as temp_dir: progress.update("Copying model artifacts to the artifact directory") zip_file_path = os.path.join(temp_dir, f"{str(uuid.uuid4())}.zip") with open(zip_file_path, "wb") as zip_file: zip_file.write(zip_contents) progress.update("Extracting model artifacts") with ZipFile(zip_file_path) as zip_file: zip_file.extractall(target_dir)
[docs] @deprecated( "2.5.9", details="Instead use `ads.common.model_artifact.ModelArtifact.from_model_catalog()`.", ) def download_model( self, model_id: str, target_dir: str, force_overwrite: bool = False, install_libs: bool = False, conflict_strategy=ConflictStrategy.IGNORE, bucket_uri: Optional[str] = None, remove_existing_artifact: Optional[bool] = True, ): """ Downloads the model from model_dir to target_dir based on model_id. Parameters ---------- model_id: str The OCID of the model to download. target_dir: str The target location of model after download. force_overwrite: bool Overwrite target_dir if exists. install_libs: bool, default: False Install the libraries specified in ds-requirements.txt which are missing in the current environment. conflict_strategy: ConflictStrategy, default: IGNORE Determines how to handle version conflicts between the current environment and requirements of model artifact. Valid values: "IGNORE", "UPDATE" or ConflictStrategy. IGNORE: Use the installed version in case of conflict UPDATE: Force update dependency to the version required by model artifact in case of conflict bucket_uri: (str, optional). Defaults to None. The OCI Object Storage URI where model artifacts will be copied to. The `bucket_uri` is only necessary for downloading large artifacts with size is greater than 2GB. Example: `oci://<bucket_name>@<namespace>/prefix/`. remove_existing_artifact: (bool, optional). Defaults to `True`. Whether artifacts uploaded to object storage bucket need to be removed or not. Returns ------- ModelArtifact A ModelArtifact instance. """ self._download_artifact( model_id, target_dir, force_overwrite, bucket_uri=bucket_uri, remove_existing_artifact=remove_existing_artifact, ) result = ModelArtifact( target_dir, conflict_strategy=conflict_strategy, install_libs=install_libs, reload=False, ) try: model_response = self.ds_client.get_model(model_id) except ServiceError as e: if e.status == 404: raise KeyError(e.message) from e raise e if hasattr(model_response.data, "custom_metadata_list"): try: result.metadata_custom = ModelCustomMetadata._from_oci_metadata( model_response.data.custom_metadata_list ) except: result.metadata_custom = ModelCustomMetadata() if hasattr(model_response.data, "defined_metadata_list"): try: result.metadata_taxonomy = ModelTaxonomyMetadata._from_oci_metadata( model_response.data.defined_metadata_list ) except: result.metadata_taxonomy = ModelTaxonomyMetadata() if hasattr(model_response.data, "input_schema"): try: result.schema_input = Schema.from_dict( json.loads(model_response.data.input_schema) if model_response.data.input_schema != "" else Schema() ) except: result.schema_input = Schema() if hasattr(model_response.data, "output_schema"): try: result.schema_output = Schema.from_dict( json.loads(model_response.data.output_schema) if model_response.data.output_schema != "" else Schema() ) except: result.schema_output = Schema() if not install_libs: logger.warning( "Libraries in `ds-requirements.txt` were not installed. " "Use `install_requirements()` to install the required dependencies." ) return result
[docs] @deprecated( "2.6.6", details="Use framework specific Model utility class for saving and deploying model. Check https://accelerated-data-science.readthedocs.io/en/latest/user_guide/model_registration/quick_start.html", ) def upload_model( self, model_artifact: ModelArtifact, provenance_metadata: Optional[ModelProvenance] = None, project_id: Optional[str] = None, display_name: Optional[str] = None, description: Optional[str] = None, freeform_tags: Optional[Dict[str, Dict[str, object]]] = None, defined_tags: Optional[Dict[str, Dict[str, object]]] = None, bucket_uri: Optional[str] = None, remove_existing_artifact: Optional[bool] = True, overwrite_existing_artifact: Optional[bool] = True, model_version_set: Optional[Union[str, ModelVersionSet]] = None, version_label: Optional[str] = None, ): """ Uploads the model artifact to cloud storage. Parameters ---------- model_artifact: Union[ModelArtifact, GenericModel] The model artifacts or generic model instance. provenance_metadata: (ModelProvenance, optional). Defaults to None. Model provenance gives data scientists information about the origin of their model. This information allows data scientists to reproduce the development environment in which the model was trained. project_id: (str, optional). Defaults to None. The project_id of model. display_name: (str, optional). Defaults to None. The name of model. If a display_name is not provided, a randomly generated easy to remember name with timestamp will be generated, like 'strange-spider-2022-08-17-23:55.02'. description: (str, optional). Defaults to None. The description of model. freeform_tags: (Dict[str, str], optional). Defaults to None. Freeform tags for the model, by default None defined_tags: (Dict[str, dict[str, object]], optional). Defaults to None. Defined tags for the model, by default None. bucket_uri: (str, optional). Defaults to None. The OCI Object Storage URI where model artifacts will be copied to. The `bucket_uri` is only necessary for uploading large artifacts which size greater than 2GB. Example: `oci://<bucket_name>@<namespace>/prefix/`. remove_existing_artifact: (bool, optional). Defaults to `True`. Whether artifacts uploaded to object storage bucket need to be removed or not. overwrite_existing_artifact: (bool, optional). Defaults to `True`. Overwrite target bucket artifact if exists. model_version_set: (Union[str, ModelVersionSet], optional). Defaults to None. The Model version set OCID, or name, or `ModelVersionSet` instance. version_label: (str, optional). Defaults to None. The model version label. Returns ------- ads.catalog.Model The ads.catalog.Model with the matching ID. """ project_id = project_id or PROJECT_OCID if not project_id: raise ValueError("`project_id` needs to be specified.") copy_artifact_to_os = False if ( bucket_uri or utils.folder_size(model_artifact.artifact_dir) > _MAX_ARTIFACT_SIZE_IN_BYTES ): if not bucket_uri: raise ValueError( f"The model artifacts size is greater than `{utils.human_size(_MAX_ARTIFACT_SIZE_IN_BYTES)}`. " "The `bucket_uri` needs to be specified to " "copy artifacts to the object storage bucket. " "Example: `bucket_uri=oci://<bucket_name>@<namespace>/prefix/`" ) copy_artifact_to_os = True # extract model_version_set_id from model_version_set attribute or environment # variables in case of saving model in context of model version set. model_version_set_id = _extract_model_version_set_id(model_version_set) # Set default display_name if not specified - randomly generated easy to remember name generated display_name = display_name or utils.get_random_name_for_resource() with utils.get_progress_bar(5) as progress: project_id = PROJECT_OCID if project_id is None else project_id if project_id is None: raise ValueError("project_id needs to be specified.") schema_file = os.path.join(model_artifact.artifact_dir, "schema.json") if os.path.exists(schema_file): with open(schema_file, "r") as schema: metadata = json.load(schema) freeform_tags = {"problem_type": metadata["problem_type"]} progress.update("Saving model in the model catalog") create_model_details = CreateModelDetails( display_name=display_name, description=description, project_id=project_id, compartment_id=self.compartment_id, custom_metadata_list=model_artifact.metadata_custom._to_oci_metadata() if model_artifact.metadata_custom is not None else [], defined_metadata_list=model_artifact.metadata_taxonomy._to_oci_metadata() if model_artifact.metadata_taxonomy is not None else [], input_schema=model_artifact.schema_input.to_json() if model_artifact.schema_input is not None else '{"schema": []}', output_schema=model_artifact.schema_output.to_json() if model_artifact.schema_output is not None else '{"schema": []}', freeform_tags=freeform_tags, defined_tags=defined_tags, model_version_set_id=model_version_set_id, version_label=version_label, ) model = self.ds_client.create_model(create_model_details) if provenance_metadata is not None: progress.update("Saving provenance metadata") self.ds_client.create_model_provenance( model.data.id, provenance_metadata ) else: progress.update() # if the model artifact size greater than 2GB then export function # needs to be used instead of upload. The export function will copy # model artifacts to the OS bucket at first and then will upload # the artifacts to the model catalog. if copy_artifact_to_os: self._export_model_artifact( model_id=model.data.id, model_artifact=model_artifact, progress=progress, bucket_uri=bucket_uri, remove_existing_artifact=remove_existing_artifact, overwrite_existing_artifact=overwrite_existing_artifact, ) else: self._upload_model_artifact( model_id=model.data.id, model_artifact=model_artifact, progress=progress, ) progress.update() progress.update("Done") return self.get_model(model.data.id)
def _prepare_model_artifact(self, model_artifact, progress: TqdmProgressBar) -> str: """Prepares model artifacts to save in the Model Catalog. Returns ------- str The path to the model artifact zip archive. """ progress.update("Preparing model artifacts zip") files_to_upload = model_artifact._get_files() artifact_path = "/tmp/saved_model_" + str(uuid.uuid4()) + ".zip" zf = ZipFile(artifact_path, "w") for matched_file in files_to_upload: zf.write( os.path.join(model_artifact.artifact_dir, matched_file), arcname=matched_file, ) zf.close() return artifact_path def _upload_model_artifact(self, model_id, model_artifact, progress): """Uploads model artifact to the model catalog. This method can be used only if the size of model artifact is less than 2GB. For the artifacts with size greater than 2 GB the `_export_model_artifact` method should be used instead. """ artifact_zip_path = self._prepare_model_artifact(model_artifact, progress) progress.update("Uploading model artifacts to the catalog") with open(artifact_zip_path, "rb") as file_data: bytes_content = file_data.read() self.ds_client.create_model_artifact( model_id, bytes_content, content_disposition=f'attachment; filename="{model_id}.zip"', ) os.remove(artifact_zip_path) progress.update() def _export_model_artifact( self, model_id: str, model_artifact: ModelArtifact, bucket_uri: str, progress, remove_existing_artifact: Optional[bool] = True, overwrite_existing_artifact: Optional[bool] = True, ): """Exports model artifact to the model catalog. This method is used for the case when the artifact size is greater than 2GB. 1. Archive model artifact. 2. Copies the artifact to the object storage bucket. 3. Exports artifact from the user's object storage bucket to the system one. """ artifact_zip_path = self._prepare_model_artifact(model_artifact, progress) progress.update(f"Copying model artifact to the Object Storage bucket") try: bucket_uri_file_name = os.path.basename(bucket_uri) if not bucket_uri_file_name: bucket_uri = os.path.join(bucket_uri, f"{model_id}.zip") elif not bucket_uri.lower().endswith(".zip"): bucket_uri = f"{bucket_uri}.zip" bucket_file_name = utils.copy_file( artifact_zip_path, bucket_uri, force_overwrite=overwrite_existing_artifact, auth=self.ds_client_auth, progressbar_description="Copying model artifact to the Object Storage bucket", ) except FileExistsError: raise FileExistsError( f"The `{bucket_uri}` exists. Please use a new file name or " "set `overwrite_existing_artifact` to `True` if you wish to overwrite." ) os.remove(artifact_zip_path) progress.update("Exporting model artifact to the model catalog") bucket_details = ObjectStorageDetails.from_path(bucket_file_name) response = self.ds_client.export_model_artifact( model_id=model_id, export_model_artifact_details=ExportModelArtifactDetails( artifact_export_details=ArtifactExportDetailsObjectStorage( namespace=bucket_details.namespace, source_bucket=bucket_details.bucket, source_object_name=bucket_details.filepath, source_region=self._region, ) ), ) self._wait_for_work_request( response=response, first_step_description="Preparing to export model artifact to the model catalog", num_steps=4, ) if remove_existing_artifact: progress.update( "Removing temporary model artifact from the Object Storage bucket" ) utils.remove_file(bucket_file_name, self.ds_client_auth) else: progress.update() @property def _region(self): """Gets current region.""" if "region" in self.ds_client_auth.get("config", {}): return self.ds_client_auth["config"]["region"] return json.loads(OCI_REGION_METADATA)["regionIdentifier"] def _wait_for_work_request( self, response, first_step_description: str = "", num_steps=4 ): """Waits for the work request to be completed.""" STOP_STATE = ( WorkRequest.STATUS_SUCCEEDED, WorkRequest.STATUS_CANCELED, WorkRequest.STATUS_CANCELING, WorkRequest.STATUS_FAILED, ) work_request_id = response.headers["opc-work-request-id"] work_request_logs = None i = 0 with utils.get_progress_bar(num_steps) as progress: progress.update(first_step_description) while not work_request_logs or len(work_request_logs) < num_steps: time.sleep(_WORK_REQUEST_INTERVAL_IN_SEC) work_request = self.ds_client.get_work_request(work_request_id) work_request_logs = self.ds_client.list_work_request_logs( work_request_id ).data if work_request_logs: new_work_request_logs = work_request_logs[i:] for wr_item in new_work_request_logs: progress.update(wr_item.message) i += 1 if work_request.data.status in STOP_STATE: if work_request.data.status != WorkRequest.STATUS_SUCCEEDED: if work_request_logs: raise Exception(work_request_logs[-1].message) else: raise Exception( "An error occurred in attempt to perform the operation. Check the service logs to get more details." ) else: break return work_request