from typing import Callable, Dict, List, Optional, Union
import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from scipy.differentiate import derivative
from scipy.interpolate import interp1d
from ..plotting.generic_plot import plot2D
from ..plotting.grid_plot import plot_grid_2D
from .constants import STANDARD_GRID_PLOT, TIME_GRID_PLOT, TITLE_MAPPING
from .models import SweepData1D
from .types import SweepDatatype
[docs]
class SweepResults1Var:
def __init__(
self,
sweep_data: SweepData1D,
dependencies: Dict[str, Callable] = {},
) -> None:
self.sweep_data = sweep_data
self.dependencies = dependencies
[docs]
def y_at_x(
self,
x: Union[float, npt.NDArray[np.float64]],
y_var: SweepDatatype = SweepDatatype.TOTAL_PTS,
) -> Union[float, npt.NDArray[np.float64]]:
y_list = getattr(self.sweep_data, y_var.value)
f = interp1d(
self.sweep_data.sweep_values, y_list, kind="cubic", bounds_error=True
)
return f(x)
[docs]
def dydx_at_x(
self,
x: float,
y_var: SweepDatatype = SweepDatatype.TOTAL_PTS,
) -> float:
y_list = getattr(self.sweep_data, y_var.value)
f = interp1d(
self.sweep_data.sweep_values, y_list, kind="cubic", bounds_error=True
)
dx = (
self.sweep_data.sweep_values[-1] - self.sweep_data.sweep_values[0]
) * 0.001
return derivative(f, x, initial_step=dx).df
[docs]
def plot(
self,
y_var: SweepDatatype = SweepDatatype.TOTAL_PTS,
*,
title: Optional[str] = None,
subtitle: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
data_label: Optional[str] = None,
**kwargs,
) -> go.Figure:
"""
Plot a single sweep datatype.
Parameters
----------
y_var : SweepDatatype, optional
SweepDatatype to plot on y-axis (default is SweepDatatype.TOTAL_PTS)
title : str, optional
Plot title (defaults to "{var_name} Sweep")
subtitle : str, optional
Plot subtitle (auto-generated from dependencies if None)
x_label : str, optional
X-axis label (defaults to var_name)
y_label : str, optional
Y-axis label (defaults to TITLE_MAPPING)
data_label : str, optional
Label for the data series in legend
**kwargs : dict
Additional arguments passed to plot2D (h_line, v_line, fit_curve,
show_points, theme, font_config, layout_config, smoothing_config)
Returns
-------
go.Figure
Plotly figure object
"""
y_list = getattr(self.sweep_data, y_var.value)
# Build title
final_title = title or f"{self.sweep_data.var_name} Sweep"
# Build subtitle from dependencies if not provided
final_subtitle = subtitle
if final_subtitle is None and len(self.dependencies) > 0:
dep_str = " ".join(self.dependencies.keys())
final_subtitle = f"Dependent Params: {dep_str}"
# Get nice y-axis label from TITLE_MAPPING using enum
final_y_label = y_label or TITLE_MAPPING[y_var]
final_x_label = x_label or self.sweep_data.var_name
final_data_label = data_label or "Interpolated"
# Use generic plot2D function
return plot2D(
x_list=self.sweep_data.sweep_values,
y_list=y_list,
title=final_title,
x_axis=final_x_label,
y_axis=final_y_label,
subtitle=final_subtitle,
data_label=final_data_label,
**kwargs,
)
[docs]
def grid_plot(
self,
show_event_times: bool = False,
rows: int = 2,
cols: int = 3,
*,
title: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
**kwargs,
) -> go.Figure:
"""
Creates a dynamic grid plot for 1D event results.
Parameters
----------
show_event_times : bool, optional
Whether to show event times on the plot instead of points (default is False)
rows : int, optional
Number of rows in the grid (default is 2)
cols : int, optional
Number of columns in the grid (default is 3)
title : str, optional
Overall grid title. If empty, a default is generated.
x_label : str, optional
Common X-axis label. Defaults to var_name.
y_label : str, optional
Common Y-axis label.
**kwargs : dict
Additional arguments passed to plot_grid_2D (fit_curve, theme,
font_config, layout_config, smoothing_config)
Returns
-------
go.Figure
Plotly figure object
"""
variables = STANDARD_GRID_PLOT if not show_event_times else TIME_GRID_PLOT
# Subplot titles using enum-based TITLE_MAPPING
subplot_titles = [TITLE_MAPPING[var] for var in variables]
# Build y_data_dict
y_data_dict = {
var.value: getattr(self.sweep_data, var.value) for var in variables
}
# Generate final labels and title
final_title = title or f"{self.sweep_data.var_name} Sweep Grid Plot"
if self.dependencies:
dep_str = " ".join(self.dependencies.keys())
final_title += f"<br><span style='font-size: 18px; color: #606060;'>Dependent Params: {dep_str}</span>"
final_x_label = x_label or self.sweep_data.var_name
final_y_label = y_label or ""
# Use generic grid plot function
return plot_grid_2D(
x_list=self.sweep_data.sweep_values,
y_data_dict=y_data_dict,
subplot_titles=subplot_titles,
title=final_title,
x_label=final_x_label,
y_label=final_y_label,
rows=rows,
cols=cols,
**kwargs,
)