svjack commited on
Commit
6e20f5b
1 Parent(s): 4d15793

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -17,13 +17,17 @@ en_model_path = "question_generator_by_en_on_pic"
17
  #zh_model_path = "question_generator_by_zh_on_pic"
18
 
19
  task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
20
- en_pretrained_model = DonutModel.from_pretrained(en_model_path)
21
  #zh_pretrained_model = DonutModel.from_pretrained(zh_model_path)
 
22
 
23
  if torch.cuda.is_available():
24
  en_pretrained_model.half()
25
  device = torch.device("cuda")
26
  en_pretrained_model.to(device)
 
 
 
27
 
28
  '''
29
  if torch.cuda.is_available():
 
17
  #zh_model_path = "question_generator_by_zh_on_pic"
18
 
19
  task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
20
+ #en_pretrained_model = DonutModel.from_pretrained(en_model_path)
21
  #zh_pretrained_model = DonutModel.from_pretrained(zh_model_path)
22
+ en_pretrained_model = DonutModel.from_pretrained(en_model_path, ignore_mismatched_sizes=True)
23
 
24
  if torch.cuda.is_available():
25
  en_pretrained_model.half()
26
  device = torch.device("cuda")
27
  en_pretrained_model.to(device)
28
+ else:
29
+ import torch
30
+ en_pretrained_model.encoder.to(torch.bfloat16)
31
 
32
  '''
33
  if torch.cuda.is_available():