Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import os | |
import shutil | |
from omnibin import generate_binary_classification_report, ColorScheme | |
# Define results directory | |
RESULTS_DIR = "/tmp/results" | |
# Map string color schemes to enum values | |
COLOR_SCHEME_MAP = { | |
"DEFAULT": ColorScheme.DEFAULT, | |
"MONOCHROME": ColorScheme.MONOCHROME, | |
"VIBRANT": ColorScheme.VIBRANT | |
} | |
def process_csv(csv_file, n_bootstrap=1000, dpi=72, color_scheme="DEFAULT"): | |
# Convert string color scheme to enum | |
color_scheme_enum = COLOR_SCHEME_MAP[color_scheme] | |
# Read the CSV file | |
df = pd.read_csv(csv_file.name) | |
# Check if required columns exist | |
required_columns = ['y_true', 'y_pred'] | |
if not all(col in df.columns for col in required_columns): | |
raise ValueError("CSV file must contain 'y_true' and 'y_pred' columns") | |
# Clean up results directory if it exists | |
if os.path.exists(RESULTS_DIR): | |
shutil.rmtree(RESULTS_DIR) | |
# Create fresh results directory | |
os.makedirs(RESULTS_DIR, exist_ok=True) | |
# Generate the report | |
report_path = generate_binary_classification_report( | |
y_true=df['y_true'].values, | |
y_scores=df['y_pred'].values, | |
output_path=os.path.join(RESULTS_DIR, "classification_report.pdf"), | |
n_bootstrap=n_bootstrap, | |
random_seed=42, | |
dpi=dpi, | |
color_scheme=color_scheme_enum | |
) | |
# Get paths to individual plots | |
plots_dir = os.path.join(RESULTS_DIR, "plots") | |
plot_paths = { | |
"ROC and PR Curves": os.path.join(plots_dir, "roc_pr.png"), | |
"Metrics vs Threshold": os.path.join(plots_dir, "metrics_threshold.png"), | |
"Confusion Matrix": os.path.join(plots_dir, "confusion_matrix.png"), | |
"Calibration Plot": os.path.join(plots_dir, "calibration.png"), | |
"Prediction Distribution": os.path.join(plots_dir, "prediction_distribution.png"), | |
"Metrics Summary": os.path.join(plots_dir, "metrics_summary.png") | |
} | |
# Return both the PDF and the plot images | |
return report_path, *plot_paths.values() | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=process_csv, | |
inputs=[ | |
gr.File(label="Upload CSV file with 'y_true' and 'y_pred' columns"), | |
gr.Number(label="Number of Bootstrap Iterations", value=1000, minimum=100, maximum=10000), | |
gr.Number(label="DPI", value=72, minimum=50, maximum=300), | |
gr.Dropdown(label="Color Scheme", choices=["DEFAULT", "MONOCHROME", "VIBRANT"], value="DEFAULT") | |
], | |
outputs=[ | |
gr.File(label="Classification Report PDF"), | |
gr.Image(label="ROC and PR Curves"), | |
gr.Image(label="Metrics vs Threshold"), | |
gr.Image(label="Confusion Matrix"), | |
gr.Image(label="Calibration Plot"), | |
gr.Image(label="Prediction Distribution"), | |
gr.Image(label="Metrics Summary") | |
], | |
title="Binary Classification Report Generator", | |
description="Upload a CSV file containing 'y_true' and 'y_pred' columns to generate a binary classification report.\n\n" | |
"'y_true': reference standard (0s or 1s).\n\n" | |
"'y_pred': model prediction (continuous value between 0 and 1).\n\n" | |
"This application takes approximately 35 seconds to generate the report.\n", | |
examples=[["scores.csv", 1000, 72, "DEFAULT"]], | |
cache_examples=False | |
) | |
if __name__ == "__main__": | |
iface.launch() | |