ssk3232 commited on
Commit
4c0998a
1 Parent(s): da1f30c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -6
app.py CHANGED
@@ -1,10 +1,16 @@
1
- import oneclass
2
  import gradio as gr
3
  import pandas as pd
4
  from io import StringIO
5
 
6
- def predict_and_download(positive_csv_file, unlabelled_csv_file, n,Hyperparameter_nu):
7
- selected_paper_info = oneclass.select_top_n_papers(n, positive_csv_file, unlabelled_csv_file,Hyperparameter_nu)
 
 
 
 
 
 
 
8
 
9
  # Create a StringIO object to store CSV data
10
  csv_buffer = StringIO()
@@ -22,10 +28,20 @@ def predict_and_download(positive_csv_file, unlabelled_csv_file, n,Hyperparamete
22
  # Create the interface
23
  iface = gr.Interface(
24
  fn=predict_and_download,
25
- inputs=["file", "file", "number","number"],
26
- outputs=[gr.DataFrame(label="Selected Papers"), gr.DownloadButton(label="Download CSV")],
 
 
 
 
 
 
 
 
27
  title="Paper Prediction",
28
- description="Enter text and upload CSV files for labelled and unlabelled data.",
 
 
29
  allow_flagging='never' # Disable flagging feature
30
  )
31
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  from io import StringIO
4
 
5
+ # Import the OneClass class
6
+ from oneclass import OneClass
7
+
8
+ def predict_and_download(positive_csv_file, unlabelled_csv_file, n, Hyperparameter_nu):
9
+ # Create an instance of the OneClass class
10
+ oc = OneClass()
11
+
12
+ # Call the select_top_n_papers method
13
+ selected_paper_info = oc.select_top_n_papers(n, positive_csv_file, unlabelled_csv_file, Hyperparameter_nu)
14
 
15
  # Create a StringIO object to store CSV data
16
  csv_buffer = StringIO()
 
28
  # Create the interface
29
  iface = gr.Interface(
30
  fn=predict_and_download,
31
+ inputs=[
32
+ gr.inputs.File(type="csv", label="Positive CSV File"),
33
+ gr.inputs.File(type="csv", label="Unlabelled CSV File"),
34
+ gr.inputs.Number(label="Number of Papers to Select", default=10),
35
+ gr.inputs.Number(label="Hyperparameter nu", default=0.5)
36
+ ],
37
+ outputs=[
38
+ gr.outputs.Dataframe(label="Selected Papers", formats=["csv", "json"]),
39
+ gr.outputs.DownloadButton(label="Download CSV")
40
+ ],
41
  title="Paper Prediction",
42
+ description="Enter the number of papers to select and upload CSV files for labelled and unlabelled data.",
43
+ article="This interface uses the OneClass algorithm to select the top N papers based on the input CSV files. The Hyperparameter nu controls the sensitivity of the algorithm.",
44
+ theme="default",
45
  allow_flagging='never' # Disable flagging feature
46
  )
47