from enum import Enum
import numpy as np
from scipy.interpolate import RegularGridInterpolator, griddata, interp1d
from scipy.ndimage import gaussian_filter, gaussian_filter1d
from .plotting_constants import *
[docs]
class DataType3D(Enum):
GridInput = "GridInput"
ScatterInput = "ScatterInput"
[docs]
def prepare_smooth_data_2D(
x_list,
y_list,
smoothing_config: SmoothingConfig,
):
"""
Helper function to prepare smoothed data for plots.
Parameters
----------
x_list : numpy.ndarray
Numpy array of x coordinates
y_list : numpy.ndarray
Numpy array of y coordinates
smoothing_config : SmoothingConfig
SmoothingConfig object
Returns
-------
x_dense : numpy.ndarray
Smoothed/interpolated x coordinates
y_dense : numpy.ndarray
Smoothed/interpolated y coordinates
"""
if smoothing_config.interp_factor > 1:
x_dense = np.linspace(
x_list.min(), x_list.max(), len(x_list) * smoothing_config.interp_factor
)
f = interp1d(
x_list,
y_list,
kind=smoothing_config.interp_method,
bounds_error=False,
fill_value="extrapolate",
)
y_dense = f(x_dense)
else:
x_dense, y_dense = x_list, y_list
if smoothing_config.smoothing_sigma > 0:
y_dense = gaussian_filter1d(y_dense, sigma=smoothing_config.smoothing_sigma)
return x_dense, y_dense
def _validate_data(x_list, y_list, z_list) -> DataType3D:
"""
Determine if the input data is grid-based or scatter-based.
Parameters
----------
x_list : array_like
X coordinates
y_list : array_like
Y coordinates
z_list : array_like
Z values
Returns
-------
DataType3D
Enum indicating whether data is GridInput or ScatterInput
Raises
------
ValueError
If data dimensions are invalid or mismatched
"""
z_ndim = np.ndim(z_list)
if z_ndim == 2:
# Grid data - validate dimensions
if z_list.shape != (len(x_list), len(y_list)):
raise ValueError(
f"Grid data dimension mismatch: z_list.shape {z_list.shape} "
f"does not match (len(x_list), len(y_list)) = ({len(x_list)}, {len(y_list)})"
)
return DataType3D.GridInput
elif z_ndim == 1:
# Scatter data - validate all have same length
if not (len(x_list) == len(y_list) == len(z_list)):
raise ValueError(
f"Scatter data length mismatch: len(x_list)={len(x_list)}, "
f"len(y_list)={len(y_list)}, len(z_list)={len(z_list)}. "
f"All must be equal for scatter data."
)
return DataType3D.ScatterInput
else:
raise ValueError(
f"Invalid z_list dimensions: expected 1D or 2D array, got {z_ndim}D"
)
[docs]
def prepare_smooth_data_3D(
x_list,
y_list,
z_list,
smoothing_config: Optional[SmoothingConfig] = None,
):
"""
Helper function to prepare smoothed data for 3D grid plots.
Requires grid-based input data.
Parameters
----------
x_list : numpy.ndarray
1D array of x coordinates defining the grid x-axis
y_list : numpy.ndarray
1D array of y coordinates defining the grid y-axis
z_list : numpy.ndarray
2D array of z values with shape (len(x_list), len(y_list))
smoothing_config : SmoothingConfig, optional
Smoothing configuration object
Returns
-------
x_interp : numpy.ndarray
Interpolated x coordinates
y_interp : numpy.ndarray
Interpolated y coordinates
z_interp : numpy.ndarray
Smoothed and interpolated z values
"""
if smoothing_config is None:
smoothing_config = DEFAULT_SMOOTHING_CONFIG
# Create dense interpolation grid
x_interp = np.linspace(
np.min(x_list), np.max(x_list), len(x_list) * smoothing_config.interp_factor
)
y_interp = np.linspace(
np.min(y_list), np.max(y_list), len(y_list) * smoothing_config.interp_factor
)
x_grid, y_grid = np.meshgrid(x_interp, y_interp, indexing="ij")
# Use RegularGridInterpolator for regularly-spaced grid data
interp_func = RegularGridInterpolator(
(x_list, y_list),
z_list,
method=smoothing_config.interp_method,
bounds_error=False,
)
points = np.vstack([x_grid.ravel(), y_grid.ravel()]).T
z_interp = interp_func(points).reshape(x_grid.shape)
# Optional smoothing
if smoothing_config.smoothing_sigma > 0:
z_interp = gaussian_filter(z_interp, sigma=smoothing_config.smoothing_sigma)
# Transpose z so it matches x/y meshgrid indexing for plotting
return x_interp, y_interp, z_interp.T
[docs]
def prepare_smooth_data_3D_scatter(
x_list,
y_list,
z_list,
smoothing_config: Optional[SmoothingConfig] = None,
):
"""
Helper function to interpolate scattered 3D data onto a regular grid.
Requires scatter-based input data (1D x, y, z arrays).
Parameters
----------
x_list : numpy.ndarray
1D array of x coordinates
y_list : numpy.ndarray
1D array of y coordinates
z_list : numpy.ndarray
1D array of z values
smoothing_config : SmoothingConfig, optional
SmoothingConfig object (uses DEFAULT_SMOOTHING_CONFIG if None)
Returns
-------
x_interp : numpy.ndarray
Regular grid x coordinates
y_interp : numpy.ndarray
Regular grid y coordinates
z_interp : numpy.ndarray
Interpolated z values on regular grid suitable for surface/contour plots
"""
if smoothing_config is None:
smoothing_config = DEFAULT_SMOOTHING_CONFIG
# Determine grid resolution based on number of scatter points
num_points = len(x_list)
grid_resolution = int(np.sqrt(num_points) * smoothing_config.interp_factor)
grid_resolution = max(grid_resolution, 50) # Minimum resolution
# Create regular grid for interpolation
x_interp = np.linspace(np.min(x_list), np.max(x_list), grid_resolution)
y_interp = np.linspace(np.min(y_list), np.max(y_list), grid_resolution)
x_grid, y_grid = np.meshgrid(x_interp, y_interp, indexing="ij")
# Stack scatter points for griddata
points = np.column_stack([x_list, y_list])
# Interpolate using griddata with cubic method for smooth surfaces
# Fall back to linear if cubic fails (can happen with certain point distributions)
try:
z_interp = griddata(
points, z_list, (x_grid, y_grid), method="cubic", fill_value=np.nan
)
except:
z_interp = griddata(
points, z_list, (x_grid, y_grid), method="linear", fill_value=np.nan
)
# Optional smoothing with Gaussian filter
if smoothing_config.smoothing_sigma > 0:
# Handle NaN values before smoothing
mask = ~np.isnan(z_interp)
if np.any(mask):
z_interp_smooth = z_interp.copy()
z_interp_smooth[mask] = gaussian_filter(
z_interp[mask].reshape(-1), sigma=smoothing_config.smoothing_sigma
)
z_interp = z_interp_smooth
# Transpose z so it matches x/y meshgrid indexing for plotting
return x_interp, y_interp, z_interp.T