tcl / app.py
khanrc
Update notice message on inference time
e89c8bc
# burrow some code from https://huggingface.co/spaces/xvjiarui/ODISE/tree/main
import os
import sys
from importlib.util import find_spec
print("Prepare demo ...")
if not os.path.exists("tcl.pth"):
print("Download TCL checkpoint ...")
os.system("wget -q https://github.com/kakaobrain/tcl/releases/download/v1.0.0/tcl.pth")
if not (find_spec("mmcv") and find_spec("mmseg")):
print("Install mmcv & mmseg ...")
os.system("mim install mmcv-full==1.6.2 mmsegmentation==0.27.0")
if not find_spec("detectron2"):
print("Install detectron ...")
os.system("pip install git+https://github.com/facebookresearch/detectron2.git")
sys.path.insert(0, "./tcl/")
print(" -- done.")
import json
from contextlib import ExitStack
import gradio as gr
import torch
from detectron2.evaluation import inference_context
from predictor import build_demo_model
model = build_demo_model()
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print(f"device: {device}")
model.to(device)
title = "TCL: Text-grounded Contrastive Learning"
title2 = "for Unsupervised Open-world Semantic Segmentation"
title = title + "<br/>" + title2
description_head = """
<p style='text-align: center'> <a href='https://arxiv.org/abs/2212.00785' target='_blank'>Paper</a> | <a href='https://github.com/kakaobrain/tcl' target='_blank'>Code</a> </p>
"""
description_body = f"""
Gradio Demo for "Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs".
Explore TCL's capability to perform open-world semantic segmentation **without any mask annotations**. Choose from provided examples or upload your own image. Use the query format `bg; class1; class2; ...`, with `;` as the separator, and the `bg` background query being optional (as in the third example).
This demo highlights the strengths and limitations of unsupervised open-world segmentation methods. Although TCL can handle arbitrary concepts, accurately capturing object boundaries without mask annotation remains a challenge.
"""
if device.type == "cpu":
description_body += f"\nThis demo is running on a free CPU device. Inference times may take around 5-10 seconds."
description = description_head + description_body
article = """
<p style='text-align: center'><a href='https://arxiv.org/abs/2212.00785' target='_blank'>Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs</a> | <a href='https://github.com/kakaobrain/tcl' target='_blank'>Github Repo</a></p>
"""
voc_examples = [
["examples/voc_59.jpg", "bg; cat; dog"],
["examples/voc_97.jpg", "bg; car"],
["examples/voc_266.jpg", "bg; dog"],
["examples/voc_294.jpg", "bg; bird"],
["examples/voc_864.jpg", "bg; cat"],
["examples/voc_1029.jpg", "bg; bus"],
]
examples = [
[
"examples/dogs.jpg",
"bg; corgi; shepherd",
],
[
"examples/dogs.jpg",
"bg; dog",
],
[
"examples/dogs.jpg",
"corgi; shepherd; lawn, trees, and fallen leaves",
],
[
"examples/banana.jpg",
"bg; banana",
],
[
"examples/banana.jpg",
"bg; red banana; green banana; yellow banana",
],
[
"examples/frodo_sam_gollum.jpg",
"bg; frodo; gollum; samwise",
],
[
"examples/frodo_sam_gollum.jpg",
"bg; rocks; monster; boys with cape"
],
[
"examples/mb_mj.jpg",
"bg; marlon brando; michael jackson",
],
]
examples = examples + voc_examples
def inference(img, query):
query = query.split(";")
query = [v.strip() for v in query]
with ExitStack() as stack:
stack.enter_context(inference_context(model))
stack.enter_context(torch.no_grad())
if device.type == "cuda":
stack.enter_context(torch.autocast("cuda"))
visualized_output = model.forward_vis(img, query)
return visualized_output
theme = gr.themes.Soft(text_size=gr.themes.sizes.text_md, primary_hue="teal")
with gr.Blocks(title=title, theme=theme) as demo:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 0rem'>" + title + "</h1>")
# gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title2 + "</h1>")
gr.Markdown(description)
input_components = []
output_components = []
with gr.Row():
with gr.Column(scale=4, variant="panel"):
output_image_gr = gr.outputs.Image(label="Segmentation", type="pil").style(height=300)
output_components.append(output_image_gr)
with gr.Row():
input_gr = gr.inputs.Image(type="pil")
query_gr = gr.inputs.Textbox(default="", label="Query")
input_components.extend([input_gr, query_gr])
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit", variant="primary")
inputs = [c for c in input_components if not isinstance(c, gr.State)]
outputs = [c for c in output_components if not isinstance(c, gr.State)]
with gr.Column(scale=2):
examples_handler = gr.Examples(
examples=examples,
inputs=inputs,
outputs=outputs,
fn=inference,
# cache_examples=True,
examples_per_page=7,
)
gr.Markdown(article)
submit_btn.click(
inference,
input_components,
output_components,
scroll_to_output=True,
)
clear_btn.click(
None,
[],
(input_components + output_components),
_js=f"""() => {json.dumps(
[component.cleared_value if hasattr(component, "cleared_value") else None
for component in input_components + output_components] + (
[gr.Column.update(visible=True)]
)
+ ([gr.Column.update(visible=False)])
)}
""",
)
demo.launch()
# demo.launch(server_name="0.0.0.0", server_port=9718)