| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from infer_model import CLIPpyModel | |
| from utils import get_similarity, get_transform, ade_palette, get_cmap_image | |
| pretrained_ckpt = "https://github.com/kahnchana/clippy/releases/download/v1.0/clippy_5k.pt" | |
| ckpt = torch.utils.model_zoo.load_url(pretrained_ckpt) | |
| clippy = CLIPpyModel() | |
| transform = get_transform((224, 224)) | |
| msg = clippy.load_state_dict(ckpt, strict=False) | |
| palette = ade_palette() | |
| def process_image(img, captions): | |
| sample_text = [x.strip() for x in captions.split(",")] | |
| sample_prompts = [f"a photo of a {x}" for x in sample_text] | |
| image = Image.fromarray(img) | |
| image_vector = clippy.encode_image(transform(image).unsqueeze(0), get_pos_tokens=True) | |
| text_vector = clippy.text.encode(sample_prompts, convert_to_tensor=True) | |
| similarity = get_similarity(image_vector, text_vector, (224, 224), do_argmax=True)[0, 0].numpy() | |
| rgb_seg = np.zeros((similarity.shape[0], similarity.shape[1], 3), dtype=np.uint8) | |
| for idx, _ in enumerate(sample_text): | |
| rgb_seg[similarity == idx] = palette[idx] | |
| joint = Image.blend(image, Image.fromarray(rgb_seg), 0.5) | |
| cmap = get_cmap_image({label: tuple(palette[idx]) for idx, label in enumerate(sample_text)}) | |
| return cmap, rgb_seg, joint | |
| title = 'CLIPpy' | |
| description = """ | |
| Gradio Demo for CLIPpy: Perceptual Grouping in Contrastive Vision Language Models. \n \n | |
| Upload an image and type in a set of comma separated labels (e.g.: "man, woman, background"). | |
| CLIPPy will segment the image, according to the set of class label you provide. | |
| """ | |
| article = """ | |
| <p style='text-align: center'> | |
| <a href='https://arxiv.org/abs/2210.09996' target='_blank'> | |
| Perceptual Grouping in Contrastive Vision Language Models | |
| </a> | |
| | | |
| <a href='https://github.com/kahnchana/clippy' target='_blank'>Github Repository</a></p> | |
| """ | |
| demo = gr.Interface( | |
| fn=process_image, | |
| inputs=[gr.Image(shape=(224, 224)), "text"], | |
| outputs=[gr.Image(shape=(224, 224)).style(height=150), | |
| gr.Image(shape=(224, 224)).style(height=260), | |
| gr.Image(shape=(224, 224)).style(height=260)], | |
| title=title, | |
| description=description, | |
| article=article, | |
| ) | |
| demo.launch() | |