Spaces:
Running
Running
File size: 3,208 Bytes
6a34fd4 debb3aa 6a34fd4 debb3aa 6a34fd4 debb3aa 6a34fd4 debb3aa 6a34fd4 debb3aa 445302e debb3aa 5fdf2ba debb3aa 445302e debb3aa 6a34fd4 debb3aa 6a34fd4 445302e 6a34fd4 debb3aa 6a34fd4 debb3aa 6a34fd4 debb3aa 6a34fd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import gradio as gr
from huggingface_hub import hf_hub_download
import pickle
import gradio as gr
import numpy as np
import subprocess
import shutil
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
# Define the function to process the input file and model selection
def process_file(file,label, model_name):
with open(file.name, 'r') as f:
content = f.read()
saved_test_dataset = "train.txt"
saved_test_label = "train_label.txt"
# Save the uploaded file content to a specified location
shutil.copyfile(file.name, saved_test_dataset)
shutil.copyfile(label.name, saved_test_label)
# For demonstration purposes, we'll just return the content with the selected model name
if(model_name=="FS"):
checkpoint="ratio_proportion_change3/output/FS/bert_fine_tuned.model.ep32"
elif(model_name=="IS"):
checkpoint="ratio_proportion_change3/output/IS/bert_fine_tuned.model.ep14"
elif(model_name=="CORRECTNESS"):
checkpoint="ratio_proportion_change3/output/correctness/bert_fine_tuned.model.ep48"
elif(model_name=="EFFECTIVENESS"):
checkpoint="ratio_proportion_change3/output/effectiveness/bert_fine_tuned.model.ep28"
else:
checkpoint=None
print(checkpoint)
subprocess.run(["python", "src/test_saved_model.py",
"--finetuned_bert_checkpoint",checkpoint
])
result = {}
with open("result.txt", 'r') as file:
for line in file:
key, value = line.strip().split(': ', 1)
# print(type(key))
if key=='epoch':
result[key]=value
else:
result[key]=float(value)
# Create a plot
with open("roc_data.pkl", "rb") as f:
fpr, tpr, _ = pickle.load(f)
roc_auc = auc(fpr, tpr)
fig, ax = plt.subplots()
ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'ROC Curve: {model_name}')
ax.legend(loc="lower right")
ax.grid()
# Save plot to a file
plot_path = "plot.png"
fig.savefig(plot_path)
plt.close(fig)
# Prepare text output
text_output = f"Model: {model_name}\nResult:\n{result}"
return text_output,plot_path
# List of models for the dropdown menu
models = ["FS", "IS", "CORRECTNESS","EFFECTIVENESS"]
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# ASTRA")
gr.Markdown("Upload a .txt file and select a model from the dropdown menu.")
with gr.Row():
file_input = gr.File(label="Upload a .txt file", file_types=['.txt'])
label_input = gr.File(label="Upload a .txt file", file_types=['.txt'])
model_dropdown = gr.Dropdown(choices=models, label="Select a model")
with gr.Row():
output_text = gr.Textbox(label="Output Text")
output_image = gr.Image(label="Output Plot")
btn = gr.Button("Submit")
btn.click(fn=process_file, inputs=[file_input,label_input, model_dropdown], outputs=[output_text,output_image])
# Launch the app
demo.launch()
|