andufkova commited on
Commit
c87255c
1 Parent(s): 0d50fc2

first classification try

Browse files
app.py CHANGED
@@ -1,7 +1,64 @@
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return ("Hello " + name + "!!", "clustering tbd")
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs=["text", "text"])
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import pickle
4
+ from sentence_transformers import SentenceTransformer
5
 
 
 
6
 
7
+ #css_code='body {background-image:url("https://picsum.photos/seed/picsum/200/300");} div.gradio-container {background: white;}'
8
+
9
+
10
+ categories = ["Censorship","Development","Digital Activism","Disaster","Economics & Business","Education","Environment","Governance","Health","History","Humanitarian Response","International Relations","Law","Media & Journalism","Migration & Immigration","Politics","Protest","Religion","Sport","Travel","War & Conflict","Technology_Science","Women&Gender_LGBTQ+_Youth","Freedom_of_Speech_Human_Rights","Literature_Arts&Culture"]
11
+ model = SentenceTransformer('sentence-transformers/LaBSE')
12
+ with open('models/MLP_classifier_average_en.pkl', 'rb') as f:
13
+ classifier = pickle.load(f)
14
+
15
+ def get_embedding(text):
16
+ if text is None:
17
+ text = ""
18
+ return model.encode(text)
19
+
20
+ def get_categories(y_pred):
21
+ indices = []
22
+ for idx, value in enumerate(y_pred):
23
+ if value == 1:
24
+ indices.append(idx)
25
+ cats = [categories[i] for i in indices]
26
+ return cats
27
+
28
+ def generate_output(article):
29
+ paragraphs = article.split("\n")
30
+ embdds = []
31
+ for par in paragraphs:
32
+ embdds.append(get_embedding(par))
33
+ embedding = np.average(embdds, axis=0)
34
+
35
+ #y_pred = classifier.predict_proba(embedding.reshape(1, 768))
36
+ y_pred = classifier.predict(embedding.reshape(1, 768))
37
+ y_pred = y_pred.flatten()
38
+ classes = get_categories(y_pred)
39
+
40
+ return (classes, "clustering tbd")
41
+
42
+ # with gr.Blocks() as demo:
43
+ # with gr.Row():
44
+ # # column for input
45
+ # with gr.Column():
46
+ # input_text = gr.Textbox(lines=6, placeholder="Insert text of the article here...", label="Article"),
47
+ # submit_button = gr.Button("Submit")
48
+ # clear_button = gr.Button("Clear")
49
+
50
+ # # column for output
51
+ # with gr.Column():
52
+ # output_classification = gr.Textbox(lines=1, label="Article category")
53
+ # output_topic_discovery = gr.Textbox(lines=5, label="Topic discovery")
54
+
55
+ #submit_button.click(generate_output, inputs=input_text, outputs=[output_classification, output_topic_discovery])
56
+ demo = gr.Interface(fn=generate_output,
57
+ inputs=gr.Textbox(lines=6, placeholder="Insert text of the article here...", label="Article"),
58
+ outputs=[gr.Textbox(lines=1, label="Category"), gr.Textbox(lines=5, label="Topic discovery")],
59
+ title="Article classification & topic discovery demo",
60
+ flagging_options=["Incorrect"],
61
+ theme=gr.themes.Base())
62
+ #css=css_code)
63
+
64
+ demo.launch()
flagged/log.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name,output 0,output 1,flag,username,timestamp
2
+ test,Hello test!!,clustering tbd,,,2023-04-13 09:59:56.971579
3
+ ,,,,,2023-04-14 18:44:46.346018
4
+ ,,,Incorrect,,2023-04-14 18:48:06.029759
models/MLP_classifier_average_en.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ee563cb660f18a2d58d4b8790f02b68bc33e6c98c90cf890d1191a27c5b25a9
3
+ size 354644961