Source code for suboptimumg.track.track

from typing import List, Optional

import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from matplotlib import pyplot as plt
from scipy.interpolate import interp1d

from ..plotting.color_themes import get_theme
from ..plotting.plotting_constants import *
from ..plotting.stem_plot import plot_3d_stem
from ..plotting.types import NewArc
from .models import *
from .utils import *


[docs] class Track: """ Represents a racing track with discrete or continuous curvature representation. Parameters ---------- dx : NDArray[float64] Array of distance steps radius : NDArray[float64] Array of radius of curvature values cumulative_dist : NDArray[float64] Array of cumulative distances along the track x_m : NDArray[float64] Array of x coordinates in meters y_m : NDArray[float64] Array of y coordinates in meters distance_step : float Length of each step in meters continuous : bool If True, track uses continuous representation; if False, uses discrete corner-based representation **kwargs : dict Additional keyword arguments: - For discrete tracks: original_corners (List[List]), arcs (List[ArcData]) - For continuous tracks: sample_dist (float), tck (B-spline representation) """ def __init__( self, dx: npt.NDArray[np.float64], radius: npt.NDArray[np.float64], cumulative_dist: npt.NDArray[np.float64], x_m: npt.NDArray[np.float64], y_m: npt.NDArray[np.float64], distance_step: float, continuous: bool, **kwargs, ): self.continuous = continuous self.dx = dx self.radius = radius self.cumulative_dist = cumulative_dist self.x_m = x_m self.y_m = y_m self.distance_step = distance_step if not continuous: # Additional discrete representation data (only populated if continuous=False) self.original_corners: List[List] = kwargs["original_corners"] self.arcs: List[ArcData] = kwargs["arcs"] else: # Additional continuous representation data (only populated if continuous=True) self.sample_dist: float = kwargs["sample_dist"] self.tck = kwargs["tck"] # B-spline representation # Compute seed indices self.seed_idx = self._compute_seed_indices()
[docs] def to_data(self) -> TrackData: """ Serialize the Track to a TrackData model for Pydantic serialization. Returns ------- TrackData A ContinuousTrackData or DiscreteTrackData instance containing all track data """ if self.continuous: return ContinuousTrackData( continuous=True, dx=self.dx, radius=self.radius, cumulative_dist=self.cumulative_dist, x_m=self.x_m, y_m=self.y_m, distance_step=self.distance_step, sample_dist=self.sample_dist, tck=self.tck, seed_idx=self.seed_idx, ) else: return DiscreteTrackData( continuous=False, dx=self.dx, radius=self.radius, cumulative_dist=self.cumulative_dist, x_m=self.x_m, y_m=self.y_m, distance_step=self.distance_step, original_corners=self.original_corners, arcs=self.arcs, seed_idx=self.seed_idx, )
[docs] def get_dx_array(self): """ Get a copy of the distance step array. Returns ------- NDArray[float64] Copy of the dx array """ return self.dx.copy()
[docs] def get_curvature_array(self): """ Get a copy of the radius of curvature array. Returns ------- NDArray[float64] Copy of the radius array """ return self.radius.copy()
[docs] def get_dist_from_start_array(self): """ Get a copy of the cumulative distance array. Returns ------- NDArray[float64] Copy of the cumulative distance array """ return self.cumulative_dist.copy()
[docs] def get_simulation_seeds(self): """ Get a copy of the simulation seed indices. Returns ------- NDArray[int32] Copy of the seed index array """ return self.seed_idx.copy()
[docs] def radius_of_curvature(self, d: float) -> float: """ Get the radius of curvature at distance d along the track. For continuous tracks: Uses Menger curvature for smooth calculation For discrete tracks: Returns the discretized radius value Parameters ---------- d : float Distance along the track in meters Returns ------- float Radius of curvature in meters """ if self.continuous: _, radius = calculate_menger_curvature( self.x_m, self.y_m, self.cumulative_dist, d, self.sample_dist ) return radius else: if d > self.cumulative_dist[-1] + 0.1: raise ValueError("Distance out of bounds") left_idx = min( np.searchsorted(self.cumulative_dist, d, side="left"), len(self.cumulative_dist) - 1, ) return self.radius[left_idx]
def _compute_seed_indices(self) -> npt.NDArray[np.int32]: """ Compute simulation seed indices for a given track. For discrete tracks, seeds are placed at the middle of each constant-radius segment. A segment is identified by consecutive indices with the same radius value. For continuous tracks, seeds are indices of the local minima of the curvature function. Returns ------- npt.NDArray[np.int32] Array of indices where simulation seeds should be placed """ if self.continuous: seeds = [] n, curr_idx, prev_radius = len(self.radius), 0, float("inf") while curr_idx < n: # If a segment of the track is close to constant radius (abs diff < 1), we prefer to # take the midpoint of the whole segment. This makes our seeding less noisy. const_radius_start = curr_idx curr_radius = self.radius[curr_idx] while curr_idx < n and np.abs(self.radius[curr_idx] - curr_radius) < 1: curr_idx += 1 const_radius_end = curr_idx - 1 # Compare with previous and next radius to determine whether it's a local minima if curr_radius < prev_radius and ( curr_idx == n or curr_radius < self.radius[curr_idx] ): seeds.append((const_radius_start + const_radius_end) // 2) prev_radius = curr_radius return np.array(seeds, dtype=np.int32) else: # Find where radius changes (start of new corners) # Add a sentinel at the end to detect the last corner radius_changes = np.diff(self.radius, prepend=np.nan, append=np.nan) != 0 corner_starts = np.where(radius_changes)[0][:-1] # Exclude the sentinel corner_ends = np.where(radius_changes)[0][1:] - 1 # End indices # Compute middle index for each corner seed_idx = ((corner_starts + corner_ends) // 2).astype(np.int32) return seed_idx
[docs] def plot_arcs(self, figsize=(20, 12), show_endpoints=True): """ Plot the track using Matplotlib Arc objects. Only available for discrete (corner-based) tracks. A validation tool to visualize our corners. Parameters ---------- figsize : tuple, optional Figure size as (width, height) in inches (default: (20, 12)) show_endpoints : bool, optional Whether to show arc endpoints as blue dots (default: True) Returns ------- tuple (fig, ax) - matplotlib figure and axes objects """ if self.continuous: raise NotImplementedError( "plot_arcs() is not implemented for continuous tracks" ) fig, ax = plt.subplots(figsize=figsize) ax.set_aspect("equal") # Create matplotlib NewArc objects from ArcData (matching old track_vis behavior) for arc_data in self.arcs: arc = NewArc( xy=arc_data.center, width=arc_data.width, height=arc_data.height, angle=arc_data.angle, theta1=arc_data.theta1, theta2=arc_data.theta2, n=arc_data.n_points, ) ax.add_patch(arc) if show_endpoints: path = arc.get_path() transform = arc.get_patch_transform() vertices = transform.transform(path.vertices) # Vertices might be reversed if arc_data.theta1_original > arc_data.theta2_original: vertices = vertices[::-1] endpoint = vertices[-1] ax.plot(endpoint[0], endpoint[1], marker=".", color="b") ax.autoscale() plt.close(fig) return fig, ax
[docs] def plot( self, theme: Optional[str] = None, font_config: Optional["FontConfig"] = None, layout_config: Optional["LayoutConfig"] = None, ) -> go.Figure: """ Visualizes the track. Parameters ---------- theme : str, optional The color theme to use (default: None, uses default theme) font_config : FontConfig, optional Font configuration object (default: None, uses DEFAULT_FONT_CONFIG) layout_config : LayoutConfig, optional Layout configuration object (default: None, uses DEFAULT_LAYOUT_CONFIG) Returns ------- go.Figure Plotly figure object """ # Use provided configs or defaults theme_colors = get_theme(theme) font_config = font_config or DEFAULT_FONT_CONFIG layout_config = layout_config or DEFAULT_LAYOUT_CONFIG x = self.x_m y = self.y_m # Create scatter plot with color mapping fig = go.Figure() fig.add_trace( go.Scatter( x=x, y=y, mode="markers", marker=dict( size=MARKER_SIZE_LARGE, color=theme_colors["dark"], ), showlegend=False, hovertemplate=f"x: %{{x:{FLOAT_PRECISION}}}<br>y: %{{y:{FLOAT_PRECISION}}}<extra></extra>", ) ) fig.update_layout( width=layout_config.width, height=layout_config.height, margin=layout_config.margin, xaxis=dict( scaleanchor="y", scaleratio=1, showgrid=True, gridcolor=GRID_COLOR, gridwidth=GRID_WIDTH, zeroline=True, zerolinecolor=ZEROLINE_COLOR, zerolinewidth=ZEROLINE_WIDTH, tickfont=dict(size=font_config.small, color=TEXT_COLOR_DARK), tickformat=FLOAT_PRECISION, title=dict( text="x position (m)", font=dict(size=font_config.medium, color=TEXT_COLOR_DARK), ), ), yaxis=dict( showgrid=True, gridcolor=GRID_COLOR, gridwidth=GRID_WIDTH, zeroline=True, zerolinecolor=ZEROLINE_COLOR, zerolinewidth=ZEROLINE_WIDTH, tickfont=dict(size=font_config.small, color=TEXT_COLOR_DARK), tickformat=FLOAT_PRECISION, title=dict( text="y position (m)", font=dict(size=font_config.medium, color=TEXT_COLOR_DARK), ), ), plot_bgcolor=layout_config.plot_bgcolor, hovermode=HOVER_MODE, ) return fig
[docs] def plot_with_overlay( self, variable: npt.NDArray[np.float64], label: Optional[str] = None, theme: Optional[str] = None, font_config: Optional["FontConfig"] = None, layout_config: Optional["LayoutConfig"] = None, colorbar_config: Optional["ColorbarConfig"] = None, ) -> go.Figure: """ Visualizes a scalar variable as color along the 2D track. Parameters ---------- variable : NDArray[float64] Array of values to map as colors (e.g., velocity, power) label : str, optional Colorbar label (default: None) theme : str, optional The color theme to use (default: None, uses default theme) font_config : FontConfig, optional Font configuration object (default: None, uses DEFAULT_FONT_CONFIG) layout_config : LayoutConfig, optional Layout configuration object (default: None, uses DEFAULT_LAYOUT_CONFIG) colorbar_config : ColorbarConfig, optional Colorbar configuration object (default: None, uses DEFAULT_COLORBAR_CONFIG) Returns ------- go.Figure Plotly figure object """ # Use provided configs or defaults theme_colors = get_theme(theme) font_config = font_config or DEFAULT_FONT_CONFIG layout_config = layout_config or DEFAULT_LAYOUT_CONFIG colorbar_config = colorbar_config or DEFAULT_COLORBAR_CONFIG variable = np.asarray(variable) n_track = len(self.x_m) n_var = len(variable) # Interpolate if lengths differ if n_var != n_track: x_old = np.linspace(0, 1, n_var) x_new = np.linspace(0, 1, n_track) interp_func = interp1d( x_old, variable, kind="linear", fill_value="extrapolate" ) variable = interp_func(x_new) x = self.x_m y = self.y_m c = variable # Create scatter plot with color mapping fig = go.Figure() fig.add_trace( go.Scatter( x=x, y=y, mode="markers", marker=dict( size=MARKER_SIZE_LARGE, color=c, colorscale=theme_colors["colorscale"], colorbar=dict( title=dict( text=label or "", font=dict(size=font_config.medium), ), thickness=colorbar_config.thickness, len=colorbar_config.length, tickfont=dict(size=font_config.small), ), showscale=True, ), showlegend=False, hovertemplate=f"x: %{{x:{FLOAT_PRECISION}}}<br>y: %{{y:{FLOAT_PRECISION}}}<br>value: %{{marker.color:{FLOAT_PRECISION}}}<extra></extra>", ) ) fig.update_layout( width=layout_config.width, height=layout_config.height, margin=layout_config.margin, xaxis=dict( scaleanchor="y", scaleratio=1, showgrid=True, gridcolor=GRID_COLOR, gridwidth=GRID_WIDTH, zeroline=True, zerolinecolor=ZEROLINE_COLOR, zerolinewidth=ZEROLINE_WIDTH, tickfont=dict(size=font_config.small, color=TEXT_COLOR_DARK), tickformat=FLOAT_PRECISION, title=dict( text="x position (m)", font=dict(size=font_config.medium, color=TEXT_COLOR_DARK), ), ), yaxis=dict( showgrid=True, gridcolor=GRID_COLOR, gridwidth=GRID_WIDTH, zeroline=True, zerolinecolor=ZEROLINE_COLOR, zerolinewidth=ZEROLINE_WIDTH, tickfont=dict(size=font_config.small, color=TEXT_COLOR_DARK), tickformat=FLOAT_PRECISION, title=dict( text="y position (m)", font=dict(size=font_config.medium, color=TEXT_COLOR_DARK), ), ), plot_bgcolor=layout_config.plot_bgcolor, hovermode=HOVER_MODE, ) return fig
[docs] def plot_3d( self, variable: npt.NDArray[np.float64], title: Optional[str] = None, z_label: Optional[str] = None, theme: Optional[str] = None, font_config: Optional["FontConfig"] = None, layout_config: Optional["LayoutConfig"] = None, scene_config: Optional["SceneConfig"] = None, ) -> go.Figure: """ Maps a variable onto the z-axis of the track for 3D visualization using Plotly. Creates a 3D stem plot with vertical lines extending from baseline to data points. Parameters ---------- variable : NDArray[float64] Array of values to map to z-axis (e.g., velocity, power) title : str, optional Plot title (default: None) z_label : str, optional Z-axis label (default: None) theme: str, optional The color theme to use (default: None) font_config : FontConfig, optional Font configuration object (default: None, uses DEFAULT_FONT_CONFIG) layout_config : LayoutConfig, optional Layout configuration object (default: None, uses DEFAULT_LAYOUT_CONFIG) scene_config : SceneConfig, optional Scene configuration object for 3D view (default: None, uses DEFAULT_SCENE_CONFIG) Returns ------- go.Figure Plotly figure object """ variable = np.asarray(variable) n_track = len(self.x_m) n_var = len(variable) # Interpolate if lengths differ if n_var != n_track: x_old = np.linspace(0, 1, n_var) x_new = np.linspace(0, 1, n_track) interp_func = interp1d( x_old, variable, kind="linear", fill_value="extrapolate" ) variable = interp_func(x_new) # Delegate to the 3D stem plot utility return plot_3d_stem( x=self.x_m, y=self.y_m, z=variable, title=title, z_label=z_label, theme=theme, font_config=font_config, layout_config=layout_config, scene_config=scene_config, )