#!/usr/bin/env python
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
import json
import os
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import Dict, List
from urllib.parse import urlparse
import oci
from ads.common import auth as authutil
from ads.common import oci_client
from ads.dataset.progress import TqdmProgressBar
THREAD_POOL_MAX_WORKERS = 10
[docs]
class InvalidObjectStoragePath(Exception): # pragma: no cover
"""Invalid Object Storage Path."""
pass
[docs]
@dataclass
class ObjectStorageDetails:
"""Class that represents the Object Storage bucket URI details.
Attributes
----------
bucket: str
The Object Storage bucket name.
namespace: (str, optional). Defaults to empty string.
The Object Storage namespace. Will be extracted automatically if not provided.
filepath: (str, optional). Defaults to empty string.
The path to the object.
version: (str, optional). Defaults to None.
The version of the object.
auth: (Dict, optional). Defaults to None.
ADS auth dictionary for OCI authentication.
This can be generated by calling ads.common.auth.api_keys() or ads.common.auth.resource_principal()
If this is None, ads.common.default_signer() will be used.
"""
bucket: str
namespace: str = ""
filepath: str = ""
version: str = None
auth: Dict = None
def __post_init__(self):
if not self.auth:
self.auth = authutil.default_signer()
# Extract OS namespace if not provided.
if not self.namespace:
self.namespace = self.os_client.get_namespace().data
@property
def os_client(self):
if not hasattr(self, "__client"):
self.__client = self.create_os_client()
return self.__client
[docs]
def create_os_client(self):
return oci_client.OCIClientFactory(**self.auth).object_storage
def __repr__(self):
return self.path
@property
def path(self):
"""Full object storage path of this file."""
return os.path.join(
"oci://",
self.bucket + "@" + self.namespace,
self.filepath.lstrip("/") if self.filepath else "",
)
[docs]
@classmethod
def from_path(cls, env_path: str) -> "ObjectStorageDetails":
"""Construct an ObjectStorageDetails instance from conda pack path.
Parameters
----------
env_path: (str)
codna pack object storage path.
Raises
------
Exception: OCI conda url path not properly configured.
Returns
-------
ObjectStorageDetails
An ObjectStorageDetails instance.
"""
try:
url_parse = urlparse(env_path)
bucket_name = url_parse.username
namespace = url_parse.hostname
object_name = url_parse.path.lstrip("/")
return cls(bucket=bucket_name, namespace=namespace, filepath=object_name)
except:
raise Exception(
"OCI path is not properly configured. "
"It should follow the pattern `oci://<bucket-name>@<namespace>/object_path`."
)
[docs]
def to_tuple(self):
"""Returns the values of the fields of ObjectStorageDetails class."""
return self.bucket, self.namespace, self.filepath
[docs]
@staticmethod
def is_valid_uri(uri: str) -> bool:
"""Validates the Object Storage URI."""
if not re.match(r"oci://*@*", uri):
raise InvalidObjectStoragePath(
f"The `{uri}` is not a valid Object Storage path. "
"It must follow the pattern `oci://<bucket_name>@<namespace>/<prefix>`."
)
return True
[docs]
@staticmethod
def is_oci_path(uri: str = None) -> bool:
"""Check if the given path is oci object storage uri.
Parameters
----------
uri: str
The URI of the target.
Returns
-------
bool: return True if the path is oci object storage uri.
"""
if not uri:
return False
return uri.lower().startswith("oci://")
[docs]
def is_bucket_versioned(self) -> bool:
"""Check if the given bucket is versioned.
Returns
-------
bool: return True if the bucket is versioned.
"""
res = self.os_client.get_bucket(
namespace_name=self.namespace, bucket_name=self.bucket
).data
return res.versioning == "Enabled"
[docs]
def list_objects(self, **kwargs):
"""Lists objects in a given oss path
Parameters
-------
**kwargs:
namespace, bucket, filepath are set by the class. By default, fields gets all values. For other supported
parameters, check https://docs.oracle.com/iaas/api/#/en/objectstorage/20160918/Object/ListObjects
Returns
-------
Object of type oci.object_storage.models.ListObjects
"""
fields = kwargs.pop(
"fields",
"name,etag,size,timeCreated,md5,timeModified,storageTier,archivalState",
)
objects = oci.pagination.list_call_get_all_results(
self.os_client.list_objects,
namespace_name=self.namespace,
bucket_name=self.bucket,
prefix=self.filepath,
fields=fields,
**kwargs,
).data
return objects
[docs]
def list_object_versions(
self,
**kwargs,
):
"""Lists object versions in a given oss path
Parameters
-------
**kwargs:
namespace, bucket, filepath are set by the class. By default, fields gets all values. For other supported
parameters, check https://docs.oracle.com/iaas/api/#/en/objectstorage/20160918/Object/ListObjectVersions
Returns
-------
Object of type oci.object_storage.models.ObjectVersionCollection
"""
fields = kwargs.pop(
"fields",
"name,etag,size,timeCreated,md5,timeModified,storageTier,archivalState",
)
objects = oci.pagination.list_call_get_all_results(
self.os_client.list_object_versions,
namespace_name=self.namespace,
bucket_name=self.bucket,
prefix=self.filepath,
fields=fields,
**kwargs,
).data
return objects
[docs]
def download_from_object_storage(
self,
path: "ObjectStorageDetails",
target_dir: str,
progress_bar: TqdmProgressBar = None,
):
"""Downloads the files with object versions set in the paths dict.
Parameters
----------
path:
OSS path along with a value of file version id.
If version_id is not available, download the latest version.
target_dir:
Local directory to save the files
progress_bar:
an instance of the TqdmProgressBar, can update description in the calling progress bar
Returns
-------
None
"""
if progress_bar:
progress_bar.update(
description=f"Copying model artifacts by reference from {path.path} to {target_dir}",
n=0,
)
res = self.os_client.get_object(
namespace_name=path.namespace,
bucket_name=path.bucket,
object_name=path.filepath,
version_id=path.version,
)
local_filepath = os.path.join(target_dir, path.bucket, path.filepath)
os.makedirs(os.path.dirname(local_filepath), exist_ok=True)
with open(local_filepath, "wb") as _file:
for chunk in res.data.iter_content(chunk_size=4096):
_file.write(chunk)
[docs]
def bulk_download_from_object_storage(
self,
paths: List["ObjectStorageDetails"],
target_dir: str,
progress_bar: TqdmProgressBar = None,
):
"""Downloads the files with object versions set in the paths dict parallely.
Parameters
----------
paths:
Contains a list of OSS paths along with a value of file version id.
If version_id is not available, download the latest version.
target_dir:
Local directory to save the files
progress_bar:
an instance of the TqdmProgressBar, can update description in the calling progress bar
Returns
-------
None
"""
with ThreadPoolExecutor(max_workers=THREAD_POOL_MAX_WORKERS) as pool:
futures = {
pool.submit(
self.download_from_object_storage, path, target_dir, progress_bar
): path
for path in paths
}
for future in as_completed(futures):
future.result()