Source code for ads.hpo.validation

#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from ads.common import logger
from ads.hpo.ads_search_space import get_model2searchspace
from ads.hpo.distributions import *
from ads.hpo.utils import _is_arraylike, _make_indexable, _num_samples, _safe_indexing
from sklearn.exceptions import NotFittedError
from sklearn.pipeline import Pipeline


ADS_DISTRIBUTIONS = (
    UniformDistribution,
    LogUniformDistribution,
    DiscreteUniformDistribution,
    IntUniformDistribution,
    IntLogUniformDistribution,
    CategoricalDistribution,
)


[docs] def assert_model_is_supported(estimator): if estimator.__class__ not in get_model2searchspace().keys(): raise NotImplementedError( "{} is not supported.".format(estimator.__class__.__name__) )
[docs] def assert_tuner_is_fitted(estimator, msg=None): if msg is None: msg = ( "This %(name)s instance is not fitted yet. Call 'fit' with " "appropriate arguments before using this method." ) if not hasattr(estimator, "tune"): raise TypeError("%s is not an estimator instance." % (estimator)) attributes = ["_n_splits", "_sample_indices", "_scorer", "_study"] if not all([hasattr(estimator, attr) for attr in attributes]): raise NotFittedError(msg % {"name": type(estimator).__name__})
[docs] def assert_is_estimator(estimator): assert hasattr(estimator, "fit"), "Estimator must implement fit"
[docs] def validate_pipeline(model): if isinstance(model, list): assert all((isinstance(m, tuple) for m in model)), "Model is not a pipeline." return Pipeline(model) return model
[docs] def validate_search_space(params, param_distributions): assert isinstance( param_distributions, dict ), "Must pass a dictionary to <code>search_space</code>!" params_to_delete = [] for param, distribution in param_distributions.items(): if not param in params: logger.warning( f"Ignoring {param} as it is not a hyperparameter of the model." ) params_to_delete.append(param) if ( isinstance(distribution, float) or isinstance(distribution, int) or isinstance(distribution, str) ): param_distributions.update({param: CategoricalDistribution([distribution])}) elif isinstance(distribution, list): param_distributions.update({param: CategoricalDistribution(distribution)}) elif not isinstance(distribution, ADS_DISTRIBUTIONS): raise ValueError( "<code>search_space</code> only accept numbers, strings, list or distributions." ) for param in params_to_delete: del param_distributions[param]
[docs] def assert_strategy_valid(param_distributions, new_strategy, old_strategy=None): if isinstance(new_strategy, str): assert new_strategy in [ "perfunctory", "detailed", ], "Valid values of <code>strategy</code> are `perfunctory` and `detailed`." else: assert isinstance( new_strategy, dict ), "Valid <code>strategy</code> type are string and dictionary." if old_strategy: for name, distribution in new_strategy.items(): if name in param_distributions.keys(): assert isinstance( distribution, param_distributions[name].__class__ ), "Cannot change the distribution of existing params." if isinstance(distribution, CategoricalDistribution): assert set(param_distributions[name].choices) == set( distribution.choices ), "Does not support updating the list of values for categorical distributions."
[docs] def validate_fit_params( X, # type: TwoDimArrayLikeType fit_params, # type: Dict indices, # type: OneDimArrayLikeType ): # type: (...) -> Dict fit_params_validated = {} for key, value in fit_params.items(): # NOTE Original implementation: # https://github.com/scikit-learn/scikit-learn/blob/ \ # 2467e1b84aeb493a22533fa15ff92e0d7c05ed1c/sklearn/utils/validation.py#L1324-L1328 # Scikit-learn does not accept non-iterable inputs. # This line is for keeping backward compatibility. # (See: https://github.com/scikit-learn/scikit-learn/issues/15805) if not _is_arraylike(value) or _num_samples(value) != _num_samples(X): fit_params_validated[key] = value else: fit_params_validated[key] = _make_indexable(value) fit_params_validated[key] = _safe_indexing( fit_params_validated[key], indices ) return fit_params_validated
[docs] def validate_params_for_plot(params, param_distributions): params = [] if params is None else params if not set(params).issubset(set(param_distributions.keys())): raise ValueError("Not all the params are in the search space.")