Source code for ads.common.object_storage_details

#!/usr/bin/env python

# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import json
import os
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import Dict, List
from urllib.parse import urlparse

import oci
from ads.common import auth as authutil
from ads.common import oci_client
from ads.dataset.progress import TqdmProgressBar

THREAD_POOL_MAX_WORKERS = 10


[docs] class InvalidObjectStoragePath(Exception): # pragma: no cover """Invalid Object Storage Path.""" pass
[docs] @dataclass class ObjectStorageDetails: """Class that represents the Object Storage bucket URI details. Attributes ---------- bucket: str The Object Storage bucket name. namespace: (str, optional). Defaults to empty string. The Object Storage namespace. Will be extracted automatically if not provided. filepath: (str, optional). Defaults to empty string. The path to the object. version: (str, optional). Defaults to None. The version of the object. auth: (Dict, optional). Defaults to None. ADS auth dictionary for OCI authentication. This can be generated by calling ads.common.auth.api_keys() or ads.common.auth.resource_principal() If this is None, ads.common.default_signer() will be used. """ bucket: str namespace: str = "" filepath: str = "" version: str = None auth: Dict = None def __post_init__(self): if not self.auth: self.auth = authutil.default_signer() # Extract OS namespace if not provided. if not self.namespace: self.namespace = self.os_client.get_namespace().data @property def os_client(self): if not hasattr(self, "__client"): self.__client = self.create_os_client() return self.__client
[docs] def create_os_client(self): return oci_client.OCIClientFactory(**self.auth).object_storage
def __repr__(self): return self.path @property def path(self): """Full object storage path of this file.""" return os.path.join( "oci://", self.bucket + "@" + self.namespace, self.filepath.lstrip("/") if self.filepath else "", )
[docs] @classmethod def from_path(cls, env_path: str) -> "ObjectStorageDetails": """Construct an ObjectStorageDetails instance from conda pack path. Parameters ---------- env_path: (str) codna pack object storage path. Raises ------ Exception: OCI conda url path not properly configured. Returns ------- ObjectStorageDetails An ObjectStorageDetails instance. """ try: url_parse = urlparse(env_path) bucket_name = url_parse.username namespace = url_parse.hostname object_name = url_parse.path.lstrip("/") return cls(bucket=bucket_name, namespace=namespace, filepath=object_name) except: raise Exception( "OCI path is not properly configured. " "It should follow the pattern `oci://<bucket-name>@<namespace>/object_path`." )
[docs] def to_tuple(self): """Returns the values of the fields of ObjectStorageDetails class.""" return self.bucket, self.namespace, self.filepath
[docs] def fetch_metadata_of_object(self) -> Dict: """Fetches the manifest metadata from the object storage of a conda pack. Returns ------- Dict The metadata in dictionary format. """ res = self.os_client.get_object(self.namespace, self.bucket, self.filepath) metadata = res.data.headers["opc-meta-manifest"] metadata_json = json.loads(metadata) return metadata_json
[docs] @staticmethod def is_valid_uri(uri: str) -> bool: """Validates the Object Storage URI.""" if not re.match(r"oci://*@*", uri): raise InvalidObjectStoragePath( f"The `{uri}` is not a valid Object Storage path. " "It must follow the pattern `oci://<bucket_name>@<namespace>/<prefix>`." ) return True
[docs] @staticmethod def is_oci_path(uri: str = None) -> bool: """Check if the given path is oci object storage uri. Parameters ---------- uri: str The URI of the target. Returns ------- bool: return True if the path is oci object storage uri. """ if not uri: return False return uri.lower().startswith("oci://")
[docs] def is_bucket_versioned(self) -> bool: """Check if the given bucket is versioned. Returns ------- bool: return True if the bucket is versioned. """ res = self.os_client.get_bucket( namespace_name=self.namespace, bucket_name=self.bucket ).data return res.versioning == "Enabled"
[docs] def list_objects(self, **kwargs): """Lists objects in a given oss path Parameters ------- **kwargs: namespace, bucket, filepath are set by the class. By default, fields gets all values. For other supported parameters, check https://docs.oracle.com/iaas/api/#/en/objectstorage/20160918/Object/ListObjects Returns ------- Object of type oci.object_storage.models.ListObjects """ fields = kwargs.pop( "fields", "name,etag,size,timeCreated,md5,timeModified,storageTier,archivalState", ) objects = oci.pagination.list_call_get_all_results( self.os_client.list_objects, namespace_name=self.namespace, bucket_name=self.bucket, prefix=self.filepath, fields=fields, **kwargs, ).data return objects
[docs] def list_object_versions( self, **kwargs, ): """Lists object versions in a given oss path Parameters ------- **kwargs: namespace, bucket, filepath are set by the class. By default, fields gets all values. For other supported parameters, check https://docs.oracle.com/iaas/api/#/en/objectstorage/20160918/Object/ListObjectVersions Returns ------- Object of type oci.object_storage.models.ObjectVersionCollection """ fields = kwargs.pop( "fields", "name,etag,size,timeCreated,md5,timeModified,storageTier,archivalState", ) objects = oci.pagination.list_call_get_all_results( self.os_client.list_object_versions, namespace_name=self.namespace, bucket_name=self.bucket, prefix=self.filepath, fields=fields, **kwargs, ).data return objects
[docs] def download_from_object_storage( self, path: "ObjectStorageDetails", target_dir: str, progress_bar: TqdmProgressBar = None, ): """Downloads the files with object versions set in the paths dict. Parameters ---------- path: OSS path along with a value of file version id. If version_id is not available, download the latest version. target_dir: Local directory to save the files progress_bar: an instance of the TqdmProgressBar, can update description in the calling progress bar Returns ------- None """ if progress_bar: progress_bar.update( description=f"Copying model artifacts by reference from {path.path} to {target_dir}", n=0, ) res = self.os_client.get_object( namespace_name=path.namespace, bucket_name=path.bucket, object_name=path.filepath, version_id=path.version, ) local_filepath = os.path.join(target_dir, path.bucket, path.filepath) os.makedirs(os.path.dirname(local_filepath), exist_ok=True) with open(local_filepath, "wb") as _file: for chunk in res.data.iter_content(chunk_size=4096): _file.write(chunk)
[docs] def bulk_download_from_object_storage( self, paths: List["ObjectStorageDetails"], target_dir: str, progress_bar: TqdmProgressBar = None, ): """Downloads the files with object versions set in the paths dict parallely. Parameters ---------- paths: Contains a list of OSS paths along with a value of file version id. If version_id is not available, download the latest version. target_dir: Local directory to save the files progress_bar: an instance of the TqdmProgressBar, can update description in the calling progress bar Returns ------- None """ with ThreadPoolExecutor(max_workers=THREAD_POOL_MAX_WORKERS) as pool: futures = { pool.submit( self.download_from_object_storage, path, target_dir, progress_bar ): path for path in paths } for future in as_completed(futures): future.result()