import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from matplotlib import axis
from plotly.subplots import make_subplots
from ..compsim.models import *
from .plotting_constants import *
from .types import Axis, Metric
def _get_x_vals(results: LapsimResults, axis: Axis) -> npt.NDArray[np.float64]:
match axis:
case Axis.DISTANCE:
return np.cumsum(results.lap_dxs)
case Axis.TIME:
return results.lap_t
case _:
raise ValueError(f"Unsupported axis: {axis}")
def _extract_data_by_key(key: str, results: LapsimResults) -> npt.NDArray[np.float64]:
"""
Extract data from LapsimResults by key name.
Supports both top-level and per-seed keys. Per-seed keys should be in the
format "{field}_seed_{n}", e.g., "v_proposal_seed_2".
Parameters
----------
key : str
Key name to extract from results
results : LapsimResults
Lapsim results object
Returns
-------
npt.NDArray[np.float64]
Extracted data array
Raises
------
ValueError
If the key is not found in either InternalData or LapsimResults
"""
if "_seed_" in key:
field_name, seed_str = key.rsplit("_seed_", 1)
seed_idx = int(seed_str)
seed_data = results.internal_data.per_seed[seed_idx]
if not hasattr(seed_data, field_name):
raise ValueError(
f"Invalid key. Per-seed data does not have field '{field_name}'."
)
return getattr(seed_data, field_name)
else:
if hasattr(results.internal_data, key):
return getattr(results.internal_data, key)
if hasattr(results, key):
return getattr(results, key)
raise ValueError(
f"Invalid key. Neither InternalData nor LapsimResults has field '{key}'."
)
def _mask_trace(
x: np.ndarray,
y: np.ndarray,
grown_forward_mask: np.ndarray,
slower_mask: np.ndarray,
name: str,
**scatter_kwargs,
) -> list[go.Scatter]:
"""
Split a trace into up to 4 traces based on two boolean masks.
Parameters
----------
x : np.ndarray
X coordinates for the trace
y : np.ndarray
Y coordinates for the trace
grown_forward_mask : np.ndarray
Boolean mask where True -> red, False -> blue
slower_mask : np.ndarray
Boolean mask where True -> 'x' markers, False -> circle markers
name : str
Base name for the trace
**scatter_kwargs
Additional keyword arguments passed to go.Scatter
Returns
-------
list[go.Scatter]
List of Scatter traces split by mask combinations
"""
traces = []
# Create 4 combinations: (grown_forward, is_slower)
# Use markers only to avoid interpolation between non-contiguous points
combinations = [
(True, True, "red", "markers", "x", "fwd, is_slower"),
(True, False, "red", "markers", "circle", "fwd"),
(False, True, "blue", "markers", "x", "back, is_slower"),
(False, False, "blue", "markers", "circle", "back"),
]
for grown_val, slower_val, color, mode, symbol, suffix in combinations:
# Create mask for this combination
mask = (grown_forward_mask == grown_val) & (slower_mask == slower_val)
if not np.any(mask):
continue # Skip if no points match this combination
# Extract masked points
x_masked = x[mask]
y_masked = y[mask]
trace_kwargs = {
"x": x_masked,
"y": y_masked,
"mode": mode,
"name": f"{name}:{suffix}",
"marker": {"color": color, "symbol": symbol},
**scatter_kwargs,
}
traces.append(go.Scatter(**trace_kwargs))
return traces
[docs]
def plot_per_seed(
results: LapsimResults | List[LapsimResults],
x_axis: Axis,
*,
include_indices: List[int] | None = None,
include_power: bool = True,
title: str | None = None,
font_config: FontConfig = DEFAULT_FONT_CONFIG,
layout_config: LayoutConfig = DEFAULT_LAYOUT_CONFIG,
) -> go.Figure:
"""
Draws one subplot per seed index showing velocity proposals, accel, power, and masks.
Allows overlaying seed traces across multiple lapsims.
Parameters
----------
results : LapsimResults or List[LapsimResults]
For consistency, these should be the same event across different runs
x_axis : Axis
Use distance or time for the x-axis
include_indices : List[int], optional
Specify which seeds to include. e.x. [0, 1] only includes first two seeds.
Defaults to all seeds found.
include_power : bool, optional
Include power proposals (kW). True by default
title : str, optional
Plot title
font_config : FontConfig, optional
Font configuration object
layout_config : LayoutConfig, optional
Layout configuration object
Returns
-------
go.Figure
Plotly figure object
"""
if not isinstance(results, list):
results = [results]
for r in results:
if r.internal_data is None:
raise ValueError(
"LapsimResults must include internal_data to plot per-seed traces."
)
indices = (
results[0].internal_data.seed_idx_list
if include_indices is None
else results[0].internal_data.seed_idx_list[include_indices]
)
num_subplots = len(indices)
fig = make_subplots(
rows=num_subplots,
cols=1,
shared_xaxes=False,
subplot_titles=[f"Seed_idx = {i}" for i in indices],
specs=[[{"secondary_y": True}] for _ in range(num_subplots)],
)
for idx in range(num_subplots):
row_idx = idx + 1
for run_idx, r in enumerate(results):
if include_indices is not None:
target_seed_data = r.internal_data.per_seed[include_indices[idx]]
else:
target_seed_data = r.internal_data.per_seed[idx]
x_vals = _get_x_vals(r, x_axis)
legend = f"Run {run_idx}"
# Add masked traces for v_proposal
for trace in _mask_trace(
x_vals,
target_seed_data.v_proposal,
target_seed_data.grown_forward_mask,
target_seed_data.slower_mask,
name=f"{legend}:v_proposal",
mode="lines",
):
fig.add_trace(trace, row=row_idx, col=1)
# Add v_max_pre (baseline before this seed's modifications)
# For seed 0, use the initial v_max_profile
# For seed i > 0, use v_max_post from seed i-1
actual_seed_idx = (
include_indices[idx] if include_indices is not None else idx
)
if actual_seed_idx == 0:
v_max_pre = r.internal_data.v_max_profile
else:
v_max_pre = r.internal_data.per_seed[actual_seed_idx - 1].v_max_post
fig.add_trace(
go.Scatter(
x=x_vals,
y=v_max_pre,
mode="lines",
name=f"{legend}:v_max_pre",
),
row=row_idx,
col=1,
)
# Add acceleration
fig.add_trace(
go.Scatter(
x=x_vals,
y=target_seed_data.acc_proposal,
mode="lines",
name=f"{legend}:acc",
),
row=row_idx,
col=1,
secondary_y=True,
)
# Add power if requested
if include_power:
fig.add_trace(
go.Scatter(
x=x_vals,
y=target_seed_data.p_proposal / 1000.0,
mode="lines",
name=f"{legend}:power(kW)",
),
row=row_idx,
col=1,
secondary_y=True,
)
fig.update_yaxes(title_text="Velocity (m/s)", row=row_idx, col=1)
fig.update_yaxes(
title_text="Accel / Power", secondary_y=True, row=row_idx, col=1
)
fig.update_xaxes(
title_text=x_axis.value,
row=row_idx,
col=1,
)
fig.update_layout(
height=layout_config.height * num_subplots,
width=layout_config.width * 1.5, # Intentionally wide graph for visibility
title={
"text": title or "Per-seed Internal Traces",
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
},
legend=dict(
orientation="v",
yanchor="top",
xanchor="left",
),
margin=layout_config.margin,
)
return fig
[docs]
def plot_event_traces(
results: LapsimResults | List[LapsimResults],
x_axis: Axis,
*,
labels: List[str] | None = None,
metrics: List[Metric] = [Metric.VELOCITY, Metric.ACCEL, Metric.POWER],
title: str | None = None,
layout_config: LayoutConfig = DEFAULT_LAYOUT_CONFIG,
font_config: FontConfig = DEFAULT_FONT_CONFIG,
) -> go.Figure:
"""
Plot final velocity, acceleration, and power traces on a single plot.
Allows overlaying traces across multiple lapsims for direct comparison
of different vehicle configurations.
Parameters
----------
results : LapsimResults or List[LapsimResults]
Results to plot. Pass a list to overlay multiple configurations.
x_axis : Axis
Axis to use for the x-axis (distance or time)
labels : List[str], optional
Labels for each result set (e.g., ["baseline", "new_config"]).
If not provided, defaults to "run_0", "run_1", etc.
metrics : List[Metric], optional
Metrics to plot. Defaults to all three (velocity, accel, power)
title : str, optional
Plot title
layout_config : LayoutConfig, optional
Layout configuration object
font_config : FontConfig, optional
Font configuration object
Returns
-------
go.Figure
Plotly figure object
"""
# Normalize inputs to lists
if not isinstance(results, list):
results = [results]
# Generate default labels if not provided
if labels is None:
labels = [f"run_{i}" for i in range(len(results))]
elif len(labels) < len(results):
labels = labels + [f"run_{i}" for i in range(len(labels), len(results))]
fig = go.Figure()
for run_idx, r in enumerate(results):
x_vals = _get_x_vals(r, x_axis)
label = labels[run_idx]
for metric in metrics:
y_vals = getattr(r, metric.value)
if metric == Metric.POWER:
y_vals = y_vals / 1000.0
fig.add_trace(
go.Scatter(
x=x_vals,
y=y_vals,
mode="lines",
name=f"{label}:{metric.value}",
),
)
fig.update_xaxes(title_text=x_axis.value)
fig.update_layout(
height=layout_config.height,
title={
"text": title or "Event Traces",
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
},
legend=dict(
orientation="v",
yanchor="top",
xanchor="left",
),
margin=layout_config.margin,
)
return fig
[docs]
def plot_competition_traces(
competitions: CompetitionResults | List[CompetitionResults],
x_axis: Axis,
*,
labels: List[str] | None = None,
title: str | None = None,
layout_config: LayoutConfig = DEFAULT_LAYOUT_CONFIG,
font_config: FontConfig = DEFAULT_FONT_CONFIG,
) -> go.Figure:
"""
Plot one subplot for each event in a competition, displaying velocity, accel, and power traces.
Allows overlaying traces across multiple competitions.
Parameters
----------
competitions : CompetitionResults or List[CompetitionResults]
Competition results to plot
x_axis : Axis
Axis to use for the x-axis (distance or time)
labels : List[str], optional
Labels for each competition (e.g., ["baseline", "new_config"]).
If not provided, defaults to "comp_0", "comp_1", etc.
title : str, optional
Plot title
layout_config : LayoutConfig, optional
Layout configuration object
font_config : FontConfig, optional
Font configuration object
Returns
-------
go.Figure
Plotly figure object
"""
if not isinstance(competitions, list):
competitions = [competitions]
# Generate default labels if not provided
if labels is None:
labels = [f"comp_{i}" for i in range(len(competitions))]
elif len(labels) < len(competitions):
labels = labels + [f"comp_{i}" for i in range(len(labels), len(competitions))]
events = ["accel", "skidpad", "autoX", "endurance"]
fig = make_subplots(
rows=len(events),
cols=1,
shared_xaxes=False,
subplot_titles=events,
specs=[[{"secondary_y": True}] for _ in events],
)
for idx, e_label in enumerate(events):
row_idx = idx + 1
for comp_idx, c in enumerate(competitions):
event_lapsim = getattr(c, e_label).lapsim_results
x_vals = _get_x_vals(event_lapsim, x_axis)
label = labels[comp_idx]
# Plot velocity
fig.add_trace(
go.Scatter(
x=x_vals,
y=getattr(event_lapsim, "lap_vels"),
mode="lines",
name=f"{label}:velocity",
),
row=row_idx,
col=1,
secondary_y=False,
)
# Plot acceleration
fig.add_trace(
go.Scatter(
x=x_vals,
y=getattr(event_lapsim, "lap_accs"),
mode="lines",
name=f"{label}:accel",
),
row=row_idx,
col=1,
secondary_y=False,
)
# Plot power
fig.add_trace(
go.Scatter(
x=x_vals,
y=getattr(event_lapsim, "lap_powers") / 1000.0,
mode="lines",
name=f"{label}:power",
),
row=row_idx,
col=1,
secondary_y=True,
)
fig.update_yaxes(title_text="Velocity / Accel", secondary_y=False)
fig.update_yaxes(title_text="Power (kW)", secondary_y=True)
fig.update_xaxes(
title_text=x_axis.value,
)
fig.update_layout(
height=layout_config.height * len(events),
title={
"text": title or "Competition Traces (velocity, accel, power)",
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
},
legend=dict(
orientation="v",
yanchor="top",
xanchor="left",
),
margin=layout_config.margin,
)
return fig
[docs]
def plot_internal_keys(
results: List[LapsimResults],
keys: List[str],
x_axis: Axis,
*,
title: str | None = None,
layout_config: LayoutConfig = DEFAULT_LAYOUT_CONFIG,
font_config: FontConfig = DEFAULT_FONT_CONFIG,
) -> go.Figure:
"""
Plot any tracked internal variables (top-level or per-seed) by key name.
Provided for back compatibility and more flexible access.
Parameters
----------
results : List[LapsimResults]
List of lapsim results to plot
keys : List[str]
List of keys to plot
x_axis : Axis
Axis to use for the x-axis (distance or time)
title : str, optional
Plot title
layout_config : LayoutConfig, optional
Layout configuration object
font_config : FontConfig, optional
Font configuration object
Returns
-------
go.Figure
Plotly figure object
Notes
-----
Possible key names:
- Internal Data top-level keys: ``v_max_profile``, ``seed_idx_list``, ``cumulative_dist``
- LapsimResults top-level keys: ``lap_t``, ``lap_dxs``, ``lap_vels``, ``lap_accs``, ``lap_powers``, ``lap_eff_motor_torques``
- Per-seed keys in the form ``{field}_seed_{n}`` (e.g. ``v_proposal_seed_2``)
"""
fig = make_subplots(
rows=len(keys),
cols=1,
shared_xaxes=True,
subplot_titles=keys,
)
for idx, key in enumerate(keys):
row_idx = idx + 1
for run_idx, r in enumerate(results):
x_vals = _get_x_vals(r, axis)
y_vals = _extract_data_by_key(key, r)
fig.add_trace(
go.Scatter(
x=x_vals,
y=y_vals,
mode="lines",
name=f"run_{run_idx}:{key}",
),
row=row_idx,
col=1,
)
fig.update_yaxes(title_text=key, row=row_idx, col=1)
fig.update_xaxes(title_text=x_axis.value)
fig.update_layout(
height=layout_config.height * len(keys),
title={
"text": title or "Internal variable traces",
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
},
legend=dict(
orientation="v",
yanchor="top",
xanchor="left",
),
margin=layout_config.margin,
)
return fig