from enum import Enum
from typing import Callable, Dict, List, Optional
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
from .color_themes import get_theme
from .plotting_constants import *
from .utils import prepare_smooth_data_2D, prepare_smooth_data_3D
# Set default plotly template for better aesthetics
pio.templates.default = "plotly_white"
[docs]
def plot_grid_2D(
x_list,
y_data_dict: Dict[str, np.ndarray],
subplot_titles: List[str],
title: str,
x_label: str,
y_label: str,
rows: int,
cols: int,
*,
theme=None,
fit_curve: bool = False,
font_config: Optional[FontConfig] = None,
layout_config: Optional[LayoutConfig] = None,
smoothing_config: Optional[SmoothingConfig] = None,
):
"""
Generic 2D grid plotting utility.
Parameters
----------
x_list : numpy.ndarray
Numpy array of x coordinates (shared across all subplots)
y_data_dict : Dict[str, np.ndarray]
Dictionary mapping event keys to y-data arrays
subplot_titles : List[str]
List of subplot titles in order
title : str
Overall grid title
x_label : str
X-axis label for all subplots
y_label : str
Y-axis label for all subplots
rows : int
Number of rows in the grid
cols : int
Number of columns in the grid
theme : str, optional
Color theme name
fit_curve : bool, optional
Whether to fit curves to data
font_config : FontConfig, optional
FontConfig object for font settings
layout_config : LayoutConfig, optional
LayoutConfig object for layout settings
smoothing_config : SmoothingConfig, optional
SmoothingConfig object for smoothing settings
Returns
-------
go.Figure
Plotly figure object
"""
if font_config is None:
font_config = DEFAULT_FONT_CONFIG
if layout_config is None:
layout_config = DEFAULT_LAYOUT_CONFIG
if smoothing_config is None:
smoothing_config = DEFAULT_SMOOTHING_CONFIG
# Create figure
fig = make_subplots(
rows=rows,
cols=cols,
subplot_titles=subplot_titles,
horizontal_spacing=layout_config.grid_horizontal_spacing,
vertical_spacing=layout_config.grid_vertical_spacing,
shared_xaxes=False,
shared_yaxes=False,
)
theme_colors = get_theme(theme)
# Populate subplots
event_keys = list(y_data_dict.keys())
for idx, event_key in enumerate(event_keys):
r = idx // cols + 1
c = idx % cols + 1
y_list = y_data_dict[event_key]
x_dense, y_dense = prepare_smooth_data_2D(
x_list, y_list, smoothing_config=smoothing_config
)
# Add interpolated line
fig.add_trace(
go.Scatter(
x=x_dense,
y=y_dense,
mode="lines",
line=dict(
color=theme_colors["light"],
width=LINE_WIDTH,
shape=("spline" if fit_curve else "linear"),
),
name=event_key,
showlegend=False,
),
row=r,
col=c,
)
# Add data points
fig.add_trace(
go.Scatter(
x=x_list,
y=y_list,
mode="markers",
marker=dict(color=theme_colors["dark"], size=MARKER_SIZE),
name=f"{event_key} Data",
showlegend=False,
),
row=r,
col=c,
)
# Set axis titles for this subplot
yaxis_name = f"yaxis{idx + 1}" if idx > 0 else "yaxis"
fig.update_layout(
{
yaxis_name: {
"title": {"text": y_label, "font": {"size": font_config.medium}}
}
}
)
xaxis_name = f"xaxis{idx + 1}" if idx > 0 else "xaxis"
fig.update_layout(
{
xaxis_name: {
"title": {"text": x_label, "font": {"size": font_config.medium}}
}
}
)
# Configure overall layout
fig.update_layout(
title={
"text": title,
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
},
width=max(layout_config.width, layout_config.grid_width_per_col * cols),
height=max(layout_config.height, layout_config.grid_height_per_row * rows),
plot_bgcolor=layout_config.plot_bgcolor,
showlegend=False,
margin=layout_config.margin,
)
# Add gridlines and format tick labels
fig.update_xaxes(
showgrid=True,
gridwidth=GRID_WIDTH,
gridcolor=GRID_COLOR,
zeroline=True,
zerolinewidth=ZEROLINE_WIDTH,
zerolinecolor=ZEROLINE_COLOR,
tickformat=FLOAT_PRECISION,
tickfont=dict(size=font_config.small),
)
fig.update_yaxes(
showgrid=True,
gridwidth=GRID_WIDTH,
gridcolor=GRID_COLOR,
zeroline=True,
zerolinewidth=ZEROLINE_WIDTH,
zerolinecolor=ZEROLINE_COLOR,
tickformat=FLOAT_PRECISION,
tickfont=dict(size=font_config.small),
)
return fig
[docs]
class PlotType(str, Enum):
"""Plot types for 3D grid plotting."""
CONTOUR = "contour"
SURFACE = "surface"
[docs]
def plot_grid_3D(
x_list,
y_list,
z_data_dict: Dict[str, np.ndarray],
subplot_titles: List[str],
title: str,
x_label: str,
y_label: str,
z_label_dict: Dict[str, str],
rows: int,
cols: int,
plot_type: PlotType,
theme=None,
font_config: Optional[FontConfig] = None,
layout_config: Optional[LayoutConfig] = None,
colorbar_config: Optional[ColorbarConfig] = None,
smoothing_config: Optional[SmoothingConfig] = None,
):
"""
Generic 3D grid plotting utility for contour or surface plots.
Requires grid-based input data (2D z arrays).
Parameters
----------
x_list : numpy.ndarray
1D numpy array of x coordinates defining the grid x-axis
y_list : numpy.ndarray
1D numpy array of y coordinates defining the grid y-axis
z_data_dict : Dict[str, np.ndarray]
Dictionary mapping event keys to z-data 2D arrays.
Each z array must have shape (len(x_list), len(y_list))
subplot_titles : List[str]
List of subplot titles in order
title : str
Overall grid title
x_label : str
X-axis label
y_label : str
Y-axis label
z_label_dict : Dict[str, str]
Dictionary mapping event keys to z-axis labels
rows : int
Number of rows in the grid
cols : int
Number of columns in the grid
plot_type : PlotType
PlotType enum (CONTOUR or SURFACE)
theme : str, optional
Color theme name
font_config : FontConfig, optional
FontConfig object for font settings
layout_config : LayoutConfig, optional
LayoutConfig object for layout settings
colorbar_config : ColorbarConfig, optional
ColorbarConfig object for colorbar settings
smoothing_config : SmoothingConfig, optional
SmoothingConfig object for smoothing settings
Returns
-------
go.Figure
Plotly figure object
"""
if font_config is None:
font_config = DEFAULT_FONT_CONFIG
if layout_config is None:
layout_config = DEFAULT_LAYOUT_CONFIG
if colorbar_config is None:
colorbar_config = DEFAULT_COLORBAR_CONFIG
if smoothing_config is None:
smoothing_config = DEFAULT_SMOOTHING_CONFIG
is_surface = plot_type == PlotType.SURFACE
specs = (
[[{"type": "surface"} for _ in range(cols)] for _ in range(rows)]
if is_surface
else None
)
# Create figure
fig = make_subplots(
rows=rows,
cols=cols,
specs=specs,
subplot_titles=subplot_titles,
horizontal_spacing=layout_config.grid_horizontal_spacing,
vertical_spacing=layout_config.grid_vertical_spacing,
shared_xaxes=False,
shared_yaxes=True,
)
theme_colors = get_theme(theme)
# Populate subplots
event_keys = list(z_data_dict.keys())
for idx, event_key in enumerate(event_keys):
r = idx // cols + 1
c = idx % cols + 1
z_list = z_data_dict[event_key]
x_interp, y_interp, z_smooth = prepare_smooth_data_3D(
x_list, y_list, z_list, smoothing_config=smoothing_config
)
# Add trace based on plot type
if is_surface:
_add_surface_trace(
fig,
r,
c,
x_interp,
y_interp,
z_smooth,
theme_colors["colorscale"],
z_label_dict.get(event_key, event_key),
font_config=font_config,
colorbar_config=colorbar_config,
)
else:
_add_contour_trace(
fig,
r,
c,
x_interp,
y_interp,
z_smooth,
theme_colors["colorscale"],
z_label_dict.get(event_key, event_key),
font_config=font_config,
colorbar_config=colorbar_config,
)
# Set axis titles
yaxis_name = f"yaxis{idx + 1}" if idx > 0 else "yaxis"
fig.update_layout(
{
yaxis_name: {
"title": {"text": y_label, "font": {"size": font_config.medium}}
}
}
)
xaxis_name = f"xaxis{idx + 1}" if idx > 0 else "xaxis"
fig.update_layout(
{
xaxis_name: {
"title": {"text": x_label, "font": {"size": font_config.medium}}
}
}
)
# Configure overall layout
fig.update_layout(
title={
"text": title,
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
},
width=max(layout_config.width, layout_config.grid_width_per_col * cols),
height=max(layout_config.height, layout_config.grid_height_per_row * rows),
plot_bgcolor=layout_config.plot_bgcolor,
showlegend=False,
margin=layout_config.margin,
)
# Final touches
if not is_surface:
fig.update_xaxes(
showline=True,
linewidth=1,
linecolor="lightgrey",
mirror=True,
showgrid=True,
gridwidth=GRID_WIDTH,
gridcolor=GRID_COLOR,
tickformat=FLOAT_PRECISION,
tickfont=dict(size=font_config.small),
)
fig.update_yaxes(
showline=True,
linewidth=1,
linecolor="lightgrey",
mirror=True,
showgrid=True,
gridwidth=GRID_WIDTH,
gridcolor=GRID_COLOR,
tickformat=FLOAT_PRECISION,
tickfont=dict(size=font_config.small),
)
# Position colorbars dynamically for contour plots
for i, trace in enumerate(fig.data):
_position_colorbar(trace, fig, i + 1)
else:
# For surface plots, set scene axis titles
for idx, event_key in enumerate(event_keys):
scene_num_str = str(idx + 1) if idx > 0 else ""
scene_name = f"scene{scene_num_str}"
z_axis_label = z_label_dict.get(event_key, event_key)
scene_obj = getattr(fig.layout, scene_name, None)
if scene_obj:
scene_obj.xaxis.title.text = x_label
scene_obj.xaxis.tickformat = FLOAT_PRECISION
scene_obj.yaxis.title.text = y_label
scene_obj.yaxis.tickformat = FLOAT_PRECISION
scene_obj.zaxis.title.text = z_axis_label
scene_obj.zaxis.tickformat = FLOAT_PRECISION
return fig
def _add_contour_trace(
fig,
row,
col,
x_interp,
y_interp,
z_smooth,
colorscale,
z_label,
font_config: FontConfig,
colorbar_config: ColorbarConfig,
):
"""
Helper function to add contour trace to figure.
Parameters
----------
fig : go.Figure
Plotly figure to add trace to
row : int
Row position in subplot grid
col : int
Column position in subplot grid
x_interp : numpy.ndarray
Interpolated x coordinates
y_interp : numpy.ndarray
Interpolated y coordinates
z_smooth : numpy.ndarray
Smoothed z values
colorscale : str
Colorscale name
z_label : str
Label for z-axis
font_config : FontConfig
Font configuration object
colorbar_config : ColorbarConfig
Colorbar configuration object
"""
contour_size = (
int((z_smooth.max() - z_smooth.min()) / NUM_CONTOURS)
if z_smooth.max() > z_smooth.min()
else 1
)
trace = go.Contour(
x=x_interp,
y=y_interp,
z=z_smooth,
colorscale=colorscale,
contours=dict(
showlabels=True,
labelfont=dict(
size=font_config.small,
color=CONTOUR_LABEL_COLOR,
),
start=int(z_smooth.min()),
end=int(z_smooth.max()),
size=contour_size,
labelformat=FLOAT_PRECISION,
),
colorbar=dict(
title=dict(
text=z_label if row is None else None,
font=dict(size=font_config.medium),
),
thickness=colorbar_config.thickness,
len=colorbar_config.length,
tickfont=dict(size=font_config.small),
),
hovertemplate=(
f"X: %{{x:{FLOAT_PRECISION}}}<br>"
f"Y: %{{y:{FLOAT_PRECISION}}}<br>"
f"{z_label}: %{{z:{FLOAT_PRECISION}}}<extra></extra>"
),
showscale=True,
)
fig.add_trace(trace, row=row, col=col)
def _add_surface_trace(
fig,
row,
col,
x_interp,
y_interp,
z_smooth,
colorscale,
z_label,
font_config: FontConfig,
colorbar_config: ColorbarConfig,
):
"""
Helper function to add surface trace to figure.
Parameters
----------
fig : go.Figure
Plotly figure to add trace to
row : int
Row position in subplot grid
col : int
Column position in subplot grid
x_interp : numpy.ndarray
Interpolated x coordinates
y_interp : numpy.ndarray
Interpolated y coordinates
z_smooth : numpy.ndarray
Smoothed z values
colorscale : str
Colorscale name
z_label : str
Label for z-axis
font_config : FontConfig
Font configuration object
colorbar_config : ColorbarConfig
Colorbar configuration object
"""
contour_size = (
(z_smooth.max() - z_smooth.min()) / NUM_CONTOURS
if z_smooth.max() > z_smooth.min()
else 1
)
trace = go.Surface(
x=np.array([x_interp] * len(y_interp)),
y=np.array([y_interp] * len(x_interp)).T,
z=z_smooth,
colorscale=colorscale,
contours={
"z": {
"show": True,
"start": z_smooth.min(),
"end": z_smooth.max(),
"size": contour_size,
"width": LINE_WIDTH,
"color": CONTOUR_LABEL_COLOR,
}
},
colorbar=dict(
title=dict(
text=z_label if row is None else None,
font=dict(size=font_config.medium),
),
thickness=colorbar_config.thickness,
len=colorbar_config.length,
tickfont=dict(size=font_config.small),
),
hovertemplate=(
f"X: %{{x:{FLOAT_PRECISION}}}<br>"
f"Y: %{{y:{FLOAT_PRECISION}}}<br>"
f"{z_label}: %{{z:{FLOAT_PRECISION}}}<extra></extra>"
),
showscale=False,
)
fig.add_trace(trace, row=row, col=col)
def _position_colorbar(trace, fig, subplot_num):
"""
Dynamically position the colorbar for a subplot in a grid.
Parameters
----------
trace : plotly trace object
The trace containing the colorbar to position
fig : go.Figure
The figure containing the subplot
subplot_num : int
The subplot number (1-indexed)
"""
if hasattr(trace, "colorbar") and trace.colorbar is not None:
xaxis_name = f"xaxis{subplot_num}" if subplot_num > 1 else "xaxis"
yaxis_name = f"yaxis{subplot_num}" if subplot_num > 1 else "yaxis"
if xaxis_name in fig.layout and yaxis_name in fig.layout:
x_dom = fig.layout[xaxis_name].domain
y_dom = fig.layout[yaxis_name].domain
trace.colorbar.x = x_dom[1]
trace.colorbar.xanchor = "left"
mid_y = 0.5 * (y_dom[0] + y_dom[1])
trace.colorbar.y = mid_y
trace.colorbar.yanchor = "middle"
height_fraction = y_dom[1] - y_dom[0]
trace.colorbar.len = 0.8 * height_fraction
trace.showscale = True