oscmansan commited on
Commit
e04d426
1 Parent(s): 0db66af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -7,23 +7,22 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
 
8
  model = torch.hub.load('mair-lab/mapl', 'mapl')
9
  model.eval()
10
- model.to(device=device, dtype=torch.bfloat16)
11
 
12
 
13
  def predict(image: Image.Image, question: str) -> str:
14
- pixel_values = model.image_transform(image).unsqueeze(0).to(device)
15
 
16
  input_ids = None
17
  if question:
18
  text = f"Please answer the question. Question: {question} Answer:" if '?' in question else question
19
  input_ids = model.text_transform(text).input_ids.to(device)
20
 
21
- with torch.autocast(device_type=device, dtype=torch.bfloat16):
22
- generated_ids = model.generate(
23
- pixel_values=pixel_values,
24
- input_ids=input_ids,
25
- max_new_tokens=50
26
- )
27
 
28
  answer = model.text_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
29
 
 
7
 
8
  model = torch.hub.load('mair-lab/mapl', 'mapl')
9
  model.eval()
10
+ model.to(device, torch.bfloat16)
11
 
12
 
13
  def predict(image: Image.Image, question: str) -> str:
14
+ pixel_values = model.image_transform(image).unsqueeze(0).to(device, torch.bfloat16)
15
 
16
  input_ids = None
17
  if question:
18
  text = f"Please answer the question. Question: {question} Answer:" if '?' in question else question
19
  input_ids = model.text_transform(text).input_ids.to(device)
20
 
21
+ generated_ids = model.generate(
22
+ pixel_values=pixel_values,
23
+ input_ids=input_ids,
24
+ max_new_tokens=50
25
+ )
 
26
 
27
  answer = model.text_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
28