|
import transformers |
|
import torch |
|
import os |
|
import json |
|
import random |
|
import numpy as np |
|
from torch import nn |
|
import argparse |
|
from torch.utils.tensorboard import SummaryWriter |
|
from datetime import datetime |
|
from tqdm import tqdm |
|
from torch.nn import DataParallel |
|
import logging |
|
from transformers import BertTokenizerFast |
|
|
|
from os.path import join, exists |
|
from itertools import zip_longest, chain |
|
|
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.nn import CrossEntropyLoss |
|
from sklearn.model_selection import train_test_split |
|
import torch.nn.functional as F |
|
from transformers import AutoTokenizer, AutoConfig, get_linear_schedule_with_warmup, AdamW, BertModel |
|
|
|
PAD = '[PAD]' |
|
pad_id = 0 |
|
|
|
|
|
def set_args(): |
|
""" |
|
Sets up the arguments. |
|
""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--device', default='0', type=str, required=False, help='生成设备') |
|
|
|
|
|
parser.add_argument('--log_path', default='interact.log', type=str, required=False, help='interact日志存放位置') |
|
parser.add_argument('--model_path', default='./pathology_extra/result/12/model.pth', type=str, required=False, help='对话模型路径') |
|
parser.add_argument('--vocab_path', default='./bert-base-zh\\vocab.txt', type=str, required=False, |
|
help='对话模型路径') |
|
parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径") |
|
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False, |
|
help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数") |
|
|
|
parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断') |
|
parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度") |
|
parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测') |
|
return parser.parse_args() |
|
|
|
|
|
def create_logger(args): |
|
""" |
|
将日志输出到日志文件和控制台 |
|
""" |
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
|
|
formatter = logging.Formatter( |
|
'%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
file_handler = logging.FileHandler( |
|
filename=args.log_path) |
|
file_handler.setFormatter(formatter) |
|
file_handler.setLevel(logging.INFO) |
|
logger.addHandler(file_handler) |
|
|
|
|
|
console = logging.StreamHandler() |
|
console.setLevel(logging.DEBUG) |
|
console.setFormatter(formatter) |
|
logger.addHandler(console) |
|
|
|
return logger |
|
|
|
class Word_BERT(nn.Module): |
|
def __init__(self, seq_label=1,cancer_label=8,transfer_label=2,ly_transfer=2): |
|
super(Word_BERT, self).__init__() |
|
self.bert = BertModel.from_pretrained('./bert-base-zh') |
|
|
|
self.out = nn.Sequential( |
|
|
|
|
|
nn.Dropout(0.1), |
|
nn.Linear(768, seq_label) |
|
) |
|
self.cancer = nn.Sequential( |
|
nn.Dropout(0.1), |
|
nn.Linear(768, cancer_label) |
|
) |
|
self.transfer = nn.Sequential( |
|
nn.Dropout(0.1), |
|
nn.Linear(768, transfer_label) |
|
) |
|
self.ly_transfer = nn.Sequential( |
|
nn.Dropout(0.1), |
|
nn.Linear(768, ly_transfer) |
|
) |
|
|
|
def forward(self, word_input, masks): |
|
|
|
output = self.bert(word_input, attention_mask=masks) |
|
sequence_output = output.last_hidden_state |
|
pool = output.pooler_output |
|
|
|
|
|
out = self.out(sequence_output) |
|
cancer = self.cancer(pool) |
|
transfer = self.transfer(pool) |
|
ly_transfer = self.ly_transfer(pool) |
|
return out,cancer,transfer,ly_transfer |
|
|
|
def getChat(text: str, userid: int): |
|
|
|
|
|
|
|
|
|
|
|
|
|
text = ['[CLS]']+[i for i in text]+['[SEP]'] |
|
|
|
text_ids = tokenizer.convert_tokens_to_ids(text) |
|
|
|
|
|
input_ids = torch.tensor(text_ids).long().to(device) |
|
input_ids = input_ids.unsqueeze(0) |
|
mask_input = torch.ones_like(input_ids).long().to(device) |
|
|
|
response = [] |
|
|
|
with torch.no_grad(): |
|
out, cancer, transfer, ly_transfer = model(input_ids, mask_input) |
|
out = F.sigmoid(out).squeeze(2).cpu() |
|
out = out.numpy().tolist() |
|
cancer = cancer.argmax(dim=-1).cpu().numpy().tolist() |
|
transfer = transfer.argmax(dim=-1).cpu().numpy().tolist() |
|
ly_transfer = ly_transfer.argmax(dim=-1).cpu().numpy().tolist() |
|
|
|
|
|
|
|
pred_thresold = [[1 if jj > 0.4 else 0 for jj in ii] for ii in out] |
|
size_list = [] |
|
start,end = 0,0 |
|
for i,j in enumerate(pred_thresold[0]): |
|
if j==1 and start==end: |
|
start = i |
|
elif j!=1 and start!=end: |
|
end = i |
|
size_list.append((start,end)) |
|
start = end |
|
print(size_list) |
|
|
|
cancer_dict = {'腺癌': 0, '肺良性疾病': 1, '鳞癌': 2, '无法判断组织分型': 3, '复合型': 4, '转移癌': 5, '小细胞癌': 6, '大细胞癌': 7} |
|
id_cancer = {j:i for i,j in cancer_dict.items()} |
|
transfer_id = {'无': 0, '转移': 1} |
|
id_transfer = {j:i for i,j in transfer_id.items()} |
|
lymph_transfer_id = {'无': 0, '淋巴转移': 1} |
|
id_lymph_transfer = {j: i for i, j in lymph_transfer_id.items()} |
|
|
|
cancer = id_cancer[cancer[0]] |
|
transfer = id_transfer[transfer[0]] |
|
ly_transfer = id_lymph_transfer[ly_transfer[0]] |
|
print(cancer,transfer,ly_transfer) |
|
|
|
return size_list,cancer,transfer,ly_transfer |
|
|
|
|
|
import requests |
|
|
|
|
|
def testFunc(): |
|
url = "http://localhost:7788/getChat" |
|
payload = json.dumps({ |
|
"userid": 602099768, |
|
"context": "我就试一下" |
|
}) |
|
headers = { |
|
'Content-Type': 'application/json' |
|
} |
|
response = requests.request("POST", url, headers=headers, data=payload) |
|
print(response.text) |
|
|
|
|
|
import uvicorn |
|
from pydantic import BaseModel |
|
from fastapi import FastAPI |
|
|
|
app = FastAPI() |
|
|
|
|
|
args = set_args() |
|
logger = create_logger(args) |
|
|
|
args.cuda = torch.cuda.is_available() and not args.no_cuda |
|
device = 'cuda' if args.cuda else 'cpu' |
|
logger.info('using device:{}'.format(device)) |
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.device |
|
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") |
|
|
|
model = Word_BERT() |
|
|
|
model = model.to(device) |
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Allhistory = {} |
|
print('初始化完成') |
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
uvicorn.run(app='main_2:app', host="localhost", |
|
port=7788, reload=False) |
|
|
|
|
|
|
|
class Items1(BaseModel): |
|
context: str |
|
userid: int |
|
|
|
|
|
|
|
import time |
|
|
|
lastReplyTime = 0 |
|
|
|
|
|
@app.post("/getChat") |
|
async def get_Chat(item1: Items1): |
|
global lastReplyTime |
|
tempReplyTime = int(time.time() * 1000) |
|
|
|
|
|
|
|
result = getChat( |
|
item1.context, item1.userid) |
|
return {"res": result} |