File size: 4,721 Bytes
2773523
76c8f3a
2773523
c57b6d0
2773523
 
3a227f4
800cbf3
2773523
05c2134
 
2773523
05c2134
 
 
 
 
 
 
 
2773523
c62a436
 
76c8f3a
c62a436
c57b6d0
 
 
 
 
a808d84
 
05c2134
 
 
 
51d259a
c57b6d0
a808d84
76c8f3a
38ab1e3
2773523
a808d84
76c8f3a
 
d6b2a16
76c8f3a
 
2773523
 
 
 
c57b6d0
 
 
 
 
 
2773523
05c2134
 
 
 
 
2773523
05c2134
2773523
76c8f3a
c62a436
c57b6d0
 
 
2773523
 
800cbf3
c57b6d0
2773523
329d18e
6334863
 
2773523
329d18e
2773523
05c2134
2773523
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, VisionEncoderDecoderModel
import torch
import open_clip

torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')

git_processor_base = AutoProcessor.from_pretrained("microsoft/git-base-coco")
git_model_base = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")

git_processor_large = AutoProcessor.from_pretrained("microsoft/git-large-coco")
git_model_large = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")

blip_processor_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model_base = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")

vitgpt_processor = AutoImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
vitgpt_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
vitgpt_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

coca_model, _, coca_transform = open_clip.create_model_and_transforms(
    "coca_ViT-L-14",
    pretrained="laion2B-s13B-b90k-mscoco-2014.pt"
)

device = "cuda" if torch.cuda.is_available() else "cpu"

git_model_base.to(device)
blip_model_base.to(device)
git_model_large.to(device)
blip_model_large.to(device)
vitgpt_model.to(device)
coca_model.to(device)

def generate_caption(processor, model, image, tokenizer=None):
    inputs = processor(images=image, return_tensors="pt").to(device)
    
    generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)

    if tokenizer is not None:
        generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    else:
        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
   
    return generated_caption


def generate_caption_coca(model, transform, image):
    im = transform(image).unsqueeze(0).to(device)
    generated = model.generate(im, seq_len=20)
    return open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")


def generate_captions(image):
    caption_git_base = generate_caption(git_processor_base, git_model_base, image)

    caption_git_large = generate_caption(git_processor_large, git_model_large, image)

    caption_blip_base = generate_caption(blip_processor_base, blip_model_base, image)

    caption_blip_large = generate_caption(blip_processor_large, blip_model_large, image)

    caption_vitgpt = generate_caption(vitgpt_processor, vitgpt_model, image, vitgpt_tokenizer)

    caption_coca = generate_caption_coca(coca_model, coca_transform, image)

    return caption_git_base, caption_git_large, caption_blip_base, caption_blip_large, caption_vitgpt, caption_coca

   
examples = [["cats.jpg"], ["stop_sign.png"], ["astronaut.jpg"]]
outputs = [gr.outputs.Textbox(label="Caption generated by GIT-base"), gr.outputs.Textbox(label="Caption generated by GIT-large"), gr.outputs.Textbox(label="Caption generated by BLIP-base"), gr.outputs.Textbox(label="Caption generated by BLIP-large"), gr.outputs.Textbox(label="Caption generated by ViT+GPT-2"), gr.outputs.Textbox(label="Caption generated by CoCa")] 

title = "Interactive demo: comparing image captioning models"
description = "Gradio Demo to compare GIT, BLIP and ViT+GPT2, 3 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"

interface = gr.Interface(fn=generate_captions, 
                         inputs=gr.inputs.Image(type="pil"),
                         outputs=outputs,
                         examples=examples, 
                         title=title,
                         description=description,
                         article=article, 
                         enable_queue=True)
interface.launch(debug=True)