import os import numpy as np import json import pdb #from sklearn.metrics.pairwise import cosine_similarity import torch device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') def write_to_csv(out_dir, map_name, match_list): out_path = os.path.join(out_dir, map_name + '.json') with open(out_path, 'w') as f: for match_dict in match_list: json.dump(match_dict, f) f.write('\n') def load_spatial_bert_pretrained_weights(model, weight_path): # load pretrained weights from SpatialBertLM to SpatialBertModel #pre_trained_model=torch.load(os.path.join(model_save_dir, weight_file_name)) pre_trained_model=torch.load(weight_path) cnt_layers = 0 cur_model_kvpair=model.state_dict() for key,value in cur_model_kvpair.items(): if 'bert.'+key in pre_trained_model: cur_model_kvpair[key]=pre_trained_model['bert.'+key] #print("weights loaded for", key) cnt_layers += 1 else: print("No weight for", key) print(cnt_layers, 'layers loaded') model.load_state_dict(cur_model_kvpair) return model def get_spatialbert_embedding(entity, model, use_distance = True, agg = 'mean'): pseudo_sentence = entity['pseudo_sentence'][None,:].to(device) attention_mask = entity['attention_mask'][None,:].to(device) sent_position_ids = entity['sent_position_ids'][None,:].to(device) pivot_token_len = entity['pivot_token_len'] if 'norm_lng_list' in entity and use_distance: position_list_x = entity['norm_lng_list'][None,:].to(device) position_list_y = entity['norm_lat_list'][None,:].to(device) else: position_list_x = [] position_list_y = [] outputs = model(input_ids = pseudo_sentence, attention_mask = attention_mask, sent_position_ids = sent_position_ids, position_list_x = position_list_x, position_list_y = position_list_y) embeddings = outputs.last_hidden_state pivot_embed = embeddings[0][1:1+pivot_token_len] if agg == 'mean': pivot_embed = torch.mean(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) elif agg == 'sum': pivot_embed = torch.sum(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) else: raise NotImplementedError return pivot_embed def get_bert_embedding(entity, model, agg = 'mean'): pseudo_sentence = entity['pseudo_sentence'].unsqueeze(0).to(device) attention_mask = entity['attention_mask'].unsqueeze(0).to(device) pivot_token_len = entity['pivot_token_len'] outputs = model(input_ids = pseudo_sentence, attention_mask = attention_mask) embeddings = outputs.last_hidden_state pivot_embed = embeddings[0][1:1+pivot_token_len] if agg == 'mean': pivot_embed = torch.mean(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) elif agg == 'sum': pivot_embed = torch.sum(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) else: raise NotImplementedError return pivot_embed