qianmuuq commited on
Commit
fe257c0
1 Parent(s): 8fc82a6

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +18 -0
main.py CHANGED
@@ -160,6 +160,24 @@ app = FastAPI()
160
 
161
  pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  @app.get("/infer_t5")
164
  def t5(input):
165
  output = pipe_flan(input)
 
160
 
161
  pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
162
 
163
+ def model_init():
164
+ args = set_args()
165
+ logger = create_logger(args)
166
+ # 当用户使用GPU,并且GPU可用时
167
+ args.cuda = torch.cuda.is_available() and not args.no_cuda
168
+ device = 'cuda' if args.cuda else 'cpu'
169
+ logger.info('using device:{}'.format(device))
170
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.device
171
+ tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
172
+ # tokenizer = BertTokenizer(vocab_file=args.voca_path)
173
+ model = Word_BERT()
174
+ # model = model.load_state_dict(torch.load(args.model_path))
175
+ model = model.to(device)
176
+ model.eval()
177
+ return model
178
+
179
+ model = model_init()
180
+
181
  @app.get("/infer_t5")
182
  def t5(input):
183
  output = pipe_flan(input)