mrm8488 commited on
Commit
486f4bf
1 Parent(s): 90a1f93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py CHANGED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import IdeficsForVisionText2Text, AutoProcessor
3
+ from PIL import Image
4
+ import gradio as gr
5
+
6
+ model_id = "mrm8488/idefics-9b-ft-describe-diffusion-bf16"
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ model = IdeficsForVisionText2Text.from_pretrained(model_id, torch_dtype=torch.bfloat16)
10
+ processor = AutoProcessor.from_pretrained(config.base_model_name_or_path)
11
+
12
+ def predict(prompt, image_url, max_length):
13
+ image = processor.image_processor.fetch_images(image_url)
14
+ prompts = [[image, prompt]]
15
+
16
+ inputs = processor(prompts[0], return_tensors="pt").to(device)
17
+
18
+ generated_ids = model.generate(**inputs, max_length=128)
19
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
20
+ print(generated_text)
21
+ return generated_text
22
+
23
+
24
+
25
+ title = "Midjourney-like Image Captioning with IDEFICS"
26
+ description = "Gradio Demo for generating Midjourney like captions (describe functionality) with IDEFICS"
27
+ #article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
28
+ "Repo</a></p> "
29
+ #examples = [['beatles.jpeg'], ['aurora.jpeg'], ['good_luck.png'], ['pokemons.jpg'], ['donuts.jpg']]
30
+ io = gr.Interface(fn=image_caption,
31
+ #inputs=gr.inputs.Image(type='pil'),
32
+ inputs=[
33
+ gr.inputs.Textbox(value="Describe the following image:"),
34
+ gr.inputs.Textbox(label="image URL", placeholder="Insert the URL of the image to be described"),
35
+ gr.inputs.Slider(label="Max tokens", value=64, max=128, min=16, step=8)
36
+ ]
37
+ outputs=gr.outputs.Textbox(label="IDEFICS Description"),
38
+ title=title, description=description
39
+ allow_flagging=False, allow_screenshot=False)
40
+ io.launch(show_errors=True)