frame-finder / app.py
Iseratho's picture
Add definitions
a2d34d2
import gradio as gr
from functools import partial
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sentence_transformers import SentenceTransformer
import torch
import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
import penman
from collections import Counter, defaultdict
import networkx as nx
from networkx.drawing.nx_agraph import pygraphviz_layout
class FramingLabels:
def __init__(self, base_model, candidate_labels, batch_size=16):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.base_pipeline = pipeline("zero-shot-classification", model=base_model, device=device)
self.candidate_labels = candidate_labels
self.classifier = partial(self.base_pipeline, candidate_labels=candidate_labels, multi_label=True, batch_size=batch_size)
def order_scores(self, dic):
indices_order = [dic["labels"].index(l) for l in self.candidate_labels]
scores_ordered = np.array(dic["scores"])[indices_order].tolist()
return scores_ordered
def get_ordered_scores(self, sequence_to_classify):
if type(sequence_to_classify) == list:
res = []
for out in tqdm.tqdm(self.classifier(sequence_to_classify)):
res.append(out)
else:
res = self.classifier(sequence_to_classify)
if type(res) == list:
scores_ordered = list(map(self.order_scores, res))
scores_ordered = list(map(list, zip(*scores_ordered))) # reorder
else:
scores_ordered = self.order_scores(res)
return scores_ordered
def get_label_names(self):
label_names = [l.split(":")[0].split(" ")[0] for l in self.candidate_labels]
return label_names
def __call__(self, sequence_to_classify):
scores = self.get_ordered_scores(sequence_to_classify)
label_names = self.get_label_names()
return dict(zip(label_names, scores))
def visualize(self, name_to_score_dict, threshold=0.5, **kwargs):
fig, ax = plt.subplots()
cp = sns.color_palette()
scores_ordered = list(name_to_score_dict.values())
label_names = list(name_to_score_dict.keys())
colors = [cp[0] if s > 0.5 else cp[1] for s in scores_ordered]
ax.barh(label_names[::-1], scores_ordered[::-1], color=colors[::-1], **kwargs)
plt.xlim(left=0)
plt.tight_layout()
return fig, ax
class FramingDimensions:
def __init__(self, base_model, dimensions, pole_names):
self.encoder = SentenceTransformer(base_model)
self.dimensions = dimensions
self.dim_embs = self.encoder.encode(dimensions)
self.pole_names = pole_names
self.axis_names = list(map(lambda x: x[0] + "/" + x[1], pole_names))
axis_embs = []
for pole1, pole2 in pole_names:
p1 = self.get_dimension_names().index(pole1)
p2 = self.get_dimension_names().index(pole2)
axis_emb = self.dim_embs[p1] - self.dim_embs[p2]
axis_embs.append(axis_emb)
self.axis_embs = np.stack(axis_embs)
def get_dimension_names(self):
dimension_names = [l.split(":")[0].split(" ")[0] for l in self.dimensions]
return dimension_names
def __call__(self, sequence_to_align):
embs = self.encoder.encode(sequence_to_align)
scores = embs @ self.axis_embs.T
named_scores = dict(zip(self.pole_names, scores.T))
return named_scores
def visualize(self, align_scores_df, **kwargs):
name_left = align_scores_df.columns.map(lambda x: x[1])
name_right = align_scores_df.columns.map(lambda x: x[0])
bias = align_scores_df.mean()
color = ["b" if x > 0 else "r" for x in bias]
inten = (align_scores_df.var().fillna(0)+0.001)*50_000
bounds = bias.abs().max()*1.1
fig = plt.figure()
ax = fig.add_subplot(111)
plt.scatter(x=bias, y=name_left, s=inten, c=color)
plt.axvline(0)
plt.xlim(-bounds, bounds)
plt.gca().invert_yaxis()
axi = ax.twinx()
axi.set_ylim(ax.get_ylim())
axi.set_yticks(ax.get_yticks(), labels=name_right)
plt.tight_layout()
return fig
class FramingStructure:
def __init__(self, base_model, roles=None):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.translator = pipeline("text2text-generation", base_model, device=device, max_length=300)
def __call__(self, sequence_to_translate):
res = self.translator(sequence_to_translate)
def try_decode(x):
try:
return penman.decode(x["generated_text"])
except:
# print(f"Decode error for {res}")
return None
graphs = list(filter(lambda item: item is not None, [try_decode(x) for x in res]))
return graphs
def visualize(self, decoded_graphs, min_node_threshold=1, **kwargs):
cnt = Counter()
for gen_text in decoded_graphs:
amr = gen_text.triples
amr = list(filter(lambda x: x[2] is not None, amr))
amr = list(map(lambda x: (x[0], x[1].replace(":", ""), x[2]), amr))
def trim_distinction_end(x):
x = x.split("_")[0]
return x
amr = list(map(lambda x: (trim_distinction_end(x[0]), x[1], trim_distinction_end(x[2])), amr))
cnt.update(amr)
G = nx.DiGraph()
color_map = defaultdict(lambda: "k", {
"ARG0": "y",
"ARG1": "r",
"ARG2": "g",
"ARG3": "b"
})
for entry, num in cnt.items():
if not G.has_node(entry[0]):
G.add_node(entry[0], weight=0)
if not G.has_node(entry[2]):
G.add_node(entry[2], weight=0)
G.nodes[entry[0]]["weight"] += num
G.nodes[entry[2]]["weight"] += num
G.add_edge(entry[0], entry[2], role=entry[1], weight=num, color=color_map[entry[1]])
G_sub = nx.subgraph_view(G, filter_node=lambda n: G.nodes[n]["weight"] >= min_node_threshold)
node_sizes = [x * 100 for x in nx.get_node_attributes(G_sub,'weight').values()]
edge_colors = nx.get_edge_attributes(G_sub,'color').values()
fig = plt.figure()
pos = pygraphviz_layout(G_sub, prog="dot")
nx.draw_networkx(G_sub, pos, node_size=node_sizes, edge_color=edge_colors)
nx.draw_networkx_labels(G_sub, pos)
nx.draw_networkx_edge_labels(G_sub, pos, edge_labels=nx.get_edge_attributes(G_sub, "role"))
plt.tight_layout()
return fig
# Specify the models
base_model_1 = "facebook/bart-large-mnli"
base_model_2 = 'all-mpnet-base-v2'
base_model_3 = "Iseratho/model_parse_xfm_bart_base-v0_1_0"
# https://homes.cs.washington.edu/~nasmith/papers/card+boydstun+gross+resnik+smith.acl15.pdf
candidate_labels = [
"Economic: costs, benefits, or other financial implications",
"Capacity and resources: availability of physical, human or financial resources, and capacity of current systems",
"Morality: religious or ethical implications",
"Fairness and equality: balance or distribution of rights, responsibilities, and resources",
"Legality, constitutionality and jurisprudence: rights, freedoms, and authority of individuals, corporations, and government",
"Policy prescription and evaluation: discussion of specific policies aimed at addressing problems",
"Crime and punishment: effectiveness and implications of laws and their enforcement",
"Security and defense: threats to welfare of the individual, community, or nation",
"Health and safety: health care, sanitation, public safety",
"Quality of life: threats and opportunities for the individual’s wealth, happiness, and well-being",
"Cultural identity: traditions, customs, or values of a social group in relation to a policy issue",
"Public opinion: attitudes and opinions of the general public, including polling and demographics",
"Political: considerations related to politics and politicians, including lobbying, elections, and attempts to sway voters",
"External regulation and reputation: international reputation or foreign policy of the U.S.",
"Other: any coherent group of frames not covered by the above categories",
]
# https://osf.io/xakyw
dimensions = [
"Care: ...acted with kindness, compassion, or empathy, or nurtured another person.",
"Harm: ...acted with cruelty, or hurt or harmed another person/animal and caused suffering.",
"Fairness: ...acted in a fair manner, promoting equality, justice, or rights.",
"Cheating: ...was unfair or cheated, or caused an injustice or engaged in fraud.",
"Loyalty: ...acted with fidelity, or as a team player, or was loyal or patriotic.",
"Betrayal: ...acted disloyal, betrayed someone, was disloyal, or was a traitor.",
"Authority: ...obeyed, or acted with respect for authority or tradition.",
"Subversion: ...disobeyed or showed disrespect, or engaged in subversion or caused chaos.",
"Sanctity: ...acted in a way that was wholesome or sacred, or displayed purity or sanctity.",
"Degradation: ...was depraved, degrading, impure, or unnatural.",
]
pole_names = [
("Care", "Harm"),
("Fairness", "Cheating"),
("Loyalty", "Betrayal"),
("Authority", "Subversion"),
("Sanctity", "Degradation"),
]
framing_label_model = FramingLabels(base_model_1, candidate_labels)
framing_dimen_model = FramingDimensions(base_model_2, dimensions, pole_names)
framing_struc_model = FramingStructure(base_model_3)
def framing_multi(texts, min_node_threshold=1):
res1 = pd.DataFrame(framing_label_model(texts))
fig1, _ = framing_label_model.visualize(res1.mean().to_dict(), xerr=res1.sem())
fig2 = framing_dimen_model.visualize(pd.DataFrame(framing_dimen_model(texts)))
fig3 = framing_struc_model.visualize(framing_struc_model(texts), min_node_threshold=min_node_threshold)
return fig1, fig2, fig3
def framing_single(text, min_node_threshold=1):
fig1, _ = framing_label_model.visualize(framing_label_model(text))
fig2 = framing_dimen_model.visualize(pd.DataFrame({k: [v] for k, v in framing_dimen_model(text).items()}))
fig3 = framing_struc_model.visualize(framing_struc_model(text), min_node_threshold=min_node_threshold)
return fig1, fig2, fig3
async def framing_textbox(text, split, min_node_threshold):
texts = text.split("\n")
if split and len(texts) > 1:
return framing_multi(texts, min_node_threshold)
return framing_single(text, min_node_threshold)
async def framing_file(file_obj, split, min_node_threshold):
with open(file_obj.name, "r") as f:
if split:
texts = f.readlines()
if len(texts) > 1:
return framing_multi(texts, min_node_threshold)
else:
text = texts[0]
else:
text = f.read()
return framing_single(text, min_node_threshold)
example_list = [["In 2010, CFCs were banned internationally due to their harmful effect on the ozone layer.", False, 1],
["In 2021, doctors prevented the spread of the virus by vaccinating with Pfizer.", False, 1],
["We must fight for our freedom.", False, 1],
["The government prevents our freedom.", False, 1],
["They prevent the spread.", False, 1],
["We fight the virus.", False, 1],
["I believe that we should act now.\nThere is no time to waste.", True, 1],
]
description = """A simple tool that helps you find (discover and detect) frames in text.
Note that due to the computation time required for underlying Transformer models, only short texts are recommended."""
article=""""Check out the preliminary article in the [Web Conference Symposium](https://dl.acm.org/doi/pdf/10.1145/3543873.3587534), will be updated to currently in review article after publication.
<details>
<summary>Explanation of labels:</summary>
<ul>
<li>Economic: costs, benefits, or other financial implications</li>
<li>Capacity and resources: availability of physical, human or financial resources, and capacity of current systems</li>
<li>Morality: religious or ethical implications</li>
<li>Fairness and equality: balance or distribution of rights, responsibilities, and resources</li>
<li>Legality, constitutionality and jurisprudence: rights, freedoms, and authority of individuals, corporations, and government</li>
<li>Policy prescription and evaluation: discussion of specific policies aimed at addressing problems</li>
<li>Crime and punishment: effectiveness and implications of laws and their enforcement</li>
<li>Security and defense: threats to welfare of the individual, community, or nation</li>
<li>Health and safety: health care, sanitation, public safety</li>
<li>Quality of life: threats and opportunities for the individual’s wealth, happiness, and well-being</li>
<li>Cultural identity: traditions, customs, or values of a social group in relation to a policy issue</li>
<li>Public opinion: attitudes and opinions of the general public, including polling and demographics</li>
<li>Political: considerations related to politics and politicians, including lobbying, elections, and attempts to sway voters</li>
<li>External regulation and reputation: international reputation or foreign policy of the U.S.</li>
<li>Other: any coherent group of frames not covered by the above categories</li>
</ul>
</details>
<details>
<summary>Explanation of dimensions: </summary>
<ul>
<li>Care: ...acted with kindness, compassion, or empathy, or nurtured another person.</li>
<li>Harm: ...acted with cruelty, or hurt or harmed another person/animal and caused suffering.</li>
<li>Fairness: ...acted in a fair manner, promoting equality, justice, or rights.</li>
<li>Cheating: ...was unfair or cheated, or caused an injustice or engaged in fraud.</li>
<li>Loyalty: ...acted with fidelity, or as a team player, or was loyal or patriotic.</li>
<li>Betrayal: ...acted disloyal, betrayed someone, was disloyal, or was a traitor.</li>
<li>Authority: ...obeyed, or acted with respect for authority or tradition.</li>
<li>Subversion: ...disobeyed or showed disrespect, or engaged in subversion or caused chaos.</li>
<li>Sanctity: ...acted in a way that was wholesome or sacred, or displayed purity or sanctity.</li>
<li>Degradation: ...was depraved, degrading, impure, or unnatural.</li>
</ul>
</details>
Document of structure (AMR) explanation: [AMR Specification](https://github.com/amrisi/amr-guidelines/blob/master/amr.md)
"""
textbox_inferface = gr.Interface(fn=framing_textbox,
inputs=[
gr.Textbox(label="Text to analyze."),
gr.Checkbox(True, label="Split on newlines? (To enter newlines type shift+Enter)"),
gr.Number(1, label="Min node threshold for framing structure.")
],
description=description,
examples=example_list,
article=article,
outputs=[gr.Plot(label="Label"),
gr.Plot(label="Dimensions"),
gr.Plot(label="Structure")
])
file_interface = gr.Interface(fn=framing_file,
inputs=[
gr.File(label="File of texts to analyze."),
gr.Checkbox(True, label="Split on newlines?"),
gr.Number(1, label="Min node threshold for framing structure."),
],
description=description,
article=article,
outputs=[gr.Plot(label="Label"),
gr.Plot(label="Dimensions"),
gr.Plot(label="Structure")])
demo = gr.TabbedInterface([textbox_inferface, file_interface],
tab_names=["Single Mode", "File Mode"],
title="FrameFinder",)
demo.launch()