Source code for ads.model.extractor.tensorflow_extractor

#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from ads.model.extractor.model_info_extractor import ModelInfoExtractor
from ads.model.model_metadata import Framework
from ads.common.decorator.runtime_dependency import (
    runtime_dependency,
    OptionalDependency,
)


[docs] class TensorflowExtractor(ModelInfoExtractor): """Class that extract model metadata from tensorflow models. Attributes ---------- model: object The model to extract metadata from. estimator: object The estimator to extract metadata from. Methods ------- framework(self) -> str Returns the framework of the model. algorithm(self) -> object Returns the algorithm of the model. version(self) -> str Returns the version of framework of the model. hyperparameter(self) -> dict Returns the hyperparameter of the model. """ def __init__(self, model): self.model = model @property def framework(self): """Extracts the framework of the model. Returns ---------- str: The framework of the model. """ return Framework.TENSORFLOW @property def algorithm(self): """Extracts the algorithm of the model. Returns ---------- object: The algorithm of the model. """ return self.model.__class__.__name__ @property @runtime_dependency( module="tensorflow", short_name="tf", install_from=OptionalDependency.TENSORFLOW ) def version(self): """Extracts the framework version of the model. Returns ---------- str: The framework version of the model. """ return tf.__version__ @property def hyperparameter(self): """Extracts the hyperparameters of the model. Returns ---------- dict: The hyperparameters of the model. """ return None