from typing import Callable, Dict, List, Optional, Union
import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from scipy.differentiate import derivative
from scipy.interpolate import RegularGridInterpolator
from ..plotting.generic_plot import plot3D_contour, plot3D_surface
from ..plotting.grid_plot import PlotType, plot_grid_3D
from ..plotting.plotting_constants import *
from .constants import STANDARD_GRID_PLOT, TIME_GRID_PLOT, TITLE_MAPPING
from .models import SweepData2D
from .types import SweepDatatype
[docs]
class SweepResults2Var:
def __init__(
self,
sweep_data: SweepData2D,
dependencies: Dict[str, Callable] = {},
) -> None:
self.sweep_data = sweep_data
self.dependencies = dependencies
[docs]
def z_at_xy(
self,
x: Union[float, npt.NDArray[np.float64]],
y: Union[float, npt.NDArray[np.float64]],
z_var: SweepDatatype = SweepDatatype.TOTAL_PTS,
) -> float:
z_list = getattr(self.sweep_data, z_var.value)
f = RegularGridInterpolator(
(self.sweep_data.var_list_1, self.sweep_data.var_list_2),
z_list,
method="cubic",
bounds_error=True,
)
pts = (
np.column_stack((x, y))
if isinstance(x, list) and isinstance(y, list)
else np.array([x, y])
)
return f(pts)[0]
[docs]
def dzdx_at_xy(
self,
x: float,
y: float,
z_var: SweepDatatype = SweepDatatype.TOTAL_PTS,
) -> float:
z_list = getattr(self.sweep_data, z_var.value)
f = RegularGridInterpolator(
(self.sweep_data.var_list_1, self.sweep_data.var_list_2),
z_list,
method="cubic",
bounds_error=True,
)
def f_x(
x_val: npt.NDArray[np.float64],
) -> Union[float, npt.NDArray[np.float64]]:
pts = np.column_stack([x_val, np.full_like(x_val, y)])
return f(pts)[0] if x_val.shape == np.shape(0) else f(pts)
dx = (self.sweep_data.var_list_1[-1] - self.sweep_data.var_list_1[0]) * 0.001
return derivative(f_x, x, initial_step=dx, preserve_shape=True).df
[docs]
def dzdy_at_xy(
self,
x: float,
y: float,
z_var: SweepDatatype = SweepDatatype.TOTAL_PTS,
) -> float:
z_list = getattr(self.sweep_data, z_var.value)
f = RegularGridInterpolator(
(self.sweep_data.var_list_1, self.sweep_data.var_list_2),
z_list,
method="cubic",
bounds_error=True,
)
def f_y(
y_val: npt.NDArray[np.float64],
) -> Union[float, npt.NDArray[np.float64]]:
pts = np.column_stack([np.full_like(y_val, x), y_val])
return f(pts)[0] if y_val.shape == np.shape(0) else f(pts)
dy = (self.sweep_data.var_list_2[-1] - self.sweep_data.var_list_2[0]) * 0.001
return derivative(f_y, y, initial_step=dy, preserve_shape=True).df
def _get_z_list(
self,
z_var: SweepDatatype = SweepDatatype.TOTAL_PTS,
) -> npt.NDArray[np.float64]:
"""Helper function to retrieve the z-data based on the variable name."""
return getattr(self.sweep_data, z_var.value)
[docs]
def plot_contour(
self,
z_var: SweepDatatype = SweepDatatype.TOTAL_PTS,
*,
title: Optional[str] = None,
subtitle: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
z_label: Optional[str] = None,
**kwargs,
) -> go.Figure:
"""
Generates a single contour plot.
Parameters
----------
z_var : SweepDatatype, optional
SweepDatatype to plot on z-axis (default is SweepDatatype.TOTAL_PTS)
title : str, optional
Plot title (auto-generated if None)
subtitle : str, optional
Plot subtitle (auto-generated from dependencies if None)
x_label : str, optional
X-axis label (defaults to var_name_1)
y_label : str, optional
Y-axis label (defaults to var_name_2)
z_label : str, optional
Z-axis label (defaults to TITLE_MAPPING)
**kwargs : dict
Additional arguments passed to plot3D_contour (theme, font_config,
layout_config, colorbar_config, smoothing_config)
Returns
-------
go.Figure
Plotly figure object
"""
# Get data
z_list = getattr(self.sweep_data, z_var.value)
# Build title
final_title = (
title
or f"{self.sweep_data.var_name_1} and {self.sweep_data.var_name_2} Sweep"
)
final_subtitle = subtitle
if final_subtitle is None and len(self.dependencies) > 0:
dep_str = " ".join(self.dependencies.keys())
final_subtitle = f"Dependent Params: {dep_str}"
# Get nice z-axis label from TITLE_MAPPING using enum
final_z_label = z_label or TITLE_MAPPING[z_var]
final_x_label = x_label or self.sweep_data.var_name_1
final_y_label = y_label or self.sweep_data.var_name_2
# Use generic plot3D_contour function
return plot3D_contour(
x_list=self.sweep_data.var_list_1,
y_list=self.sweep_data.var_list_2,
z_list=z_list,
title=final_title,
x_axis=final_x_label,
y_axis=final_y_label,
z_axis=final_z_label,
subtitle=final_subtitle,
**kwargs,
)
[docs]
def plot_surface(
self,
z_var: SweepDatatype = SweepDatatype.TOTAL_PTS,
*,
title: Optional[str] = None,
subtitle: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
z_label: Optional[str] = None,
**kwargs,
) -> go.Figure:
"""
Generates a single surface plot.
Parameters
----------
z_var : SweepDatatype, optional
SweepDatatype to plot on z-axis (default is SweepDatatype.TOTAL_PTS)
title : str, optional
Plot title (auto-generated if None)
subtitle : str, optional
Plot subtitle (auto-generated from dependencies if None)
x_label : str, optional
X-axis label (defaults to var_name_1)
y_label : str, optional
Y-axis label (defaults to var_name_2)
z_label : str, optional
Z-axis label (defaults to TITLE_MAPPING)
**kwargs : dict
Additional arguments passed to plot3D_surface (theme, font_config,
layout_config, colorbar_config, smoothing_config, scene_config)
Returns
-------
go.Figure
Plotly figure object
"""
# Get data
z_list = getattr(self.sweep_data, z_var.value)
# Build title
final_title = (
title
or f"{self.sweep_data.var_name_1} and {self.sweep_data.var_name_2} Sweep"
)
final_subtitle = subtitle
if final_subtitle is None and len(self.dependencies) > 0:
dep_str = " ".join(self.dependencies.keys())
final_subtitle = f"Dependent Params: {dep_str}"
# Get nice z-axis label from TITLE_MAPPING using enum
final_z_label = z_label or TITLE_MAPPING[z_var]
final_x_label = x_label or self.sweep_data.var_name_1
final_y_label = y_label or self.sweep_data.var_name_2
# Use generic plot3D_surface function
return plot3D_surface(
x_list=self.sweep_data.var_list_1,
y_list=self.sweep_data.var_list_2,
z_list=z_list,
title=final_title,
x_axis=final_x_label,
y_axis=final_y_label,
z_axis=final_z_label,
subtitle=final_subtitle,
**kwargs,
)
[docs]
def grid_plot(
self,
plot_type: PlotType,
show_event_times: bool = False,
rows: int = 2,
cols: int = 3,
*,
title: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
**kwargs,
) -> go.Figure:
"""
Creates a dynamic grid plot for event results.
Parameters
----------
plot_type : PlotType
PlotType enum (CONTOUR or SURFACE). Required.
show_event_times : bool, optional
Whether to show event times on the plot instead of points (default is False)
rows : int, optional
Number of rows in the grid (default is 2)
cols : int, optional
Number of columns in the grid (default is 3)
title : str, optional
Overall grid title. If empty, a default is generated.
x_label : str, optional
Common X-axis label. Defaults to var_name_1.
y_label : str, optional
Common Y-axis label. Defaults to var_name_2.
**kwargs : dict
Additional arguments passed to plot_grid_3D (theme, font_config,
layout_config, colorbar_config, smoothing_config)
Returns
-------
go.Figure
Plotly figure object
"""
variables = STANDARD_GRID_PLOT if not show_event_times else TIME_GRID_PLOT
# Subplot titles using enum-based TITLE_MAPPING
subplot_titles = [TITLE_MAPPING[var] for var in variables]
# Build z_data_dict and z_label_dict
z_data_dict = {}
z_label_dict = {}
for var in variables:
z_data_dict[var.value] = getattr(self.sweep_data, var.value)
z_label_dict[var.value] = TITLE_MAPPING[var]
# Generate final labels and title
final_title = (
title
or f"{self.sweep_data.var_name_1} and {self.sweep_data.var_name_2} Sweep Grid Plot"
)
if self.dependencies:
dep_str = " ".join(self.dependencies.keys())
final_title += f"<br><span style='font-size: 18px; color: #606060;'>Dependent Params: {dep_str}</span>"
final_x_label = x_label or self.sweep_data.var_name_1
final_y_label = y_label or self.sweep_data.var_name_2
# Use generic grid plot function
return plot_grid_3D(
x_list=self.sweep_data.var_list_1,
y_list=self.sweep_data.var_list_2,
z_data_dict=z_data_dict,
subplot_titles=subplot_titles,
title=final_title,
x_label=final_x_label,
y_label=final_y_label,
z_label_dict=z_label_dict,
rows=rows,
cols=cols,
plot_type=plot_type,
**kwargs,
)