import gradio as gr from clip_gpt2 import CLIPGPT2, CLIPGPT2Config, CLIPGPT2Processor import os import torch device = 'cuda' if torch.cuda.is_available() else 'cpu' config = CLIPGPT2Config(image_from_pretrained=False, text_from_pretrained=False) model = CLIPGPT2(config) model.load_state_dict(torch.load("pytorch_model.bin", map_location=device)) processor = CLIPGPT2Processor(config) title = "Generate Image Captions With CLIP And GPT2" def generate_image_captions(image, text): inputs = processor(images=image, texts=text, return_tensors="pt") input_ids = inputs.get("input_ids", None) pixel_values = inputs.get("pixel_values", None) attention_mask = inputs.get("attention_mask", None) prediction = model.generate( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=50 ) processor.tokenizer.padding_side = 'left' processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id prediction_text = processor.decode(prediction[0], skip_special_tokens=True) return prediction_text article = "This demo is originated from this paper: [original paper](https://arxiv.org/abs/2209.15162)" description = """ ### Expand GPT2's language capabilities to vision with CLIP! ### Tips: - Only English is supported. - When no image is provided, the model degrades to a vanilla GPT2-Large! - When no description is provided, the model automatically generates a caption for the provided image. - Try appending 'Answer:' after your question, the model is more likely to give desired outputs this way. """ demo = gr.Interface( fn=generate_image_captions, inputs=[ gr.Image(), gr.Textbox(placeholder="A picture of", lines=3) ], outputs="text", examples=[ [os.path.join(os.getcwd(), 'two_bear.png'), ""], [os.path.join(os.getcwd(), 'three_women.png'), "What is the woman in the middle's dress's color? Answer:"], [os.path.join(os.getcwd(), 'cat_with_food.png'), "Describe the picture:"], [os.path.join(os.getcwd(), 'dog_with_frisbee.png'), "What is the color of the frisbee in the photo? Answer:"], [os.path.join(os.getcwd(), 'stop_sign.png'), "What does the sign in the picture say? Answer:"] ], article=article, title=title, description=description, cache_examples=False ) demo.launch()