ruslanmv commited on
Commit
ecbc33a
1 Parent(s): 36c27b0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +113 -1
main.py CHANGED
@@ -1,4 +1,116 @@
1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import numpy as np
4
  from PIL import Image
@@ -42,7 +154,7 @@ iface = gr.Interface(
42
  iface.launch(server_name="0.0.0.0", server_port=7860)
43
 
44
 
45
- '''
46
  import gradio as gr
47
  import subprocess
48
 
 
1
 
2
+ import gradio as gr
3
+ from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, VisionEncoderDecoderModel
4
+ import torch
5
+
6
+ torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
7
+ torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
8
+ torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')
9
+
10
+ git_processor_base = AutoProcessor.from_pretrained("microsoft/git-base-coco")
11
+ git_model_base = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
12
+
13
+ git_processor_large = AutoProcessor.from_pretrained("microsoft/git-large-coco")
14
+ git_model_large = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
15
+
16
+ blip_processor_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
17
+ blip_model_base = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
18
+
19
+ blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
20
+ blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
21
+
22
+ vitgpt_processor = AutoImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
23
+ vitgpt_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
24
+ vitgpt_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ git_model_base.to(device)
29
+ blip_model_base.to(device)
30
+ git_model_large.to(device)
31
+ blip_model_large.to(device)
32
+ vitgpt_model.to(device)
33
+
34
+ def generate_caption(processor, model, image, tokenizer=None):
35
+ inputs = processor(images=image, return_tensors="pt").to(device)
36
+
37
+ generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
38
+
39
+ if tokenizer is not None:
40
+ generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
41
+ else:
42
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
43
+
44
+ return generated_caption
45
+
46
+ def generate_captions(image):
47
+ caption_git_base = generate_caption(git_processor_base, git_model_base, image)
48
+
49
+ caption_git_large = generate_caption(git_processor_large, git_model_large, image)
50
+
51
+ caption_blip_base = generate_caption(blip_processor_base, blip_model_base, image)
52
+
53
+ caption_blip_large = generate_caption(blip_processor_large, blip_model_large, image)
54
+
55
+ caption_vitgpt = generate_caption(vitgpt_processor, vitgpt_model, image, vitgpt_tokenizer)
56
+
57
+ return caption_git_base, caption_git_large, caption_blip_base, caption_blip_large, caption_vitgpt
58
+
59
+ examples = [["cats.jpg"], ["stop_sign.png"], ["astronaut.jpg"]]
60
+ outputs = [gr.outputs.Textbox(label="Caption generated by GIT-base"), gr.outputs.Textbox(label="Caption generated by GIT-large"), gr.outputs.Textbox(label="Caption generated by BLIP-base"), gr.outputs.Textbox(label="Caption generated by BLIP-large"), gr.outputs.Textbox(label="Caption generated by ViT+GPT-2")]
61
+
62
+ title = "Interactive demo: comparing image captioning models"
63
+ description = "Gradio Demo to compare GIT, BLIP and ViT+GPT2, 3 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
64
+ article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"
65
+
66
+ css = """
67
+ body {
68
+ background-color: #f2f2f2;
69
+ font-family: Arial, sans-serif;
70
+ }
71
+
72
+ .title {
73
+ color: #333333;
74
+ font-size: 24px;
75
+ font-weight: bold;
76
+ margin-bottom: 20px;
77
+ }
78
+
79
+ .description {
80
+ color: #666666;
81
+ font-size: 16px;
82
+ margin-bottom: 20px;
83
+ }
84
+
85
+ .article {
86
+ color: #666666;
87
+ font-size: 14px;
88
+ margin-bottom: 20px;
89
+ text-align: center;
90
+ }
91
+
92
+ .input {
93
+ margin-bottom: 20px;
94
+ }
95
+
96
+ .output {
97
+ margin-bottom: 20px;
98
+ }
99
+ """
100
+
101
+ interface = gr.Interface(fn=generate_captions,
102
+ inputs=gr.inputs.Image(type="pil"),
103
+ outputs=outputs,
104
+ examples=examples,
105
+ title=title,
106
+ description=description,
107
+ article=article,
108
+ css=css,
109
+ enable_queue=True)
110
+ interface.launch(debug=True)
111
+
112
+ '''
113
+
114
  import gradio as gr
115
  import numpy as np
116
  from PIL import Image
 
154
  iface.launch(server_name="0.0.0.0", server_port=7860)
155
 
156
 
157
+
158
  import gradio as gr
159
  import subprocess
160