qianmuuq commited on
Commit
47d6c70
1 Parent(s): 419df27

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +167 -0
main.py CHANGED
@@ -1,11 +1,178 @@
1
  from fastapi import FastAPI
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import FileResponse
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
 
 
5
  from transformers import pipeline
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  app = FastAPI()
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
10
 
11
  @app.get("/infer_t5")
 
1
  from fastapi import FastAPI
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import FileResponse
4
+ import torch
5
+ import os
6
+ import json
7
+ import random
8
+ import numpy as np
9
+ from torch import nn
10
+ import argparse
11
+ import logging
12
+ from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config
13
+ from transformers import BertTokenizerFast
14
+ import torch.nn.functional as F
15
+ from transformers import AutoTokenizer, AutoConfig, get_linear_schedule_with_warmup, AdamW, BertModel
16
 
17
+ import requests
18
+ import uvicorn
19
+ from pydantic import BaseModel
20
  from transformers import pipeline
21
 
22
+ def set_args():
23
+ """
24
+ Sets up the arguments.
25
+ """
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument('--device', default='0', type=str, required=False, help='生成设备')
28
+ # parser.add_argument('--model_config', default='config/model_config_dialogue_small.json', type=str, required=False,
29
+ # help='模型参数')
30
+ parser.add_argument('--log_path', default='interact.log', type=str, required=False, help='interact日志存放位置')
31
+ parser.add_argument('--model_path', default='./pathology_extra/result/12/model.pth', type=str, required=False, help='对话模型路径')
32
+ parser.add_argument('--vocab_path', default='/app/bert-base-zh/vocab.txt', type=str, required=False,
33
+ help='对话模型路径')
34
+ parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
35
+ help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
36
+ # parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
37
+ parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断')
38
+ parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度")
39
+ parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测')
40
+ return parser.parse_args()
41
+
42
+
43
+ def create_logger(args):
44
+ """
45
+ 将日志输出到日志文件和控制台
46
+ """
47
+ logger = logging.getLogger(__name__)
48
+ logger.setLevel(logging.INFO)
49
+
50
+ formatter = logging.Formatter(
51
+ '%(asctime)s - %(levelname)s - %(message)s')
52
+
53
+ # 创建一个handler,用于写入日志文件
54
+ file_handler = logging.FileHandler(
55
+ filename=args.log_path)
56
+ file_handler.setFormatter(formatter)
57
+ file_handler.setLevel(logging.INFO)
58
+ logger.addHandler(file_handler)
59
+
60
+ # 创建一个handler,用于将日志输出到控制台
61
+ console = logging.StreamHandler()
62
+ console.setLevel(logging.DEBUG)
63
+ console.setFormatter(formatter)
64
+ logger.addHandler(console)
65
+
66
+ return logger
67
+
68
+ class Word_BERT(nn.Module):
69
+ def __init__(self, seq_label=1,cancer_label=8,transfer_label=2,ly_transfer=2):
70
+ super(Word_BERT, self).__init__()
71
+ self.bert = BertModel.from_pretrained('/app/bert-base-zh')
72
+ # self.bert_config = self.bert.config
73
+ self.out = nn.Sequential(
74
+ # nn.Linear(768,256),
75
+ # nn.ReLU(),
76
+ nn.Dropout(0.1),
77
+ nn.Linear(768, seq_label)
78
+ )
79
+ self.cancer = nn.Sequential(
80
+ nn.Dropout(0.1),
81
+ nn.Linear(768, cancer_label)
82
+ )
83
+ self.transfer = nn.Sequential(
84
+ nn.Dropout(0.1),
85
+ nn.Linear(768, transfer_label)
86
+ )
87
+ self.ly_transfer = nn.Sequential(
88
+ nn.Dropout(0.1),
89
+ nn.Linear(768, ly_transfer)
90
+ )
91
+
92
+ def forward(self, word_input, masks):
93
+ # print(word_input.size())
94
+ output = self.bert(word_input, attention_mask=masks)
95
+ sequence_output = output.last_hidden_state
96
+ pool = output.pooler_output
97
+ # print(sequence_output.size())
98
+ # print(pool.size())
99
+ out = self.out(sequence_output)
100
+ cancer = self.cancer(pool)
101
+ transfer = self.transfer(pool)
102
+ ly_transfer = self.ly_transfer(pool)
103
+ return out,cancer,transfer,ly_transfer
104
+
105
+ def getChat(text):
106
+ # while True:
107
+ # if True:
108
+ # text = input("user:")
109
+ # text = "你好"
110
+ # if args.save_samples_path:
111
+ # samples_file.write("user:{}\n".format(text))
112
+ text = ['[CLS]']+[i for i in text]+['[SEP]']
113
+ # print(text)
114
+ text_ids = tokenizer.convert_tokens_to_ids(text)
115
+ # print(text_ids)
116
+
117
+ input_ids = torch.tensor(text_ids).long().to(device)
118
+ input_ids = input_ids.unsqueeze(0)
119
+ mask_input = torch.ones_like(input_ids).long().to(device)
120
+ # print(input_ids.size())
121
+ response = [] # 根据context,生成的response
122
+ # 最多生成max_len个token
123
+ with torch.no_grad():
124
+ out, cancer, transfer, ly_transfer = model(input_ids, mask_input)
125
+ out = F.sigmoid(out).squeeze(2).cpu()
126
+ out = out.numpy().tolist()
127
+ cancer = cancer.argmax(dim=-1).cpu().numpy().tolist()
128
+ transfer = transfer.argmax(dim=-1).cpu().numpy().tolist()
129
+ ly_transfer = ly_transfer.argmax(dim=-1).cpu().numpy().tolist()
130
+ # print(out)
131
+ # print(cancer,transfer,ly_transfer)
132
+
133
+ pred_thresold = [[1 if jj > 0.4 else 0 for jj in ii] for ii in out]
134
+ size_list = []
135
+ start,end = 0,0
136
+ for i,j in enumerate(pred_thresold[0]):
137
+ if j==1 and start==end:
138
+ start = i
139
+ elif j!=1 and start!=end:
140
+ end = i
141
+ size_list.append((start,end))
142
+ start = end
143
+ print(size_list)
144
+
145
+ cancer_dict = {'腺癌': 0, '肺良性疾病': 1, '鳞癌': 2, '无法判断组织分型': 3, '复合型': 4, '转移癌': 5, '小细胞癌': 6, '大细胞癌': 7}
146
+ id_cancer = {j:i for i,j in cancer_dict.items()}
147
+ transfer_id = {'无': 0, '转移': 1}
148
+ id_transfer = {j:i for i,j in transfer_id.items()}
149
+ lymph_transfer_id = {'无': 0, '淋巴转移': 1}
150
+ id_lymph_transfer = {j: i for i, j in lymph_transfer_id.items()}
151
+ # print(cancer)
152
+ cancer = id_cancer[cancer[0]]
153
+ transfer = id_transfer[transfer[0]]
154
+ ly_transfer = id_lymph_transfer[ly_transfer[0]]
155
+ print(cancer,transfer,ly_transfer)
156
+
157
+ return size_list,cancer,transfer,ly_transfer
158
+
159
  app = FastAPI()
160
 
161
+ args = set_args()
162
+ logger = create_logger(args)
163
+ # 当用户使用GPU,并且GPU可用时
164
+ args.cuda = torch.cuda.is_available() and not args.no_cuda
165
+ device = 'cuda' if args.cuda else 'cpu'
166
+ logger.info('using device:{}'.format(device))
167
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.device
168
+ tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
169
+ # tokenizer = BertTokenizer(vocab_file=args.voca_path)
170
+ model = Word_BERT()
171
+ # model = model.load_state_dict(torch.load(args.model_path))
172
+ model = model.to(device)
173
+ model.eval()
174
+ print('初始化完成')
175
+
176
  pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
177
 
178
  @app.get("/infer_t5")