#!/usr/bin/env python# -*- coding: utf-8 -*--# Copyright (c) 2022, 2023 Oracle and/or its affiliates.# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/importjsonimportosimporttempfileimportyamlfromtypingimportAny,Dict,OptionalfromzipfileimportZipFilefromads.commonimportutilsDEPRECATE_AS_ONNX_WARNING="This attribute `as_onnx` will be deprecated in the future. You can choose specific format by setting `model_save_serializer`."DEPRECATE_USE_TORCH_SCRIPT_WARNING="This attribute `use_torch_script` will be deprecated in the future. You can choose specific format by setting `model_save_serializer`."def_extract_locals(locals:Dict[str,Any],filter_out_nulls:Optional[bool]=True)->Dict[str,Any]:"""Extract arguments from local variables. If input dictionary contains `kwargs`, then it will not be included to the result dictionary. Properties ---------- locals: Dict[str, Any] A dictionary, the result of `locals()` method. However can be any dictionary. filter_out_nulls: (bool, optional). Defaults to `True`. Whether `None` values should be filtered out from the result dict or not. Returns ------- Dict[str, Any] A new dictionary with the result values. """result={}keys_to_filter_out=("kwargs",)consolidated_dict={**locals.get("kwargs",{}),**locals}forkey,valueinconsolidated_dict.items():ifkeynotinkeys_to_filter_outandnot(filter_out_nullsandvalueisNone):result[key]=valuereturnresultdef_is_json_serializable(data:Any)->bool:"""Check is data input is json serialization. Parameters ---------- data: (Any) data to be passed to model for prediction. Returns ------- bool Whether data is json serializable. """result=Truetry:json.dumps(data)except:result=Falsereturnresult
[docs]deffetch_manifest_from_conda_location(env_location:str):""" Convenience method to fetch the manifest file from a conda environment. :param env_location: Absolute path to the environment. :type env_location: str """manifest_location=Noneforfileinos.listdir(env_location):iffile.endswith("_manifest.yaml"):manifest_location=f"{env_location}/{file}"breakenv={}ifnotmanifest_location:raiseException(f"Could not locate manifest file in the provided conda environment: {env_location}. Dir Listing - "f"{os.listdir(env_location)}")withopen(manifest_location)asmlf:env=yaml.load(mlf,Loader=yaml.FullLoader)manifest=env["manifest"]returnmanifest
[docs]defzip_artifact(artifact_dir:str)->str:"""Prepares model artifacts ZIP archive. Parameters ---------- artifact_dir: str Path to the model artifact. Returns ------- str Path to the model artifact ZIP archive file. """ifnotartifact_dir:raiseValueError("The `artifact_dir` must be provided.")ifnotos.path.exists(artifact_dir):raiseValueError(f"The {artifact_dir} not exists.")ifnotos.path.isdir(artifact_dir):raiseValueError("The `artifact_dir` must be a folder.")files_to_upload=utils.get_files(artifact_dir)# Set delete=False when creating NamedTemporaryFile,# Otherwise, the file will be delete when download_artifact() close the file.artifact=tempfile.NamedTemporaryFile(suffix=".zip",delete=False)# Close the file since NamedTemporaryFile() opens the file by default.artifact.close()artifact_zip_path=artifact.namewithZipFile(artifact_zip_path,"w")aszf:formatched_fileinfiles_to_upload:zf.write(os.path.join(artifact_dir,matched_file),arcname=matched_file,)returnartifact_zip_path