File size: 5,878 Bytes
aba0616
c87255c
48032f9
c87255c
48032f9
 
c87255c
48032f9
 
aba0616
 
48032f9
 
c87255c
48032f9
 
c87255c
48032f9
 
 
c87255c
 
48032f9
 
 
c87255c
 
 
 
48032f9
c87255c
 
 
 
 
 
 
 
 
48032f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c87255c
 
 
 
 
 
 
 
48032f9
 
 
 
 
 
 
 
 
c87255c
48032f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c87255c
48032f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c87255c
 
48032f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c87255c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import gradio as gr
import numpy as np
import pandas as pd
import pickle
import sklearn
import plotly.express as px
from sentence_transformers import SentenceTransformer
from sklearn.cluster import MiniBatchKMeans
from learn_multi_doc_model import Model


#css_code='body {background-image:url("https://picsum.photos/seed/picsum/200/300");} div.gradio-container {background: white;}, button#component-8{background-color: rgb(158,202,225);}'
css_code='button#component-8{background-color: rgb(158,202,225);}'

import __main__
setattr(__main__, "Model", Model)

categories = ["Censorship","Development","Digital Activism","Disaster","Economics & Business","Education","Environment","Governance","Health","History","Humanitarian Response","International Relations","Law","Media & Journalism","Migration & Immigration","Politics","Protest","Religion","Sport","Travel","War & Conflict","Technology + Science","Women & Gender + LGBTQ + Youth","Freedom of Speech + Human Rights","Literature + Arts & Culture"]
input_cvect_key_file = 'topic_discovery/cvects.key'
model_labse = SentenceTransformer('sentence-transformers/LaBSE')
with open('models/MLP_classifier_average_en.pkl', 'rb') as f:
    classifier = pickle.load(f)
mul_model = None
with open('models/model_0.0001_100.pkl', 'rb') as f:
    mul_model = pickle.load(f)

def get_embedding(text):
    if text is None:
        text = ""
    return model_labse.encode(text)

def get_categories(y_pred):
    indices = []
    for idx, value in enumerate(y_pred):
        if value == 1:
            indices.append(idx)
    cats = [categories[i] for i in indices]
    return cats

def get_words(doc_emb):
    # load countvectorizers
    cvects = {}
    vocab = {}  # load vocabulary of words for each lang
    with open(input_cvect_key_file, "r") as fpr:
        for line in fpr:
            #print(line)
            lang, fpath = line.strip().split()
            with open(fpath, "rb") as fpr:
                #print(f"loading {fpath}")
                cvects[lang] = pickle.load(fpr)
            vocab[lang] = cvects[lang].get_feature_names()

            #print(
            #    "Loaded CountVectorizer for lang",
            #    lang,
            #    "with vocab size:",
            #    len(vocab[lang]),
            #)

    topn = 10 # top N words per cluster

    #print(vocab["en"])
    #print("MODEL KEYS")
    #print(mul_model.E.keys())

    doc_emb = doc_emb.flatten()

    words_dict = {}

    for lang in mul_model.E.keys():

        #print(lang, end=": ")

        scores = mul_model.E[lang] @ (doc_emb).T
        k_ixs = np.argsort(scores)[::-1][:topn].squeeze()  # sort them in descending order and pick topn
        tmp = []
        for i in k_ixs:
            #print(vocab[lang][i], end=", ")
            tmp.append(vocab[lang][i])

        words_dict[lang] = tmp
        #print()

    return words_dict


def generate_output(article):
    paragraphs = article.split("\n")
    embdds = []
    for par in paragraphs:
        embdds.append(get_embedding(par))
    embedding = np.average(embdds, axis=0)

    #y_pred = classifier.predict_proba(embedding.reshape(1, 768))
    reshaped = embedding.reshape(1, 768)
    #y_pred = classifier.predict(reshaped)
    #y_pred = y_pred.flatten()

    y_prob = classifier.predict_proba(reshaped)
    y_prob = y_prob.reshape(len(categories),1)

    y_pred = [1 if x >= 0.5 else 0 for x in y_prob]

    classes = get_categories(y_pred)
    if len(classes) > 1:
        classes_string = ', '.join(classes)
    elif len(classes) == 1:
        classes_string = classes[0]
    else:
        classes_string = 'No category was found.'

    

    data = pd.DataFrame()
    data['Category'] = categories
    data['Probability'] = y_prob
    fig = px.bar(data, x='Probability', y='Category', orientation='h', height=600)#, title="Category probability")
    fig.update_xaxes(range=[0, 1])
    fig.update_layout(margin=dict(l=5, r=5, t=20, b=5)) #paper_bgcolor="LightSteelBlue")
    fig.update_traces(marker_color='rgb(158,202,225)')

    #print(f"LEN Y_PROB {len(y_prob)}")
    #print(f"LEN CAT {len(categories)}")

    words_dict = get_words(reshaped)
    words_string = ""

    for lang, w in words_dict.items():
        words_string += f"{lang}: "
        words_string += ', '.join(w)
        words_string += "\n"

    return (classes_string, fig, words_string)

# demo = gr.Interface(fn=generate_output, 
#     inputs=gr.Textbox(lines=6, placeholder="Insert text of the article here...", label="Article"), 
#     outputs=[gr.Textbox(lines=1, label="Category"), gr.Plot(label="Category probability"), gr.Textbox(lines=5, label="Topic discovery")], 
#     title="Article classification & topic discovery demo",
#     flagging_options=["Incorrect"],
#     theme=gr.themes.Base())
    #css=css_code)

demo = gr.Blocks(css=css_code, theme=gr.themes.Base(), title="Article classification & topic discovery demo")

with demo:
    with gr.Row():
        my_title = gr.HTML("<h1 align='center'>Article classification & topic discovery demo</h1>")
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(lines=22, placeholder="Insert text of the article here...", label="Article")
            with gr.Row():
                clear_button = gr.Button("Clear")
                submit_button = gr.Button("Submit")
        with gr.Column():
            with gr.Tabs():
                with gr.TabItem("Classification"):
                    category_text = gr.Textbox(lines=1, label="Category")
                    category_plot = gr.Plot()
                with gr.TabItem("Topic discovery"):
                    topic_text = gr.Textbox(lines=22, label="The most representative words")

    submit_button.click(generate_output, inputs=input_text, outputs=[category_text, category_plot, topic_text])
    clear_button.click(lambda: None, None, input_text, queue=False)

demo.launch()