from typing import Optional
import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
import plotly.io as pio
from .color_themes import get_theme
from .plotting_constants import *
# Set default plotly template for better aesthetics
pio.templates.default = "plotly_white"
[docs]
def plot_3d_stem(
x: npt.NDArray[np.float64],
y: npt.NDArray[np.float64],
z: npt.NDArray[np.float64],
title: Optional[str] = None,
z_label: Optional[str] = None,
theme: Optional[str] = None,
font_config: Optional["FontConfig"] = None,
layout_config: Optional["LayoutConfig"] = None,
scene_config: Optional["SceneConfig"] = None,
) -> go.Figure:
"""
Creates a 3D stem plot with vertical lines extending from baseline to data points.
This is a general-purpose 3D stem plotting utility that creates vertical lines
(stems) from a baseline to each data point, with markers at the top of each stem.
Parameters
----------
x : NDArray[float64]
X coordinates of data points
y : NDArray[float64]
Y coordinates of data points
z : NDArray[float64]
Z values (heights) of data points
title : str, optional
Plot title (default: None)
z_label : str, optional
Z-axis label (default: None)
theme : str, optional
The color theme to use (default: None)
font_config : FontConfig, optional
Font configuration object (default: None, uses DEFAULT_FONT_CONFIG)
layout_config : LayoutConfig, optional
Layout configuration object (default: None, uses DEFAULT_LAYOUT_CONFIG)
scene_config : SceneConfig, optional
Scene configuration object for 3D view (default: None, uses DEFAULT_SCENE_CONFIG)
Returns
-------
go.Figure
Plotly figure object with 3D stem plot
"""
# Use provided configs or defaults
theme_colors = get_theme(theme)
font_config = font_config or DEFAULT_FONT_CONFIG
layout_config = layout_config or DEFAULT_LAYOUT_CONFIG
scene_config = scene_config or DEFAULT_SCENE_CONFIG
# Create figure
fig = go.Figure()
# Find minimum z value for baseline
z_min = np.min(z)
# Create vertical stem lines efficiently using None-separated segments
# This creates all stems in a single trace instead of one trace per stem
stem_x = []
stem_y = []
stem_z = []
for i in range(len(x)):
# Add line segment from baseline to data point
stem_x.extend([x[i], x[i], None])
stem_y.extend([y[i], y[i], None])
stem_z.extend([z_min, z[i], None])
# Add all stems as a single trace (MUCH more efficient than loop with add_trace)
fig.add_trace(
go.Scatter3d(
x=stem_x,
y=stem_y,
z=stem_z,
mode="lines",
line=dict(
color=GRID_COLOR,
width=GRID_WIDTH,
),
showlegend=False,
hoverinfo="skip",
)
)
# Add markers at the top of each stem
fig.add_trace(
go.Scatter3d(
x=x,
y=y,
z=z,
mode="markers",
marker=dict(
size=MARKER_SIZE,
color=theme_colors["light"],
),
showlegend=False,
hovertemplate=(
f"x: %{{x:{FLOAT_PRECISION}}}<br>"
f"y: %{{y:{FLOAT_PRECISION}}}<br>"
f"{z_label if z_label else 'z'}: %{{z:{FLOAT_PRECISION}}}<extra></extra>"
),
)
)
# Configure layout
fig.update_layout(
title={
"text": title,
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
"yanchor": layout_config.title_yanchor,
},
scene=dict(
xaxis_title="x position (m)",
yaxis_title="y position (m)",
zaxis_title=z_label or "Value",
xaxis=dict(
gridcolor=GRID_COLOR,
showbackground=True,
backgroundcolor=layout_config.scene_bgcolor,
tickformat=FLOAT_PRECISION,
),
yaxis=dict(
gridcolor=GRID_COLOR,
showbackground=True,
backgroundcolor=layout_config.scene_bgcolor,
tickformat=FLOAT_PRECISION,
),
zaxis=dict(
gridcolor=GRID_COLOR,
showbackground=True,
backgroundcolor=layout_config.scene_bgcolor,
tickformat=FLOAT_PRECISION,
),
aspectratio=dict(
x=scene_config.aspect_ratio_x,
y=scene_config.aspect_ratio_y,
z=scene_config.aspect_ratio_z,
),
),
width=layout_config.width,
height=layout_config.height,
margin=layout_config.margin,
)
# Set camera angle
fig.update_layout(
scene_camera=dict(
center=dict(x=0, y=0, z=0),
eye=dict(
x=scene_config.camera_distance
* np.cos(np.radians(scene_config.default_view_angle)),
y=scene_config.camera_distance
* np.sin(np.radians(scene_config.default_view_angle)),
z=scene_config.camera_z,
),
up=dict(
x=scene_config.camera_up_x,
y=scene_config.camera_up_y,
z=scene_config.camera_up_z,
),
)
)
return fig