File size: 1,992 Bytes
e7febdd
 
c61d5a8
 
e7febdd
c61d5a8
 
 
9aafc47
8c03de9
113c1cc
c61d5a8
4d677d2
c61d5a8
5e43e87
4e9e3d3
c61d5a8
4e9e3d3
9aafc47
4d677d2
c61d5a8
dd8705d
8c03de9
 
 
 
 
 
c61d5a8
2dfbbfa
8c03de9
 
 
c61d5a8
71820b7
4d677d2
c61d5a8
 
 
 
 
 
 
 
8c03de9
 
dd8705d
8c03de9
 
 
 
 
 
 
 
e7febdd
c61d5a8
 
 
 
 
 
e7febdd
 
c61d5a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env python

from __future__ import annotations

import gradio as gr
import PIL.Image
import spaces
import torch
from transformers import AutoProcessor, BlipForConditionalGeneration
from typing import Union
import os

DESCRIPTION = "# Image Captioning with LongCap"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)

model_id = "unography/blip-long-cap"
processor = AutoProcessor.from_pretrained(model_id)
model = BlipForConditionalGeneration.from_pretrained(model_id).to(device)

torch.hub.download_url_to_file("https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg", "demo.jpg")
torch.hub.download_url_to_file(
    "https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png", "stop_sign.png"
)
torch.hub.download_url_to_file(
    "https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg", "astronaut.jpg"
)

@spaces.GPU()
def run(image: Union[str, PIL.Image.Image]) -> str:
    if isinstance(image, str):
        image = Image.open(image)
    inputs = processor(images=image, return_tensors="pt").to(device)
    out = model.generate(pixel_values=inputs.pixel_values, num_beams=3, repetition_penalty=2.5, max_length=300)
    generated_caption = processor.decode(out[0], skip_special_tokens=True)
    return generated_caption


with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    input_image = gr.Image(type="pil")
    run_button = gr.Button("Caption")
    output = gr.Textbox(label="Result")
    gr.Examples(
        examples=[
            "demo.jpg",
            "stop_sign.png",
            "astronaut.jpg",
        ],
        inputs=input_image,
        outputs=output,
        fn=run,
        cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
    )

    run_button.click(
        fn=run,
        inputs=input_image,
        outputs=output,
        api_name="caption",
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()