from typing import Dict, List
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from .color_themes import get_theme
from .plotting_constants import *
# Set default plotly template for better aesthetics
pio.templates.default = "plotly_white"
[docs]
def plot_bar_chart(
data_by_variable: List[Dict],
percent_steps: np.ndarray,
baseline_value: float,
title: str,
y_label: str,
theme: str = "chart_default",
font_config: Optional[FontConfig] = None,
layout_config: Optional[LayoutConfig] = None,
):
"""
Generic bar chart plotting utility for showing multiple data series.
Parameters
----------
data_by_variable : List[Dict]
List of dicts with keys: 'var_name' (str), 'points' (numpy array),
'plausible' (numpy boolean array)
percent_steps : numpy.ndarray
Array of percent step values used in the sweep
baseline_value : float
Float value to draw as a horizontal baseline
title : str
Plot title
y_label : str
Y-axis label
theme : str, optional
Color theme name for the gradient
font_config : FontConfig, optional
FontConfig object for font sizes
layout_config : LayoutConfig, optional
LayoutConfig object for layout dimensions
Returns
-------
go.Figure
Plotly figure object
"""
font_config = font_config or DEFAULT_FONT_CONFIG
layout_config = layout_config or DEFAULT_LAYOUT_CONFIG
colorscale = get_theme(theme)["colorscale"]
fig = go.Figure()
# Get all unique variable names for baseline
var_names = [var_data["var_name"] for var_data in data_by_variable]
# Add baseline as a horizontal dashed line
fig.add_trace(
go.Scatter(
x=var_names,
y=[baseline_value] * len(var_names),
mode="lines",
line=dict(color=ZEROLINE_COLOR, width=LINE_WIDTH, dash="dash"),
name="Baseline",
showlegend=True,
)
)
# Add all plausible points in a single trace with colorscale
all_plausible_x = []
all_plausible_y = []
all_plausible_colors = []
all_plausible_text = []
for var_data in data_by_variable:
# Use numpy masking to get plausible points
var_name = var_data["var_name"]
plausible_points = var_data["points"][var_data["plausible"]]
plausible_steps = percent_steps[var_data["plausible"]]
# Add all plausible points for this variable
for point_value, step_value in zip(plausible_points, plausible_steps):
all_plausible_x.append(var_name)
all_plausible_y.append(point_value)
all_plausible_colors.append(step_value)
all_plausible_text.append(f"Step: {step_value:.0%}")
# Add a single trace for all plausible points using a colorscale
fig.add_trace(
go.Scatter(
x=all_plausible_x,
y=all_plausible_y,
mode="markers",
marker=dict(
size=MARKER_SIZE_LARGE,
color=all_plausible_colors,
colorscale=colorscale,
colorbar=dict(
title="Percent Change",
tickformat="+.0%",
y=0,
yanchor="bottom",
len=0.7,
),
),
text=all_plausible_text,
hovertemplate="%{x}<br>Points: %{y:.1f}<br>%{text}<extra></extra>",
name="Plausible Points",
showlegend=False,
)
)
# Dummy trace for marker only in legend
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode="markers",
marker=dict(
size=MARKER_SIZE_LARGE,
color=TEXT_COLOR_DARK,
),
name="Plausible Points",
legendgroup="plausible",
showlegend=True,
)
)
# Add a separate trace for implausible points
all_implausible_x = []
all_implausible_y = []
all_implausible_text = []
for var_data in data_by_variable:
# Use numpy masking to get implausible points
var_name = var_data["var_name"]
implausible_points = var_data["points"][~var_data["plausible"]]
implausible_steps = percent_steps[~var_data["plausible"]]
# Add all implausible points for this variable
for point_value, step_value in zip(implausible_points, implausible_steps):
all_implausible_x.append(var_name)
all_implausible_y.append(point_value)
all_implausible_text.append(f"Step: {step_value:.0%}")
# Add a single trace for all implausible points
if all_implausible_x:
fig.add_trace(
go.Scatter(
x=all_implausible_x,
y=all_implausible_y,
mode="markers",
marker=dict(
size=MARKER_SIZE_LARGE,
color=TEXT_COLOR_DARK,
symbol=BAR_CHART_IMPLAUSIBLE_SYMBOL,
),
text=all_implausible_text,
hovertemplate="%{x}<br>Points: %{y:.1f}<br>%{text} (Implausible)<extra></extra>",
name="Implausible Points",
)
)
# Title, Axis, Labels, Legend
fig.update_layout(
title={
"text": title,
"font": dict(size=font_config.large, color=TEXT_COLOR_DARK),
"x": layout_config.title_x,
"xanchor": layout_config.title_xanchor,
},
xaxis_title={
"text": "Variable Names",
"font": dict(size=font_config.medium, color=TEXT_COLOR_DARK),
},
yaxis_title={
"text": y_label,
"font": dict(size=font_config.medium, color=TEXT_COLOR_DARK),
},
legend_title="Legend",
hovermode=HOVER_MODE,
plot_bgcolor=layout_config.plot_bgcolor,
width=layout_config.width,
height=800, # Bar charts need more vertical space
margin=dict(l=80, r=120, t=100, b=80),
)
# Add subtle grid lines
fig.update_xaxes(
showgrid=True,
gridwidth=GRID_WIDTH,
gridcolor=GRID_COLOR,
tickangle=BAR_CHART_XAXIS_TICKANGLE,
)
fig.update_yaxes(
showgrid=True,
gridwidth=GRID_WIDTH,
gridcolor=GRID_COLOR,
zeroline=True,
zerolinewidth=ZEROLINE_WIDTH,
zerolinecolor=ZEROLINE_COLOR,
)
return fig