Source code for ads.common.object_storage_details

#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import json
import os
import re
from dataclasses import dataclass
from typing import Dict, List
from urllib.parse import urlparse


import oci
from ads.common import auth as authutil
from ads.common import oci_client
from ads.dataset.progress import TqdmProgressBar
from concurrent.futures import ThreadPoolExecutor, as_completed

THREAD_POOL_MAX_WORKERS = 10


[docs] class InvalidObjectStoragePath(Exception): # pragma: no cover """Invalid Object Storage Path.""" pass
[docs] @dataclass class ObjectStorageDetails: """Class that represents the Object Storage bucket URI details. Attributes ---------- bucket: str The Object Storage bucket name. namespace: (str, optional). Defaults to empty string. The Object Storage namespace. Will be extracted automatically if not provided. filepath: (str, optional). Defaults to empty string. The path to the object. version: (str, optional). Defaults to None. The version of the object. auth: (Dict, optional). Defaults to None. ADS auth dictionary for OCI authentication. This can be generated by calling ads.common.auth.api_keys() or ads.common.auth.resource_principal() If this is None, ads.common.default_signer() will be used. """ bucket: str namespace: str = "" filepath: str = "" version: str = None auth: Dict = None def __post_init__(self): if not self.auth: self.auth = authutil.default_signer() # Extract OS namespace if not provided. if not self.namespace: self.namespace = self.os_client.get_namespace().data @property def os_client(self): if not hasattr(self, "__client"): self.__client = self.create_os_client() return self.__client
[docs] def create_os_client(self): return oci_client.OCIClientFactory(**self.auth).object_storage
def __repr__(self): return self.path @property def path(self): """Full object storage path of this file.""" return os.path.join( "oci://", self.bucket + "@" + self.namespace, self.filepath.lstrip("/") if self.filepath else "", )
[docs] @classmethod def from_path(cls, env_path: str) -> "ObjectStorageDetails": """Construct an ObjectStorageDetails instance from conda pack path. Parameters ---------- env_path: (str) codna pack object storage path. Raises ------ Exception: OCI conda url path not properly configured. Returns ------- ObjectStorageDetails An ObjectStorageDetails instance. """ try: url_parse = urlparse(env_path) bucket_name = url_parse.username namespace = url_parse.hostname object_name = url_parse.path.lstrip("/") return cls(bucket=bucket_name, namespace=namespace, filepath=object_name) except: raise Exception( "OCI path is not properly configured. " "It should follow the pattern `oci://<bucket-name>@<namespace>/object_path`." )
[docs] def to_tuple(self): """Returns the values of the fields of ObjectStorageDetails class.""" return self.bucket, self.namespace, self.filepath
[docs] def fetch_metadata_of_object(self) -> Dict: """Fetches the manifest metadata from the object storage of a conda pack. Returns ------- Dict The metadata in dictionary format. """ res = self.os_client.get_object(self.namespace, self.bucket, self.filepath) metadata = res.data.headers["opc-meta-manifest"] metadata_json = json.loads(metadata) return metadata_json
[docs] @staticmethod def is_valid_uri(uri: str) -> bool: """Validates the Object Storage URI.""" if not re.match(r"oci://*@*", uri): raise InvalidObjectStoragePath( f"The `{uri}` is not a valid Object Storage path. " "It must follow the pattern `oci://<bucket_name>@<namespace>/<prefix>`." ) return True
[docs] @staticmethod def is_oci_path(uri: str = None) -> bool: """Check if the given path is oci object storage uri. Parameters ---------- uri: str The URI of the target. Returns ------- bool: return True if the path is oci object storage uri. """ if not uri: return False return uri.lower().startswith("oci://")
[docs] def is_bucket_versioned(self) -> bool: """Check if the given bucket is versioned. Returns ------- bool: return True if the bucket is versioned. """ res = self.os_client.get_bucket( namespace_name=self.namespace, bucket_name=self.bucket ).data return res.versioning == "Enabled"
[docs] def list_objects(self, **kwargs): """Lists objects in a given oss path Parameters ------- **kwargs: namespace, bucket, filepath are set by the class. By default, fields gets all values. For other supported parameters, check https://docs.oracle.com/iaas/api/#/en/objectstorage/20160918/Object/ListObjects Returns ------- Object of type oci.object_storage.models.ListObjects """ fields = kwargs.pop( "fields", "name,etag,size,timeCreated,md5,timeModified,storageTier,archivalState", ) objects = oci.pagination.list_call_get_all_results( self.os_client.list_objects, namespace_name=self.namespace, bucket_name=self.bucket, prefix=self.filepath, fields=fields, **kwargs, ).data return objects
[docs] def list_object_versions( self, **kwargs, ): """Lists object versions in a given oss path Parameters ------- **kwargs: namespace, bucket, filepath are set by the class. By default, fields gets all values. For other supported parameters, check https://docs.oracle.com/iaas/api/#/en/objectstorage/20160918/Object/ListObjectVersions Returns ------- Object of type oci.object_storage.models.ObjectVersionCollection """ fields = kwargs.pop( "fields", "name,etag,size,timeCreated,md5,timeModified,storageTier,archivalState", ) objects = oci.pagination.list_call_get_all_results( self.os_client.list_object_versions, namespace_name=self.namespace, bucket_name=self.bucket, prefix=self.filepath, fields=fields, **kwargs, ).data return objects
[docs] def download_from_object_storage( self, path: "ObjectStorageDetails", target_dir: str, progress_bar: TqdmProgressBar = None, ): """Downloads the files with object versions set in the paths dict. Parameters ---------- path: OSS path along with a value of file version id. If version_id is not available, download the latest version. target_dir: Local directory to save the files progress_bar: an instance of the TqdmProgressBar, can update description in the calling progress bar Returns ------- None """ if progress_bar: progress_bar.update( description=f"Copying model artifacts by reference from {path.path} to {target_dir}", n=0, ) res = self.os_client.get_object( namespace_name=path.namespace, bucket_name=path.bucket, object_name=path.filepath, version_id=path.version, ) local_filepath = os.path.join(target_dir, path.bucket, path.filepath) os.makedirs(os.path.dirname(local_filepath), exist_ok=True) with open(local_filepath, "wb") as _file: for chunk in res.data.iter_content(chunk_size=4096): _file.write(chunk)
[docs] def bulk_download_from_object_storage( self, paths: List["ObjectStorageDetails"], target_dir: str, progress_bar: TqdmProgressBar = None, ): """Downloads the files with object versions set in the paths dict parallely. Parameters ---------- paths: Contains a list of OSS paths along with a value of file version id. If version_id is not available, download the latest version. target_dir: Local directory to save the files progress_bar: an instance of the TqdmProgressBar, can update description in the calling progress bar Returns ------- None """ with ThreadPoolExecutor(max_workers=THREAD_POOL_MAX_WORKERS) as pool: futures = { pool.submit( self.download_from_object_storage, path, target_dir, progress_bar ): path for path in paths } for future in as_completed(futures): future.result()