Source code for plot_utils.helper

# -*- coding: utf-8 -*-

import collections
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pylab as pl
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

#%%----------------------------------------------------------------------------
_array_like = (list, np.ndarray, pd.Series)  # define a "compound" data type
_scalar_like = (int, float, np.number)  # "compound" data type

#%%----------------------------------------------------------------------------
class LengthError(Exception):
    pass

class DimensionError(Exception):
    pass

#%%============================================================================
def assert_type(something, desired_type, name='something'):
    '''
    Assert ``something`` is a ``desired_type``.

    Parameters
    ----------
    something :
        Any Python object.
    desired_type : type or typle<type>
        A valid Python type, such as float, or a tuple of Python types, such
        as (float, int).
    name : str
        The name of ``something`` to show in the error message (if applicable).
    '''
    if not isinstance(something, desired_type):
        msg = '"%s" must be %s, rather than %s.' % (name, desired_type, type(something))
        raise TypeError(msg)
    # END IF

#%%============================================================================
def assert_element_type(some_iterable, desired_element_type, name='some_iterable'):
    '''
    Assert all elements of ``some_iterable`` is of ``desired_type``.

    Parameters
    ----------
    some_iterable : Python iterable
        An iterable object, such as a list, numpy array.
    desired_element_type : type or tuple<type>
        Desired element type.
    name : str
        The name of ``something`` to show in the error message (if applicable).
    '''
    msg = 'All elements of "%s" must be %s.' % (name, desired_element_type)
    assert_type(some_iterable, collections.abc.Iterable, name=name)
    if isinstance(desired_element_type, type):
        if not all([isinstance(_, desired_element_type) for _ in some_iterable]):
            raise TypeError(msg)
        # END IF
    elif isinstance(desired_element_type, tuple):
        success = False
        for this_type in desired_element_type:
            if all([isinstance(_, this_type) for _ in some_iterable]):
                success = True
                continue
            # END IF
        # END FOR
        if not success:
            raise TypeError(msg)
        # END IF
    else:
        raise TypeError('`desired_element_type` must be a type or a tuple of types.')
    # END IF-ELSE

#%%============================================================================
def _process_fig_ax_objects(fig, ax, figsize=None, dpi=None, ax_proj=None):
    '''
    Processes figure and axes objects. If ``fig`` and ``ax`` are None, creates
    new figure and new axes according to ``figsize``, ``dpi``, and ``ax_proj``.
    Otherwise, uses the passed-in ``fig`` and/or ``ax``.

    Parameters
    ----------
    fig : matplotlib.figure.Figure or ``None``
        Figure object. If None, a new figure will be created.
    ax : matplotlib.axes._subplots.AxesSubplot or ``None``
        Axes object. If None, a new axes will be created.
    figsize: (float, float)
        Figure size in inches, as a tuple of two numbers. The figure
        size of ``fig`` (if not ``None``) will override this parameter.
    dpi : float
        Figure resolution. The dpi of ``fig`` (if not ``None``) will override
        this parameter.
    ax_proj : {None, 'aitoff', 'hammer', 'lambert', 'mollweide', 'polar', 'rectilinear', str}
        The projection type of the axes. The default None results in a
        'rectilinear' projection.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The figure object being created or being passed into this function.
    ax : matplotlib.axes._subplots.AxesSubplot
        The axes object being created or being passed into this function.
    '''
    if fig is None:  # if a figure handle is not provided, create new figure
        fig = pl.figure(figsize=figsize,dpi=dpi)
    else:   # if provided, plot to the specified figure
        pl.figure(fig.number)

    if ax is None:  # if ax is not provided
        ax = plt.axes(projection=ax_proj)  # create new axes and plot lines on it
    else:
        ax = ax  # plot lines on the provided axes handle

    return fig, ax

#%%============================================================================
def _upcast_dtype(x):
    '''
    Cast dtype of x (a pandas Series) as string or float in-place.

    Parameter
    ---------
    x : pandas.Series
        An array whose elements are to be upcast.

    Returns
    -------
    x : pandas.Series
        The array whose elements are now upcast.
    '''
    assert(type(x) == pd.Series)

    if x.dtype.name in ['category', 'bool', 'datetime64[ns]', 'datetime64[ns, tz]']:
        x = x.astype(str)

    if x.dtype.name == 'timedelta[ns]':
        x = x.astype(float)

    return x

#%%============================================================================
[docs]def _find_axes_lim(data_limit, tick_base_unit, direction='upper'): ''' Return a "whole" number to be used as the upper or lower limit of axes. For example, if the maximum x value of the data is 921.5, and you would like the upper x_limit to be a multiple of 50, then this function returns 950. Parameters ---------- data_limit : float, int, list<float>, list<int>, tuple<float>, tuple<int> The upper and/or lower limit(s) of data. (1) If a tuple (or list) of two elements is provided, then the upper and lower axis limits are automatically determined. (The order of the two elements does not matter.) (2) If a float or an int is provided, then the axis limit is determined based on the ``direction`` provided. tick_base_unit : float For example, if you want your axis limit(s) to be a multiple of 20 (such as 80, 120, 2020, etc.), then use 20. direction : {'upper', 'lower'} The direction of the limit to be found. For example, if the maximum of the data is 127, and ``tick_base_unit`` is 50, then a ``direction`` of lower yields a result of 100. This parameter is effective only when ``data_limit`` is a scalar. Returns ------- axes_lim : list<float> or float If ``data_limit`` is a list/tuple of length 2, return a list: [min_limit, max_limit] (always ordered no matter what the order of ``data_limit`` is). If ``data_limit`` is a scalar, return the axis limit according to ``direction``. ''' if isinstance(data_limit, _scalar_like): if direction == 'upper': return tick_base_unit * (int(data_limit/tick_base_unit)+1) elif direction == 'lower': return tick_base_unit * (int(data_limit/tick_base_unit)) else: raise LengthError('Length of `data_limit` should be at least 1.') elif isinstance(data_limit, (tuple, list)): if len(data_limit) > 2: raise LengthError('Length of `data_limit` should be at most 2.') elif len(list(data_limit)) == 2: min_data = min(data_limit) max_data = max(data_limit) max_limit = tick_base_unit * (int(max_data/tick_base_unit)+1) min_limit = tick_base_unit * (int(min_data/tick_base_unit)) return [min_limit, max_limit] elif len(data_limit) == 1: # such as [2.14] return _find_axes_lim(data_limit[0],tick_base_unit,direction) elif isinstance(data_limit, np.ndarray): data_limit = data_limit.flatten() # convert np.array(2.5) into np.array([2.5]) if data_limit.size == 1: return _find_axes_lim(data_limit[0],tick_base_unit,direction) elif data_limit.size == 2: return _find_axes_lim(list(data_limit),tick_base_unit,direction) elif data_limit.size >= 3: raise LengthError('Length of `data_limit` should be at most 2.') else: raise TypeError( '`data_limit` should be a scalar or a tuple/list of length 2.' ) else: raise TypeError( '`data_limit` should be a scalar or a tuple/list of length 2.' )
#%%============================================================================ class _FixedOrderFormatter(mpl.ticker.ScalarFormatter): ''' Formats axis ticks using scientific notation with a constant order of magnitude. (Reference: https://stackoverflow.com/a/3679918) Note: this class is not currently being used. ''' def __init__(self, order_of_mag=0, useOffset=True, useMathText=True): self._order_of_mag = order_of_mag mpl.ticker.ScalarFormatter.__init__( self, useOffset=useOffset, useMathText=useMathText, ) def _set_orderOfMagnitude(self, range): """Over-riding this to avoid having orderOfMagnitude reset elsewhere""" self.orderOfMagnitude = self._order_of_mag #%%============================================================================ def _calc_bar_width(width): ''' Calculate width (in points) of bar plot from figure width (in inches). ''' if width <= 7: bar_width = width * 3.35 # these numbers are manually fine-tuned elif width <= 9: bar_width = width * 2.60 elif width <= 10: bar_width = width * 2.10 else: bar_width = width * 1.2 return bar_width #%%============================================================================ def _get_ax_size(fig, ax, unit='inches'): ''' Get size of axes within a figure, given fig and ax objects. https://stackoverflow.com/questions/19306510/determine-matplotlib-axis-size-in-pixels ''' bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) width, height = bbox.width, bbox.height if unit == 'pixels': width *= fig.dpi # convert from inches to pixels height *= fig.dpi return width, height #%%============================================================================ def _calc_r2_score(y_true, y_pred): ''' Calculate the coefficient of determination between two arrays. The best possible value is 1.0. The result can be negative, because the model predicted value (``y_pred``) can be arbitrarily bad. A naive prediction, i.e., ``y_pred`` equals the mean value of ``y_true`` produces a negative infinity R2 score. Parameters ---------- y_true : list, numpy.ndarray, or pandas.Series The "true values", or the dependent variable, or "y axis". y_pred : list, numpy.ndarray, or pandas.Series The "predicted values", or the independent variable, or "x axis". Returns ------- r2_score : float The coefficient of determination. References ---------- .. [1] `Wikipedia entry on the Coefficient of determination <https://en.wikipedia.org/wiki/Coefficient_of_determination>`_ ''' if not isinstance(y_true, _array_like): raise TypeError('`y_true` needs to be a list, numpy array, or pandas Series.') if not isinstance(y_pred, _array_like): raise TypeError('`y_pred` needs to be a list, numpy array, or pandas Series.') if len(y_true) != len(y_pred): raise LengthError('`y_true` and `y_pred` should have the same length.') f = np.array(y_pred) # follow the notation in the wikipedia page y = np.array(y_true) y_bar = np.mean(y) SS_tot = np.sum((y - y_bar)**2.0) SS_res = np.sum((y - f)**2.0) r2_score = 1 - SS_res / SS_tot return r2_score #%%============================================================================ def __axes_styling_helper( ax, vert, rot, data_names, n_datasets, data_ax_label, name_ax_label, title, ): ''' Helper function. Used by _violin_plot_helper() and _multi_hist_helper(). Parameters ---------- ax : matplotlib.axes._subplots.AxesSubplot Matplotlib axes object. vert : bool Whether to show the violins or the "base" of the histograms as vertical. rot : float The rotation (in degrees) of the data_names when shown as the tick labels. If ``vert`` is ``False``, ``rot`` has no effect. data_names : list<str>, ``[]``, or ``None`` The names of each data set, to be shown as the axis tick label of each data set. If ``[]`` or ``None``, it will be determined automatically. If ``X`` is a: - numpy.ndarray: + data_names = ['data_0', 'data_1', 'data_2', ...] - pandas.Series: + data_names = X.name - pd.DataFrame: + data_names = list(X.columns) - dict: + data_names = list(X.keys()) n_datasets : int Number of sets of data. data_ax_label : str The labels of the "data axis". ("Data axis" is the axis along which the data values are presented.) name_ax_label : str The label of the "name axis". ("Name axis" is the axis along which different violins are presented.) title : str The title of the plot. Returns ------- ax : matplotlib.axes._subplots.AxesSubplot Matplotlib axes object. ''' ax.grid(ls=':') ax.set_axisbelow(True) if vert: ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(base=1.0)) ax.set_xticks(np.arange(n_datasets) + 1) ha = 'center' if (0 <= rot < 30 or rot == 90) else 'right' ax.set_xticklabels(data_names, rotation=rot, ha=ha) else: ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(base=1.0)) ax.set_yticks(np.arange(n_datasets) + 1) ax.set_yticklabels(data_names) if data_ax_label: if not vert: ax.set_xlabel(data_ax_label) else: ax.set_ylabel(data_ax_label) if name_ax_label: if not vert: ax.set_ylabel(name_ax_label) else: ax.set_xlabel(name_ax_label) if title: ax.set_title(title) return ax #%%============================================================================ class _MidpointNormalize(Normalize): ''' Auxiliary class definition. Copied from: https://stackoverflow.com/questions/20144529/shifted-colorbar-matplotlib/20146989#20146989 ''' def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False): self.midpoint = midpoint Normalize.__init__(self, vmin, vmax, clip) def __call__(self, value, clip=None): # I'm ignoring masked values and all kinds of edge cases to make a # simple example... x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1] return np.ma.masked_array(np.interp(value, x, y))