xyc8888's picture
Second model version
190e2ab
raw
history blame
3.62 kB
#!/usr/bin/python
# -*- ecoding: utf-8 -*-
# @ModuleName: predict.py
# @Function:
# @Author: zhaokh
# @Time: 2021/9/7 10:55
import torch
import numpy as np
from roformer import RoFormerForCausalLM, RoFormerConfig
from transformers import BertTokenizer
class Model(object):
def __init__(self, pretrain_model_path):
# 模型配置
pretrained_model = pretrain_model_path
# 建立分词器
self.tokenizer = BertTokenizer.from_pretrained(pretrained_model)
# 加载模型
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
config = RoFormerConfig.from_pretrained(pretrained_model)
config.is_decoder = True
config.eos_token_id = self.tokenizer.sep_token_id
config.pooler_activation = "linear"
self.model = RoFormerForCausalLM.from_pretrained(pretrained_model, config=config)
self.model.to(self.device)
self.model.eval()
def encoder_predict(self, r):
inputs2 = self.tokenizer(r, padding=True, return_tensors="pt")
with torch.no_grad():
inputs2.to(self.device)
outputs = self.model(**inputs2)
Z = outputs.pooler_output.cpu().numpy()
return Z
def read_txt_file(file_path):
data = []
file = open(file_path, 'r',encoding='utf-8') # 打开文件
file_data = file.readlines() # 读取所有行
for row in file_data:
data.append(row.split('\n')[0])
return data
if __name__ == '__main__':
###########################任务1###########################
# 请在问号出填入模型地址,使模型可以运行。
print("开始加载模型...")
encoder = Model("../roformer_chinese_sim_char_base")
# 请在问号处填入待输入文本
print("预测任务1句向量...")
text = "北京"
pred_emb = encoder.encoder_predict([text])
# 请在问号处填入向量维度,可使用numpy函数的shape属性
print(pred_emb.shape)
###########################任务2###########################
# 请在问号处填入待输入文本
print("预测任务2句向量...")
text1 = "什么是网站域名?"
text2 = "网站所拥有的域名具体是指代什么呢?"
pred_emb_multi = encoder.encoder_predict([text1, text2])
# 使用以下语句进行向量归一化。请在括号中输入操作维度及其它参数
pred_emb_multi /= (pred_emb_multi ** 2).sum(axis=1, keepdims=True) ** 0.5
# 可使用np.dot()计算余弦相似度
similarity_score = np.dot(pred_emb_multi[1], pred_emb_multi[0])
print(similarity_score.shape)
print(similarity_score)
#print("arg sort :", similarity_score.argsort())
###########################任务3###########################
# 读取语料,补充read_txt_file函数
print("获取语料及待预测数据...")
text3 = "小度插上电源黑屏怎么回事?"
corpus_data = read_txt_file("D://ChromeDownLoad//YOLO_v3_PyTorch-master//code//corpus.txt")
#
# # 获取待测数据及语料的句向量
corpus_data = [text3]+corpus_data
pred_emb_corpus = encoder.encoder_predict(corpus_data)
#
# # 使用以下语句进行向量归一化。请参考上述操作。
pred_emb_corpus /= (pred_emb_corpus ** 2).sum(axis=1, keepdims=True) ** 0.5
#
# # 计算分数
argsort = np.dot(pred_emb_corpus[1:], -pred_emb_corpus[0]).argsort()
score = np.dot(pred_emb_corpus[argsort[0]+1], -pred_emb_corpus[0])
print([corpus_data[i + 1] for i in argsort[:1]])
#
# # 输出最相似的语句
print(-score)