import numba
import numpy as np
import numpy.typing as npt
from plotly import graph_objects as go
from tqdm import tqdm
from ..constants import DELIMITER, title_block
from ..core_data_structures.data_instance import DataInstance
from ..core_data_structures.single_run_data import SingleRunData
from ..plotting.diff_plotter import plot_diff_bars
from ..plotting.plotting_constants import *
from ..units import _to_seconds
def _pct(num: float, den: float) -> float:
"""Return num/den as a percentage, or 0.0 if den is non-positive."""
return 0.0 if den <= 0 else (num / den) * 100.0
@numba.njit(cache=True)
def _get_diff_timestamps_core(
ts_a: npt.NDArray[np.float64],
va: npt.NDArray[np.float64],
ts_b: npt.NDArray[np.float64],
vb: npt.NDArray[np.float64],
tol: np.float64,
diff_rtol: float,
diff_atol: float,
) -> tuple[
numba.typed.List,
numba.typed.List,
numba.typed.List,
numba.typed.List,
int,
int,
]:
"""JIT-compiled two-pointer core; returns typed lists and final (i, j) indices."""
rpi_extra = numba.typed.List.empty_list(numba.float64)
server_extra = numba.typed.List.empty_list(numba.float64)
diff_ts = numba.typed.List.empty_list(numba.float64)
matched_ts = numba.typed.List.empty_list(numba.float64)
i = 0
j = 0
n = len(ts_a)
m = len(ts_b)
while i < n and j < m:
if ts_a[i] < ts_b[j] - tol:
rpi_extra.append(ts_a[i])
i += 1
continue
if ts_b[j] < ts_a[i] - tol:
server_extra.append(ts_b[j])
j += 1
continue
# Matched pair within timestamp tolerance — inline np.isclose
a_val = va[i]
b_val = vb[j]
a_nan = a_val != a_val # noqa: PLR0124
b_nan = b_val != b_val # noqa: PLR0124
if a_nan and b_nan:
values_close = True
elif a_nan or b_nan:
values_close = False
else:
values_close = abs(a_val - b_val) <= diff_atol + diff_rtol * abs(b_val)
if not values_close:
diff_ts.append(ts_a[i])
else:
matched_ts.append(ts_a[i])
i += 1
j += 1
return rpi_extra, server_extra, diff_ts, matched_ts, i, j
def _get_diff_timestamps(
base_ts: npt.NDArray[np.float64],
base_vals: npt.NDArray[np.float64],
incom_ts: npt.NDArray[np.float64],
incom_vals: npt.NDArray[np.float64],
timestamp_tolerance_s: float = 0.002,
diff_rtol: float = 1e-3,
diff_atol: float = 1e-3,
) -> tuple[
npt.NDArray[np.float64],
npt.NDArray[np.float64],
npt.NDArray[np.float64],
npt.NDArray[np.float64],
]:
"""Compare two time-series point-by-point and classify each point.
Uses a two-pointer walk over sorted timestamps. Points within
``timestamp_tolerance_s`` of each other are considered matched; their
values are then compared with ``diff_rtol``/``diff_atol``. Points with no
match in the other series are counted as extras.
Parameters
----------
base_ts:
Sorted float64 timestamp array for the base run (seconds).
base_vals:
float64 value array aligned with ``base_ts``.
incom_ts:
Sorted float64 timestamp array for the incoming run (seconds).
incom_vals:
float64 value array aligned with ``incom_ts``.
timestamp_tolerance_s:
Maximum timestamp difference (seconds) to treat two points as matching.
diff_rtol:
Relative tolerance for value comparison.
diff_atol:
Absolute tolerance for value comparison.
Returns
-------
tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]
(base_extra_ts, incom_extra_ts, value_mismatch_ts, matched_ts)
"""
ts_a = base_ts.astype(np.float64, copy=False)
va = base_vals.astype(np.float64, copy=False)
ts_b = incom_ts.astype(np.float64, copy=False)
vb = incom_vals.astype(np.float64, copy=False)
# Fast exact match shortcut
if np.array_equal(ts_a, ts_b) and np.array_equal(va, vb, equal_nan=True):
empty = np.array([], dtype=np.float64)
return empty, empty, empty, ts_a.copy()
tol = np.float64(max(timestamp_tolerance_s, 0.0))
rpi_extra_list, server_extra_list, diff_ts_list, matched_ts_list, i, j = (
_get_diff_timestamps_core(ts_a, va, ts_b, vb, tol, diff_rtol, diff_atol)
)
rpi_extra = np.array(rpi_extra_list, dtype=np.float64)
server_extra = np.array(server_extra_list, dtype=np.float64)
diff_ts = np.array(diff_ts_list, dtype=np.float64)
matched_ts = np.array(matched_ts_list, dtype=np.float64)
tail_a = ts_a[i:]
tail_b = ts_b[j:]
return (
np.concatenate([rpi_extra, tail_a]),
np.concatenate([server_extra, tail_b]),
diff_ts,
np.concatenate([matched_ts, tail_a, tail_b]),
)
[docs]
def diff(
rpi_data: SingleRunData,
server_data: SingleRunData,
timestamp_tolerance_s: float = 0.002,
diff_rtol: float = 1e-3,
diff_atol: float = 1e-3,
diff_plot_config: DiffPlotConfig = DEFAULT_DIFF_PLOT_CONFIG,
layout_config: LayoutConfig = DEFAULT_LAYOUT_CONFIG,
font_config: FontConfig = DEFAULT_FONT_CONFIG,
) -> go.Figure:
"""Compare two SingleRunData objects and report differences.
Performs a three-stage comparison:
1. Variable-name alignment: reports C++ names present in one run but not the other
2. Point-level diff: for each common variable, classifies every data point
as a base-only extra, incoming-only extra, value mismatch, or match.
3. Summary + plot: prints a diff summary table and displays an interactive
Plotly bar chart of per-bucket diff counts.
Timestamps from each run are converted to seconds using the run's
``timestamp_unit`` metadata before comparison, so runs with different
logging units are handled correctly.
Parameters
----------
rpi_data:
The reference (baseline) run.
server_data:
The incoming run to compare against the baseline.
timestamp_tolerance_s:
Maximum timestamp delta (seconds) to consider two points as matching.
Defaults to 0.002 (2 ms).
diff_rtol:
Relative tolerance for value comparison via ``np.isclose``. Defaults to 1e-3.
diff_atol:
Absolute tolerance for value comparison via ``np.isclose``. Defaults to 1e-3.
"""
base_cpp_name_to_id = rpi_data.cpp_name_to_id
base_id_to_instance = rpi_data.id_to_instance
server_cpp_name_to_id = server_data.cpp_name_to_id
server_id_to_instance = server_data.id_to_instance
# ===== Stage 1: Compare Variables =====
in_rpi_not_in_server = []
in_server_not_in_rpi = []
# Shared C++ names -> (base instance, incoming instance)
shared_cpp_name_to_instances: dict[str, tuple[DataInstance, DataInstance]] = {}
for rpi_cpp_name, base_id in base_cpp_name_to_id.items():
if rpi_cpp_name in server_cpp_name_to_id:
shared_cpp_name_to_instances[rpi_cpp_name] = (
base_id_to_instance[base_id],
server_id_to_instance[server_cpp_name_to_id[rpi_cpp_name]],
)
else:
in_rpi_not_in_server.append(rpi_cpp_name)
for server_cpp_name in server_cpp_name_to_id:
if server_cpp_name not in base_cpp_name_to_id:
in_server_not_in_rpi.append(server_cpp_name)
has_mismatch = bool(in_rpi_not_in_server or in_server_not_in_rpi)
if has_mismatch:
print(
title_block(
f"Mismatched Variables: {len(in_rpi_not_in_server) + len(in_server_not_in_rpi)}"
)
)
if in_rpi_not_in_server:
print(f" Only in RPI: {in_rpi_not_in_server}")
if in_server_not_in_rpi:
print(f" Only in server: {in_server_not_in_rpi}")
print()
else:
print(title_block("All C++ names match") + "\n")
# ===== Stage 2: Compare DataInstances =====
rpi_extra_ts_list: list[npt.NDArray[np.float64]] = []
server_extra_ts_list: list[npt.NDArray[np.float64]] = []
diff_ts_list: list[npt.NDArray[np.float64]] = []
matched_ts_list: list[npt.NDArray[np.float64]] = []
total_rpi_entries = 0
total_server_entries = 0
total_rpi_extra_entries = 0
total_server_extra_entries = 0
total_matched_entries = 0
total_diff_entries = 0
timestamps_compared = 0
pbar = tqdm(
shared_cpp_name_to_instances.items(),
desc="Comparing matching variables",
unit=" vars",
total=len(shared_cpp_name_to_instances),
)
for _, (rpi_di, server_di) in pbar:
rpi_ts_s = _to_seconds(
rpi_di.timestamp_np.astype(np.float64), rpi_data.timestamp_unit
)
server_ts_s = _to_seconds(
server_di.timestamp_np.astype(np.float64), server_data.timestamp_unit
)
rpi_extra_ts, server_extra_ts, var_diff_ts, var_matched_ts = (
_get_diff_timestamps(
rpi_ts_s,
rpi_di.value_np,
server_ts_s,
server_di.value_np,
timestamp_tolerance_s=timestamp_tolerance_s,
diff_rtol=diff_rtol,
diff_atol=diff_atol,
)
)
rpi_extra_ts_list.append(rpi_extra_ts)
server_extra_ts_list.append(server_extra_ts)
diff_ts_list.append(var_diff_ts)
matched_ts_list.append(var_matched_ts)
total_rpi_entries += rpi_di.timestamp_np.size
total_server_entries += server_di.timestamp_np.size
total_rpi_extra_entries += rpi_extra_ts.size
total_server_extra_entries += server_extra_ts.size
total_matched_entries += var_matched_ts.size
total_diff_entries += var_diff_ts.size
timestamps_compared += rpi_di.timestamp_np.size + server_di.timestamp_np.size
pbar.set_postfix({"timestamps": timestamps_compared})
pbar.clear()
pbar.close()
rpi_extra_all = (
np.concatenate(rpi_extra_ts_list)
if rpi_extra_ts_list
else np.array([], dtype=np.float64)
)
server_extra_all = (
np.concatenate(server_extra_ts_list)
if server_extra_ts_list
else np.array([], dtype=np.float64)
)
diff_all = (
np.concatenate(diff_ts_list) if diff_ts_list else np.array([], dtype=np.float64)
)
total_present_all = (
np.concatenate(
rpi_extra_ts_list + server_extra_ts_list + matched_ts_list + diff_ts_list
)
if (rpi_extra_ts_list + server_extra_ts_list + matched_ts_list + diff_ts_list)
else np.array([], dtype=np.float64)
)
# ===== Stage 3: Summary + Plot =====
rows = [
("Matched variables:", str(len(shared_cpp_name_to_instances))),
("Total RPI entries:", str(total_rpi_entries)),
("Total server entries:", str(total_server_entries)),
("Matched entries:", str(total_matched_entries)),
("RPI-only entries:", str(total_rpi_extra_entries)),
("Server-only entries:", str(total_server_extra_entries)),
("Value mismatch entries:", str(total_diff_entries)),
(
"RPI entries lost:",
f"{_pct(total_rpi_extra_entries, total_rpi_entries):.3f}%",
),
(
"Server entries lost:",
f"{_pct(total_server_extra_entries, total_server_entries):.3f}%",
),
]
col_width = max(len(label) for label, _ in rows) + 2
print(title_block("Diff Summary"))
for label, value in rows:
print(f"{label:<{col_width}} {value}")
print(DELIMITER)
fig = plot_diff_bars(
base_extra_ts=rpi_extra_all,
incom_extra_ts=server_extra_all,
value_mismatch_ts=diff_all,
total_present_ts=total_present_all,
title="Diff Counts Over Time",
diff_plot_config=diff_plot_config,
layout_config=layout_config,
font_config=font_config,
)
return fig