import gradio as gr import numpy as np import pandas as pd import pickle import sklearn import plotly.express as px from sentence_transformers import SentenceTransformer from sklearn.cluster import MiniBatchKMeans from learn_multi_doc_model import Model #css_code='body {background-image:url("https://picsum.photos/seed/picsum/200/300");} div.gradio-container {background: white;}, button#component-8{background-color: rgb(158,202,225);}' css_code='button#component-8{background-color: rgb(158,202,225);}' import __main__ setattr(__main__, "Model", Model) categories = ["Censorship","Development","Digital Activism","Disaster","Economics & Business","Education","Environment","Governance","Health","History","Humanitarian Response","International Relations","Law","Media & Journalism","Migration & Immigration","Politics","Protest","Religion","Sport","Travel","War & Conflict","Technology + Science","Women & Gender + LGBTQ + Youth","Freedom of Speech + Human Rights","Literature + Arts & Culture"] input_cvect_key_file = 'topic_discovery/cvects.key' model_labse = SentenceTransformer('sentence-transformers/LaBSE') with open('models/MLP_classifier_average_en.pkl', 'rb') as f: classifier = pickle.load(f) mul_model = None with open('models/model_0.0001_100.pkl', 'rb') as f: mul_model = pickle.load(f) def get_embedding(text): if text is None: text = "" return model_labse.encode(text) def get_categories(y_pred): indices = [] for idx, value in enumerate(y_pred): if value == 1: indices.append(idx) cats = [categories[i] for i in indices] return cats def get_words(doc_emb): # load countvectorizers cvects = {} vocab = {} # load vocabulary of words for each lang with open(input_cvect_key_file, "r") as fpr: for line in fpr: #print(line) lang, fpath = line.strip().split() with open(fpath, "rb") as fpr: #print(f"loading {fpath}") cvects[lang] = pickle.load(fpr) vocab[lang] = cvects[lang].get_feature_names() #print( # "Loaded CountVectorizer for lang", # lang, # "with vocab size:", # len(vocab[lang]), #) topn = 10 # top N words per cluster #print(vocab["en"]) #print("MODEL KEYS") #print(mul_model.E.keys()) doc_emb = doc_emb.flatten() words_dict = {} for lang in mul_model.E.keys(): #print(lang, end=": ") scores = mul_model.E[lang].detach().numpy() @ (doc_emb).T k_ixs = np.argsort(scores)[::-1][:topn].squeeze() # sort them in descending order and pick topn tmp = [] for i in k_ixs: #print(vocab[lang][i], end=", ") tmp.append(vocab[lang][i]) words_dict[lang] = tmp #print() return words_dict def generate_output(article): paragraphs = article.split("\n") embdds = [] for par in paragraphs: embdds.append(get_embedding(par)) embedding = np.average(embdds, axis=0) #y_pred = classifier.predict_proba(embedding.reshape(1, 768)) reshaped = embedding.reshape(1, 768) #y_pred = classifier.predict(reshaped) #y_pred = y_pred.flatten() y_prob = classifier.predict_proba(reshaped) y_prob = y_prob.reshape(len(categories),1) y_pred = [1 if x >= 0.5 else 0 for x in y_prob] classes = get_categories(y_pred) if len(classes) > 1: classes_string = ', '.join(classes) elif len(classes) == 1: classes_string = classes[0] else: classes_string = 'No category was found.' data = pd.DataFrame() data['Category'] = categories data['Probability'] = y_prob fig = px.bar(data, x='Probability', y='Category', orientation='h', height=600)#, title="Category probability") fig.update_xaxes(range=[0, 1]) fig.update_layout(margin=dict(l=5, r=5, t=20, b=5)) #paper_bgcolor="LightSteelBlue") fig.update_traces(marker_color='rgb(158,202,225)') #print(f"LEN Y_PROB {len(y_prob)}") #print(f"LEN CAT {len(categories)}") words_dict = get_words(reshaped) words_string = "" for lang, w in words_dict.items(): words_string += f"{lang}: " words_string += ', '.join(w) words_string += "\n" return (classes_string, fig, words_string) # demo = gr.Interface(fn=generate_output, # inputs=gr.Textbox(lines=6, placeholder="Insert text of the article here...", label="Article"), # outputs=[gr.Textbox(lines=1, label="Category"), gr.Plot(label="Category probability"), gr.Textbox(lines=5, label="Topic discovery")], # title="Article classification & topic discovery demo", # flagging_options=["Incorrect"], # theme=gr.themes.Base()) #css=css_code) demo = gr.Blocks(css=css_code, theme=gr.themes.Base(), title="Article classification & topic discovery demo") with demo: with gr.Row(): my_title = gr.HTML("

Article classification & topic discovery demo

") with gr.Row(): with gr.Column(): input_text = gr.Textbox(lines=22, placeholder="Insert text of the article here...", label="Article") with gr.Row(): clear_button = gr.Button("Clear") submit_button = gr.Button("Submit") with gr.Column(): with gr.Tabs(): with gr.TabItem("Classification"): category_text = gr.Textbox(lines=1, label="Category") category_plot = gr.Plot() with gr.TabItem("Topic discovery"): topic_text = gr.Textbox(lines=22, label="The most representative words") submit_button.click(generate_output, inputs=input_text, outputs=[category_text, category_plot, topic_text]) clear_button.click(lambda: None, None, input_text, queue=False) demo.launch()