Theory NotebookMath for LLMs

Calibration and Uncertainty

Evaluation and Reliability / Calibration and Uncertainty

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Calibration and Uncertainty

Calibration asks whether confidence matches correctness; uncertainty methods decide when a model should answer, abstain, or return a set of plausible answers.

This notebook is the executable companion to notes.md. It uses synthetic data so the evaluation mathematics can run anywhere without external files.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3

import math

COLORS = {
    "primary":   "#0077BB",
    "secondary": "#EE7733",
    "tertiary":  "#009988",
    "error":     "#CC3311",
    "neutral":   "#555555",
    "highlight": "#EE3377",
}

def header(title):
    print("\n" + "=" * 80)
    print(title)
    print("=" * 80)

def check_true(condition, message):
    print(f"{'PASS' if bool(condition) else 'FAIL'} - {message}")
    assert bool(condition)

def check_close(actual, expected, tol=1e-8, message="values close"):
    ok = abs(actual - expected) <= tol
    print(f"{'PASS' if ok else 'FAIL'} - {message}: actual={actual:.6f}, expected={expected:.6f}")
    assert ok

def bootstrap_mean_ci(values, B=1000, alpha=0.05):
    values = np.asarray(values, dtype=float)
    idx = np.random.randint(0, len(values), size=(B, len(values)))
    boot = values[idx].mean(axis=1)
    lo, hi = np.quantile(boot, [alpha / 2, 1 - alpha / 2])
    return float(values.mean()), float(lo), float(hi)

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

print("Evaluation helper functions loaded.")

Demo 1: Confidence should match correctness

This cell studies Confidence should match correctness through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 5

header("Demo 1 - Confidence should match correctness: reliability diagram and ECE")
n = 1200
confidence = np.random.beta(4, 2, size=n)
true_prob = np.clip(confidence - 0.12 + np.random.normal(0, 0.03, size=n), 0.01, 0.99)
correct = (np.random.rand(n) < true_prob).astype(float)
bins = np.linspace(0, 1, 11)
ece = 0.0
centers, accs, confs = [], [], []
for lo, hi in zip(bins[:-1], bins[1:]):
    mask = (confidence >= lo) & (confidence < hi)
    if mask.any():
        acc = correct[mask].mean()
        conf = confidence[mask].mean()
        weight = mask.mean()
        ece += weight * abs(acc - conf)
        centers.append((lo + hi) / 2)
        accs.append(acc)
        confs.append(conf)
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 1], color=COLORS["neutral"], linestyle="--", label="perfect calibration")
ax.plot(confs, accs, marker="o", color=COLORS["primary"], label="model")
ax.set_title("Reliability diagram")
ax.set_xlabel("Mean confidence")
ax.set_ylabel("Empirical accuracy")
ax.legend()
fig.tight_layout()
plt.show()
print(f"ECE={ece:.4f}")
check_true(ece >= 0, "ECE is nonnegative")

Demo 2: High accuracy can still be unsafe

This cell studies High accuracy can still be unsafe through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 7

header("Demo 2 - High accuracy can still be unsafe: temperature scaling")
logits = np.random.normal(size=(500, 4)) * 2.0
labels = np.argmax(logits + np.random.normal(scale=1.2, size=logits.shape), axis=1)
def nll_for_T(T):
    z = logits / T
    z = z - z.max(axis=1, keepdims=True)
    log_probs = z - np.log(np.exp(z).sum(axis=1, keepdims=True))
    return -log_probs[np.arange(len(labels)), labels].mean()
Ts = np.linspace(0.5, 3.0, 31)
nlls = np.array([nll_for_T(T) for T in Ts])
best = Ts[nlls.argmin()]
fig, ax = plt.subplots()
ax.plot(Ts, nlls, color=COLORS["primary"], label="validation NLL")
ax.axvline(best, color=COLORS["secondary"], linestyle="--", label=f"best T={best:.2f}")
ax.set_title("Temperature scaling validation objective")
ax.set_xlabel("Temperature T")
ax.set_ylabel("NLL")
ax.legend()
fig.tight_layout()
plt.show()
print(f"best temperature={best:.3f}, best NLL={nlls.min():.3f}")
check_true(0.5 <= best <= 3.0, "best temperature is inside search range")

Demo 3: Selective prediction and abstention

This cell studies Selective prediction and abstention through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 9

header("Demo 3 - Selective prediction and abstention: split conformal coverage")
n_cal = 500
n_test = 800
scores_cal = np.random.beta(2, 8, size=n_cal)
alpha = 0.1
q = np.quantile(scores_cal, np.ceil((n_cal + 1) * (1 - alpha)) / n_cal, method="higher")
scores_test = np.random.beta(2, 8, size=n_test)
covered = (scores_test <= q).mean()
print(f"threshold={q:.3f}, empirical coverage={covered:.3f}, target={1-alpha:.3f}")
check_true(covered > 0.85, "coverage is near the target in this simulation")

Demo 4: Epistemic and aleatoric uncertainty

This cell studies Epistemic and aleatoric uncertainty through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 11

header("Demo 4 - Epistemic and aleatoric uncertainty: bootstrap uncertainty")
values = np.clip(np.random.normal(0.68, 0.18, size=500), 0, 1)
mean, lo, hi = bootstrap_mean_ci(values, B=800)
print(f"mean={mean:.3f}, bootstrap 95% CI=[{lo:.3f}, {hi:.3f}]")
check_true(hi > lo, "bootstrap interval has positive width")

Demo 5: Why LLM verbal confidence is unreliable

This cell studies Why LLM verbal confidence is unreliable through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 13

header("Demo 5 - Why LLM verbal confidence is unreliable: metric comparison plot")
names = np.array(["base", "prompt", "retrieval", "tool"])
scores = np.array([0.62, 0.68, 0.73, 0.70])
fig, ax = plt.subplots()
bars = ax.bar(names, scores, color=[COLORS["primary"], COLORS["secondary"], COLORS["tertiary"], COLORS["highlight"]])
ax.set_title("Evaluation metric across system variants")
ax.set_xlabel("System variant")
ax.set_ylabel("Score")
ax.set_ylim(0, 1)
for bar, val in zip(bars, scores):
    ax.text(bar.get_x() + bar.get_width()/2, val + 0.02, f"{val:.2f}", ha="center")
fig.tight_layout()
plt.show()
print("plotted metric comparison for four variants")
check_true(scores.max() > scores.min(), "variants differ in measured score")

Demo 6: Calibration condition

This cell studies Calibration condition through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 15

header("Demo 6 - Calibration condition: finite-sample accuracy interval")
n = 600
y = (np.random.rand(n) < 0.74).astype(float)
mean = y.mean()
se = np.sqrt(mean * (1 - mean) / n)
lo, hi = mean - 1.96 * se, mean + 1.96 * se
print(f"accuracy={mean:.3f}, 95% CI=[{lo:.3f}, {hi:.3f}], n={n}")
check_true(lo <= mean <= hi, "point estimate lies inside its interval")

Demo 7: Reliability function

This cell studies Reliability function through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 17

header("Demo 7 - Reliability function: slice metrics")
slices = np.array(["short", "long", "code", "math", "multilingual"])
acc = np.array([0.82, 0.71, 0.66, 0.58, 0.62])
n = np.array([400, 260, 180, 160, 140])
se = np.sqrt(acc * (1 - acc) / n)
for name, a, e in zip(slices, acc, se):
    print(f"slice={name:12s} accuracy={a:.3f} +/- {1.96*e:.3f}")
check_true(acc.min() < acc.max(), "slices reveal heterogeneous performance")

Demo 8: Expected and maximum calibration error

This cell studies Expected and maximum calibration error through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 19

header("Demo 8 - Expected and maximum calibration error: tail risk via CVaR")
losses = np.random.lognormal(mean=-2.0, sigma=0.8, size=1200)
alpha = 0.9
threshold = np.quantile(losses, alpha)
cvar = losses[losses >= threshold].mean()
print(f"mean loss={losses.mean():.4f}, 90% tail threshold={threshold:.4f}, CVaR={cvar:.4f}")
check_true(cvar >= losses.mean(), "tail risk exceeds average risk")

Demo 9: Brier score and negative log-likelihood

This cell studies Brier score and negative log-likelihood through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 21

header("Demo 9 - Brier score and negative log-likelihood: reliability diagram and ECE")
n = 1200
confidence = np.random.beta(4, 2, size=n)
true_prob = np.clip(confidence - 0.12 + np.random.normal(0, 0.03, size=n), 0.01, 0.99)
correct = (np.random.rand(n) < true_prob).astype(float)
bins = np.linspace(0, 1, 11)
ece = 0.0
centers, accs, confs = [], [], []
for lo, hi in zip(bins[:-1], bins[1:]):
    mask = (confidence >= lo) & (confidence < hi)
    if mask.any():
        acc = correct[mask].mean()
        conf = confidence[mask].mean()
        weight = mask.mean()
        ece += weight * abs(acc - conf)
        centers.append((lo + hi) / 2)
        accs.append(acc)
        confs.append(conf)
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 1], color=COLORS["neutral"], linestyle="--", label="perfect calibration")
ax.plot(confs, accs, marker="o", color=COLORS["primary"], label="model")
ax.set_title("Reliability diagram")
ax.set_xlabel("Mean confidence")
ax.set_ylabel("Empirical accuracy")
ax.legend()
fig.tight_layout()
plt.show()
print(f"ECE={ece:.4f}")
check_true(ece >= 0, "ECE is nonnegative")

Demo 10: Coverage, set size, and selective risk

This cell studies Coverage, set size, and selective risk through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 23

header("Demo 10 - Coverage, set size, and selective risk: temperature scaling")
logits = np.random.normal(size=(500, 4)) * 2.0
labels = np.argmax(logits + np.random.normal(scale=1.2, size=logits.shape), axis=1)
def nll_for_T(T):
    z = logits / T
    z = z - z.max(axis=1, keepdims=True)
    log_probs = z - np.log(np.exp(z).sum(axis=1, keepdims=True))
    return -log_probs[np.arange(len(labels)), labels].mean()
Ts = np.linspace(0.5, 3.0, 31)
nlls = np.array([nll_for_T(T) for T in Ts])
best = Ts[nlls.argmin()]
fig, ax = plt.subplots()
ax.plot(Ts, nlls, color=COLORS["primary"], label="validation NLL")
ax.axvline(best, color=COLORS["secondary"], linestyle="--", label=f"best T={best:.2f}")
ax.set_title("Temperature scaling validation objective")
ax.set_xlabel("Temperature T")
ax.set_ylabel("NLL")
ax.legend()
fig.tight_layout()
plt.show()
print(f"best temperature={best:.3f}, best NLL={nlls.min():.3f}")
check_true(0.5 <= best <= 3.0, "best temperature is inside search range")

Demo 11: Reliability diagrams

This cell studies Reliability diagrams through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 25

header("Demo 11 - Reliability diagrams: split conformal coverage")
n_cal = 500
n_test = 800
scores_cal = np.random.beta(2, 8, size=n_cal)
alpha = 0.1
q = np.quantile(scores_cal, np.ceil((n_cal + 1) * (1 - alpha)) / n_cal, method="higher")
scores_test = np.random.beta(2, 8, size=n_test)
covered = (scores_test <= q).mean()
print(f"threshold={q:.3f}, empirical coverage={covered:.3f}, target={1-alpha:.3f}")
check_true(covered > 0.85, "coverage is near the target in this simulation")

Demo 12: Confidence histograms

This cell studies Confidence histograms through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 27

header("Demo 12 - Confidence histograms: bootstrap uncertainty")
values = np.clip(np.random.normal(0.68, 0.18, size=500), 0, 1)
mean, lo, hi = bootstrap_mean_ci(values, B=800)
print(f"mean={mean:.3f}, bootstrap 95% CI=[{lo:.3f}, {hi:.3f}]")
check_true(hi > lo, "bootstrap interval has positive width")

Demo 13: Binning bias and adaptive bins

This cell studies Binning bias and adaptive bins through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 29

header("Demo 13 - Binning bias and adaptive bins: metric comparison plot")
names = np.array(["base", "prompt", "retrieval", "tool"])
scores = np.array([0.62, 0.68, 0.73, 0.70])
fig, ax = plt.subplots()
bars = ax.bar(names, scores, color=[COLORS["primary"], COLORS["secondary"], COLORS["tertiary"], COLORS["highlight"]])
ax.set_title("Evaluation metric across system variants")
ax.set_xlabel("System variant")
ax.set_ylabel("Score")
ax.set_ylim(0, 1)
for bar, val in zip(bars, scores):
    ax.text(bar.get_x() + bar.get_width()/2, val + 0.02, f"{val:.2f}", ha="center")
fig.tight_layout()
plt.show()
print("plotted metric comparison for four variants")
check_true(scores.max() > scores.min(), "variants differ in measured score")

Demo 14: Bootstrap uncertainty for ECE

This cell studies Bootstrap uncertainty for ECE through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 31

header("Demo 14 - Bootstrap uncertainty for ECE: finite-sample accuracy interval")
n = 600
y = (np.random.rand(n) < 0.74).astype(float)
mean = y.mean()
se = np.sqrt(mean * (1 - mean) / n)
lo, hi = mean - 1.96 * se, mean + 1.96 * se
print(f"accuracy={mean:.3f}, 95% CI=[{lo:.3f}, {hi:.3f}], n={n}")
check_true(lo <= mean <= hi, "point estimate lies inside its interval")

Demo 15: Prompt and slice-level calibration

This cell studies Prompt and slice-level calibration through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 33

header("Demo 15 - Prompt and slice-level calibration: slice metrics")
slices = np.array(["short", "long", "code", "math", "multilingual"])
acc = np.array([0.82, 0.71, 0.66, 0.58, 0.62])
n = np.array([400, 260, 180, 160, 140])
se = np.sqrt(acc * (1 - acc) / n)
for name, a, e in zip(slices, acc, se):
    print(f"slice={name:12s} accuracy={a:.3f} +/- {1.96*e:.3f}")
check_true(acc.min() < acc.max(), "slices reveal heterogeneous performance")

Demo 16: Log score

This cell studies Log score through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 35

header("Demo 16 - Log score: tail risk via CVaR")
losses = np.random.lognormal(mean=-2.0, sigma=0.8, size=1200)
alpha = 0.9
threshold = np.quantile(losses, alpha)
cvar = losses[losses >= threshold].mean()
print(f"mean loss={losses.mean():.4f}, 90% tail threshold={threshold:.4f}, CVaR={cvar:.4f}")
check_true(cvar >= losses.mean(), "tail risk exceeds average risk")

Demo 17: Brier score

This cell studies Brier score through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 37

header("Demo 17 - Brier score: reliability diagram and ECE")
n = 1200
confidence = np.random.beta(4, 2, size=n)
true_prob = np.clip(confidence - 0.12 + np.random.normal(0, 0.03, size=n), 0.01, 0.99)
correct = (np.random.rand(n) < true_prob).astype(float)
bins = np.linspace(0, 1, 11)
ece = 0.0
centers, accs, confs = [], [], []
for lo, hi in zip(bins[:-1], bins[1:]):
    mask = (confidence >= lo) & (confidence < hi)
    if mask.any():
        acc = correct[mask].mean()
        conf = confidence[mask].mean()
        weight = mask.mean()
        ece += weight * abs(acc - conf)
        centers.append((lo + hi) / 2)
        accs.append(acc)
        confs.append(conf)
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 1], color=COLORS["neutral"], linestyle="--", label="perfect calibration")
ax.plot(confs, accs, marker="o", color=COLORS["primary"], label="model")
ax.set_title("Reliability diagram")
ax.set_xlabel("Mean confidence")
ax.set_ylabel("Empirical accuracy")
ax.legend()
fig.tight_layout()
plt.show()
print(f"ECE={ece:.4f}")
check_true(ece >= 0, "ECE is nonnegative")

Demo 18: Sharpness versus calibration

This cell studies Sharpness versus calibration through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 39

header("Demo 18 - Sharpness versus calibration: temperature scaling")
logits = np.random.normal(size=(500, 4)) * 2.0
labels = np.argmax(logits + np.random.normal(scale=1.2, size=logits.shape), axis=1)
def nll_for_T(T):
    z = logits / T
    z = z - z.max(axis=1, keepdims=True)
    log_probs = z - np.log(np.exp(z).sum(axis=1, keepdims=True))
    return -log_probs[np.arange(len(labels)), labels].mean()
Ts = np.linspace(0.5, 3.0, 31)
nlls = np.array([nll_for_T(T) for T in Ts])
best = Ts[nlls.argmin()]
fig, ax = plt.subplots()
ax.plot(Ts, nlls, color=COLORS["primary"], label="validation NLL")
ax.axvline(best, color=COLORS["secondary"], linestyle="--", label=f"best T={best:.2f}")
ax.set_title("Temperature scaling validation objective")
ax.set_xlabel("Temperature T")
ax.set_ylabel("NLL")
ax.legend()
fig.tight_layout()
plt.show()
print(f"best temperature={best:.3f}, best NLL={nlls.min():.3f}")
check_true(0.5 <= best <= 3.0, "best temperature is inside search range")

Demo 19: Strict propriety

This cell studies Strict propriety through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 41

header("Demo 19 - Strict propriety: split conformal coverage")
n_cal = 500
n_test = 800
scores_cal = np.random.beta(2, 8, size=n_cal)
alpha = 0.1
q = np.quantile(scores_cal, np.ceil((n_cal + 1) * (1 - alpha)) / n_cal, method="higher")
scores_test = np.random.beta(2, 8, size=n_test)
covered = (scores_test <= q).mean()
print(f"threshold={q:.3f}, empirical coverage={covered:.3f}, target={1-alpha:.3f}")
check_true(covered > 0.85, "coverage is near the target in this simulation")

Demo 20: Scoring rules for LLM outputs

This cell studies Scoring rules for LLM outputs through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 43

header("Demo 20 - Scoring rules for LLM outputs: bootstrap uncertainty")
values = np.clip(np.random.normal(0.68, 0.18, size=500), 0, 1)
mean, lo, hi = bootstrap_mean_ci(values, B=800)
print(f"mean={mean:.3f}, bootstrap 95% CI=[{lo:.3f}, {hi:.3f}]")
check_true(hi > lo, "bootstrap interval has positive width")

Demo 21: Temperature scaling

This cell studies Temperature scaling through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 45

header("Demo 21 - Temperature scaling: metric comparison plot")
names = np.array(["base", "prompt", "retrieval", "tool"])
scores = np.array([0.62, 0.68, 0.73, 0.70])
fig, ax = plt.subplots()
bars = ax.bar(names, scores, color=[COLORS["primary"], COLORS["secondary"], COLORS["tertiary"], COLORS["highlight"]])
ax.set_title("Evaluation metric across system variants")
ax.set_xlabel("System variant")
ax.set_ylabel("Score")
ax.set_ylim(0, 1)
for bar, val in zip(bars, scores):
    ax.text(bar.get_x() + bar.get_width()/2, val + 0.02, f"{val:.2f}", ha="center")
fig.tight_layout()
plt.show()
print("plotted metric comparison for four variants")
check_true(scores.max() > scores.min(), "variants differ in measured score")

Demo 22: Platt scaling

This cell studies Platt scaling through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 47

header("Demo 22 - Platt scaling: finite-sample accuracy interval")
n = 600
y = (np.random.rand(n) < 0.74).astype(float)
mean = y.mean()
se = np.sqrt(mean * (1 - mean) / n)
lo, hi = mean - 1.96 * se, mean + 1.96 * se
print(f"accuracy={mean:.3f}, 95% CI=[{lo:.3f}, {hi:.3f}], n={n}")
check_true(lo <= mean <= hi, "point estimate lies inside its interval")

Demo 23: Isotonic regression

This cell studies Isotonic regression through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 49

header("Demo 23 - Isotonic regression: slice metrics")
slices = np.array(["short", "long", "code", "math", "multilingual"])
acc = np.array([0.82, 0.71, 0.66, 0.58, 0.62])
n = np.array([400, 260, 180, 160, 140])
se = np.sqrt(acc * (1 - acc) / n)
for name, a, e in zip(slices, acc, se):
    print(f"slice={name:12s} accuracy={a:.3f} +/- {1.96*e:.3f}")
check_true(acc.min() < acc.max(), "slices reveal heterogeneous performance")

Demo 24: Validation split discipline

This cell studies Validation split discipline through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 51

header("Demo 24 - Validation split discipline: tail risk via CVaR")
losses = np.random.lognormal(mean=-2.0, sigma=0.8, size=1200)
alpha = 0.9
threshold = np.quantile(losses, alpha)
cvar = losses[losses >= threshold].mean()
print(f"mean loss={losses.mean():.4f}, 90% tail threshold={threshold:.4f}, CVaR={cvar:.4f}")
check_true(cvar >= losses.mean(), "tail risk exceeds average risk")

Demo 25: Calibration under shift

This cell studies Calibration under shift through a small executable experiment. Focus on the estimator, the uncertainty statement, and the failure mode.

Code cell 53

header("Demo 25 - Calibration under shift: reliability diagram and ECE")
n = 1200
confidence = np.random.beta(4, 2, size=n)
true_prob = np.clip(confidence - 0.12 + np.random.normal(0, 0.03, size=n), 0.01, 0.99)
correct = (np.random.rand(n) < true_prob).astype(float)
bins = np.linspace(0, 1, 11)
ece = 0.0
centers, accs, confs = [], [], []
for lo, hi in zip(bins[:-1], bins[1:]):
    mask = (confidence >= lo) & (confidence < hi)
    if mask.any():
        acc = correct[mask].mean()
        conf = confidence[mask].mean()
        weight = mask.mean()
        ece += weight * abs(acc - conf)
        centers.append((lo + hi) / 2)
        accs.append(acc)
        confs.append(conf)
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 1], color=COLORS["neutral"], linestyle="--", label="perfect calibration")
ax.plot(confs, accs, marker="o", color=COLORS["primary"], label="model")
ax.set_title("Reliability diagram")
ax.set_xlabel("Mean confidence")
ax.set_ylabel("Empirical accuracy")
ax.legend()
fig.tight_layout()
plt.show()
print(f"ECE={ece:.4f}")
check_true(ece >= 0, "ECE is nonnegative")