from typing import Dict, Optional
import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from .plotting_constants import *
from .utils import color_cycle, prepare_smooth_data_2D
[docs]
def overlay_plot2D(
x_data_dict: Dict[str, npt.NDArray[np.number]],
y_data_dict: Dict[str, npt.NDArray[np.number]],
title: str,
x_axis: str,
y_axis: str,
*,
h_line: Optional[ReferenceLineConfig] = None,
v_line: Optional[ReferenceLineConfig] = None,
fit_curve: bool = False,
show_points: bool = False,
subtitle: Optional[str] = None,
font_config: FontConfig = DEFAULT_FONT_CONFIG,
layout_config: LayoutConfig = DEFAULT_LAYOUT_CONFIG,
smoothing_config: SmoothingConfig = DEFAULT_SMOOTHING_CONFIG,
) -> go.Figure:
"""
Generic 2D overlay plotting utility for plotting multiple series on the same graph.
Parameters
----------
x_data_dict : Dict[str, npt.NDArray[np.number]]
Dictionary mapping trace labels to x data arrays
y_data_dict : Dict[str, npt.NDArray[np.number]]
Dictionary mapping trace labels to y data arrays
title : str
Plot title
x_axis : str
X-axis label
y_axis : str
Y-axis label
h_line : ReferenceLineConfig, optional
ReferenceLineConfig for horizontal reference line (None = no line)
v_line : ReferenceLineConfig, optional
ReferenceLineConfig for vertical reference line (None = no line)
fit_curve : bool, optional
Whether to fit a curve to the data points
show_points : bool, optional
Whether to show individual data points
subtitle : str, optional
Optional subtitle
font_config : FontConfig, optional
FontConfig object for font settings
layout_config : LayoutConfig, optional
LayoutConfig object for layout settings
smoothing_config : SmoothingConfig, optional
SmoothingConfig object for smoothing settings
Returns
-------
go.Figure
Plotly figure object
"""
fig = go.Figure()
# Get trace labels and assign colors
trace_labels = sorted(list(y_data_dict.keys()))
colors = color_cycle(len(trace_labels))
# Create traces for each series
for trace_label, color in zip(trace_labels, colors):
x_vals = x_data_dict[trace_label]
y_vals = y_data_dict[trace_label]
if fit_curve:
x_dense, y_dense = prepare_smooth_data_2D(
x_vals, y_vals, smoothing_config=smoothing_config
)
else:
x_dense, y_dense = x_vals, y_vals
# Interpolated line
fig.add_trace(
go.Scatter(
x=x_dense,
y=y_dense,
mode="lines",
line=dict(
color=color,
width=LINE_WIDTH,
shape="spline",
),
name=trace_label,
hovertemplate=f"{x_axis}: %{{x:{FLOAT_PRECISION}}}<br>{y_axis}: %{{y:{FLOAT_PRECISION}}}<extra></extra>",
)
)
if show_points:
fig.add_trace(
go.Scatter(
x=x_vals,
y=y_vals,
mode="markers",
marker=dict(color=color, size=MARKER_SIZE_LARGE),
showlegend=False,
hovertemplate=None,
)
)
# Add horizontal reference line if specified
if h_line is not None and h_line.value is not None:
# Get x range across all traces
all_x = np.concatenate([x_data_dict[label] for label in trace_labels])
x_min, x_max = min(all_x), max(all_x)
fig.add_trace(
go.Scatter(
x=[x_min, x_max],
y=[h_line.value, h_line.value],
mode="lines",
line=dict(
color=h_line.color,
width=h_line.width,
dash=h_line.dash,
),
name=f"{h_line.label}",
hoverinfo="name",
)
)
# Add vertical reference line if specified
if v_line is not None and v_line.value is not None:
# Get y range across all traces
all_y = np.concatenate([y_data_dict[label] for label in trace_labels])
y_min, y_max = min(all_y), max(all_y)
fig.add_trace(
go.Scatter(
x=[v_line.value, v_line.value],
y=[y_min, y_max],
mode="lines",
line=dict(
color=v_line.color,
width=v_line.width,
dash=v_line.dash,
),
name=f"{v_line.label}",
hoverinfo="name",
)
)
# Title, Axes, Label, Legend, Size
full_title = title
if subtitle is not None:
full_title += f"<br><span style='font-size: {font_config.medium}px; color: {TEXT_COLOR_LIGHT};'>{subtitle}</span>"
fig.update_layout(
title={
"text": full_title,
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
},
xaxis_title={
"text": x_axis,
"font": dict(size=font_config.medium, color=TEXT_COLOR_DARK),
},
yaxis_title={
"text": y_axis,
"font": dict(size=font_config.medium, color=TEXT_COLOR_DARK),
},
legend_title="Legend",
showlegend=True,
legend_title_font=dict(size=font_config.medium),
legend_font=dict(size=font_config.small),
hovermode=HOVER_MODE,
plot_bgcolor=layout_config.plot_bgcolor,
width=layout_config.width,
height=layout_config.height,
margin=layout_config.margin,
)
# Add subtle grid lines
fig.update_xaxes(
showgrid=True,
gridwidth=GRID_WIDTH,
gridcolor=GRID_COLOR,
zeroline=True,
zerolinewidth=ZEROLINE_WIDTH,
zerolinecolor=ZEROLINE_COLOR,
tickfont=dict(size=font_config.small),
tickformat=FLOAT_PRECISION,
)
fig.update_yaxes(
showgrid=True,
gridwidth=GRID_WIDTH,
gridcolor=GRID_COLOR,
zeroline=True,
zerolinewidth=ZEROLINE_WIDTH,
zerolinecolor=ZEROLINE_COLOR,
tickfont=dict(size=font_config.small),
tickformat=FLOAT_PRECISION,
)
if h_line is not None and h_line.value is not None:
all_y = np.concatenate([y_data_dict[label] for label in trace_labels])
padding = 1 + RANGE_PADDING
fig.update_yaxes(
range=[
(1 / padding) * min(min(all_y), h_line.value),
padding * max(max(all_y), h_line.value),
]
)
if v_line is not None and v_line.value is not None:
all_x = np.concatenate([x_data_dict[label] for label in trace_labels])
padding = 1 + RANGE_PADDING
fig.update_xaxes(
range=[
(1 / padding) * min(min(all_x), v_line.value),
padding * max(max(all_x), v_line.value),
]
)
return fig