import gradio as gr from functools import partial import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sentence_transformers import SentenceTransformer import torch import tqdm from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline import penman from collections import Counter, defaultdict import networkx as nx from networkx.drawing.nx_agraph import pygraphviz_layout class FramingLabels: def __init__(self, base_model, candidate_labels, batch_size=16): device = "cuda:0" if torch.cuda.is_available() else "cpu" self.base_pipeline = pipeline("zero-shot-classification", model=base_model, device=device) self.candidate_labels = candidate_labels self.classifier = partial(self.base_pipeline, candidate_labels=candidate_labels, multi_label=True, batch_size=batch_size) def order_scores(self, dic): indices_order = [dic["labels"].index(l) for l in self.candidate_labels] scores_ordered = np.array(dic["scores"])[indices_order].tolist() return scores_ordered def get_ordered_scores(self, sequence_to_classify): if type(sequence_to_classify) == list: res = [] for out in tqdm.tqdm(self.classifier(sequence_to_classify)): res.append(out) else: res = self.classifier(sequence_to_classify) if type(res) == list: scores_ordered = list(map(self.order_scores, res)) scores_ordered = list(map(list, zip(*scores_ordered))) # reorder else: scores_ordered = self.order_scores(res) return scores_ordered def get_label_names(self): label_names = [l.split(":")[0].split(" ")[0] for l in self.candidate_labels] return label_names def __call__(self, sequence_to_classify): scores = self.get_ordered_scores(sequence_to_classify) label_names = self.get_label_names() return dict(zip(label_names, scores)) def visualize(self, name_to_score_dict, threshold=0.5, **kwargs): fig, ax = plt.subplots() cp = sns.color_palette() scores_ordered = list(name_to_score_dict.values()) label_names = list(name_to_score_dict.keys()) colors = [cp[0] if s > 0.5 else cp[1] for s in scores_ordered] ax.barh(label_names[::-1], scores_ordered[::-1], color=colors[::-1], **kwargs) plt.xlim(left=0) plt.tight_layout() return fig, ax class FramingDimensions: def __init__(self, base_model, dimensions, pole_names): self.encoder = SentenceTransformer(base_model) self.dimensions = dimensions self.dim_embs = self.encoder.encode(dimensions) self.pole_names = pole_names self.axis_names = list(map(lambda x: x[0] + "/" + x[1], pole_names)) axis_embs = [] for pole1, pole2 in pole_names: p1 = self.get_dimension_names().index(pole1) p2 = self.get_dimension_names().index(pole2) axis_emb = self.dim_embs[p1] - self.dim_embs[p2] axis_embs.append(axis_emb) self.axis_embs = np.stack(axis_embs) def get_dimension_names(self): dimension_names = [l.split(":")[0].split(" ")[0] for l in self.dimensions] return dimension_names def __call__(self, sequence_to_align): embs = self.encoder.encode(sequence_to_align) scores = embs @ self.axis_embs.T named_scores = dict(zip(self.pole_names, scores.T)) return named_scores def visualize(self, align_scores_df, **kwargs): name_left = align_scores_df.columns.map(lambda x: x[1]) name_right = align_scores_df.columns.map(lambda x: x[0]) bias = align_scores_df.mean() color = ["b" if x > 0 else "r" for x in bias] inten = (align_scores_df.var().fillna(0)+0.001)*50_000 bounds = bias.abs().max()*1.1 fig = plt.figure() ax = fig.add_subplot(111) plt.scatter(x=bias, y=name_left, s=inten, c=color) plt.axvline(0) plt.xlim(-bounds, bounds) plt.gca().invert_yaxis() axi = ax.twinx() axi.set_ylim(ax.get_ylim()) axi.set_yticks(ax.get_yticks(), labels=name_right) plt.tight_layout() return fig class FramingStructure: def __init__(self, base_model, roles=None): device = "cuda:0" if torch.cuda.is_available() else "cpu" self.translator = pipeline("text2text-generation", base_model, device=device, max_length=300) def __call__(self, sequence_to_translate): res = self.translator(sequence_to_translate) def try_decode(x): try: return penman.decode(x["generated_text"]) except: # print(f"Decode error for {res}") return None graphs = list(filter(lambda item: item is not None, [try_decode(x) for x in res])) return graphs def visualize(self, decoded_graphs, min_node_threshold=1, **kwargs): cnt = Counter() for gen_text in decoded_graphs: amr = gen_text.triples amr = list(filter(lambda x: x[2] is not None, amr)) amr = list(map(lambda x: (x[0], x[1].replace(":", ""), x[2]), amr)) def trim_distinction_end(x): x = x.split("_")[0] return x amr = list(map(lambda x: (trim_distinction_end(x[0]), x[1], trim_distinction_end(x[2])), amr)) cnt.update(amr) G = nx.DiGraph() color_map = defaultdict(lambda: "k", { "ARG0": "y", "ARG1": "r", "ARG2": "g", "ARG3": "b" }) for entry, num in cnt.items(): if not G.has_node(entry[0]): G.add_node(entry[0], weight=0) if not G.has_node(entry[2]): G.add_node(entry[2], weight=0) G.nodes[entry[0]]["weight"] += num G.nodes[entry[2]]["weight"] += num G.add_edge(entry[0], entry[2], role=entry[1], weight=num, color=color_map[entry[1]]) G_sub = nx.subgraph_view(G, filter_node=lambda n: G.nodes[n]["weight"] >= min_node_threshold) node_sizes = [x * 100 for x in nx.get_node_attributes(G_sub,'weight').values()] edge_colors = nx.get_edge_attributes(G_sub,'color').values() fig = plt.figure() pos = pygraphviz_layout(G_sub, prog="dot") nx.draw_networkx(G_sub, pos, node_size=node_sizes, edge_color=edge_colors) nx.draw_networkx_labels(G_sub, pos) nx.draw_networkx_edge_labels(G_sub, pos, edge_labels=nx.get_edge_attributes(G_sub, "role")) plt.tight_layout() return fig # Specify the models base_model_1 = "facebook/bart-large-mnli" base_model_2 = 'all-mpnet-base-v2' base_model_3 = "Iseratho/model_parse_xfm_bart_base-v0_1_0" # https://homes.cs.washington.edu/~nasmith/papers/card+boydstun+gross+resnik+smith.acl15.pdf candidate_labels = [ "Economic: costs, benefits, or other financial implications", "Capacity and resources: availability of physical, human or financial resources, and capacity of current systems", "Morality: religious or ethical implications", "Fairness and equality: balance or distribution of rights, responsibilities, and resources", "Legality, constitutionality and jurisprudence: rights, freedoms, and authority of individuals, corporations, and government", "Policy prescription and evaluation: discussion of specific policies aimed at addressing problems", "Crime and punishment: effectiveness and implications of laws and their enforcement", "Security and defense: threats to welfare of the individual, community, or nation", "Health and safety: health care, sanitation, public safety", "Quality of life: threats and opportunities for the individual’s wealth, happiness, and well-being", "Cultural identity: traditions, customs, or values of a social group in relation to a policy issue", "Public opinion: attitudes and opinions of the general public, including polling and demographics", "Political: considerations related to politics and politicians, including lobbying, elections, and attempts to sway voters", "External regulation and reputation: international reputation or foreign policy of the U.S.", "Other: any coherent group of frames not covered by the above categories", ] # https://osf.io/xakyw dimensions = [ "Care: ...acted with kindness, compassion, or empathy, or nurtured another person.", "Harm: ...acted with cruelty, or hurt or harmed another person/animal and caused suffering.", "Fairness: ...acted in a fair manner, promoting equality, justice, or rights.", "Cheating: ...was unfair or cheated, or caused an injustice or engaged in fraud.", "Loyalty: ...acted with fidelity, or as a team player, or was loyal or patriotic.", "Betrayal: ...acted disloyal, betrayed someone, was disloyal, or was a traitor.", "Authority: ...obeyed, or acted with respect for authority or tradition.", "Subversion: ...disobeyed or showed disrespect, or engaged in subversion or caused chaos.", "Sanctity: ...acted in a way that was wholesome or sacred, or displayed purity or sanctity.", "Degredation: ...was depraved, degrading, impure, or unnatural.", ] pole_names = [ ("Care", "Harm"), ("Fairness", "Cheating"), ("Loyalty", "Betrayal"), ("Authority", "Subversion"), ("Sanctity", "Degredation"), ] framing_label_model = FramingLabels(base_model_1, candidate_labels) framing_dimen_model = FramingDimensions(base_model_2, dimensions, pole_names) framing_struc_model = FramingStructure(base_model_3) def framing_multi(texts, min_node_threshold=1): res1 = pd.DataFrame(framing_label_model(texts)) fig1, _ = framing_label_model.visualize(res1.mean().to_dict(), xerr=res1.sem()) fig2 = framing_dimen_model.visualize(pd.DataFrame(framing_dimen_model(texts))) fig3 = framing_struc_model.visualize(framing_struc_model(texts), min_node_threshold=min_node_threshold) return fig1, fig2, fig3 def framing_single(text, min_node_threshold=1): fig1, _ = framing_label_model.visualize(framing_label_model(text)) fig2 = framing_dimen_model.visualize(pd.DataFrame({k: [v] for k, v in framing_dimen_model(text).items()})) fig3 = framing_struc_model.visualize(framing_struc_model(text), min_node_threshold=min_node_threshold) return fig1, fig2, fig3 async def framing_textbox(text, split, min_node_threshold): texts = text.split("\n") if split and len(texts) > 1: return framing_multi(texts, min_node_threshold) return framing_single(text, min_node_threshold) async def framing_file(file_obj, min_node_threshold): with open(file_obj.name, "r") as f: texts = f.readlines() if len(texts) > 1: return framing_multi(texts, min_node_threshold) return framing_single(texts, min_node_threshold) example_list = [["In 2010, CFCs were banned internationally due to their harmful effect on the ozone layer.", False, 1], ["In 2021, doctors prevented the spread of the virus by vaccinating with Pfizer.", False, 1], ["We must fight for our freedom.", False, 1], ["The government prevents our freedom.", False, 1], ["They prevent the spread.", False, 1], ["We fight the virus.", False, 1], ["I believe that we should act now.\nThere is no time to waste.", True, 1], ] textbox_inferface = gr.Interface(fn=framing_textbox, inputs=[ gr.Textbox(label="Text to analyze."), gr.Checkbox(True, label="Split on newlines? (To enter newlines type shift+Enter)"), gr.Number(1, label="Min node threshold for framing structure.") ], description="A simple tool that helps you find (discover and detect) frames in text.", examples=example_list, article="Check out the preliminary article in the [Web Conference Symposium](https://dl.acm.org/doi/pdf/10.1145/3543873.3587534), will be updated to currently in review article after publication.", outputs=[gr.Plot(label="Label"), gr.Plot(label="Dimensions"), gr.Plot(label="Structure") ]) file_interface = gr.Interface(fn=framing_file, inputs=[ gr.File(label="File of texts to analyze."), gr.Number(1, label="Min node threshold for framing structure."), ], description="A simple tool that helps you find (discover and detect) frames in text.", article="Check out the preliminary article in the [Web Conference Symposium](https://dl.acm.org/doi/pdf/10.1145/3543873.3587534), will be updated to currently in review article after publication.", outputs=[gr.Plot(label="Label"), gr.Plot(label="Dimensions"), gr.Plot(label="Structure")]) demo = gr.TabbedInterface([textbox_inferface, file_interface], tab_names=["Single Mode", "File Mode"], title="FrameFinder",) demo.launch()