ktrapeznikov commited on
Commit
0d3b8f7
1 Parent(s): 31ef2b4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2021 Systems & Technology Research. All rights reserved.
3
+ # This software and associated documentation is subject to the use restrictions stated in the LICENSE.txt file.
4
+ #
5
+
6
+
7
+ import streamlit as st
8
+ import pandas as pd
9
+ import json
10
+ from PIL import ImageColor
11
+ import math
12
+ import numpy as np
13
+ from colorcet import blues
14
+ from transformers import RobertaTokenizerFast, RobertaForMaskedLM
15
+ import torch
16
+ import os
17
+ import hashlib
18
+ device = "cpu"
19
+
20
+ sample_text="""SAN FRANCISCO — A Facebook-appointed panel of journalists, activists and lawyers on Wednesday upheld the social network’s ban of former President Donald J. Trump, ending any immediate return by Mr. Trump to mainstream social media and renewing a debate about tech power over online speech.
21
+ Facebook’s Oversight Board, which acts as a quasi-court over the company’s content decisions, ruled the social network was right to bar Mr. Trump after the insurrection in Washington in January, saying he “created an environment where a serious risk of violence was possible.” The panel said that ongoing risk “justified” the move.
22
+ But the board also kicked the case back to Facebook and its top executives. It said that an indefinite suspension was “not appropriate” because it was not a penalty defined in Facebook’s policies and that the company should apply a standard punishment, such as a time-bound suspension or a permanent ban. The board gave Facebook six months to make a final decision on Mr. Trump’s account status.
23
+ “Our sole job is to hold this extremely powerful organization, Facebook, accountable,” Michael McConnell, co-chair of the Oversight Board, said on a call with reporters. The ban on Mr. Trump “did not meet these standards,” he said."""
24
+
25
+
26
+
27
+
28
+
29
+ st.sidebar.success(f"running on {device}")
30
+
31
+ def get_color(norm_value,cmap):
32
+ idx = int(math.floor((len(cmap)-1)*norm_value))
33
+ return cmap[idx]
34
+
35
+ def get_color_cat(idx,cmap):
36
+ return cmap[idx % len(cmap)]
37
+
38
+ def make_html_text_with_color(text,color):
39
+ rgba = "rgba"+str(ImageColor.getrgb(color) + (.6,))
40
+ return f'<span style="background-color: {rgba}">{text}</span>'
41
+
42
+ def replace(text):
43
+ if text in ['<s>', '</s>', '<unk>', '<pad>', '<mask>']:
44
+ text = ""
45
+ return text.replace("�","")
46
+
47
+ def make_full_html(tokens, values, cmap=["yellow"], bounds=None, categotical = True):
48
+ if not categotical:
49
+ if bounds is None:
50
+ vmn = values.min()
51
+ vmx = values.max()
52
+ values = (values-vmn)/(vmx-vmn+1e-6)
53
+ else:
54
+ vmn,vmx = bounds
55
+ values = np.clip(values, vmn, vmx)
56
+ values = (values-vmn)/(vmx-vmn)
57
+ return "".join([make_html_text_with_color(replace(t),get_color(v,cmap)) for t,v in zip(tokens,values)])
58
+ else:
59
+ return "".join([make_html_text_with_color(replace(t),get_color_cat(v,cmap)) if v>=0 else replace(t) for t,v in zip(tokens,values)])
60
+
61
+
62
+ emotions = ["anger", "joy", "fear", "trust", "anticipation", "sadness", "disgust", "surprise"]
63
+
64
+ PATH_CONN = "noun_adj_conntation_lexicon.csv"
65
+
66
+ @st.cache(allow_output_mutation = True)
67
+ def get_connotations(emotion, vocab):
68
+ data = pd.read_csv(PATH_CONN)
69
+ data.conn = data.conn.apply(json.loads)
70
+ i = emotions.index(emotion)
71
+ mask = data.conn.apply(lambda e: e["Emo"][i]==1.)
72
+ word_set = set(data.loc[mask,"word"].values.tolist())
73
+ vocab_mask = torch.from_numpy(vocab.isin(word_set).values)
74
+ return word_set, vocab_mask
75
+
76
+ @st.cache(allow_output_mutation = True)
77
+ def get_articles():
78
+ article = pd.read_csv("/proj/semafor/datasets/all-the-news-2-1.csv", nrows=10000).dropna(subset =["title","article"])
79
+ return article
80
+
81
+
82
+ @st.cache(allow_output_mutation = True)
83
+ def search_articles(keyword,article):
84
+ mask = article.title.str.contains(keyword,case=False)
85
+ temp = article.loc[mask].sample().iloc[0]
86
+ return temp["title"] + " "+ temp["article"]
87
+
88
+ @st.cache(allow_output_mutation = True)
89
+ def get_model():
90
+ tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
91
+ model = RobertaForMaskedLM.from_pretrained('roberta-base').eval().requires_grad_(False).to(device)
92
+ clean_vocab = pd.Series({v:tokenizer.convert_tokens_to_string(k).strip().lower() for k,v in tokenizer.get_vocab().items()}).sort_index()
93
+ return tokenizer, clean_vocab, model
94
+
95
+
96
+ tokenizer, clean_vocab, model = get_model()
97
+
98
+
99
+
100
+
101
+ f"## Change Connotation"
102
+
103
+ col1,col2 = st.columns(2)
104
+
105
+ emotion_source = col1.selectbox("Source Emotion", emotions, index = 1)
106
+ emotion_target = col2.selectbox("Target Emotion", emotions, index = 0)
107
+
108
+
109
+ _, emotion_words_source = get_connotations(emotion_source,clean_vocab)
110
+ _, emotion_words_taget = get_connotations(emotion_target,clean_vocab)
111
+
112
+ # st.sidebar.write(emotion_words)
113
+
114
+ # custom_input = st.sidebar.checkbox("Custom Input",value = True)
115
+
116
+ custom_input = True
117
+
118
+ if custom_input:
119
+ article = st.sidebar.text_area("Paste Text Here", value =sample_text, height = 600)
120
+ else:
121
+ articles = get_articles()
122
+ keyword = st.sidebar.text_input("Keywords",value="virus")
123
+ article = search_articles(keyword, articles)
124
+
125
+ inputs = tokenizer(article, max_length=512, truncation = True,return_tensors = "pt" )
126
+ original_input_ids = inputs["input_ids"][0].clone()
127
+ words = [tokenizer.convert_tokens_to_string(s) for s in tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])]
128
+
129
+ #predict masked out words
130
+
131
+ mask = (inputs["input_ids"][0][:,None] == emotion_words_source.nonzero(as_tuple = False).flatten()).any(-1)
132
+ if not mask.any():
133
+ st.warning("no source words found, try another input")
134
+ scores = -np.ones(len(words))
135
+ words_mod = words
136
+
137
+ else:
138
+ inputs["input_ids"][0][mask] = tokenizer.mask_token_id
139
+
140
+ with torch.no_grad():
141
+ logits = model(**{k:v.to(device) for k,v in inputs.items()}).logits[0]
142
+ logits[:,~emotion_words_taget] = float("-inf")
143
+ logits[mask,original_input_ids[mask]] = float("-inf")
144
+ idx = logits[mask,:].argmax(-1).cpu()
145
+ # vals, idx = .topk(5,dim = -1)
146
+
147
+
148
+
149
+ inputs["input_ids"][0,mask] = idx
150
+ words_mod = [tokenizer.convert_tokens_to_string(s) for s in tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])]
151
+
152
+ # [tokenizer.decode(el) for el in idx.cpu()]
153
+
154
+
155
+ scores = mask.numpy().astype(int)
156
+ scores[scores==0] = -1
157
+
158
+
159
+
160
+
161
+
162
+
163
+ with col1:
164
+ # f"*{article.title}*"
165
+ html_str = make_full_html(words, scores,cmap=["blue"])
166
+ st.markdown(html_str, unsafe_allow_html=True)
167
+
168
+ with col2:
169
+ # f"*{article.title}*"
170
+ html_str = make_full_html(words_mod, scores,cmap=["yellow"])
171
+ st.markdown(html_str, unsafe_allow_html=True)
172
+
173
+
174
+
175
+