File size: 2,019 Bytes
486f4bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b390d84
458c57f
 
 
 
 
 
486f4bf
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from transformers import IdeficsForVisionText2Text, AutoProcessor
from PIL import Image
import gradio as gr

model_id = "mrm8488/idefics-9b-ft-describe-diffusion-bf16"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = IdeficsForVisionText2Text.from_pretrained(model_id, torch_dtype=torch.bfloat16)
processor = AutoProcessor.from_pretrained(config.base_model_name_or_path)

def predict(prompt, image_url, max_length):
    image = processor.image_processor.fetch_images(image_url)
    prompts = [[image, prompt]]

    inputs = processor(prompts[0], return_tensors="pt").to(device)

    generated_ids = model.generate(**inputs, max_length=128)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print(generated_text)
    return generated_text



title = "Midjourney-like Image Captioning with IDEFICS"
description = "Gradio Demo for generating Midjourney like captions (describe functionality) with IDEFICS"

examples = [
    ["Describe the following image:", "https://cdn.arstechnica.net/wp-content/uploads/2023/06/zoomout_2-1440x807.jpg", 64],
    ["Describe the following image:", "https://framerusercontent.com/images/inZdRVn7eafZNvaVre2iW1a538.png", 64],
    ["Describe the following image:", "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg", 64]
    
]
io = gr.Interface(fn=image_caption, 
                  #inputs=gr.inputs.Image(type='pil'),
                  inputs=[
                      gr.inputs.Textbox(value="Describe the following image:"),
                      gr.inputs.Textbox(label="image URL", placeholder="Insert the URL of the image to be described"),
                      gr.inputs.Slider(label="Max tokens", value=64, max=128, min=16, step=8)
                  ]
                  outputs=gr.outputs.Textbox(label="IDEFICS Description"),
                  title=title, description=description
                  allow_flagging=False, allow_screenshot=False)
io.launch(show_errors=True)