Source code for suboptimumg.plotting.grid_plot

from enum import Enum
from typing import Callable, Dict, List, Optional

import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

from .color_themes import get_theme
from .plotting_constants import *
from .utils import prepare_smooth_data_2D, prepare_smooth_data_3D

# Set default plotly template for better aesthetics
pio.templates.default = "plotly_white"


[docs] def plot_grid_2D( x_list, y_data_dict: Dict[str, np.ndarray], subplot_titles: List[str], title: str, x_label: str, y_label: str, rows: int, cols: int, *, theme=None, fit_curve: bool = False, font_config: Optional[FontConfig] = None, layout_config: Optional[LayoutConfig] = None, smoothing_config: Optional[SmoothingConfig] = None, ): """ Generic 2D grid plotting utility. Parameters ---------- x_list : numpy.ndarray Numpy array of x coordinates (shared across all subplots) y_data_dict : Dict[str, np.ndarray] Dictionary mapping event keys to y-data arrays subplot_titles : List[str] List of subplot titles in order title : str Overall grid title x_label : str X-axis label for all subplots y_label : str Y-axis label for all subplots rows : int Number of rows in the grid cols : int Number of columns in the grid theme : str, optional Color theme name fit_curve : bool, optional Whether to fit curves to data font_config : FontConfig, optional FontConfig object for font settings layout_config : LayoutConfig, optional LayoutConfig object for layout settings smoothing_config : SmoothingConfig, optional SmoothingConfig object for smoothing settings Returns ------- go.Figure Plotly figure object """ if font_config is None: font_config = DEFAULT_FONT_CONFIG if layout_config is None: layout_config = DEFAULT_LAYOUT_CONFIG if smoothing_config is None: smoothing_config = DEFAULT_SMOOTHING_CONFIG # Create figure fig = make_subplots( rows=rows, cols=cols, subplot_titles=subplot_titles, horizontal_spacing=layout_config.grid_horizontal_spacing, vertical_spacing=layout_config.grid_vertical_spacing, shared_xaxes=False, shared_yaxes=False, ) theme_colors = get_theme(theme) # Populate subplots event_keys = list(y_data_dict.keys()) for idx, event_key in enumerate(event_keys): r = idx // cols + 1 c = idx % cols + 1 y_list = y_data_dict[event_key] x_dense, y_dense = prepare_smooth_data_2D( x_list, y_list, smoothing_config=smoothing_config ) # Add interpolated line fig.add_trace( go.Scatter( x=x_dense, y=y_dense, mode="lines", line=dict( color=theme_colors["light"], width=LINE_WIDTH, shape=("spline" if fit_curve else "linear"), ), name=event_key, showlegend=False, ), row=r, col=c, ) # Add data points fig.add_trace( go.Scatter( x=x_list, y=y_list, mode="markers", marker=dict(color=theme_colors["dark"], size=MARKER_SIZE), name=f"{event_key} Data", showlegend=False, ), row=r, col=c, ) # Set axis titles for this subplot yaxis_name = f"yaxis{idx + 1}" if idx > 0 else "yaxis" fig.update_layout( { yaxis_name: { "title": {"text": y_label, "font": {"size": font_config.medium}} } } ) xaxis_name = f"xaxis{idx + 1}" if idx > 0 else "xaxis" fig.update_layout( { xaxis_name: { "title": {"text": x_label, "font": {"size": font_config.medium}} } } ) # Configure overall layout fig.update_layout( title={ "text": title, "font": dict(size=font_config.large, color=TEXT_COLOR_DARK), "x": layout_config.title_x, "xanchor": layout_config.title_xanchor, }, width=max(layout_config.width, layout_config.grid_width_per_col * cols), height=max(layout_config.height, layout_config.grid_height_per_row * rows), plot_bgcolor=layout_config.plot_bgcolor, showlegend=False, margin=layout_config.margin, ) # Add gridlines and format tick labels fig.update_xaxes( showgrid=True, gridwidth=GRID_WIDTH, gridcolor=GRID_COLOR, zeroline=True, zerolinewidth=ZEROLINE_WIDTH, zerolinecolor=ZEROLINE_COLOR, tickformat=FLOAT_PRECISION, tickfont=dict(size=font_config.small), ) fig.update_yaxes( showgrid=True, gridwidth=GRID_WIDTH, gridcolor=GRID_COLOR, zeroline=True, zerolinewidth=ZEROLINE_WIDTH, zerolinecolor=ZEROLINE_COLOR, tickformat=FLOAT_PRECISION, tickfont=dict(size=font_config.small), ) return fig
[docs] class PlotType(str, Enum): """Plot types for 3D grid plotting.""" CONTOUR = "contour" SURFACE = "surface"
[docs] def plot_grid_3D( x_list, y_list, z_data_dict: Dict[str, np.ndarray], subplot_titles: List[str], title: str, x_label: str, y_label: str, z_label_dict: Dict[str, str], rows: int, cols: int, plot_type: PlotType, theme=None, font_config: Optional[FontConfig] = None, layout_config: Optional[LayoutConfig] = None, colorbar_config: Optional[ColorbarConfig] = None, smoothing_config: Optional[SmoothingConfig] = None, ): """ Generic 3D grid plotting utility for contour or surface plots. Requires grid-based input data (2D z arrays). Parameters ---------- x_list : numpy.ndarray 1D numpy array of x coordinates defining the grid x-axis y_list : numpy.ndarray 1D numpy array of y coordinates defining the grid y-axis z_data_dict : Dict[str, np.ndarray] Dictionary mapping event keys to z-data 2D arrays. Each z array must have shape (len(x_list), len(y_list)) subplot_titles : List[str] List of subplot titles in order title : str Overall grid title x_label : str X-axis label y_label : str Y-axis label z_label_dict : Dict[str, str] Dictionary mapping event keys to z-axis labels rows : int Number of rows in the grid cols : int Number of columns in the grid plot_type : PlotType PlotType enum (CONTOUR or SURFACE) theme : str, optional Color theme name font_config : FontConfig, optional FontConfig object for font settings layout_config : LayoutConfig, optional LayoutConfig object for layout settings colorbar_config : ColorbarConfig, optional ColorbarConfig object for colorbar settings smoothing_config : SmoothingConfig, optional SmoothingConfig object for smoothing settings Returns ------- go.Figure Plotly figure object """ if font_config is None: font_config = DEFAULT_FONT_CONFIG if layout_config is None: layout_config = DEFAULT_LAYOUT_CONFIG if colorbar_config is None: colorbar_config = DEFAULT_COLORBAR_CONFIG if smoothing_config is None: smoothing_config = DEFAULT_SMOOTHING_CONFIG is_surface = plot_type == PlotType.SURFACE specs = ( [[{"type": "surface"} for _ in range(cols)] for _ in range(rows)] if is_surface else None ) # Create figure fig = make_subplots( rows=rows, cols=cols, specs=specs, subplot_titles=subplot_titles, horizontal_spacing=layout_config.grid_horizontal_spacing, vertical_spacing=layout_config.grid_vertical_spacing, shared_xaxes=False, shared_yaxes=True, ) theme_colors = get_theme(theme) # Populate subplots event_keys = list(z_data_dict.keys()) for idx, event_key in enumerate(event_keys): r = idx // cols + 1 c = idx % cols + 1 z_list = z_data_dict[event_key] x_interp, y_interp, z_smooth = prepare_smooth_data_3D( x_list, y_list, z_list, smoothing_config=smoothing_config ) # Add trace based on plot type if is_surface: _add_surface_trace( fig, r, c, x_interp, y_interp, z_smooth, theme_colors["colorscale"], z_label_dict.get(event_key, event_key), font_config=font_config, colorbar_config=colorbar_config, ) else: _add_contour_trace( fig, r, c, x_interp, y_interp, z_smooth, theme_colors["colorscale"], z_label_dict.get(event_key, event_key), font_config=font_config, colorbar_config=colorbar_config, ) # Set axis titles yaxis_name = f"yaxis{idx + 1}" if idx > 0 else "yaxis" fig.update_layout( { yaxis_name: { "title": {"text": y_label, "font": {"size": font_config.medium}} } } ) xaxis_name = f"xaxis{idx + 1}" if idx > 0 else "xaxis" fig.update_layout( { xaxis_name: { "title": {"text": x_label, "font": {"size": font_config.medium}} } } ) # Configure overall layout fig.update_layout( title={ "text": title, "font": dict(size=font_config.large, color=TEXT_COLOR_DARK), "x": layout_config.title_x, "xanchor": layout_config.title_xanchor, }, width=max(layout_config.width, layout_config.grid_width_per_col * cols), height=max(layout_config.height, layout_config.grid_height_per_row * rows), plot_bgcolor=layout_config.plot_bgcolor, showlegend=False, margin=layout_config.margin, ) # Final touches if not is_surface: fig.update_xaxes( showline=True, linewidth=1, linecolor="lightgrey", mirror=True, showgrid=True, gridwidth=GRID_WIDTH, gridcolor=GRID_COLOR, tickformat=FLOAT_PRECISION, tickfont=dict(size=font_config.small), ) fig.update_yaxes( showline=True, linewidth=1, linecolor="lightgrey", mirror=True, showgrid=True, gridwidth=GRID_WIDTH, gridcolor=GRID_COLOR, tickformat=FLOAT_PRECISION, tickfont=dict(size=font_config.small), ) # Position colorbars dynamically for contour plots for i, trace in enumerate(fig.data): _position_colorbar(trace, fig, i + 1) else: # For surface plots, set scene axis titles for idx, event_key in enumerate(event_keys): scene_num_str = str(idx + 1) if idx > 0 else "" scene_name = f"scene{scene_num_str}" z_axis_label = z_label_dict.get(event_key, event_key) scene_obj = getattr(fig.layout, scene_name, None) if scene_obj: scene_obj.xaxis.title.text = x_label scene_obj.xaxis.tickformat = FLOAT_PRECISION scene_obj.yaxis.title.text = y_label scene_obj.yaxis.tickformat = FLOAT_PRECISION scene_obj.zaxis.title.text = z_axis_label scene_obj.zaxis.tickformat = FLOAT_PRECISION return fig
def _add_contour_trace( fig, row, col, x_interp, y_interp, z_smooth, colorscale, z_label, font_config: FontConfig, colorbar_config: ColorbarConfig, ): """ Helper function to add contour trace to figure. Parameters ---------- fig : go.Figure Plotly figure to add trace to row : int Row position in subplot grid col : int Column position in subplot grid x_interp : numpy.ndarray Interpolated x coordinates y_interp : numpy.ndarray Interpolated y coordinates z_smooth : numpy.ndarray Smoothed z values colorscale : str Colorscale name z_label : str Label for z-axis font_config : FontConfig Font configuration object colorbar_config : ColorbarConfig Colorbar configuration object """ contour_size = ( int((z_smooth.max() - z_smooth.min()) / NUM_CONTOURS) if z_smooth.max() > z_smooth.min() else 1 ) trace = go.Contour( x=x_interp, y=y_interp, z=z_smooth, colorscale=colorscale, contours=dict( showlabels=True, labelfont=dict( size=font_config.small, color=CONTOUR_LABEL_COLOR, ), start=int(z_smooth.min()), end=int(z_smooth.max()), size=contour_size, labelformat=FLOAT_PRECISION, ), colorbar=dict( title=dict( text=z_label if row is None else None, font=dict(size=font_config.medium), ), thickness=colorbar_config.thickness, len=colorbar_config.length, tickfont=dict(size=font_config.small), ), hovertemplate=( f"X: %{{x:{FLOAT_PRECISION}}}<br>" f"Y: %{{y:{FLOAT_PRECISION}}}<br>" f"{z_label}: %{{z:{FLOAT_PRECISION}}}<extra></extra>" ), showscale=True, ) fig.add_trace(trace, row=row, col=col) def _add_surface_trace( fig, row, col, x_interp, y_interp, z_smooth, colorscale, z_label, font_config: FontConfig, colorbar_config: ColorbarConfig, ): """ Helper function to add surface trace to figure. Parameters ---------- fig : go.Figure Plotly figure to add trace to row : int Row position in subplot grid col : int Column position in subplot grid x_interp : numpy.ndarray Interpolated x coordinates y_interp : numpy.ndarray Interpolated y coordinates z_smooth : numpy.ndarray Smoothed z values colorscale : str Colorscale name z_label : str Label for z-axis font_config : FontConfig Font configuration object colorbar_config : ColorbarConfig Colorbar configuration object """ contour_size = ( (z_smooth.max() - z_smooth.min()) / NUM_CONTOURS if z_smooth.max() > z_smooth.min() else 1 ) trace = go.Surface( x=np.array([x_interp] * len(y_interp)), y=np.array([y_interp] * len(x_interp)).T, z=z_smooth, colorscale=colorscale, contours={ "z": { "show": True, "start": z_smooth.min(), "end": z_smooth.max(), "size": contour_size, "width": LINE_WIDTH, "color": CONTOUR_LABEL_COLOR, } }, colorbar=dict( title=dict( text=z_label if row is None else None, font=dict(size=font_config.medium), ), thickness=colorbar_config.thickness, len=colorbar_config.length, tickfont=dict(size=font_config.small), ), hovertemplate=( f"X: %{{x:{FLOAT_PRECISION}}}<br>" f"Y: %{{y:{FLOAT_PRECISION}}}<br>" f"{z_label}: %{{z:{FLOAT_PRECISION}}}<extra></extra>" ), showscale=False, ) fig.add_trace(trace, row=row, col=col) def _position_colorbar(trace, fig, subplot_num): """ Dynamically position the colorbar for a subplot in a grid. Parameters ---------- trace : plotly trace object The trace containing the colorbar to position fig : go.Figure The figure containing the subplot subplot_num : int The subplot number (1-indexed) """ if hasattr(trace, "colorbar") and trace.colorbar is not None: xaxis_name = f"xaxis{subplot_num}" if subplot_num > 1 else "xaxis" yaxis_name = f"yaxis{subplot_num}" if subplot_num > 1 else "yaxis" if xaxis_name in fig.layout and yaxis_name in fig.layout: x_dom = fig.layout[xaxis_name].domain y_dom = fig.layout[yaxis_name].domain trace.colorbar.x = x_dom[1] trace.colorbar.xanchor = "left" mid_y = 0.5 * (y_dom[0] + y_dom[1]) trace.colorbar.y = mid_y trace.colorbar.yanchor = "middle" height_fraction = y_dom[1] - y_dom[0] trace.colorbar.len = 0.8 * height_fraction trace.showscale = True