baixintech_zhangyiming_prod commited on
Commit
141d444
1 Parent(s): 3d7e2fe
Files changed (2) hide show
  1. app.py +30 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import gradio.components as grc
3
+ import torch
4
+ from PIL import Image
5
+
6
+ from lavis.models import load_model_and_preprocess
7
+ from lavis.processors import load_processor
8
+
9
+ raw_image = Image.open("merlion.png").convert("RGB")
10
+
11
+ # setup device to use
12
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
13
+ caption = "merlion in Singapore"
14
+
15
+ model, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True)
16
+
17
+ def predict(raw_image, caption):
18
+ raw_image = raw_image.convert("RGB")
19
+ img = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
20
+ txt = text_processors["eval"](caption)
21
+ itm_output = model({"image": img, "text_input": txt}, match_head="itm")
22
+ itm_scores = torch.nn.functional.softmax(itm_output, dim=1)
23
+ itm_score = itm_scores[:, 1].item()
24
+ itc_score = model({"image": img, "text_input": txt}, match_head='itc')
25
+ return '%.3f' % itm_score, '%.4f' % itc_score
26
+
27
+ app = gr.Interface(fn=predict, inputs=[grc.Image(type="pil"), grc.Textbox()], outputs=[grc.Text(label="itm score"), grc.Text(label="itc score")])
28
+ app.launch()
29
+
30
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ salesforce-lavis
2
+ pillow
3
+ torch