File size: 4,897 Bytes
09b15be
 
 
 
 
f160eaf
09b15be
 
 
 
 
0391a1c
f160eaf
 
09b15be
 
0391a1c
 
09b15be
3fbe084
f160eaf
 
 
 
 
 
 
 
 
 
09b15be
 
f160eaf
 
09b15be
 
 
0391a1c
09b15be
 
 
0391a1c
 
 
 
 
 
09b15be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cc174c
 
 
09b15be
 
 
 
 
 
f160eaf
09b15be
 
 
 
 
 
5f27df1
 
09b15be
 
 
 
 
 
 
 
 
 
0391a1c
 
 
09b15be
 
 
 
0391a1c
 
09b15be
 
 
 
 
 
0391a1c
09b15be
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import torch
from transformers import FuyuForCausalLM, AutoTokenizer
from transformers.models.fuyu.processing_fuyu import FuyuProcessor
from transformers.models.fuyu.image_processing_fuyu import FuyuImageProcessor
from PIL import Image

model_id = "adept/fuyu-8b"
dtype = torch.bfloat16
device = "cuda"


tokenizer = AutoTokenizer.from_pretrained(model_id)
model = FuyuForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=dtype)
processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=tokenizer)

CAPTION_PROMPT = "Generate a coco-style caption.\n"
DETAILED_CAPTION_PROMPT = "What is happening in this image?"

def resize_to_max(image, max_width=1080, max_height=1080):
    width, height = image.size
    if width <= max_width and height <= max_height:
        return image

    scale = min(max_width/width, max_height/height)
    width = int(width*scale)
    height = int(height*scale)

    return image.resize((width, height), Image.LANCZOS)

def predict(image, prompt):
    # image = image.convert('RGB')
    image = resize_to_max(image)

    model_inputs = processor(text=prompt, images=[image])
    model_inputs = {k: v.to(dtype=dtype if torch.is_floating_point(v) else v.dtype, device=device) for k,v in model_inputs.items()}

    generation_output = model.generate(**model_inputs, max_new_tokens=50)
    prompt_len = model_inputs["input_ids"].shape[-1]
    return tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True)

def caption(image, detailed_captioning):
    if detailed_captioning:
        caption_prompt = DETAILED_CAPTION_PROMPT
    else:
        caption_prompt = CAPTION_PROMPT
    return predict(image, caption_prompt).lstrip()

def set_example_image(example: list) -> dict:
    return gr.Image.update(value=example[0])



css = """
  #mkd {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.HTML(
        """
            <h1 id="title">Fuyu Multimodal Demo</h1>
            <h3><a href="https://hf.co/adept/fuyu-8b">Fuyu-8B</a> is a multimodal model that supports a variety of tasks combining text and image prompts.</h3>
            For example, you can use it for captioning by asking it to describe an image. You can also ask it questions about an image, a task known as Visual Question Answering, or VQA. This demo lets you explore captioning and VQA, with more tasks coming soon :)
            Learn more about the model in <a href="https://www.adept.ai/blog/fuyu-8b">our blog post</a>.
            <br>
          	<br>
            <strong>Note: This is a raw model release. We have not added further instruction-tuning, postprocessing or sampling strategies to control for undesirable outputs. The model may hallucinate, and you should expect to have to fine-tune the model for your use-case!</strong>
            <h3>Play with Fuyu-8B in this demo! πŸ’¬</h3>
        """
    )
    with gr.Tab("Visual Question Answering"):
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(label="Upload your Image", type="pil")
                text_input = gr.Textbox(label="Ask a Question")
            vqa_output = gr.Textbox(label="Output")
            
        vqa_btn = gr.Button("Answer Visual Question")
        
        gr.Examples(
            [["assets/vqa_example_1.png", "How is this made?"], ["assets/vqa_example_2.png", "What is this flower and where is it's origin?"],
            ["assets/docvqa_example.png", "How many items are sold?"], ["assets/screen2words_ui_example.png", "What is this app about?"]],
            inputs = [image_input, text_input],
            outputs = [vqa_output],
            fn=predict,
            cache_examples=True,
            label='Click on any Examples below to get VQA results quickly πŸ‘‡'
            )

        
    with gr.Tab("Image Captioning"):
        with gr.Row():
            with gr.Column():
                captioning_input = gr.Image(label="Upload your Image", type="pil")
                detailed_captioning_checkbox = gr.Checkbox(label="Enable detailed captioning")
            captioning_output = gr.Textbox(label="Output")
        captioning_btn = gr.Button("Generate Caption")

        gr.Examples(
            [["assets/captioning_example_1.png", False], ["assets/captioning_example_2.png", True]],
            inputs = [captioning_input, detailed_captioning_checkbox],
            outputs = [captioning_output],
            fn=caption,
            cache_examples=True,
            label='Click on any Examples below to get captioning results quickly πŸ‘‡'
            )
        
    captioning_btn.click(fn=caption, inputs=[captioning_input, detailed_captioning_checkbox], outputs=captioning_output)
    vqa_btn.click(fn=predict, inputs=[image_input, text_input], outputs=vqa_output)

    
demo.launch(server_name="0.0.0.0")