Source code for ads.model.extractor.keras_extractor

#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import logging
from ads.model.extractor.model_info_extractor import ModelInfoExtractor
from ads.model.model_metadata import Framework
from ads.common.decorator.runtime_dependency import (
    runtime_dependency,
    OptionalDependency,
)


[docs] class KerasExtractor(ModelInfoExtractor): """Class that extract model metadata from keras models. Attributes ---------- model: object The model to extract metadata from. estimator: object The estimator to extract metadata from. """ def __init__(self, model): self.model = model @property def framework(self): """Extracts the framework of the model. Returns ---------- str: The framework of the model. """ return Framework.KERAS @property def algorithm(self): """Extracts the algorithm of the model. Returns ---------- object: The algorithm of the model. """ return self.model.__class__.__name__ @property @runtime_dependency( module="tensorflow", install_from=OptionalDependency.TENSORFLOW, ) def version(self): """Extracts the framework version of the model. Returns ---------- str: The framework version of the model. """ from tensorflow import keras return keras.__version__ @property def hyperparameter(self): """Extracts the hyperparameters of the model. Returns ---------- dict: The hyperparameters of the model. """ if hasattr(self.model, "get_config"): return self.model.get_config() else: logging.warning( "Cannot extract the hyperparameters from this model automatically." ) return {}