File size: 1,148 Bytes
61a185c b1c8440 61a185c 25d5cc9 61a185c b1c8440 61a185c b1c8440 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import oneclass
import gradio as gr
import pandas as pd
import io
def predict_and_download(positive_csv_file, unlabelled_csv_file, n, text):
selected_paper_info = oneclass.select_top_n_papers(n, positive_csv_file, unlabelled_csv_file)
# Create a StringIO object to store CSV data
csv_buffer = StringIO()
# Write DataFrame to the StringIO buffer as CSV
selected_paper_info.to_csv(csv_buffer, index=False)
# Get the CSV data from the buffer
csv_content = csv_buffer.getvalue()
# Close the buffer
csv_buffer.close()
# Save CSV content to a variable
csv_variable = csv_content
# Return selected_paper_info and CSV content
return selected_paper_info, csv_variable
# Create the interface
iface = gr.Interface(
fn=predict_and_download,
inputs=["file", "file", "number", "textbox"],
outputs=[gr.DataFrame(label="Selected Papers"), gr.DownloadButton(label="Download CSV")],
title="Paper Prediction",
description="Enter text and upload CSV files for labelled and unlabelled data.",
allow_flagging='never' # Disable flagging feature
)
iface.launch() |