nielsr's picture
nielsr HF staff
Update app.py
d522bbe
raw
history blame
No virus
4.82 kB
import gradio as gr
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration, VisionEncoderDecoderModel, InstructBlipForConditionalGeneration
import torch
import open_clip
from huggingface_hub import hf_hub_download
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')
git_processor_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
git_model_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco").to(device)
blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
blip2_processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-6.7b-coco")
blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b-coco", device_map="auto", load_in_4bit=True, torch_dtype=torch.float16)
instructblip_processor = AutoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
instructblip_model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto", load_in_4bit=True, torch_dtype=torch.float16)
def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
inputs = processor(images=image, return_tensors="pt").to(device)
if use_float_16:
inputs = inputs.to(torch.float16)
generated_ids = model.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5)
if tokenizer is not None:
generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_caption
def generate_caption_blip2(processor, model, image, prompt, replace_token=False):
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device=model.device, dtype=torch.float16)
generated_ids = model.generate(**inputs,
num_beams=5, max_length=50, min_length=1, top_p=0.9,
repetition_penalty=1.5, length_penalty=1.0, temperature=1)
if replace_token:
# TODO remove once https://github.com/huggingface/transformers/pull/24492 is merged
generated_ids[generated_ids == 0] = 2
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
def generate_captions(image):
caption_git_large_coco = generate_caption(git_processor_large_coco, git_model_large_coco, image)
caption_blip_large = generate_caption(blip_processor_large, blip_model_large, image)
caption_blip2 = generate_caption_blip2(blip2_processor, blip2_model, image, prompt="A photo of").strip()
caption_instructblip = generate_caption_blip2(instructblip_processor, instructblip_model, image, prompt="Generate a caption for the image:", replace_token=True)
return caption_git_large_coco, caption_blip_large, caption_blip2, caption_instructblip
examples = [["cats.jpg"], ["stop_sign.png"], ["astronaut.jpg"]]
outputs = [gr.outputs.Textbox(label="Caption generated by GIT-large fine-tuned on COCO"), gr.outputs.Textbox(label="Caption generated by BLIP-large"), gr.outputs.Textbox(label="Caption generated by BLIP-2 OPT 6.7b"), gr.outputs.Textbox(label="Caption generated by InstructBLIP"), ]
title = "Interactive demo: comparing image captioning models"
description = "Gradio Demo to compare GIT, BLIP, BLIP-2 and InstructBLIP, 4 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"
interface = gr.Interface(fn=generate_captions,
inputs=gr.inputs.Image(type="pil"),
outputs=outputs,
examples=examples,
title=title,
description=description,
article=article,
enable_queue=True)
interface.launch(debug=True)