|
import unicodedata |
|
import torch |
|
import utils |
|
import parameters |
|
import json |
|
from sentence_transformers import SentenceTransformer |
|
import os |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
device = torch.device('cpu') |
|
|
|
class Siameser: |
|
def __init__(self, model_name=None, stadard_scope=None): |
|
print("Load sentence embedding model (If this is the first time you run this repo, It could be take time to download sentence embedding model)") |
|
self.threshold = 0.61 |
|
self.embedding_model = SentenceTransformer(parameters.embedding_model).to(device) |
|
|
|
if stadard_scope == 'all': |
|
print('Load standard address') |
|
with open(file=parameters.NORM_ADDS_FILE_ALL_1, mode='r', encoding='utf-8') as f: |
|
self.NORM_ADDS = json.load(fp=f) |
|
|
|
print('Load standard address matrix') |
|
embedding = torch.load(parameters.STD_EMBEDDING_FILE_ALL_1) |
|
self.std_embeddings = embedding['accent_matrix'].to(device) |
|
self.NT_std_embeddings = embedding['noaccent_matrix'].to(device) |
|
else: |
|
print('Load standard address') |
|
with open(file=parameters.NORM_ADDS_FILE_HN_HCM, mode='r', encoding='utf-8') as f: |
|
self.NORM_ADDS = json.load(fp=f) |
|
|
|
print('Load standard address matrix') |
|
embedding = torch.load(parameters.STD_EMBEDDING_FILE_HN_HCM) |
|
self.std_embeddings = embedding['accent_matrix'].to(device) |
|
self.NT_std_embeddings = embedding['noaccent_matrix'].to(device) |
|
|
|
self.num_std_add = self.std_embeddings.shape[0] |
|
print('Done') |
|
|
|
def standardize(self, raw_add_): |
|
raw_add = unicodedata.normalize('NFC', raw_add_).lower() |
|
raw_add = utils.remove_punctuation(raw_add) |
|
raw_add_vector = self.embedding_model.encode(raw_add, convert_to_tensor=True).to(device) |
|
raw_add_vectors = raw_add_vector.repeat(self.num_std_add, 1) |
|
if raw_add == utils.remove_accent(raw_add): |
|
score = F.cosine_similarity(raw_add_vectors, self.NT_std_embeddings) |
|
else: |
|
score = F.cosine_similarity(raw_add_vectors, self.std_embeddings) |
|
s, top_k = score.topk(1) |
|
|
|
s, idx = s.tolist()[0], top_k.tolist()[0] |
|
if s < self.threshold: |
|
return {'Format Error': 'Xâu truyền vào không phải địa chỉ, mời nhập lại.'} |
|
std_add = self.NORM_ADDS[str(idx)] |
|
return utils.get_full_result(raw_add_, std_add, round(s, 4)) |
|
|
|
def get_top_k(self, raw_add_, k): |
|
raw_add = unicodedata.normalize('NFC', raw_add_).lower() |
|
raw_add = utils.remove_punctuation(raw_add) |
|
raw_add_vector = self.embedding_model.encode(raw_add, convert_to_tensor=True).to(device) |
|
raw_add_vectors = raw_add_vector.repeat(self.num_std_add, 1) |
|
if raw_add == utils.remove_accent(raw_add): |
|
score = F.cosine_similarity(raw_add_vectors, self.NT_std_embeddings) |
|
else: |
|
score = F.cosine_similarity(raw_add_vectors, self.std_embeddings) |
|
s, top_k = score.topk(k) |
|
s, top_k = s.tolist(), top_k.tolist() |
|
|
|
if s[0] < self.threshold: |
|
return {'Format Error': 'Dường như xâu truyền vào không phải địa chỉ, mời nhập lại.'}, {} |
|
|
|
top_std_adds = [] |
|
for score, idx in zip(s, top_k): |
|
std_add = self.NORM_ADDS[str(idx)] |
|
top_std_adds.append(utils.get_full_result(raw_add_, std_add, round(score, 4))) |
|
|
|
return top_std_adds[0], top_std_adds |