File size: 6,318 Bytes
e972e1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f368476
e972e1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1def87
e972e1f
 
 
 
 
 
4788a4c
 
 
 
 
e972e1f
 
c8d5211
e972e1f
 
 
 
 
 
 
0274e1e
e972e1f
 
d1def87
e972e1f
 
6f0521e
c8d5211
e972e1f
 
 
 
 
d1def87
e972e1f
 
 
 
 
 
 
f368476
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# --------------------------------------------------------
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Xueyan Zou (xueyan@cs.wisc.edu), Jianwei Yang (jianwyan@microsoft.com)
# --------------------------------------------------------

import os
os.system("python -m pip install git+https://github.com/MaureenZOU/detectron2-xyz.git")

import gradio as gr
import torch
import argparse

from xdecoder.BaseModel import BaseModel
from xdecoder import build_model
from utils.distributed import init_distributed
from utils.arguments import load_opt_from_config_files

from tasks import *

def parse_option():
    parser = argparse.ArgumentParser('X-Decoder All-in-One Demo', add_help=False)
    parser.add_argument('--conf_files', default="configs/xdecoder/svlp_focalt_lang.yaml", metavar="FILE", help='path to config file', )
    args = parser.parse_args()

    return args

'''
build args
'''
args = parse_option()
opt = load_opt_from_config_files(args.conf_files)
opt = init_distributed(opt)

# META DATA
pretrained_pth_last = os.path.join("xdecoder_focalt_last.pt")
pretrained_pth_novg = os.path.join("xdecoder_focalt_last_novg.pt")

if not os.path.exists(pretrained_pth_last):
    os.system("wget {}".format("https://projects4jw.blob.core.windows.net/x-decoder/release/xdecoder_focalt_last.pt"))

if not os.path.exists(pretrained_pth_novg):
    os.system("wget {}".format("https://projects4jw.blob.core.windows.net/x-decoder/release/xdecoder_focalt_last_novg.pt"))


'''
build model
'''
model_last = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth_last).eval().cuda()
model_cap = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth_novg).eval().cuda()

with torch.no_grad():
    model_last.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background", "background"], is_eval=True)
    model_cap.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background", "background"], is_eval=True)

'''
inference model
'''

@torch.no_grad()
def inference(image, task, *args, **kwargs):
    image = image.convert("RGB")
    with torch.autocast(device_type='cuda', dtype=torch.float16):
        if task == 'Referring Editing':
            return referring_inpainting(model_last, image, *args, **kwargs)
        elif task == 'Referring Segmentation':
            return referring_segmentation(model_last, image, *args, **kwargs)
        elif task == 'Open Vocabulary Semantic Segmentation':
            return open_semseg(model_last, image, *args, **kwargs)
        elif task == 'Open Vocabulary Panoptic Segmentation':
            return open_panoseg(model_last, image, *args, **kwargs)
        elif task == 'Open Vocabulary Instance Segmentation':
            return open_instseg(model_last, image, *args, **kwargs)
        elif task == 'Image Captioning':
            return image_captioning(model_cap, image, *args, **kwargs)
        elif task == 'Referring Captioning (Beta)':
            return referring_captioning([model_last, model_cap], image, *args, **kwargs)
        elif task == 'Text Retrieval':
            return text_retrieval(model_cap, image, *args, **kwargs)
        elif task == 'Image/Region Retrieval':
            return region_retrieval([model_cap, model_last], image, *args, **kwargs)

'''
launch app
'''
title = "X-Decoder All-in-One Demo"
description = """<p style='text-align: center'> <a href='https://x-decoder-vl.github.io/' target='_blank'>Project Page</a> | <a href='https://arxiv.org/pdf/2212.11270.pdf' target='_blank'>Paper</a> | <a href='https://github.com/microsoft/X-Decoder' target='_blank'>Github Repo</a> | <a href='https://youtu.be/wYp6vmyolqE' target='_blank'>Video</a> </p>
<p>Skip the queue by duplicating this space and upgrading to GPU in settings</p>
<a href="https://huggingface.co/spaces/xdecoder/Demo?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
"""

article = "The Demo is Run on X-Decoder (Focal-T)."

inputs = [gr.inputs.Image(type='pil'), gr.inputs.Radio(choices=["Referring Segmentation", "Referring Editing", 'Open Vocabulary Semantic Segmentation','Open Vocabulary Instance Segmentation', "Open Vocabulary Panoptic Segmentation", "Image Captioning", "Text Retrieval", "Image/Region Retrieval", "Referring Captioning (Beta)"], type="value", default="OpenVocab Semantic Segmentation", label="Task"), gr.Textbox(label="xdecoder_text"), gr.Textbox(label="inpainting_text"), gr.Textbox(label="task_description")]
gr.Interface(
    fn=inference,
    inputs=inputs,
    outputs=[
        gr.outputs.Image(
        type="pil",
        label="segmentation results"),
        gr.Textbox(label="text results"),
        gr.outputs.Image(
        type="pil",
        label="editing results"),
    ],
    examples=[
    ["./images/fruit.jpg", "Referring Segmentation", "The larger watermelon.,The front white flower.,White tea pot.,Flower bunch.,white vase.,The peach on the left.,The brown knife.,The handkerchief.", '', 'Format: s,s,s'],
    ["./images/apples.jpg", "Referring Editing", "the green apple", 'a red apple', 'x-decoder + ldm (inference takes ~20s),  use inpainting_text "clean and empty scene" for image inpainting'],
    ["./images/animals.png", "Open Vocabulary Semantic Segmentation", "zebra,antelope,giraffe,ostrich,sky,water,grass,sand,tree", '', 'Format: x,x,x'],
    ["./images/street.jpg", "Open Vocabulary Panoptic Segmentation", "stuff:building,sky,street,tree,rock,sidewalk;thing:car,person,traffic light", '', 'Format: stuff:x,x,x;thing:y,y,y'],
    ["./images/owls.jpeg", "Open Vocabulary Instance Segmentation", "owl", '', 'Format: y,y,y'],
    ["./images/mountain.jpeg", "Image Captioning", "", '', ''],
    ["./images/rose.webp", "Text Retrieval", "lily,rose,peoney,tulip", '', 'Format: s,s,s'],
    ["./images/region_retrieval.png", "Image/Region Retrieval", "The tangerine on the plate.", '', 'Please describe the object in a detailed way (80 images in the pool).'],
    ["./images/landscape.jpg", "Referring Captioning (Beta)", "cloud", '', 'Please fill in a noun/noun phrase. (may start with a/the)'],
    ],
    title=title,
    description=description,
    article=article,
    allow_flagging='never',
    cache_examples=True,
).launch()