Source code for seeq.addons.correlation._heatmap

import numpy as np
import pandas as pd
import pickle
# There is a bug that prevents to correctly memorize a pandas.DataFrame
# Thus, all functions that use the @cached decorator need to accept serialized dataframes (pickle is a good option)
from memoization import cached
from ._config import _cache_max_items
from . import default_preprocessing_wrapper
from . import lags_coeffs
from ._heatmap_html import build_overlays_html, wrap_heatmap_html
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.transforms import Bbox
import seaborn as sns
import base64
from io import BytesIO


[docs] def heatmap(df, max_time_shift='auto', output_values='coeffs', output_type='plot', time_output_unit='auto', bypass_preprocessing=False): """ Creates a heatmap plot of the cross-correlation coefficients between signals. The signals can be allowed to shift in time to find the maximum cross-correlation between signals. Alternatively, a heatmap of the time shifts to maximize correlation of signals can be plotted as well. Parameters ---------- df: pandas.DataFrame A DataFrame that contains a set of signals as columns and date-time as the index. max_time_shift: {'auto', str, None}, default 'auto' Maximum time (e.g. '15s', or '1min') that the signals are allowed to slide in order to maximize cross-correlation. For times specified as a str, normal time units are accepted.If 'auto' is selected, a default maximum time shift is calculated based on the number of samples. If None, the raw signals are used and no time shifts are calculated. output_values: {'coeffs', 'time_shifts'}, default 'coeffs' Values to plot in the heatmap. Either the Pearson's coefficients or the time shifts that the signals were shifted to maximize cross-correlation. output_type: {'plot', 'table'}, default 'plot' The heatmap can be outputted either as an (interactive) plot or as a DataFrame time_output_unit: {'auto', str} default 'auto' Specifies the time unit used to display the time shifts. Valid units are the ones accepted by pd.Timedelta bypass_preprocessing: bool, default False Whether the data pre-processing routine is by-passed or not. Setting it to True is not recommended unless the data has been pre-processed elsewhere. Returns ------- Either one of the following -: None Displays a Matplotlib/Seaborn figure with either Pearson's coefficients or signal time shifts table: pandas.DataFrame A DataFrame with either Pearson's coefficients or times_shifts of all signal pairs Examples -------- Create a heatmap plot of the cross-correlation coefficients for the signals in a DataFrame allowing for automatic guess of maximum time shifts >>> seeq.addons.correlation.heatmap(df, >>> max_time_shift='auto', >>> output_values='coeffs', >>> output_type='plot') Create a heatmap plot of the cross-correlation coefficients for the signals in a DataFrame specifying a maximum time shift between signals of 1 hour >>> seeq.addons.correlation.heatmap(df, >>> max_time_shift='1h', >>> output_values='coeffs', >>> output_type='plot') Create a table of the time shifts to maximize cross-correlation of the signals in a DataFrame specifying a maximum time shift between signals of 1 hour >>> seeq.addons.correlation.heatmap(df, >>> max_time_shift='1h', >>> output_values='time_shifts', >>> output_type='table') Create a table of the cross-correlation coefficients for the signals in a DataFrame using the raw data (no time shift allowed) >>> seeq.addons.correlation.heatmap(df, >>> max_time_shift=None, >>> output_values='coeffs', >>> output_type='table') """ if time_output_unit is None: raise ValueError('time_output_unit cannot be None. Please specify a valid pd.Timedelta unit') heatmap_object = _heatmap(df, max_time_shift=max_time_shift, output_values=output_values, output_type=output_type, time_output_unit=time_output_unit, bypass_preprocessing=bypass_preprocessing) if output_type == 'plot': heatmap_object.show(config={'displaylogo': False, 'displayModeBar': True}) else: return heatmap_object
def _heatmap(df, max_time_shift='auto', output_values='coeffs', output_type='plot', time_output_unit='auto', bypass_preprocessing=False): # We don't want to remove outliers here. Increased the outlier_sensitivity df = default_preprocessing_wrapper(df, consecutivenans=0.04, percent_nan=0.0, bypass_processing=bypass_preprocessing) lags, coeffs, sampling_time, time_unit, maxlags = lags_coeffs(df, max_time_shift, time_output_unit) lags_to_time = lags * sampling_time coeffs_df = pd.DataFrame(data=coeffs, columns=df.columns, index=df.columns) time_shifts_df = pd.DataFrame(data=lags_to_time, columns=df.columns, index=df.columns) if output_type == 'plot': if output_values == 'coeffs': fig = _heatmap_plot(pickle.dumps(coeffs_df), pickle.dumps(time_shifts_df), time_unit=time_unit, lags_plot=False) elif output_values == 'time_shifts': fig = _heatmap_plot(pickle.dumps(time_shifts_df), pickle.dumps(coeffs_df), time_unit=time_unit, lags_plot=True) else: raise ValueError('Invalid output_type: {}'.format(output_values)) return fig elif output_type == 'table': if output_values == 'coeffs': return coeffs_df elif output_values == 'time_shifts': time_shifts_df.columns = [f"{x} ({time_unit})" for x in time_shifts_df.columns] return time_shifts_df else: raise ValueError('Invalid output_values: {}'.format(output_values)) else: raise ValueError('Invalid output_values: {}'.format(output_type)) def rename_signals(signal_list, max_label_chars): if np.array([len(x) for x in signal_list]).max() > max_label_chars: new_names = [] size_ = int(max_label_chars/2) for i, name in enumerate(signal_list): if len(name) > max_label_chars: truncated_name = name[:size_] + "..." + name[-(size_-3):] if truncated_name in new_names: unique_name = f"{truncated_name[2:]}_{i}" else: unique_name = truncated_name new_names.append(unique_name) else: new_names.append(name) else: new_names = signal_list return new_names def _compute_color_limits(primary_array: np.ndarray, lags_plot: bool): # Compute color limits and colormap for the heatmap if lags_plot: flat = primary_array.flatten() limit = max(np.nanmax(flat), abs(np.nanmin(flat))) cmap = 'RdBu' else: limit = 1.0 cmap = 'RdBu' center = 0 return limit, cmap, center def _prepare_frames(primary_df_serialized, secondary_df_serialized, boolean_df, max_label_chars): # Prepare the DataFrames for plotting primary_df = pickle.loads(primary_df_serialized) secondary_df = pickle.loads(secondary_df_serialized) new_names = rename_signals(list(primary_df.columns), max_label_chars) # Data used to draw (apply mask if provided) if isinstance(boolean_df, pd.DataFrame): plot_df = primary_df[boolean_df].copy() primary_array = plot_df.values else: plot_df = primary_df.copy() primary_array = plot_df.values plot_df.index = new_names plot_df.columns = new_names # Tooltip values (unmasked) primary_vals = primary_df.copy() primary_vals.index = new_names primary_vals.columns = new_names secondary_vals = secondary_df.loc[primary_df.index, primary_df.columns].copy() secondary_vals.index = new_names secondary_vals.columns = new_names return plot_df, primary_array, primary_vals, secondary_vals, new_names def _draw_heatmap(plot_df: pd.DataFrame, limit, cmap, center): # Create the figure and draw the heatmap num_signals = len(plot_df) base_size = max(4, min(8, num_signals * 0.35)) fig, ax = plt.subplots(figsize=(base_size, base_size), facecolor='white') sns.heatmap( plot_df, annot=False, fmt='.2f', cmap=cmap, center=center, vmin=-limit, vmax=limit, square=True, linewidths=0.5, cbar=False, ax=ax ) ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') ax.tick_params(axis='x', labelsize=8, pad=6) ax.tick_params(axis='y', labelsize=8, pad=4) return fig, ax def _add_colorbar(fig, ax, lags_plot: bool): # Create colorbar divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="4%", pad=0.25) cbar = fig.colorbar(ax.collections[0], cax=cax) if lags_plot: cbar.set_label("Time (minutes)", rotation=270, labelpad=12) else: cbar.set_label("Correlation Coefficient", rotation=270, labelpad=12) cbar.ax.yaxis.label.set_size(10) cbar.ax.yaxis.label.set_weight('bold') return cbar def _compute_geometry_percentages(fig, ax, pad_inches: float = 0.02): # Compute the geometry of the heatmap for CSS positioning fig.canvas.draw() renderer = fig.canvas.get_renderer() tight = fig.get_tightbbox(renderer) # inches x0, y0, x1, y1 = tight.extents tight_padded = Bbox.from_extents(x0 - pad_inches, y0 - pad_inches, x1 + pad_inches, y1 + pad_inches) # Keep original display width so the visual size is stable in HTML orig_width_px = int(round(tight_padded.width * fig.dpi)) # Axes bbox in inches ab_px = ax.get_window_extent(renderer=renderer) ab_in = Bbox.from_extents(ab_px.x0 / fig.dpi, ab_px.y0 / fig.dpi, ab_px.x1 / fig.dpi, ab_px.y1 / fig.dpi) # Fractions of the cropped image ax_left_frac = (ab_in.x0 - tight_padded.x0) / tight_padded.width ax_top_frac = (tight_padded.y1 - ab_in.y1) / tight_padded.height ax_width_frac = ab_in.width / tight_padded.width ax_height_frac = ab_in.height / tight_padded.height # Convert to percentages for CSS ax_left_pct = 100.0 * ax_left_frac ax_top_pct = 100.0 * ax_top_frac ax_width_pct = 100.0 * ax_width_frac ax_height_pct = 100.0 * ax_height_frac return tight_padded, orig_width_px, ax_left_pct, ax_top_pct, ax_width_pct, ax_height_pct def _export_png_base64(fig, bbox_inches, export_scale: float = 2.0): # Export the figure to PNG and encode it in base64 export_dpi = int(round(fig.dpi * export_scale)) buf = BytesIO() fig.savefig(buf, format="png", bbox_inches=bbox_inches, dpi=export_dpi) plt.close(fig) buf.seek(0) return base64.b64encode(buf.read()).decode("ascii") @cached(max_size=_cache_max_items) def _heatmap_plot(primary_df_serialized, secondary_df_serialized, time_unit: str, lags_plot=False, boolean_df=None, max_label_chars=30): # Prep frames plot_df, primary_array, primary_vals, secondary_vals, _ = _prepare_frames( primary_df_serialized, secondary_df_serialized, boolean_df, max_label_chars ) if plot_df.empty: return None # Color scale limit, cmap, center = _compute_color_limits(primary_array, lags_plot) # Draw fig, ax = _draw_heatmap(plot_df, limit, cmap, center) # Colorbar _add_colorbar(fig, ax, lags_plot) # Geometry (DPI-agnostic percentages + original display width) tight_padded, orig_width_px, ax_left_pct, ax_top_pct, ax_width_pct, ax_height_pct = \ _compute_geometry_percentages(fig, ax, pad_inches=0.02) # Build tooltip overlays overlays_html = build_overlays_html( plot_df=plot_df, primary_vals=primary_vals, secondary_vals=secondary_vals, time_unit=time_unit, lags_plot=lags_plot, ax_left_pct=ax_left_pct, ax_top_pct=ax_top_pct, ax_width_pct=ax_width_pct, ax_height_pct=ax_height_pct, ) # Export PNG png_b64 = _export_png_base64(fig, bbox_inches=tight_padded, export_scale=2.0) # Final HTML html = wrap_heatmap_html(png_b64=png_b64, overlays_html=overlays_html, orig_width_px=orig_width_px) return html