from typing import List, Optional
import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from matplotlib import pyplot as plt
from scipy.interpolate import interp1d
from ..plotting.color_themes import get_theme
from ..plotting.plotting_constants import *
from ..plotting.stem_plot import plot_3d_stem
from ..plotting.types import NewArc
from .models import *
from .utils import *
[docs]
class Track:
"""
Represents a racing track with discrete or continuous curvature representation.
Parameters
----------
dx : NDArray[float64]
Array of distance steps
radius : NDArray[float64]
Array of radius of curvature values
cumulative_dist : NDArray[float64]
Array of cumulative distances along the track
x_m : NDArray[float64]
Array of x coordinates in meters
y_m : NDArray[float64]
Array of y coordinates in meters
distance_step : float
Length of each step in meters
continuous : bool
If True, track uses continuous representation; if False, uses discrete corner-based representation
**kwargs : dict
Additional keyword arguments:
- For discrete tracks: original_corners (List[List]), arcs (List[ArcData])
- For continuous tracks: sample_dist (float), tck (B-spline representation)
"""
def __init__(
self,
dx: npt.NDArray[np.float64],
radius: npt.NDArray[np.float64],
cumulative_dist: npt.NDArray[np.float64],
x_m: npt.NDArray[np.float64],
y_m: npt.NDArray[np.float64],
distance_step: float,
continuous: bool,
**kwargs,
):
self.continuous = continuous
self.dx = dx
self.radius = radius
self.cumulative_dist = cumulative_dist
self.x_m = x_m
self.y_m = y_m
self.distance_step = distance_step
if not continuous:
# Additional discrete representation data (only populated if continuous=False)
self.original_corners: List[List] = kwargs["original_corners"]
self.arcs: List[ArcData] = kwargs["arcs"]
else:
# Additional continuous representation data (only populated if continuous=True)
self.sample_dist: float = kwargs["sample_dist"]
self.tck = kwargs["tck"] # B-spline representation
# Compute seed indices
self.seed_idx = self._compute_seed_indices()
[docs]
def to_data(self) -> TrackData:
"""
Serialize the Track to a TrackData model for Pydantic serialization.
Returns
-------
TrackData
A ContinuousTrackData or DiscreteTrackData instance containing all track data
"""
if self.continuous:
return ContinuousTrackData(
continuous=True,
dx=self.dx,
radius=self.radius,
cumulative_dist=self.cumulative_dist,
x_m=self.x_m,
y_m=self.y_m,
distance_step=self.distance_step,
sample_dist=self.sample_dist,
tck=self.tck,
seed_idx=self.seed_idx,
)
else:
return DiscreteTrackData(
continuous=False,
dx=self.dx,
radius=self.radius,
cumulative_dist=self.cumulative_dist,
x_m=self.x_m,
y_m=self.y_m,
distance_step=self.distance_step,
original_corners=self.original_corners,
arcs=self.arcs,
seed_idx=self.seed_idx,
)
[docs]
def get_dx_array(self):
"""
Get a copy of the distance step array.
Returns
-------
NDArray[float64]
Copy of the dx array
"""
return self.dx.copy()
[docs]
def get_curvature_array(self):
"""
Get a copy of the radius of curvature array.
Returns
-------
NDArray[float64]
Copy of the radius array
"""
return self.radius.copy()
[docs]
def get_dist_from_start_array(self):
"""
Get a copy of the cumulative distance array.
Returns
-------
NDArray[float64]
Copy of the cumulative distance array
"""
return self.cumulative_dist.copy()
[docs]
def get_simulation_seeds(self):
"""
Get a copy of the simulation seed indices.
Returns
-------
NDArray[int32]
Copy of the seed index array
"""
return self.seed_idx.copy()
[docs]
def radius_of_curvature(self, d: float) -> float:
"""
Get the radius of curvature at distance d along the track.
For continuous tracks: Uses Menger curvature for smooth calculation
For discrete tracks: Returns the discretized radius value
Parameters
----------
d : float
Distance along the track in meters
Returns
-------
float
Radius of curvature in meters
"""
if self.continuous:
_, radius = calculate_menger_curvature(
self.x_m, self.y_m, self.cumulative_dist, d, self.sample_dist
)
return radius
else:
if d > self.cumulative_dist[-1] + 0.1:
raise ValueError("Distance out of bounds")
left_idx = min(
np.searchsorted(self.cumulative_dist, d, side="left"),
len(self.cumulative_dist) - 1,
)
return self.radius[left_idx]
def _compute_seed_indices(self) -> npt.NDArray[np.int32]:
"""
Compute simulation seed indices for a given track.
For discrete tracks, seeds are placed at the middle of each constant-radius segment.
A segment is identified by consecutive indices with the same radius value.
For continuous tracks, seeds are indices of the local minima of the curvature function.
Returns
-------
npt.NDArray[np.int32]
Array of indices where simulation seeds should be placed
"""
if self.continuous:
seeds = []
n, curr_idx, prev_radius = len(self.radius), 0, float("inf")
while curr_idx < n:
# If a segment of the track is close to constant radius (abs diff < 1), we prefer to
# take the midpoint of the whole segment. This makes our seeding less noisy.
const_radius_start = curr_idx
curr_radius = self.radius[curr_idx]
while curr_idx < n and np.abs(self.radius[curr_idx] - curr_radius) < 1:
curr_idx += 1
const_radius_end = curr_idx - 1
# Compare with previous and next radius to determine whether it's a local minima
if curr_radius < prev_radius and (
curr_idx == n or curr_radius < self.radius[curr_idx]
):
seeds.append((const_radius_start + const_radius_end) // 2)
prev_radius = curr_radius
return np.array(seeds, dtype=np.int32)
else:
# Find where radius changes (start of new corners)
# Add a sentinel at the end to detect the last corner
radius_changes = np.diff(self.radius, prepend=np.nan, append=np.nan) != 0
corner_starts = np.where(radius_changes)[0][:-1] # Exclude the sentinel
corner_ends = np.where(radius_changes)[0][1:] - 1 # End indices
# Compute middle index for each corner
seed_idx = ((corner_starts + corner_ends) // 2).astype(np.int32)
return seed_idx
[docs]
def plot_arcs(self, figsize=(20, 12), show_endpoints=True):
"""
Plot the track using Matplotlib Arc objects. Only available for discrete (corner-based)
tracks. A validation tool to visualize our corners.
Parameters
----------
figsize : tuple, optional
Figure size as (width, height) in inches (default: (20, 12))
show_endpoints : bool, optional
Whether to show arc endpoints as blue dots (default: True)
Returns
-------
tuple
(fig, ax) - matplotlib figure and axes objects
"""
if self.continuous:
raise NotImplementedError(
"plot_arcs() is not implemented for continuous tracks"
)
fig, ax = plt.subplots(figsize=figsize)
ax.set_aspect("equal")
# Create matplotlib NewArc objects from ArcData (matching old track_vis behavior)
for arc_data in self.arcs:
arc = NewArc(
xy=arc_data.center,
width=arc_data.width,
height=arc_data.height,
angle=arc_data.angle,
theta1=arc_data.theta1,
theta2=arc_data.theta2,
n=arc_data.n_points,
)
ax.add_patch(arc)
if show_endpoints:
path = arc.get_path()
transform = arc.get_patch_transform()
vertices = transform.transform(path.vertices)
# Vertices might be reversed
if arc_data.theta1_original > arc_data.theta2_original:
vertices = vertices[::-1]
endpoint = vertices[-1]
ax.plot(endpoint[0], endpoint[1], marker=".", color="b")
ax.autoscale()
plt.close(fig)
return fig, ax
[docs]
def plot(
self,
theme: Optional[str] = None,
font_config: Optional["FontConfig"] = None,
layout_config: Optional["LayoutConfig"] = None,
) -> go.Figure:
"""
Visualizes the track.
Parameters
----------
theme : str, optional
The color theme to use (default: None, uses default theme)
font_config : FontConfig, optional
Font configuration object (default: None, uses DEFAULT_FONT_CONFIG)
layout_config : LayoutConfig, optional
Layout configuration object (default: None, uses DEFAULT_LAYOUT_CONFIG)
Returns
-------
go.Figure
Plotly figure object
"""
# Use provided configs or defaults
theme_colors = get_theme(theme)
font_config = font_config or DEFAULT_FONT_CONFIG
layout_config = layout_config or DEFAULT_LAYOUT_CONFIG
x = self.x_m
y = self.y_m
# Create scatter plot with color mapping
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="markers",
marker=dict(
size=MARKER_SIZE_LARGE,
color=theme_colors["dark"],
),
showlegend=False,
hovertemplate=f"x: %{{x:{FLOAT_PRECISION}}}<br>y: %{{y:{FLOAT_PRECISION}}}<extra></extra>",
)
)
fig.update_layout(
width=layout_config.width,
height=layout_config.height,
margin=layout_config.margin,
xaxis=dict(
scaleanchor="y",
scaleratio=1,
showgrid=True,
gridcolor=GRID_COLOR,
gridwidth=GRID_WIDTH,
zeroline=True,
zerolinecolor=ZEROLINE_COLOR,
zerolinewidth=ZEROLINE_WIDTH,
tickfont=dict(size=font_config.small, color=TEXT_COLOR_DARK),
tickformat=FLOAT_PRECISION,
title=dict(
text="x position (m)",
font=dict(size=font_config.medium, color=TEXT_COLOR_DARK),
),
),
yaxis=dict(
showgrid=True,
gridcolor=GRID_COLOR,
gridwidth=GRID_WIDTH,
zeroline=True,
zerolinecolor=ZEROLINE_COLOR,
zerolinewidth=ZEROLINE_WIDTH,
tickfont=dict(size=font_config.small, color=TEXT_COLOR_DARK),
tickformat=FLOAT_PRECISION,
title=dict(
text="y position (m)",
font=dict(size=font_config.medium, color=TEXT_COLOR_DARK),
),
),
plot_bgcolor=layout_config.plot_bgcolor,
hovermode=HOVER_MODE,
)
return fig
[docs]
def plot_with_overlay(
self,
variable: npt.NDArray[np.float64],
label: Optional[str] = None,
theme: Optional[str] = None,
font_config: Optional["FontConfig"] = None,
layout_config: Optional["LayoutConfig"] = None,
colorbar_config: Optional["ColorbarConfig"] = None,
) -> go.Figure:
"""
Visualizes a scalar variable as color along the 2D track.
Parameters
----------
variable : NDArray[float64]
Array of values to map as colors (e.g., velocity, power)
label : str, optional
Colorbar label (default: None)
theme : str, optional
The color theme to use (default: None, uses default theme)
font_config : FontConfig, optional
Font configuration object (default: None, uses DEFAULT_FONT_CONFIG)
layout_config : LayoutConfig, optional
Layout configuration object (default: None, uses DEFAULT_LAYOUT_CONFIG)
colorbar_config : ColorbarConfig, optional
Colorbar configuration object (default: None, uses DEFAULT_COLORBAR_CONFIG)
Returns
-------
go.Figure
Plotly figure object
"""
# Use provided configs or defaults
theme_colors = get_theme(theme)
font_config = font_config or DEFAULT_FONT_CONFIG
layout_config = layout_config or DEFAULT_LAYOUT_CONFIG
colorbar_config = colorbar_config or DEFAULT_COLORBAR_CONFIG
variable = np.asarray(variable)
n_track = len(self.x_m)
n_var = len(variable)
# Interpolate if lengths differ
if n_var != n_track:
x_old = np.linspace(0, 1, n_var)
x_new = np.linspace(0, 1, n_track)
interp_func = interp1d(
x_old, variable, kind="linear", fill_value="extrapolate"
)
variable = interp_func(x_new)
x = self.x_m
y = self.y_m
c = variable
# Create scatter plot with color mapping
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="markers",
marker=dict(
size=MARKER_SIZE_LARGE,
color=c,
colorscale=theme_colors["colorscale"],
colorbar=dict(
title=dict(
text=label or "",
font=dict(size=font_config.medium),
),
thickness=colorbar_config.thickness,
len=colorbar_config.length,
tickfont=dict(size=font_config.small),
),
showscale=True,
),
showlegend=False,
hovertemplate=f"x: %{{x:{FLOAT_PRECISION}}}<br>y: %{{y:{FLOAT_PRECISION}}}<br>value: %{{marker.color:{FLOAT_PRECISION}}}<extra></extra>",
)
)
fig.update_layout(
width=layout_config.width,
height=layout_config.height,
margin=layout_config.margin,
xaxis=dict(
scaleanchor="y",
scaleratio=1,
showgrid=True,
gridcolor=GRID_COLOR,
gridwidth=GRID_WIDTH,
zeroline=True,
zerolinecolor=ZEROLINE_COLOR,
zerolinewidth=ZEROLINE_WIDTH,
tickfont=dict(size=font_config.small, color=TEXT_COLOR_DARK),
tickformat=FLOAT_PRECISION,
title=dict(
text="x position (m)",
font=dict(size=font_config.medium, color=TEXT_COLOR_DARK),
),
),
yaxis=dict(
showgrid=True,
gridcolor=GRID_COLOR,
gridwidth=GRID_WIDTH,
zeroline=True,
zerolinecolor=ZEROLINE_COLOR,
zerolinewidth=ZEROLINE_WIDTH,
tickfont=dict(size=font_config.small, color=TEXT_COLOR_DARK),
tickformat=FLOAT_PRECISION,
title=dict(
text="y position (m)",
font=dict(size=font_config.medium, color=TEXT_COLOR_DARK),
),
),
plot_bgcolor=layout_config.plot_bgcolor,
hovermode=HOVER_MODE,
)
return fig
[docs]
def plot_3d(
self,
variable: npt.NDArray[np.float64],
title: Optional[str] = None,
z_label: Optional[str] = None,
theme: Optional[str] = None,
font_config: Optional["FontConfig"] = None,
layout_config: Optional["LayoutConfig"] = None,
scene_config: Optional["SceneConfig"] = None,
) -> go.Figure:
"""
Maps a variable onto the z-axis of the track for 3D visualization using Plotly.
Creates a 3D stem plot with vertical lines extending from baseline to data points.
Parameters
----------
variable : NDArray[float64]
Array of values to map to z-axis (e.g., velocity, power)
title : str, optional
Plot title (default: None)
z_label : str, optional
Z-axis label (default: None)
theme: str, optional
The color theme to use (default: None)
font_config : FontConfig, optional
Font configuration object (default: None, uses DEFAULT_FONT_CONFIG)
layout_config : LayoutConfig, optional
Layout configuration object (default: None, uses DEFAULT_LAYOUT_CONFIG)
scene_config : SceneConfig, optional
Scene configuration object for 3D view (default: None, uses DEFAULT_SCENE_CONFIG)
Returns
-------
go.Figure
Plotly figure object
"""
variable = np.asarray(variable)
n_track = len(self.x_m)
n_var = len(variable)
# Interpolate if lengths differ
if n_var != n_track:
x_old = np.linspace(0, 1, n_var)
x_new = np.linspace(0, 1, n_track)
interp_func = interp1d(
x_old, variable, kind="linear", fill_value="extrapolate"
)
variable = interp_func(x_new)
# Delegate to the 3D stem plot utility
return plot_3d_stem(
x=self.x_m,
y=self.y_m,
z=variable,
title=title,
z_label=z_label,
theme=theme,
font_config=font_config,
layout_config=layout_config,
scene_config=scene_config,
)