Source code for ads.model.common.utils

#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import json
import os
import tempfile
import yaml
from typing import Any, Dict, Optional
from zipfile import ZipFile
from ads.common import utils


DEPRECATE_AS_ONNX_WARNING = "This attribute `as_onnx` will be deprecated in the future. You can choose specific format by setting `model_save_serializer`."
DEPRECATE_USE_TORCH_SCRIPT_WARNING = "This attribute `use_torch_script` will be deprecated in the future. You can choose specific format by setting `model_save_serializer`."


def _extract_locals(
    locals: Dict[str, Any], filter_out_nulls: Optional[bool] = True
) -> Dict[str, Any]:
    """Extract arguments from local variables.
    If input dictionary contains `kwargs`, then it will not be included to
    the result dictionary.

    Properties
    ----------
    locals: Dict[str, Any]
        A dictionary, the result of `locals()` method.
        However can be any dictionary.
    filter_out_nulls: (bool, optional). Defaults to `True`.
        Whether `None` values should be filtered out from the result dict or not.

    Returns
    -------
    Dict[str, Any]
        A new dictionary with the result values.
    """
    result = {}
    keys_to_filter_out = ("kwargs",)
    consolidated_dict = {**locals.get("kwargs", {}), **locals}
    for key, value in consolidated_dict.items():
        if key not in keys_to_filter_out and not (filter_out_nulls and value is None):
            result[key] = value
    return result


def _is_json_serializable(data: Any) -> bool:
    """Check is data input is json serialization.

    Parameters
    ----------
    data: (Any)
        data to be passed to model for prediction.

    Returns
    -------
    bool
        Whether data is json serializable.
    """
    result = True
    try:
        json.dumps(data)
    except:
        result = False
    return result


[docs] def fetch_manifest_from_conda_location(env_location: str): """ Convenience method to fetch the manifest file from a conda environment. :param env_location: Absolute path to the environment. :type env_location: str """ manifest_location = None for file in os.listdir(env_location): if file.endswith("_manifest.yaml"): manifest_location = f"{env_location}/{file}" break env = {} if not manifest_location: raise Exception( f"Could not locate manifest file in the provided conda environment: {env_location}. Dir Listing - " f"{os.listdir(env_location)}" ) with open(manifest_location) as mlf: env = yaml.load(mlf, Loader=yaml.FullLoader) manifest = env["manifest"] return manifest
[docs] def zip_artifact(artifact_dir: str) -> str: """Prepares model artifacts ZIP archive. Parameters ---------- artifact_dir: str Path to the model artifact. Returns ------- str Path to the model artifact ZIP archive file. """ if not artifact_dir: raise ValueError("The `artifact_dir` must be provided.") if not os.path.exists(artifact_dir): raise ValueError(f"The {artifact_dir} not exists.") if not os.path.isdir(artifact_dir): raise ValueError("The `artifact_dir` must be a folder.") files_to_upload = utils.get_files(artifact_dir) # Set delete=False when creating NamedTemporaryFile, # Otherwise, the file will be delete when download_artifact() close the file. artifact = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) # Close the file since NamedTemporaryFile() opens the file by default. artifact.close() artifact_zip_path = artifact.name with ZipFile(artifact_zip_path, "w") as zf: for matched_file in files_to_upload: zf.write( os.path.join(artifact_dir, matched_file), arcname=matched_file, ) return artifact_zip_path