File size: 2,231 Bytes
1d7cddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import numpy as np
import torch
from PIL import Image

from infer_model import CLIPpyModel
from utils import get_similarity, get_transform, ade_palette, get_cmap_image

pretrained_ckpt = "https://github.com/kahnchana/clippy/releases/download/v1.0/clippy_5k.pt"
ckpt = torch.utils.model_zoo.load_url(pretrained_ckpt)

clippy = CLIPpyModel()
transform = get_transform((224, 224))

msg = clippy.load_state_dict(ckpt, strict=False)

palette = ade_palette()


def process_image(img, captions):
    sample_text = [x.strip() for x in captions.split(",")]
    sample_prompts = [f"a photo of a {x}" for x in sample_text]

    image = Image.fromarray(img)
    image_vector = clippy.encode_image(transform(image).unsqueeze(0), get_pos_tokens=True)
    text_vector = clippy.text.encode(sample_prompts, convert_to_tensor=True)

    similarity = get_similarity(image_vector, text_vector, (224, 224), do_argmax=True)[0, 0].numpy()
    rgb_seg = np.zeros((similarity.shape[0], similarity.shape[1], 3), dtype=np.uint8)
    for idx, _ in enumerate(sample_text):
        rgb_seg[similarity == idx] = palette[idx]

    joint = Image.blend(image, Image.fromarray(rgb_seg), 0.5)
    cmap = get_cmap_image({label: tuple(palette[idx]) for idx, label in enumerate(sample_text)})

    return cmap, rgb_seg, joint


title = 'CLIPpy'

description = """
Gradio Demo for CLIPpy: Perceptual Grouping in Contrastive Vision Language Models. \n \n
Upload an image and type in a set of comma separated labels (e.g.: "man, woman, background"). 
CLIPPy will segment the image, according to the set of class label you provide.
"""

article = """
<p style='text-align: center'>
<a href='https://arxiv.org/abs/2210.09996' target='_blank'>
Perceptual Grouping in Contrastive Vision Language Models
</a>
|
<a href='https://github.com/kahnchana/clippy' target='_blank'>Github Repository</a></p>
"""

demo = gr.Interface(
    fn=process_image,
    inputs=[gr.Image(shape=(224, 224)), "text"],
    outputs=[gr.Image(shape=(224, 224)).style(height=150),
             gr.Image(shape=(224, 224)).style(height=260),
             gr.Image(shape=(224, 224)).style(height=260)],
    title=title,
    description=description,
    article=article,
)

demo.launch()