Iseratho commited on
Commit
d2ddecb
1 Parent(s): 7d6b0b0

Multiline, file mode, node threshold and plot improvments

Browse files
Files changed (1) hide show
  1. app.py +56 -14
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import gradio as gr
2
 
3
- from transformers import pipeline
4
  from functools import partial
5
 
 
6
  import numpy as np
7
 
8
  import matplotlib.pyplot as plt
@@ -63,7 +63,8 @@ class FramingLabels:
63
 
64
  colors = [cp[0] if s > 0.5 else cp[1] for s in scores_ordered]
65
  ax.barh(label_names[::-1], scores_ordered[::-1], color=colors[::-1], **kwargs)
66
-
 
67
  return fig, ax
68
 
69
  class FramingDimensions:
@@ -108,6 +109,7 @@ class FramingDimensions:
108
  axi = ax.twinx()
109
  axi.set_ylim(ax.get_ylim())
110
  axi.set_yticks(ax.get_yticks(), labels=name_right)
 
111
  return fig
112
 
113
  class FramingStructure:
@@ -121,6 +123,7 @@ class FramingStructure:
121
  try:
122
  return penman.decode(x["generated_text"])
123
  except:
 
124
  return None
125
  graphs = list(filter(lambda item: item is not None, [try_decode(x) for x in res]))
126
  return graphs
@@ -167,6 +170,7 @@ class FramingStructure:
167
  nx.draw_networkx(G_sub, pos, node_size=node_sizes, edge_color=edge_colors)
168
  nx.draw_networkx_labels(G_sub, pos)
169
  nx.draw_networkx_edge_labels(G_sub, pos, edge_labels=nx.get_edge_attributes(G_sub, "role"))
 
170
  return fig
171
 
172
  # Specify the models
@@ -217,26 +221,49 @@ framing_label_model = FramingLabels(base_model_1, candidate_labels)
217
  framing_dimen_model = FramingDimensions(base_model_2, dimensions, pole_names)
218
  framing_struc_model = FramingStructure(base_model_3)
219
 
220
- import pandas as pd
 
 
 
 
221
 
222
- async def framing_single(text):
 
 
223
  fig1, _ = framing_label_model.visualize(framing_label_model(text))
224
  fig2 = framing_dimen_model.visualize(pd.DataFrame({k: [v] for k, v in framing_dimen_model(text).items()}))
225
- fig3 = framing_struc_model.visualize(framing_struc_model(text))
226
 
227
  return fig1, fig2, fig3
228
 
229
- example_list = ["In 2021, doctors prevented the spread of the virus by vaccinating with Pfizer.",
230
- "We must fight for our freedom.",
231
- "The government prevents our freedom.",
232
- "They prevent the spread.",
233
- "We fight the virus.",
234
- "I believe that we should act now. There is no time to waste."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  ]
236
 
237
- demo = gr.Interface(fn=framing_single,
238
- title="FrameFinder",
239
- inputs=gr.Textbox(label="Text to analyze."),
 
 
 
240
  description="A simple tool that helps you find (discover and detect) frames in text.",
241
  examples=example_list,
242
  article="Check out the preliminary article in the [Web Conference Symposium](https://dl.acm.org/doi/pdf/10.1145/3543873.3587534), will be updated to currently in review article after publication.",
@@ -245,4 +272,19 @@ demo = gr.Interface(fn=framing_single,
245
  gr.Plot(label="Structure")
246
  ])
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  demo.launch()
 
1
  import gradio as gr
2
 
 
3
  from functools import partial
4
 
5
+ import pandas as pd
6
  import numpy as np
7
 
8
  import matplotlib.pyplot as plt
 
63
 
64
  colors = [cp[0] if s > 0.5 else cp[1] for s in scores_ordered]
65
  ax.barh(label_names[::-1], scores_ordered[::-1], color=colors[::-1], **kwargs)
66
+ plt.xlim(left=0)
67
+ plt.tight_layout()
68
  return fig, ax
69
 
70
  class FramingDimensions:
 
109
  axi = ax.twinx()
110
  axi.set_ylim(ax.get_ylim())
111
  axi.set_yticks(ax.get_yticks(), labels=name_right)
112
+ plt.tight_layout()
113
  return fig
114
 
115
  class FramingStructure:
 
123
  try:
124
  return penman.decode(x["generated_text"])
125
  except:
126
+ # print(f"Decode error for {res}")
127
  return None
128
  graphs = list(filter(lambda item: item is not None, [try_decode(x) for x in res]))
129
  return graphs
 
170
  nx.draw_networkx(G_sub, pos, node_size=node_sizes, edge_color=edge_colors)
171
  nx.draw_networkx_labels(G_sub, pos)
172
  nx.draw_networkx_edge_labels(G_sub, pos, edge_labels=nx.get_edge_attributes(G_sub, "role"))
173
+ plt.tight_layout()
174
  return fig
175
 
176
  # Specify the models
 
221
  framing_dimen_model = FramingDimensions(base_model_2, dimensions, pole_names)
222
  framing_struc_model = FramingStructure(base_model_3)
223
 
224
+ def framing_multi(texts, min_node_threshold=1):
225
+ res1 = pd.DataFrame(framing_label_model(texts))
226
+ fig1, _ = framing_label_model.visualize(res1.mean().to_dict(), xerr=res1.sem())
227
+ fig2 = framing_dimen_model.visualize(pd.DataFrame(framing_dimen_model(texts)))
228
+ fig3 = framing_struc_model.visualize(framing_struc_model(texts), min_node_threshold=min_node_threshold)
229
 
230
+ return fig1, fig2, fig3
231
+
232
+ def framing_single(text, min_node_threshold=1):
233
  fig1, _ = framing_label_model.visualize(framing_label_model(text))
234
  fig2 = framing_dimen_model.visualize(pd.DataFrame({k: [v] for k, v in framing_dimen_model(text).items()}))
235
+ fig3 = framing_struc_model.visualize(framing_struc_model(text), min_node_threshold=min_node_threshold)
236
 
237
  return fig1, fig2, fig3
238
 
239
+ async def framing_textbox(text, split, min_node_threshold):
240
+ texts = text.split("\n")
241
+ if split and len(texts) > 1:
242
+ return framing_multi(texts, min_node_threshold)
243
+ return framing_single(text, min_node_threshold)
244
+
245
+ async def framing_file(file_obj, min_node_threshold):
246
+ with open(file_obj.name, "r") as f:
247
+ texts = f.readlines()
248
+ if len(texts) > 1:
249
+ return framing_multi(texts, min_node_threshold)
250
+ return framing_single(texts, min_node_threshold)
251
+
252
+ example_list = [["In 2010, CFCs were banned internationally due to their harmful effect on the ozone layer.", False, 1],
253
+ ["In 2021, doctors prevented the spread of the virus by vaccinating with Pfizer.", False, 1],
254
+ ["We must fight for our freedom.", False, 1],
255
+ ["The government prevents our freedom.", False, 1],
256
+ ["They prevent the spread.", False, 1],
257
+ ["We fight the virus.", False, 1],
258
+ ["I believe that we should act now. There is no time to waste.", True, 1],
259
  ]
260
 
261
+ textbox_inferface = gr.Interface(fn=framing_textbox,
262
+ inputs=[
263
+ gr.Textbox(label="Text to analyze."),
264
+ gr.Checkbox(True, label="Split on newlines? (To enter newlines type shift+Enter)"),
265
+ gr.Number(1, label="Min node threshold for framing structure.")
266
+ ],
267
  description="A simple tool that helps you find (discover and detect) frames in text.",
268
  examples=example_list,
269
  article="Check out the preliminary article in the [Web Conference Symposium](https://dl.acm.org/doi/pdf/10.1145/3543873.3587534), will be updated to currently in review article after publication.",
 
272
  gr.Plot(label="Structure")
273
  ])
274
 
275
+ file_interface = gr.Interface(fn=framing_file,
276
+ inputs=[
277
+ gr.File(label="File of texts to analyze."),
278
+ gr.Number(1, label="Min node threshold for framing structure."),
279
+ ],
280
+ description="A simple tool that helps you find (discover and detect) frames in text.",
281
+ article="Check out the preliminary article in the [Web Conference Symposium](https://dl.acm.org/doi/pdf/10.1145/3543873.3587534), will be updated to currently in review article after publication.",
282
+ outputs=[gr.Plot(label="Label"),
283
+ gr.Plot(label="Dimensions"),
284
+ gr.Plot(label="Structure")])
285
+
286
+ demo = gr.TabbedInterface([textbox_inferface, file_interface],
287
+ tab_names=["Single Mode", "File Mode"],
288
+ title="FrameFinder",)
289
+
290
  demo.launch()