johann22's picture
Update app.py
77d615b
raw
history blame
2.23 kB
import torch
from transformers import AutoProcessor
#from transformers import IdeficsForVisionText2Text, AutoProcessor
#from peft import PeftModel, PeftConfig
import gradio as gr
from huggingface_hub import InferenceClient
client = InferenceClient("HuggingFaceM4/idefics-9b-instruct")
peft_model_id = "mrm8488/idefics-9b-ft-describe-diffusion-bf16-adapter"
#peft_model_id = "HuggingFaceM4/idefics-9b"
device = "cuda" if torch.cuda.is_available() else "cpu"
#config = PeftConfig.from_pretrained(peft_model_id)
#model = AutoProcessor.from_pretrained(peft_model_id, torch_dtype=torch.bfloat16)
#model = PeftModel.from_pretrained(model, peft_model_id)
processor = AutoProcessor.from_pretrained(peft_model_id)
#model = model.to(device)
#model.eval()
#Pre-determined best prompt for this fine-tune
prompt="Describe the following image:"
#Max generated tokens for your prompt
max_length=64
def predict(image):
prompts = [[image, prompt]]
print (prompts)
inputs = image
#inputs = processor(prompts[0], return_tensors="pt").to(device)
print (inputs)
#generated_ids = client.image_to_text(**inputs, max_length=max_length)
generated_ids = client.image_to_text(inputs)
print (generated_ids)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print (generated_text)
generated_text = generated_text.replace(f"{prompt} ","")
print (generated_text)
return generated_text
title = "Midjourney-like Image Captioning with IDEFICS"
description = "Gradio Demo for generating *Midjourney* like captions (describe functionality) with **IDEFICS**"
examples = [
["1_sTXgMwDUW0pk-1yK4iHYFw.png"],
["0_6as5rHi0sgG4W2Tq.png"],
["zoomout_2-1440x807.jpg"],
["inZdRVn7eafZNvaVre2iW1a538.webp"],
["cute-photos-of-cats-in-grass-1593184777.jpg"],
["llama2-coder-logo.png"]
]
io = gr.Interface(fn=predict,
inputs=[
gr.Image(label="Upload an image", type="pil"),
],
outputs=[
gr.Textbox(label="IDEFICS Description")
],
title=title, description=description, examples=examples)
io.launch(debug=True)