JasonTPhillipsJr commited on
Commit
7cc56e3
1 Parent(s): 90d656e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -40
app.py CHANGED
@@ -146,7 +146,13 @@ def processSpatialEntities(review, nlp):
146
  token_embeddings.append(spaBert_emb)
147
  if(dev_mode == True):
148
  st.write("Geo-Entity Found in review: ", text)
149
-
 
 
 
 
 
 
150
  token_embeddings = torch.stack(token_embeddings, dim=0)
151
  processed_embedding = token_embeddings.mean(dim=0) # Shape: (768)
152
  #processed_embedding = processed_embedding.unsqueeze(0) # Shape: (1, 768)
@@ -273,7 +279,7 @@ user_input_review = st.text_area("Or type your own review here","")
273
  st.info(f"Please include one of the following entities in your review:\n {', '.join(california_entities)}")
274
 
275
  review_to_process = user_input_review if user_input_review.strip() else selected_review
276
- st.write("Selected Review: ", review_to_process)
277
  lower_case_review = review_to_process.lower()
278
 
279
  # Process the text when the button is clicked
@@ -281,45 +287,49 @@ if st.button("Process Review"):
281
  if lower_case_review.strip():
282
  bert_embedding = get_bert_embedding(lower_case_review)
283
  spaBert_embedding, current_pseudo_sentences = processSpatialEntities(review_to_process,nlp)
284
- combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1)
285
-
286
- if(dev_mode == True):
287
- st.write("Review Embedding Shape:", bert_embedding.shape)
288
- st.write("Geo-Entities embedding shape: ", spaBert_embedding.shape)
289
- st.write("Concatenated Embedding Shape:", combined_embedding.shape)
290
- st.write("Concatenated Embedding:", combined_embedding)
291
-
292
- prediction = get_prediction(combined_embedding)
293
-
294
- # Process the text using spaCy
295
- doc = nlp(review_to_process)
296
-
297
- # Highlight geo-entities with different colors
298
- highlighted_text = review_to_process
299
- for ent in reversed(doc.ents):
300
- if ent.label_ in COLOR_MAP:
301
- color = COLOR_MAP[ent.label_][0]
302
- highlighted_text = (
303
- highlighted_text[:ent.start_char] +
304
- f"<span style='color:{color}; font-weight:bold'>{ent.text}</span>" +
305
- highlighted_text[ent.end_char:]
306
- )
307
-
308
- # Display the highlighted text with HTML support
309
- st.markdown(highlighted_text, unsafe_allow_html=True)
310
-
311
- #Display pseudo sentences found
312
- for sentence in current_pseudo_sentences:
313
- clean_sentence = sentence.replace("[PAD]", "").strip()
314
- st.write("Pseudo-Sentence:", clean_sentence)
315
-
316
- #Display the models prediction
317
- if prediction == 0:
318
- st.markdown("<h3 style='color:green;'>✅ Prediction: Not Spam</h3>", unsafe_allow_html=True)
319
- elif prediction == 1:
320
- st.markdown("<h3 style='color:red;'>❌ Prediction: Spam</h3>", unsafe_allow_html=True)
321
  else:
322
- st.markdown("<h3 style='color:orange;'>⚠️ Error during prediction</h3>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
  else:
325
  st.error("Please select a review.")
 
146
  token_embeddings.append(spaBert_emb)
147
  if(dev_mode == True):
148
  st.write("Geo-Entity Found in review: ", text)
149
+
150
+ # Handle the case where no geo-entities are found
151
+ if not token_embeddings:
152
+ st.warning("No geo-entities found in the review. Please include one from the list.")
153
+ # Return a zero vector as a fallback if no entities are found
154
+ return torch.zeros(bert_model.config.hidden_size), []
155
+
156
  token_embeddings = torch.stack(token_embeddings, dim=0)
157
  processed_embedding = token_embeddings.mean(dim=0) # Shape: (768)
158
  #processed_embedding = processed_embedding.unsqueeze(0) # Shape: (1, 768)
 
279
  st.info(f"Please include one of the following entities in your review:\n {', '.join(california_entities)}")
280
 
281
  review_to_process = user_input_review if user_input_review.strip() else selected_review
282
+ #st.write("Selected Review: ", review_to_process)
283
  lower_case_review = review_to_process.lower()
284
 
285
  # Process the text when the button is clicked
 
287
  if lower_case_review.strip():
288
  bert_embedding = get_bert_embedding(lower_case_review)
289
  spaBert_embedding, current_pseudo_sentences = processSpatialEntities(review_to_process,nlp)
290
+ # Check if SpaBERT embedding is valid
291
+ if spaBert_embedding is None or spaBert_embedding.sum() == 0:
292
+ st.error("Unable to process the review. Please include at least one valid geo-entity.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  else:
294
+ combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1)
295
+
296
+ if(dev_mode == True):
297
+ st.write("Review Embedding Shape:", bert_embedding.shape)
298
+ st.write("Geo-Entities embedding shape: ", spaBert_embedding.shape)
299
+ st.write("Concatenated Embedding Shape:", combined_embedding.shape)
300
+ st.write("Concatenated Embedding:", combined_embedding)
301
+
302
+ prediction = get_prediction(combined_embedding)
303
+
304
+ # Process the text using spaCy
305
+ doc = nlp(review_to_process)
306
+
307
+ # Highlight geo-entities with different colors
308
+ highlighted_text = review_to_process
309
+ for ent in reversed(doc.ents):
310
+ if ent.label_ in COLOR_MAP:
311
+ color = COLOR_MAP[ent.label_][0]
312
+ highlighted_text = (
313
+ highlighted_text[:ent.start_char] +
314
+ f"<span style='color:{color}; font-weight:bold'>{ent.text}</span>" +
315
+ highlighted_text[ent.end_char:]
316
+ )
317
+
318
+ # Display the highlighted text with HTML support
319
+ st.markdown(highlighted_text, unsafe_allow_html=True)
320
+
321
+ #Display pseudo sentences found
322
+ for sentence in current_pseudo_sentences:
323
+ clean_sentence = sentence.replace("[PAD]", "").strip()
324
+ st.write("Pseudo-Sentence:", clean_sentence)
325
+
326
+ #Display the models prediction
327
+ if prediction == 0:
328
+ st.markdown("<h3 style='color:green;'>✅ Prediction: Not Spam</h3>", unsafe_allow_html=True)
329
+ elif prediction == 1:
330
+ st.markdown("<h3 style='color:red;'>❌ Prediction: Spam</h3>", unsafe_allow_html=True)
331
+ else:
332
+ st.markdown("<h3 style='color:orange;'>⚠️ Error during prediction</h3>", unsafe_allow_html=True)
333
 
334
  else:
335
  st.error("Please select a review.")