Iseratho commited on
Commit
0a1cbe4
1 Parent(s): 64bd930

Add application files

Browse files
Files changed (3) hide show
  1. app.py +264 -0
  2. packages.txt +1 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+
9
+ from sentence_transformers import SentenceTransformer
10
+ import torch
11
+ import tqdm
12
+
13
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
14
+ import penman
15
+ from collections import Counter, defaultdict
16
+ import networkx as nx
17
+ from networkx.drawing.nx_agraph import pygraphviz_layout
18
+
19
+ from transformers import pipeline
20
+ from functools import partial
21
+
22
+ import numpy as np
23
+
24
+ import matplotlib.pyplot as plt
25
+ import seaborn as sns
26
+
27
+ from sentence_transformers import SentenceTransformer
28
+ import torch
29
+ import tqdm
30
+
31
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
32
+ import penman
33
+ from collections import Counter, defaultdict
34
+ import networkx as nx
35
+ from networkx.drawing.nx_agraph import pygraphviz_layout
36
+
37
+ class FramingLabels:
38
+ def __init__(self, base_model, candidate_labels, batch_size=16):
39
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
40
+ self.base_pipeline = pipeline("zero-shot-classification", model=base_model, device=device)
41
+ self.candidate_labels = candidate_labels
42
+ self.classifier = partial(self.base_pipeline, candidate_labels=candidate_labels, multi_label=True, batch_size=batch_size)
43
+
44
+ def order_scores(self, dic):
45
+ indices_order = [dic["labels"].index(l) for l in self.candidate_labels]
46
+ scores_ordered = np.array(dic["scores"])[indices_order].tolist()
47
+ return scores_ordered
48
+
49
+ def get_ordered_scores(self, sequence_to_classify):
50
+ if type(sequence_to_classify) == list:
51
+ res = []
52
+ for out in tqdm.tqdm(self.classifier(sequence_to_classify)):
53
+ res.append(out)
54
+ else:
55
+ res = self.classifier(sequence_to_classify)
56
+ if type(res) == list:
57
+ scores_ordered = list(map(self.order_scores, res))
58
+ scores_ordered = list(map(list, zip(*scores_ordered))) # reorder
59
+ else:
60
+ scores_ordered = self.order_scores(res)
61
+ return scores_ordered
62
+
63
+ def get_label_names(self):
64
+ label_names = [l.split(":")[0].split(" ")[0] for l in self.candidate_labels]
65
+ return label_names
66
+
67
+ def __call__(self, sequence_to_classify):
68
+ scores = self.get_ordered_scores(sequence_to_classify)
69
+ label_names = self.get_label_names()
70
+ return dict(zip(label_names, scores))
71
+
72
+ def visualize(self, name_to_score_dict, threshold=0.5, **kwargs):
73
+ fig, ax = plt.subplots()
74
+
75
+ cp = sns.color_palette()
76
+
77
+ scores_ordered = list(name_to_score_dict.values())
78
+ label_names = list(name_to_score_dict.keys())
79
+
80
+ colors = [cp[0] if s > 0.5 else cp[1] for s in scores_ordered]
81
+ ax.barh(label_names[::-1], scores_ordered[::-1], color=colors[::-1], **kwargs)
82
+
83
+ return fig, ax
84
+
85
+ class FramingDimensions:
86
+ def __init__(self, base_model, dimensions, pole_names):
87
+ self.encoder = SentenceTransformer(base_model)
88
+ self.dimensions = dimensions
89
+ self.dim_embs = self.encoder.encode(dimensions)
90
+ self.pole_names = pole_names
91
+ self.axis_names = list(map(lambda x: x[0] + "/" + x[1], pole_names))
92
+ axis_embs = []
93
+ for pole1, pole2 in pole_names:
94
+ p1 = self.get_dimension_names().index(pole1)
95
+ p2 = self.get_dimension_names().index(pole2)
96
+ axis_emb = self.dim_embs[p1] - self.dim_embs[p2]
97
+ axis_embs.append(axis_emb)
98
+ self.axis_embs = np.stack(axis_embs)
99
+
100
+ def get_dimension_names(self):
101
+ dimension_names = [l.split(":")[0].split(" ")[0] for l in self.dimensions]
102
+ return dimension_names
103
+
104
+ def __call__(self, sequence_to_align):
105
+ embs = self.encoder.encode(sequence_to_align)
106
+ scores = embs @ self.axis_embs.T
107
+ named_scores = dict(zip(self.pole_names, scores.T))
108
+ return named_scores
109
+
110
+ def visualize(self, align_scores_df, **kwargs):
111
+ name_left = align_scores_df.columns.map(lambda x: x[1])
112
+ name_right = align_scores_df.columns.map(lambda x: x[0])
113
+ bias = align_scores_df.mean()
114
+ color = ["b" if x > 0 else "r" for x in bias]
115
+ inten = (align_scores_df.var().fillna(0)+0.001)*50_000
116
+ bounds = bias.abs().max()*1.1
117
+
118
+ fig = plt.figure()
119
+ ax = fig.add_subplot(111)
120
+ plt.scatter(x=bias, y=name_left, s=inten, c=color)
121
+ plt.axvline(0)
122
+ plt.xlim(-bounds, bounds)
123
+ plt.gca().invert_yaxis()
124
+ axi = ax.twinx()
125
+ axi.set_ylim(ax.get_ylim())
126
+ axi.set_yticks(ax.get_yticks(), labels=name_right)
127
+ return fig
128
+
129
+ class FramingStructure:
130
+ def __init__(self, base_model, roles=None):
131
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
132
+ pipe2 = pipeline("text2text-generation", base_model, device=device, max_length=300)
133
+
134
+ def __call__(self, sequence_to_translate):
135
+ res = self.translator(sequence_to_translate)
136
+ def try_decode(x):
137
+ try:
138
+ return penman.decode(x["generated_text"])
139
+ except:
140
+ return None
141
+ graphs = list(filter(lambda item: item is not None, [try_decode(x) for x in res]))
142
+ return graphs
143
+
144
+ def visualize(self, decoded_graphs, min_node_threshold=1, **kwargs):
145
+ cnt = Counter()
146
+
147
+ for gen_text in decoded_graphs:
148
+ amr = gen_text.triples
149
+ amr = list(filter(lambda x: x[2] is not None, amr))
150
+ amr = list(map(lambda x: (x[0], x[1].replace(":", ""), x[2]), amr))
151
+ def trim_distinction_end(x):
152
+ x = x.split("_")[0]
153
+ return x
154
+ amr = list(map(lambda x: (trim_distinction_end(x[0]), x[1], trim_distinction_end(x[2])), amr))
155
+ cnt.update(amr)
156
+
157
+ G = nx.DiGraph()
158
+
159
+ color_map = defaultdict(lambda: "k", {
160
+ "ARG0": "y",
161
+ "ARG1": "r",
162
+ "ARG2": "g",
163
+ "ARG3": "b"
164
+ })
165
+
166
+ for entry, num in cnt.items():
167
+ if not G.has_node(entry[0]):
168
+ G.add_node(entry[0], weight=0)
169
+ if not G.has_node(entry[2]):
170
+ G.add_node(entry[2], weight=0)
171
+ G.nodes[entry[0]]["weight"] += num
172
+ G.nodes[entry[2]]["weight"] += num
173
+ G.add_edge(entry[0], entry[2], role=entry[1], weight=num, color=color_map[entry[1]])
174
+
175
+ G_sub = nx.subgraph_view(G, filter_node=lambda n: G.nodes[n]["weight"] >= min_node_threshold)
176
+
177
+ node_sizes = [x * 100 for x in nx.get_node_attributes(G_sub,'weight').values()]
178
+ edge_colors = nx.get_edge_attributes(G_sub,'color').values()
179
+
180
+ fig = plt.figure()
181
+
182
+ pos = pygraphviz_layout(G_sub, prog="dot")
183
+ nx.draw_networkx(G_sub, pos, node_size=node_sizes, edge_color=edge_colors)
184
+ nx.draw_networkx_labels(G_sub, pos)
185
+ nx.draw_networkx_edge_labels(G_sub, pos, edge_labels=nx.get_edge_attributes(G_sub, "role"))
186
+ return fig
187
+
188
+ # Specify the models
189
+ base_model_1 = "facebook/bart-large-mnli"
190
+ base_model_2 = 'all-mpnet-base-v2'
191
+ base_model_3 = "Iseratho/model_parse_xfm_bart_base-v0_1_0"
192
+ # https://homes.cs.washington.edu/~nasmith/papers/card+boydstun+gross+resnik+smith.acl15.pdf
193
+ candidate_labels = [
194
+ "Economic: costs, benefits, or other financial implications",
195
+ "Capacity and resources: availability of physical, human or financial resources, and capacity of current systems",
196
+ "Morality: religious or ethical implications",
197
+ "Fairness and equality: balance or distribution of rights, responsibilities, and resources",
198
+ "Legality, constitutionality and jurisprudence: rights, freedoms, and authority of individuals, corporations, and government",
199
+ "Policy prescription and evaluation: discussion of specific policies aimed at addressing problems",
200
+ "Crime and punishment: effectiveness and implications of laws and their enforcement",
201
+ "Security and defense: threats to welfare of the individual, community, or nation",
202
+ "Health and safety: health care, sanitation, public safety",
203
+ "Quality of life: threats and opportunities for the individual’s wealth, happiness, and well-being",
204
+ "Cultural identity: traditions, customs, or values of a social group in relation to a policy issue",
205
+ "Public opinion: attitudes and opinions of the general public, including polling and demographics",
206
+ "Political: considerations related to politics and politicians, including lobbying, elections, and attempts to sway voters",
207
+ "External regulation and reputation: international reputation or foreign policy of the U.S.",
208
+ "Other: any coherent group of frames not covered by the above categories",
209
+ ]
210
+
211
+ # https://osf.io/xakyw
212
+ dimensions = [
213
+ "Care: ...acted with kindness, compassion, or empathy, or nurtured another person.",
214
+ "Harm: ...acted with cruelty, or hurt or harmed another person/animal and caused suffering.",
215
+ "Fairness: ...acted in a fair manner, promoting equality, justice, or rights.",
216
+ "Cheating: ...was unfair or cheated, or caused an injustice or engaged in fraud.",
217
+ "Loyalty: ...acted with fidelity, or as a team player, or was loyal or patriotic.",
218
+ "Betrayal: ...acted disloyal, betrayed someone, was disloyal, or was a traitor.",
219
+ "Authority: ...obeyed, or acted with respect for authority or tradition.",
220
+ "Subversion: ...disobeyed or showed disrespect, or engaged in subversion or caused chaos.",
221
+ "Sanctity: ...acted in a way that was wholesome or sacred, or displayed purity or sanctity.",
222
+ "Degredation: ...was depraved, degrading, impure, or unnatural.",
223
+ ]
224
+ pole_names = [
225
+ ("Care", "Harm"),
226
+ ("Fairness", "Cheating"),
227
+ ("Loyalty", "Betrayal"),
228
+ ("Authority", "Subversion"),
229
+ ("Sanctity", "Degredation"),
230
+ ]
231
+
232
+ framing_label_model = FramingLabels(base_model_1, candidate_labels)
233
+ framing_dimen_model = FramingDimensions(base_model_2, dimensions, pole_names)
234
+ framing_struc_model = FramingStructure(base_model_3)
235
+
236
+ import pandas as pd
237
+
238
+ async def framing_single(text):
239
+ fig1, _ = framing_label_model.visualize(framing_label_model(text))
240
+ fig2 = framing_dimen_model.visualize(pd.DataFrame({k: [v] for k, v in framing_dimen_model(text).items()}))
241
+ fig3 = framing_struc_model.visualize(framing_struc_model(text))
242
+
243
+ return fig1, fig2, fig3
244
+
245
+ example_list = ["In 2021, doctors prevented the spread of the virus by vaccinating with Pfizer.",
246
+ "We must fight for our freedom.",
247
+ "The government prevents our freedom.",
248
+ "They prevent the spread.",
249
+ "We fight the virus.",
250
+ "I believe that we should act now. There is no time to waste."
251
+ ]
252
+
253
+ demo = gr.Interface(fn=framing_single,
254
+ title="FrameFinder",
255
+ inputs=gr.Textbox(label="Text to analyze."),
256
+ description="A simple tool that helps you find (discover and detect) frames in text.",
257
+ examples=example_list,
258
+ article="Check out the preliminary article in the [Web Conference Symposium](https://dl.acm.org/doi/pdf/10.1145/3543873.3587534), will be updated to currently in review article after publication.",
259
+ outputs=[gr.Plot(label="Label"),
260
+ gr.Plot(label="Dimensions"),
261
+ gr.Plot(label="Structure")
262
+ ])
263
+
264
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libgraphviz-dev
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ sentence_transformers
3
+ penman
4
+ pygraphviz