clip-gpt2 / app.py
bczhou's picture
Update app.py
9669c25
import gradio as gr
from clip_gpt2 import CLIPGPT2, CLIPGPT2Config, CLIPGPT2Processor
import os
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = CLIPGPT2Config(image_from_pretrained=False, text_from_pretrained=False)
model = CLIPGPT2(config)
model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
processor = CLIPGPT2Processor(config)
title = "Generate Image Captions With CLIP And GPT2"
def generate_image_captions(image, text):
inputs = processor(images=image, texts=text, return_tensors="pt")
input_ids = inputs.get("input_ids", None)
pixel_values = inputs.get("pixel_values", None)
attention_mask = inputs.get("attention_mask", None)
prediction = model.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=50
)
processor.tokenizer.padding_side = 'left'
processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
prediction_text = processor.decode(prediction[0], skip_special_tokens=True)
return prediction_text
article = "This demo is originated from this paper: [original paper](https://arxiv.org/abs/2209.15162)"
description = """
### Expand GPT2's language capabilities to vision with CLIP!
### Tips:
- Only English is supported.
- When no image is provided, the model degrades to a vanilla GPT2-Large!
- When no description is provided, the model automatically generates a caption for the provided image.
- Try appending 'Answer:' after your question, the model is more likely to give desired outputs this way.
"""
demo = gr.Interface(
fn=generate_image_captions,
inputs=[
gr.Image(),
gr.Textbox(placeholder="A picture of", lines=3)
],
outputs="text",
examples=[
[os.path.join(os.getcwd(), 'two_bear.png'), ""],
[os.path.join(os.getcwd(), 'three_women.png'), "What is the woman in the middle's dress's color? Answer:"],
[os.path.join(os.getcwd(), 'cat_with_food.png'), "Describe the picture:"],
[os.path.join(os.getcwd(), 'dog_with_frisbee.png'), "What is the color of the frisbee in the photo? Answer:"],
[os.path.join(os.getcwd(), 'stop_sign.png'), "What does the sign in the picture say? Answer:"]
],
article=article,
title=title,
description=description,
cache_examples=False
)
demo.launch()