Source code for ads.dataset.correlation_plot

#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2020, 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import os

import matplotlib as mpl
import matplotlib.cm as cm
import numpy as np
import pandas as pd
from ads.common import utils as au
from ads.common.decorator.runtime_dependency import (
    runtime_dependency,
    OptionalDependency,
)
from ads.dataset.correlation import _validate_correlation_methods
from ads.dataset.helper import deprecate_default_value


[docs] class BokehHeatMap(object): """ Generate a HeatMap or horizontal bar plot to compare features. """ @runtime_dependency(module="bokeh", install_from=OptionalDependency.VIZ) def __init__(self, ds): from bokeh.io import output_notebook from bokeh.palettes import BuPu output_notebook() self.ds = ds self.colormap = cm.get_cmap("BuPu") self.bokehpalette = [ mpl.colors.rgb2hex(m) for m in self.colormap(np.arange(self.colormap.N)) ]
[docs] def debug(self): """ Return True if in debug mode, otherwise False. """ return au.is_debug_mode()
[docs] def flatten_corr_matrix(self, corr_matrix): """ Flatten a correlation matrix into a pandas Dataframe. Parameters ---------- corr_matrix: Pandas Dataframe The correlation matrix to be flattened. Returns ------- corr_flatten: Pandas DataFrame The flattened correlation matrix. """ rows = corr_matrix.index.values.tolist() columns = corr_matrix.columns corr_flatten = pd.DataFrame( [(r, c, corr_matrix[r][c]) for c in columns for r in rows], columns=["x", "y", "corr"], ) return corr_flatten
[docs] @runtime_dependency(module="bokeh", install_from=OptionalDependency.VIZ) def plot_heat_map( self, matrix, xrange: list, yrange: list, low: float = 1, high=1, title: str = None, tool_tips: list = None, ): """ Plots a matrix as a heatmap. Parameters ---------- matrix: Pandas Dataframe The dataframe to be plotted. xrange: List of floats The range of x values to plot. yrange: List of floats The range of y values to plot. low: float, Defaults to 1 The color mapping value for "low" points. high: float, Defaults to 1 The color mapping value for "high" points. title: str, Defaults to None The optional title of the heat map. tool_tips: list of str, Defaults to None An optional list of tool tips to include with the plot. Returns ------- fig: matplotlib Figure A matplotlib heatmap figure object. """ if self.debug(): print(matrix) mapper = bokeh.models.LinearColorMapper( palette=self.bokehpalette, low=low, high=high ) source = bokeh.models.ColumnDataSource(matrix) from bokeh.plotting import figure p = figure( title=title, x_range=xrange, y_range=yrange, toolbar_location="below", toolbar_sticky=False, width=600, height=600, ) p.rect( x="x", y="y", width=1, height=1, source=source, fill_color={"field": "corr", "transform": mapper}, line_color=None, ) p.xaxis.major_label_orientation = "vertical" if tool_tips: p.add_tools(bokeh.models.HoverTool(tooltips=tool_tips)) color_bar = bokeh.models.ColorBar( color_mapper=mapper, major_label_text_font_size="5pt", ticker=bokeh.models.BasicTicker(desired_num_ticks=8), formatter=bokeh.models.PrintfTickFormatter(format="%0.2f"), label_standoff=6, border_line_color=None, location=(0, 0), ) p.add_layout(color_bar, "right") return p
[docs] @runtime_dependency(module="bokeh", install_from=OptionalDependency.VIZ) def plot_hbar( self, matrix, low: float = 1, high=1, title: str = None, tool_tips: list = None, column_name: str = None, ): """ Plots a histogram bar-graph. Parameters ---------- matrix: Pandas Dataframe The dataframe to be plotted. low: float, Defaults to 1 The color mapping value for "low" points. high: float, Defaults to 1 The color mapping value for "high" points. title: str, Defaults to None The optional title of the heat map. tool_tips: list of str, Defaults to None An optional list of tool tips to include with the plot. column_name: str, Defaults to None The name of the column which is being plotted. Returns ------- fig: matplotlib Figure A matplotlib heatmap figure object. """ mapper = bokeh.models.LinearColorMapper( palette=self.bokehpalette, low=low, high=high ) source = bokeh.models.ColumnDataSource(matrix) from bokeh.plotting import figure p = figure( title=f"{title} ({column_name})", x_range=(low, high), y_range=(0, len(matrix["Y"]) + 1), toolbar_location="below", toolbar_sticky=False, width=600, height=600, ) p.hbar( y="Y", height=0.5, left=0, source=source, right="corr", fill_color={"field": "corr", "transform": mapper}, line_color=None, ) p.add_tools(bokeh.models.HoverTool(tooltips=tool_tips)) color_bar = bokeh.models.ColorBar( color_mapper=mapper, major_label_text_font_size="5pt", ticker=bokeh.models.BasicTicker(desired_num_ticks=8), formatter=bokeh.models.PrintfTickFormatter(format="%0.2f"), label_standoff=6, border_line_color=None, location=(0, 0), ) labels = bokeh.models.LabelSet( x="label", y="Y", text="X", level="glyph", y_offset=-5, source=source, ) p.add_layout(color_bar, "right") p.add_layout(labels) return p
[docs] @runtime_dependency(module="bokeh", install_from=OptionalDependency.VIZ) def generate_heatmap( self, corr_matrix, title: str, msg: str, correlation_threshold: float ): """ Generate a heatmap from a correlation matrix. Parameters ---------- corr_matrix: Pandas Dataframe The dataframe to be used for heatmap generation. title: str title of the heatmap. msg: str An additional msg to include in the plot. correlation_threshold: float A float between 0 and 1 which is used for excluding correlations which are not intense enough from the plot. Returns ------- tab: matplotlib Panel A matplotlib Panel object which includes a plotted heatmap """ from bokeh.plotting import figure if len(corr_matrix) == 0: tab = bokeh.models.TabPanel( child=figure(title=msg + ", nothing to display"), title=title, ) return tab corr_matrix = _corr_filter(correlation_threshold, corr_matrix) corr_flatten = self.flatten_corr_matrix(corr_matrix) low = -1 if corr_matrix.min().min() < 0 else 0 high = 1 if self.debug(): print(f"{title} : Min is {low}, Max is: {high}") p = self.plot_heat_map( corr_flatten, corr_matrix.index.values.tolist(), corr_matrix.columns.values.tolist(), low, high, title, tool_tips=[("X", "@x"), ("Y", "@y"), ("Corr", "@corr")], ) tab = bokeh.models.TabPanel(child=p, title=title) return tab
[docs] @runtime_dependency(module="bokeh", install_from=OptionalDependency.VIZ) def generate_target_heatmap( self, corr_matrix, title: str, correlation_target: str, msg: str, correlation_threshold: float, ): """ Generate a heatmap from a correlation matrix and its targets. Parameters ---------- corr_matrix: Pandas Dataframe The dataframe to be used for heatmap generation. title: str title of the heatmap. correlation_target: str The target column name for computing correlations against. msg: str An additional msg to include in the plot. correlation_threshold: float A float between 0 and 1 which is used for excluding correlations which are not intense enough from the plot. Returns ------- tab: matplotlib Panel A matplotlib Panel object which includes a plotted heatmap. """ from bokeh.plotting import figure if len(corr_matrix) == 0: tab = bokeh.models.TabPanel( child=figure(title=msg + ", nothing to display"), title=title, ) return tab corr_matrix = _corr_filter(correlation_threshold, corr_matrix) assert correlation_target, "Correlation target is required for this plot" if correlation_target not in corr_matrix.columns: tab = bokeh.models.TabPanel( child=figure(title="No Data to display"), title=title ) return tab corr_flatten = {} corr_clean = corr_matrix[correlation_target].dropna() if self.debug(): print(corr_matrix[correlation_target].values.tolist()) print(len(corr_matrix[correlation_target].values.tolist())) print("*" * 30) corr_flatten["X"] = corr_clean.index.values.tolist() if self.debug(): print(corr_clean.index.values.tolist()) print(len(corr_clean.index.values.tolist())) print("*" * 30) corr_flatten["Y"] = [ i for i in range(1, len(corr_clean.index.values.tolist()) + 1) ] corr_flatten["corr"] = corr_clean.values.tolist() corr_flatten["label"] = [ x + 0.01 if x > 0 else 0.01 for x in corr_clean.values.tolist() ] corr_flatten = pd.DataFrame.from_dict(data=corr_flatten) if self.debug(): print(corr_flatten) low = -1 if corr_clean.min() < 0 else 0 high = 1 if self.debug(): print(f"{title} : Min is {low}, Max is: {high}") p = self.plot_hbar( corr_flatten, low, high, title, tool_tips=[("Feature", "@X"), ("Corr", "@corr")], column_name=correlation_target, ) tab = bokeh.models.TabPanel(child=p, title=title) return tab
[docs] @runtime_dependency(module="bokeh", install_from=OptionalDependency.VIZ) def plot_correlation_heatmap( self, ds, plot_type: str = "heatmap", correlation_target: str = None, correlation_threshold=-1, correlation_methods: str = "pearson", **kwargs, ): """ Plots a correlation heatmap. Parameters ---------- ds: Pandas Slice A data slice or file plot_type: str Defaults to "heatmap" The type of plot - "bar" is another option. correlation_target: str, Defaults to None the target column for correlation calculations. correlation_threshold: float, Defaults to -1 the threshold for computing correlation heatmap elements. correlation_methods: str, Defaults to "pearson" the way to compute correlations, other options are "cramers v" and "correlation ratio" """ assert self.ds or ds, "Expecting input for ds or file" plot_ds = None if ds: plot_ds = ds else: plot_ds = self.ds frac = kwargs.get("frac", 1) frac = deprecate_default_value( frac, None, 1, f"<code>frac=None</code> is deprecated. Use <code>frac=1.0</code> instead.", FutureWarning, ) force_recompute = kwargs.get("force_recompute", False) nan_threshold = kwargs.get("nan_threshold", 0.8) cts_cts = pd.DataFrame() cat_cat = pd.DataFrame() cat_cts = pd.DataFrame() correlation_list = plot_ds.corr( frac=frac, nan_threshold=nan_threshold, force_recompute=force_recompute, correlation_methods=correlation_methods, ) correlation_methods = _validate_correlation_methods(correlation_methods) for method in correlation_methods: if method == "pearson": cts_cts = plot_ds._pearson elif method == "cramers v": cat_cat = plot_ds._cramers_v elif method == "correlation ratio": cat_cts = plot_ds._correlation_ratio else: raise ValueError(f"This {method} is not supported.") # generate the msg cts_cts_msg = _generate_msg( "pearson", correlation_methods, correlation_target, "continuous" ) cat_cat_msg = _generate_msg( "cramers v", correlation_methods, correlation_target, "categorical" ) cat_cts_msg = _generate_msg( "correlation ratio", correlation_methods, correlation_target, "continuous and categorical", ) panel_items = [ {"data": cts_cts, "title": "Continuous vs Continuous", "msg": cts_cts_msg}, {"data": cat_cat, "title": "Category vs Category", "msg": cat_cat_msg}, {"data": cat_cts, "title": "Category vs Continuous", "msg": cat_cts_msg}, ] if self.debug(): print("-" * 50, "Continuous vs Continuous", "-" * 50) print(cts_cts.head()) print("-" * 50, "Category vs Category", "-" * 50) print(cat_cat.head()) print("-" * 50, "Category vs Continuous", "-" * 50) print(cat_cts.head()) if plot_type == "heatmap": tabs = [ self.generate_heatmap( item["data"], item["title"], item["msg"], correlation_threshold ) for item in panel_items ] elif plot_type == "bar": tabs = [ self.generate_target_heatmap( item["data"], item["title"], correlation_target, item["msg"], correlation_threshold, ) for item in panel_items ] else: raise ValueError("Only supported plot types are heatmap and bar") bokeh_tabs = bokeh.models.Tabs(tabs=tabs) from bokeh.io import show show(bokeh_tabs)
def _generate_msg( method: str, correlation_methods: list, correlation_target: str, feature_type: str ): if method in correlation_methods or "all" in correlation_methods: if correlation_target is not None: msg = ( correlation_target + f"is not of the {feature_type} type or not enough of the {feature_type} type(s)" ) else: msg = f"Not enough of the {feature_type} type(s)" else: msg = f'Pass in "{method}" to show the plot' return msg def _has_categorical_features(ds): df_type = pd.DataFrame.from_dict(ds.feature_types).T return ( "categorical" in df_type["type"].values or "zipcode" in df_type["type"].values )
[docs] def plot_correlation_heatmap(ds=None, **kwargs) -> None: """ Plots a correlation heatmap. Parameters ---------- ds: Pandas Slice A data slice or file """ corr_heatmap_helper = BokehHeatMap(ds) corr_heatmap_helper.plot_correlation_heatmap(ds, **kwargs)
def _corr_filter(correlation_threshold, corr): if correlation_threshold >= 1 or correlation_threshold < -1: raise ValueError( "The correlation_threshold value should within the range from -1 to 1." ) new_corr = corr.copy(deep=True) np.fill_diagonal(new_corr.values, np.nan) new_corr = ( new_corr[new_corr >= correlation_threshold] .dropna(axis=0, how="all") .dropna(axis=1, how="all") ) # fill the diagonal with 1 to effectively shrink the correlation matrix np.fill_diagonal(new_corr.values, 1) return new_corr