qianmuuq commited on
Commit
930237b
1 Parent(s): 4e02653

Upload 5 files

Browse files
bert-base-zh/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "directionality": "bidi",
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 3072,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 512,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 12,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "pooler_fc_size": 768,
19
+ "pooler_num_attention_heads": 12,
20
+ "pooler_num_fc_layers": 3,
21
+ "pooler_size_per_head": 128,
22
+ "pooler_type": "first_token_transform",
23
+ "type_vocab_size": 2,
24
+ "vocab_size": 21128
25
+ }
bert-base-zh/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a693db616eaf647ed2bfe531e1fa446637358fc108a8bf04e8d4db17e837ee9
3
+ size 411577189
bert-base-zh/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "do_lower_case": false
3
+ }
bert-base-zh/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
main_2.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ import os
4
+ import json
5
+ import random
6
+ import numpy as np
7
+ from torch import nn
8
+ import argparse
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ from datetime import datetime
11
+ from tqdm import tqdm
12
+ from torch.nn import DataParallel
13
+ import logging
14
+ from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config
15
+ from transformers import BertTokenizerFast
16
+ # from transformers import BertTokenizer
17
+ from os.path import join, exists
18
+ from itertools import zip_longest, chain
19
+ # from chatbot.model import DialogueGPT2Model
20
+ # from dataset import MyDataset
21
+ from torch.utils.data import Dataset, DataLoader
22
+ from torch.nn import CrossEntropyLoss
23
+ from sklearn.model_selection import train_test_split
24
+ import torch.nn.functional as F
25
+ from transformers import AutoTokenizer, AutoConfig, get_linear_schedule_with_warmup, AdamW, BertModel
26
+
27
+ PAD = '[PAD]'
28
+ pad_id = 0
29
+
30
+
31
+ def set_args():
32
+ """
33
+ Sets up the arguments.
34
+ """
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument('--device', default='0', type=str, required=False, help='生成设备')
37
+ # parser.add_argument('--model_config', default='config/model_config_dialogue_small.json', type=str, required=False,
38
+ # help='模型参数')
39
+ parser.add_argument('--log_path', default='interact.log', type=str, required=False, help='interact日志存放位置')
40
+ parser.add_argument('--model_path', default='./pathology_extra/result/12/model.pth', type=str, required=False, help='对话模型路径')
41
+ parser.add_argument('--vocab_path', default='D:\\transformerFileDownload\\Pytorch\\bert-base-zh\\vocab.txt', type=str, required=False,
42
+ help='对话模型路径')
43
+ parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径")
44
+ parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
45
+ help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
46
+ # parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
47
+ parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断')
48
+ parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度")
49
+ parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测')
50
+ return parser.parse_args()
51
+
52
+
53
+ def create_logger(args):
54
+ """
55
+ 将日志输出到日志文件和控制台
56
+ """
57
+ logger = logging.getLogger(__name__)
58
+ logger.setLevel(logging.INFO)
59
+
60
+ formatter = logging.Formatter(
61
+ '%(asctime)s - %(levelname)s - %(message)s')
62
+
63
+ # 创建一个handler,用于写入日志文件
64
+ file_handler = logging.FileHandler(
65
+ filename=args.log_path)
66
+ file_handler.setFormatter(formatter)
67
+ file_handler.setLevel(logging.INFO)
68
+ logger.addHandler(file_handler)
69
+
70
+ # 创建一个handler,用于将日志输出到控制台
71
+ console = logging.StreamHandler()
72
+ console.setLevel(logging.DEBUG)
73
+ console.setFormatter(formatter)
74
+ logger.addHandler(console)
75
+
76
+ return logger
77
+
78
+ class Word_BERT(nn.Module):
79
+ def __init__(self, seq_label=1,cancer_label=8,transfer_label=2,ly_transfer=2):
80
+ super(Word_BERT, self).__init__()
81
+ self.bert = BertModel.from_pretrained('D:\\transformerFileDownload\\Pytorch\\bert-base-zh')
82
+ # self.bert_config = self.bert.config
83
+ self.out = nn.Sequential(
84
+ # nn.Linear(768,256),
85
+ # nn.ReLU(),
86
+ nn.Dropout(0.1),
87
+ nn.Linear(768, seq_label)
88
+ )
89
+ self.cancer = nn.Sequential(
90
+ nn.Dropout(0.1),
91
+ nn.Linear(768, cancer_label)
92
+ )
93
+ self.transfer = nn.Sequential(
94
+ nn.Dropout(0.1),
95
+ nn.Linear(768, transfer_label)
96
+ )
97
+ self.ly_transfer = nn.Sequential(
98
+ nn.Dropout(0.1),
99
+ nn.Linear(768, ly_transfer)
100
+ )
101
+
102
+ def forward(self, word_input, masks):
103
+ # print(word_input.size())
104
+ output = self.bert(word_input, attention_mask=masks)
105
+ sequence_output = output.last_hidden_state
106
+ pool = output.pooler_output
107
+ # print(sequence_output.size())
108
+ # print(pool.size())
109
+ out = self.out(sequence_output)
110
+ cancer = self.cancer(pool)
111
+ transfer = self.transfer(pool)
112
+ ly_transfer = self.ly_transfer(pool)
113
+ return out,cancer,transfer,ly_transfer
114
+
115
+ def getChat(text: str, userid: int):
116
+ # while True:
117
+ # if True:
118
+ # text = input("user:")
119
+ # text = "你好"
120
+ # if args.save_samples_path:
121
+ # samples_file.write("user:{}\n".format(text))
122
+ text = ['[CLS]']+[i for i in text]+['[SEP]']
123
+ # print(text)
124
+ text_ids = tokenizer.convert_tokens_to_ids(text)
125
+ # print(text_ids)
126
+
127
+ input_ids = torch.tensor(text_ids).long().to(device)
128
+ input_ids = input_ids.unsqueeze(0)
129
+ mask_input = torch.ones_like(input_ids).long().to(device)
130
+ # print(input_ids.size())
131
+ response = [] # 根据context,生成的response
132
+ # 最多生成max_len个token
133
+ with torch.no_grad():
134
+ out, cancer, transfer, ly_transfer = model(input_ids, mask_input)
135
+ out = F.sigmoid(out).squeeze(2).cpu()
136
+ out = out.numpy().tolist()
137
+ cancer = cancer.argmax(dim=-1).cpu().numpy().tolist()
138
+ transfer = transfer.argmax(dim=-1).cpu().numpy().tolist()
139
+ ly_transfer = ly_transfer.argmax(dim=-1).cpu().numpy().tolist()
140
+ # print(out)
141
+ # print(cancer,transfer,ly_transfer)
142
+
143
+ pred_thresold = [[1 if jj > 0.4 else 0 for jj in ii] for ii in out]
144
+ size_list = []
145
+ start,end = 0,0
146
+ for i,j in enumerate(pred_thresold[0]):
147
+ if j==1 and start==end:
148
+ start = i
149
+ elif j!=1 and start!=end:
150
+ end = i
151
+ size_list.append((start,end))
152
+ start = end
153
+ print(size_list)
154
+
155
+ cancer_dict = {'腺癌': 0, '肺良性疾病': 1, '鳞癌': 2, '无法判断组织分型': 3, '复合型': 4, '转移癌': 5, '小细胞癌': 6, '大细胞癌': 7}
156
+ id_cancer = {j:i for i,j in cancer_dict.items()}
157
+ transfer_id = {'无': 0, '转移': 1}
158
+ id_transfer = {j:i for i,j in transfer_id.items()}
159
+ lymph_transfer_id = {'无': 0, '淋巴转移': 1}
160
+ id_lymph_transfer = {j: i for i, j in lymph_transfer_id.items()}
161
+ # print(cancer)
162
+ cancer = id_cancer[cancer[0]]
163
+ transfer = id_transfer[transfer[0]]
164
+ ly_transfer = id_lymph_transfer[ly_transfer[0]]
165
+ print(cancer,transfer,ly_transfer)
166
+
167
+ return size_list,cancer,transfer,ly_transfer
168
+
169
+
170
+ import requests
171
+
172
+
173
+ def testFunc():
174
+ url = "http://localhost:7788/getChat"
175
+ payload = json.dumps({
176
+ "userid": 602099768,
177
+ "context": "我就试一下"
178
+ })
179
+ headers = {
180
+ 'Content-Type': 'application/json'
181
+ }
182
+ response = requests.request("POST", url, headers=headers, data=payload)
183
+ print(response.text)
184
+
185
+
186
+ import uvicorn
187
+ from pydantic import BaseModel
188
+ from fastapi import FastAPI
189
+
190
+ app = FastAPI()
191
+ # import intel_extension_for_pytorch as ipex
192
+
193
+ args = set_args()
194
+ logger = create_logger(args)
195
+ # 当用户使用GPU,并且GPU可用时
196
+ args.cuda = torch.cuda.is_available() and not args.no_cuda
197
+ device = 'cuda' if args.cuda else 'cpu'
198
+ logger.info('using device:{}'.format(device))
199
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.device
200
+ tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
201
+ # tokenizer = BertTokenizer(vocab_file=args.voca_path)
202
+ model = Word_BERT()
203
+ # model = model.load_state_dict(torch.load(args.model_path))
204
+ model = model.to(device)
205
+ # model = ipex.optimize(model, dtype=torch.float32)
206
+ model.eval()
207
+ # if args.save_samples_path:
208
+ # if not os.path.exists(args.save_samples_path):
209
+ # os.makedirs(args.save_samples_path)
210
+ # samples_file = open(args.save_samples_path + '/samples.txt', 'a', encoding='utf8')
211
+ # samples_file.write("聊天记录{}:\n".format(datetime.now()))
212
+ # 存储聊天记录,每个utterance以token的id的形式进行存储
213
+ # history = []
214
+ Allhistory = {}
215
+ print('初始化完成')
216
+
217
+ if __name__ == '__main__':
218
+ # getChat("测试一下", 0)
219
+ # main()
220
+ uvicorn.run(app='main_2:app', host="localhost",
221
+ port=7788, reload=False)
222
+ # testFunc()
223
+
224
+
225
+ class Items1(BaseModel):
226
+ context: str
227
+ userid: int
228
+ # must: bool
229
+
230
+
231
+ import time
232
+
233
+ lastReplyTime = 0
234
+
235
+
236
+ @app.post("/getChat")
237
+ async def get_Chat(item1: Items1):
238
+ global lastReplyTime
239
+ tempReplyTime = int(time.time() * 1000)
240
+ # if tempReplyTime % 10 == 0 or item1.must == True or tempReplyTime - lastReplyTime < 30000:
241
+ # if item1.must == True:
242
+ # lastReplyTime = tempReplyTime
243
+ result = getChat(
244
+ item1.context, item1.userid)
245
+ return {"res": result}