from typing import List
import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from matplotlib import pyplot as plt
from scipy.interpolate import interp1d
from ..plotting.color_themes import get_theme
from ..plotting.plotting_constants import *
from ..plotting.stem_plot import plot_3d_stem
from ..plotting.types import NewArc
from .models import *
from .utils import *
[docs]
class Track:
"""
Represents a racing track with discrete or continuous curvature representation.
Parameters
----------
dx : NDArray[float64]
Array of distance steps
radius : NDArray[float64]
Array of radius of curvature values
cumulative_dist : NDArray[float64]
Array of cumulative distances along the track
x_m : NDArray[float64]
Array of x coordinates in meters
y_m : NDArray[float64]
Array of y coordinates in meters
distance_step : float
Length of each step in meters
continuous : bool
If True, track uses continuous representation; if False, uses discrete corner-based representation
**kwargs : dict
Additional keyword arguments:
- For discrete tracks: original_corners (List[List]), arcs (List[ArcData])
- For continuous tracks: sample_dist (float), tck (B-spline representation)
"""
def __init__(
self,
dx: npt.NDArray[np.float64],
radius: npt.NDArray[np.float64],
cumulative_dist: npt.NDArray[np.float64],
x_m: npt.NDArray[np.float64],
y_m: npt.NDArray[np.float64],
distance_step: float,
continuous: bool,
**kwargs,
):
self.continuous = continuous
self.dx = dx
self.radius = radius
self.cumulative_dist = cumulative_dist
self.x_m = x_m
self.y_m = y_m
self.distance_step = distance_step
if not continuous:
# Additional discrete representation data (only populated if continuous=False)
self.original_corners: List[List] = kwargs["original_corners"]
self.arcs: List[ArcData] = kwargs["arcs"]
else:
# Additional continuous representation data (only populated if continuous=True)
self.sample_dist: float = kwargs["sample_dist"]
self.tck = kwargs["tck"] # B-spline representation
self.absolute_origin: Tuple[float, float] = kwargs["absolute_origin"]
# Compute seed indices
self.seed_idx = self._compute_seed_indices()
[docs]
def to_data(self) -> TrackData:
"""
Serialize the Track to a TrackData model for Pydantic serialization.
Returns
-------
TrackData
A ContinuousTrackData or DiscreteTrackData instance containing all track data
"""
if self.continuous:
return ContinuousTrackData(
continuous=True,
dx=self.dx,
radius=self.radius,
cumulative_dist=self.cumulative_dist,
x_m=self.x_m,
y_m=self.y_m,
distance_step=self.distance_step,
sample_dist=self.sample_dist,
tck=self.tck,
absolute_origin=self.absolute_origin,
seed_idx=self.seed_idx,
)
else:
return DiscreteTrackData(
continuous=False,
dx=self.dx,
radius=self.radius,
cumulative_dist=self.cumulative_dist,
x_m=self.x_m,
y_m=self.y_m,
distance_step=self.distance_step,
original_corners=self.original_corners,
arcs=self.arcs,
seed_idx=self.seed_idx,
)
[docs]
def get_dx_array(self):
"""
Get a copy of the distance step array.
Returns
-------
NDArray[float64]
Copy of the dx array
"""
return self.dx.copy()
[docs]
def get_curvature_array(self):
"""
Get a copy of the radius of curvature array.
Returns
-------
NDArray[float64]
Copy of the radius array
"""
return self.radius.copy()
[docs]
def get_dist_from_start_array(self):
"""
Get a copy of the cumulative distance array.
Returns
-------
NDArray[float64]
Copy of the cumulative distance array
"""
return self.cumulative_dist.copy()
[docs]
def get_simulation_seeds(self):
"""
Get a copy of the simulation seed indices.
Returns
-------
NDArray[int32]
Copy of the seed index array
"""
return self.seed_idx.copy()
[docs]
def radius_of_curvature(self, d: float) -> float:
"""
Get the radius of curvature at distance d along the track.
For continuous tracks: Uses Menger curvature for smooth calculation
For discrete tracks: Returns the discretized radius value
Parameters
----------
d : float
Distance along the track in meters
Returns
-------
float
Radius of curvature in meters
"""
if self.continuous:
_, radius = calculate_menger_curvature(
self.x_m, self.y_m, self.cumulative_dist, d, self.sample_dist
)
return radius
else:
if d > self.cumulative_dist[-1] + 0.1:
raise ValueError("Distance out of bounds")
left_idx = min(
np.searchsorted(self.cumulative_dist, d, side="left"),
len(self.cumulative_dist) - 1,
)
return self.radius[left_idx]
def _compute_seed_indices(self) -> npt.NDArray[np.int32]:
"""
Compute simulation seed indices for a given track.
For discrete tracks, seeds are placed at the middle of each constant-radius segment.
A segment is identified by consecutive indices with the same radius value.
For continuous tracks, seeds are indices of the local minima of the curvature function.
Returns
-------
npt.NDArray[np.int32]
Array of indices where simulation seeds should be placed
"""
if self.continuous:
seeds = []
for idx, r in enumerate(self.radius):
# Do not include the ends
# Technically more elegant ways of excluding these exist, but this is intuitive
if idx == 0 or idx == len(self.radius) - 1:
continue
# if r(idx-1) > r(idx) < r(idx+1), that is a local minima
if self.radius[idx - 1] > r and r < self.radius[idx + 1]:
seeds.append(idx)
return np.array(seeds, dtype=np.int32)
else:
# Find where radius changes (start of new corners)
# Add a sentinel at the end to detect the last corner
radius_changes = np.diff(self.radius, prepend=np.nan, append=np.nan) != 0
corner_starts = np.where(radius_changes)[0][:-1] # Exclude the sentinel
corner_ends = np.where(radius_changes)[0][1:] - 1 # End indices
# Compute middle index for each corner
seed_idx = ((corner_starts + corner_ends) // 2).astype(np.int32)
return seed_idx
[docs]
def plot_arcs(self, figsize=(20, 12), show_endpoints=True):
"""
Plot the track using Matplotlib Arc objects. Only available for discrete (corner-based)
tracks. A validation tool to visualize our corners.
Parameters
----------
figsize : tuple, optional
Figure size as (width, height) in inches (default: (20, 12))
show_endpoints : bool, optional
Whether to show arc endpoints as blue dots (default: True)
Returns
-------
tuple
(fig, ax) - matplotlib figure and axes objects
"""
if self.continuous:
raise NotImplementedError(
"plot_arcs() is not implemented for continuous tracks"
)
fig, ax = plt.subplots(figsize=figsize)
ax.set_aspect("equal")
# Create matplotlib NewArc objects from ArcData (matching old track_vis behavior)
for arc_data in self.arcs:
arc = NewArc(
xy=arc_data.center,
width=arc_data.width,
height=arc_data.height,
angle=arc_data.angle,
theta1=arc_data.theta1,
theta2=arc_data.theta2,
n=arc_data.n_points,
)
ax.add_patch(arc)
if show_endpoints:
path = arc.get_path()
transform = arc.get_patch_transform()
vertices = transform.transform(path.vertices)
# Vertices might be reversed
if arc_data.theta1_original > arc_data.theta2_original:
vertices = vertices[::-1]
endpoint = vertices[-1]
ax.plot(endpoint[0], endpoint[1], marker=".", color="b")
ax.autoscale()
plt.close(fig)
return fig, ax
def _rotate_coordinates(
self, degrees: float
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
"""
Rotate track coordinates counterclockwise around the first point.
Parameters
----------
degrees : float
Rotation angle in degrees (counterclockwise)
Returns
-------
tuple[NDArray[float64], NDArray[float64]]
Rotated x and y coordinates
"""
# Convert to radians
theta = np.radians(degrees)
# Use first point as rotation center
x0, y0 = self.x_m[0], self.y_m[0]
# Translate to origin
x_translated = self.x_m - x0
y_translated = self.y_m - y0
# Rotate counterclockwise
cos_theta = np.cos(theta)
sin_theta = np.sin(theta)
x_rotated = x_translated * cos_theta - y_translated * sin_theta
y_rotated = x_translated * sin_theta + y_translated * cos_theta
# Translate back
x_final = x_rotated + x0
y_final = y_rotated + y0
return x_final, y_final
[docs]
def plot(
self,
rotate_by_deg: float = 0.0,
theme: str | None = "suspension",
font_config: FontConfig | None = None,
layout_config: LayoutConfig | None = None,
) -> go.Figure:
"""
Visualizes the track.
Parameters
----------
rotate_by_deg : float, optional
Rotation angle in degrees counterclockwise around the first point (default: 0.0)
theme : str, optional
The color theme to use (default: None, uses default theme)
font_config : FontConfig, optional
Font configuration object (default: None, uses DEFAULT_FONT_CONFIG)
layout_config : LayoutConfig, optional
Layout configuration object (default: None, uses DEFAULT_LAYOUT_CONFIG)
Returns
-------
go.Figure
Plotly figure object
"""
# Use provided configs or defaults
theme_colors = get_theme(theme)
font_config = font_config or DEFAULT_FONT_CONFIG
layout_config = layout_config or DEFAULT_LAYOUT_CONFIG
# Apply rotation if specified
x, y = (
self._rotate_coordinates(rotate_by_deg)
if rotate_by_deg
else (self.x_m, self.y_m)
)
# Create scatter plot with color mapping
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="markers",
marker={
"size": MARKER_SIZE_LARGE,
"color": self.radius,
"colorscale": theme_colors["colorscale"],
"cmin": 0,
"cmax": 200,
"colorbar": {"title": "Radius (m)"},
},
hovertext=[
f"Distance: {dist:.1f} m<br>Radius: {rad:.1f} m"
for dist, rad in zip(self.cumulative_dist, self.radius)
],
hovertemplate=(
f"x: %{{x:{FLOAT_PRECISION}}}<br>"
f"y: %{{y:{FLOAT_PRECISION}}}<br>"
"%{hovertext}<extra></extra>"
),
showlegend=False,
)
)
fig.update_layout(
width=layout_config.width,
height=layout_config.height,
margin=layout_config.margin,
xaxis={
"scaleanchor": "y",
"scaleratio": 1,
"showgrid": True,
"gridcolor": GRID_COLOR,
"gridwidth": GRID_WIDTH,
"zeroline": True,
"zerolinecolor": ZEROLINE_COLOR,
"zerolinewidth": ZEROLINE_WIDTH,
"tickfont": dict(size=font_config.small, color=TEXT_COLOR_DARK),
"tickformat": FLOAT_PRECISION,
"title": {
"text": "x position (m)",
"font": dict(size=font_config.medium, color=TEXT_COLOR_DARK),
},
},
yaxis={
"showgrid": True,
"gridcolor": GRID_COLOR,
"gridwidth": GRID_WIDTH,
"zeroline": True,
"zerolinecolor": ZEROLINE_COLOR,
"zerolinewidth": ZEROLINE_WIDTH,
"tickfont": dict(size=font_config.small, color=TEXT_COLOR_DARK),
"tickformat": FLOAT_PRECISION,
"title": {
"text": "y position (m)",
"font": dict(size=font_config.medium, color=TEXT_COLOR_DARK),
},
},
plot_bgcolor=layout_config.plot_bgcolor,
hovermode=HOVER_MODE,
)
return fig
[docs]
def plot_with_overlay(
self,
variable: npt.NDArray[np.float64],
title: str | None = None,
label: str | None = None,
rotate_by_deg: float = 0.0,
theme: str | None = None,
font_config: FontConfig | None = None,
layout_config: LayoutConfig | None = None,
colorbar_config: ColorbarConfig | None = None,
) -> go.Figure:
"""
Visualizes a scalar variable as color along the 2D track.
Parameters
----------
variable : NDArray[float64]
Array of values to map as colors (e.g., velocity, power)
label : str, optional
Colorbar label (default: None)
rotate_by_deg : float, optional
Rotation angle in degrees counterclockwise around the first point (default: 0.0)
theme : str, optional
The color theme to use (default: None, uses default theme)
font_config : FontConfig, optional
Font configuration object (default: None, uses DEFAULT_FONT_CONFIG)
layout_config : LayoutConfig, optional
Layout configuration object (default: None, uses DEFAULT_LAYOUT_CONFIG)
colorbar_config : ColorbarConfig, optional
Colorbar configuration object (default: None, uses DEFAULT_COLORBAR_CONFIG)
Returns
-------
go.Figure
Plotly figure object
"""
# Use provided configs or defaults
theme_colors = get_theme(theme)
font_config = font_config or DEFAULT_FONT_CONFIG
layout_config = layout_config or DEFAULT_LAYOUT_CONFIG
colorbar_config = colorbar_config or DEFAULT_COLORBAR_CONFIG
variable = np.asarray(variable)
n_track = len(self.x_m)
n_var = len(variable)
# Interpolate if lengths differ
if n_var != n_track:
x_old = np.linspace(0, 1, n_var)
x_new = np.linspace(0, 1, n_track)
interp_func = interp1d(
x_old, variable, kind="linear", fill_value="extrapolate"
)
variable = interp_func(x_new)
c = variable
# Apply rotation if specified
x, y = (
self._rotate_coordinates(rotate_by_deg)
if rotate_by_deg
else (self.x_m, self.y_m)
)
# Create scatter plot with color mapping
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="markers",
marker=dict(
size=MARKER_SIZE_LARGE,
color=c,
colorscale=theme_colors["colorscale"],
colorbar=dict(
title=dict(
text=label or "",
font=dict(size=font_config.medium),
),
thickness=colorbar_config.thickness,
len=colorbar_config.length,
tickfont=dict(size=font_config.small),
),
showscale=True,
),
showlegend=False,
hovertemplate=f"x: %{{x:{FLOAT_PRECISION}}}<br>y: %{{y:{FLOAT_PRECISION}}}<br>value: %{{marker.color:{FLOAT_PRECISION}}}<extra></extra>",
)
)
if title:
fig.update_layout(
title={
"text": title,
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
"yanchor": layout_config.title_yanchor,
},
)
fig.update_layout(
width=layout_config.width,
height=layout_config.height,
margin=layout_config.margin,
xaxis=dict(
scaleanchor="y",
scaleratio=1,
showgrid=True,
gridcolor=GRID_COLOR,
gridwidth=GRID_WIDTH,
zeroline=True,
zerolinecolor=ZEROLINE_COLOR,
zerolinewidth=ZEROLINE_WIDTH,
tickfont=dict(size=font_config.small, color=TEXT_COLOR_DARK),
tickformat=FLOAT_PRECISION,
title=dict(
text="x position (m)",
font=dict(size=font_config.medium, color=TEXT_COLOR_DARK),
),
),
yaxis=dict(
showgrid=True,
gridcolor=GRID_COLOR,
gridwidth=GRID_WIDTH,
zeroline=True,
zerolinecolor=ZEROLINE_COLOR,
zerolinewidth=ZEROLINE_WIDTH,
tickfont=dict(size=font_config.small, color=TEXT_COLOR_DARK),
tickformat=FLOAT_PRECISION,
title=dict(
text="y position (m)",
font=dict(size=font_config.medium, color=TEXT_COLOR_DARK),
),
),
plot_bgcolor=layout_config.plot_bgcolor,
hovermode=HOVER_MODE,
)
return fig
[docs]
def plot_3d(
self,
variable: npt.NDArray[np.float64],
title: str | None = None,
z_label: str | None = None,
theme: str | None = None,
font_config: FontConfig | None = None,
layout_config: LayoutConfig | None = None,
scene_config: SceneConfig | None = None,
) -> go.Figure:
"""
Maps a variable onto the z-axis of the track for 3D visualization using Plotly.
Creates a 3D stem plot with vertical lines extending from baseline to data points.
Parameters
----------
variable : NDArray[float64]
Array of values to map to z-axis (e.g., velocity, power)
title : str, optional
Plot title (default: None)
z_label : str, optional
Z-axis label (default: None)
theme: str, optional
The color theme to use (default: None)
font_config : FontConfig, optional
Font configuration object (default: None, uses DEFAULT_FONT_CONFIG)
layout_config : LayoutConfig, optional
Layout configuration object (default: None, uses DEFAULT_LAYOUT_CONFIG)
scene_config : SceneConfig, optional
Scene configuration object for 3D view (default: None, uses DEFAULT_SCENE_CONFIG)
Returns
-------
go.Figure
Plotly figure object
"""
variable = np.asarray(variable)
n_track = len(self.x_m)
n_var = len(variable)
# Interpolate if lengths differ
if n_var != n_track:
x_old = np.linspace(0, 1, n_var)
x_new = np.linspace(0, 1, n_track)
interp_func = interp1d(
x_old, variable, kind="linear", fill_value="extrapolate"
)
variable = interp_func(x_new)
# Delegate to the 3D stem plot utility
return plot_3d_stem(
x=self.x_m,
y=self.y_m,
z=variable,
title=title,
z_label=z_label,
theme=theme,
font_config=font_config,
layout_config=layout_config,
scene_config=scene_config,
)