mrm8488's picture
Update app.py
5bbc196
raw
history blame
1.93 kB
import torch
from transformers import IdeficsForVisionText2Text, AutoProcessor
from peft import PeftModel, PeftConfig
import gradio as gr
peft_model_id = "mrm8488/idefics-9b-ft-describe-diffusion-bf16-adapter"
device = "cuda" if torch.cuda.is_available() else "cpu"
config = PeftConfig.from_pretrained(peft_model_id)
model = IdeficsForVisionText2Text.from_pretrained(config.base_model_name_or_path, torch_dtype=torch.bfloat16)
model = PeftModel.from_pretrained(model, peft_model_id)
processor = AutoProcessor.from_pretrained(config.base_model_name_or_path)
model = model.to(device)
model.eval()
#Pre-determined best prompt for this fine-tune
prompt="Describe the following image:"
#Max generated tokens for your prompt
max_length=64
def predict(image):
prompts = [[image, prompt]]
inputs = processor(prompts[0], return_tensors="pt").to(device)
generated_ids = model.generate(**inputs, max_length=max_length)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
generated_text = generated_text.replace(prompt,"")
return generated_text
title = "Midjourney-like Image Captioning with IDEFICS"
description = "Gradio Demo for generating *Midjourney* like captions (describe functionality) with **IDEFICS**"
examples = [
["1_sTXgMwDUW0pk-1yK4iHYFw.png"],
["0_6as5rHi0sgG4W2Tq.png"],
["zoomout_2-1440x807.jpg"],
["inZdRVn7eafZNvaVre2iW1a538.webp"],
["cute-photos-of-cats-in-grass-1593184777.jpg"],
["llama2-coder-logo.png"]
]
io = gr.Interface(fn=predict,
inputs=[
gr.Image(label="Upload an image", type="pil"),
],
outputs=[
gr.Textbox(label="IDEFICS Description")
],
title=title, description=description, examples=examples,
allow_flagging=False, allow_screenshot=False)
io.launch(debug=True)