|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import numpy as np |
|
from roformer import RoFormerForCausalLM, RoFormerConfig |
|
from transformers import BertTokenizer |
|
|
|
|
|
class Model(object): |
|
def __init__(self, pretrain_model_path): |
|
|
|
pretrained_model = pretrain_model_path |
|
|
|
|
|
self.tokenizer = BertTokenizer.from_pretrained(pretrained_model) |
|
|
|
|
|
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
config = RoFormerConfig.from_pretrained(pretrained_model) |
|
config.is_decoder = True |
|
config.eos_token_id = self.tokenizer.sep_token_id |
|
config.pooler_activation = "linear" |
|
self.model = RoFormerForCausalLM.from_pretrained(pretrained_model, config=config) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
def encoder_predict(self, r): |
|
inputs2 = self.tokenizer(r, padding=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
inputs2.to(self.device) |
|
outputs = self.model(**inputs2) |
|
Z = outputs.pooler_output.cpu().numpy() |
|
return Z |
|
|
|
|
|
def read_txt_file(file_path): |
|
data = [] |
|
file = open(file_path, 'r',encoding='utf-8') |
|
file_data = file.readlines() |
|
for row in file_data: |
|
data.append(row.split('\n')[0]) |
|
return data |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
print("开始加载模型...") |
|
encoder = Model("../roformer_chinese_sim_char_base") |
|
|
|
|
|
print("预测任务1句向量...") |
|
text = "北京" |
|
pred_emb = encoder.encoder_predict([text]) |
|
|
|
|
|
print(pred_emb.shape) |
|
|
|
|
|
|
|
print("预测任务2句向量...") |
|
text1 = "什么是网站域名?" |
|
text2 = "网站所拥有的域名具体是指代什么呢?" |
|
pred_emb_multi = encoder.encoder_predict([text1, text2]) |
|
|
|
|
|
pred_emb_multi /= (pred_emb_multi ** 2).sum(axis=1, keepdims=True) ** 0.5 |
|
|
|
|
|
|
|
similarity_score = np.dot(pred_emb_multi[1], pred_emb_multi[0]) |
|
print(similarity_score.shape) |
|
|
|
print(similarity_score) |
|
|
|
|
|
|
|
|
|
print("获取语料及待预测数据...") |
|
text3 = "小度插上电源黑屏怎么回事?" |
|
corpus_data = read_txt_file("D://ChromeDownLoad//YOLO_v3_PyTorch-master//code//corpus.txt") |
|
|
|
|
|
|
|
corpus_data = [text3]+corpus_data |
|
pred_emb_corpus = encoder.encoder_predict(corpus_data) |
|
|
|
|
|
pred_emb_corpus /= (pred_emb_corpus ** 2).sum(axis=1, keepdims=True) ** 0.5 |
|
|
|
|
|
argsort = np.dot(pred_emb_corpus[1:], -pred_emb_corpus[0]).argsort() |
|
score = np.dot(pred_emb_corpus[argsort[0]+1], -pred_emb_corpus[0]) |
|
print([corpus_data[i + 1] for i in argsort[:1]]) |
|
|
|
|
|
|
|
print(-score) |
|
|
|
|
|
|