mrm8488's picture
Update app.py
8d07585
import torch
from transformers import IdeficsForVisionText2Text, AutoProcessor
from peft import PeftModel, PeftConfig
import gradio as gr
peft_model_id = "mrm8488/idefics-9b-ft-describe-diffusion-bf16-adapter"
device = "cuda" if torch.cuda.is_available() else "cpu"
config = PeftConfig.from_pretrained(peft_model_id)
model = IdeficsForVisionText2Text.from_pretrained(config.base_model_name_or_path, torch_dtype=torch.bfloat16)
model = PeftModel.from_pretrained(model, peft_model_id)
processor = AutoProcessor.from_pretrained(config.base_model_name_or_path)
model = model.to(device)
model.eval()
def predict(prompt, image_url, max_length):
image = processor.image_processor.fetch_images(image_url)
prompts = [[image, prompt]]
inputs = processor(prompts[0], return_tensors="pt").to(device)
generated_ids = model.generate(**inputs, max_length=max_length)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
title = "Midjourney-like Image Captioning with IDEFICS"
description = "Gradio Demo for generating *Midjourney* like captions (describe functionality) with **IDEFICS**"
examples = [
["Describe the following image:", "https://cdn.arstechnica.net/wp-content/uploads/2023/06/zoomout_2-1440x807.jpg", 64],
["Describe the following image:", "https://framerusercontent.com/images/inZdRVn7eafZNvaVre2iW1a538.png", 64],
["Describe the following image:", "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg", 64]
]
io = gr.Interface(fn=predict,
#inputs=gr.inputs.Image(type='pil'),
inputs=[
gr.Textbox(label="Prompt", value="Describe the following image:"),
gr.Textbox(label="image URL", placeholder="Insert the URL of the image to be described"),
gr.Slider(label="Max tokens", value=64, max=128, min=16, step=8)
],
outputs=gr.Textbox(label="IDEFICS Description"),
title=title, description=description, examples=examples,
allow_flagging=False, allow_screenshot=False)
io.launch(debug=True)