import torch from transformers import IdeficsForVisionText2Text, AutoProcessor from PIL import Image import gradio as gr model_id = "mrm8488/idefics-9b-ft-describe-diffusion-bf16" device = "cuda" if torch.cuda.is_available() else "cpu" model = IdeficsForVisionText2Text.from_pretrained(model_id, torch_dtype=torch.bfloat16) processor = AutoProcessor.from_pretrained(model_id) def predict(prompt, image_url, max_length): image = processor.image_processor.fetch_images(image_url) prompts = [[image, prompt]] inputs = processor(prompts[0], return_tensors="pt").to(device) generated_ids = model.generate(**inputs, max_length=128) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] print(generated_text) return generated_text title = "Midjourney-like Image Captioning with IDEFICS" description = "Gradio Demo for generating Midjourney like captions (describe functionality) with IDEFICS" examples = [ ["Describe the following image:", "https://cdn.arstechnica.net/wp-content/uploads/2023/06/zoomout_2-1440x807.jpg", 64], ["Describe the following image:", "https://framerusercontent.com/images/inZdRVn7eafZNvaVre2iW1a538.png", 64], ["Describe the following image:", "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg", 64] ] io = gr.Interface(fn=predict, #inputs=gr.inputs.Image(type='pil'), inputs=[ gr.inputs.Textbox(value="Describe the following image:"), gr.inputs.Textbox(label="image URL", placeholder="Insert the URL of the image to be described"), gr.inputs.Slider(label="Max tokens", value=64, max=128, min=16, step=8) ], outputs=gr.outputs.Textbox(label="IDEFICS Description"), title=title, description=description, examples=examples, allow_flagging=False, allow_screenshot=False) io.launch(show_errors=True)