nickmuchi commited on
Commit
d1dcb4e
1 Parent(s): 8a619b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -3
app.py CHANGED
@@ -9,6 +9,7 @@ from optimum.onnxruntime import ORTModelForSequenceClassification
9
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
10
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
11
  import streamlit as st
 
12
 
13
  nltk.download('punkt')
14
 
@@ -50,18 +51,28 @@ auth_token = os.environ.get("auth_token")
50
 
51
  progress_bar = st.sidebar.progress(0)
52
 
53
- @st.experimental_singleton()
54
  def load_models():
55
  asr_model = whisper.load_model("small")
56
  q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
 
57
  q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
 
58
  sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer)
59
  sum_pipe = pipeline("summarization",model="facebook/bart-large-cnn", tokenizer="facebook/bart-large-cnn")
 
 
60
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
61
 
62
- return asr_model, sent_pipe, sum_pipe, cross_encoder
 
 
 
 
 
63
 
64
- asr_model, sent_pipe, sum_pipe, cross_encoder = load_models()
 
65
 
66
  @st.experimental_memo(suppress_st_warning=True)
67
  def inference(link, upload):
@@ -131,6 +142,147 @@ def preprocess_plain_text(text,window_size=3):
131
  print(f"Passages: {len(passages)}")
132
 
133
  return passages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  def display_df_as_table(model,top_k,score='score'):
136
  '''Display the df with text and scores as a table'''
 
9
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
10
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
11
  import streamlit as st
12
+ import en_core_web_lg
13
 
14
  nltk.download('punkt')
15
 
 
51
 
52
  progress_bar = st.sidebar.progress(0)
53
 
54
+ @st.experimental_singleton(suppress_st_warning=True)
55
  def load_models():
56
  asr_model = whisper.load_model("small")
57
  q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
58
+ ner_model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
59
  q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
60
+ ner_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
61
  sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer)
62
  sum_pipe = pipeline("summarization",model="facebook/bart-large-cnn", tokenizer="facebook/bart-large-cnn")
63
+ ner_pip = pipeline("ner", model=model, tokenizer=tokenizer, grouped_entities=True)
64
+ sbert = SentenceTransformer("all-mpnet-base-v2")
65
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
66
 
67
+ return asr_model, sent_pipe, sum_pipe, ner_pipe, sbert, cross_encoder
68
+
69
+ @st.experimental_singleton(suppress_st_warning=True)
70
+ def get_spacy():
71
+ nlp = en_core_web_lg.load()
72
+ return nlp
73
 
74
+ nlp = get_spacy()
75
+ asr_model, sent_pipe, sum_pipe, ner_pipe, sbert, cross_encoder = load_models()
76
 
77
  @st.experimental_memo(suppress_st_warning=True)
78
  def inference(link, upload):
 
142
  print(f"Passages: {len(passages)}")
143
 
144
  return passages
145
+
146
+ @st.experimental_memo(suppress_st_warning=True)
147
+ def chunk_clean_text(text):
148
+
149
+ """Chunk text longer than 500 tokens"""
150
+
151
+ article = nlp(text)
152
+ sentences = [i.text for i in list(article.sents)]
153
+
154
+ current_chunk = 0
155
+ chunks = []
156
+
157
+ for sentence in sentences:
158
+ if len(chunks) == current_chunk + 1:
159
+ if len(chunks[current_chunk]) + len(sentence.split(" ")) <= 500:
160
+ chunks[current_chunk].extend(sentence.split(" "))
161
+ else:
162
+ current_chunk += 1
163
+ chunks.append(sentence.split(" "))
164
+ else:
165
+ chunks.append(sentence.split(" "))
166
+
167
+ for chunk_id in range(len(chunks)):
168
+ chunks[chunk_id] = " ".join(chunks[chunk_id])
169
+
170
+ return chunks
171
+
172
+ def summary_downloader(raw_text):
173
+
174
+ b64 = base64.b64encode(raw_text.encode()).decode()
175
+ new_filename = "new_text_file_{}_.txt".format(time_str)
176
+ st.markdown("#### Download Summary as a File ###")
177
+ href = f'<a href="data:file/txt;base64,{b64}" download="{new_filename}">Click to Download!!</a>'
178
+ st.markdown(href,unsafe_allow_html=True)
179
+
180
+ def get_all_entities_per_sentence(text):
181
+ doc = nlp(''.join(text))
182
+
183
+ sentences = list(doc.sents)
184
+
185
+ entities_all_sentences = []
186
+ for sentence in sentences:
187
+ entities_this_sentence = []
188
+
189
+ # SPACY ENTITIES
190
+ for entity in sentence.ents:
191
+ entities_this_sentence.append(str(entity))
192
+
193
+ # FLAIR ENTITIES (CURRENTLY NOT USED)
194
+ # sentence_entities = Sentence(str(sentence))
195
+ # tagger.predict(sentence_entities)
196
+ # for entity in sentence_entities.get_spans('ner'):
197
+ # entities_this_sentence.append(entity.text)
198
+
199
+ # XLM ENTITIES
200
+ entities_xlm = [entity["word"] for entity in ner_model(str(sentence))]
201
+ for entity in entities_xlm:
202
+ entities_this_sentence.append(str(entity))
203
+
204
+ entities_all_sentences.append(entities_this_sentence)
205
+
206
+ return entities_all_sentences
207
+
208
+ def get_all_entities(text):
209
+ all_entities_per_sentence = get_all_entities_per_sentence(text)
210
+ return list(itertools.chain.from_iterable(all_entities_per_sentence))
211
+
212
+ def get_and_compare_entities(article_content,summary_output):
213
+
214
+ all_entities_per_sentence = get_all_entities_per_sentence(article_content)
215
+ entities_article = list(itertools.chain.from_iterable(all_entities_per_sentence))
216
+
217
+ all_entities_per_sentence = get_all_entities_per_sentence(summary_output)
218
+ entities_summary = list(itertools.chain.from_iterable(all_entities_per_sentence))
219
+
220
+ matched_entities = []
221
+ unmatched_entities = []
222
+ for entity in entities_summary:
223
+ if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article):
224
+ matched_entities.append(entity)
225
+ elif any(
226
+ np.inner(sentence_embedding_model.encode(entity, show_progress_bar=False),
227
+ sentence_embedding_model.encode(art_entity, show_progress_bar=False)) > 0.9 for
228
+ art_entity in entities_article):
229
+ matched_entities.append(entity)
230
+ else:
231
+ unmatched_entities.append(entity)
232
+
233
+ matched_entities = list(dict.fromkeys(matched_entities))
234
+ unmatched_entities = list(dict.fromkeys(unmatched_entities))
235
+
236
+ matched_entities_to_remove = []
237
+ unmatched_entities_to_remove = []
238
+
239
+ for entity in matched_entities:
240
+ for substring_entity in matched_entities:
241
+ if entity != substring_entity and entity.lower() in substring_entity.lower():
242
+ matched_entities_to_remove.append(entity)
243
+
244
+ for entity in unmatched_entities:
245
+ for substring_entity in unmatched_entities:
246
+ if entity != substring_entity and entity.lower() in substring_entity.lower():
247
+ unmatched_entities_to_remove.append(entity)
248
+
249
+ matched_entities_to_remove = list(dict.fromkeys(matched_entities_to_remove))
250
+ unmatched_entities_to_remove = list(dict.fromkeys(unmatched_entities_to_remove))
251
+
252
+ for entity in matched_entities_to_remove:
253
+ matched_entities.remove(entity)
254
+ for entity in unmatched_entities_to_remove:
255
+ unmatched_entities.remove(entity)
256
+
257
+ return matched_entities, unmatched_entities
258
+
259
+ def highlight_entities(article_content,summary_output):
260
+
261
+ markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">"
262
+ markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">"
263
+ markdown_end = "</mark>"
264
+
265
+ matched_entities, unmatched_entities = get_and_compare_entities(article_content,summary_output)
266
+
267
+ print(summary_output)
268
+
269
+ for entity in matched_entities:
270
+ summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_green + entity + markdown_end,summary_output)
271
+
272
+ for entity in unmatched_entities:
273
+ summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_red + entity + markdown_end,summary_output)
274
+
275
+ print("")
276
+ print(summary_output)
277
+
278
+ print("")
279
+ print(summary_output)
280
+
281
+ soup = BeautifulSoup(summary_output, features="html.parser")
282
+
283
+ return HTML_WRAPPER.format(soup)
284
+
285
+ nlp = get_spacy()
286
 
287
  def display_df_as_table(model,top_k,score='score'):
288
  '''Display the df with text and scores as a table'''