Kirill
update
6d16c16
#
# Copyright 2021 Systems & Technology Research. All rights reserved.
# This software and associated documentation is subject to the use restrictions stated in the LICENSE.txt file.
#
import streamlit as st
# import _json
import pandas as pd
import json
from PIL import ImageColor
import math
import numpy as np
from colorcet import blues
from transformers import RobertaTokenizerFast, RobertaForMaskedLM
import torch
import os
import hashlib
device = "cpu"
sample_text="""SAN FRANCISCO — A Facebook-appointed panel of journalists, activists and lawyers on Wednesday upheld the social network’s ban of former President Donald J. Trump, ending any immediate return by Mr. Trump to mainstream social media and renewing a debate about tech power over online speech.
Facebook’s Oversight Board, which acts as a quasi-court over the company’s content decisions, ruled the social network was right to bar Mr. Trump after the insurrection in Washington in January, saying he “created an environment where a serious risk of violence was possible.” The panel said that ongoing risk “justified” the move.
But the board also kicked the case back to Facebook and its top executives. It said that an indefinite suspension was “not appropriate” because it was not a penalty defined in Facebook’s policies and that the company should apply a standard punishment, such as a time-bound suspension or a permanent ban. The board gave Facebook six months to make a final decision on Mr. Trump’s account status.
“Our sole job is to hold this extremely powerful organization, Facebook, accountable,” Michael McConnell, co-chair of the Oversight Board, said on a call with reporters. The ban on Mr. Trump “did not meet these standards,” he said."""
st.sidebar.success(f"running on {device}")
def get_color(norm_value,cmap):
idx = int(math.floor((len(cmap)-1)*norm_value))
return cmap[idx]
def get_color_cat(idx,cmap):
return cmap[idx % len(cmap)]
def make_html_text_with_color(text,color):
rgba = "rgba"+str(ImageColor.getrgb(color) + (.6,))
return f'<span style="background-color: {rgba}">{text}</span>'
def replace(text):
if text in ['<s>', '</s>', '<unk>', '<pad>', '<mask>']:
text = ""
return text.replace("�","")
def make_full_html(tokens, values, cmap=["yellow"], bounds=None, categotical = True):
if not categotical:
if bounds is None:
vmn = values.min()
vmx = values.max()
values = (values-vmn)/(vmx-vmn+1e-6)
else:
vmn,vmx = bounds
values = np.clip(values, vmn, vmx)
values = (values-vmn)/(vmx-vmn)
return "".join([make_html_text_with_color(replace(t),get_color(v,cmap)) for t,v in zip(tokens,values)])
else:
return "".join([make_html_text_with_color(replace(t),get_color_cat(v,cmap)) if v>=0 else replace(t) for t,v in zip(tokens,values)])
emotions = ["anger", "joy", "fear", "trust", "anticipation", "sadness", "disgust", "surprise"]
PATH_CONN = "noun_adj_conntation_lexicon.csv"
@st.cache(allow_output_mutation = True,hash_funcs={'_json.Scanner': hash})
def get_connotations(emotion, vocab):
data = pd.read_csv(PATH_CONN)
data.conn = data.conn.apply(json.loads)
i = emotions.index(emotion)
mask = data.conn.apply(lambda e: e["Emo"][i]==1.)
word_set = set(data.loc[mask,"word"].values.tolist())
vocab_mask = torch.from_numpy(vocab.isin(word_set).values)
return word_set, vocab_mask
@st.cache(allow_output_mutation = True)
def get_model():
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
model = RobertaForMaskedLM.from_pretrained('roberta-base').eval().requires_grad_(False).to(device)
clean_vocab = pd.Series({v:tokenizer.convert_tokens_to_string(k).strip().lower() for k,v in tokenizer.get_vocab().items()}).sort_index()
return tokenizer, clean_vocab, model
tokenizer, clean_vocab, model = get_model()
f"## Change Connotation"
col1,col2 = st.columns(2)
emotion_source = col1.selectbox("Source Emotion", emotions, index = 1)
emotion_target = col2.selectbox("Target Emotion", emotions, index = 0)
_, emotion_words_source = get_connotations(emotion_source,clean_vocab)
_, emotion_words_taget = get_connotations(emotion_target,clean_vocab)
# st.sidebar.write(emotion_words)
# custom_input = st.sidebar.checkbox("Custom Input",value = True)
custom_input = True
if custom_input:
article = st.sidebar.text_area("Paste Text Here", value =sample_text, height = 600)
else:
articles = get_articles()
keyword = st.sidebar.text_input("Keywords",value="virus")
article = search_articles(keyword, articles)
inputs = tokenizer(article, max_length=512, truncation = True,return_tensors = "pt" )
original_input_ids = inputs["input_ids"][0].clone()
words = [tokenizer.convert_tokens_to_string(s) for s in tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])]
#predict masked out words
mask = (inputs["input_ids"][0][:,None] == emotion_words_source.nonzero(as_tuple = False).flatten()).any(-1)
if not mask.any():
st.warning("no source words found, try another input")
scores = -np.ones(len(words))
words_mod = words
else:
inputs["input_ids"][0][mask] = tokenizer.mask_token_id
with torch.no_grad():
logits = model(**{k:v.to(device) for k,v in inputs.items()}).logits[0]
logits[:,~emotion_words_taget] = float("-inf")
logits[mask,original_input_ids[mask]] = float("-inf")
idx = logits[mask,:].argmax(-1).cpu()
# vals, idx = .topk(5,dim = -1)
inputs["input_ids"][0,mask] = idx
words_mod = [tokenizer.convert_tokens_to_string(s) for s in tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])]
# [tokenizer.decode(el) for el in idx.cpu()]
scores = mask.numpy().astype(int)
scores[scores==0] = -1
with col1:
# f"*{article.title}*"
html_str = make_full_html(words, scores,cmap=["blue"])
st.markdown(html_str, unsafe_allow_html=True)
with col2:
# f"*{article.title}*"
html_str = make_full_html(words_mod, scores,cmap=["yellow"])
st.markdown(html_str, unsafe_allow_html=True)