from typing import Dict, List
import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from .plotting_constants import *
from .utils import prepare_smooth_data_2D
[docs]
def overlay_grid_plot_2D(
x_data_dict: Dict[str, Dict[str, npt.NDArray[np.number]]],
y_data_dict: Dict[str, Dict[str, npt.NDArray[np.number]]],
subplot_titles: List[str],
title: str,
x_label: str,
y_label: str,
rows: int,
cols: int,
*,
show_points: bool = False,
show_legend: bool = False,
fit_curve: bool = False,
layout_config: LayoutConfig = DEFAULT_LAYOUT_CONFIG,
font_config: FontConfig = DEFAULT_FONT_CONFIG,
smoothing_config: SmoothingConfig = DEFAULT_SMOOTHING_CONFIG,
) -> go.Figure:
"""
Generic 2D overlay grid plotting utility for comparing data source across multiple subplots.
Each subplot contains one trace for each data source, allowing for easy comparison of trends across different
categories.
Parameters
----------
x_data_dict : Dict[str, Dict[str, npt.NDArray[np.number]]]
{subplot_key: {trace_label: x_data_array}}
- subplot_key: identifies which subplot the data belongs to (must match order in subplot_titles)
- trace_label: identifies which trace/series (must match keys in y_data_dict[subplot_key])
y_data_dict : Dict[str, Dict[str, npt.NDArray[np.number]]]
{subplot_key: {trace_label: y_data_array}}
- subplot_key: identifies which subplot the data belongs to (must match order in subplot_titles)
- trace_label: identifies which trace/series (must match keys in x_data_dict[subplot_key])
subplot_titles : List[str]
List of titles for each subplot
title : str
Overall title for the grid
x_label : str
Label for x-axis
y_label : str
Label for y-axis
rows : int
Number of rows in the grid
cols : int
Number of columns in the grid
show_legend : bool, optional
If True, shows legend on the first subplot
fit_curve : bool, optional
Whether to use spline fitting for curves
smoothing : float, optional
Smoothing parameter (deprecated, use smoothing_config)
interp_factor : int, optional
Interpolation factor (deprecated, use smoothing_config)
layout_config : LayoutConfig, optional
Layout configuration object
font_config : FontConfig, optional
Font configuration object
smoothing_config : SmoothingConfig, optional
Smoothing configuration object
Returns
-------
go.Figure
Plotly figure object
"""
# Build figure
fig = make_subplots(
rows=rows,
cols=cols,
subplot_titles=subplot_titles,
horizontal_spacing=layout_config.grid_horizontal_spacing,
vertical_spacing=layout_config.grid_vertical_spacing,
shared_xaxes=False,
shared_yaxes=False,
)
subplot_keys = list(y_data_dict.keys())
# Get all unique trace labels across all subplots
all_trace_labels = set()
for subplot_key in subplot_keys:
all_trace_labels.update(y_data_dict[subplot_key].keys())
trace_labels = sorted(list(all_trace_labels))
colors = _color_cycle(len(trace_labels))
# Create traces for each subplot
for subplot_idx, subplot_key in enumerate(subplot_keys):
r = subplot_idx // cols + 1
c = subplot_idx % cols + 1
x_data_for_subplot = x_data_dict[subplot_key]
y_data_for_subplot = y_data_dict[subplot_key]
for trace_label, color in zip(trace_labels, colors):
# Get x and y data for this trace
x_vals = x_data_for_subplot[trace_label]
y_vals = y_data_for_subplot[trace_label]
x_dense, y_dense = prepare_smooth_data_2D(x_vals, y_vals, smoothing_config)
# Interpolated line
fig.add_trace(
go.Scatter(
x=x_dense,
y=y_dense,
mode="lines",
line=dict(
color=color,
width=2,
shape=("spline" if fit_curve else "linear"),
),
name=trace_label,
legendgroup=trace_label,
showlegend=(show_legend and subplot_idx == 0),
),
row=r,
col=c,
)
# Original points
if show_points:
fig.add_trace(
go.Scatter(
x=x_vals,
y=y_vals,
mode="markers",
marker=dict(color=color, size=MARKER_SIZE, opacity=0.85),
name=trace_label,
showlegend=False,
hovertemplate=f"{x_label}: %{{x}}<br>{subplot_titles[subplot_idx]}: %{{y:.2f}}<extra></extra>",
),
row=r,
col=c,
)
# Axis titles for each subplot
yaxis_name = f"yaxis{subplot_idx + 1}" if subplot_idx > 0 else "yaxis"
xaxis_name = f"xaxis{subplot_idx + 1}" if subplot_idx > 0 else "xaxis"
fig.update_layout(
{
yaxis_name: {
"title": {"text": y_label, "font": {"size": font_config.medium}}
}
}
)
fig.update_layout(
{
xaxis_name: {
"title": {"text": x_label, "font": {"size": font_config.medium}}
}
}
)
# Layout and styling
fig.update_layout(
title={
"text": title,
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
},
width=max(layout_config.width, layout_config.grid_width_per_col * cols),
height=max(layout_config.height, layout_config.grid_height_per_row * rows),
plot_bgcolor=layout_config.plot_bgcolor,
showlegend=show_legend,
legend=dict(
orientation="v",
yanchor="top",
xanchor="left",
),
margin=layout_config.margin,
)
# Subtle gridlines and zero lines
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="rgba(211, 211, 211, 0.3)")
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="rgba(211, 211, 211, 0.3)")
fig.update_xaxes(
zeroline=True, zerolinewidth=1, zerolinecolor="rgba(211, 211, 211, 0.8)"
)
fig.update_yaxes(
zeroline=True, zerolinewidth=1, zerolinecolor="rgba(211, 211, 211, 0.8)"
)
return fig