import gradio as gr import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from compute_curve import compute_curve from distributions import BinaryYFDistribution def convert_y_params_to_f_params(p_y, p_f_given_y1, p_f_given_y0): """Convert Y-based parameters to F-based parameters using Bayes' theorem""" # P(F = 1) = P(F = 1 | Y = 1) * P(Y = 1) + P(F = 1 | Y = 0) * P(Y = 0) p_f = p_f_given_y1 * p_y + p_f_given_y0 * (1 - p_y) # P(Y = 1 | F = 1) = P(F = 1 | Y = 1) * P(Y = 1) / P(F = 1) p_y_given_f1 = (p_f_given_y1 * p_y) / p_f if p_f > 0 else 0 # P(Y = 1 | F = 0) = P(F = 0 | Y = 1) * P(Y = 1) / P(F = 0) p_y_given_f0 = ((1 - p_f_given_y1) * p_y) / (1 - p_f) if p_f < 1 else 0 return p_f, p_y_given_f1, p_y_given_f0 def plot_custom_probabilities(p_f, p_y_given_f1, p_y_given_f0, show_ppi): plt.close() fig, ax = plt.subplots(1, 1, figsize=(8, 6)) binary_dist = BinaryYFDistribution( p_f=p_f, p_y_given_f1=p_y_given_f1, p_y_given_f0=p_y_given_f0 ) n_range = np.arange(4, 200, 2) data = compute_curve(binary_dist, n_range) data = pd.DataFrame(data) corr = binary_dist.correlation() ax.set_ylim(0.5, 2) best_rel_perf = ( binary_dist.variance_y() - binary_dist.covariance_f_y() ** 2 / binary_dist.variance_f() ) / binary_dist.variance_y() # Always show PPI++ sns.lineplot( data, x="n", y="relative_var_cf", ax=ax, label="Relative Variance PPI++", linewidth=2, ) # Conditionally show PPI if show_ppi: sns.lineplot( data, x="n", y="relative_var_ppi", ax=ax, label="Relative Variance PPI", linewidth=2, linestyle="dotted", ) # Add horizontal line at y=1 ax.axhline( y=1, color="k", linestyle="dotted", alpha=0.7, label="Baseline (Sample Mean)" ) # Add horizontal line at y=1 ax.axhline( y=best_rel_perf, color="b", linestyle="dotted", alpha=0.7, label="Asymptotic Performance", ) # Find and mark crossing points cf_data = data["relative_var_cf"].values n_data = data["n"].values # Find where PPI++ crosses y=1 cf_crossings = np.where(np.diff(np.sign(cf_data - 1)))[0] for crossing in cf_crossings: if crossing < len(n_data) - 1: ax.axvline(x=n_data[crossing], color="red", linestyle="--", alpha=0.7) ax.text( n_data[crossing], 0.55, f"n={n_data[crossing]}", rotation=0, ha="center", va="bottom", color="red", ) ax.set_xlabel("Sample Size (n)") ax.set_ylabel("Relative Variance") ax.set_title( f"Relative Variance Analysis (Correlation of Psuedo-Label: {corr:.2f})" ) ax.legend() ax.grid(True, alpha=0.3) return fig def plot_y_based_probabilities(p_y, p_f_given_y1, p_f_given_y0, show_ppi): """Plot using Y-based parameters by converting to F-based parameters""" # Convert Y-based parameters to F-based parameters p_f, p_y_given_f1, p_y_given_f0 = convert_y_params_to_f_params( p_y, p_f_given_y1, p_f_given_y0 ) # Use the existing plotting function with converted parameters return plot_custom_probabilities(p_f, p_y_given_f1, p_y_given_f0, show_ppi) # Create interface for F-based parameters (original) f_based_interface = gr.Interface( fn=plot_custom_probabilities, inputs=[ gr.Slider(0.05, 0.95, value=0.50, step=0.05, label="P(F = 1)"), gr.Slider(0.05, 0.95, value=0.60, step=0.05, label="P(Y = 1 | F = 1)"), gr.Slider(0.05, 0.95, value=0.40, step=0.05, label="P(Y = 1 | F = 0)"), gr.Checkbox(value=True, label="Show PPI curve"), ], outputs=gr.Plot(label="PPI++ Analysis", format="png"), title="Example: Specify in terms of psuedo-label prevalence and Y | F distribution", description=""" Analyze relative variance curves of PPI and PPI++ (with Cross-Fitting), as compared to using the empirical mean of Y. **Inputs:** - P(F = 1): Prior probability of the binary pseudo-label - P(Y = 1 | F = 1): Conditional probability of label given pseudo-label = 1 - P(Y = 1 | F = 0): Conditional probability of label given pseudo-label = 0 """, live=True, flagging_mode="never", ) # Create interface for Y-based parameters y_based_interface = gr.Interface( fn=plot_y_based_probabilities, inputs=[ gr.Slider(0.05, 0.95, value=0.1, step=0.05, label="P(Y = 1)"), gr.Slider(0.05, 0.95, value=0.9, step=0.05, label="P(F = 1 | Y = 1)"), gr.Slider(0.05, 0.95, value=0.1, step=0.05, label="P(F = 1 | Y = 0)"), gr.Checkbox(value=True, label="Show PPI curve"), ], outputs=gr.Plot(label="PPI++ Analysis", format="png"), title="Example: Specify in terms of prevalance and F | Y distribution", description=""" Analyze relative variance curves of PPI and PPI++ (with Cross-Fitting), as compared to using the empirical mean of Y. **Inputs:** - P(Y = 1): Prior probability of the binary label - P(F = 1 | Y = 1): Conditional probability of pseudo-label given label = 1 - P(F = 1 | Y = 0): Conditional probability of pseudo-label given label = 0 """, live=True, flagging_mode="never", ) # Create tabbed interface demo = gr.TabbedInterface( [f_based_interface, y_based_interface], ["F-based Parameters", "Y-based Parameters"], title="Sample Size Analysis (Binary Case)", ) if __name__ == "__main__": demo.launch(share=True)