|
import gradio as gr |
|
import pandas as pd |
|
from io import StringIO |
|
|
|
|
|
from oneclass import OneClass |
|
|
|
def predict_and_download(positive_csv_file, unlabelled_csv_file, n, Hyperparameter_nu): |
|
|
|
oc = OneClass() |
|
|
|
|
|
selected_paper_info = oc.select_top_n_papers(n, positive_csv_file, unlabelled_csv_file, Hyperparameter_nu) |
|
|
|
|
|
csv_buffer = StringIO() |
|
|
|
|
|
selected_paper_info.to_csv(csv_buffer, index=False) |
|
|
|
|
|
csv_content = csv_buffer.getvalue() |
|
|
|
|
|
|
|
return selected_paper_info, csv_buffer |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_and_download, |
|
inputs=[ |
|
gr.inputs.File(type="csv", label="Positive CSV File"), |
|
gr.inputs.File(type="csv", label="Unlabelled CSV File"), |
|
gr.inputs.Number(label="Number of Papers to Select", default=10), |
|
gr.inputs.Number(label="Hyperparameter nu", default=0.5) |
|
], |
|
outputs=[ |
|
gr.outputs.Dataframe(label="Selected Papers", formats=["csv", "json"]), |
|
gr.outputs.DownloadButton(label="Download CSV") |
|
], |
|
title="Paper Prediction", |
|
description="Enter the number of papers to select and upload CSV files for labelled and unlabelled data.", |
|
article="This interface uses the OneClass algorithm to select the top N papers based on the input CSV files. The Hyperparameter nu controls the sensitivity of the algorithm.", |
|
theme="default", |
|
allow_flagging='never' |
|
) |
|
|
|
iface.launch() |