xinyu1205 commited on
Commit
8962d34
1 Parent(s): 0ad877d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -9
app.py CHANGED
@@ -1,21 +1,82 @@
1
  from PIL import Image
2
  import requests
 
 
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
 
 
5
  from transformers import BlipProcessor, BlipForConditionalGeneration
6
 
7
  model_id = "Salesforce/blip-image-captioning-base"
8
 
9
  model = BlipForConditionalGeneration.from_pretrained(model_id)
10
- processor = BlipProcessor.from_pretrained(model_id)
11
 
12
- ##
13
 
14
- def launch(input):
15
- image = Image.open(requests.get(input, stream=True).raw).convert('RGB')
16
- inputs = processor(image, return_tensors="pt")
17
- out = model.generate(**inputs)
18
- return processor.decode(out[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- iface = gr.Interface(launch, inputs="text", outputs="text")
21
- iface.launch()
 
1
  from PIL import Image
2
  import requests
3
+ import torch
4
+ from torchvision import transforms
5
+ from torchvision.transforms.functional import InterpolationMode
6
+
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
+
9
+
10
+
11
+
12
+
13
  import gradio as gr
14
 
15
+ # from models.blip import blip_decoder
16
  from transformers import BlipProcessor, BlipForConditionalGeneration
17
 
18
  model_id = "Salesforce/blip-image-captioning-base"
19
 
20
  model = BlipForConditionalGeneration.from_pretrained(model_id)
 
21
 
 
22
 
23
+ image_size = 384
24
+ transform = transforms.Compose([
25
+ transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
28
+ ])
29
+
30
+ # model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
31
+
32
+ # model = blip_decoder(pretrained=model_url, image_size=384, vit='large')
33
+ model.eval()
34
+ model = model.to(device)
35
+
36
+
37
+ # from models.blip_vqa import blip_vqa
38
+
39
+ # image_size_vq = 480
40
+ # transform_vq = transforms.Compose([
41
+ # transforms.Resize((image_size_vq,image_size_vq),interpolation=InterpolationMode.BICUBIC),
42
+ # transforms.ToTensor(),
43
+ # transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
44
+ # ])
45
+
46
+ # model_url_vq = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
47
+
48
+ # model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base')
49
+ # model_vq.eval()
50
+ # model_vq = model_vq.to(device)
51
+
52
+
53
+
54
+ def inference(raw_image, model_n, question, strategy):
55
+ if model_n == 'Image Captioning':
56
+ image = transform(raw_image).unsqueeze(0).to(device)
57
+ with torch.no_grad():
58
+ if strategy == "Beam search":
59
+ caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
60
+ else:
61
+ caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
62
+ return 'caption: '+caption[0]
63
+
64
+ else:
65
+ image_vq = transform_vq(raw_image).unsqueeze(0).to(device)
66
+ with torch.no_grad():
67
+ answer = model_vq(image_vq, question, train=False, inference='generate')
68
+ return 'answer: '+answer[0]
69
+
70
+ # inputs = [gr.inputs.Image(type='pil'),gr.inputs.Radio(choices=['Image Captioning',"Visual Question Answering"], type="value", default="Image Captioning", label="Task"),gr.inputs.Textbox(lines=2, label="Question"),gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Nucleus sampling", label="Caption Decoding Strategy")]
71
+ inputs = [gr.inputs.Image(type='pil'),gr.inputs.Radio(choices=['Image Captioning'], type="value", default="Image Captioning", label="Task"),gr.inputs.Textbox(lines=2, label="Question"),gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Nucleus sampling", label="Caption Decoding Strategy")]
72
+
73
+ outputs = gr.outputs.Textbox(label="Output")
74
+
75
+ title = "BLIP"
76
+
77
+ description = "Gradio demo for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation (Salesforce Research). To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
78
+
79
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation</a> | <a href='https://github.com/salesforce/BLIP' target='_blank'>Github Repo</a></p>"
80
+
81
 
82
+ gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['starrynight.jpeg',"Image Captioning","None","Nucleus sampling"]]).launch(enable_queue=True)