qianmuuq commited on
Commit
0bb4e62
1 Parent(s): 4339350

Delete app_2.py

Browse files
Files changed (1) hide show
  1. app_2.py +0 -244
app_2.py DELETED
@@ -1,244 +0,0 @@
1
- import transformers
2
- import torch
3
- import os
4
- import json
5
- import random
6
- import numpy as np
7
- from torch import nn
8
- import argparse
9
- from torch.utils.tensorboard import SummaryWriter
10
- from datetime import datetime
11
- from tqdm import tqdm
12
- from torch.nn import DataParallel
13
- import logging
14
- from transformers import BertTokenizerFast
15
- # from transformers import BertTokenizer
16
- from os.path import join, exists
17
- from itertools import zip_longest, chain
18
- # from chatbot.model import DialogueGPT2Model
19
- # from dataset import MyDataset
20
- from torch.utils.data import Dataset, DataLoader
21
- from torch.nn import CrossEntropyLoss
22
- from sklearn.model_selection import train_test_split
23
- import torch.nn.functional as F
24
- from transformers import AutoTokenizer, AutoConfig, get_linear_schedule_with_warmup, AdamW, BertModel
25
-
26
- PAD = '[PAD]'
27
- pad_id = 0
28
-
29
-
30
- def set_args():
31
- """
32
- Sets up the arguments.
33
- """
34
- parser = argparse.ArgumentParser()
35
- parser.add_argument('--device', default='0', type=str, required=False, help='生成设备')
36
- # parser.add_argument('--model_config', default='config/model_config_dialogue_small.json', type=str, required=False,
37
- # help='模型参数')
38
- parser.add_argument('--log_path', default='interact.log', type=str, required=False, help='interact日志存放位置')
39
- parser.add_argument('--model_path', default='./pathology_extra/result/12/model.pth', type=str, required=False, help='对话模型路径')
40
- parser.add_argument('--vocab_path', default='./bert-base-zh\\vocab.txt', type=str, required=False,
41
- help='对话模型路径')
42
- parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径")
43
- parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
44
- help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
45
- # parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
46
- parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断')
47
- parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度")
48
- parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测')
49
- return parser.parse_args()
50
-
51
-
52
- def create_logger(args):
53
- """
54
- 将日志输出到日志文件和控制台
55
- """
56
- logger = logging.getLogger(__name__)
57
- logger.setLevel(logging.INFO)
58
-
59
- formatter = logging.Formatter(
60
- '%(asctime)s - %(levelname)s - %(message)s')
61
-
62
- # 创建一个handler,用于写入日志文件
63
- file_handler = logging.FileHandler(
64
- filename=args.log_path)
65
- file_handler.setFormatter(formatter)
66
- file_handler.setLevel(logging.INFO)
67
- logger.addHandler(file_handler)
68
-
69
- # 创建一个handler,用于将日志输出到控制台
70
- console = logging.StreamHandler()
71
- console.setLevel(logging.DEBUG)
72
- console.setFormatter(formatter)
73
- logger.addHandler(console)
74
-
75
- return logger
76
-
77
- class Word_BERT(nn.Module):
78
- def __init__(self, seq_label=1,cancer_label=8,transfer_label=2,ly_transfer=2):
79
- super(Word_BERT, self).__init__()
80
- self.bert = BertModel.from_pretrained('./bert-base-zh')
81
- # self.bert_config = self.bert.config
82
- self.out = nn.Sequential(
83
- # nn.Linear(768,256),
84
- # nn.ReLU(),
85
- nn.Dropout(0.1),
86
- nn.Linear(768, seq_label)
87
- )
88
- self.cancer = nn.Sequential(
89
- nn.Dropout(0.1),
90
- nn.Linear(768, cancer_label)
91
- )
92
- self.transfer = nn.Sequential(
93
- nn.Dropout(0.1),
94
- nn.Linear(768, transfer_label)
95
- )
96
- self.ly_transfer = nn.Sequential(
97
- nn.Dropout(0.1),
98
- nn.Linear(768, ly_transfer)
99
- )
100
-
101
- def forward(self, word_input, masks):
102
- # print(word_input.size())
103
- output = self.bert(word_input, attention_mask=masks)
104
- sequence_output = output.last_hidden_state
105
- pool = output.pooler_output
106
- # print(sequence_output.size())
107
- # print(pool.size())
108
- out = self.out(sequence_output)
109
- cancer = self.cancer(pool)
110
- transfer = self.transfer(pool)
111
- ly_transfer = self.ly_transfer(pool)
112
- return out,cancer,transfer,ly_transfer
113
-
114
- def getChat(text: str, userid: int):
115
- # while True:
116
- # if True:
117
- # text = input("user:")
118
- # text = "你好"
119
- # if args.save_samples_path:
120
- # samples_file.write("user:{}\n".format(text))
121
- text = ['[CLS]']+[i for i in text]+['[SEP]']
122
- # print(text)
123
- text_ids = tokenizer.convert_tokens_to_ids(text)
124
- # print(text_ids)
125
-
126
- input_ids = torch.tensor(text_ids).long().to(device)
127
- input_ids = input_ids.unsqueeze(0)
128
- mask_input = torch.ones_like(input_ids).long().to(device)
129
- # print(input_ids.size())
130
- response = [] # 根据context,生成的response
131
- # 最多生成max_len个token
132
- with torch.no_grad():
133
- out, cancer, transfer, ly_transfer = model(input_ids, mask_input)
134
- out = F.sigmoid(out).squeeze(2).cpu()
135
- out = out.numpy().tolist()
136
- cancer = cancer.argmax(dim=-1).cpu().numpy().tolist()
137
- transfer = transfer.argmax(dim=-1).cpu().numpy().tolist()
138
- ly_transfer = ly_transfer.argmax(dim=-1).cpu().numpy().tolist()
139
- # print(out)
140
- # print(cancer,transfer,ly_transfer)
141
-
142
- pred_thresold = [[1 if jj > 0.4 else 0 for jj in ii] for ii in out]
143
- size_list = []
144
- start,end = 0,0
145
- for i,j in enumerate(pred_thresold[0]):
146
- if j==1 and start==end:
147
- start = i
148
- elif j!=1 and start!=end:
149
- end = i
150
- size_list.append((start,end))
151
- start = end
152
- print(size_list)
153
-
154
- cancer_dict = {'腺癌': 0, '肺良性疾病': 1, '鳞癌': 2, '无法判断组织分型': 3, '复合型': 4, '转移癌': 5, '小细胞癌': 6, '大细胞癌': 7}
155
- id_cancer = {j:i for i,j in cancer_dict.items()}
156
- transfer_id = {'无': 0, '转移': 1}
157
- id_transfer = {j:i for i,j in transfer_id.items()}
158
- lymph_transfer_id = {'无': 0, '淋巴转移': 1}
159
- id_lymph_transfer = {j: i for i, j in lymph_transfer_id.items()}
160
- # print(cancer)
161
- cancer = id_cancer[cancer[0]]
162
- transfer = id_transfer[transfer[0]]
163
- ly_transfer = id_lymph_transfer[ly_transfer[0]]
164
- print(cancer,transfer,ly_transfer)
165
-
166
- return size_list,cancer,transfer,ly_transfer
167
-
168
-
169
- import requests
170
-
171
-
172
- def testFunc():
173
- url = "http://localhost:7788/getChat"
174
- payload = json.dumps({
175
- "userid": 602099768,
176
- "context": "我就试一下"
177
- })
178
- headers = {
179
- 'Content-Type': 'application/json'
180
- }
181
- response = requests.request("POST", url, headers=headers, data=payload)
182
- print(response.text)
183
-
184
-
185
- import uvicorn
186
- from pydantic import BaseModel
187
- from fastapi import FastAPI
188
-
189
- app = FastAPI()
190
- # import intel_extension_for_pytorch as ipex
191
-
192
- args = set_args()
193
- logger = create_logger(args)
194
- # 当用户使用GPU,并且GPU可用时
195
- args.cuda = torch.cuda.is_available() and not args.no_cuda
196
- device = 'cuda' if args.cuda else 'cpu'
197
- logger.info('using device:{}'.format(device))
198
- os.environ["CUDA_VISIBLE_DEVICES"] = args.device
199
- tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
200
- # tokenizer = BertTokenizer(vocab_file=args.voca_path)
201
- model = Word_BERT()
202
- # model = model.load_state_dict(torch.load(args.model_path))
203
- model = model.to(device)
204
- # model = ipex.optimize(model, dtype=torch.float32)
205
- model.eval()
206
- # if args.save_samples_path:
207
- # if not os.path.exists(args.save_samples_path):
208
- # os.makedirs(args.save_samples_path)
209
- # samples_file = open(args.save_samples_path + '/samples.txt', 'a', encoding='utf8')
210
- # samples_file.write("聊天记录{}:\n".format(datetime.now()))
211
- # 存储聊天记录,每个utterance以token的id的形式进行存储
212
- # history = []
213
- Allhistory = {}
214
- print('初始化完成')
215
-
216
- if __name__ == '__main__':
217
- # getChat("测试一下", 0)
218
- # main()
219
- uvicorn.run(app='main_2:app', host="localhost",
220
- port=7788, reload=False)
221
- # testFunc()
222
-
223
-
224
- class Items1(BaseModel):
225
- context: str
226
- userid: int
227
- # must: bool
228
-
229
-
230
- import time
231
-
232
- lastReplyTime = 0
233
-
234
-
235
- @app.post("/getChat")
236
- async def get_Chat(item1: Items1):
237
- global lastReplyTime
238
- tempReplyTime = int(time.time() * 1000)
239
- # if tempReplyTime % 10 == 0 or item1.must == True or tempReplyTime - lastReplyTime < 30000:
240
- # if item1.must == True:
241
- # lastReplyTime = tempReplyTime
242
- result = getChat(
243
- item1.context, item1.userid)
244
- return {"res": result}