from typing import Dict, List, Optional
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
from .color_themes import get_theme
from .plotting_constants import *
from .utils import prepare_smooth_data_2D
# Set default plotly template for better aesthetics
pio.templates.default = "plotly_white"
[docs]
def plot_grid_2D(
x_list,
y_data_dict: Dict[str, np.ndarray],
subplot_titles: List[str],
title: str,
x_label: str,
y_label: str,
rows: int,
cols: int,
*,
theme=None,
fit_curve: bool = False,
font_config: Optional[FontConfig] = None,
layout_config: Optional[LayoutConfig] = None,
smoothing_config: Optional[SmoothingConfig] = None,
):
"""
Generic 2D grid plotting utility.
Parameters
----------
x_list : numpy.ndarray
Numpy array of x coordinates (shared across all subplots)
y_data_dict : Dict[str, np.ndarray]
Dictionary mapping event keys to y-data arrays
subplot_titles : List[str]
List of subplot titles in order
title : str
Overall grid title
x_label : str
X-axis label for all subplots
y_label : str
Y-axis label for all subplots
rows : int
Number of rows in the grid
cols : int
Number of columns in the grid
theme : str, optional
Color theme name
fit_curve : bool, optional
Whether to fit curves to data
font_config : FontConfig, optional
FontConfig object for font settings
layout_config : LayoutConfig, optional
LayoutConfig object for layout settings
smoothing_config : SmoothingConfig, optional
SmoothingConfig object for smoothing settings
Returns
-------
go.Figure
Plotly figure object
"""
if font_config is None:
font_config = DEFAULT_FONT_CONFIG
if layout_config is None:
layout_config = DEFAULT_LAYOUT_CONFIG
if smoothing_config is None:
smoothing_config = DEFAULT_SMOOTHING_CONFIG
# Create figure
fig = make_subplots(
rows=rows,
cols=cols,
subplot_titles=subplot_titles,
horizontal_spacing=layout_config.grid_horizontal_spacing,
vertical_spacing=layout_config.grid_vertical_spacing,
shared_xaxes=False,
shared_yaxes=False,
)
theme_colors = get_theme(theme)
# Populate subplots
event_keys = list(y_data_dict.keys())
for idx, event_key in enumerate(event_keys):
r = idx // cols + 1
c = idx % cols + 1
y_list = y_data_dict[event_key]
x_dense, y_dense = prepare_smooth_data_2D(
x_list, y_list, smoothing_config=smoothing_config
)
# Add interpolated line
fig.add_trace(
go.Scatter(
x=x_dense,
y=y_dense,
mode="lines",
line=dict(
color=theme_colors["light"],
width=LINE_WIDTH,
shape=("spline" if fit_curve else "linear"),
),
name=event_key,
showlegend=False,
),
row=r,
col=c,
)
# Add data points
fig.add_trace(
go.Scatter(
x=x_list,
y=y_list,
mode="markers",
marker=dict(color=theme_colors["dark"], size=MARKER_SIZE),
name=f"{event_key} Data",
showlegend=False,
),
row=r,
col=c,
)
# Set axis titles for this subplot
yaxis_name = f"yaxis{idx + 1}" if idx > 0 else "yaxis"
fig.update_layout(
{
yaxis_name: {
"title": {"text": y_label, "font": {"size": font_config.medium}}
}
}
)
xaxis_name = f"xaxis{idx + 1}" if idx > 0 else "xaxis"
fig.update_layout(
{
xaxis_name: {
"title": {"text": x_label, "font": {"size": font_config.medium}}
}
}
)
# Configure overall layout
fig.update_layout(
title={
"text": title,
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
},
width=max(layout_config.width, layout_config.grid_width_per_col * cols),
height=max(layout_config.height, layout_config.grid_height_per_row * rows),
plot_bgcolor=layout_config.plot_bgcolor,
showlegend=False,
margin=layout_config.margin,
)
# Add gridlines and format tick labels
fig.update_xaxes(
showgrid=True,
gridwidth=GRID_WIDTH,
gridcolor=GRID_COLOR,
zeroline=True,
zerolinewidth=ZEROLINE_WIDTH,
zerolinecolor=ZEROLINE_COLOR,
tickformat=FLOAT_PRECISION,
tickfont=dict(size=font_config.small),
)
fig.update_yaxes(
showgrid=True,
gridwidth=GRID_WIDTH,
gridcolor=GRID_COLOR,
zeroline=True,
zerolinewidth=ZEROLINE_WIDTH,
zerolinecolor=ZEROLINE_COLOR,
tickformat=FLOAT_PRECISION,
tickfont=dict(size=font_config.small),
)
return fig