Source code for perda.plotting.subplots

from typing import List, Tuple

import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from ..core_data_structures.data_instance import DataInstance
from ..core_data_structures.single_run_data import SingleRunData
from ..units import Timescale, _to_seconds
from .plotting_constants import *


[docs] def data_instance_subplots( rows: List[List[DataInstance]], title: str | None = None, row_y_labels: List[str | None] | None = None, show_legend: bool = True, layout_config: LayoutConfig = DEFAULT_LAYOUT_CONFIG, font_config: FontConfig = DEFAULT_FONT_CONFIG, timestamp_unit: Timescale = Timescale.MS, ) -> go.Figure: """Plot groups of DataInstances as stacked subplots on a shared time axis. Each element of ``rows`` becomes one subplot row. Multiple DataInstances in the same row are overlaid on that row's y-axis — useful for comparing signals that share the same units or scale. Parameters ---------- rows : List[List[DataInstance]] Outer list defines subplot rows (top to bottom); inner list defines the DataInstances overlaid within that row. title : str | None, optional Figure-level title. Default is None (no title). row_y_labels : List[str | None] | None, optional Y-axis label for each row. Must match the length of ``rows`` when provided. ``None`` entries fall back to auto-labelling from the DataInstance labels in that row. Default is None (all rows auto-label). show_legend : bool, optional Whether to show the figure legend. Default is True. layout_config : LayoutConfig, optional Dimensions, spacing, and style for the subplot grid. font_config : FontConfig, optional Font sizes for title, axis labels, tick labels, and legend. timestamp_unit : Timescale, optional Timestamp unit of the underlying DataInstances. Converted to seconds for x-axis display. Default is Timescale.MS. Returns ------- go.Figure Plotly figure containing the stacked subplot grid. Examples -------- >>> fig = data_instance_subplots( ... rows=[[speed_di], [torque_di, motor_di]], ... title="Run Overview", ... row_y_labels=["Speed (mph)", "Torque / Motor"], ... ) >>> fig.show() """ if not rows: print("Warning: No rows provided for subplot figure") return go.Figure() n = len(rows) subplot_titles = [] for i, row_dis in enumerate(rows): if row_y_labels and i < len(row_y_labels) and row_y_labels[i] is not None: subplot_titles.append(row_y_labels[i]) # type: ignore[arg-type] else: labels = [di.label for di in row_dis if di.label] subplot_titles.append(", ".join(labels) if labels else "") fig = make_subplots( rows=n, cols=1, shared_xaxes=True, vertical_spacing=layout_config.grid_vertical_spacing, subplot_titles=subplot_titles, ) for row_idx, row_dis in enumerate(rows, start=1): for di in row_dis: if len(di) == 0: print(f"Warning: No data points in DataInstance for {di.label}") continue timestamps_s = _to_seconds(di.timestamp_np.astype(float), timestamp_unit) fig.add_trace( go.Scattergl( x=timestamps_s, y=di.value_np, mode="lines", name=di.label, legendgroup=str(row_idx), ), row=row_idx, col=1, ) # Resolve y-axis label for this row if row_y_labels and row_idx - 1 < len(row_y_labels): y_label = row_y_labels[row_idx - 1] else: labels = [di.label for di in row_dis if di.label] y_label = ", ".join(labels) if labels else None fig.update_yaxes( title_text=y_label, title_font=dict(size=font_config.medium), tickfont=dict(size=font_config.small), row=row_idx, col=1, ) fig.update_xaxes( title_text="Time (s)", title_font=dict(size=font_config.medium), tickfont=dict(size=font_config.small), row=n, col=1, ) fig.update_layout( title=dict( text=title, x=0.5, xanchor="center", yanchor="top", font=dict(size=font_config.large), ), height=layout_config.grid_height_per_row * n, width=layout_config.width, margin=layout_config.margin, plot_bgcolor=layout_config.plot_bgcolor, showlegend=show_legend, hovermode="x unified", legend=dict(font=dict(size=font_config.small)), ) return fig
[docs] def stride_downsample( timestamps_s: npt.NDArray[np.float64], values: npt.NDArray[np.float64], max_display_resolution: float, ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: """Return stride-downsampled arrays capped at *max_display_resolution*. Parameters ---------- timestamps_s : np.ndarray Time values in seconds. values : np.ndarray Corresponding sample values. max_display_resolution : float Target display sample rate. Returns ------- Tuple[np.ndarray, np.ndarray] Downsampled ``(timestamps_s, values)``. """ n = len(timestamps_s) if n <= 1: return timestamps_s, values duration_s = float(timestamps_s[-1] - timestamps_s[0]) if duration_s <= 0: return timestamps_s, values actual_hz = n / duration_s if actual_hz <= max_display_resolution: print( "Data resolution is below the max display resolution, skipping downsampling." ) return timestamps_s, values step = max(1, int(round(actual_hz / max_display_resolution))) idx = np.arange(0, n, step) return timestamps_s[idx], values[idx]
[docs] def plot_multi_log_subplots( logs: List[SingleRunData], var_names: List[str], row_y_labels: List[str | None] | None = None, title: str | None = None, layout_config: LayoutConfig = DEFAULT_LAYOUT_CONFIG, font_config: FontConfig = DEFAULT_FONT_CONFIG, timestamp_unit: Timescale = Timescale.MS, ) -> go.Figure: """Plot the same set of variables across multiple logs as stacked subplots, with each variable on one subplot. Parameters ---------- logs : List[SingleRunData] var_names : List[str] Ordered list of variable names to plot, one per subplot row. Variables absent from log are silently skipped. row_y_labels : List[str | None] | None, optional Y-axis label for each row. Must match the length of ``var_names`` title : str | None, optional Base figure title. When multiple figures are produced a ``(1/N)`` suffix is appended automatically. layout_config : LayoutConfig, optional font_config : FontConfig, optional timestamp_unit : Timescale, optional Timestamp unit of the underlying DataInstances. Default is Timescale.MS. Returns ------- go.Figure The combined figure with all subplots. Examples -------- >>> fig = plot_multi_log_subplots( ... logs=[aly1.data, aly2.data], ... var_names=["pcm."pcm.wheelSpeeds.frontLeft", "ludwig.steeringWheel.angle", "pcm.pedals.brakePressure.front"], ... title="Speed Comparison", ... ) >>> fig.show() """ if not logs or not var_names: print("No logs or variable names provided.") return go.Figure() n_vars = len(var_names) fig = make_subplots( rows=n_vars, cols=1, shared_xaxes=True, vertical_spacing=layout_config.grid_vertical_spacing, subplot_titles=var_names, ) for var_idx, var_name in enumerate(var_names, start=1): first_valid_log_idx = 0 for log_idx, srd in enumerate(logs): if var_name not in srd or len(srd[var_name]) == 0: continue elif first_valid_log_idx == 0: first_valid_log_idx = log_idx di = srd[var_name] timestamps_s = _to_seconds(di.timestamp_np.astype(float), timestamp_unit) values = di.value_np if layout_config.max_display_resolution: timestamps_s, values = stride_downsample( timestamps_s, values, layout_config.max_display_resolution ) fig.add_trace( go.Scattergl( x=timestamps_s, y=values, mode="lines", line=dict(width=1), name=var_name, legendgroup="Log " + str(log_idx), showlegend=(log_idx == first_valid_log_idx), ), row=var_idx, col=1, ) if row_y_labels and len(row_y_labels) == n_vars: for var_idx, y_label in enumerate(row_y_labels, start=1): fig.update_yaxes( title_text=y_label, title_font=dict(size=font_config.medium), tickfont=dict(size=font_config.small), row=var_idx, col=1, ) fig.update_xaxes( title_text="Time (s)", title_font=dict(size=font_config.medium), tickfont=dict(size=font_config.small), row=n_vars, col=1, ) fig.update_layout( title=dict( text=title, x=0.5, xanchor="center", yanchor="top", font=dict(size=font_config.large), ), height=layout_config.grid_height_per_row * n_vars, width=layout_config.width, margin=layout_config.margin, plot_bgcolor=layout_config.plot_bgcolor, hovermode="x unified", legend=dict(font=dict(size=font_config.small)), ) return fig