File size: 9,193 Bytes
b6396ac
 
 
 
 
 
 
f0a46b7
b6396ac
 
 
63c1053
1c48e2c
b6396ac
 
 
 
 
 
 
 
 
 
 
f7c79e5
b6396ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bd742d
 
 
b6396ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1519d5d
b6396ac
 
 
 
1519d5d
b6396ac
 
 
 
1519d5d
b6396ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12fc175
b6396ac
d415b47
b6396ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
823c6c6
b6396ac
 
 
 
 
823c6c6
b6396ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import os 
import sys

os.system("pip install gdown")

os.system("pip install imutils")

os.system("python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'")

os.system("pip install git+https://github.com/cocodataset/panopticapi.git")

os.system("python fcclip/modeling/pixel_decoder/ops/setup.py build install")

import gradio as gr
# check pytorch installation: 
from detectron2.utils.logger import setup_logger
from contextlib import ExitStack
# import some common libraries
import numpy as np
import cv2
import torch
import itertools
# import some common detectron2 utilities
from detectron2.config import get_cfg
from detectron2.utils.visualizer import ColorMode, random_color
from detectron2.data import MetadataCatalog
from detectron2.projects.deeplab import add_deeplab_config


# import FCCLIP project
from fcclip import add_maskformer2_config, add_fcclip_config
from demo.predictor import DefaultPredictor, OpenVocabVisualizer
from PIL import Image
import imutils
import json

setup_logger()
logger = setup_logger(name="fcclip")

cfg = get_cfg()
cfg.MODEL.DEVICE='cpu'
add_deeplab_config(cfg)
add_maskformer2_config(cfg)
add_fcclip_config(cfg)
cfg.merge_from_file("configs/coco/panoptic-segmentation/fcclip/fcclip_convnext_large_eval_ade20k.yaml")
os.system("gdown 1-91PIns86vyNaL3CzMmDD39zKGnPMtvj")
cfg.MODEL.WEIGHTS = './fcclip_cocopan.pth'
cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = False
cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = False
cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = True
predictor = DefaultPredictor(cfg)

# def inference(img):
#     im = cv2.imread(img)
#     #im = imutils.resize(im, width=512)
#     outputs = predictor(im)
#     v = OpenVocabVisualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)
#     panoptic_result = v.draw_panoptic_seg(outputs["panoptic_seg"][0].to("cpu"), outputs["panoptic_seg"][1]).get_image()
#     return Image.fromarray(np.uint8(panoptic_result)).convert('RGB')
    

title = "FC-CLIP"
description = """Gradio demo for FC-CLIP. To use it, simply upload your image, or click one of the examples to load them. FC-CLIP could perform open vocabulary segmentation, you may input more classes (separate by comma).
The expected format is 'a1,a2;b1,b2', where a1,a2 are synonyms vocabularies for the first class. 
The first word will be displayed as the class name.Read more at the links below."""

article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2207.04044' target='_blank'>kMaX-DeepLab</a> | <a href='https://github.com/google-research/deeplab2' target='_blank'>Github Repo</a></p>"

examples = [
    [
        "demo/examples/coco.jpg",
        "black pickup truck,pickup truck;blue sky,sky",
        ["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
    ],
    [
        "demo/examples/ade.jpg",
        "luggage,suitcase,baggage;handbag",
        ["ADE (150 categories)"],
    ],
    [
        "demo/examples/ego4d.jpg",
        "faucet,tap;kitchen paper,paper towels",
        ["COCO (133 categories)"],
    ],
]


coco_metadata = MetadataCatalog.get("openvocab_coco_2017_val_panoptic_with_sem_seg")
ade20k_metadata = MetadataCatalog.get("openvocab_ade20k_panoptic_val")
lvis_classes = open("./fcclip/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines()
lvis_classes = [x[x.find(':')+1:] for x in lvis_classes]
lvis_colors = list(
    itertools.islice(itertools.cycle(coco_metadata.stuff_colors), len(lvis_classes))
)
# rerrange to thing_classes, stuff_classes
coco_thing_classes = coco_metadata.thing_classes
coco_stuff_classes = [x for x in coco_metadata.stuff_classes if x not in coco_thing_classes]
coco_thing_colors = coco_metadata.thing_colors
coco_stuff_colors = [x for x in coco_metadata.stuff_colors if x not in coco_thing_colors]
ade20k_thing_classes = ade20k_metadata.thing_classes
ade20k_stuff_classes = [x for x in ade20k_metadata.stuff_classes if x not in ade20k_thing_classes]
ade20k_thing_colors = ade20k_metadata.thing_colors
ade20k_stuff_colors = [x for x in ade20k_metadata.stuff_colors if x not in ade20k_thing_colors]

def build_demo_classes_and_metadata(vocab, label_list):
    extra_classes = []

    if vocab:
        for words in vocab.split(";"):
            extra_classes.append(words)
    extra_colors = [random_color(rgb=True, maximum=1) for _ in range(len(extra_classes))]
    print("extra_classes:", extra_classes)
    demo_thing_classes = extra_classes
    demo_stuff_classes = []
    demo_thing_colors = extra_colors
    demo_stuff_colors = []

    if any("COCO" in label for label in label_list):
        demo_thing_classes += coco_thing_classes
        demo_stuff_classes += coco_stuff_classes
        demo_thing_colors += coco_thing_colors
        demo_stuff_colors += coco_stuff_colors
    if any("ADE" in label for label in label_list):
        demo_thing_classes += ade20k_thing_classes
        demo_stuff_classes += ade20k_stuff_classes
        demo_thing_colors += ade20k_thing_colors
        demo_stuff_colors += ade20k_stuff_colors
    if any("LVIS" in label for label in label_list):
        demo_thing_classes += lvis_classes
        demo_thing_colors += lvis_colors

    MetadataCatalog.pop("fcclip_demo_metadata", None)
    demo_metadata = MetadataCatalog.get("fcclip_demo_metadata")
    demo_metadata.thing_classes = [c[0] for c in demo_thing_classes]
    demo_metadata.stuff_classes = [
        *demo_metadata.thing_classes,
        *[c[0] for c in demo_stuff_classes],
    ]
    demo_metadata.thing_colors = demo_thing_colors
    demo_metadata.stuff_colors = demo_thing_colors + demo_stuff_colors
    demo_metadata.stuff_dataset_id_to_contiguous_id = {
        idx: idx for idx in range(len(demo_metadata.stuff_classes))
    }
    demo_metadata.thing_dataset_id_to_contiguous_id = {
        idx: idx for idx in range(len(demo_metadata.thing_classes))
    }

    demo_classes = demo_thing_classes + demo_stuff_classes

    return demo_classes, demo_metadata


def inference(image_path, vocab, label_list):

    logger.info("building class names")
    vocab = vocab.replace(", ", ",").replace("; ", ";")
    demo_classes, demo_metadata = build_demo_classes_and_metadata(vocab, label_list)
    predictor.set_metadata(demo_metadata)

    im = cv2.imread(image_path)
    outputs = predictor(im)
    v = OpenVocabVisualizer(im[:, :, ::-1], demo_metadata, instance_mode=ColorMode.IMAGE)
    panoptic_result = v.draw_panoptic_seg(outputs["panoptic_seg"][0].to("cpu"), outputs["panoptic_seg"][1]).get_image()
    return Image.fromarray(np.uint8(panoptic_result)).convert('RGB')

    
with gr.Blocks(title=title) as demo:
    gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
    gr.Markdown(description)
    input_components = []
    output_components = []

    with gr.Row():
        output_image_gr = gr.outputs.Image(label="Panoptic Segmentation", type="pil")
        output_components.append(output_image_gr)

    with gr.Row().style(equal_height=True, mobile_collapse=True):
        with gr.Column(scale=3, variant="panel") as input_component_column:
            input_image_gr = gr.inputs.Image(type="filepath")
            extra_vocab_gr = gr.inputs.Textbox(default="", label="Extra Vocabulary")
            category_list_gr = gr.inputs.CheckboxGroup(
                choices=["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
                default=["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
                label="Category to use",
            )
            input_components.extend([input_image_gr, extra_vocab_gr, category_list_gr])

        with gr.Column(scale=2):
            examples_handler = gr.Examples(
                examples=examples,
                inputs=[c for c in input_components if not isinstance(c, gr.State)],
                outputs=[c for c in output_components if not isinstance(c, gr.State)],
                fn=inference,
                cache_examples=torch.cuda.is_available(),
                examples_per_page=5,
            )
            with gr.Row():
                clear_btn = gr.Button("Clear")
                submit_btn = gr.Button("Submit", variant="primary")

    gr.Markdown(article)

    submit_btn.click(
        inference,
        input_components,
        output_components,
        api_name="predict",
        scroll_to_output=True,
    )

    clear_btn.click(
        None,
        [],
        (input_components + output_components + [input_component_column]),
        _js=f"""() => {json.dumps(
                    [component.cleared_value if hasattr(component, "cleared_value") else None
                     for component in input_components + output_components] + (
                        [gr.Column.update(visible=True)]
                    )
                    + ([gr.Column.update(visible=False)])
                )}
                """,
    )

demo.launch()


# gr.Interface(inference, inputs=gr.inputs.Image(type="filepath"), outputs=gr.outputs.Image(label="Panoptic segmentation",type="pil"), title=title,
#     description=description,
#     article=article,
#     examples=examples).launch(enable_queue=True)