Source code for suboptimumg.plotting.utils

from enum import Enum
from typing import Iterable, Optional

import numpy as np
from plotly.colors import qualitative
from scipy.interpolate import RegularGridInterpolator, griddata, interp1d
from scipy.ndimage import gaussian_filter, gaussian_filter1d

from .plotting_constants import *


[docs] class DataType3D(Enum): # z_list is 2D array corresponding to x/y grid GridInput = "GridInput" # z_list is 1D array of values at scattered (x,y) points ScatterInput = "ScatterInput"
[docs] def prepare_smooth_data_2D( x_list, y_list, smoothing_config: SmoothingConfig, ): """ Helper function to prepare smoothed data for plots. Parameters ---------- x_list : numpy.ndarray Numpy array of x coordinates y_list : numpy.ndarray Numpy array of y coordinates smoothing_config : SmoothingConfig SmoothingConfig object Returns ------- x_dense : numpy.ndarray Smoothed/interpolated x coordinates y_dense : numpy.ndarray Smoothed/interpolated y coordinates """ if smoothing_config.interp_factor > 1: x_dense = np.linspace( x_list.min(), x_list.max(), len(x_list) * smoothing_config.interp_factor ) f = interp1d( x_list, y_list, kind=smoothing_config.interp_method, bounds_error=False, fill_value="extrapolate", ) y_dense = f(x_dense) else: x_dense, y_dense = x_list, y_list if smoothing_config.smoothing_sigma > 0: y_dense = gaussian_filter1d(y_dense, sigma=smoothing_config.smoothing_sigma) return x_dense, y_dense
[docs] def validate_data(x_list, y_list, z_list) -> DataType3D: """ Determine if the input data is grid-based or scatter-based. Parameters ---------- x_list : array_like X coordinates y_list : array_like Y coordinates z_list : array_like Z values Returns ------- DataType3D Enum indicating whether data is GridInput or ScatterInput Raises ------ ValueError If data dimensions are invalid or mismatched """ z_ndim = np.ndim(z_list) if z_ndim == 2: # Grid data - validate dimensions if z_list.shape != (len(x_list), len(y_list)): raise ValueError( f"Grid data dimension mismatch: z_list.shape {z_list.shape} " f"does not match (len(x_list), len(y_list)) = ({len(x_list)}, {len(y_list)})" ) return DataType3D.GridInput elif z_ndim == 1: # Scatter data - validate all have same length if not (len(x_list) == len(y_list) == len(z_list)): raise ValueError( f"Scatter data length mismatch: len(x_list)={len(x_list)}, " f"len(y_list)={len(y_list)}, len(z_list)={len(z_list)}. " f"All must be equal for scatter data." ) return DataType3D.ScatterInput else: raise ValueError( f"Invalid z_list dimensions: expected 1D or 2D array, got {z_ndim}D" )
[docs] def prepare_smooth_data_3D( x_list, y_list, z_list, smoothing_config: Optional[SmoothingConfig] = None, ): """ Helper function to prepare smoothed data for 3D grid plots. Requires grid-based input data. Parameters ---------- x_list : numpy.ndarray 1D array of x coordinates defining the grid x-axis y_list : numpy.ndarray 1D array of y coordinates defining the grid y-axis z_list : numpy.ndarray 2D array of z values with shape (len(x_list), len(y_list)) smoothing_config : SmoothingConfig, optional Smoothing configuration object Returns ------- x_interp : numpy.ndarray Interpolated x coordinates y_interp : numpy.ndarray Interpolated y coordinates z_interp : numpy.ndarray Smoothed and interpolated z values """ if smoothing_config is None: smoothing_config = DEFAULT_SMOOTHING_CONFIG # Create dense interpolation grid x_interp = np.linspace( np.min(x_list), np.max(x_list), len(x_list) * smoothing_config.interp_factor ) y_interp = np.linspace( np.min(y_list), np.max(y_list), len(y_list) * smoothing_config.interp_factor ) x_grid, y_grid = np.meshgrid(x_interp, y_interp, indexing="ij") # Use RegularGridInterpolator for regularly-spaced grid data interp_func = RegularGridInterpolator( (x_list, y_list), z_list, method=smoothing_config.interp_method, bounds_error=False, ) points = np.vstack([x_grid.ravel(), y_grid.ravel()]).T z_interp = interp_func(points).reshape(x_grid.shape) # Optional smoothing if smoothing_config.smoothing_sigma > 0: z_interp = gaussian_filter(z_interp, sigma=smoothing_config.smoothing_sigma) # Transpose z so it matches x/y meshgrid indexing for plotting return x_interp, y_interp, z_interp.T
[docs] def prepare_smooth_data_3D_scatter( x_list, y_list, z_list, smoothing_config: Optional[SmoothingConfig] = None, ): """ Helper function to interpolate scattered 3D data onto a regular grid. Requires scatter-based input data (1D x, y, z arrays). Parameters ---------- x_list : numpy.ndarray 1D array of x coordinates y_list : numpy.ndarray 1D array of y coordinates z_list : numpy.ndarray 1D array of z values smoothing_config : SmoothingConfig, optional SmoothingConfig object (uses DEFAULT_SMOOTHING_CONFIG if None) Returns ------- x_interp : numpy.ndarray Regular grid x coordinates y_interp : numpy.ndarray Regular grid y coordinates z_interp : numpy.ndarray Interpolated z values on regular grid suitable for surface/contour plots """ if smoothing_config is None: smoothing_config = DEFAULT_SMOOTHING_CONFIG # Determine grid resolution based on number of scatter points num_points = len(x_list) grid_resolution = int(np.sqrt(num_points) * smoothing_config.interp_factor) grid_resolution = max(grid_resolution, 50) # Minimum resolution # Create regular grid for interpolation x_interp = np.linspace(np.min(x_list), np.max(x_list), grid_resolution) y_interp = np.linspace(np.min(y_list), np.max(y_list), grid_resolution) x_grid, y_grid = np.meshgrid(x_interp, y_interp, indexing="ij") # Stack scatter points for griddata points = np.column_stack([x_list, y_list]) # Interpolate using griddata with cubic method for smooth surfaces # Fall back to linear if cubic fails (can happen with certain point distributions) try: z_interp = griddata( points, z_list, (x_grid, y_grid), method="cubic", fill_value=np.nan ) except: z_interp = griddata( points, z_list, (x_grid, y_grid), method="linear", fill_value=np.nan ) # Optional smoothing with Gaussian filter if smoothing_config.smoothing_sigma > 0: # Handle NaN values before smoothing mask = ~np.isnan(z_interp) if np.any(mask): z_interp_smooth = z_interp.copy() z_interp_smooth[mask] = gaussian_filter( z_interp[mask].reshape(-1), sigma=smoothing_config.smoothing_sigma ) z_interp = z_interp_smooth # Transpose z so it matches x/y meshgrid indexing for plotting return x_interp, y_interp, z_interp.T
[docs] def color_cycle(n: int) -> Iterable[str]: """ Get n distinct colors from Plotly's qualitative palette. Parameters ---------- n : int Number of colors needed Returns ------- Iterable[str] Iterable of color strings, repeating if n exceeds palette size """ palette: Iterable[str] = qualitative.Plotly if n <= len(palette): return palette[:n] times = n // len(palette) + 1 return (palette * times)[:n]