Spaces:
Runtime error
Runtime error
File size: 2,673 Bytes
486f4bf ce0ce67 486f4bf ce0ce67 486f4bf ce0ce67 486f4bf 7378b29 486f4bf 8d07585 486f4bf cf8abff 486f4bf 8d07585 b390d84 458c57f 7378b29 458c57f 1651e3c 486f4bf 8f189a9 7378b29 8f189a9 9fd6fc9 cf8abff b94ad3f cf8abff 9fd6fc9 486f4bf 232d5c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import torch
from transformers import IdeficsForVisionText2Text, AutoProcessor
from peft import PeftModel, PeftConfig
import gradio as gr
peft_model_id = "mrm8488/idefics-9b-ft-describe-diffusion-bf16-adapter"
device = "cuda" if torch.cuda.is_available() else "cpu"
config = PeftConfig.from_pretrained(peft_model_id)
model = IdeficsForVisionText2Text.from_pretrained(config.base_model_name_or_path, torch_dtype=torch.bfloat16)
model = PeftModel.from_pretrained(model, peft_model_id)
processor = AutoProcessor.from_pretrained(config.base_model_name_or_path)
model = model.to(device)
model.eval()
def predict(prompt, image_url, image_pil=None, max_length=64):
if image_pil is not None:
image = image_pil
else:
image = processor.image_processor.fetch_images(image_url)
prompts = [[image, prompt]]
inputs = processor(prompts[0], return_tensors="pt").to(device)
generated_ids = model.generate(**inputs, max_length=max_length)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return image, generated_text
title = "Midjourney-like Image Captioning with IDEFICS"
description = "Gradio Demo for generating *Midjourney* like captions (describe functionality) with **IDEFICS**"
examples = [
["Describe the following image:", "https://miro.medium.com/v2/resize:fit:0/1*sTXgMwDUW0pk-1yK4iHYFw.png", None, 64],
["Describe the following image:", "https://miro.medium.com/v2/resize:fit:1400/0*6as5rHi0sgG4W2Tq.png", None, 64],
["Describe the following image:", "https://cdn.arstechnica.net/wp-content/uploads/2023/06/zoomout_2-1440x807.jpg", None, 64],
["Describe the following image:", "https://framerusercontent.com/images/inZdRVn7eafZNvaVre2iW1a538.png", None, 64],
["Describe the following image:", "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg", None, 64]
]
io = gr.Interface(fn=predict,
inputs=[
gr.Textbox(label="Prompt", value="Describe the following image:"),
gr.Textbox(label="image URL", placeholder="Insert the URL of the image to be described"),
gr.Image(label="or upload an image", type="pil"),
gr.Slider(label="Max tokens", value=64, max=128, min=16, step=8)
],
outputs=[
gr.Image(type='pil', label="Image"),
gr.Textbox(label="IDEFICS Description")
],
title=title, description=description, examples=examples,
allow_flagging=False, allow_screenshot=False)
io.launch(debug=True) |