Source code for ads.model.runtime.utils

#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import json
import os
from functools import lru_cache
from typing import Dict, Tuple
import logging

import fsspec
import requests
import yaml
from cerberus import DocumentError, Validator

from ads.common.object_storage_details import ObjectStorageDetails
from ads.common.utils import PAR_LINK

MODEL_PROVENANCE_SCHEMA_PATH = os.path.join(
    os.path.dirname(os.path.abspath(__file__)),
    "schemas",
    "model_provenance_schema.yaml",
)
INFERENCE_ENV_SCHEMA_PATH = os.path.join(
    os.path.dirname(os.path.abspath(__file__)),
    "schemas",
    "inference_env_info_schema.yaml",
)
TRAINING_ENV_SCHEMA_PATH = os.path.join(
    os.path.dirname(os.path.abspath(__file__)),
    "schemas",
    "training_env_info_schema.yaml",
)
SERVICE_PACKS = "service_packs"


[docs] class SchemaValidator: """ Base Schema Validator which validate yaml file. """ def __init__(self, schema_file_path: str) -> "SchemaValidator": """Initiate a SchemaValidator instance. Parameters ---------- schema_file_path: (str) schema file path. The schema is used to validate the yaml file. Returns ------- SchemaValidator A SchemaValidator instance. """ self.schema_file_path = schema_file_path self._schema_validator = self._load_schema_validator()
[docs] def validate(self, document: Dict) -> bool: """Validate the schema. Parameters ---------- document: (Dict) yaml file content to validate. Raises ------ DocumentError: Raised when the validation schema is missing, has the wrong format or contains errors. Returns ------- bool validation result. """ v = Validator() res = v.validate(document, self._schema_validator) if not res: raise DocumentError(v.errors) return res
def _load_schema_validator(self): """load the schema validator to validate the schema. Returns ------- dict Schema validator. """ with open(self.schema_file_path, encoding="utf-8") as schema_file: ext = os.path.splitext(self.schema_file_path)[-1].lower() if ext in [".yaml", ".yml"]: schema_validator = yaml.load(schema_file, yaml.FullLoader) elif ext in [".json"]: schema_validator = json.load(schema_file) else: raise NotImplementedError(f"{ext} format schema is not supported.") return schema_validator
def _get_index_json_through_bucket( namespace: str, bucketname: str, auth: dict = None ) -> list: """get the index json from the object storage. Parameters ---------- namespace: str The Object Storage namespace. bucketname: str The Object Storage bucketname. auth: (Dict, optional). Defaults to None. The default authentication is set using `ads.set_auth` API. If you need to override the default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate authentication signer and kwargs required to instantiate IdentityClient object. Returns ------- list: A list of dictionary which contains service packs information. """ auth = auth or {} service_pack_list = [] try: uri = f"oci://{bucketname}@{namespace}/service_pack/index.json" with fsspec.open(uri, "r", **auth) as f: service_packs = json.loads(f.read()) service_pack_list = service_packs.get(SERVICE_PACKS) except Exception as e: logging.warn(e) logging.warn( "Failed to retrieve the full conda pack path from slug. Pass conda pack path 'oci://<bucketname>@<namespace>/<path_to_conda>' instead of slug." ) return service_pack_list
[docs] def get_service_packs( namespace: str, bucketname: str, auth: dict = None ) -> Tuple[Dict, Dict]: """Get the service pack path mapping and service pack slug mapping. Note: deprecated packs are also included. Parameters ---------- namespace: str namespace of the service pack. bucketname: str bucketname of the service pack. auth: (Dict, optional). Defaults to None. The default authentication is set using `ads.set_auth` API. If you need to override the default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate authentication signer and kwargs required to instantiate IdentityClient object. Returns ------- (Dict, Dict) Service pack path mapping(service pack path -> (slug, python version)) and the service pack slug mapping(service pack slug -> (pack path, python version)). """ service_pack_path_mapping = {} service_pack_slug_mapping = {} try: response = requests.request("GET", PAR_LINK) # if there is internet if response.ok: service_pack_list = response.json().get(SERVICE_PACKS) # response not good. else: service_pack_list = _get_index_json_through_bucket( namespace=namespace, bucketname=bucketname, auth=auth ) except Exception as e: # not internet service_pack_list = _get_index_json_through_bucket( namespace=namespace, bucketname=bucketname, auth=auth ) for service_pack in service_pack_list: # Here we need to replace the namespace and bucketname # with the bucket and namespace of the region that # user is in. The reason is that the mapping is generated # from the index.json file which has static namespace # and bucket of prod. however, namespace will change based # on the region. also, dev has different bucketname. pack_path = ObjectStorageDetails( bucket=bucketname, namespace=namespace, filepath=ObjectStorageDetails.from_path( service_pack.get("pack_path") ).filepath, ).path service_pack_path_mapping[pack_path] = ( service_pack.get("slug"), service_pack.get("python"), ) service_pack_slug_mapping[service_pack.get("slug")] = ( pack_path, service_pack.get("python"), ) return service_pack_path_mapping, service_pack_slug_mapping