#!/usr/bin/env python
# -*- coding: utf-8; -*-
# Copyright (c) 2020, 2022 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
from __future__ import print_function, absolute_import
import random
from collections import defaultdict
from math import pi
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors as mcolors
from ads.dataset.helper import _log_yscale_not_set
from ads.common.decorator.runtime_dependency import (
runtime_dependency,
OptionalDependency,
)
from ads.common.utils import _log_plot_high_cardinality_warning, MAX_DISPLAY_VALUES
from ads.type_discovery.latlon_detector import LatLonDetector
from ads.type_discovery.typed_feature import (
ContinuousTypedFeature,
DateTimeTypedFeature,
ConstantTypedFeature,
DiscreteTypedFeature,
CreditCardTypedFeature,
ZipcodeTypedFeature,
OrdinalTypedFeature,
CategoricalTypedFeature,
GISTypedFeature,
)
from ads.dataset import logger
[docs]class Plotting:
def __init__(self, df, feature_types, x, y=None, plot_type="infer", yscale=None):
self.df = df
self.feature_types = feature_types
self.x = x
self.y = y
self.x_type = self.feature_types[self.x]
self.y_type = self.feature_types[self.y] if self.y is not None else None
self.plot_type = plot_type
self.yscale = yscale
def __repr__(self):
choices = self._get_plot_method() # add (plot_type='{0}') in plot method"
if len(choices) > 1:
logger.info(f"Recommended plot type is {choices[0][1].__name__}.")
logger.info(
"Available plot types are ",
", ".join([x[1].__name__ for x in choices]),
".",
)
self.show_in_notebook()
return ""
[docs] def select_best_plot(self):
"""
Returns the best plot for a given dataset
"""
#
# auto logic
#
choices = self._get_plot_method()
if len(choices) > 1:
logger.info(
"select_best_plot (%s, %s) called, possible plot types are %s"
% (
self.x_type.meta_data["type"],
self.y_type.meta_data["type"] if self.y_type is not None else "",
", ".join([x[0].__name__ for x in choices]),
)
)
if self.plot_type != "infer":
for choice in choices:
if choice[1].__name__.lower().startswith(self.plot_type.lower()):
return choice
logger.info("invalid plot_type: {}".format(self.plot_type))
raise ValueError(
"plot_type: '%s' invalid, use one of: %s"
% (self.plot_type, ", ".join([x[0].__name__ for x in choices]))
)
return choices[0]
[docs] def show_in_notebook(self, **kwargs):
"""
Visualizes the dataset by plotting the distribution of a feature or relationship between two features.
Parameters
----------
figsize: tuple
defines the size of the fig
-------
"""
plotlib_type, plot_method, plot_kwargs = self.select_best_plot()
plotlib_type(plot_method, **plot_kwargs, **kwargs)
@staticmethod
def _add_identity(ax, *line_args, **line_kwargs):
ax.plot(ax.get_xlim(), ax.get_ylim(), ls="--", c="-.3", transform=ax.transAxes)
@staticmethod
def _build_plot_key(x_type, y_type=None):
if y_type is None:
return x_type.__name__
return x_type.__name__ + "," + y_type.__name__
@staticmethod
@runtime_dependency(module="scipy", install_from=OptionalDependency.VIZ)
def _gaussian_heatmap(x, y, data, s=10, edgecolor="white", cmap=plt.cm.jet):
"""
Generate a scatter plot and assign a color to each data point based on the local density (gaussian kernel) of
points.
Parameters
----------
x: str
name of the feature
y: str
name of the feature
data: object
The dataframe that contains x and y
s: int
area of each marker
edgecolor: str
edge color of each point. string value, e.g. 'blue'
cmap: object
color map for the heatmap
Returns
-------
ax.scatter() object: a scatter plot with colored density
Raises
------
ValueError
When the columns are identical or the columns are highly correlated.
"""
try:
_x = np.array(data[x])
_y = np.array(data[y])
xy = np.vstack([_x, _y])
z = scipy.stats.gaussian_kde(xy)(xy)
sc = plt.scatter(_x, _y, c=z, s=s, edgecolor=edgecolor, cmap=cmap)
plt.xlabel(x)
plt.ylabel(y)
return plt.colorbar(sc)
except:
return plt.scatter(_x, _y)
@staticmethod
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _categorical_vs_continuous_violin_plot(x, y, data):
# when x is categorical, we set the order based on counts of each category
vc = data[x].value_counts()
if len(vc.keys()) > 10:
_log_plot_high_cardinality_warning(x, len(vc.keys()))
idxes = vc[:10].index
else:
idxes = vc.index
seaborn.violinplot(x=x, y=y, data=data[data[x].isin(idxes)], order=idxes)
@staticmethod
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _ordinal_vs_continuous_violin_plot(x, y, data):
# when x is ordinal, we want to get a natural order to x values
vals = np.array(data[x].values)
sorted_x = list(np.sort(vals))
# get the frequency of each distinct element in the list using a dictionary.
freq = {}
for items in sorted_x:
freq[items] = sorted_x.count(items)
if len(freq) > MAX_DISPLAY_VALUES:
_log_plot_high_cardinality_warning(x, len(freq))
idxes = list(freq.keys())[:10]
else:
idxes = freq.keys()
seaborn.violinplot(x=x, y=y, data=data[data[x].isin(idxes)], order=idxes)
@staticmethod
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _categorical_vs_continuous_horizontal_violin_plot(x, y, data):
vc = data[x].value_counts()
if len(vc.keys()) > 10:
_log_plot_high_cardinality_warning(x, len(vc.keys()))
idxes = vc[:10].index
else:
idxes = vc.index
seaborn.violinplot(
x=y, y=x, data=data[data[x].isin(idxes)], order=idxes, orient="h"
)
@staticmethod
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _ordinal_vs_continuous_horizontal_violin_plot(x, y, data):
# when x is ordinal, we want to get a natural order to x values
vals = np.array(data[x].values)
sorted_x = list(np.sort(vals))
# get the frequency of each distinct element in the list using a dictionary.
freq = {}
for items in sorted_x:
freq[items] = sorted_x.count(items)
if len(freq) > MAX_DISPLAY_VALUES:
_log_plot_high_cardinality_warning(x, len(freq))
idxes = list(freq.keys())[:10]
else:
idxes = freq.keys()
seaborn.violinplot(
x=y, y=x, data=data[data[x].isin(idxes)], order=idxes, orient="h"
)
@staticmethod
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _categorical_vs_continuous_box_plot(x, y, data):
vc = data[x].value_counts()
if len(vc.keys()) > 10:
_log_plot_high_cardinality_warning(x, len(vc.keys()))
idxes = vc[:10].index
else:
idxes = vc.index
seaborn.boxplot(x=x, y=y, data=data[data[x].isin(idxes)], order=idxes)
@staticmethod
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _ordinal_vs_continuous_box_plot(x, y, data):
# when x is ordinal, we want to get a natural order to x values
vals = np.array(data[x].values)
sorted_x = list(np.sort(vals))
# get the frequency of each distinct element in the list using a dictionary.
freq = {}
for items in sorted_x:
freq[items] = sorted_x.count(items)
if len(freq) > MAX_DISPLAY_VALUES:
_log_plot_high_cardinality_warning(x, len(freq))
idxes = list(freq.keys())[:10]
else:
idxes = freq.keys()
seaborn.boxplot(x=x, y=y, data=data[data[x].isin(idxes)], order=idxes)
@staticmethod
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _ordinal_vs_continuous_horizontal_box_plot(x, y, data):
# when x is ordinal, we want to get a natural order to x values
vals = np.array(data[x].values)
sorted_x = list(np.sort(vals))
# get the frequency of each distinct element in the list using a dictionary.
freq = {}
for items in sorted_x:
freq[items] = sorted_x.count(items)
if len(freq) > MAX_DISPLAY_VALUES:
_log_plot_high_cardinality_warning(x, len(freq))
idxes = list(freq.keys())[:10]
else:
idxes = freq.keys()
seaborn.boxplot(
x=y, y=x, data=data[data[x].isin(idxes)], order=idxes, orient="h"
)
@staticmethod
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _count_plot(x, hue, data, yscale=None):
if not yscale:
_log_yscale_not_set()
# get the copy of data and convert categorical data type to object
data = data.copy()
data[hue] = data[hue].astype("object")
# get cardinality of categorical values
cat_cardi = data[hue].value_counts()
if len(cat_cardi) > 5:
top_categoricals = cat_cardi[:5].index
# modify the data to replace non top 5 categorical values to be the same value
data[hue] = np.where(
data[hue].isin(top_categoricals), data[hue], "all_other_categories"
)
cat_index = data[hue].value_counts().index
else:
cat_index = cat_cardi.index
# get cardinality of ordinal values
ordi_cardi = data[x].value_counts()
if len(ordi_cardi) > 10:
# bin the values and sort from small to large
data[x] = pd.cut(data[x], 10, precision=0)
data[x] = data[x].apply(lambda k: pd.Interval(int(k.left), int(k.right)))
g = seaborn.countplot(x=x, hue=hue, data=data, hue_order=cat_index)
if yscale:
g.set_yscale(yscale)
@staticmethod
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _ordinal_count_plot(x, data, yscale=None):
if not yscale:
_log_yscale_not_set()
data = data.copy()
ordi_cardi = data[x].value_counts()
if len(ordi_cardi) > 20:
intervals = pd.cut(data[x], 20, precision=0)
intervals = intervals.apply(
lambda k: pd.Interval(int(k.left), int(k.right))
)
else:
intervals = data[x]
g = seaborn.countplot(x=intervals, color="#1f77b4")
if yscale:
g.set_yscale(yscale)
@staticmethod
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _ordinal_vs_constant_count_plot(x, hue, data, yscale=None):
if not yscale:
_log_yscale_not_set()
data = data.copy()
intervals = pd.cut(data[x], 20, precision=0)
intervals = intervals.apply(lambda k: pd.Interval(int(k.left), int(k.right)))
g = seaborn.countplot(x=intervals, hue=hue, data=data)
if yscale:
g.set_yscale(yscale)
@staticmethod
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _single_column_count_plot(x, data, yscale=None):
if not yscale:
_log_yscale_not_set()
order = data[x].value_counts().iloc[:24].index
g = seaborn.countplot(x=x, data=data, order=order)
if yscale:
g.set_yscale(yscale)
@staticmethod
@runtime_dependency(module="IPython", install_from=OptionalDependency.NOTEBOOK)
@runtime_dependency(module="folium", install_from=OptionalDependency.VIZ)
def _folium_map(x, data):
import folium.plugins
df = LatLonDetector.extract_x_y(data[x])
lat_min, lat_max, long_min, long_max = (
min(df.Y),
max(df.Y),
min(df.X),
max(df.X),
)
m = folium.Map(tiles="Stamen Terrain", zoom_control=False)
folium.plugins.HeatMap(df[["Y", "X"]]).add_to(m)
m.fit_bounds([[lat_min, long_min], [lat_max, long_max]])
from IPython.core.display import display
display(m)
@staticmethod
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _single_pdf(x, y, data):
seaborn.kdeplot(data[x], shade=True, shade_lowest=False)
@staticmethod
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _multiple_pdf(x, y, data):
colors = dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS)
hues = [
colors[x]
for x in colors.keys()
if isinstance(colors[x], str) and colors[x].startswith("#")
]
for i, cat in enumerate(list(data[y].unique())):
s = data.loc[data[y] == cat][x]
color = random.choice(hues)
seaborn.kdeplot(s, color=color, shade=True, shade_lowest=False, label=cat)
plt.xlabel(x)
plt.ylabel(y)
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _matplot(self, plot_method, figsize=(4, 3), **kwargs):
plt.style.use("seaborn-white")
plt.rc("xtick", labelsize="x-small")
plt.rc("ytick", labelsize="x-small")
plt.rc("font", size=8)
plt.rc("figure", dpi=144)
fig = plt.figure(figsize=figsize)
#
# generate a title for the plot
#
text = '{}, "{}" ({})'.format(
plot_method.__name__.upper(), self.x, self.feature_types[self.x].type
)
if self.y:
text = '{} vs "{}" ({})'.format(
text, self.y, self.feature_types[self.y].type
)
plt.title(text, y=1.08)
plt.grid(linestyle="dotted")
# draw a 45 degree dotted or dashed line indicating equality when plot method is scatter plots and the span of x equals the span of y
if (
plot_method is plt.scatter or plot_method is Plotting._gaussian_heatmap
) and (self.df[kwargs["x"]].values.ptp() == self.df[kwargs["y"]].values.ptp()):
Plotting._add_identity(fig.axes[0], color="grey", ls="--")
plot_method(**kwargs, data=self.df)
# set x and y axis label
if self.y:
plt.ylabel(self.y)
if self.x:
plt.xlabel(self.x)
# rename the y-axis label and x-axis label when "count" is the y-axis label
if self.y == "count":
plt.xlabel("Column: {} values ".format(self.x))
plt.ylabel("instance count")
# add y-axis label as "count" when plot type is hist
if plot_method is plt.hist:
plt.ylabel("instance count")
# add tickmark on x-axis to see labeled values on x-axis in historgram. It has 30 intervals because that's the most tickmarks on graph can fit.
plt.xticks(
np.arange(
min(self.df[kwargs["x"]].values),
max(self.df[kwargs["x"]].values) + 1,
(
max(self.df[kwargs["x"]].values)
- min(self.df[kwargs["x"]].values)
)
/ 30,
)
)
# override y-axis label as "count" when plot type is _count_plot or countplot
if plot_method is Plotting._count_plot or plot_method is seaborn.countplot:
plt.ylabel("count")
plt.xticks(rotation=90)
def _generic_plot(self, plot_method, **kwargs):
plot_method(**kwargs, data=self.df)
@runtime_dependency(module="seaborn", install_from=OptionalDependency.VIZ)
def _get_plot_method(self):
#
# combos contains a dictionary with the key being a composite of the x and y types, the value will
# always be a list, possibly and empty list, indicating no match for combination
#
#
combos = defaultdict(list)
combos[
Plotting._build_plot_key(CategoricalTypedFeature, ContinuousTypedFeature)
] = [
(
self._matplot,
Plotting._categorical_vs_continuous_violin_plot,
{"x": self.x, "y": self.y},
),
(
self._matplot,
Plotting._categorical_vs_continuous_box_plot,
{"x": self.x, "y": self.y},
),
]
combos[
Plotting._build_plot_key(OrdinalTypedFeature, ContinuousTypedFeature)
] = [
(
self._matplot,
Plotting._ordinal_vs_continuous_violin_plot,
{"x": self.x, "y": self.y},
),
(
self._matplot,
Plotting._ordinal_vs_continuous_box_plot,
{"x": self.x, "y": self.y},
),
]
combos[
Plotting._build_plot_key(ContinuousTypedFeature, OrdinalTypedFeature)
] = [
(
self._matplot,
Plotting._ordinal_vs_continuous_horizontal_violin_plot,
{"x": self.y, "y": self.x},
),
(
self._matplot,
Plotting._ordinal_vs_continuous_horizontal_box_plot,
{"x": self.y, "y": self.x},
),
]
combos[
Plotting._build_plot_key(ContinuousTypedFeature, CategoricalTypedFeature)
] = [(self._matplot, Plotting._multiple_pdf, {"x": self.x, "y": self.y})]
combos[
Plotting._build_plot_key(ConstantTypedFeature, ContinuousTypedFeature)
] = [(self._matplot, seaborn.barplot, {"x": self.x, "y": self.y})]
combos[
Plotting._build_plot_key(ContinuousTypedFeature, ConstantTypedFeature)
] = [(self._matplot, Plotting._single_pdf, {"x": self.x, "y": self.y})]
combos[Plotting._build_plot_key(ConstantTypedFeature, DiscreteTypedFeature)] = [
(self._matplot, seaborn.barplot, {"x": self.x, "y": self.y})
]
combos[Plotting._build_plot_key(DiscreteTypedFeature, ConstantTypedFeature)] = [
(
self._matplot,
Plotting._ordinal_vs_constant_count_plot,
{"x": self.x, "hue": self.y, "yscale": self.yscale},
)
]
combos[
Plotting._build_plot_key(DateTimeTypedFeature, ContinuousTypedFeature)
] = [
(
self._matplot,
plt.scatter,
{
"x": self.x,
"y": self.y,
"s": pi / 10 * (matplotlib.rcParams["lines.markersize"] ** 2),
"edgecolor": "white",
"linewidths": "0.1",
},
)
]
combos[Plotting._build_plot_key(DateTimeTypedFeature, OrdinalTypedFeature)] = [
(
self._matplot,
plt.scatter,
{"x": self.x, "y": self.y, "edgecolor": "white", "linewidths": "0.1"},
)
]
combos[
Plotting._build_plot_key(ContinuousTypedFeature, ContinuousTypedFeature)
] = [
(
self._matplot,
Plotting._gaussian_heatmap,
{
"x": self.x,
"y": self.y,
"s": pi / 10 * (matplotlib.rcParams["lines.markersize"] ** 2),
},
),
(
self._matplot,
plt.scatter,
{"x": self.x, "y": self.y, "edgecolor": "white", "linewidths": "0.1"},
),
]
combos[Plotting._build_plot_key(OrdinalTypedFeature, OrdinalTypedFeature)] = [
(
self._matplot,
seaborn.scatterplot,
{
"x": self.x,
"y": self.y,
"s": pi / 10 * (matplotlib.rcParams["lines.markersize"] ** 2),
"edgecolor": "white",
"linewidths": "0.1",
},
)
]
combos[Plotting._build_plot_key(OrdinalTypedFeature, DiscreteTypedFeature)] = [
(self._matplot, seaborn.countplot, {"x": self.x, "hue": self.y})
]
combos[
Plotting._build_plot_key(OrdinalTypedFeature, CategoricalTypedFeature)
] = [
(
self._matplot,
Plotting._count_plot,
{"x": self.x, "hue": self.y, "yscale": self.yscale},
)
]
combos[
Plotting._build_plot_key(CategoricalTypedFeature, OrdinalTypedFeature)
] = [
(
self._matplot,
Plotting._count_plot,
{"x": self.y, "hue": self.x, "yscale": self.yscale},
)
]
combos[Plotting._build_plot_key(DiscreteTypedFeature, OrdinalTypedFeature)] = [
(
self._matplot,
seaborn.countplot,
{
"x": self.x,
"hue": self.y,
"order": self.df[self.x]
.value_counts(ascending=True)
.iloc[:10]
.index,
},
)
]
combos[
Plotting._build_plot_key(DiscreteTypedFeature, CategoricalTypedFeature)
] = [
(
self._matplot,
seaborn.countplot,
{
"x": self.x,
"hue": self.y,
"order": self.df[self.x].value_counts().iloc[:10].index,
},
)
]
combos[Plotting._build_plot_key(DateTimeTypedFeature, DateTimeTypedFeature)] = [
(
self._matplot,
plt.scatter,
{
"x": self.x,
"y": self.y,
"s": pi / 10 * (matplotlib.rcParams["lines.markersize"] ** 2),
"edgecolor": "white",
"linewidths": "0.1",
},
)
]
combos[Plotting._build_plot_key(ContinuousTypedFeature, None)] = [
(self._matplot, plt.hist, {"x": self.x})
]
combos[Plotting._build_plot_key(CategoricalTypedFeature, None)] = [
(
self._matplot,
Plotting._single_column_count_plot,
{"x": self.x, "yscale": self.yscale},
)
]
combos[Plotting._build_plot_key(OrdinalTypedFeature, None)] = [
(
self._matplot,
Plotting._ordinal_count_plot,
{"x": self.x, "yscale": self.yscale},
)
]
combos[Plotting._build_plot_key(ConstantTypedFeature, None)] = [
(self._matplot, seaborn.countplot, {"x": self.x})
]
combos[Plotting._build_plot_key(DateTimeTypedFeature, None)] = [
(self._matplot, plt.hist, {"x": self.x, "bins": 10, "color": "#1f77b4"})
]
combos[Plotting._build_plot_key(GISTypedFeature, None)] = [
(self._generic_plot, Plotting._folium_map, {"x": self.x})
]
y_type_name = None if self.y_type is None else self.y_type.__class__
keys_to_check = list(
[Plotting._build_plot_key(self.x_type.__class__, y_type_name)]
)
new_x_type = Plotting._change_type(self.x_type)
new_y_type = Plotting._change_type(self.y_type)
keys_to_check.append(Plotting._build_plot_key(new_x_type, y_type_name))
keys_to_check.append(
Plotting._build_plot_key(self.x_type.__class__, new_y_type)
)
keys_to_check.append(Plotting._build_plot_key(new_x_type, new_y_type))
for key in keys_to_check:
if key in combos and combos[key]:
assert isinstance(combos[key][0], tuple)
return combos[key]
if y_type_name is not None:
raise NotImplementedError(
"Plotting for the feature combination ({0} vs {1}) is not yet supported.".format(
self.x_type.meta_data["type"], self.y_type.meta_data["type"]
)
)
else:
raise NotImplementedError(
"Plotting for feature type {0} is not supported".format(
self.x_type.meta_data["type"]
)
)
def _change_type(feature_type):
if feature_type is None:
return None
return (
DiscreteTypedFeature
if isinstance(feature_type, DiscreteTypedFeature)
or isinstance(feature_type, CreditCardTypedFeature)
or isinstance(feature_type, ZipcodeTypedFeature)
or isinstance(feature_type, OrdinalTypedFeature)
else feature_type.__class__
)