|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
import pickle |
|
import sklearn |
|
import plotly.express as px |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.cluster import MiniBatchKMeans |
|
from learn_multi_doc_model import Model |
|
|
|
|
|
|
|
css_code='button#component-8{background-color: rgb(158,202,225);}' |
|
|
|
import __main__ |
|
setattr(__main__, "Model", Model) |
|
|
|
categories = ["Censorship","Development","Digital Activism","Disaster","Economics & Business","Education","Environment","Governance","Health","History","Humanitarian Response","International Relations","Law","Media & Journalism","Migration & Immigration","Politics","Protest","Religion","Sport","Travel","War & Conflict","Technology + Science","Women & Gender + LGBTQ + Youth","Freedom of Speech + Human Rights","Literature + Arts & Culture"] |
|
input_cvect_key_file = 'topic_discovery/cvects.key' |
|
model_labse = SentenceTransformer('sentence-transformers/LaBSE') |
|
with open('models/MLP_classifier_average_en.pkl', 'rb') as f: |
|
classifier = pickle.load(f) |
|
mul_model = None |
|
with open('models/model_0.0001_100.pkl', 'rb') as f: |
|
mul_model = pickle.load(f) |
|
|
|
def get_embedding(text): |
|
if text is None: |
|
text = "" |
|
return model_labse.encode(text) |
|
|
|
def get_categories(y_pred): |
|
indices = [] |
|
for idx, value in enumerate(y_pred): |
|
if value == 1: |
|
indices.append(idx) |
|
cats = [categories[i] for i in indices] |
|
return cats |
|
|
|
def get_words(doc_emb): |
|
|
|
cvects = {} |
|
vocab = {} |
|
with open(input_cvect_key_file, "r") as fpr: |
|
for line in fpr: |
|
|
|
lang, fpath = line.strip().split() |
|
with open(fpath, "rb") as fpr: |
|
|
|
cvects[lang] = pickle.load(fpr) |
|
vocab[lang] = cvects[lang].get_feature_names() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
topn = 10 |
|
|
|
|
|
|
|
|
|
|
|
doc_emb = doc_emb.flatten() |
|
|
|
words_dict = {} |
|
|
|
for lang in mul_model.E.keys(): |
|
|
|
|
|
|
|
scores = mul_model.E[lang] @ (doc_emb).T |
|
k_ixs = np.argsort(scores)[::-1][:topn].squeeze() |
|
tmp = [] |
|
for i in k_ixs: |
|
|
|
tmp.append(vocab[lang][i]) |
|
|
|
words_dict[lang] = tmp |
|
|
|
|
|
return words_dict |
|
|
|
|
|
def generate_output(article): |
|
paragraphs = article.split("\n") |
|
embdds = [] |
|
for par in paragraphs: |
|
embdds.append(get_embedding(par)) |
|
embedding = np.average(embdds, axis=0) |
|
|
|
|
|
reshaped = embedding.reshape(1, 768) |
|
|
|
|
|
|
|
y_prob = classifier.predict_proba(reshaped) |
|
y_prob = y_prob.reshape(len(categories),1) |
|
|
|
y_pred = [1 if x >= 0.5 else 0 for x in y_prob] |
|
|
|
classes = get_categories(y_pred) |
|
if len(classes) > 1: |
|
classes_string = ', '.join(classes) |
|
elif len(classes) == 1: |
|
classes_string = classes[0] |
|
else: |
|
classes_string = 'No category was found.' |
|
|
|
|
|
|
|
data = pd.DataFrame() |
|
data['Category'] = categories |
|
data['Probability'] = y_prob |
|
fig = px.bar(data, x='Probability', y='Category', orientation='h', height=600) |
|
fig.update_xaxes(range=[0, 1]) |
|
fig.update_layout(margin=dict(l=5, r=5, t=20, b=5)) |
|
fig.update_traces(marker_color='rgb(158,202,225)') |
|
|
|
|
|
|
|
|
|
words_dict = get_words(reshaped) |
|
words_string = "" |
|
|
|
for lang, w in words_dict.items(): |
|
words_string += f"{lang}: " |
|
words_string += ', '.join(w) |
|
words_string += "\n" |
|
|
|
return (classes_string, fig, words_string) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Blocks(css=css_code, theme=gr.themes.Base(), title="Article classification & topic discovery demo") |
|
|
|
with demo: |
|
with gr.Row(): |
|
my_title = gr.HTML("<h1 align='center'>Article classification & topic discovery demo</h1>") |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.Textbox(lines=22, placeholder="Insert text of the article here...", label="Article") |
|
with gr.Row(): |
|
clear_button = gr.Button("Clear") |
|
submit_button = gr.Button("Submit") |
|
with gr.Column(): |
|
with gr.Tabs(): |
|
with gr.TabItem("Classification"): |
|
category_text = gr.Textbox(lines=1, label="Category") |
|
category_plot = gr.Plot() |
|
with gr.TabItem("Topic discovery"): |
|
topic_text = gr.Textbox(lines=22, label="The most representative words") |
|
|
|
submit_button.click(generate_output, inputs=input_text, outputs=[category_text, category_plot, topic_text]) |
|
clear_button.click(lambda: None, None, input_text, queue=False) |
|
|
|
demo.launch() |
|
|