Source code for suboptimumg.plotting.overlay_plots

from typing import Dict, Iterable, List, Optional

import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from plotly.colors import qualitative
from plotly.subplots import make_subplots

from suboptimumg.plotting.generic_plot import prepare_smooth_data_2D

from .plotting_constants import *


def _color_cycle(n: int) -> Iterable[str]:
    """
    Get n distinct colors from Plotly's qualitative palette.

    Parameters
    ----------
    n : int
        Number of colors needed

    Returns
    -------
    Iterable[str]
        Iterable of color strings, repeating if n exceeds palette size
    """
    palette: Iterable[str] = qualitative.Plotly
    if n <= len(palette):
        return palette[:n]
    times = n // len(palette) + 1
    return (palette * times)[:n]


[docs] def overlay_plot2D( x_data_dict: Dict[str, npt.NDArray[np.number]], y_data_dict: Dict[str, npt.NDArray[np.number]], title: str, x_axis: str, y_axis: str, *, h_line: Optional[ReferenceLineConfig] = None, v_line: Optional[ReferenceLineConfig] = None, fit_curve: bool = False, show_points: bool = False, subtitle: Optional[str] = None, font_config: Optional[FontConfig] = None, layout_config: Optional[LayoutConfig] = None, smoothing_config: Optional[SmoothingConfig] = None, ) -> go.Figure: """ Generic 2D overlay plotting utility for comparing multiple series. Parameters ---------- x_data_dict : Dict[str, npt.NDArray[np.number]] Dictionary mapping trace labels to x data arrays y_data_dict : Dict[str, npt.NDArray[np.number]] Dictionary mapping trace labels to y data arrays title : str Plot title x_axis : str X-axis label y_axis : str Y-axis label h_line : ReferenceLineConfig, optional ReferenceLineConfig for horizontal reference line (None = no line) v_line : ReferenceLineConfig, optional ReferenceLineConfig for vertical reference line (None = no line) fit_curve : bool, optional Whether to fit a curve to the data points show_points : bool, optional Whether to show individual data points subtitle : str, optional Optional subtitle font_config : FontConfig, optional FontConfig object for font settings layout_config : LayoutConfig, optional LayoutConfig object for layout settings smoothing_config : SmoothingConfig, optional SmoothingConfig object for smoothing settings Returns ------- go.Figure Plotly figure object """ font_config = font_config or DEFAULT_FONT_CONFIG layout_config = layout_config or DEFAULT_LAYOUT_CONFIG smoothing_config = smoothing_config or DEFAULT_SMOOTHING_CONFIG fig = go.Figure() # Get trace labels and assign colors trace_labels = sorted(list(y_data_dict.keys())) colors = _color_cycle(len(trace_labels)) # Create traces for each series for trace_label, color in zip(trace_labels, colors): x_vals = x_data_dict[trace_label] y_vals = y_data_dict[trace_label] if fit_curve: x_dense, y_dense = prepare_smooth_data_2D( x_vals, y_vals, smoothing_config=smoothing_config ) else: x_dense, y_dense = x_vals, y_vals # Interpolated line fig.add_trace( go.Scatter( x=x_dense, y=y_dense, mode="lines", line=dict( color=color, width=LINE_WIDTH, shape="spline", ), name=trace_label, hovertemplate=f"{x_axis}: %{{x:{FLOAT_PRECISION}}}<br>{y_axis}: %{{y:{FLOAT_PRECISION}}}<extra></extra>", ) ) if show_points: fig.add_trace( go.Scatter( x=x_vals, y=y_vals, mode="markers", marker=dict(color=color, size=MARKER_SIZE_LARGE), showlegend=False, hovertemplate=None, ) ) # Add horizontal reference line if specified if h_line is not None and h_line.value is not None: # Get x range across all traces all_x = np.concatenate([x_data_dict[label] for label in trace_labels]) x_min, x_max = min(all_x), max(all_x) fig.add_trace( go.Scatter( x=[x_min, x_max], y=[h_line.value, h_line.value], mode="lines", line=dict( color=h_line.color, width=h_line.width, dash=h_line.dash, ), name=f"{h_line.label}", hoverinfo="name", ) ) # Add vertical reference line if specified if v_line is not None and v_line.value is not None: # Get y range across all traces all_y = np.concatenate([y_data_dict[label] for label in trace_labels]) y_min, y_max = min(all_y), max(all_y) fig.add_trace( go.Scatter( x=[v_line.value, v_line.value], y=[y_min, y_max], mode="lines", line=dict( color=v_line.color, width=v_line.width, dash=v_line.dash, ), name=f"{v_line.label}", hoverinfo="name", ) ) # Title, Axes, Label, Legend, Size full_title = title if subtitle is not None: full_title += f"<br><span style='font-size: {font_config.medium}px; color: {TEXT_COLOR_LIGHT};'>{subtitle}</span>" fig.update_layout( title={ "text": full_title, "font": dict(size=font_config.large, color=TEXT_COLOR_DARK), "x": layout_config.title_x, "xanchor": layout_config.title_xanchor, }, xaxis_title={ "text": x_axis, "font": dict(size=font_config.medium, color=TEXT_COLOR_DARK), }, yaxis_title={ "text": y_axis, "font": dict(size=font_config.medium, color=TEXT_COLOR_DARK), }, legend_title="Legend", showlegend=True, legend_title_font=dict(size=font_config.medium), legend_font=dict(size=font_config.small), hovermode=HOVER_MODE, plot_bgcolor=layout_config.plot_bgcolor, width=layout_config.width, height=layout_config.height, margin=layout_config.margin, ) # Add subtle grid lines fig.update_xaxes( showgrid=True, gridwidth=GRID_WIDTH, gridcolor=GRID_COLOR, zeroline=True, zerolinewidth=ZEROLINE_WIDTH, zerolinecolor=ZEROLINE_COLOR, tickfont=dict(size=font_config.small), tickformat=FLOAT_PRECISION, ) fig.update_yaxes( showgrid=True, gridwidth=GRID_WIDTH, gridcolor=GRID_COLOR, zeroline=True, zerolinewidth=ZEROLINE_WIDTH, zerolinecolor=ZEROLINE_COLOR, tickfont=dict(size=font_config.small), tickformat=FLOAT_PRECISION, ) if h_line is not None and h_line.value is not None: all_y = np.concatenate([y_data_dict[label] for label in trace_labels]) padding = 1 + RANGE_PADDING fig.update_yaxes( range=[ (1 / padding) * min(min(all_y), h_line.value), padding * max(max(all_y), h_line.value), ] ) if v_line is not None and v_line.value is not None: all_x = np.concatenate([x_data_dict[label] for label in trace_labels]) padding = 1 + RANGE_PADDING fig.update_xaxes( range=[ (1 / padding) * min(min(all_x), v_line.value), padding * max(max(all_x), v_line.value), ] ) return fig
[docs] def overlay_grid_plot_2D( x_data_dict: Dict[str, Dict[str, npt.NDArray[np.number]]], y_data_dict: Dict[str, Dict[str, npt.NDArray[np.number]]], subplot_titles: List[str], title: str, x_label: str, y_label: str, rows: int, cols: int, *, show_points: bool = False, show_legend: bool = False, fit_curve: bool = False, layout_config: LayoutConfig = DEFAULT_LAYOUT_CONFIG, font_config: FontConfig = DEFAULT_FONT_CONFIG, smoothing_config: SmoothingConfig = DEFAULT_SMOOTHING_CONFIG, ) -> None: """ Generic 2D overlay grid plotting utility for comparing multiple series. Parameters ---------- x_data_dict : Dict[str, Dict[str, npt.NDArray[np.number]]] {subplot_key: {trace_label: x_data_array}} - subplot_key: identifies which subplot the data belongs to (must match order in subplot_titles) - trace_label: identifies which trace/series (must match keys in y_data_dict[subplot_key]) y_data_dict : Dict[str, Dict[str, npt.NDArray[np.number]]] {subplot_key: {trace_label: y_data_array}} - subplot_key: identifies which subplot the data belongs to (must match order in subplot_titles) - trace_label: identifies which trace/series (must match keys in x_data_dict[subplot_key]) subplot_titles : List[str] List of titles for each subplot title : str Overall title for the grid x_label : str Label for x-axis y_label : str Label for y-axis rows : int Number of rows in the grid cols : int Number of columns in the grid show_legend : bool, optional If True, shows legend on the first subplot fit_curve : bool, optional Whether to use spline fitting for curves smoothing : float, optional Smoothing parameter (deprecated, use smoothing_config) interp_factor : int, optional Interpolation factor (deprecated, use smoothing_config) layout_config : LayoutConfig, optional Layout configuration object font_config : FontConfig, optional Font configuration object smoothing_config : SmoothingConfig, optional Smoothing configuration object Returns ------- go.Figure Plotly figure object """ # Build figure fig = make_subplots( rows=rows, cols=cols, subplot_titles=subplot_titles, horizontal_spacing=layout_config.grid_horizontal_spacing, vertical_spacing=layout_config.grid_vertical_spacing, shared_xaxes=False, shared_yaxes=False, ) subplot_keys = list(y_data_dict.keys()) # Get all unique trace labels across all subplots all_trace_labels = set() for subplot_key in subplot_keys: all_trace_labels.update(y_data_dict[subplot_key].keys()) trace_labels = sorted(list(all_trace_labels)) colors = _color_cycle(len(trace_labels)) # Create traces for each subplot for subplot_idx, subplot_key in enumerate(subplot_keys): r = subplot_idx // cols + 1 c = subplot_idx % cols + 1 x_data_for_subplot = x_data_dict[subplot_key] y_data_for_subplot = y_data_dict[subplot_key] for trace_label, color in zip(trace_labels, colors): # Get x and y data for this trace x_vals = x_data_for_subplot[trace_label] y_vals = y_data_for_subplot[trace_label] x_dense, y_dense = prepare_smooth_data_2D(x_vals, y_vals, smoothing_config) # Interpolated line fig.add_trace( go.Scatter( x=x_dense, y=y_dense, mode="lines", line=dict( color=color, width=2, shape=("spline" if fit_curve else "linear"), ), name=trace_label, legendgroup=trace_label, showlegend=(show_legend and subplot_idx == 0), ), row=r, col=c, ) # Original points if show_points: fig.add_trace( go.Scatter( x=x_vals, y=y_vals, mode="markers", marker=dict(color=color, size=MARKER_SIZE, opacity=0.85), name=trace_label, showlegend=False, hovertemplate=f"{x_label}: %{{x}}<br>{subplot_titles[subplot_idx]}: %{{y:.2f}}<extra></extra>", ), row=r, col=c, ) # Axis titles for each subplot yaxis_name = f"yaxis{subplot_idx + 1}" if subplot_idx > 0 else "yaxis" xaxis_name = f"xaxis{subplot_idx + 1}" if subplot_idx > 0 else "xaxis" fig.update_layout( { yaxis_name: { "title": {"text": y_label, "font": {"size": font_config.medium}} } } ) fig.update_layout( { xaxis_name: { "title": {"text": x_label, "font": {"size": font_config.medium}} } } ) # Layout and styling fig.update_layout( title={ "text": title, "font": dict(size=font_config.large, color=TEXT_COLOR_DARK), "x": layout_config.title_x, "xanchor": layout_config.title_xanchor, }, width=max(layout_config.width, layout_config.grid_width_per_col * cols), height=max(layout_config.height, layout_config.grid_height_per_row * rows), plot_bgcolor=layout_config.plot_bgcolor, showlegend=show_legend, legend=dict( orientation="v", yanchor="top", xanchor="left", ), margin=layout_config.margin, ) # Subtle gridlines and zero lines fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="rgba(211, 211, 211, 0.3)") fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="rgba(211, 211, 211, 0.3)") fig.update_xaxes( zeroline=True, zerolinewidth=1, zerolinecolor="rgba(211, 211, 211, 0.8)" ) fig.update_yaxes( zeroline=True, zerolinewidth=1, zerolinecolor="rgba(211, 211, 211, 0.8)" ) return fig