nickmuchi commited on
Commit
b633e3e
·
1 Parent(s): 8951082

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -27
app.py CHANGED
@@ -8,7 +8,9 @@ import validators, re
8
  from fake_useragent import UserAgent
9
  from bs4 import BeautifulSoup
10
  import streamlit as st
11
- from transformers import pipeline
 
 
12
  import time
13
  import base64
14
  import requests
@@ -158,19 +160,235 @@ def summary_downloader(raw_text):
158
  st.markdown("#### Download Summary as a File ###")
159
  href = f'<a href="data:file/txt;base64,{b64}" download="{new_filename}">Click to Download!!</a>'
160
  st.markdown(href,unsafe_allow_html=True)
 
 
 
161
 
162
- @st.cache(allow_output_mutation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def facebook_model():
164
 
165
  summarizer = pipeline('summarization',model='facebook/bart-large-cnn')
166
  return summarizer
167
 
168
- @st.cache(allow_output_mutation=True)
169
  def schleifer_model():
170
 
171
  summarizer = pipeline('summarization',model='sshleifer/distilbart-cnn-12-6')
172
  return summarizer
173
 
 
 
 
 
 
 
 
 
 
 
174
  #Streamlit App
175
 
176
  st.title("Article Text and Link Extractive Summarizer 📝")
@@ -211,6 +429,14 @@ st.markdown("---")
211
 
212
  url_text = st.text_input("Please Enter a url here")
213
 
 
 
 
 
 
 
 
 
214
 
215
  st.markdown(
216
  "<h3 style='text-align: center; color: red;'>OR</h3>",
@@ -228,27 +454,19 @@ upload_doc = st.file_uploader(
228
  "Upload a .txt, .pdf, .docx file for summarization"
229
  )
230
 
231
- is_url = validators.url(url_text)
232
-
233
- if is_url:
234
- # complete text, chunks to summarize (list of sentences for long docs)
235
- article_title,chunks = article_text_extractor(url=url_text)
236
 
237
  elif upload_doc:
 
238
 
239
- clean_text = chunk_clean_text(preprocess_plain_text(extract_text_from_file(upload_doc)))
240
-
241
- else:
242
-
243
- clean_text = chunk_clean_text(preprocess_plain_text(plain_text))
244
-
245
  summarize = st.button("Summarize")
246
 
247
  # called on toggle button [summarize]
248
  if summarize:
249
  if model_type == "Facebook-Bart":
250
- if is_url:
251
- text_to_summarize = chunks
252
  else:
253
  text_to_summarize = clean_text
254
 
@@ -260,8 +478,8 @@ if summarize:
260
  summarized_text = ' '.join([summ['summary_text'] for summ in summarized_text])
261
 
262
  elif model_type == "Sshleifer-DistilBart":
263
- if is_url:
264
- text_to_summarize = chunks
265
  else:
266
  text_to_summarize = clean_text
267
 
@@ -270,19 +488,25 @@ if summarize:
270
  ):
271
  summarizer_model = schleifer_model()
272
  summarized_text = summarizer_model(text_to_summarize, max_length=max_len, min_length=min_len)
273
- summarized_text = ' '.join([summ['summary_text'] for summ in summarized_text])
274
-
275
- # final summarized output
276
- st.subheader("Summarized text")
277
 
278
- if is_url:
279
 
280
- # view summarized text (expander)
281
- st.markdown(f"Article title: {article_title}")
 
 
 
 
 
 
 
 
 
282
 
283
- st.write(summarized_text)
284
 
285
- summary_downloader(summarized_text)
286
 
287
 
288
  st.markdown("""
 
8
  from fake_useragent import UserAgent
9
  from bs4 import BeautifulSoup
10
  import streamlit as st
11
+ from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
12
+ from sentence_transformers import SentenceTransformer
13
+ import en_core_web_lg
14
  import time
15
  import base64
16
  import requests
 
160
  st.markdown("#### Download Summary as a File ###")
161
  href = f'<a href="data:file/txt;base64,{b64}" download="{new_filename}">Click to Download!!</a>'
162
  st.markdown(href,unsafe_allow_html=True)
163
+
164
+ def get_all_entities_per_sentence(text):
165
+ doc = nlp(text)
166
 
167
+ sentences = list(doc.sents)
168
+
169
+ entities_all_sentences = []
170
+ for sentence in sentences:
171
+ entities_this_sentence = []
172
+
173
+ # SPACY ENTITIES
174
+ for entity in sentence.ents:
175
+ entities_this_sentence.append(str(entity))
176
+
177
+ # FLAIR ENTITIES (CURRENTLY NOT USED)
178
+ # sentence_entities = Sentence(str(sentence))
179
+ # tagger.predict(sentence_entities)
180
+ # for entity in sentence_entities.get_spans('ner'):
181
+ # entities_this_sentence.append(entity.text)
182
+
183
+ # XLM ENTITIES
184
+ entities_xlm = [entity["word"] for entity in ner_model(str(sentence))]
185
+ for entity in entities_xlm:
186
+ entities_this_sentence.append(str(entity))
187
+
188
+ entities_all_sentences.append(entities_this_sentence)
189
+
190
+ return entities_all_sentences
191
+
192
+ def get_all_entities(text):
193
+ all_entities_per_sentence = get_all_entities_per_sentence(text)
194
+ return list(itertools.chain.from_iterable(all_entities_per_sentence))
195
+
196
+ def get_and_compare_entities(article_content,summary_output):
197
+
198
+ all_entities_per_sentence = get_all_entities_per_sentence(article_content)
199
+ entities_article = list(itertools.chain.from_iterable(all_entities_per_sentence))
200
+
201
+ all_entities_per_sentence = get_all_entities_per_sentence(summary_output)
202
+ entities_summary = list(itertools.chain.from_iterable(all_entities_per_sentence))
203
+
204
+ matched_entities = []
205
+ unmatched_entities = []
206
+ for entity in entities_summary:
207
+ if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article):
208
+ matched_entities.append(entity)
209
+ elif any(
210
+ np.inner(sentence_embedding_model.encode(entity, show_progress_bar=False),
211
+ sentence_embedding_model.encode(art_entity, show_progress_bar=False)) > 0.9 for
212
+ art_entity in entities_article):
213
+ matched_entities.append(entity)
214
+ else:
215
+ unmatched_entities.append(entity)
216
+
217
+ matched_entities = list(dict.fromkeys(matched_entities))
218
+ unmatched_entities = list(dict.fromkeys(unmatched_entities))
219
+
220
+ matched_entities_to_remove = []
221
+ unmatched_entities_to_remove = []
222
+
223
+ for entity in matched_entities:
224
+ for substring_entity in matched_entities:
225
+ if entity != substring_entity and entity.lower() in substring_entity.lower():
226
+ matched_entities_to_remove.append(entity)
227
+
228
+ for entity in unmatched_entities:
229
+ for substring_entity in unmatched_entities:
230
+ if entity != substring_entity and entity.lower() in substring_entity.lower():
231
+ unmatched_entities_to_remove.append(entity)
232
+
233
+ matched_entities_to_remove = list(dict.fromkeys(matched_entities_to_remove))
234
+ unmatched_entities_to_remove = list(dict.fromkeys(unmatched_entities_to_remove))
235
+
236
+ for entity in matched_entities_to_remove:
237
+ matched_entities.remove(entity)
238
+ for entity in unmatched_entities_to_remove:
239
+ unmatched_entities.remove(entity)
240
+
241
+ return matched_entities, unmatched_entities
242
+
243
+ def highlight_entities(article_content,summary_output):
244
+
245
+ markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">"
246
+ markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">"
247
+ markdown_end = "</mark>"
248
+
249
+ matched_entities, unmatched_entities = get_and_compare_entities(article_content,summary_output)
250
+
251
+ for entity in matched_entities:
252
+ summary_content = summary_output.replace(entity, markdown_start_green + entity + markdown_end)
253
+
254
+ for entity in unmatched_entities:
255
+ summary_content = summary_output.replace(entity, markdown_start_red + entity + markdown_end)
256
+ soup = BeautifulSoup(summary_content, features="html.parser")
257
+ return HTML_WRAPPER.format(soup)
258
+
259
+
260
+ def render_dependency_parsing(text: Dict):
261
+ html = render_sentence_custom(text, nlp)
262
+ html = html.replace("\n\n", "\n")
263
+ st.write(get_svg(html), unsafe_allow_html=True)
264
+
265
+
266
+ def check_dependency(article: bool):
267
+ if article:
268
+ text = st.session_state.article_text
269
+ all_entities = get_all_entities_per_sentence(text)
270
+ else:
271
+ text = st.session_state.summary_output
272
+ all_entities = get_all_entities_per_sentence(text)
273
+ doc = nlp(text)
274
+ tok_l = doc.to_json()['tokens']
275
+ test_list_dict_output = []
276
+
277
+ sentences = list(doc.sents)
278
+ for i, sentence in enumerate(sentences):
279
+ start_id = sentence.start
280
+ end_id = sentence.end
281
+ for t in tok_l:
282
+ if t["id"] < start_id or t["id"] > end_id:
283
+ continue
284
+ head = tok_l[t['head']]
285
+ if t['dep'] == 'amod' or t['dep'] == "pobj":
286
+ object_here = text[t['start']:t['end']]
287
+ object_target = text[head['start']:head['end']]
288
+ if t['dep'] == "pobj" and str.lower(object_target) != "in":
289
+ continue
290
+ # ONE NEEDS TO BE ENTITY
291
+ if object_here in all_entities[i]:
292
+ identifier = object_here + t['dep'] + object_target
293
+ test_list_dict_output.append({"dep": t['dep'], "cur_word_index": (t['id'] - sentence.start),
294
+ "target_word_index": (t['head'] - sentence.start),
295
+ "identifier": identifier, "sentence": str(sentence)})
296
+ elif object_target in all_entities[i]:
297
+ identifier = object_here + t['dep'] + object_target
298
+ test_list_dict_output.append({"dep": t['dep'], "cur_word_index": (t['id'] - sentence.start),
299
+ "target_word_index": (t['head'] - sentence.start),
300
+ "identifier": identifier, "sentence": str(sentence)})
301
+ else:
302
+ continue
303
+ return test_list_dict_output
304
+
305
+
306
+ def render_svg(svg_file):
307
+ with open(svg_file, "r") as f:
308
+ lines = f.readlines()
309
+ svg = "".join(lines)
310
+
311
+ # """Renders the given svg string."""
312
+ b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
313
+ html = r'<img src="data:image/svg+xml;base64,%s"/>' % b64
314
+ return html
315
+
316
+
317
+ def generate_abstractive_summary(text, type, min_len=120, max_len=512, **kwargs):
318
+ text = text.strip().replace("\n", " ")
319
+ if type == "top_p":
320
+ text = summarization_model(text, min_length=min_len,
321
+ max_length=max_len,
322
+ top_k=50, top_p=0.95, clean_up_tokenization_spaces=True, truncation=True, **kwargs)
323
+ elif type == "greedy":
324
+ text = summarization_model(text, min_length=min_len,
325
+ max_length=max_len, clean_up_tokenization_spaces=True, truncation=True, **kwargs)
326
+ elif type == "top_k":
327
+ text = summarization_model(text, min_length=min_len, max_length=max_len, top_k=50,
328
+ clean_up_tokenization_spaces=True, truncation=True, **kwargs)
329
+ elif type == "beam":
330
+ text = summarization_model(text, min_length=min_len,
331
+ max_length=max_len,
332
+ clean_up_tokenization_spaces=True, truncation=True, **kwargs)
333
+ summary = text[0]['summary_text'].replace("<n>", " ")
334
+ return summary
335
+
336
+ def clean_text(text,doc=False,plain_text=False,url=False):
337
+ """Return clean text from the various input sources"""
338
+
339
+ if url:
340
+ is_url = validators.url(text)
341
+
342
+ if is_url:
343
+ # complete text, chunks to summarize (list of sentences for long docs)
344
+ article_title,chunks = article_text_extractor(url=url_text)
345
+
346
+ return article_title, chunks
347
+
348
+ elif doc:
349
+
350
+ clean_text = chunk_clean_text(preprocess_plain_text(extract_text_from_file(text)))
351
+
352
+ return None, clean_text
353
+
354
+ elif plain_text:
355
+
356
+ clean_text = chunk_clean_text(preprocess_plain_text(text))
357
+
358
+ return None, clean_text
359
+
360
+ # Load all different models (cached) at start time of the hugginface space
361
+ sentence_embedding_model = get_sentence_embedding_model()
362
+ ner_model = get_transformer_pipeline()
363
+ nlp = get_spacy()
364
+
365
+ @st.experimental_singleton
366
+ def get_spacy():
367
+ nlp = en_core_web_lg.load()
368
+ return nlp
369
+
370
+ @st.experimental_singleton
371
  def facebook_model():
372
 
373
  summarizer = pipeline('summarization',model='facebook/bart-large-cnn')
374
  return summarizer
375
 
376
+ @st.experimental_singleton
377
  def schleifer_model():
378
 
379
  summarizer = pipeline('summarization',model='sshleifer/distilbart-cnn-12-6')
380
  return summarizer
381
 
382
+ @st.experimental_singleton
383
+ def get_sentence_embedding_model():
384
+ return SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
385
+
386
+ @st.experimental_singleton
387
+ def get_ner_pipeline():
388
+ tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
389
+ model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
390
+ return pipeline("ner", model=model, tokenizer=tokenizer, grouped_entities=True)
391
+
392
  #Streamlit App
393
 
394
  st.title("Article Text and Link Extractive Summarizer 📝")
 
429
 
430
  url_text = st.text_input("Please Enter a url here")
431
 
432
+ if url_text:
433
+ article_title, clean_text = clean_text(url_text, url=True)
434
+
435
+ article_text = st.text_area(
436
+ label='Full Article Text',
437
+ value= clean_text,
438
+ height=250
439
+ )
440
 
441
  st.markdown(
442
  "<h3 style='text-align: center; color: red;'>OR</h3>",
 
454
  "Upload a .txt, .pdf, .docx file for summarization"
455
  )
456
 
457
+ if plain_text:
458
+ None, clean_text = clean_text(plain_text,plain_text=True)
 
 
 
459
 
460
  elif upload_doc:
461
+ None, clean_text = clean_text(plain_text,doc=True)
462
 
 
 
 
 
 
 
463
  summarize = st.button("Summarize")
464
 
465
  # called on toggle button [summarize]
466
  if summarize:
467
  if model_type == "Facebook-Bart":
468
+ if url_text:
469
+ text_to_summarize = url_clean_text
470
  else:
471
  text_to_summarize = clean_text
472
 
 
478
  summarized_text = ' '.join([summ['summary_text'] for summ in summarized_text])
479
 
480
  elif model_type == "Sshleifer-DistilBart":
481
+ if url_text:
482
+ text_to_summarize = url_clean_text
483
  else:
484
  text_to_summarize = clean_text
485
 
 
488
  ):
489
  summarizer_model = schleifer_model()
490
  summarized_text = summarizer_model(text_to_summarize, max_length=max_len, min_length=min_len)
491
+ summarized_text = ' '.join([summ['summary_text'] for summ in summarized_text])
 
 
 
492
 
493
+ with st.spinner("Calculating and matching entities, this takes a few seconds..."):
494
 
495
+ entity_match_html = highlight_entities(clean_text,summarized_text)
496
+ st.subheader("Summarized text with matched entities in Green and mismatched entities in Red relative to the original text")
497
+ st.markdown("####")
498
+
499
+ if article_title:
500
+
501
+ # view summarized text (expander)
502
+ st.markdown(f"Article title: {article_title}")
503
+
504
+ st.markdown("####")
505
+ st.write(entity_match_html, unsafe_allow_html=True)
506
 
507
+ st.markdown("####")
508
 
509
+ summary_downloader(summarized_text)
510
 
511
 
512
  st.markdown("""