Delete app_2.py
Browse files
app_2.py
DELETED
@@ -1,244 +0,0 @@
|
|
1 |
-
import transformers
|
2 |
-
import torch
|
3 |
-
import os
|
4 |
-
import json
|
5 |
-
import random
|
6 |
-
import numpy as np
|
7 |
-
from torch import nn
|
8 |
-
import argparse
|
9 |
-
from torch.utils.tensorboard import SummaryWriter
|
10 |
-
from datetime import datetime
|
11 |
-
from tqdm import tqdm
|
12 |
-
from torch.nn import DataParallel
|
13 |
-
import logging
|
14 |
-
from transformers import BertTokenizerFast
|
15 |
-
# from transformers import BertTokenizer
|
16 |
-
from os.path import join, exists
|
17 |
-
from itertools import zip_longest, chain
|
18 |
-
# from chatbot.model import DialogueGPT2Model
|
19 |
-
# from dataset import MyDataset
|
20 |
-
from torch.utils.data import Dataset, DataLoader
|
21 |
-
from torch.nn import CrossEntropyLoss
|
22 |
-
from sklearn.model_selection import train_test_split
|
23 |
-
import torch.nn.functional as F
|
24 |
-
from transformers import AutoTokenizer, AutoConfig, get_linear_schedule_with_warmup, AdamW, BertModel
|
25 |
-
|
26 |
-
PAD = '[PAD]'
|
27 |
-
pad_id = 0
|
28 |
-
|
29 |
-
|
30 |
-
def set_args():
|
31 |
-
"""
|
32 |
-
Sets up the arguments.
|
33 |
-
"""
|
34 |
-
parser = argparse.ArgumentParser()
|
35 |
-
parser.add_argument('--device', default='0', type=str, required=False, help='生成设备')
|
36 |
-
# parser.add_argument('--model_config', default='config/model_config_dialogue_small.json', type=str, required=False,
|
37 |
-
# help='模型参数')
|
38 |
-
parser.add_argument('--log_path', default='interact.log', type=str, required=False, help='interact日志存放位置')
|
39 |
-
parser.add_argument('--model_path', default='./pathology_extra/result/12/model.pth', type=str, required=False, help='对话模型路径')
|
40 |
-
parser.add_argument('--vocab_path', default='./bert-base-zh\\vocab.txt', type=str, required=False,
|
41 |
-
help='对话模型路径')
|
42 |
-
parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径")
|
43 |
-
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
|
44 |
-
help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
|
45 |
-
# parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
|
46 |
-
parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断')
|
47 |
-
parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度")
|
48 |
-
parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测')
|
49 |
-
return parser.parse_args()
|
50 |
-
|
51 |
-
|
52 |
-
def create_logger(args):
|
53 |
-
"""
|
54 |
-
将日志输出到日志文件和控制台
|
55 |
-
"""
|
56 |
-
logger = logging.getLogger(__name__)
|
57 |
-
logger.setLevel(logging.INFO)
|
58 |
-
|
59 |
-
formatter = logging.Formatter(
|
60 |
-
'%(asctime)s - %(levelname)s - %(message)s')
|
61 |
-
|
62 |
-
# 创建一个handler,用于写入日志文件
|
63 |
-
file_handler = logging.FileHandler(
|
64 |
-
filename=args.log_path)
|
65 |
-
file_handler.setFormatter(formatter)
|
66 |
-
file_handler.setLevel(logging.INFO)
|
67 |
-
logger.addHandler(file_handler)
|
68 |
-
|
69 |
-
# 创建一个handler,用于将日志输出到控制台
|
70 |
-
console = logging.StreamHandler()
|
71 |
-
console.setLevel(logging.DEBUG)
|
72 |
-
console.setFormatter(formatter)
|
73 |
-
logger.addHandler(console)
|
74 |
-
|
75 |
-
return logger
|
76 |
-
|
77 |
-
class Word_BERT(nn.Module):
|
78 |
-
def __init__(self, seq_label=1,cancer_label=8,transfer_label=2,ly_transfer=2):
|
79 |
-
super(Word_BERT, self).__init__()
|
80 |
-
self.bert = BertModel.from_pretrained('./bert-base-zh')
|
81 |
-
# self.bert_config = self.bert.config
|
82 |
-
self.out = nn.Sequential(
|
83 |
-
# nn.Linear(768,256),
|
84 |
-
# nn.ReLU(),
|
85 |
-
nn.Dropout(0.1),
|
86 |
-
nn.Linear(768, seq_label)
|
87 |
-
)
|
88 |
-
self.cancer = nn.Sequential(
|
89 |
-
nn.Dropout(0.1),
|
90 |
-
nn.Linear(768, cancer_label)
|
91 |
-
)
|
92 |
-
self.transfer = nn.Sequential(
|
93 |
-
nn.Dropout(0.1),
|
94 |
-
nn.Linear(768, transfer_label)
|
95 |
-
)
|
96 |
-
self.ly_transfer = nn.Sequential(
|
97 |
-
nn.Dropout(0.1),
|
98 |
-
nn.Linear(768, ly_transfer)
|
99 |
-
)
|
100 |
-
|
101 |
-
def forward(self, word_input, masks):
|
102 |
-
# print(word_input.size())
|
103 |
-
output = self.bert(word_input, attention_mask=masks)
|
104 |
-
sequence_output = output.last_hidden_state
|
105 |
-
pool = output.pooler_output
|
106 |
-
# print(sequence_output.size())
|
107 |
-
# print(pool.size())
|
108 |
-
out = self.out(sequence_output)
|
109 |
-
cancer = self.cancer(pool)
|
110 |
-
transfer = self.transfer(pool)
|
111 |
-
ly_transfer = self.ly_transfer(pool)
|
112 |
-
return out,cancer,transfer,ly_transfer
|
113 |
-
|
114 |
-
def getChat(text: str, userid: int):
|
115 |
-
# while True:
|
116 |
-
# if True:
|
117 |
-
# text = input("user:")
|
118 |
-
# text = "你好"
|
119 |
-
# if args.save_samples_path:
|
120 |
-
# samples_file.write("user:{}\n".format(text))
|
121 |
-
text = ['[CLS]']+[i for i in text]+['[SEP]']
|
122 |
-
# print(text)
|
123 |
-
text_ids = tokenizer.convert_tokens_to_ids(text)
|
124 |
-
# print(text_ids)
|
125 |
-
|
126 |
-
input_ids = torch.tensor(text_ids).long().to(device)
|
127 |
-
input_ids = input_ids.unsqueeze(0)
|
128 |
-
mask_input = torch.ones_like(input_ids).long().to(device)
|
129 |
-
# print(input_ids.size())
|
130 |
-
response = [] # 根据context,生成的response
|
131 |
-
# 最多生成max_len个token
|
132 |
-
with torch.no_grad():
|
133 |
-
out, cancer, transfer, ly_transfer = model(input_ids, mask_input)
|
134 |
-
out = F.sigmoid(out).squeeze(2).cpu()
|
135 |
-
out = out.numpy().tolist()
|
136 |
-
cancer = cancer.argmax(dim=-1).cpu().numpy().tolist()
|
137 |
-
transfer = transfer.argmax(dim=-1).cpu().numpy().tolist()
|
138 |
-
ly_transfer = ly_transfer.argmax(dim=-1).cpu().numpy().tolist()
|
139 |
-
# print(out)
|
140 |
-
# print(cancer,transfer,ly_transfer)
|
141 |
-
|
142 |
-
pred_thresold = [[1 if jj > 0.4 else 0 for jj in ii] for ii in out]
|
143 |
-
size_list = []
|
144 |
-
start,end = 0,0
|
145 |
-
for i,j in enumerate(pred_thresold[0]):
|
146 |
-
if j==1 and start==end:
|
147 |
-
start = i
|
148 |
-
elif j!=1 and start!=end:
|
149 |
-
end = i
|
150 |
-
size_list.append((start,end))
|
151 |
-
start = end
|
152 |
-
print(size_list)
|
153 |
-
|
154 |
-
cancer_dict = {'腺癌': 0, '肺良性疾病': 1, '鳞癌': 2, '无法判断组织分型': 3, '复合型': 4, '转移癌': 5, '小细胞癌': 6, '大细胞癌': 7}
|
155 |
-
id_cancer = {j:i for i,j in cancer_dict.items()}
|
156 |
-
transfer_id = {'无': 0, '转移': 1}
|
157 |
-
id_transfer = {j:i for i,j in transfer_id.items()}
|
158 |
-
lymph_transfer_id = {'无': 0, '淋巴转移': 1}
|
159 |
-
id_lymph_transfer = {j: i for i, j in lymph_transfer_id.items()}
|
160 |
-
# print(cancer)
|
161 |
-
cancer = id_cancer[cancer[0]]
|
162 |
-
transfer = id_transfer[transfer[0]]
|
163 |
-
ly_transfer = id_lymph_transfer[ly_transfer[0]]
|
164 |
-
print(cancer,transfer,ly_transfer)
|
165 |
-
|
166 |
-
return size_list,cancer,transfer,ly_transfer
|
167 |
-
|
168 |
-
|
169 |
-
import requests
|
170 |
-
|
171 |
-
|
172 |
-
def testFunc():
|
173 |
-
url = "http://localhost:7788/getChat"
|
174 |
-
payload = json.dumps({
|
175 |
-
"userid": 602099768,
|
176 |
-
"context": "我就试一下"
|
177 |
-
})
|
178 |
-
headers = {
|
179 |
-
'Content-Type': 'application/json'
|
180 |
-
}
|
181 |
-
response = requests.request("POST", url, headers=headers, data=payload)
|
182 |
-
print(response.text)
|
183 |
-
|
184 |
-
|
185 |
-
import uvicorn
|
186 |
-
from pydantic import BaseModel
|
187 |
-
from fastapi import FastAPI
|
188 |
-
|
189 |
-
app = FastAPI()
|
190 |
-
# import intel_extension_for_pytorch as ipex
|
191 |
-
|
192 |
-
args = set_args()
|
193 |
-
logger = create_logger(args)
|
194 |
-
# 当用户使用GPU,并且GPU可用时
|
195 |
-
args.cuda = torch.cuda.is_available() and not args.no_cuda
|
196 |
-
device = 'cuda' if args.cuda else 'cpu'
|
197 |
-
logger.info('using device:{}'.format(device))
|
198 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
199 |
-
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
|
200 |
-
# tokenizer = BertTokenizer(vocab_file=args.voca_path)
|
201 |
-
model = Word_BERT()
|
202 |
-
# model = model.load_state_dict(torch.load(args.model_path))
|
203 |
-
model = model.to(device)
|
204 |
-
# model = ipex.optimize(model, dtype=torch.float32)
|
205 |
-
model.eval()
|
206 |
-
# if args.save_samples_path:
|
207 |
-
# if not os.path.exists(args.save_samples_path):
|
208 |
-
# os.makedirs(args.save_samples_path)
|
209 |
-
# samples_file = open(args.save_samples_path + '/samples.txt', 'a', encoding='utf8')
|
210 |
-
# samples_file.write("聊天记录{}:\n".format(datetime.now()))
|
211 |
-
# 存储聊天记录,每个utterance以token的id的形式进行存储
|
212 |
-
# history = []
|
213 |
-
Allhistory = {}
|
214 |
-
print('初始化完成')
|
215 |
-
|
216 |
-
if __name__ == '__main__':
|
217 |
-
# getChat("测试一下", 0)
|
218 |
-
# main()
|
219 |
-
uvicorn.run(app='main_2:app', host="localhost",
|
220 |
-
port=7788, reload=False)
|
221 |
-
# testFunc()
|
222 |
-
|
223 |
-
|
224 |
-
class Items1(BaseModel):
|
225 |
-
context: str
|
226 |
-
userid: int
|
227 |
-
# must: bool
|
228 |
-
|
229 |
-
|
230 |
-
import time
|
231 |
-
|
232 |
-
lastReplyTime = 0
|
233 |
-
|
234 |
-
|
235 |
-
@app.post("/getChat")
|
236 |
-
async def get_Chat(item1: Items1):
|
237 |
-
global lastReplyTime
|
238 |
-
tempReplyTime = int(time.time() * 1000)
|
239 |
-
# if tempReplyTime % 10 == 0 or item1.must == True or tempReplyTime - lastReplyTime < 30000:
|
240 |
-
# if item1.must == True:
|
241 |
-
# lastReplyTime = tempReplyTime
|
242 |
-
result = getChat(
|
243 |
-
item1.context, item1.userid)
|
244 |
-
return {"res": result}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|