Source code for ads.model.runtime.model_provenance_details
#!/usr/bin/env python
# -*- coding: utf-8 -*--
# Copyright (c) 2022 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
from dataclasses import dataclass, field
from typing import Dict
from ads.common.serializer import DataClassSerializable
from ads.model.runtime.env_info import TrainingEnvInfo
from ads.model.runtime.utils import MODEL_PROVENANCE_SCHEMA_PATH, SchemaValidator
[docs]
@dataclass(repr=False)
class TrainingCode(DataClassSerializable):
"""TrainingCode class."""
artifact_directory: str = ""
@classmethod
def _validate_dict(cls, obj_dict: Dict) -> bool:
assert obj_dict and (
"ARTIFACT_DIRECTORY" in obj_dict
), "`training_code` must have `ARTIFACT_DIRECTORY` field."
return True
[docs]
@dataclass(repr=False)
class ModelProvenanceDetails(DataClassSerializable):
"""ModelProvenanceDetails class."""
project_ocid: str = ""
tenancy_ocid: str = ""
training_code: TrainingCode = field(default_factory=TrainingCode)
training_compartment_ocid: str = ""
training_conda_env: TrainingEnvInfo = field(default_factory=TrainingEnvInfo)
training_region: str = ""
training_resource_ocid: str = ""
user_ocid: str = ""
vm_image_internal_id: str = ""
@classmethod
def _validate_dict(cls, obj_dict: Dict) -> bool:
"""validate the yaml file.
Parameters
----------
obj_dict: (Dict)
yaml file content to validate.
Returns
-------
bool
Validation result.
"""
validator = SchemaValidator(schema_file_path=MODEL_PROVENANCE_SCHEMA_PATH)
return validator.validate(obj_dict)