oscmansan commited on
Commit
91755d5
1 Parent(s): c8b5d27

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import gradio as gr
4
+
5
+
6
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
+
8
+ model = torch.hub.load('mair-lab/mapl-private', 'mapl')
9
+ model.eval()
10
+ model.to(device)
11
+
12
+
13
+ def predict(image: Image.Image, question: str) -> str:
14
+ pixel_values = model.image_transform(image).unsqueeze(0).to(device)
15
+
16
+ input_ids = None
17
+ if question:
18
+ text = f"Please answer the question. Question: {question} Answer:" if '?' in question else question
19
+ input_ids = model.text_transform(text).input_ids.to(device)
20
+
21
+ with torch.autocast(device_type=device, dtype=torch.float16):
22
+ generated_ids = model.generate(
23
+ pixel_values=pixel_values,
24
+ input_ids=input_ids,
25
+ max_new_tokens=50,
26
+ num_beams=5
27
+ )
28
+
29
+ answer = model.text_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
30
+
31
+ return answer
32
+
33
+
34
+ image = gr.components.Image(type='pil')
35
+ question = gr.components.Textbox(value="What is this?", label="Question")
36
+ answer = gr.components.Textbox(label="Answer")
37
+
38
+ interface = gr.Interface(
39
+ fn=predict,
40
+ inputs=[image, question],
41
+ outputs=answer,
42
+ allow_flagging='never')
43
+ interface.launch()