#!/usr/bin/env python
# -*- coding: utf-8 -*--
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
import json
import os
import re
from dataclasses import dataclass
from typing import Dict, List
from urllib.parse import urlparse
import oci
from ads.common import auth as authutil
from ads.common import oci_client
from ads.dataset.progress import TqdmProgressBar
from concurrent.futures import ThreadPoolExecutor, as_completed
THREAD_POOL_MAX_WORKERS = 10
[docs]
class InvalidObjectStoragePath(Exception): # pragma: no cover
"""Invalid Object Storage Path."""
pass
[docs]
@dataclass
class ObjectStorageDetails:
"""Class that represents the Object Storage bucket URI details.
Attributes
----------
bucket: str
The Object Storage bucket name.
namespace: (str, optional). Defaults to empty string.
The Object Storage namespace. Will be extracted automatically if not provided.
filepath: (str, optional). Defaults to empty string.
The path to the object.
version: (str, optional). Defaults to None.
The version of the object.
auth: (Dict, optional). Defaults to None.
ADS auth dictionary for OCI authentication.
This can be generated by calling ads.common.auth.api_keys() or ads.common.auth.resource_principal()
If this is None, ads.common.default_signer() will be used.
"""
bucket: str
namespace: str = ""
filepath: str = ""
version: str = None
auth: Dict = None
def __post_init__(self):
if not self.auth:
self.auth = authutil.default_signer()
# Extract OS namespace if not provided.
if not self.namespace:
self.namespace = self.os_client.get_namespace().data
@property
def os_client(self):
if not hasattr(self, "__client"):
self.__client = self.create_os_client()
return self.__client
[docs]
def create_os_client(self):
return oci_client.OCIClientFactory(**self.auth).object_storage
def __repr__(self):
return self.path
@property
def path(self):
"""Full object storage path of this file."""
return os.path.join(
"oci://",
self.bucket + "@" + self.namespace,
self.filepath.lstrip("/") if self.filepath else "",
)
[docs]
@classmethod
def from_path(cls, env_path: str) -> "ObjectStorageDetails":
"""Construct an ObjectStorageDetails instance from conda pack path.
Parameters
----------
env_path: (str)
codna pack object storage path.
Raises
------
Exception: OCI conda url path not properly configured.
Returns
-------
ObjectStorageDetails
An ObjectStorageDetails instance.
"""
try:
url_parse = urlparse(env_path)
bucket_name = url_parse.username
namespace = url_parse.hostname
object_name = url_parse.path.lstrip("/")
return cls(bucket=bucket_name, namespace=namespace, filepath=object_name)
except:
raise Exception(
"OCI path is not properly configured. "
"It should follow the pattern `oci://<bucket-name>@<namespace>/object_path`."
)
[docs]
def to_tuple(self):
"""Returns the values of the fields of ObjectStorageDetails class."""
return self.bucket, self.namespace, self.filepath
[docs]
@staticmethod
def is_valid_uri(uri: str) -> bool:
"""Validates the Object Storage URI."""
if not re.match(r"oci://*@*", uri):
raise InvalidObjectStoragePath(
f"The `{uri}` is not a valid Object Storage path. "
"It must follow the pattern `oci://<bucket_name>@<namespace>/<prefix>`."
)
return True
[docs]
@staticmethod
def is_oci_path(uri: str = None) -> bool:
"""Check if the given path is oci object storage uri.
Parameters
----------
uri: str
The URI of the target.
Returns
-------
bool: return True if the path is oci object storage uri.
"""
if not uri:
return False
return uri.lower().startswith("oci://")
[docs]
def is_bucket_versioned(self) -> bool:
"""Check if the given bucket is versioned.
Returns
-------
bool: return True if the bucket is versioned.
"""
res = self.os_client.get_bucket(
namespace_name=self.namespace, bucket_name=self.bucket
).data
return res.versioning == "Enabled"
[docs]
def list_objects(self, **kwargs):
"""Lists objects in a given oss path
Parameters
-------
**kwargs:
namespace, bucket, filepath are set by the class. By default, fields gets all values. For other supported
parameters, check https://docs.oracle.com/iaas/api/#/en/objectstorage/20160918/Object/ListObjects
Returns
-------
Object of type oci.object_storage.models.ListObjects
"""
fields = kwargs.pop(
"fields",
"name,etag,size,timeCreated,md5,timeModified,storageTier,archivalState",
)
objects = oci.pagination.list_call_get_all_results(
self.os_client.list_objects,
namespace_name=self.namespace,
bucket_name=self.bucket,
prefix=self.filepath,
fields=fields,
**kwargs,
).data
return objects
[docs]
def list_object_versions(
self,
**kwargs,
):
"""Lists object versions in a given oss path
Parameters
-------
**kwargs:
namespace, bucket, filepath are set by the class. By default, fields gets all values. For other supported
parameters, check https://docs.oracle.com/iaas/api/#/en/objectstorage/20160918/Object/ListObjectVersions
Returns
-------
Object of type oci.object_storage.models.ObjectVersionCollection
"""
fields = kwargs.pop(
"fields",
"name,etag,size,timeCreated,md5,timeModified,storageTier,archivalState",
)
objects = oci.pagination.list_call_get_all_results(
self.os_client.list_object_versions,
namespace_name=self.namespace,
bucket_name=self.bucket,
prefix=self.filepath,
fields=fields,
**kwargs,
).data
return objects
[docs]
def download_from_object_storage(
self,
path: "ObjectStorageDetails",
target_dir: str,
progress_bar: TqdmProgressBar = None,
):
"""Downloads the files with object versions set in the paths dict.
Parameters
----------
path:
OSS path along with a value of file version id.
If version_id is not available, download the latest version.
target_dir:
Local directory to save the files
progress_bar:
an instance of the TqdmProgressBar, can update description in the calling progress bar
Returns
-------
None
"""
if progress_bar:
progress_bar.update(
description=f"Copying model artifacts by reference from {path.path} to {target_dir}",
n=0,
)
res = self.os_client.get_object(
namespace_name=path.namespace,
bucket_name=path.bucket,
object_name=path.filepath,
version_id=path.version,
)
local_filepath = os.path.join(target_dir, path.bucket, path.filepath)
os.makedirs(os.path.dirname(local_filepath), exist_ok=True)
with open(local_filepath, "wb") as _file:
for chunk in res.data.iter_content(chunk_size=4096):
_file.write(chunk)
[docs]
def bulk_download_from_object_storage(
self,
paths: List["ObjectStorageDetails"],
target_dir: str,
progress_bar: TqdmProgressBar = None,
):
"""Downloads the files with object versions set in the paths dict parallely.
Parameters
----------
paths:
Contains a list of OSS paths along with a value of file version id.
If version_id is not available, download the latest version.
target_dir:
Local directory to save the files
progress_bar:
an instance of the TqdmProgressBar, can update description in the calling progress bar
Returns
-------
None
"""
with ThreadPoolExecutor(max_workers=THREAD_POOL_MAX_WORKERS) as pool:
futures = {
pool.submit(
self.download_from_object_storage, path, target_dir, progress_bar
): path
for path in paths
}
for future in as_completed(futures):
future.result()