Source code for ads.evaluations.evaluation_plot

#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2020, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from __future__ import print_function, absolute_import, division

import base64
from io import BytesIO
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from matplotlib.ticker import FormatStrFormatter
import numpy as np
import math
from ads.common import logger
from ads.common.decorator.runtime_dependency import (
    runtime_dependency,
    OptionalDependency,
)
import itertools
import pandas as pd

MAX_TITLE_LEN = 20
MAX_LEGEND_LEN = 10
MAX_PLOTS_PER_ROW = 2
# Maximum class number evaluation plotting supporting for multiclass problems
MAX_PLOTTING_CLASSES = 10
# Maximum characters in class label able to be shown without being truncated
MAX_CHARACTERS_LEN = 13


def _fig_to_html(fig):
    tmpfile = BytesIO()
    fig.savefig(tmpfile, format="png")
    encoded = base64.b64encode(tmpfile.getvalue()).decode("utf-8")

    html = "<img src='data:image/png;base64,{}'>".format(encoded)
    return html


[docs] class EvaluationPlot: """EvaluationPlot holds data and methods for plots and it used to output them Attributes ---------- baseline (bool): whether to plot the null model or zero information model baseline_kwargs (dict): keyword arguments for the baseline plot color_wheel (dict): color information used by the plot font_sz (dict): dictionary of plot methods perfect (bool): determines whether a "perfect" classifier curve is displayed perfect_kwargs (dict): parameters for the perfect classifier for precision/recall curves prob_type (str): model type, i.e. classification or regression Methods ------- get_legend_labels(legend_labels) Renders the legend labels on the plot plot(evaluation, plots, num_classes, perfect, baseline, legend_labels) Generates the evalation plot """ # dict of plot methods font_sz = { "xl": 16, # Plot type title "l": 14, # Individual plot title (name of model) "m": 12, # Axis titles "s": 10, # Axis labels "xs": 8, } # test within the plot baseline_kwargs = {"ls": "--", "c": ".2"} # lw = ?? perfect_kwargs = {"label": "Perfect Classifier", "ls": "--", "color": "gold"} perfect = None baseline = None prob_type = None color_wheel = ["teal", "blueviolet", "forestgreen", "peru", "y", "dodgerblue", "r"] _pretty_titles_map = { "normalized_confusion_matrix": "Normalized Confusion Matrix", "lift_chart": "Lift Chart", "gain_chart": "Gain Chart", "ks_statistics": "KS Statistics", "residuals_qq": "Residuals Q-Q Plot", "residuals_vs_predicted": "Residuals vs Predicted", "residuals_vs_observed": "Residuals vs Observed", "observed_vs_predicted": "Observed vs Predicted", "precision_by_label": "Precision by Label", "recall_by_label": "Recall by Label", "f1_by_label": "F1 by Label", "jaccard_by_label": "Jaccard by Label", "pr_curve": "PR Curve", "roc_curve": "ROC Curve", "pr_and_roc_curve": "PR Curve, ROC Curve", "lift_and_gain_chart": "Lift Chart, Gain Chart", } _ugly_titles_map = {v: k for k, v in _pretty_titles_map.items()} double_overlay_plots = ["pr_and_roc_curve", "lift_and_gain_chart"] single_overlay_plots = ["lift_chart", "gain_chart", "roc_curve", "pr_curve"] _bin_plots = [ "pr_curve", "roc_curve", "lift_chart", "gain_chart", "normalized_confusion_matrix", ] _multi_plots = [ "normalized_confusion_matrix", "roc_curve", "pr_curve", "precision_by_label", "recall_by_label", "f1_by_label", "jaccard_by_label", ] _reg_plots = [ "observed_vs_predicted", "residuals_qq", "residuals_vs_predicted", "residuals_vs_observed", ] # list of detailed descriptions of each plot type for every classification type, can be extended when adding more metrics _bin_plots_details = """ In pattern recognition, information retrieval and binary classification, precision (also called positive predictive value) is the fraction of relevant instances among the retrieved instances, while recall (also known as sensitivity) is the fraction of relevant instances that have been retrieved over the total amount of relevant instances. \n A receiver operating characteristic curve, or ROC curve, is a graphical plot that illustrates the diagnostic ability of a binary classifier system as its discrimination threshold is varied. \n In data mining and association rule learning, lift is a measure of the performance of a targeting model (association rule) at predicting or classifying cases as having an enhanced response (with respect to the population as a whole), measured against a random choice targeting model. \n A gain graph is a graph whose edges are labelled 'invertibly', or 'orientably', by elements of a group G. \n In the field of machine learning and specifically the problem of statistical classification, a confusion matrix, also known as an error matrix, is a specific table layout that allows visualization of the performance of an algorithm, typically a supervised learning one (in unsupervised learning it is usually called a matching matrix). """ # Removed for now: # Kuiper 's test (ks_statistics) is used in statistics to test that whether a given distribution, or family of # distributions, is contradicted by evidence from a sample of data. \n _multi_plots_details = """ In the field of machine learning and specifically the problem of statistical classification, a confusion matrix, also known as an error matrix, is a specific table layout that allows visualization of the performance of an algorithm, typically a supervised learning one (in unsupervised learning it is usually called a matching matrix). \n A receiver operating characteristic curve, or ROC curve, is a graphical plot that illustrates the diagnostic ability of a binary classifier system as its discrimination threshold is varied. \n In pattern recognition, information retrieval and binary classification, precision (also called positive predictive value) is the fraction of relevant instances among the retrieved instances, while recall (also known as sensitivity) is the fraction of relevant instances that have been retrieved over the total amount of relevant instances. \n In statistical analysis of binary classification, the F1 score (also F-score or F-measure) is a measure of a test's accuracy. It considers both the precision p and the recall r of the test to compute the score: p is the number of correct positive results divided by the number of all positive results returned by the classifier, and r is the number of correct positive results divided by the number of all relevant samples (all samples that should have been identified as positive). The F1 score is the harmonic mean of the precision and recall, where an F1 score reaches its best value at 1 (perfect precision and recall) and worst at 0. \n The Jaccard index, also known as Intersection over Union and the Jaccard similarity coefficient, is a statistic used for gauging the similarity and diversity of sample sets. The Jaccard coefficient measures similarity between finite sample sets, and is defined as the size of the intersection divided by the size of the union of the sample sets """ _reg_plots_details = """ In statistics, a Q–Q (quantile-quantile) plot is a probability plot, which is a graphical method for comparing two probability distributions by plotting their quantiles against each other. \n In statistics and optimization, errors and residuals are two closely related and easily confused measures of the deviation of an observed value of an element of a statistical sample from its "theoretical value". The error (or disturbance) of an observed value is the deviation of the observed value from the (unobservable) true value of a quantity of interest (for example, a population mean), and the residual of an observed value is the difference between the observed value and the estimated value of the quantity of interest (for example, a sample mean). """ @classmethod def _get_formatted_title(cls, title, max_len=MAX_TITLE_LEN): return title if len(title) < max_len + 3 else title[:max_len] + "..."
[docs] @classmethod def get_legend_labels(cls, legend_labels): """Gets the legend labels, resolves any conflicts such as length, and renders the labels for the plot Parameters ---------- legend_labels (dict): key/value dictionary containing legend label data Returns ------- Nothing Examples -------- EvaluationPlot.get_legend_labels({'class_0': 'green', 'class_1': 'yellow', 'class_2': 'red'}) """ @runtime_dependency(module="IPython", install_from=OptionalDependency.NOTEBOOK) @runtime_dependency( module="ipywidgets", object="HTML", install_from=OptionalDependency.NOTEBOOK ) def render_legend_labels(label_dict): encodings = pd.DataFrame( pd.Series(label_dict, index=label_dict.keys()), columns=["Shortened labels"], ) from IPython.core.display import display, HTML display( HTML( encodings.style.format(precision=4) .set_properties(**{"text-align": "center"}) .set_table_styles( [dict(selector="", props=[("text-align", "center")])] ) .set_table_attributes("class=table") .set_caption( '<div align="left"><b style="font-size:20px;">' + "Legend for labels of the target feature:</b></div>" ) .to_html() ) ) if legend_labels is not None: # CAUTION: cls.classes is a list of strings. Make sure users know that labels are all converted to strings. if isinstance(legend_labels, dict) and set(cls.classes).issubset( set(legend_labels.keys()) ): render_legend_labels(legend_labels) cls.legend_labels = legend_labels else: logger.error( "The provided `legend_labels` is either not a Python dict or does not possess all possible class labels." ) return # try to remove leading words def _check_for_redundant_words(label_vec, prefix, max_len=MAX_LEGEND_LEN): words_found = False classes = [lab for lab in label_vec] while len(set([lab.split()[0] for lab in classes])) <= 1: # remove that word from vec, and add it to prefix additional_prefix = classes[0].split()[0] prefix = ( prefix[:-3] + additional_prefix + " ..." if words_found else additional_prefix + " ..." ) classes = [lab[len(additional_prefix) + 1 :] for lab in classes] words_found = True classes = ( classes if words_found else [label[max_len:] for label in label_vec] ) return classes, prefix # returns mapping from real labels to psuedo-labels, when psuedo-labels are not the first X letter def _resolve_conflict(label_vec, prefix, max_len=MAX_LEGEND_LEN): classes, prefix = _check_for_redundant_words(label_vec, prefix) label_dict = _get_labels(classes, max_len=max_len) resolved = {} for orig, new in label_dict.items(): resolved[prefix[:-3] + orig] = "..." + new if new[:3] != "..." else new return resolved # returns mapping from provided list of strings, to short, unique substrings def _get_labels(classes, max_len=MAX_LEGEND_LEN): conflict_dict = {} for label in classes: prefix = label if len(label) < max_len + 3 else label[:max_len] + "..." if conflict_dict.get(prefix, None) is None: conflict_dict[prefix] = [label] else: conflict_dict[prefix].append(label) out = {} for k, v in conflict_dict.items(): if len(v) == 1: out[v[0]] = k else: resolved = _resolve_conflict(v, k, max_len=MAX_LEGEND_LEN) out.update(resolved) return out cls.legend_labels = _get_labels(cls.classes) if set(cls.legend_labels.keys()) != set(cls.legend_labels.values()): logger.info( f"Class labels greater than {MAX_CHARACTERS_LEN} characters have been truncated. " "Use the `legend_labels` parameter to define labels." ) render_legend_labels(cls.legend_labels)
# evaluation is a DataFrame with models as columns and metrics as rows
[docs] @classmethod def plot( cls, evaluation, plots, num_classes, perfect=False, baseline=True, legend_labels=None, ): """Generates the evaluation plot Parameters ---------- evaluation (DataFrame): DataFrame with models as columns and metrics as rows. plots (str): The plot type based on class attribute `prob_type`. num_classes (int): The number of classes for the model. perfect (bool, optional): Whether to display the curve of a perfect classifier. Default value is `False`. baseline (bool, optional): Whether to display the curve of the baseline, featureless model. Default value is `True`. legend_labels (dict, optional): Legend labels dictionary. Default value is `None`. If legend_labels not specified class names will be used for plots. Returns ------- Nothing """ cls.perfect = perfect cls.baseline = baseline # get plots to show if num_classes == 2: cls.prob_type = "_bin" elif num_classes > 2: cls.prob_type = "_multi" else: cls.prob_type = "_reg" plot_details = getattr(cls, cls.prob_type + "_plots_details") if plots is None: plots = getattr(cls, cls.prob_type + "_plots") logger.info( "Showing plot types: {}.".format( ", ".join( [ "{}".format(EvaluationPlot._pretty_titles_map[str(p)]) for p in plots ] ), ", ".join(["{}".format(x) for x in map(str, plots)]), ) ) logger.info(plot_details) if cls.prob_type == "_bin": if "lift_chart" in plots and "gain_chart" in plots: plots.remove("lift_chart") plots.remove("gain_chart") plots.insert(0, "lift_and_gain_chart") if "roc_curve" in plots and "pr_curve" in plots: plots.remove("roc_curve") plots.remove("pr_curve") plots.insert(0, "pr_and_roc_curve") elif cls.prob_type == "_multi": if ( "normalized_confusion_matrix" in plots and len(evaluation[evaluation.columns[0]]["classes"]) >= MAX_PLOTTING_CLASSES ): logger.error( f"Evaluation plotting is not yet supported for multiclass problems with {MAX_PLOTTING_CLASSES} or more classes." ) plots = [] classes = evaluation[evaluation.columns[0]]["classes"] if classes is not None: # CAUTION: class labels are converted to strings here. # If users are passing in legend_labels, they have to use strings as well. # Otherwise get_legend_labels() will complaint. # If users are not passing in legend_labels, get_legend_labels() generates them from cls.classes. # cls.legend_labels are assigned/created in cls.get_lengend_labels() and contain only strings as keys. cls.classes = [str(c) for c in classes] cls.get_legend_labels(legend_labels) mpl.style.use("default") html_raw = [] for i, plot_type in enumerate(plots): fig_title, fig = None, None try: fig_title, ax_title = plt.subplots(1, 1, figsize=(18, 0.5), dpi=144) ax_title.text( 0.5, 0.5, cls._pretty_titles_map[plot_type], fontsize=16, fontweight="semibold", horizontalalignment="center", verticalalignment="center", transform=ax_title.transAxes, ) ax_title.axis("off") html_raw.append(_fig_to_html(fig_title)) if cls.prob_type == "_bin" and plot_type in cls.double_overlay_plots: fig, ax = plt.subplots(1, 2, figsize=(12, 4.5), dpi=144) elif cls.prob_type == "_bin" and plot_type in ["roc_curve", "pr_curve"]: fig, axs = plt.subplots(1, 2, figsize=(12, 4.5), dpi=144) axs[1].axis("off") ax = [axs[0]] elif cls.prob_type == "_bin" and plot_type in [ "lift_chart", "gain_chart", ]: fig, axs = plt.subplots(1, 2, figsize=(12, 4.5), dpi=144) axs[1].axis("off") ax = axs[0] else: nrows = math.ceil(len(evaluation.columns) / MAX_PLOTS_PER_ROW) fig, ax = plt.subplots( nrows, MAX_PLOTS_PER_ROW, figsize=(10, 4 * nrows), dpi=144 ) # 10, 3.5 ax = ax.flatten() getattr(cls, "_" + plot_type)(ax, evaluation) fig.tight_layout() html_raw.append(_fig_to_html(fig)) except KeyError as e: try: if fig_title: plt.close(fig=fig_title) if fig: plt.close(fig=fig) except: pass logger.warning( f"Evaluator was not able to plot " f"{cls._pretty_titles_map.get(plot_type, plot_type)}, because the relevant " f"metrics had complications. Ensure that `predict` and `predict_proba` " f"are valid." ) return html_raw
@classmethod def _lift_and_gain_chart(cls, ax, evaluation): cls._lift_chart(ax[0], evaluation) cls._gain_chart(ax[1], evaluation) @classmethod def _lift_chart(cls, ax, evaluation): for mod_name, col in evaluation.items(): if col["y_score"] is not None: ax.plot( col["percentages"][1:], [1] + list(col["lift"]), label=cls._get_formatted_title(mod_name), ) if cls.baseline: ax.plot([-10, 110], [1, 1], **cls.baseline_kwargs) if cls.perfect: perf_idx = next( idx for idx, scores in enumerate(evaluation.loc["y_score"]) if scores is not None ) ax.plot( evaluation.loc["percentages"][perf_idx][1:], [1] + list(evaluation.loc["perfect_lift"][perf_idx]), **cls.perfect_kwargs, ) ax.legend(loc="upper right", frameon=False) ax.set_xlabel("Percentage of Population", fontsize=12) ax.set_ylabel("Lift", fontsize=12) ax.set_title("Lift Chart", y=1.08, fontsize=14) ax.grid(linewidth=0.2, which="both") ax.set_xlim([-10, 110]) @classmethod def _gain_chart(cls, ax, evaluation): for mod_name, col in evaluation.items(): if col["y_score"] is not None: ax.plot( col["percentages"], list(col["cumulative_gain"]), label=cls._get_formatted_title(mod_name), ) if cls.baseline: ax.plot([-10, 110], [-10, 110], **cls.baseline_kwargs) if cls.perfect: perf_idx = next( idx for idx, scores in enumerate(evaluation.loc["y_score"]) if scores is not None ) ax.plot( evaluation.loc["percentages"][perf_idx], evaluation.loc["perfect_gain"][perf_idx], **cls.perfect_kwargs, ) ax.legend(loc="lower right", frameon=False) ax.set_xlabel("Percentage of Population", fontsize=12) ax.set_ylabel("Percentage of Positive Class", fontsize=12) ax.set_title("Gain Chart", y=1.08, fontsize=14) ax.grid(linewidth=0.2, which="both") ax.set_xlim([-10, 110]) ax.set_ylim([-10, 110]) @classmethod def _pr_and_roc_curve(cls, ax, evaluation): cls._pr_curve([ax[0]], evaluation) cls._roc_curve([ax[1]], evaluation) @classmethod def _pr_curve(cls, axs, evaluation): n_models = len(evaluation.columns) for i, ax in enumerate(axs): if i >= n_models: ax.axis("off") return if cls.prob_type == "_bin": for mod_name, col in evaluation.items(): if col["y_score"] is not None: ax.plot( col["recall_values"], col["precision_values"], label="%s (Precision: %s)" % ( cls._get_formatted_title(mod_name), "{:.3f}".format(col["precision"]), ), ) ax.plot( *col["pr_best_model_score"], color=ax.get_lines()[-1].get_color(), marker="*", ) else: model_name = evaluation.columns[i] mod = evaluation[model_name] if mod["y_score"] is not None: for j, lab in enumerate(mod.classes): # cls.legend_labels contains only strings as keys. lab = str(lab) ax.plot( mod["recall_values"][j], mod["precision_values"][j], label="%s (Precision: %s)" % ( cls.legend_labels[lab], "{:.3f}".format(mod["precision_by_label"][j]), ), ) ax.plot( *mod["pr_best_model_score"][j], color=ax.get_lines()[-1].get_color(), marker="*", ) ax.set_xlabel("Recall", fontsize=12) ax.set_ylabel("Precision", fontsize=12) ax.set_title("Precision Recall Curve", y=1.08, fontsize=14) ax.grid(linewidth=0.2, which="both") ax.set_xlim([-0.1, 1.1]) ax.set_ylim([-0.1, 1.1]) handles, labels = ax.get_legend_handles_labels() star = mlines.Line2D( [], [], color="black", marker="*", linestyle="None", markersize=5, label="Minimum Error Rate", ) handles.append(star) labels.append("Minimum Error Rate") ax.legend( loc="upper right", labels=labels, handles=handles, frameon=False, fontsize="x-small", ) @classmethod def _roc_curve(cls, axs, evaluation): n_models = len(evaluation.columns) for i, ax in enumerate(axs): if i >= n_models: ax.axis("off") return if cls.prob_type == "_bin": for mod_name, col in evaluation.items(): if col["y_score"] is not None: ax.plot( col["false_positive_rate"], col["true_positive_rate"], label="%s (AUC: %s)" % ( cls._get_formatted_title(mod_name), "{:.3f}".format(col["auc"]), ), ) ax.plot( *col["roc_best_model_score"], color=ax.get_lines()[-1].get_color(), marker="*", ) else: model_name = evaluation.columns[i] mod = evaluation[model_name] if mod["y_score"] is not None: for j, lab in enumerate(mod.classes): # cls.legend_labels contains only strings as keys. lab = str(lab) ax.plot( mod["fpr_by_label"][j], mod["tpr_by_label"][j], label="%s (AUC: %s)" % (cls.legend_labels[lab], "{:.3f}".format(mod["auc"][j])), ) ax.plot( *mod["roc_best_model_score"][j], color=ax.get_lines()[-1].get_color(), marker="*", ) if cls.baseline: ax.plot([-0.1, 1.1], [-0.1, 1.1], **cls.baseline_kwargs) ax.set_xlabel("False Positive Rate", fontsize=12) ax.set_ylabel("True Positive Rate", fontsize=12) ax.set_title("ROC Curve", y=1.08, fontsize=14) ax.grid(linewidth=0.2, which="both") ax.set_xlim([-0.1, 1.1]) ax.set_ylim([-0.1, 1.1]) handles, labels = ax.get_legend_handles_labels() star = mlines.Line2D( [], [], color="black", marker="*", linestyle="None", markersize=5, label="Youden's J Statistic", ) handles.append(star) labels.append("Youden's J Statistic") ax.legend( loc="lower right", labels=labels, handles=handles, frameon=False, fontsize="x-small", ) @classmethod def _ks_statistics(cls, axs, evaluation): n_models = len(evaluation.columns) for i, ax in enumerate(axs): if i >= n_models: ax.axis("off") return model_name = evaluation.columns[i] mod = evaluation[model_name] ax.set_title(model_name, fontsize=14) if mod["y_score"] is not None: ax.plot( mod["ks_thresholds"], mod["ks_pct1"], lw=3, label=mod["ks_labels"][0], ) ax.plot( mod["ks_thresholds"], mod["ks_pct2"], lw=3, label=mod["ks_labels"][1], ) if cls.baseline: idx = np.where(mod["ks_thresholds"] == mod["max_distance_at"])[0][0] ax.axvline( mod["max_distance_at"], *sorted([mod["ks_pct1"][idx], mod["ks_pct2"][idx]]), label="KS Statistic: {:.3f} at {:.3f}".format( mod["ks_statistic"], mod["max_distance_at"] ), linestyle="--", color=".2", ) ax.set_xlim([0.0, 1.0]) ax.set_ylim([0.0, 1.0]) ax.set_xlabel("Threshold", fontsize=12) ax.set_ylabel("Percentage below threshold", fontsize=12) ax.tick_params(labelsize=10) ax.legend(loc="lower right", fontsize=8) @classmethod def _pretty_barh( cls, ax, x, y, axis_labels=None, title=None, axis_lim=None, plot_kwargs=None ): # cls.legend_labels contains only strings as keys. new_lab = [cls.legend_labels[str(item)] for item in x] ax.barh( new_lab, y, color=["teal", "blueviolet", "forestgreen", "peru", "y", "dodgerblue", "r"], ) for j, v in enumerate(y): ax.annotate("{:.3f}".format(v), xy=(v / 2, j), va="center", ha="left") if axis_labels: if axis_labels[0]: ax.set_xlabel(axis_labels[0], fontsize=12) if axis_labels[1]: ax.set_ylabel(axis_labels[1], fontsize=12) if title: title = cls._get_formatted_title(title) ax.set_title(title, y=1.08, fontsize=14) if axis_lim: ax.set_xlim(axis_lim) @classmethod def _precision_by_label(cls, axs, evaluation): n_models = len(evaluation.columns) for i, ax in enumerate(axs): if i < n_models: col = evaluation.columns[i] cls._pretty_barh( ax, evaluation[col]["classes"], evaluation[col]["precision_by_label"], axis_lim=[0, 1], axis_labels=["Precision", None], title=col, ) else: ax.axis("off") @classmethod def _recall_by_label(cls, axs, evaluation): n_models = len(evaluation.columns) for i, ax in enumerate(axs): if i < n_models: col = evaluation.columns[i] cls._pretty_barh( ax, evaluation[col]["classes"], evaluation[col]["recall_by_label"], axis_lim=[0, 1], axis_labels=["Recall", None], title=col, ) else: ax.axis("off") @classmethod def _f1_by_label(cls, axs, evaluation): n_models = len(evaluation.columns) for i, ax in enumerate(axs): if i < n_models: col = evaluation.columns[i] cls._pretty_barh( ax, evaluation[col]["classes"], evaluation[col]["f1_by_label"], axis_lim=[0, 1], axis_labels=["F1 Score", None], title=col, ) else: ax.axis("off") @classmethod def _jaccard_by_label(cls, axs, evaluation): n_models = len(evaluation.columns) for i, ax in enumerate(axs): if i < n_models: col = evaluation.columns[i] cls._pretty_barh( ax, evaluation[col]["classes"], evaluation[col]["jaccard_by_label"], axis_lim=[0, 1], axis_labels=["Jaccard Score", None], title=col, ) else: ax.axis("off") @classmethod def _pretty_scatter( cls, ax, x, y, s=5, alpha=1.0, title=None, legend=False, axis_labels=None, axis_lim=None, grid=True, label=None, plot_kwargs=None, ): if plot_kwargs is None: plot_kwargs = {} ax.scatter(x, y, s=s, label=label, marker="o", alpha=alpha, **plot_kwargs) ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f")) ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f")) if legend: ax.legend(frameon=False) if axis_labels: ax.set_xlabel(axis_labels[0]) ax.set_ylabel(axis_labels[1]) if title: ax.set_title(title, y=1.08, fontsize=14) if grid: ax.grid(linewidth=0.2) if axis_lim: if axis_lim[0]: ax.set_xlim(axis_lim[0]) if axis_lim[1]: ax.set_ylim(axis_lim[1]) @classmethod def _top_2_features(cls, axs, evaluation): pass @classmethod def _residuals_qq(cls, axs, evaluation): n_models = len(evaluation.columns) for i, ax in enumerate(axs): if i >= n_models: ax.axis("off") return model_name = evaluation.columns[i] mod = evaluation[model_name] # getattr(ax, self.plot_method)(self.x, y, **self.plot_kwargs, color='#4a91c2', label=label) cls._pretty_scatter( ax, mod["norm_quantiles"], mod["residual_quantiles"], title=model_name, axis_lim=[(-2.7, 2.7), (-3.1, 3.1)], axis_labels=["Theoretical Quantiles", "Sample Quantiles"], ) if cls.baseline: ax.plot((-100, 100), (-100, 100), **cls.baseline_kwargs) @classmethod def _residuals_vs_predicted(cls, axs, evaluation): n_models = len(evaluation.columns) for i, ax in enumerate(axs): if i >= n_models: ax.axis("off") return model_name = evaluation.columns[i] mod = evaluation[model_name] y_pred = np.asarray(mod["y_pred"]) resid = np.asarray(mod["residuals"]) cls._pretty_scatter( ax, y_pred, resid, s=4, alpha=0.5, title=model_name, axis_labels=["Predicted Values", "Residuals"], ) x_lim = ( y_pred.min() - y_pred.min() * 0.05, y_pred.max() + y_pred.max() * 0.05, ) if cls.baseline: ax.plot(x_lim, (0, 0), **cls.baseline_kwargs) ax.set_xlim(x_lim) ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f")) ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f")) @classmethod def _residuals_vs_observed(cls, axs, evaluation): n_models = len(evaluation.columns) for i, ax in enumerate(axs): if i >= n_models: ax.axis("off") return model_name = evaluation.columns[i] mod = evaluation[model_name] y_true = np.asarray(mod["y_true"]) cls._pretty_scatter( ax, y_true, mod["residuals"], s=4, alpha=0.5, title=model_name, axis_labels=["Observed Values", "Residuals"], ) x_lim = ( y_true.min() - y_true.min() * 0.05, y_true.max() + y_true.max() * 0.05, ) if cls.baseline: ax.plot(x_lim, (0, 0), **cls.baseline_kwargs) ax.set_xlim(x_lim) ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f")) ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f")) @classmethod def _observed_vs_predicted(cls, axs, evaluation): n_models = len(evaluation.columns) for i, ax in enumerate(axs): if i >= n_models: ax.axis("off") return model_name = evaluation.columns[i] mod = evaluation[model_name] ax.scatter(mod["y_true"], mod["y_pred"], s=4, marker="o", alpha=0.5) y_true = np.asarray(mod["y_true"]) yt_min = y_true.min() yt_max = y_true.max() x_lim = (yt_min - yt_min * 0.05, yt_max + yt_max * 0.05) if cls.baseline: ax.plot(x_lim, x_lim, **cls.baseline_kwargs) ax.set_xlabel("Observed Values") ax.set_ylabel("Predicted Values") ax.set_title(model_name, y=1.08, fontsize=14) ax.grid(linewidth=0.2) ax.set_xlim(x_lim) ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f")) ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f")) @classmethod def _normalized_confusion_matrix(cls, axs, evaluation): for model_num, ax in enumerate(axs): if model_num >= len(evaluation.columns): ax.axis("off") return model_name = evaluation.columns[model_num] mod = evaluation[model_name] if cls.prob_type == "_bin": labels = [str(lab == mod["positive_class"]) for lab in mod["classes"]] else: labels = cls.legend_labels.values() raw_cm = mod["raw_confusion_matrix"] cm = np.asarray(mod["confusion_matrix"]) ax.set_title( "%s\n" % cls._get_formatted_title(model_name), y=1.08, fontsize=14 ) ax.imshow(cm, interpolation="nearest", cmap="BuGn") x_tick_marks = np.arange(len(labels)) y_tick_marks = np.arange(len(labels)) ax.set_xticks(x_tick_marks) ax.set_yticks(y_tick_marks) ax.set_xticklabels(labels, rotation=90, fontsize=10) ax.set_yticklabels(labels, fontsize=10) for i, j in itertools.product( range(raw_cm.shape[0]), range(raw_cm.shape[1]) ): ax.text( j, i, "%s [%s]" % (round(cm[i][j], 3), raw_cm[i, j]), horizontalalignment="center", verticalalignment="center", rotation=45, fontsize=max(3, 10 - max(raw_cm.shape[0], raw_cm.shape[1])), color="white" if cm[i, j] > 0.5 else "black", ) ax.set_ylabel("True label", fontsize=10) ax.set_xlabel("Predicted label", fontsize=10) ax.grid(False)