Update main.py
Browse files
main.py
CHANGED
@@ -19,6 +19,7 @@ import uvicorn
|
|
19 |
from pydantic import BaseModel
|
20 |
from transformers import pipeline
|
21 |
|
|
|
22 |
def set_args():
|
23 |
"""
|
24 |
Sets up the arguments.
|
@@ -161,21 +162,18 @@ app = FastAPI()
|
|
161 |
pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
|
162 |
|
163 |
def model_init():
|
164 |
-
args = set_args()
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
# logger.info('using device:{}'.format(device))
|
170 |
-
# os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
171 |
-
# tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
|
172 |
# # tokenizer = BertTokenizer(vocab_file=args.voca_path)
|
173 |
-
|
174 |
# # model = model.load_state_dict(torch.load(args.model_path))
|
175 |
# model = model.to(device)
|
176 |
-
|
177 |
-
|
178 |
-
return None
|
179 |
|
180 |
model11 = model_init()
|
181 |
print(os.getcwd())
|
|
|
19 |
from pydantic import BaseModel
|
20 |
from transformers import pipeline
|
21 |
|
22 |
+
extra_args = {}
|
23 |
def set_args():
|
24 |
"""
|
25 |
Sets up the arguments.
|
|
|
162 |
pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
|
163 |
|
164 |
def model_init():
|
165 |
+
# args = set_args()
|
166 |
+
acuda = torch.cuda.is_available() and not args.no_cuda
|
167 |
+
device = 'cuda' if acuda else 'cpu'
|
168 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
169 |
+
tokenizer = BertTokenizerFast(vocab_file='/app/bert-base-zh/vocab.txt', sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
|
|
|
|
|
|
|
170 |
# # tokenizer = BertTokenizer(vocab_file=args.voca_path)
|
171 |
+
model = Word_BERT()
|
172 |
# # model = model.load_state_dict(torch.load(args.model_path))
|
173 |
# model = model.to(device)
|
174 |
+
model.eval()
|
175 |
+
return model
|
176 |
+
# return None
|
177 |
|
178 |
model11 = model_init()
|
179 |
print(os.getcwd())
|