Source code for suboptimumg.plotting.generic_plot

from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import plotly.graph_objects as go
import plotly.io as pio

from ..plotting.color_themes import get_theme
from .plotting_constants import *
from .utils import (
    DataType3D,
    _validate_data,
    prepare_smooth_data_2D,
    prepare_smooth_data_3D,
    prepare_smooth_data_3D_scatter,
)

# Set default plotly template for better aesthetics
pio.templates.default = "plotly_white"


[docs] def plot2D( x_list, y_list, title, x_axis, y_axis, h_line: Optional[ReferenceLineConfig] = None, v_line: Optional[ReferenceLineConfig] = None, fit_curve=False, show_points=False, subtitle=None, data_label=None, theme=None, font_config: Optional[FontConfig] = None, layout_config: Optional[LayoutConfig] = None, smoothing_config: Optional[SmoothingConfig] = None, ): """ Generic 2D plotting utility. Parameters ---------- x_list : numpy.ndarray Numpy array of x coordinates y_list : numpy.ndarray Numpy array of y coordinates title : str Plot title x_axis : str X-axis label y_axis : str Y-axis label h_line : ReferenceLineConfig, optional ReferenceLineConfig for horizontal reference line (None = no line) v_line : ReferenceLineConfig, optional ReferenceLineConfig for vertical reference line (None = no line) fit_curve : bool, optional Whether to fit a curve to the data points show_points : bool, optional Whether to show individual data points subtitle : str, optional Optional subtitle data_label : str, optional Curve label used in legend theme : str, optional Color theme name, see plotting.color_theme font_config : FontConfig, optional FontConfig object for font settings layout_config : LayoutConfig, optional LayoutConfig object for layout settings smoothing_config : SmoothingConfig, optional SmoothingConfig object for smoothing settings Returns ------- go.Figure Plotly figure object """ font_config = font_config or DEFAULT_FONT_CONFIG layout_config = layout_config or DEFAULT_LAYOUT_CONFIG smoothing_config = smoothing_config or DEFAULT_SMOOTHING_CONFIG if fit_curve: x_dense, y_dense = prepare_smooth_data_2D( x_list, y_list, smoothing_config=smoothing_config ) else: x_dense, y_dense = x_list, y_list fig = go.Figure() theme_colors = get_theme(theme) # Interpolated line fig.add_trace( go.Scatter( x=x_dense, y=y_dense, mode="lines", line=dict( color=theme_colors["light"], width=LINE_WIDTH, shape="spline", ), name=data_label, hovertemplate=f"{x_axis}: %{{x:{FLOAT_PRECISION}}}<br>{y_axis}: %{{y:{FLOAT_PRECISION}}}<extra></extra>", ) ) if show_points: fig.add_trace( go.Scatter( x=x_list, y=y_list, mode="markers", marker=dict(color=theme_colors["dark"], size=MARKER_SIZE_LARGE), showlegend=False, hovertemplate=None, ) ) # Add horizontal reference line if specified if h_line is not None and h_line.value is not None: x_min, x_max = min(x_dense), max(x_dense) fig.add_trace( go.Scatter( x=[x_min, x_max], y=[h_line.value, h_line.value], mode="lines", line=dict( color=h_line.color, width=h_line.width, dash=h_line.dash, ), name=f"{h_line.label}", hoverinfo="name", ) ) # Add vertical reference line if specified if v_line is not None and v_line.value is not None: y_min, y_max = min(y_dense), max(y_dense) fig.add_trace( go.Scatter( x=[v_line.value, v_line.value], y=[y_min, y_max], mode="lines", line=dict( color=v_line.color, width=v_line.width, dash=v_line.dash, ), name=f"{v_line.label}", hoverinfo="name", ) ) # Title, Axes, Label, Legend, Size full_title = title if subtitle is not None: full_title += f"<br><span style='font-size: {font_config.medium}px; color: {TEXT_COLOR_LIGHT};'>{subtitle}</span>" fig.update_layout( title={ "text": full_title, "font": dict(size=font_config.large, color=TEXT_COLOR_DARK), "x": layout_config.title_x, "xanchor": layout_config.title_xanchor, }, xaxis_title={ "text": x_axis, "font": dict(size=font_config.medium, color=TEXT_COLOR_DARK), }, yaxis_title={ "text": y_axis, "font": dict(size=font_config.medium, color=TEXT_COLOR_DARK), }, legend_title="Legend", showlegend=(data_label is not None) or (h_line is not None and h_line.value is not None) or (v_line is not None and v_line.value is not None), legend_title_font=dict(size=font_config.medium), legend_font=dict(size=font_config.small), hovermode=HOVER_MODE, plot_bgcolor=layout_config.plot_bgcolor, width=layout_config.width, height=layout_config.height, margin=layout_config.margin, ) # Add subtle grid lines fig.update_xaxes( showgrid=True, gridwidth=GRID_WIDTH, gridcolor=GRID_COLOR, zeroline=True, zerolinewidth=ZEROLINE_WIDTH, zerolinecolor=ZEROLINE_COLOR, tickfont=dict(size=font_config.small), tickformat=FLOAT_PRECISION, ) fig.update_yaxes( showgrid=True, gridwidth=GRID_WIDTH, gridcolor=GRID_COLOR, zeroline=True, zerolinewidth=ZEROLINE_WIDTH, zerolinecolor=ZEROLINE_COLOR, tickfont=dict(size=font_config.small), tickformat=FLOAT_PRECISION, ) if h_line is not None and h_line.value is not None: padding = 1 + RANGE_PADDING fig.update_yaxes( range=[ (1 / padding) * min(min(y_dense), h_line.value), padding * max(max(y_dense), h_line.value), ] ) if v_line is not None and v_line.value is not None: padding = 1 + RANGE_PADDING fig.update_xaxes( range=[ (1 / padding) * min(min(x_dense), v_line.value), padding * max(max(x_dense), v_line.value), ] ) return fig
[docs] def plot3D_surface( x_list, y_list, z_list, title, x_axis, y_axis, z_axis, subtitle=None, theme=None, font_config: Optional[FontConfig] = None, layout_config: Optional[LayoutConfig] = None, colorbar_config: Optional[ColorbarConfig] = None, smoothing_config: Optional[SmoothingConfig] = None, scene_config: Optional[SceneConfig] = None, ): """ Generic 3D surface plotting utility. Handles both grid data (2D z array) and scatter data (1D x, y, z arrays). For grid data: Creates interpolated surface plot with contour lines. For scatter data: Creates 3D scatter plot with points colored by z values. Parameters ---------- x_list : numpy.ndarray X coordinates (1D for grid, 1D for scatter) y_list : numpy.ndarray Y coordinates (1D for grid, 1D for scatter) z_list : numpy.ndarray Z values (2D for grid, 1D for scatter) title : str Plot title x_axis : str X-axis label y_axis : str Y-axis label z_axis : str Z-axis label subtitle : str, optional Optional subtitle theme : str, optional Color theme name font_config : FontConfig, optional FontConfig object for font settings layout_config : LayoutConfig, optional LayoutConfig object for layout settings colorbar_config : ColorbarConfig, optional ColorbarConfig object for colorbar settings smoothing_config : SmoothingConfig, optional SmoothingConfig object for smoothing settings scene_config : SceneConfig, optional SceneConfig object for 3D scene settings Returns ------- go.Figure Plotly figure object """ font_config = font_config or DEFAULT_FONT_CONFIG layout_config = layout_config or DEFAULT_LAYOUT_CONFIG colorbar_config = colorbar_config or DEFAULT_COLORBAR_CONFIG smoothing_config = smoothing_config or DEFAULT_SMOOTHING_CONFIG scene_config = scene_config or DEFAULT_SCENE_CONFIG # Detect data type data_type = _validate_data(x_list, y_list, z_list) # Create figure and add trace fig = go.Figure() theme_colors = get_theme(theme) # Prepare smoothed/interpolated data based on input type match data_type: case DataType3D.GridInput: x_dense, y_dense, z_dense = prepare_smooth_data_3D( x_list, y_list, z_list, smoothing_config=smoothing_config ) case DataType3D.ScatterInput: x_dense, y_dense, z_dense = prepare_smooth_data_3D_scatter( x_list, y_list, z_list, smoothing_config=smoothing_config ) # Calculate contour size (use nanmin/nanmax to handle potential NaN values) z_min, z_max = np.nanmin(z_dense), np.nanmax(z_dense) contour_size = (z_max - z_min) / NUM_CONTOURS if z_max > z_min else 1 # Add surface plot fig.add_trace( go.Surface( x=np.array([x_dense] * len(x_dense)), y=np.array([y_dense] * len(x_dense)).T, z=z_dense, colorscale=theme_colors["colorscale"], contours={ "z": { "show": True, "start": z_min, "end": z_max, "size": contour_size, "width": LINE_WIDTH, "color": CONTOUR_LABEL_COLOR, } }, colorbar=dict( title=dict( text=z_axis, font=dict(size=font_config.medium), ), thickness=colorbar_config.thickness, len=colorbar_config.length, tickfont=dict(size=font_config.small), ), hovertemplate=( f"{x_axis}: %{{x:{FLOAT_PRECISION}}}<br>" f"{y_axis}: %{{y:{FLOAT_PRECISION}}}<br>" f"{z_axis}: %{{z:{FLOAT_PRECISION}}}<extra></extra>" ), ) ) # Overlay original scatter points if input was scatter data if data_type == DataType3D.ScatterInput: fig.add_trace( go.Scatter3d( x=x_list, y=y_list, z=z_list, mode="markers", marker=dict( color="black", size=MARKER_SIZE / 2, opacity=0.5, ), showlegend=False, hovertemplate=( f"Original point<br>" f"{x_axis}: %{{x:{FLOAT_PRECISION}}}<br>" f"{y_axis}: %{{y:{FLOAT_PRECISION}}}<br>" f"{z_axis}: %{{z:{FLOAT_PRECISION}}}<extra></extra>" ), ) ) # Configure layout full_title = title if subtitle is not None: full_title += f"<br><span style='font-size: {font_config.medium}px; color: {TEXT_COLOR_LIGHT};'>{subtitle}</span>" fig.update_layout( title={ "text": full_title, "font": dict(size=font_config.large, color=TEXT_COLOR_DARK), "x": layout_config.title_x, "xanchor": layout_config.title_xanchor, "yanchor": layout_config.title_yanchor, }, scene=dict( xaxis_title=x_axis, yaxis_title=y_axis, zaxis_title=z_axis, xaxis=dict( gridcolor=GRID_COLOR, showbackground=True, backgroundcolor=layout_config.scene_bgcolor, tickformat=FLOAT_PRECISION, ), yaxis=dict( gridcolor=GRID_COLOR, showbackground=True, backgroundcolor=layout_config.scene_bgcolor, tickformat=FLOAT_PRECISION, ), zaxis=dict( gridcolor=GRID_COLOR, showbackground=True, backgroundcolor=layout_config.scene_bgcolor, tickformat=FLOAT_PRECISION, ), aspectratio=dict( x=scene_config.aspect_ratio_x, y=scene_config.aspect_ratio_y, z=scene_config.aspect_ratio_z, ), ), width=layout_config.width, height=layout_config.height, margin=layout_config.margin, ) fig.update_layout( scene_camera=dict( center=dict(x=0, y=0, z=0), eye=dict( x=scene_config.camera_distance * np.cos(np.radians(scene_config.default_view_angle)), y=scene_config.camera_distance * np.sin(np.radians(scene_config.default_view_angle)), z=scene_config.camera_z, ), up=dict( x=scene_config.camera_up_x, y=scene_config.camera_up_y, z=scene_config.camera_up_z, ), ) ) return fig
[docs] def plot3D_contour( x_list, y_list, z_list, title, x_axis, y_axis, z_axis, subtitle=None, theme=None, font_config: Optional[FontConfig] = None, layout_config: Optional[LayoutConfig] = None, colorbar_config: Optional[ColorbarConfig] = None, smoothing_config: Optional[SmoothingConfig] = None, ): """ Generic 3D contour plotting utility. Handles both grid data (2D z array) and scatter data (1D x, y, z arrays). For grid data: Creates interpolated contour plot with contour lines. For scatter data: Creates scatter plot with points colored by z values. Parameters ---------- x_list : numpy.ndarray X coordinates (1D for grid, 1D for scatter) y_list : numpy.ndarray Y coordinates (1D for grid, 1D for scatter) z_list : numpy.ndarray Z values (2D for grid, 1D for scatter) title : str Plot title x_axis : str X-axis label y_axis : str Y-axis label z_axis : str Z-axis label subtitle : str, optional Optional subtitle theme : str, optional Color theme name font_config : FontConfig, optional FontConfig object for font settings layout_config : LayoutConfig, optional LayoutConfig object for layout settings colorbar_config : ColorbarConfig, optional ColorbarConfig object for colorbar settings smoothing_config : SmoothingConfig, optional SmoothingConfig object for smoothing settings Returns ------- go.Figure Plotly figure object """ if font_config is None: font_config = DEFAULT_FONT_CONFIG if layout_config is None: layout_config = DEFAULT_LAYOUT_CONFIG if colorbar_config is None: colorbar_config = DEFAULT_COLORBAR_CONFIG if smoothing_config is None: smoothing_config = DEFAULT_SMOOTHING_CONFIG # Detect data type data_type = _validate_data(x_list, y_list, z_list) theme_colors = get_theme(theme) fig = go.Figure() # Prepare smoothed/interpolated data based on input type match data_type: case DataType3D.GridInput: x_dense, y_dense, z_dense = prepare_smooth_data_3D( x_list, y_list, z_list, smoothing_config=smoothing_config ) case DataType3D.ScatterInput: x_dense, y_dense, z_dense = prepare_smooth_data_3D_scatter( x_list, y_list, z_list, smoothing_config=smoothing_config ) fig.add_trace( go.Scatter( x=x_list, y=y_list, mode="markers", marker=dict( color="black", size=MARKER_SIZE, opacity=0.5, symbol="x", ), showlegend=False, hovertemplate=( f"Original point<br>" f"{x_axis}: %{{x:{FLOAT_PRECISION}}}<br>" f"{y_axis}: %{{y:{FLOAT_PRECISION}}}<br>" f"{z_axis}: %{{customdata:{FLOAT_PRECISION}}}<extra></extra>" ), customdata=z_list, ) ) # Calculate contour size (use nanmin/nanmax to handle potential NaN values) z_min, z_max = np.nanmin(z_dense), np.nanmax(z_dense) contour_size = int((z_max - z_min) / NUM_CONTOURS) if z_max > z_min else 1 # Add contour plot fig.add_trace( go.Contour( x=x_dense, y=y_dense, z=z_dense, colorscale=theme_colors["colorscale"], contours=dict( showlabels=True, labelfont=dict( size=font_config.small, color=CONTOUR_LABEL_COLOR, ), start=int(z_min), end=int(z_max), size=contour_size, labelformat=FLOAT_PRECISION, ), colorbar=dict( title=dict( text=z_axis, font=dict(size=font_config.medium), ), thickness=colorbar_config.thickness, len=colorbar_config.length, tickfont=dict(size=font_config.small), tickformat=FLOAT_PRECISION, ), hovertemplate=( f"{x_axis}: %{{x:{FLOAT_PRECISION}}}<br>" f"{y_axis}: %{{y:{FLOAT_PRECISION}}}<br>" f"{z_axis}: %{{z:{FLOAT_PRECISION}}}<extra></extra>" ), ) ) full_title = title if subtitle is not None: full_title += f"<br><span style='font-size: {font_config.medium}px; color: {TEXT_COLOR_LIGHT};'>{subtitle}</span>" fig.update_layout( title={ "text": full_title, "font": dict(size=font_config.large, color=TEXT_COLOR_DARK), "x": layout_config.title_x, "xanchor": layout_config.title_xanchor, }, xaxis_title={ "text": x_axis, "font": dict(size=font_config.medium, color=TEXT_COLOR_DARK), }, yaxis_title={ "text": y_axis, "font": dict(size=font_config.medium, color=TEXT_COLOR_DARK), }, width=layout_config.width, height=layout_config.height, margin=layout_config.margin, plot_bgcolor=layout_config.plot_bgcolor, ) fig.update_xaxes( tickfont=dict(size=font_config.small), tickformat=FLOAT_PRECISION, ) fig.update_yaxes( tickfont=dict(size=font_config.small), tickformat=FLOAT_PRECISION, ) return fig