TSA / inference.py
QINGCHE's picture
fix bugs
bff547d
raw history blame
No virus
2.48 kB
import os
import numpy as np
import transformers
import torch
import torch.nn as nn
from torch import cuda
from transformers import BertTokenizer
from BERT_inference import BertClassificationModel
def encoder(max_len,text):
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
tokenizer = tokenizer(
text,
padding = True,
truncation = True,
max_length = max_len,
return_tensors='pt'
)
input_ids = tokenizer['input_ids']
token_type_ids = tokenizer['token_type_ids']
attention_mask = tokenizer['attention_mask']
return input_ids,token_type_ids,attention_mask
def predict(model,device,text):
model.to(device)
model.eval()
with torch.no_grad():
input_ids,token_type_ids,attention_mask = encoder(512,text)
input_ids,token_type_ids,attention_mask=input_ids.to(device),token_type_ids.to(device),attention_mask.to(device)
out_put = model(input_ids,token_type_ids,attention_mask)
# pre_numpy = out_put.cpu().numpy().tolist()
probs = torch.nn.functional.softmax(out_put).detach().cpu().numpy().tolist()
# print(probs)
return probs[0][1]
def inference_matrix(topics):
device = torch.device('cuda' if cuda.is_available() else 'cpu')
load_path = "bert_model.pkl"
model = torch.load(load_path,map_location=torch.device(device))
matrix = np.zeros([len(topics),len(topics)],dtype=float)
for i,i_text in enumerate(topics):
for j,j_text in enumerate(topics):
if(i == j):
matrix[i][j] = 0
else:
test = i_text+" 是否包含 "+j_text
outputs = predict(model,device,test)
# outputs = model(ids, mask,token_type_ids)
# print(outputs)
matrix[i][j] = outputs
return matrix
if __name__ == "__main__":
print("yes")
topics = ['在本次报告中我将介绍分布式并行加速算法模型架构内存和计算优化以及集群架构等关键技术', '在现代机器学习任务中大模型训练已成为解决复杂问题的重要手段', '首先分布式并行加速策略包括数据并行模型并行流水线并行和张量并行等四种方式', '选择合适的集群架构是实现大模型的分布式训练的关键', '这些策略帮助我们将训练数据和模型分布到多个设备上以加速大模型训练过程']
print(inference_matrix(topics))