Spaces:
Sleeping
Sleeping
import unicodedata | |
import torch | |
import utils | |
import parameters | |
import json | |
from sentence_transformers import SentenceTransformer | |
import os | |
import torch.nn.functional as F | |
# os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
device = torch.device('cpu') | |
class Siameser: | |
def __init__(self, model_name=None, stadard_scope=None): | |
print("Load sentence embedding model (If this is the first time you run this repo, It could be take time to download sentence embedding model)") | |
self.threshold = 0.61 | |
self.embedding_model = SentenceTransformer(parameters.embedding_model).to(device) | |
if stadard_scope == 'all': | |
print('Load standard address') | |
with open(file=parameters.NORM_ADDS_FILE_ALL_1, mode='r', encoding='utf-8') as f: | |
self.NORM_ADDS = json.load(fp=f) | |
print('Load standard address matrix') | |
embedding = torch.load(parameters.STD_EMBEDDING_FILE_ALL_1) | |
self.std_embeddings = embedding['accent_matrix'].to(device) | |
self.NT_std_embeddings = embedding['noaccent_matrix'].to(device) | |
else: | |
print('Load standard address') | |
with open(file=parameters.NORM_ADDS_FILE_HN_HCM, mode='r', encoding='utf-8') as f: | |
self.NORM_ADDS = json.load(fp=f) | |
print('Load standard address matrix') | |
embedding = torch.load(parameters.STD_EMBEDDING_FILE_HN_HCM) | |
self.std_embeddings = embedding['accent_matrix'].to(device) | |
self.NT_std_embeddings = embedding['noaccent_matrix'].to(device) | |
self.num_std_add = self.std_embeddings.shape[0] | |
print('Done') | |
def standardize(self, raw_add_): | |
raw_add = unicodedata.normalize('NFC', raw_add_).lower() | |
raw_add = utils.remove_punctuation(raw_add) | |
raw_add_vector = self.embedding_model.encode(raw_add, convert_to_tensor=True).to(device) | |
raw_add_vectors = raw_add_vector.repeat(self.num_std_add, 1) | |
if raw_add == utils.remove_accent(raw_add): | |
score = F.cosine_similarity(raw_add_vectors, self.NT_std_embeddings) | |
else: | |
score = F.cosine_similarity(raw_add_vectors, self.std_embeddings) | |
s, top_k = score.topk(1) | |
s, idx = s.tolist()[0], top_k.tolist()[0] | |
if s < self.threshold: | |
return {'Format Error': 'Xâu truyền vào không phải địa chỉ, mời nhập lại.'} | |
std_add = self.NORM_ADDS[str(idx)] | |
return utils.get_full_result(raw_add_, std_add, round(s, 4)) | |
def get_top_k(self, raw_add_, k): | |
raw_add = unicodedata.normalize('NFC', raw_add_).lower() | |
raw_add = utils.remove_punctuation(raw_add) | |
raw_add_vector = self.embedding_model.encode(raw_add, convert_to_tensor=True).to(device) | |
raw_add_vectors = raw_add_vector.repeat(self.num_std_add, 1) | |
if raw_add == utils.remove_accent(raw_add): | |
score = F.cosine_similarity(raw_add_vectors, self.NT_std_embeddings) | |
else: | |
score = F.cosine_similarity(raw_add_vectors, self.std_embeddings) | |
s, top_k = score.topk(k) | |
s, top_k = s.tolist(), top_k.tolist() | |
if s[0] < self.threshold: | |
return {'Format Error': 'Dường như xâu truyền vào không phải địa chỉ, mời nhập lại.'}, {} | |
top_std_adds = [] | |
for score, idx in zip(s, top_k): | |
std_add = self.NORM_ADDS[str(idx)] | |
top_std_adds.append(utils.get_full_result(raw_add_, std_add, round(score, 4))) | |
return top_std_adds[0], top_std_adds |