from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from ..plotting.color_themes import get_theme
from .plotting_constants import *
from .utils import (
DataType3D,
_validate_data,
prepare_smooth_data_2D,
prepare_smooth_data_3D,
prepare_smooth_data_3D_scatter,
)
# Set default plotly template for better aesthetics
pio.templates.default = "plotly_white"
[docs]
def plot2D(
x_list,
y_list,
title,
x_axis,
y_axis,
h_line: Optional[ReferenceLineConfig] = None,
v_line: Optional[ReferenceLineConfig] = None,
fit_curve=False,
show_points=False,
subtitle=None,
data_label=None,
theme=None,
font_config: Optional[FontConfig] = None,
layout_config: Optional[LayoutConfig] = None,
smoothing_config: Optional[SmoothingConfig] = None,
):
"""
Generic 2D plotting utility.
Parameters
----------
x_list : numpy.ndarray
Numpy array of x coordinates
y_list : numpy.ndarray
Numpy array of y coordinates
title : str
Plot title
x_axis : str
X-axis label
y_axis : str
Y-axis label
h_line : ReferenceLineConfig, optional
ReferenceLineConfig for horizontal reference line (None = no line)
v_line : ReferenceLineConfig, optional
ReferenceLineConfig for vertical reference line (None = no line)
fit_curve : bool, optional
Whether to fit a curve to the data points
show_points : bool, optional
Whether to show individual data points
subtitle : str, optional
Optional subtitle
data_label : str, optional
Curve label used in legend
theme : str, optional
Color theme name, see plotting.color_theme
font_config : FontConfig, optional
FontConfig object for font settings
layout_config : LayoutConfig, optional
LayoutConfig object for layout settings
smoothing_config : SmoothingConfig, optional
SmoothingConfig object for smoothing settings
Returns
-------
go.Figure
Plotly figure object
"""
font_config = font_config or DEFAULT_FONT_CONFIG
layout_config = layout_config or DEFAULT_LAYOUT_CONFIG
smoothing_config = smoothing_config or DEFAULT_SMOOTHING_CONFIG
if fit_curve:
x_dense, y_dense = prepare_smooth_data_2D(
x_list, y_list, smoothing_config=smoothing_config
)
else:
x_dense, y_dense = x_list, y_list
fig = go.Figure()
theme_colors = get_theme(theme)
# Interpolated line
fig.add_trace(
go.Scatter(
x=x_dense,
y=y_dense,
mode="lines",
line=dict(
color=theme_colors["light"],
width=LINE_WIDTH,
shape="spline",
),
name=data_label,
hovertemplate=f"{x_axis}: %{{x:{FLOAT_PRECISION}}}<br>{y_axis}: %{{y:{FLOAT_PRECISION}}}<extra></extra>",
)
)
if show_points:
fig.add_trace(
go.Scatter(
x=x_list,
y=y_list,
mode="markers",
marker=dict(color=theme_colors["dark"], size=MARKER_SIZE_LARGE),
showlegend=False,
hovertemplate=None,
)
)
# Add horizontal reference line if specified
if h_line is not None and h_line.value is not None:
x_min, x_max = min(x_dense), max(x_dense)
fig.add_trace(
go.Scatter(
x=[x_min, x_max],
y=[h_line.value, h_line.value],
mode="lines",
line=dict(
color=h_line.color,
width=h_line.width,
dash=h_line.dash,
),
name=f"{h_line.label}",
hoverinfo="name",
)
)
# Add vertical reference line if specified
if v_line is not None and v_line.value is not None:
y_min, y_max = min(y_dense), max(y_dense)
fig.add_trace(
go.Scatter(
x=[v_line.value, v_line.value],
y=[y_min, y_max],
mode="lines",
line=dict(
color=v_line.color,
width=v_line.width,
dash=v_line.dash,
),
name=f"{v_line.label}",
hoverinfo="name",
)
)
# Title, Axes, Label, Legend, Size
full_title = title
if subtitle is not None:
full_title += f"<br><span style='font-size: {font_config.medium}px; color: {TEXT_COLOR_LIGHT};'>{subtitle}</span>"
fig.update_layout(
title={
"text": full_title,
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
},
xaxis_title={
"text": x_axis,
"font": dict(size=font_config.medium, color=TEXT_COLOR_DARK),
},
yaxis_title={
"text": y_axis,
"font": dict(size=font_config.medium, color=TEXT_COLOR_DARK),
},
legend_title="Legend",
showlegend=(data_label is not None)
or (h_line is not None and h_line.value is not None)
or (v_line is not None and v_line.value is not None),
legend_title_font=dict(size=font_config.medium),
legend_font=dict(size=font_config.small),
hovermode=HOVER_MODE,
plot_bgcolor=layout_config.plot_bgcolor,
width=layout_config.width,
height=layout_config.height,
margin=layout_config.margin,
)
# Add subtle grid lines
fig.update_xaxes(
showgrid=True,
gridwidth=GRID_WIDTH,
gridcolor=GRID_COLOR,
zeroline=True,
zerolinewidth=ZEROLINE_WIDTH,
zerolinecolor=ZEROLINE_COLOR,
tickfont=dict(size=font_config.small),
tickformat=FLOAT_PRECISION,
)
fig.update_yaxes(
showgrid=True,
gridwidth=GRID_WIDTH,
gridcolor=GRID_COLOR,
zeroline=True,
zerolinewidth=ZEROLINE_WIDTH,
zerolinecolor=ZEROLINE_COLOR,
tickfont=dict(size=font_config.small),
tickformat=FLOAT_PRECISION,
)
if h_line is not None and h_line.value is not None:
padding = 1 + RANGE_PADDING
fig.update_yaxes(
range=[
(1 / padding) * min(min(y_dense), h_line.value),
padding * max(max(y_dense), h_line.value),
]
)
if v_line is not None and v_line.value is not None:
padding = 1 + RANGE_PADDING
fig.update_xaxes(
range=[
(1 / padding) * min(min(x_dense), v_line.value),
padding * max(max(x_dense), v_line.value),
]
)
return fig
[docs]
def plot3D_surface(
x_list,
y_list,
z_list,
title,
x_axis,
y_axis,
z_axis,
subtitle=None,
theme=None,
font_config: Optional[FontConfig] = None,
layout_config: Optional[LayoutConfig] = None,
colorbar_config: Optional[ColorbarConfig] = None,
smoothing_config: Optional[SmoothingConfig] = None,
scene_config: Optional[SceneConfig] = None,
):
"""
Generic 3D surface plotting utility.
Handles both grid data (2D z array) and scatter data (1D x, y, z arrays).
For grid data: Creates interpolated surface plot with contour lines.
For scatter data: Creates 3D scatter plot with points colored by z values.
Parameters
----------
x_list : numpy.ndarray
X coordinates (1D for grid, 1D for scatter)
y_list : numpy.ndarray
Y coordinates (1D for grid, 1D for scatter)
z_list : numpy.ndarray
Z values (2D for grid, 1D for scatter)
title : str
Plot title
x_axis : str
X-axis label
y_axis : str
Y-axis label
z_axis : str
Z-axis label
subtitle : str, optional
Optional subtitle
theme : str, optional
Color theme name
font_config : FontConfig, optional
FontConfig object for font settings
layout_config : LayoutConfig, optional
LayoutConfig object for layout settings
colorbar_config : ColorbarConfig, optional
ColorbarConfig object for colorbar settings
smoothing_config : SmoothingConfig, optional
SmoothingConfig object for smoothing settings
scene_config : SceneConfig, optional
SceneConfig object for 3D scene settings
Returns
-------
go.Figure
Plotly figure object
"""
font_config = font_config or DEFAULT_FONT_CONFIG
layout_config = layout_config or DEFAULT_LAYOUT_CONFIG
colorbar_config = colorbar_config or DEFAULT_COLORBAR_CONFIG
smoothing_config = smoothing_config or DEFAULT_SMOOTHING_CONFIG
scene_config = scene_config or DEFAULT_SCENE_CONFIG
# Detect data type
data_type = _validate_data(x_list, y_list, z_list)
# Create figure and add trace
fig = go.Figure()
theme_colors = get_theme(theme)
# Prepare smoothed/interpolated data based on input type
match data_type:
case DataType3D.GridInput:
x_dense, y_dense, z_dense = prepare_smooth_data_3D(
x_list, y_list, z_list, smoothing_config=smoothing_config
)
case DataType3D.ScatterInput:
x_dense, y_dense, z_dense = prepare_smooth_data_3D_scatter(
x_list, y_list, z_list, smoothing_config=smoothing_config
)
# Calculate contour size (use nanmin/nanmax to handle potential NaN values)
z_min, z_max = np.nanmin(z_dense), np.nanmax(z_dense)
contour_size = (z_max - z_min) / NUM_CONTOURS if z_max > z_min else 1
# Add surface plot
fig.add_trace(
go.Surface(
x=np.array([x_dense] * len(x_dense)),
y=np.array([y_dense] * len(x_dense)).T,
z=z_dense,
colorscale=theme_colors["colorscale"],
contours={
"z": {
"show": True,
"start": z_min,
"end": z_max,
"size": contour_size,
"width": LINE_WIDTH,
"color": CONTOUR_LABEL_COLOR,
}
},
colorbar=dict(
title=dict(
text=z_axis,
font=dict(size=font_config.medium),
),
thickness=colorbar_config.thickness,
len=colorbar_config.length,
tickfont=dict(size=font_config.small),
),
hovertemplate=(
f"{x_axis}: %{{x:{FLOAT_PRECISION}}}<br>"
f"{y_axis}: %{{y:{FLOAT_PRECISION}}}<br>"
f"{z_axis}: %{{z:{FLOAT_PRECISION}}}<extra></extra>"
),
)
)
# Overlay original scatter points if input was scatter data
if data_type == DataType3D.ScatterInput:
fig.add_trace(
go.Scatter3d(
x=x_list,
y=y_list,
z=z_list,
mode="markers",
marker=dict(
color="black",
size=MARKER_SIZE / 2,
opacity=0.5,
),
showlegend=False,
hovertemplate=(
f"Original point<br>"
f"{x_axis}: %{{x:{FLOAT_PRECISION}}}<br>"
f"{y_axis}: %{{y:{FLOAT_PRECISION}}}<br>"
f"{z_axis}: %{{z:{FLOAT_PRECISION}}}<extra></extra>"
),
)
)
# Configure layout
full_title = title
if subtitle is not None:
full_title += f"<br><span style='font-size: {font_config.medium}px; color: {TEXT_COLOR_LIGHT};'>{subtitle}</span>"
fig.update_layout(
title={
"text": full_title,
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
"yanchor": layout_config.title_yanchor,
},
scene=dict(
xaxis_title=x_axis,
yaxis_title=y_axis,
zaxis_title=z_axis,
xaxis=dict(
gridcolor=GRID_COLOR,
showbackground=True,
backgroundcolor=layout_config.scene_bgcolor,
tickformat=FLOAT_PRECISION,
),
yaxis=dict(
gridcolor=GRID_COLOR,
showbackground=True,
backgroundcolor=layout_config.scene_bgcolor,
tickformat=FLOAT_PRECISION,
),
zaxis=dict(
gridcolor=GRID_COLOR,
showbackground=True,
backgroundcolor=layout_config.scene_bgcolor,
tickformat=FLOAT_PRECISION,
),
aspectratio=dict(
x=scene_config.aspect_ratio_x,
y=scene_config.aspect_ratio_y,
z=scene_config.aspect_ratio_z,
),
),
width=layout_config.width,
height=layout_config.height,
margin=layout_config.margin,
)
fig.update_layout(
scene_camera=dict(
center=dict(x=0, y=0, z=0),
eye=dict(
x=scene_config.camera_distance
* np.cos(np.radians(scene_config.default_view_angle)),
y=scene_config.camera_distance
* np.sin(np.radians(scene_config.default_view_angle)),
z=scene_config.camera_z,
),
up=dict(
x=scene_config.camera_up_x,
y=scene_config.camera_up_y,
z=scene_config.camera_up_z,
),
)
)
return fig
[docs]
def plot3D_contour(
x_list,
y_list,
z_list,
title,
x_axis,
y_axis,
z_axis,
subtitle=None,
theme=None,
font_config: Optional[FontConfig] = None,
layout_config: Optional[LayoutConfig] = None,
colorbar_config: Optional[ColorbarConfig] = None,
smoothing_config: Optional[SmoothingConfig] = None,
):
"""
Generic 3D contour plotting utility.
Handles both grid data (2D z array) and scatter data (1D x, y, z arrays).
For grid data: Creates interpolated contour plot with contour lines.
For scatter data: Creates scatter plot with points colored by z values.
Parameters
----------
x_list : numpy.ndarray
X coordinates (1D for grid, 1D for scatter)
y_list : numpy.ndarray
Y coordinates (1D for grid, 1D for scatter)
z_list : numpy.ndarray
Z values (2D for grid, 1D for scatter)
title : str
Plot title
x_axis : str
X-axis label
y_axis : str
Y-axis label
z_axis : str
Z-axis label
subtitle : str, optional
Optional subtitle
theme : str, optional
Color theme name
font_config : FontConfig, optional
FontConfig object for font settings
layout_config : LayoutConfig, optional
LayoutConfig object for layout settings
colorbar_config : ColorbarConfig, optional
ColorbarConfig object for colorbar settings
smoothing_config : SmoothingConfig, optional
SmoothingConfig object for smoothing settings
Returns
-------
go.Figure
Plotly figure object
"""
if font_config is None:
font_config = DEFAULT_FONT_CONFIG
if layout_config is None:
layout_config = DEFAULT_LAYOUT_CONFIG
if colorbar_config is None:
colorbar_config = DEFAULT_COLORBAR_CONFIG
if smoothing_config is None:
smoothing_config = DEFAULT_SMOOTHING_CONFIG
# Detect data type
data_type = _validate_data(x_list, y_list, z_list)
theme_colors = get_theme(theme)
fig = go.Figure()
# Prepare smoothed/interpolated data based on input type
match data_type:
case DataType3D.GridInput:
x_dense, y_dense, z_dense = prepare_smooth_data_3D(
x_list, y_list, z_list, smoothing_config=smoothing_config
)
case DataType3D.ScatterInput:
x_dense, y_dense, z_dense = prepare_smooth_data_3D_scatter(
x_list, y_list, z_list, smoothing_config=smoothing_config
)
fig.add_trace(
go.Scatter(
x=x_list,
y=y_list,
mode="markers",
marker=dict(
color="black",
size=MARKER_SIZE,
opacity=0.5,
symbol="x",
),
showlegend=False,
hovertemplate=(
f"Original point<br>"
f"{x_axis}: %{{x:{FLOAT_PRECISION}}}<br>"
f"{y_axis}: %{{y:{FLOAT_PRECISION}}}<br>"
f"{z_axis}: %{{customdata:{FLOAT_PRECISION}}}<extra></extra>"
),
customdata=z_list,
)
)
# Calculate contour size (use nanmin/nanmax to handle potential NaN values)
z_min, z_max = np.nanmin(z_dense), np.nanmax(z_dense)
contour_size = int((z_max - z_min) / NUM_CONTOURS) if z_max > z_min else 1
# Add contour plot
fig.add_trace(
go.Contour(
x=x_dense,
y=y_dense,
z=z_dense,
colorscale=theme_colors["colorscale"],
contours=dict(
showlabels=True,
labelfont=dict(
size=font_config.small,
color=CONTOUR_LABEL_COLOR,
),
start=int(z_min),
end=int(z_max),
size=contour_size,
labelformat=FLOAT_PRECISION,
),
colorbar=dict(
title=dict(
text=z_axis,
font=dict(size=font_config.medium),
),
thickness=colorbar_config.thickness,
len=colorbar_config.length,
tickfont=dict(size=font_config.small),
tickformat=FLOAT_PRECISION,
),
hovertemplate=(
f"{x_axis}: %{{x:{FLOAT_PRECISION}}}<br>"
f"{y_axis}: %{{y:{FLOAT_PRECISION}}}<br>"
f"{z_axis}: %{{z:{FLOAT_PRECISION}}}<extra></extra>"
),
)
)
full_title = title
if subtitle is not None:
full_title += f"<br><span style='font-size: {font_config.medium}px; color: {TEXT_COLOR_LIGHT};'>{subtitle}</span>"
fig.update_layout(
title={
"text": full_title,
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
},
xaxis_title={
"text": x_axis,
"font": dict(size=font_config.medium, color=TEXT_COLOR_DARK),
},
yaxis_title={
"text": y_axis,
"font": dict(size=font_config.medium, color=TEXT_COLOR_DARK),
},
width=layout_config.width,
height=layout_config.height,
margin=layout_config.margin,
plot_bgcolor=layout_config.plot_bgcolor,
)
fig.update_xaxes(
tickfont=dict(size=font_config.small),
tickformat=FLOAT_PRECISION,
)
fig.update_yaxes(
tickfont=dict(size=font_config.small),
tickformat=FLOAT_PRECISION,
)
return fig