File size: 2,673 Bytes
486f4bf
 
ce0ce67
486f4bf
 
ce0ce67
486f4bf
 
ce0ce67
 
 
 
 
 
 
486f4bf
7378b29
 
 
 
 
486f4bf
 
8d07585
486f4bf
cf8abff
486f4bf
 
 
 
8d07585
b390d84
458c57f
7378b29
 
 
 
 
458c57f
 
1651e3c
486f4bf
8f189a9
 
7378b29
8f189a9
9fd6fc9
cf8abff
b94ad3f
cf8abff
 
9fd6fc9
486f4bf
232d5c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from transformers import IdeficsForVisionText2Text, AutoProcessor
from peft import PeftModel, PeftConfig
import gradio as gr

peft_model_id = "mrm8488/idefics-9b-ft-describe-diffusion-bf16-adapter"
device = "cuda" if torch.cuda.is_available() else "cpu"


config = PeftConfig.from_pretrained(peft_model_id)
model = IdeficsForVisionText2Text.from_pretrained(config.base_model_name_or_path, torch_dtype=torch.bfloat16)
model = PeftModel.from_pretrained(model, peft_model_id)
processor = AutoProcessor.from_pretrained(config.base_model_name_or_path)
model = model.to(device)
model.eval()

def predict(prompt, image_url, image_pil=None, max_length=64):
    if image_pil is not None:
        image = image_pil
    else:  
        image = processor.image_processor.fetch_images(image_url)
    prompts = [[image, prompt]]
    inputs = processor(prompts[0], return_tensors="pt").to(device)
    generated_ids = model.generate(**inputs, max_length=max_length)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return image, generated_text



title = "Midjourney-like Image Captioning with IDEFICS"
description = "Gradio Demo for generating *Midjourney* like captions (describe functionality) with **IDEFICS**"

examples = [
    ["Describe the following image:", "https://miro.medium.com/v2/resize:fit:0/1*sTXgMwDUW0pk-1yK4iHYFw.png", None, 64],
    ["Describe the following image:", "https://miro.medium.com/v2/resize:fit:1400/0*6as5rHi0sgG4W2Tq.png", None, 64],
    ["Describe the following image:", "https://cdn.arstechnica.net/wp-content/uploads/2023/06/zoomout_2-1440x807.jpg", None, 64],
    ["Describe the following image:", "https://framerusercontent.com/images/inZdRVn7eafZNvaVre2iW1a538.png", None, 64],
    ["Describe the following image:", "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg", None, 64]
    
]
io = gr.Interface(fn=predict, 
                  inputs=[
                      gr.Textbox(label="Prompt", value="Describe the following image:"),
                      gr.Textbox(label="image URL", placeholder="Insert the URL of the image to be described"),
                      gr.Image(label="or upload an image", type="pil"),
                      gr.Slider(label="Max tokens", value=64, max=128, min=16, step=8)
                  ],
                  outputs=[
                      gr.Image(type='pil', label="Image"),
                      gr.Textbox(label="IDEFICS Description")
                  ],
                  title=title, description=description, examples=examples,
                  allow_flagging=False, allow_screenshot=False)
io.launch(debug=True)