File size: 1,867 Bytes
486f4bf
 
ce0ce67
486f4bf
 
ce0ce67
486f4bf
 
ce0ce67
 
 
 
 
 
 
486f4bf
78cea7d
 
 
 
 
 
 
486f4bf
 
8d07585
486f4bf
a853a81
78cea7d
486f4bf
 
8d07585
b390d84
458c57f
78cea7d
 
 
 
5bbc196
 
458c57f
1651e3c
486f4bf
78cea7d
9fd6fc9
cf8abff
 
 
b20969d
232d5c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from transformers import IdeficsForVisionText2Text, AutoProcessor
from peft import PeftModel, PeftConfig
import gradio as gr

peft_model_id = "mrm8488/idefics-9b-ft-describe-diffusion-bf16-adapter"
device = "cuda" if torch.cuda.is_available() else "cpu"


config = PeftConfig.from_pretrained(peft_model_id)
model = IdeficsForVisionText2Text.from_pretrained(config.base_model_name_or_path, torch_dtype=torch.bfloat16)
model = PeftModel.from_pretrained(model, peft_model_id)
processor = AutoProcessor.from_pretrained(config.base_model_name_or_path)
model = model.to(device)
model.eval()

#Pre-determined best prompt for this fine-tune
prompt="Describe the following image:"

#Max generated tokens for your prompt
max_length=64

def predict(image):
    prompts = [[image, prompt]]
    inputs = processor(prompts[0], return_tensors="pt").to(device)
    generated_ids = model.generate(**inputs, max_length=max_length)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    generated_text = generated_text.replace(f"{prompt} ","")
    return generated_text

title = "Midjourney-like Image Captioning with IDEFICS"
description = "Gradio Demo for generating *Midjourney* like captions (describe functionality) with **IDEFICS**"

examples = [
    ["1_sTXgMwDUW0pk-1yK4iHYFw.png"],
    ["0_6as5rHi0sgG4W2Tq.png"],
    ["zoomout_2-1440x807.jpg"],
    ["inZdRVn7eafZNvaVre2iW1a538.webp"],
    ["cute-photos-of-cats-in-grass-1593184777.jpg"],
    ["llama2-coder-logo.png"]
]
io = gr.Interface(fn=predict, 
                  inputs=[
                      gr.Image(label="Upload an image", type="pil"),
                  ],
                  outputs=[
                      gr.Textbox(label="IDEFICS Description")
                  ],
                  title=title, description=description, examples=examples)
io.launch(debug=True)