Source code for ads.model.extractor.model_info_extractor_factory
#!/usr/bin/env python
# -*- coding: utf-8 -*--
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
from ads.model.model_metadata import Framework
from ads.common import logger, utils
from ads.model.extractor.sklearn_extractor import SklearnExtractor
from ads.model.extractor.xgboost_extractor import XgboostExtractor
from ads.model.extractor.lightgbm_extractor import LightgbmExtractor
from ads.model.extractor.keras_extractor import KerasExtractor
from ads.model.extractor.automl_extractor import AutoMLExtractor
from ads.model.extractor.spark_extractor import SparkExtractor
from ads.model.extractor.pytorch_extractor import PytorchExtractor
from ads.model.extractor.tensorflow_extractor import TensorflowExtractor
from ads.model.extractor.huggingface_extractor import HuggingFaceExtractor
ORDERED_FRAMEWORKS = [
"lightgbm",
"xgboost",
"sklearn",
"keras",
"tensorflow",
"bert",
"transformers",
"torch",
"spark",
"automl",
]
[docs]
class ModelInfoExtractorFactory:
"""Class that extract Model Taxonomy Metadata for all supported frameworks."""
_estimator_map = {
Framework.SCIKIT_LEARN: SklearnExtractor,
Framework.XGBOOST: XgboostExtractor,
Framework.LIGHT_GBM: LightgbmExtractor,
Framework.KERAS: KerasExtractor,
Framework.ORACLE_AUTOML: AutoMLExtractor,
Framework.TENSORFLOW: TensorflowExtractor,
Framework.PYTORCH: PytorchExtractor,
Framework.SPARK: SparkExtractor,
Framework.TRANSFORMERS: HuggingFaceExtractor,
}
[docs]
@staticmethod
def extract_info(model):
"""Extracts model taxonomy metadata.
Parameters
----------
model: [ADS model, sklearn, xgboost, lightgbm, keras, oracle_automl]
The model object
Returns
-------
`ModelTaxonomyMetadata`
A dictionary with keys of Framework, FrameworkVersion, Algorithm, Hyperparameters of the model
Examples
--------
>>> from ads.common.model_info_extractor_factory import ModelInfoExtractorFactory
>>> metadata_taxonomy = ModelInfoExtractorFactory.extract_info(model)
"""
from ads.common.model import ADSModel
if isinstance(model, ADSModel):
model = model.est
model_framework = None
model_bases = utils.get_base_modules(model)
model_framework = ModelInfoExtractorFactory._get_estimator(
model_bases=model_bases
)
if model_framework not in ModelInfoExtractorFactory._estimator_map:
logger.warn(
f"Auto-extraction of taxonomy is not supported for the provided model. "
f"The supported models are {', '.join(ORDERED_FRAMEWORKS)}."
)
return None
return ModelInfoExtractorFactory._estimator_map[model_framework](model).info()
@staticmethod
def _get_estimator(model_bases):
mapping = {
"lightgbm": Framework.LIGHT_GBM,
"xgboost": Framework.XGBOOST,
"sklearn": Framework.SCIKIT_LEARN,
"keras": Framework.KERAS,
"tensorflow": Framework.TENSORFLOW,
"bert": Framework.BERT,
"transformers": Framework.TRANSFORMERS,
"torch": Framework.PYTORCH,
"spark": Framework.SPARK,
"automl": Framework.ORACLE_AUTOML,
}
for model_base in model_bases:
for framework in ORDERED_FRAMEWORKS:
if framework in model_base.__module__.split("."):
return mapping.get(framework)
return None