Source code for suboptimumg.plotting.overlay_grid_plot_2d

from typing import Dict, List

import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from .plotting_constants import *
from .utils import prepare_smooth_data_2D


[docs] def overlay_grid_plot_2D( x_data_dict: Dict[str, Dict[str, npt.NDArray[np.number]]], y_data_dict: Dict[str, Dict[str, npt.NDArray[np.number]]], subplot_titles: List[str], title: str, x_label: str, y_label: str, rows: int, cols: int, *, show_points: bool = False, show_legend: bool = False, fit_curve: bool = False, layout_config: LayoutConfig = DEFAULT_LAYOUT_CONFIG, font_config: FontConfig = DEFAULT_FONT_CONFIG, smoothing_config: SmoothingConfig = DEFAULT_SMOOTHING_CONFIG, ) -> go.Figure: """ Generic 2D overlay grid plotting utility for comparing data source across multiple subplots. Each subplot contains one trace for each data source, allowing for easy comparison of trends across different categories. Parameters ---------- x_data_dict : Dict[str, Dict[str, npt.NDArray[np.number]]] {subplot_key: {trace_label: x_data_array}} - subplot_key: identifies which subplot the data belongs to (must match order in subplot_titles) - trace_label: identifies which trace/series (must match keys in y_data_dict[subplot_key]) y_data_dict : Dict[str, Dict[str, npt.NDArray[np.number]]] {subplot_key: {trace_label: y_data_array}} - subplot_key: identifies which subplot the data belongs to (must match order in subplot_titles) - trace_label: identifies which trace/series (must match keys in x_data_dict[subplot_key]) subplot_titles : List[str] List of titles for each subplot title : str Overall title for the grid x_label : str Label for x-axis y_label : str Label for y-axis rows : int Number of rows in the grid cols : int Number of columns in the grid show_legend : bool, optional If True, shows legend on the first subplot fit_curve : bool, optional Whether to use spline fitting for curves smoothing : float, optional Smoothing parameter (deprecated, use smoothing_config) interp_factor : int, optional Interpolation factor (deprecated, use smoothing_config) layout_config : LayoutConfig, optional Layout configuration object font_config : FontConfig, optional Font configuration object smoothing_config : SmoothingConfig, optional Smoothing configuration object Returns ------- go.Figure Plotly figure object """ # Build figure fig = make_subplots( rows=rows, cols=cols, subplot_titles=subplot_titles, horizontal_spacing=layout_config.grid_horizontal_spacing, vertical_spacing=layout_config.grid_vertical_spacing, shared_xaxes=False, shared_yaxes=False, ) subplot_keys = list(y_data_dict.keys()) # Get all unique trace labels across all subplots all_trace_labels = set() for subplot_key in subplot_keys: all_trace_labels.update(y_data_dict[subplot_key].keys()) trace_labels = sorted(list(all_trace_labels)) colors = _color_cycle(len(trace_labels)) # Create traces for each subplot for subplot_idx, subplot_key in enumerate(subplot_keys): r = subplot_idx // cols + 1 c = subplot_idx % cols + 1 x_data_for_subplot = x_data_dict[subplot_key] y_data_for_subplot = y_data_dict[subplot_key] for trace_label, color in zip(trace_labels, colors): # Get x and y data for this trace x_vals = x_data_for_subplot[trace_label] y_vals = y_data_for_subplot[trace_label] x_dense, y_dense = prepare_smooth_data_2D(x_vals, y_vals, smoothing_config) # Interpolated line fig.add_trace( go.Scatter( x=x_dense, y=y_dense, mode="lines", line=dict( color=color, width=2, shape=("spline" if fit_curve else "linear"), ), name=trace_label, legendgroup=trace_label, showlegend=(show_legend and subplot_idx == 0), ), row=r, col=c, ) # Original points if show_points: fig.add_trace( go.Scatter( x=x_vals, y=y_vals, mode="markers", marker=dict(color=color, size=MARKER_SIZE, opacity=0.85), name=trace_label, showlegend=False, hovertemplate=f"{x_label}: %{{x}}<br>{subplot_titles[subplot_idx]}: %{{y:.2f}}<extra></extra>", ), row=r, col=c, ) # Axis titles for each subplot yaxis_name = f"yaxis{subplot_idx + 1}" if subplot_idx > 0 else "yaxis" xaxis_name = f"xaxis{subplot_idx + 1}" if subplot_idx > 0 else "xaxis" fig.update_layout( { yaxis_name: { "title": {"text": y_label, "font": {"size": font_config.medium}} } } ) fig.update_layout( { xaxis_name: { "title": {"text": x_label, "font": {"size": font_config.medium}} } } ) # Layout and styling fig.update_layout( title={ "text": title, "font": dict(size=font_config.large, color=TEXT_COLOR_DARK), "x": layout_config.title_x, "xanchor": layout_config.title_xanchor, }, width=max(layout_config.width, layout_config.grid_width_per_col * cols), height=max(layout_config.height, layout_config.grid_height_per_row * rows), plot_bgcolor=layout_config.plot_bgcolor, showlegend=show_legend, legend=dict( orientation="v", yanchor="top", xanchor="left", ), margin=layout_config.margin, ) # Subtle gridlines and zero lines fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="rgba(211, 211, 211, 0.3)") fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="rgba(211, 211, 211, 0.3)") fig.update_xaxes( zeroline=True, zerolinewidth=1, zerolinecolor="rgba(211, 211, 211, 0.8)" ) fig.update_yaxes( zeroline=True, zerolinewidth=1, zerolinecolor="rgba(211, 211, 211, 0.8)" ) return fig