clippy / app.py
kahnchana's picture
init
1d7cddb
import gradio as gr
import numpy as np
import torch
from PIL import Image
from infer_model import CLIPpyModel
from utils import get_similarity, get_transform, ade_palette, get_cmap_image
pretrained_ckpt = "https://github.com/kahnchana/clippy/releases/download/v1.0/clippy_5k.pt"
ckpt = torch.utils.model_zoo.load_url(pretrained_ckpt)
clippy = CLIPpyModel()
transform = get_transform((224, 224))
msg = clippy.load_state_dict(ckpt, strict=False)
palette = ade_palette()
def process_image(img, captions):
sample_text = [x.strip() for x in captions.split(",")]
sample_prompts = [f"a photo of a {x}" for x in sample_text]
image = Image.fromarray(img)
image_vector = clippy.encode_image(transform(image).unsqueeze(0), get_pos_tokens=True)
text_vector = clippy.text.encode(sample_prompts, convert_to_tensor=True)
similarity = get_similarity(image_vector, text_vector, (224, 224), do_argmax=True)[0, 0].numpy()
rgb_seg = np.zeros((similarity.shape[0], similarity.shape[1], 3), dtype=np.uint8)
for idx, _ in enumerate(sample_text):
rgb_seg[similarity == idx] = palette[idx]
joint = Image.blend(image, Image.fromarray(rgb_seg), 0.5)
cmap = get_cmap_image({label: tuple(palette[idx]) for idx, label in enumerate(sample_text)})
return cmap, rgb_seg, joint
title = 'CLIPpy'
description = """
Gradio Demo for CLIPpy: Perceptual Grouping in Contrastive Vision Language Models. \n \n
Upload an image and type in a set of comma separated labels (e.g.: "man, woman, background").
CLIPPy will segment the image, according to the set of class label you provide.
"""
article = """
<p style='text-align: center'>
<a href='https://arxiv.org/abs/2210.09996' target='_blank'>
Perceptual Grouping in Contrastive Vision Language Models
</a>
|
<a href='https://github.com/kahnchana/clippy' target='_blank'>Github Repository</a></p>
"""
demo = gr.Interface(
fn=process_image,
inputs=[gr.Image(shape=(224, 224)), "text"],
outputs=[gr.Image(shape=(224, 224)).style(height=150),
gr.Image(shape=(224, 224)).style(height=260),
gr.Image(shape=(224, 224)).style(height=260)],
title=title,
description=description,
article=article,
)
demo.launch()