ssk / app.py
ssk3232's picture
Update app.py
4c0998a verified
raw
history blame
No virus
1.75 kB
import gradio as gr
import pandas as pd
from io import StringIO
# Import the OneClass class
from oneclass import OneClass
def predict_and_download(positive_csv_file, unlabelled_csv_file, n, Hyperparameter_nu):
# Create an instance of the OneClass class
oc = OneClass()
# Call the select_top_n_papers method
selected_paper_info = oc.select_top_n_papers(n, positive_csv_file, unlabelled_csv_file, Hyperparameter_nu)
# Create a StringIO object to store CSV data
csv_buffer = StringIO()
# Write DataFrame to the StringIO buffer as CSV
selected_paper_info.to_csv(csv_buffer, index=False)
# Get the CSV data from the buffer
csv_content = csv_buffer.getvalue()
# Return selected_paper_info and CSV content
return selected_paper_info, csv_buffer
# Create the interface
iface = gr.Interface(
fn=predict_and_download,
inputs=[
gr.inputs.File(type="csv", label="Positive CSV File"),
gr.inputs.File(type="csv", label="Unlabelled CSV File"),
gr.inputs.Number(label="Number of Papers to Select", default=10),
gr.inputs.Number(label="Hyperparameter nu", default=0.5)
],
outputs=[
gr.outputs.Dataframe(label="Selected Papers", formats=["csv", "json"]),
gr.outputs.DownloadButton(label="Download CSV")
],
title="Paper Prediction",
description="Enter the number of papers to select and upload CSV files for labelled and unlabelled data.",
article="This interface uses the OneClass algorithm to select the top N papers based on the input CSV files. The Hyperparameter nu controls the sensitivity of the algorithm.",
theme="default",
allow_flagging='never' # Disable flagging feature
)
iface.launch()