|
import gradio as gr |
|
import numpy as np |
|
import pickle |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
|
|
|
categories = ["Censorship","Development","Digital Activism","Disaster","Economics & Business","Education","Environment","Governance","Health","History","Humanitarian Response","International Relations","Law","Media & Journalism","Migration & Immigration","Politics","Protest","Religion","Sport","Travel","War & Conflict","Technology_Science","Women&Gender_LGBTQ+_Youth","Freedom_of_Speech_Human_Rights","Literature_Arts&Culture"] |
|
model = SentenceTransformer('sentence-transformers/LaBSE') |
|
with open('models/MLP_classifier_average_en.pkl', 'rb') as f: |
|
classifier = pickle.load(f) |
|
|
|
def get_embedding(text): |
|
if text is None: |
|
text = "" |
|
return model.encode(text) |
|
|
|
def get_categories(y_pred): |
|
indices = [] |
|
for idx, value in enumerate(y_pred): |
|
if value == 1: |
|
indices.append(idx) |
|
cats = [categories[i] for i in indices] |
|
return cats |
|
|
|
def generate_output(article): |
|
paragraphs = article.split("\n") |
|
embdds = [] |
|
for par in paragraphs: |
|
embdds.append(get_embedding(par)) |
|
embedding = np.average(embdds, axis=0) |
|
|
|
|
|
y_pred = classifier.predict(embedding.reshape(1, 768)) |
|
y_pred = y_pred.flatten() |
|
classes = get_categories(y_pred) |
|
|
|
return (classes, "clustering tbd") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface(fn=generate_output, |
|
inputs=gr.Textbox(lines=6, placeholder="Insert text of the article here...", label="Article"), |
|
outputs=[gr.Textbox(lines=1, label="Category"), gr.Textbox(lines=5, label="Topic discovery")], |
|
title="Article classification & topic discovery demo", |
|
flagging_options=["Incorrect"], |
|
theme=gr.themes.Base()) |
|
|
|
|
|
demo.launch() |
|
|