Source code for suboptimumg.plotting.overlay_plot_2d

from typing import Dict, Optional

import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go

from .plotting_constants import *
from .utils import color_cycle, prepare_smooth_data_2D


[docs] def overlay_plot2D( x_data_dict: Dict[str, npt.NDArray[np.number]], y_data_dict: Dict[str, npt.NDArray[np.number]], title: str, x_axis: str, y_axis: str, *, h_line: Optional[ReferenceLineConfig] = None, v_line: Optional[ReferenceLineConfig] = None, fit_curve: bool = False, show_points: bool = False, subtitle: Optional[str] = None, font_config: FontConfig = DEFAULT_FONT_CONFIG, layout_config: LayoutConfig = DEFAULT_LAYOUT_CONFIG, smoothing_config: SmoothingConfig = DEFAULT_SMOOTHING_CONFIG, ) -> go.Figure: """ Generic 2D overlay plotting utility for plotting multiple series on the same graph. Parameters ---------- x_data_dict : Dict[str, npt.NDArray[np.number]] Dictionary mapping trace labels to x data arrays y_data_dict : Dict[str, npt.NDArray[np.number]] Dictionary mapping trace labels to y data arrays title : str Plot title x_axis : str X-axis label y_axis : str Y-axis label h_line : ReferenceLineConfig, optional ReferenceLineConfig for horizontal reference line (None = no line) v_line : ReferenceLineConfig, optional ReferenceLineConfig for vertical reference line (None = no line) fit_curve : bool, optional Whether to fit a curve to the data points show_points : bool, optional Whether to show individual data points subtitle : str, optional Optional subtitle font_config : FontConfig, optional FontConfig object for font settings layout_config : LayoutConfig, optional LayoutConfig object for layout settings smoothing_config : SmoothingConfig, optional SmoothingConfig object for smoothing settings Returns ------- go.Figure Plotly figure object """ fig = go.Figure() # Get trace labels and assign colors trace_labels = sorted(list(y_data_dict.keys())) colors = color_cycle(len(trace_labels)) # Create traces for each series for trace_label, color in zip(trace_labels, colors): x_vals = x_data_dict[trace_label] y_vals = y_data_dict[trace_label] if fit_curve: x_dense, y_dense = prepare_smooth_data_2D( x_vals, y_vals, smoothing_config=smoothing_config ) else: x_dense, y_dense = x_vals, y_vals # Interpolated line fig.add_trace( go.Scatter( x=x_dense, y=y_dense, mode="lines", line=dict( color=color, width=LINE_WIDTH, shape="spline", ), name=trace_label, hovertemplate=f"{x_axis}: %{{x:{FLOAT_PRECISION}}}<br>{y_axis}: %{{y:{FLOAT_PRECISION}}}<extra></extra>", ) ) if show_points: fig.add_trace( go.Scatter( x=x_vals, y=y_vals, mode="markers", marker=dict(color=color, size=MARKER_SIZE_LARGE), showlegend=False, hovertemplate=None, ) ) # Add horizontal reference line if specified if h_line is not None and h_line.value is not None: # Get x range across all traces all_x = np.concatenate([x_data_dict[label] for label in trace_labels]) x_min, x_max = min(all_x), max(all_x) fig.add_trace( go.Scatter( x=[x_min, x_max], y=[h_line.value, h_line.value], mode="lines", line=dict( color=h_line.color, width=h_line.width, dash=h_line.dash, ), name=f"{h_line.label}", hoverinfo="name", ) ) # Add vertical reference line if specified if v_line is not None and v_line.value is not None: # Get y range across all traces all_y = np.concatenate([y_data_dict[label] for label in trace_labels]) y_min, y_max = min(all_y), max(all_y) fig.add_trace( go.Scatter( x=[v_line.value, v_line.value], y=[y_min, y_max], mode="lines", line=dict( color=v_line.color, width=v_line.width, dash=v_line.dash, ), name=f"{v_line.label}", hoverinfo="name", ) ) # Title, Axes, Label, Legend, Size full_title = title if subtitle is not None: full_title += f"<br><span style='font-size: {font_config.medium}px; color: {TEXT_COLOR_LIGHT};'>{subtitle}</span>" fig.update_layout( title={ "text": full_title, "font": dict(size=font_config.large, color=TEXT_COLOR_DARK), "x": layout_config.title_x, "xanchor": layout_config.title_xanchor, }, xaxis_title={ "text": x_axis, "font": dict(size=font_config.medium, color=TEXT_COLOR_DARK), }, yaxis_title={ "text": y_axis, "font": dict(size=font_config.medium, color=TEXT_COLOR_DARK), }, legend_title="Legend", showlegend=True, legend_title_font=dict(size=font_config.medium), legend_font=dict(size=font_config.small), hovermode=HOVER_MODE, plot_bgcolor=layout_config.plot_bgcolor, width=layout_config.width, height=layout_config.height, margin=layout_config.margin, ) # Add subtle grid lines fig.update_xaxes( showgrid=True, gridwidth=GRID_WIDTH, gridcolor=GRID_COLOR, zeroline=True, zerolinewidth=ZEROLINE_WIDTH, zerolinecolor=ZEROLINE_COLOR, tickfont=dict(size=font_config.small), tickformat=FLOAT_PRECISION, ) fig.update_yaxes( showgrid=True, gridwidth=GRID_WIDTH, gridcolor=GRID_COLOR, zeroline=True, zerolinewidth=ZEROLINE_WIDTH, zerolinecolor=ZEROLINE_COLOR, tickfont=dict(size=font_config.small), tickformat=FLOAT_PRECISION, ) if h_line is not None and h_line.value is not None: all_y = np.concatenate([y_data_dict[label] for label in trace_labels]) padding = 1 + RANGE_PADDING fig.update_yaxes( range=[ (1 / padding) * min(min(all_y), h_line.value), padding * max(max(all_y), h_line.value), ] ) if v_line is not None and v_line.value is not None: all_x = np.concatenate([x_data_dict[label] for label in trace_labels]) padding = 1 + RANGE_PADDING fig.update_xaxes( range=[ (1 / padding) * min(min(all_x), v_line.value), padding * max(max(all_x), v_line.value), ] ) return fig