from typing import Dict, Iterable, List, Optional
import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from plotly.colors import qualitative
from plotly.subplots import make_subplots
from suboptimumg.plotting.generic_plot import prepare_smooth_data_2D
from .plotting_constants import *
def _color_cycle(n: int) -> Iterable[str]:
"""
Get n distinct colors from Plotly's qualitative palette.
Parameters
----------
n : int
Number of colors needed
Returns
-------
Iterable[str]
Iterable of color strings, repeating if n exceeds palette size
"""
palette: Iterable[str] = qualitative.Plotly
if n <= len(palette):
return palette[:n]
times = n // len(palette) + 1
return (palette * times)[:n]
[docs]
def overlay_plot2D(
x_data_dict: Dict[str, npt.NDArray[np.number]],
y_data_dict: Dict[str, npt.NDArray[np.number]],
title: str,
x_axis: str,
y_axis: str,
*,
h_line: Optional[ReferenceLineConfig] = None,
v_line: Optional[ReferenceLineConfig] = None,
fit_curve: bool = False,
show_points: bool = False,
subtitle: Optional[str] = None,
font_config: Optional[FontConfig] = None,
layout_config: Optional[LayoutConfig] = None,
smoothing_config: Optional[SmoothingConfig] = None,
) -> go.Figure:
"""
Generic 2D overlay plotting utility for comparing multiple series.
Parameters
----------
x_data_dict : Dict[str, npt.NDArray[np.number]]
Dictionary mapping trace labels to x data arrays
y_data_dict : Dict[str, npt.NDArray[np.number]]
Dictionary mapping trace labels to y data arrays
title : str
Plot title
x_axis : str
X-axis label
y_axis : str
Y-axis label
h_line : ReferenceLineConfig, optional
ReferenceLineConfig for horizontal reference line (None = no line)
v_line : ReferenceLineConfig, optional
ReferenceLineConfig for vertical reference line (None = no line)
fit_curve : bool, optional
Whether to fit a curve to the data points
show_points : bool, optional
Whether to show individual data points
subtitle : str, optional
Optional subtitle
font_config : FontConfig, optional
FontConfig object for font settings
layout_config : LayoutConfig, optional
LayoutConfig object for layout settings
smoothing_config : SmoothingConfig, optional
SmoothingConfig object for smoothing settings
Returns
-------
go.Figure
Plotly figure object
"""
font_config = font_config or DEFAULT_FONT_CONFIG
layout_config = layout_config or DEFAULT_LAYOUT_CONFIG
smoothing_config = smoothing_config or DEFAULT_SMOOTHING_CONFIG
fig = go.Figure()
# Get trace labels and assign colors
trace_labels = sorted(list(y_data_dict.keys()))
colors = _color_cycle(len(trace_labels))
# Create traces for each series
for trace_label, color in zip(trace_labels, colors):
x_vals = x_data_dict[trace_label]
y_vals = y_data_dict[trace_label]
if fit_curve:
x_dense, y_dense = prepare_smooth_data_2D(
x_vals, y_vals, smoothing_config=smoothing_config
)
else:
x_dense, y_dense = x_vals, y_vals
# Interpolated line
fig.add_trace(
go.Scatter(
x=x_dense,
y=y_dense,
mode="lines",
line=dict(
color=color,
width=LINE_WIDTH,
shape="spline",
),
name=trace_label,
hovertemplate=f"{x_axis}: %{{x:{FLOAT_PRECISION}}}<br>{y_axis}: %{{y:{FLOAT_PRECISION}}}<extra></extra>",
)
)
if show_points:
fig.add_trace(
go.Scatter(
x=x_vals,
y=y_vals,
mode="markers",
marker=dict(color=color, size=MARKER_SIZE_LARGE),
showlegend=False,
hovertemplate=None,
)
)
# Add horizontal reference line if specified
if h_line is not None and h_line.value is not None:
# Get x range across all traces
all_x = np.concatenate([x_data_dict[label] for label in trace_labels])
x_min, x_max = min(all_x), max(all_x)
fig.add_trace(
go.Scatter(
x=[x_min, x_max],
y=[h_line.value, h_line.value],
mode="lines",
line=dict(
color=h_line.color,
width=h_line.width,
dash=h_line.dash,
),
name=f"{h_line.label}",
hoverinfo="name",
)
)
# Add vertical reference line if specified
if v_line is not None and v_line.value is not None:
# Get y range across all traces
all_y = np.concatenate([y_data_dict[label] for label in trace_labels])
y_min, y_max = min(all_y), max(all_y)
fig.add_trace(
go.Scatter(
x=[v_line.value, v_line.value],
y=[y_min, y_max],
mode="lines",
line=dict(
color=v_line.color,
width=v_line.width,
dash=v_line.dash,
),
name=f"{v_line.label}",
hoverinfo="name",
)
)
# Title, Axes, Label, Legend, Size
full_title = title
if subtitle is not None:
full_title += f"<br><span style='font-size: {font_config.medium}px; color: {TEXT_COLOR_LIGHT};'>{subtitle}</span>"
fig.update_layout(
title={
"text": full_title,
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
},
xaxis_title={
"text": x_axis,
"font": dict(size=font_config.medium, color=TEXT_COLOR_DARK),
},
yaxis_title={
"text": y_axis,
"font": dict(size=font_config.medium, color=TEXT_COLOR_DARK),
},
legend_title="Legend",
showlegend=True,
legend_title_font=dict(size=font_config.medium),
legend_font=dict(size=font_config.small),
hovermode=HOVER_MODE,
plot_bgcolor=layout_config.plot_bgcolor,
width=layout_config.width,
height=layout_config.height,
margin=layout_config.margin,
)
# Add subtle grid lines
fig.update_xaxes(
showgrid=True,
gridwidth=GRID_WIDTH,
gridcolor=GRID_COLOR,
zeroline=True,
zerolinewidth=ZEROLINE_WIDTH,
zerolinecolor=ZEROLINE_COLOR,
tickfont=dict(size=font_config.small),
tickformat=FLOAT_PRECISION,
)
fig.update_yaxes(
showgrid=True,
gridwidth=GRID_WIDTH,
gridcolor=GRID_COLOR,
zeroline=True,
zerolinewidth=ZEROLINE_WIDTH,
zerolinecolor=ZEROLINE_COLOR,
tickfont=dict(size=font_config.small),
tickformat=FLOAT_PRECISION,
)
if h_line is not None and h_line.value is not None:
all_y = np.concatenate([y_data_dict[label] for label in trace_labels])
padding = 1 + RANGE_PADDING
fig.update_yaxes(
range=[
(1 / padding) * min(min(all_y), h_line.value),
padding * max(max(all_y), h_line.value),
]
)
if v_line is not None and v_line.value is not None:
all_x = np.concatenate([x_data_dict[label] for label in trace_labels])
padding = 1 + RANGE_PADDING
fig.update_xaxes(
range=[
(1 / padding) * min(min(all_x), v_line.value),
padding * max(max(all_x), v_line.value),
]
)
return fig
[docs]
def overlay_grid_plot_2D(
x_data_dict: Dict[str, Dict[str, npt.NDArray[np.number]]],
y_data_dict: Dict[str, Dict[str, npt.NDArray[np.number]]],
subplot_titles: List[str],
title: str,
x_label: str,
y_label: str,
rows: int,
cols: int,
*,
show_points: bool = False,
show_legend: bool = False,
fit_curve: bool = False,
layout_config: LayoutConfig = DEFAULT_LAYOUT_CONFIG,
font_config: FontConfig = DEFAULT_FONT_CONFIG,
smoothing_config: SmoothingConfig = DEFAULT_SMOOTHING_CONFIG,
) -> None:
"""
Generic 2D overlay grid plotting utility for comparing multiple series.
Parameters
----------
x_data_dict : Dict[str, Dict[str, npt.NDArray[np.number]]]
{subplot_key: {trace_label: x_data_array}}
- subplot_key: identifies which subplot the data belongs to (must match order in subplot_titles)
- trace_label: identifies which trace/series (must match keys in y_data_dict[subplot_key])
y_data_dict : Dict[str, Dict[str, npt.NDArray[np.number]]]
{subplot_key: {trace_label: y_data_array}}
- subplot_key: identifies which subplot the data belongs to (must match order in subplot_titles)
- trace_label: identifies which trace/series (must match keys in x_data_dict[subplot_key])
subplot_titles : List[str]
List of titles for each subplot
title : str
Overall title for the grid
x_label : str
Label for x-axis
y_label : str
Label for y-axis
rows : int
Number of rows in the grid
cols : int
Number of columns in the grid
show_legend : bool, optional
If True, shows legend on the first subplot
fit_curve : bool, optional
Whether to use spline fitting for curves
smoothing : float, optional
Smoothing parameter (deprecated, use smoothing_config)
interp_factor : int, optional
Interpolation factor (deprecated, use smoothing_config)
layout_config : LayoutConfig, optional
Layout configuration object
font_config : FontConfig, optional
Font configuration object
smoothing_config : SmoothingConfig, optional
Smoothing configuration object
Returns
-------
go.Figure
Plotly figure object
"""
# Build figure
fig = make_subplots(
rows=rows,
cols=cols,
subplot_titles=subplot_titles,
horizontal_spacing=layout_config.grid_horizontal_spacing,
vertical_spacing=layout_config.grid_vertical_spacing,
shared_xaxes=False,
shared_yaxes=False,
)
subplot_keys = list(y_data_dict.keys())
# Get all unique trace labels across all subplots
all_trace_labels = set()
for subplot_key in subplot_keys:
all_trace_labels.update(y_data_dict[subplot_key].keys())
trace_labels = sorted(list(all_trace_labels))
colors = _color_cycle(len(trace_labels))
# Create traces for each subplot
for subplot_idx, subplot_key in enumerate(subplot_keys):
r = subplot_idx // cols + 1
c = subplot_idx % cols + 1
x_data_for_subplot = x_data_dict[subplot_key]
y_data_for_subplot = y_data_dict[subplot_key]
for trace_label, color in zip(trace_labels, colors):
# Get x and y data for this trace
x_vals = x_data_for_subplot[trace_label]
y_vals = y_data_for_subplot[trace_label]
x_dense, y_dense = prepare_smooth_data_2D(x_vals, y_vals, smoothing_config)
# Interpolated line
fig.add_trace(
go.Scatter(
x=x_dense,
y=y_dense,
mode="lines",
line=dict(
color=color,
width=2,
shape=("spline" if fit_curve else "linear"),
),
name=trace_label,
legendgroup=trace_label,
showlegend=(show_legend and subplot_idx == 0),
),
row=r,
col=c,
)
# Original points
if show_points:
fig.add_trace(
go.Scatter(
x=x_vals,
y=y_vals,
mode="markers",
marker=dict(color=color, size=MARKER_SIZE, opacity=0.85),
name=trace_label,
showlegend=False,
hovertemplate=f"{x_label}: %{{x}}<br>{subplot_titles[subplot_idx]}: %{{y:.2f}}<extra></extra>",
),
row=r,
col=c,
)
# Axis titles for each subplot
yaxis_name = f"yaxis{subplot_idx + 1}" if subplot_idx > 0 else "yaxis"
xaxis_name = f"xaxis{subplot_idx + 1}" if subplot_idx > 0 else "xaxis"
fig.update_layout(
{
yaxis_name: {
"title": {"text": y_label, "font": {"size": font_config.medium}}
}
}
)
fig.update_layout(
{
xaxis_name: {
"title": {"text": x_label, "font": {"size": font_config.medium}}
}
}
)
# Layout and styling
fig.update_layout(
title={
"text": title,
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
},
width=max(layout_config.width, layout_config.grid_width_per_col * cols),
height=max(layout_config.height, layout_config.grid_height_per_row * rows),
plot_bgcolor=layout_config.plot_bgcolor,
showlegend=show_legend,
legend=dict(
orientation="v",
yanchor="top",
xanchor="left",
),
margin=layout_config.margin,
)
# Subtle gridlines and zero lines
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="rgba(211, 211, 211, 0.3)")
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="rgba(211, 211, 211, 0.3)")
fig.update_xaxes(
zeroline=True, zerolinewidth=1, zerolinecolor="rgba(211, 211, 211, 0.8)"
)
fig.update_yaxes(
zeroline=True, zerolinewidth=1, zerolinecolor="rgba(211, 211, 211, 0.8)"
)
return fig