astra / app.py
suryadev1's picture
progress bar
ee40bd7
raw
history blame
9.19 kB
import gradio as gr
from huggingface_hub import hf_hub_download
import pickle
from gradio import Progress
import numpy as np
import subprocess
import shutil
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
# Define the function to process the input file and model selection
def process_file(file,label,info,inc_val,progress=Progress(track_tqdm=True)):
# progress = gr.Progress(track_tqdm=True)
progress(0, desc="Starting the processing")
with open(file.name, 'r') as f:
content = f.read()
saved_test_dataset = "train.txt"
saved_test_label = "train_label.txt"
saved_train_info="train_info.txt"
# Save the uploaded file content to a specified location
shutil.copyfile(file.name, saved_test_dataset)
shutil.copyfile(label.name, saved_test_label)
shutil.copyfile(info.name, saved_train_info)
# For demonstration purposes, we'll just return the content with the selected model name
# if(model_name=="highGRschool10"):
# checkpoint="ratio_proportion_change3/output/FS/bert_fine_tuned.model.ep32"
# elif(model_name=="lowGRschoolAll"):
# checkpoint="ratio_proportion_change3/output/IS/bert_fine_tuned.model.ep14"
# elif(model_name=="fullTest"):
# checkpoint="ratio_proportion_change3/output/correctness/bert_fine_tuned.model.ep48"
# else:
# checkpoint=None
# print(checkpoint)
if (inc_val<5):
model_name="highGRschool10"
elif(inc_val>=5 & inc_val<10):
model_name="highGRschool10"
else:
model_name="highGRschool10"
subprocess.run([
"python", "new_test_saved_finetuned_model.py",
"-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
"-finetune_task", model_name,
# "-test_dataset_path","../../../../train.txt",
# "-test_label_path","../../../../train_label.txt",
"-finetuned_bert_classifier_checkpoint",
"ratio_proportion_change3_2223/sch_largest_100-coded/output/highGRschool10/bert_fine_tuned.model.ep42",
"-e",str(1),
"-b",str(1000)
])
progress(0.6,desc="Model execution completed")
result = {}
with open("result.txt", 'r') as file:
for line in file:
key, value = line.strip().split(': ', 1)
# print(type(key))
if key=='epoch':
result[key]=value
else:
result[key]=float(value)
# Create a plot
with open("roc_data.pkl", "rb") as f:
fpr, tpr, _ = pickle.load(f)
roc_auc = auc(fpr, tpr)
fig, ax = plt.subplots()
ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'ROC Curve: {model_name}')
ax.legend(loc="lower right")
ax.grid()
# Save plot to a file
plot_path = "plot.png"
fig.savefig(plot_path)
plt.close(fig)
progress(1.0)
# Prepare text output
text_output = f"Model: {model_name}\nResult:\n{result}"
return text_output,plot_path
# List of models for the dropdown menu
models = ["highGRschool10", "lowGRschoolAll", "fullTest"]
# Create the Gradio interface
with gr.Blocks(css="""
body {
background-color: #1e1e1e!important;
font-family: 'Arial', sans-serif;
color: #f5f5f5!important;;
}
.gradio-container {
max-width: 850px!important;
margin: 0 auto!important;;
padding: 20px!important;;
background-color: #292929!important;
border-radius: 10px;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2);
}
.gradio-container-4-44-0 .prose h1 {
font-size: var(--text-xxl);
color: #ffffff!important;
}
#title {
color: white!important;
font-size: 2.3em;
font-weight: bold;
text-align: center!important;
margin-bottom: 20px;
}
.description {
text-align: center;
font-size: 1.1em;
color: #bfbfbf;
margin-bottom: 30px;
}
.file-box {
max-width: 180px;
padding: 5px;
background-color: #444!important;
border: 1px solid #666!important;
border-radius: 6px;
height: 80px!important;;
margin: 0 auto!important;;
text-align: center;
color: transparent;
}
.file-box span {
color: #f5f5f5!important;
font-size: 1em;
line-height: 45px; /* Vertically center text */
}
.dropdown-menu {
max-width: 220px;
margin: 0 auto!important;
background-color: #444!important;
color:#444!important;
border-radius: 6px;
padding: 8px;
font-size: 1.1em;
border: 1px solid #666;
}
.button {
background-color: #4CAF50!important;
color: white!important;
font-size: 1.1em;
padding: 10px 25px;
border-radius: 6px;
cursor: pointer;
transition: background-color 0.2s ease-in-out;
}
.button:hover {
background-color: #45a049!important;
}
.output-text {
background-color: #333!important;
padding: 12px;
border-radius: 8px;
border: 1px solid #666;
font-size: 1.1em;
}
.footer {
text-align: center;
margin-top: 50px;
font-size: 0.9em;
color: #b0b0b0;
}
.svelte-12ioyct .wrap {
display: none !important;
}
.file-label-text {
display: none !important;
}
div.svelte-sfqy0y {
display: flex;
flex-direction: inherit;
flex-wrap: wrap;
gap: var(--form-gap-width);
box-shadow: var(--block-shadow);
border: var(--block-border-width) solid var(--border-color-primary);
border-radius: var(--block-radius);
background: #1f2937!important;
overflow-y: hidden;
}
.block.svelte-12cmxck {
position: relative;
margin: 0;
box-shadow: var(--block-shadow);
border-width: var(--block-border-width);
border-color: var(--block-border-color);
border-radius: var(--block-radius);
background: #1f2937!important;
width: 100%;
line-height: var(--line-sm);
}
.svelte-12ioyct .wrap {
display: none !important;
}
.file-label-text {
display: none !important;
}
input[aria-label="file upload"] {
display: none !important;
}
gradio-app .gradio-container.gradio-container-4-44-0 .contain .file-box span {
font-size: 1em;
line-height: 45px;
color: #1f2937 !important;
}
.wrap.svelte-12ioyct {
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
min-height: var(--size-60);
color: #1f2937 !important;
line-height: var(--line-md);
height: 100%;
padding-top: var(--size-3);
text-align: center;
margin: auto var(--spacing-lg);
}
span.svelte-1gfkn6j:not(.has-info) {
margin-bottom: var(--spacing-lg);
color: white!important;
}
label.float.svelte-1b6s6s {
position: relative!important;
top: var(--block-label-margin);
left: var(--block-label-margin);
}
label.svelte-1b6s6s {
display: inline-flex;
align-items: center;
z-index: var(--layer-2);
box-shadow: var(--block-label-shadow);
border: var(--block-label-border-width) solid var(--border-color-primary);
border-top: none;
border-left: none;
border-radius: var(--block-label-radius);
background: rgb(120 151 180)!important;
padding: var(--block-label-padding);
pointer-events: none;
color: #1f2937!important;
font-weight: var(--block-label-text-weight);
font-size: var(--block-label-text-size);
line-height: var(--line-sm);
}
.file.svelte-18wv37q.svelte-18wv37q {
display: block!important;
width: var(--size-full);
}
tbody.svelte-18wv37q>tr.svelte-18wv37q:nth-child(odd) {
background: ##7897b4!important;
color: white;
background: #aca7b2;
}
.gradio-container-4-31-4 .prose h1, .gradio-container-4-31-4 .prose h2, .gradio-container-4-31-4 .prose h3, .gradio-container-4-31-4 .prose h4, .gradio-container-4-31-4 .prose h5 {
color: white;
""") as demo:
gr.Markdown("<h1 id='title'>ASTRA</h1>", elem_id="title")
gr.Markdown("<p class='description'>Upload a .txt file and select a model from the dropdown menu.</p>")
with gr.Row():
file_input = gr.File(label="Upload a test file", file_types=['.txt'], elem_classes="file-box")
label_input = gr.File(label="Upload test labels", file_types=['.txt'], elem_classes="file-box")
info_input = gr.File(label="Upload test info", file_types=['.txt'], elem_classes="file-box")
# model_dropdown = gr.Dropdown(choices=models, label="Select Finetune Task", elem_classes="dropdown-menu")
increment_slider = gr.Slider(minimum=1, maximum=50, step=5, label="Schools number", value=1)
with gr.Row():
output_text = gr.Textbox(label="Output Text")
output_image = gr.Image(label="Output Plot")
btn = gr.Button("Submit")
btn.click(fn=process_file, inputs=[file_input,label_input,info_input,increment_slider], outputs=[output_text,output_image])
# Launch the app
demo.launch()