qianmuuq commited on
Commit
297e563
1 Parent(s): ea27bda

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +10 -12
main.py CHANGED
@@ -19,6 +19,7 @@ import uvicorn
19
  from pydantic import BaseModel
20
  from transformers import pipeline
21
 
 
22
  def set_args():
23
  """
24
  Sets up the arguments.
@@ -161,21 +162,18 @@ app = FastAPI()
161
  pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
162
 
163
  def model_init():
164
- args = set_args()
165
- # logger = create_logger(args)
166
- # # 当用户使用GPU,并且GPU可用时
167
- # args.cuda = torch.cuda.is_available() and not args.no_cuda
168
- # device = 'cuda' if args.cuda else 'cpu'
169
- # logger.info('using device:{}'.format(device))
170
- # os.environ["CUDA_VISIBLE_DEVICES"] = args.device
171
- # tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
172
  # # tokenizer = BertTokenizer(vocab_file=args.voca_path)
173
- # model = Word_BERT()
174
  # # model = model.load_state_dict(torch.load(args.model_path))
175
  # model = model.to(device)
176
- # model.eval()
177
- # return model
178
- return None
179
 
180
  model11 = model_init()
181
  print(os.getcwd())
 
19
  from pydantic import BaseModel
20
  from transformers import pipeline
21
 
22
+ extra_args = {}
23
  def set_args():
24
  """
25
  Sets up the arguments.
 
162
  pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
163
 
164
  def model_init():
165
+ # args = set_args()
166
+ acuda = torch.cuda.is_available() and not args.no_cuda
167
+ device = 'cuda' if acuda else 'cpu'
168
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.device
169
+ tokenizer = BertTokenizerFast(vocab_file='/app/bert-base-zh/vocab.txt', sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
 
 
 
170
  # # tokenizer = BertTokenizer(vocab_file=args.voca_path)
171
+ model = Word_BERT()
172
  # # model = model.load_state_dict(torch.load(args.model_path))
173
  # model = model.to(device)
174
+ model.eval()
175
+ return model
176
+ # return None
177
 
178
  model11 = model_init()
179
  print(os.getcwd())