Scientific Validity of Stratified PPI++ for Mean Estimation¶
This notebook provides empirical evidence that GLIDE's Stratified Prediction-Powered Inference (Stratified PPI++) implementation is statistically valid.
Setup: We estimate the mean of a binary outcome (e.g., the hallucination rate of an AI system) over a population that is naturally partitioned into strata (e.g., language, domain, or data source). Within each stratum we have:
- A small set of true labels (
y_true), expensive but unbiased - A large set of proxy labels (
y_proxy), cheap but potentially biased
Stratified PPI++ fits an optimal rectification weight per stratum, then combines stratum-level estimates with population-proportional weights. This yields confidence intervals that are:
- Valid : they cover the true mean at the specified rate (e.g. 90% confidence)
- Shorter : as compared to those obtained with true labels only, especially when proxy quality is strong in at least one stratum
We test these two claims empirically across a range of proxy/true correlation levels.
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
from glide.estimators import StratifiedClassicalMeanEstimator, StratifiedPPIMeanEstimator
from glide.samplers import StratifiedSampler
from glide.scientific_validation import compute_hits, coverage_with_error_bar, run_monte_carlo
from glide.simulators import generate_stratified_binary_dataset, simulate_annotation
plt.rcParams.update(
{
"font.size": 18,
"axes.labelsize": 18,
"axes.titlesize": 18,
"legend.fontsize": 16,
"xtick.labelsize": 16,
"ytick.labelsize": 16,
"figure.titlesize": 19,
}
)
Experiment Parameters¶
We fix all parameters up front so every section of this notebook uses a consistent setup. We define:
CONFIDENCE_LEVEL: the confidence level at which we will compute confidence intervals.N_STRATA: number of strata.N_TOTAL_PER_STRATUM: total number of samples per stratum.BUDGET: total number of human annotated samples.STRATUM_TRUE_MEANS: per-stratum true mean values.STRATUM_PROXY_MEANS: per-stratum (biased) proxy mean values.TRUE_MEAN: the population-weighted true mean, used as the ground-truth target for coverage evaluation.STRATUM_CORRELATION_OFFSETS: per-stratum offsets to correlation level between true and proxy values.N_SEEDS: number of simulations in Monte Carlo experiments.
Note on correlation bounds: Depending on the values of
STRATUM_TRUE_MEANSandSTRATUM_PROXY_MEANS, extreme correlation values (close to -1 or 1) may not be possible. The correlation sweep is kept within safe limits for all strata.
Finally, we define the baseline estimation methods for comparison:
True only: uses true human labels only — the gold standard for validityProxy only: uses proxy labels only — biased but cheapStratified PPI++: Stratified Prediction-Powered Inference with per-stratum power-tuning
CONFIDENCE_LEVEL = 0.9
N_STRATA = 4
N_TOTAL_PER_STRATUM = np.array([500, 750, 750, 500]) # per stratum
BUDGET = 500 # total labeled samples
STRATUM_TRUE_MEANS = np.array([0.55, 0.4, 0.55, 0.5])
STRATUM_PROXY_MEANS = np.array([0.5, 0.45, 0.6, 0.55])
STRATUM_CORRELATION_OFFSETS = np.array([-0.1, 0.0, -0.2, 0.0])
N_SEEDS = 1000
total = np.sum(N_TOTAL_PER_STRATUM)
STRATUM_WEIGHTS = N_TOTAL_PER_STRATUM / total
# Population-weighted true mean (ground truth for coverage validation)
TRUE_MEAN = np.sum(STRATUM_WEIGHTS * STRATUM_TRUE_MEANS)
METHODS = ["True only", "Proxy only", "Stratified PPI++"]
# Correlation sweep — kept within feasible range for both strata
correlations = np.arange(0.1, 0.95, 0.1)
n_correlations = len(correlations)
correlations_lmh = [
correlations[n_correlations // 4],
correlations[n_correlations // 2],
correlations[3 * n_correlations // 4],
] # low, medium and high values
corr_labels = ["Low", "Medium", "High"]
print(f"Population-weighted TRUE_MEAN = {TRUE_MEAN:.3f}")
print(f"Stratum weights: {STRATUM_WEIGHTS}")
Population-weighted TRUE_MEAN = 0.495 Stratum weights: [0.2 0.3 0.3 0.2]
Data Simulation¶
We use generate_stratified_binary_dataset to simulate a stratified evaluation scenario. It simulates correlated binary labels (y_true_oracle, y_proxy) for all samples across all strata. The absence of certain ground-truths is then simulated by randomly selecting samples to annotate via StratifiedSampler (which samples proportionally within each stratum) and masking the rest with np.nan via simulate_annotation.
The correlation parameter controls the Pearson correlation between true and proxy labels within each stratum. In the sweep below, all strata receive the same correlation value.
# Single example dataset for illustration
y_true_oracle, y_proxy, groups = generate_stratified_binary_dataset(
n_total=N_TOTAL_PER_STRATUM.tolist(),
true_mean=STRATUM_TRUE_MEANS.tolist(),
proxy_mean=STRATUM_PROXY_MEANS.tolist(),
correlation=[0.8] * N_STRATA,
random_seed=42,
)
xi = StratifiedSampler().sample(y_proxy, groups, budget=BUDGET, strategy="proportional", random_seed=42)
y_true = simulate_annotation(y_true_oracle, xi)
n_labeled = int(np.sum(~np.isnan(y_true)))
n_unlabeled = len(y_true) - n_labeled
print(f"Total samples: {len(y_true)}")
print(f"Labeled samples: {n_labeled}")
print(f"Unlabeled samples: {n_unlabeled}")
Total samples: 2500 Labeled samples: 500 Unlabeled samples: 2000
Inference Results¶
We compare three estimation methods:
| Estimation method | Data used | Notes |
|---|---|---|
| True only | y_true |
Classical CLT Confidence Interval, the gold standard for validity |
| Proxy only | y_proxy |
Biased, cheap but wrong |
| Stratified PPI++ | y_true + y_proxy (rectified per stratum) |
Best of both worlds, valid and efficient |
The function below simulates a dataset for a given seed and correlation level, then runs all three estimation methods on it.
Note that we introduce offsets to per-stratum correlation levels simulating heterogenous proxy quality across strata.
def simulate_estimates(seed, correlation):
y_true, y_proxy, groups = generate_stratified_binary_dataset(
n_total=N_TOTAL_PER_STRATUM.tolist(),
true_mean=STRATUM_TRUE_MEANS.tolist(),
proxy_mean=STRATUM_PROXY_MEANS.tolist(),
correlation=(correlation + STRATUM_CORRELATION_OFFSETS).tolist(),
random_seed=seed,
)
xi = StratifiedSampler().sample(y_proxy, groups, budget=BUDGET, strategy="proportional", random_seed=seed)
y_true = simulate_annotation(y_true, xi)
# --- Stratified PPI ---
estimator = StratifiedPPIMeanEstimator()
stratified_ppi_result = estimator.estimate(y_true, y_proxy, groups, confidence_level=CONFIDENCE_LEVEL)
# --- Classical baselines ---
classical_estimator = StratifiedClassicalMeanEstimator()
true_only_result = classical_estimator.estimate(y_true, groups, confidence_level=CONFIDENCE_LEVEL)
proxy_only_result = classical_estimator.estimate(y_proxy, groups, confidence_level=CONFIDENCE_LEVEL)
return {
"True only": {
"mean": true_only_result.mean,
"std": true_only_result.std,
"confidence_interval": true_only_result.confidence_interval,
},
"Proxy only": {
"mean": proxy_only_result.mean,
"std": proxy_only_result.std,
"confidence_interval": proxy_only_result.confidence_interval,
},
"Stratified PPI++": {
"mean": stratified_ppi_result.mean,
"std": stratified_ppi_result.std,
"confidence_interval": stratified_ppi_result.confidence_interval,
"effective_sample_size": stratified_ppi_result.effective_sample_size,
},
}
StratifiedPPIMeanEstimator splits the samples by stratum_id, computes a power-tuned PPI++ estimate within each stratum, and combines them with population-proportional weights. StratifiedClassicalMeanEstimator implements conventional mean estimation using true labels only but partitions by stratum to compute the variance.
Coverage Validity¶
A confidence interval is valid if it reliably captures the true value at the nominal rate: a 90% confidence interval is valid if, across many repetitions, around 90% of the resulting intervals contain the true value.
We run a Monte Carlo experiment to verify this for each method. We check that the empirical coverage tracks the nominal level throughout. See the Scientific Validation Methodology page for more details about the verification protocol.
Coverage vs confidence level for three correlation levels¶
We sweep the confidence level from 0.55 to 0.95 and plot the observed coverage. For a valid estimation method, the dots should fall on or above the black diagonal $y = \text{confidence level}$.
We do this for low, medium and high proxy correlation.
# Run Monte Carlo simulations for each correlation level
confidence_levels = np.arange(0.55, 1.00, 0.05)
confidence_levels = np.round(confidence_levels, 2)
raw_stats = {
corr: run_monte_carlo(confidence_levels, partial(simulate_estimates, correlation=corr)) for corr in correlations
}
# Derive coverage for every (correlation, confidence_level) pair
coverages_confidence_intervals = {}
for correlation in correlations_lmh:
coverages_confidence_intervals[correlation] = {}
for confidence_level in confidence_levels:
hits = compute_hits(raw_stats[correlation], confidence_level, TRUE_MEAN)
coverages_confidence_intervals[correlation][confidence_level] = dict()
for method in METHODS:
coverage_confidence_interval = coverage_with_error_bar(hits[method], confidence_level=CONFIDENCE_LEVEL)
coverages_confidence_intervals[correlation][confidence_level][method] = coverage_confidence_interval
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=True)
colors = {"True only": "steelblue", "Stratified PPI++": "darkorange", "Proxy only": "red"}
for ax, correlation, label in zip(axes, correlations_lmh, corr_labels):
ax.plot(confidence_levels, confidence_levels, color="black", lw=1.5, linestyle="--", label="Ideal")
for method in METHODS:
mean_ci = np.array([coverages_confidence_intervals[correlation][cl][method] for cl in confidence_levels])
mean = mean_ci[:, 0]
lo = mean_ci[:, 1]
hi = mean_ci[:, 2]
ax.plot(confidence_levels, mean, marker="o", color=colors[method], label=method)
ax.fill_between(confidence_levels, lo, hi, alpha=0.15, color=colors[method])
ax.set_title(f"{label} correlation (${round(correlation, 2)}$)")
ax.set_xlabel("Target confidence level")
ax.set_ylabel("Observed coverage")
ax.legend(loc="lower right")
ax.set_xlim(0.5, 1.0)
ax.set_ylim(0.5, 1.0)
plt.tight_layout()
plt.show()
Both Stratified PPI++ and True only track the diagonal closely across all correlation levels, confirming that Stratified PPI++ achieves valid coverage regardless of proxy quality. The Proxy only method is far from the diagonal because it uses biased data so that its coverage is invalid.
Coverage vs correlation for fixed confidence level¶
We now fix the confidence level and sweep a range of proxy-true correlation levels. This shows that Stratified PPI++ validity does not degrade as the proxy becomes weaker.
coverage_by_corr = {} # {correlation: {method: observed mean coverage}}
coverage_ci_by_corr = {} # {correlation: {method: (lower, upper) Confidence Interval on coverage}}
for correlation in correlations:
hits = compute_hits(raw_stats[correlation], CONFIDENCE_LEVEL, TRUE_MEAN)
coverage_by_corr[correlation] = {}
coverage_ci_by_corr[correlation] = {}
for method in METHODS:
mean_cov, lo, hi = coverage_with_error_bar(hits[method], CONFIDENCE_LEVEL)
coverage_by_corr[correlation][method] = mean_cov
coverage_ci_by_corr[correlation][method] = (lo, hi)
fig, ax = plt.subplots(figsize=(8, 5))
method_colors = {"True only": "steelblue", "Stratified PPI++": "darkorange"}
for method in ["True only", "Stratified PPI++"]:
obs = np.array([coverage_by_corr[correlation][method] for correlation in correlations])
ci_bounds = np.array([coverage_ci_by_corr[correlation][method] for correlation in correlations])
lo = ci_bounds[:, 0]
hi = ci_bounds[:, 1]
ax.plot(correlations, obs, marker="o", color=method_colors[method], label=method)
ax.fill_between(correlations, lo, hi, alpha=0.15, color=method_colors[method])
ax.axhline(y=CONFIDENCE_LEVEL, color="red", linestyle="--", lw=2, label=f"Target coverage {CONFIDENCE_LEVEL:.0%}")
ax.set_xlabel("Proxy–true correlation")
ax.set_ylabel("Observed coverage")
ax.set_xlim(0, 0.95)
ax.set_ylim(0.8, 1.0)
ax.yaxis.set_ticks(ax.get_yticks()[1:-1:2])
ax.legend()
plt.tight_layout()
plt.show()
Note that Proxy only is not plotted because the proxy is biased (proxy mean ≠ true mean in each stratum). Therefore it has invalid coverage (close to 0) whereas Stratified PPI++ and True only remain valid across all correlation levels.
Confidence Interval Width¶
Coverage validity is necessary but not sufficient: we also want short intervals. The width difference between True only and Stratified PPI is attributable solely to the proxy labels.
We compare mean confidence interval widths for Stratified PPI and True only across correlation levels.
width_by_corr = {}
for correlation in correlations:
width_by_corr[correlation] = {}
for method in METHODS:
lower_bound = raw_stats[correlation][method]["lower_bounds"][CONFIDENCE_LEVEL]
upper_bound = raw_stats[correlation][method]["upper_bounds"][CONFIDENCE_LEVEL]
width_by_corr[correlation][method] = upper_bound - lower_bound
fig, ax = plt.subplots(figsize=(9, 5))
plot_methods = ["True only", "Stratified PPI++"]
colors_w = {"True only": "steelblue", "Stratified PPI++": "darkorange"}
# Compute percentiles based on CONFIDENCE_LEVEL
lower_percentile = round(((1 - CONFIDENCE_LEVEL) / 2) * 100)
upper_percentile = 100 - lower_percentile
for method in plot_methods:
means_w = [np.mean(width_by_corr[correlation][method]) for correlation in correlations]
q_lower = [np.percentile(width_by_corr[correlation][method], lower_percentile) for correlation in correlations]
q_upper = [np.percentile(width_by_corr[correlation][method], upper_percentile) for correlation in correlations]
ax.plot(correlations, means_w, marker="o", label=method, color=colors_w[method])
ax.fill_between(correlations, q_lower, q_upper, alpha=0.15, color=colors_w[method])
ax.set_xlabel("Proxy–true correlation")
ax.set_ylabel("Confidence Interval width")
ax.set_xlim(0.05, 0.95)
ax.yaxis.set_ticks(ax.get_yticks()[1:-1:2])
ax.legend()
plt.tight_layout()
plt.show()
As expected, Stratified PPI++'s interval width decreases with increasing correlation. The per-stratum power-tuning allows the estimator to leverage proxy data only where it is informative, further reducing interval width compared to a single global rectification.
Effective Sample Size¶
A natural summary of Stratified PPI's efficiency gain is the effective sample size (ESS): the number of true labels that would be needed to match Stratified PPI's mean confidence interval width.
We report Stratified PPI's effective sample size across correlation levels, translating the width reduction into an equivalent number of true labels. See the Scientific Validation Methodology page for the formal definition and formula of ESS.
ess_mean = [
np.mean(raw_stats[correlation]["Stratified PPI++"]["effective_sample_sizes"]) for correlation in correlations
]
ess_q_lower = [
np.percentile(raw_stats[correlation]["Stratified PPI++"]["effective_sample_sizes"], lower_percentile)
for correlation in correlations
]
ess_q_upper = [
np.percentile(raw_stats[correlation]["Stratified PPI++"]["effective_sample_sizes"], upper_percentile)
for correlation in correlations
]
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(correlations, ess_mean, marker="o", color="darkorange", label="Stratified PPI++ ESS (mean)")
ax.fill_between(
correlations,
ess_q_lower,
ess_q_upper,
alpha=0.15,
color="darkorange",
label=f"{lower_percentile:.0f}th–{upper_percentile:.0f}th percentile",
)
ax.axhline(y=BUDGET, color="steelblue", linestyle="--", lw=2, label=f"Baseline (True only, n={BUDGET})")
ax.set_xlabel("Proxy–true correlation")
ax.set_ylabel("Effective sample size")
ax.set_xlim(0.05, 0.95)
ax.legend()
plt.tight_layout()
plt.show()
Summary¶
This notebook has empirically validated that GLIDE's Stratified PPI++ implementation satisfies two key statistical properties:
| Property | Result |
|---|---|
| Coverage validity | Stratified PPI++ achieves the nominal coverage across all correlation levels and confidence levels tested |
| Efficiency | Stratified PPI++ produces shorter confidence intervals than labeled-only whenever correlation is positive, with the gain growing with correlation |
Crucially, the biased baseline (Proxy only) fails the coverage test. It appears precise but is systematically wrong. Stratified PPI++ avoids this by correcting for proxy bias using the labeled subset within each stratum.
The ESS analysis shows that with moderate proxy correlation, Stratified PPI++ is equivalent to having significantly more labeled data, a substantial practical gain in scenarios where true annotation is expensive. By fitting a separate rectification weight per stratum, the estimator can fully exploit informative strata while gracefully degrading toward the labeled-only estimate in strata where the proxy is weak.