import unicodedata import torch import utils import parameters import json from sentence_transformers import SentenceTransformer import os import torch.nn.functional as F # os.environ['CUDA_VISIBLE_DEVICES'] = '0' device = torch.device('cpu') class Siameser: def __init__(self, model_name=None, stadard_scope=None): print("Load sentence embedding model (If this is the first time you run this repo, It could be take time to download sentence embedding model)") self.threshold = 0.61 self.embedding_model = SentenceTransformer(parameters.embedding_model).to(device) if stadard_scope == 'all': print('Load standard address') with open(file=parameters.NORM_ADDS_FILE_ALL_1, mode='r', encoding='utf-8') as f: self.NORM_ADDS = json.load(fp=f) print('Load standard address matrix') embedding = torch.load(parameters.STD_EMBEDDING_FILE_ALL_1) self.std_embeddings = embedding['accent_matrix'].to(device) self.NT_std_embeddings = embedding['noaccent_matrix'].to(device) else: print('Load standard address') with open(file=parameters.NORM_ADDS_FILE_HN_HCM, mode='r', encoding='utf-8') as f: self.NORM_ADDS = json.load(fp=f) print('Load standard address matrix') embedding = torch.load(parameters.STD_EMBEDDING_FILE_HN_HCM) self.std_embeddings = embedding['accent_matrix'].to(device) self.NT_std_embeddings = embedding['noaccent_matrix'].to(device) self.num_std_add = self.std_embeddings.shape[0] print('Done') def standardize(self, raw_add_): raw_add = unicodedata.normalize('NFC', raw_add_).lower() raw_add = utils.remove_punctuation(raw_add) raw_add_vector = self.embedding_model.encode(raw_add, convert_to_tensor=True).to(device) raw_add_vectors = raw_add_vector.repeat(self.num_std_add, 1) if raw_add == utils.remove_accent(raw_add): score = F.cosine_similarity(raw_add_vectors, self.NT_std_embeddings) else: score = F.cosine_similarity(raw_add_vectors, self.std_embeddings) s, top_k = score.topk(1) s, idx = s.tolist()[0], top_k.tolist()[0] if s < self.threshold: return {'Format Error': 'Xâu truyền vào không phải địa chỉ, mời nhập lại.'} std_add = self.NORM_ADDS[str(idx)] return utils.get_full_result(raw_add_, std_add, round(s, 4)) def get_top_k(self, raw_add_, k): raw_add = unicodedata.normalize('NFC', raw_add_).lower() raw_add = utils.remove_punctuation(raw_add) raw_add_vector = self.embedding_model.encode(raw_add, convert_to_tensor=True).to(device) raw_add_vectors = raw_add_vector.repeat(self.num_std_add, 1) if raw_add == utils.remove_accent(raw_add): score = F.cosine_similarity(raw_add_vectors, self.NT_std_embeddings) else: score = F.cosine_similarity(raw_add_vectors, self.std_embeddings) s, top_k = score.topk(k) s, top_k = s.tolist(), top_k.tolist() if s[0] < self.threshold: return {'Format Error': 'Dường như xâu truyền vào không phải địa chỉ, mời nhập lại.'}, {} top_std_adds = [] for score, idx in zip(s, top_k): std_add = self.NORM_ADDS[str(idx)] top_std_adds.append(utils.get_full_result(raw_add_, std_add, round(score, 4))) return top_std_adds[0], top_std_adds