Source code for ads.model.extractor.sklearn_extractor
#!/usr/bin/env python
# -*- coding: utf-8 -*--
# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
import logging
import re
from collections import defaultdict
from ads.model.extractor.model_info_extractor import (
ModelInfoExtractor,
normalize_hyperparameter,
)
from ads.model.model_metadata import Framework
[docs]
class SklearnExtractor(ModelInfoExtractor):
"""Class that extract model metadata from sklearn models.
Attributes
----------
model: object
The model to extract metadata from.
estimator: object
The estimator to extract metadata from.
Methods
-------
framework(self) -> str
Returns the framework of the model.
algorithm(self) -> object
Returns the algorithm of the model.
version(self) -> str
Returns the version of framework of the model.
hyperparameter(self) -> dict
Returns the hyperparameter of the model.
"""
def __init__(self, model):
self.model = model
@property
def framework(self):
"""Extracts the framework of the model.
Returns
----------
str:
The framework of the model.
"""
return Framework.SCIKIT_LEARN
@property
def algorithm(self):
"""Extracts the algorithm of the model.
Returns
----------
object:
The algorithm of the model.
"""
return self.model.__class__.__name__
@property
def version(self):
"""Extracts the framework version of the model.
Returns
----------
str:
The framework version of the model.
"""
import sklearn
return sklearn.__version__
@property
def hyperparameter(self):
"""Extracts the hyperparameters of the model.
Returns
----------
dict:
The hyperparameters of the model.
"""
if hasattr(self.model, "get_params"):
hp_dict = self.model.get_params()
# make shallow copy to avoid modifying the model object
new_dict = hp_dict.copy()
# handle sklearn pipeline case
if "steps" in hp_dict:
new_dict["steps"] = defaultdict(list)
for i, (k, v) in enumerate(hp_dict["steps"]):
new_dict["steps"][i] = {k: re.sub("[()]", "", str(v))}
new_dict[k] = re.sub("[()]", "", str(v))
# handle sklearn model selection case
elif "param_grid" in hp_dict:
new_dict["estimator"] = str(hp_dict["estimator"])
new_dict["param_grid"] = defaultdict(list)
for k, v in hp_dict["param_grid"].items():
new_dict["param_grid"][k] = v.tolist()
new_dict.update(self.model.best_params_)
return normalize_hyperparameter(new_dict)
else:
# for onnx model case.
logging.warning(
"Cannot extract the hyperparameters from this model automatically."
)
return {}