Source code for ads.model.artifact_downloader

#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import os
import shutil
import tempfile
import uuid
from abc import ABC, abstractmethod
from typing import Dict, Optional
from zipfile import ZipFile

from ads.common import utils
from ads.common.utils import extract_region
from ads.model.service.oci_datascience_model import OCIDataScienceModel


[docs]class ArtifactDownloader(ABC): """The abstract class to download model artifacts.""" PROGRESS_STEPS_COUNT = 1 def __init__( self, dsc_model: OCIDataScienceModel, target_dir: str, force_overwrite: Optional[bool] = False, ): """Initializes `ArtifactDownloader` instance. Parameters ---------- dsc_model: OCIDataScienceModel The data scince model instance. target_dir: str The target location of model after download. force_overwrite: bool Overwrite target_dir if exists. """ self.dsc_model = dsc_model self.target_dir = target_dir self.force_overwrite = force_overwrite self.progress = None
[docs] def download(self): """Downloads model artifacts. Returns ------- None Raises ------ ValueError If target directory does not exist. """ if os.path.exists(self.target_dir) and len(os.listdir(self.target_dir)) > 0: if not self.force_overwrite: raise ValueError( f"The `{self.target_dir}` directory already exists. " "Set `force_overwrite` to `True` if you wish to overwrite." ) shutil.rmtree(self.target_dir) with utils.get_progress_bar( ArtifactDownloader.PROGRESS_STEPS_COUNT + self.PROGRESS_STEPS_COUNT ) as progress: self.progress = progress self._download() self.progress.update( "Downloading model artifacts has been successfully completed." ) self.progress.update("Done.")
@abstractmethod def _download(self): """Downloads model artifacts."""
[docs]class SmallArtifactDownloader(ArtifactDownloader): PROGRESS_STEPS_COUNT = 3 def _download(self): """Downloads model artifacts.""" self.progress.update("Importing model artifacts from catalog") zip_content = self.dsc_model.get_model_artifact_content() with tempfile.TemporaryDirectory() as temp_dir: self.progress.update("Copying model artifacts to the artifact directory") zip_file_path = os.path.join(temp_dir, f"{str(uuid.uuid4())}.zip") with open(zip_file_path, "wb") as zip_file: zip_file.write(zip_content) self.progress.update("Extracting model artifacts") with ZipFile(zip_file_path) as zip_file: zip_file.extractall(self.target_dir)
[docs]class LargeArtifactDownloader(ArtifactDownloader): PROGRESS_STEPS_COUNT = 4 def __init__( self, dsc_model: OCIDataScienceModel, target_dir: str, auth: Optional[Dict] = None, force_overwrite: Optional[bool] = False, region: Optional[str] = None, bucket_uri: Optional[str] = None, overwrite_existing_artifact: Optional[bool] = True, remove_existing_artifact: Optional[bool] = True, ): """Initializes `LargeArtifactDownloader` instance. Parameters ---------- dsc_model: OCIDataScienceModel The data scince model instance. target_dir: str The target location of model after download. auth: (Dict, optional). Defaults to `None`. The default authetication is set using `ads.set_auth` API. If you need to override the default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate authentication signer and kwargs required to instantiate IdentityClient object. force_overwrite: (bool, optional). Defaults to `False`. Overwrite target_dir if exists. region: (str, optional). Defaults to `None`. The destination Object Storage bucket region. By default the value will be extracted from the `OCI_REGION_METADATA` environment variables. bucket_uri: (str, optional). Defaults to None. The OCI Object Storage URI where model artifacts will be copied to. The `bucket_uri` is only necessary for uploading large artifacts which size is greater than 2GB. Example: `oci://<bucket_name>@<namespace>/prefix/`. overwrite_existing_artifact: (bool, optional). Defaults to `True`. Overwrite target bucket artifact if exists. remove_existing_artifact: (bool, optional). Defaults to `True`. Wether artifacts uploaded to object storage bucket need to be removed or not. """ super().__init__( dsc_model=dsc_model, target_dir=target_dir, force_overwrite=force_overwrite ) self.auth = auth or dsc_model.auth self.region = region or extract_region(self.auth) self.bucket_uri = bucket_uri self.overwrite_existing_artifact = overwrite_existing_artifact self.remove_existing_artifact = remove_existing_artifact def _download(self): """Downloads model artifacts.""" self.progress.update(f"Importing model artifacts from catalog") bucket_uri = self.bucket_uri if not os.path.basename(bucket_uri): bucket_uri = os.path.join(bucket_uri, f"{self.dsc_model.id}.zip") elif not bucket_uri.lower().endswith(".zip"): bucket_uri = f"{bucket_uri}.zip" self.dsc_model.import_model_artifact(bucket_uri=bucket_uri, region=self.region) self.progress.update("Copying model artifacts to the artifact directory") with tempfile.TemporaryDirectory() as temp_dir: zip_file_path = os.path.join(temp_dir, f"{str(uuid.uuid4())}.zip") zip_file_path = utils.copy_file( uri_src=bucket_uri, uri_dst=zip_file_path, auth=self.auth, progressbar_description="Copying model artifacts to the artifact directory", ) self.progress.update("Extracting model artifacts") with ZipFile(zip_file_path) as zip_file: zip_file.extractall(self.target_dir) if self.remove_existing_artifact: self.progress.update( "Removing temporary artifacts from the Object Storage bucket" ) utils.remove_file(bucket_uri) else: self.progress.update()