FelixPhilip's picture
Oracle
3388ab8
import os
import gradio as gr
from Oracle.deepfundingoracle import prepare_dataset, train_predict_weight, create_submission_csv
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import time
import io
from PIL import Image
def analyze_file(file, progress=gr.Progress(track_tqdm=True)):
start_time = time.time()
progress(0, desc="Preparing dataset...")
df = prepare_dataset(file.name)
progress(0.3, desc="Predicting weights...")
df = train_predict_weight(df)
progress(0.6, desc="Saving results to CSV...")
csv_path = create_submission_csv(df, "submission.csv")
progress(0.8, desc="Generating graphs...")
# Feature distribution plot
dist_fig = plt.figure(figsize=(15, 10))
numeric_cols = df.select_dtypes(include=[np.number]).columns
df[numeric_cols].hist(bins=20, figsize=(15, 10), color="skyblue", edgecolor="black")
plt.suptitle("Feature Distributions", fontsize=16)
dist_buf = io.BytesIO()
plt.savefig(dist_buf, format='png')
dist_buf.seek(0)
plt.close(dist_fig)
dist_img = Image.open(dist_buf)
# Correlation matrix plot
corr_fig = plt.figure(figsize=(12, 8))
correlation_matrix = df[numeric_cols].corr()
sns.heatmap(correlation_matrix, annot=True, cmap="coolwarm", fmt=".2f", linewidths=0.5)
plt.title("Feature Correlation Matrix", fontsize=16)
corr_buf = io.BytesIO()
plt.savefig(corr_buf, format='png')
corr_buf.seek(0)
plt.close(corr_fig)
corr_img = Image.open(corr_buf)
progress(1, desc="Done!")
elapsed = time.time() - start_time
preview = df.head().to_csv(index=False)
return preview, csv_path, dist_img, corr_img, f"Analysis completed in {elapsed:.2f} seconds."
iface = gr.Interface(
fn=analyze_file,
inputs=gr.File(label="Upload CSV"),
outputs=[
gr.Textbox(label="Preview of Results"),
gr.File(label="Download CSV"),
gr.Image(label="Feature Distributions"),
gr.Image(label="Feature Correlation Matrix"),
gr.Textbox(label="Status/Timing Info")
],
title="DeepFunding Oracle",
description="Upload a CSV of repo-parent relationships; see analysis progress, get graphs, and download results as CSV.",
allow_flagging="never"
)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
iface.launch(server_name="0.0.0.0", server_port=port)