File size: 13,655 Bytes
ac736ed 5c91758 bc50d7d f7a8863 a2d8109 01b8a52 ebf50a4 b7111b8 3d8cd48 83e90d7 ac736ed d914cbe dbbbac5 d914cbe 60a8335 d914cbe 1aa7dda 3d81019 60a8335 dc9ff0b 3d81019 d914cbe 3d81019 60a8335 dc9ff0b fa29176 d914cbe 3be15aa fbce538 3d81019 3be15aa d914cbe 80744c0 3d81019 60a8335 a74fa0d 3d8cd48 5914cea 18634d6 5914cea 3d8cd48 bf52bfd 18634d6 3d8cd48 fa29176 3d81019 4c18d69 60a8335 a74fa0d 169e7aa a74fa0d 857dba3 a74fa0d 0cea6d5 a74fa0d 4b2b2d4 a74fa0d 857dba3 4b2b2d4 fa29176 4b2b2d4 0cea6d5 4b2b2d4 a74fa0d f82dac8 1091141 f82dac8 857dba3 3d81019 60a8335 cbcad17 7837592 4c18d69 b8aeb00 94e1f37 60a8335 b8aeb00 7837592 94e1f37 b8aeb00 4c18d69 60a8335 cbcad17 4c18d69 7837592 4c18d69 340e947 a3ad8d6 7837592 cb96047 dfc41ea 94e1f37 b14ffe5 bec805a b6c17e6 7837592 60a8335 4c18d69 f7a8863 fea8f2e f7a8863 b81fd11 f7a8863 d2568a6 b28bcdc 9edc447 b28bcdc 9edc447 b28bcdc 9edc447 b28bcdc d2568a6 4c18d69 18634d6 b28bcdc dbbbac5 27b4c27 83e90d7 9822204 3577a57 9822204 3577a57 d914cbe 3577a57 9822204 3577a57 4c18d69 b23060e 4aae01c b23060e 4aae01c b23060e 4aae01c 0e51129 83e90d7 1f1627a 5c91758 4b2b2d4 0e51129 1f1627a 5f6c4ef 94e1f37 5fb8341 b28bcdc 4b2b2d4 5c91758 53c7e98 e72b522 b4303dc 1f1627a b4303dc 83e90d7 3577a57 4b2b2d4 7837592 ffa1ad8 7837592 4b2b2d4 aafb64e 4b2b2d4 aafb64e 4b2b2d4 5c91758 b23060e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 |
import streamlit as st
import spacy
import torch
import torch.nn as nn
import pandas as pd
from transformers import BertTokenizer, BertModel, AutoConfig
from transformers.models.bert.modeling_bert import BertForMaskedLM
from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel
from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights
from models.spabert.datasets.osm_sample_loader import PbfMapDataset
from import DataLoader
from PIL import Image
device = torch.device('cpu')
dev_mode = False
#Spacy Initialization Section
nlp = spacy.load("./models/en_core_web_sm")
#BERT Initialization Section
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")
#SpaBERT Initialization Section
data_file_path = 'models/spabert/datasets/SpaBERTPivots.json' #Sample file otherwise this model will take too long on CPU.
pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
config = SpatialBertConfig()
config.output_hidden_states = True
spaBERT_model = SpatialBertForMaskedLM(config)
pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu'))
spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False)
spaBERT_model.load_state_dict(pre_trained_model, strict=False)
#Load data using SpatialDataset
spatialDataset = PbfMapDataset(data_file_path = data_file_path,
tokenizer = bert_tokenizer,
max_token_len = 256, #Originally 300
#max_token_len = max_seq_length, #Originally 300
distance_norm_factor = 0.0001,
spatial_dist_fill = 20,
with_type = False,
sep_between_neighbors = True,
label_encoder = None,
mode = None) #If set to None it will use the full dataset for mlm
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False)
# Create a dictionary to map entity names to indices
entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)}
# Ensure names are stored in lowercase for case-insensitive matching
entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()}
#Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
def process_entity(batch, model, device):
input_ids = batch['masked_input'].to(device)
attention_mask = batch['attention_mask'].to(device)
position_list_x = batch['norm_lng_list'].to(device)
position_list_y = batch['norm_lat_list'].to(device)
sent_position_ids = batch['sent_position_ids'].to(device)
pseudo_sentence = batch['pseudo_sentence'].to(device)
# Convert tensor to list of token IDs, and decode them into a readable sentence
pseudo_sentence_decoded = bert_tokenizer.decode(pseudo_sentence[0].tolist(), skip_special_tokens=False)
with torch.no_grad():
outputs = spaBERT_model(#input_ids=input_ids,
#NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct
spaBERT_embedding = outputs.hidden_states[-1].to(device)
# Extract the [CLS] token embedding (first token)
spaBERT_embedding = spaBERT_embedding[:, 0, :].detach() # [batch_size, hidden_size]
#return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
return spaBERT_embedding, input_ids, pseudo_sentence_decoded
spaBERT_embeddings = []
pseudo_sentences = []
for batch in (data_loader):
spaBERT_embedding, input_ids, pseudo_sentence = process_entity(batch, spaBERT_model, device)
embedding_cache = {}
#Get BERT Embedding for review
def get_bert_embedding(review_text):
#tokenize review
inputs = bert_tokenizer(review_text, return_tensors='pt', padding=True, truncation=True).to(device)
# Forward pass through the BERT model
with torch.no_grad():
outputs = bert_model(**inputs)
# Extract embeddings from the last hidden state
bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
return bert_embedding
#Get SpaBERT Embedding for geo-entity
def get_spaBert_embedding(entity,current_pseudo_sentences):
entity_index = entity_index_dict.get(entity.lower(), None)
if entity_index is None:
if(dev_mode == True):
st.write("Got Bert embedding for: ", entity)
return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
if(dev_mode == True):
st.write("Got SpaBert embedding for: ", entity)
return spaBERT_embeddings[entity_index]
#Go through each review, identify all geo-entities, then extract their SpaBERT embedings
def processSpatialEntities(review, nlp):
doc = nlp(review)
entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]
token_embeddings = []
current_pseudo_sentences = []
# Iterate over each entity span and process only geo entities
for start, end, text, label in entity_spans:
if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities
if(dev_mode == True):
st.write("Text found:", text)
spaBert_emb = get_spaBert_embedding(text,current_pseudo_sentences)
if(dev_mode == True):
st.write("Geo-Entity Found in review: ", text)
token_embeddings = torch.stack(token_embeddings, dim=0)
processed_embedding = token_embeddings.mean(dim=0) # Shape: (768)
#processed_embedding = processed_embedding.unsqueeze(0) # Shape: (1, 768)
return processed_embedding,current_pseudo_sentences
#Initialize discriminator module
class Discriminator(nn.Module):
def __init__(self, input_size=512, hidden_sizes=[512], num_labels=2, dropout_rate=0.1):
super(Discriminator, self).__init__()
self.input_dropout = nn.Dropout(p=dropout_rate)
layers = []
hidden_sizes = [input_size] + hidden_sizes
for i in range(len(hidden_sizes)-1):
layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)])
self.layers = nn.Sequential(*layers) #per il flatten
self.logit = nn.Linear(hidden_sizes[-1],num_labels+1) # +1 for the probability of this sample being fake/real.
self.softmax = nn.Softmax(dim=-1)
def forward(self, input_rep):
input_rep = self.input_dropout(input_rep)
last_rep = self.layers(input_rep)
logits = self.logit(last_rep)
probs = self.softmax(logits)
return last_rep, logits, probs
dConfig = AutoConfig.from_pretrained("bert-base-uncased")
hidden_size = int(dConfig.hidden_size)
num_hidden_layers_d = 2;
hidden_levels_d = [hidden_size for i in range(0, num_hidden_layers_d)]
label_list = ["1", "0"]
out_dropout_rate = 0.5;
discriminator = Discriminator(input_size=hidden_size*2, hidden_sizes=hidden_levels_d, num_labels=len(label_list), dropout_rate=out_dropout_rate).to(device)
discriminator_weights = ('data/datasets/discriminator_weights.pth')
def get_prediction(embeddings):
with torch.no_grad():
# Forward pass through the discriminator to get the logits and probabilities
last_rep, logits, probs = discriminator(embeddings)
# Filter logits to ignore the last dimension (assuming you only care about the first two)
filtered_logits = logits[:, 0:-1]
# Get the predicted labels using the filtered logits
_, predicted_labels = torch.max(filtered_logits, dim=-1)
# Convert to numpy array if needed
predicted_labels = predicted_labels.cpu().numpy()
return predicted_labels
# Function to read reviews from a text file
def load_reviews_from_file(file_path):
reviews = {}
with open(file_path, 'r', encoding='utf-8') as file:
for i, line in enumerate(file):
line = line.strip()
if line: # Ensure the line is not empty
reviews[f"Review {i + 1}"] = line
except FileNotFoundError:
st.error(f"File not found: {file_path}")
return reviews
#Demo Section
st.title("SpaGAN Demo")
st.write("This demo lets you explore a curated list of sample reviews, containing a real and fake example.")
st.write("Upon selecting a review, any identified geo-entities will be color coded for easy visualization.")
st.write("For each geo-entity found, the model will generate a contextual pseudo-sentnece, highlighting its closest neighbors from our dataset.")
st.write("Finally, the entire review is embedded and enriched with spatial embeddings, enabling the model to determine whether the review is authentic or fake.")
# Define a color map and descriptions for different entity types
'FAC': ('red', 'Facilities (e.g., buildings, airports)'),
'ORG': ('blue', 'Organizations (e.g., companies, institutions)'),
'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'),
'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)')
# Display the color key
st.write("**Color Key:**")
for label, (color, description) in COLOR_MAP.items():
st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True)
review_file_path = "models/spabert/datasets/SampleReviews.txt"
example_reviews = load_reviews_from_file(review_file_path)
# Define labels
review_labels = {
"Review 1": "Real",
"Review 2": "Spam",
# Create options with labels for the dropdown
dropdown_options = [f"{key} ({review_labels.get(key, 'Unknown')})" for key in example_reviews.keys()]
# Dropdown for selecting an example review
user_selection = st.selectbox("Select an example review", options=dropdown_options)
# Extract the original review key from the selected option
selected_key = user_selection.split(" (")[0] # Remove the label part
selected_review = example_reviews[selected_key]
lower_case_review = selected_review.lower()
#Optional textbox for interactivity
user_input_review = st.text_area("Or type your own review here","")
review_to_process = user_input_review if user_input_review.strip() else selected_review
lower_case_review = review_to_process.lower()
# Process the text when the button is clicked
if st.button("Process Review"):
if lower_case_review.strip():
bert_embedding = get_bert_embedding(lower_case_review)
spaBert_embedding, current_pseudo_sentences = processSpatialEntities(review_to_process,nlp)
combined_embedding =,spaBert_embedding),dim=-1)
if(dev_mode == True):
st.write("Review Embedding Shape:", bert_embedding.shape)
st.write("Geo-Entities embedding shape: ", spaBert_embedding.shape)
st.write("Concatenated Embedding Shape:", combined_embedding.shape)
st.write("Concatenated Embedding:", combined_embedding)
prediction = get_prediction(combined_embedding)
# Process the text using spaCy
doc = nlp(selected_review)
# Highlight geo-entities with different colors
highlighted_text = review_to_process
for ent in reversed(doc.ents):
if ent.label_ in COLOR_MAP:
color = COLOR_MAP[ent.label_][0]
highlighted_text = (
highlighted_text[:ent.start_char] +
f"<span style='color:{color}; font-weight:bold'>{ent.text}</span>" +
# Display the highlighted text with HTML support
st.markdown(highlighted_text, unsafe_allow_html=True)
#Display pseudo sentences found
for sentence in current_pseudo_sentences:
clean_sentence = sentence.replace("[PAD]", "").strip()
st.write("Pseudo-Sentence:", clean_sentence)
#Display the models prediction
if prediction == 0:
st.markdown("<h3 style='color:green;'>✅ Prediction: Not Spam</h3>", unsafe_allow_html=True)
elif prediction == 1:
st.markdown("<h3 style='color:red;'>❌ Prediction: Spam</h3>", unsafe_allow_html=True)
st.markdown("<h3 style='color:orange;'>⚠️ Error during prediction</h3>", unsafe_allow_html=True)
st.error("Please select a review.") |