File size: 3,208 Bytes
6a34fd4
 
 
 
 
 
 
debb3aa
 
6a34fd4
debb3aa
6a34fd4
 
debb3aa
 
6a34fd4
 
 
debb3aa
6a34fd4
debb3aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445302e
 
 
 
 
 
 
 
 
debb3aa
 
 
5fdf2ba
 
 
debb3aa
 
 
 
 
 
 
 
 
 
 
 
445302e
debb3aa
 
 
 
6a34fd4
 
debb3aa
6a34fd4
 
 
445302e
6a34fd4
 
 
 
debb3aa
6a34fd4
 
debb3aa
 
 
6a34fd4
 
debb3aa
6a34fd4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
from huggingface_hub import hf_hub_download
import pickle
import gradio as gr
import numpy as np
import subprocess
import shutil
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
# Define the function to process the input file and model selection
def process_file(file,label, model_name):
    with open(file.name, 'r') as f:
        content = f.read()
    saved_test_dataset = "train.txt"
    saved_test_label = "train_label.txt"
    
    # Save the uploaded file content to a specified location
    shutil.copyfile(file.name, saved_test_dataset)
    shutil.copyfile(label.name, saved_test_label)
    # For demonstration purposes, we'll just return the content with the selected model name
    if(model_name=="FS"):
        checkpoint="ratio_proportion_change3/output/FS/bert_fine_tuned.model.ep32"
    elif(model_name=="IS"):
        checkpoint="ratio_proportion_change3/output/IS/bert_fine_tuned.model.ep14"
    elif(model_name=="CORRECTNESS"):
        checkpoint="ratio_proportion_change3/output/correctness/bert_fine_tuned.model.ep48"
    elif(model_name=="EFFECTIVENESS"):
        checkpoint="ratio_proportion_change3/output/effectiveness/bert_fine_tuned.model.ep28"
    else:
        checkpoint=None

    print(checkpoint)
    subprocess.run(["python", "src/test_saved_model.py",
                    "--finetuned_bert_checkpoint",checkpoint
                    ])
    result = {}
    with open("result.txt", 'r') as file:
        for line in file:
            key, value = line.strip().split(': ', 1)
            # print(type(key))
            if key=='epoch':
                result[key]=value
            else:
                 result[key]=float(value)
# Create a plot
    with open("roc_data.pkl", "rb") as f:
        fpr, tpr, _ = pickle.load(f)



    roc_auc = auc(fpr, tpr)
    fig, ax = plt.subplots()
    ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'ROC Curve: {model_name}')
    ax.legend(loc="lower right")
    ax.grid()

    # Save plot to a file
    plot_path = "plot.png"
    fig.savefig(plot_path)
    plt.close(fig)

    # Prepare text output
    text_output = f"Model: {model_name}\nResult:\n{result}"

    return text_output,plot_path

# List of models for the dropdown menu
models = ["FS", "IS", "CORRECTNESS","EFFECTIVENESS"]

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# ASTRA")
    gr.Markdown("Upload a .txt file and select a model from the dropdown menu.")
    
    with gr.Row():
        file_input = gr.File(label="Upload a .txt file", file_types=['.txt'])
        label_input = gr.File(label="Upload a .txt file", file_types=['.txt'])
        model_dropdown = gr.Dropdown(choices=models, label="Select a model")
    
    with gr.Row():
        output_text = gr.Textbox(label="Output Text")
        output_image = gr.Image(label="Output Plot")

    btn = gr.Button("Submit")
    btn.click(fn=process_file, inputs=[file_input,label_input, model_dropdown], outputs=[output_text,output_image])

# Launch the app
demo.launch()