# burrow some code from https://huggingface.co/spaces/xvjiarui/ODISE/tree/main import os import sys from importlib.util import find_spec print("Prepare demo ...") if not os.path.exists("tcl.pth"): print("Download TCL checkpoint ...") os.system("wget -q https://github.com/kakaobrain/tcl/releases/download/v1.0.0/tcl.pth") if not (find_spec("mmcv") and find_spec("mmseg")): print("Install mmcv & mmseg ...") os.system("mim install mmcv-full==1.6.2 mmsegmentation==0.27.0") if not find_spec("detectron2"): print("Install detectron ...") os.system("pip install git+https://github.com/facebookresearch/detectron2.git") sys.path.insert(0, "./tcl/") print(" -- done.") import json from contextlib import ExitStack import gradio as gr import torch from detectron2.evaluation import inference_context from predictor import build_demo_model model = build_demo_model() if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") print(f"device: {device}") model.to(device) title = "TCL: Text-grounded Contrastive Learning" title2 = "for Unsupervised Open-world Semantic Segmentation" title = title + "
" + title2 description_head = """

Paper | Code

""" description_body = f""" Gradio Demo for "Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs". Explore TCL's capability to perform open-world semantic segmentation **without any mask annotations**. Choose from provided examples or upload your own image. Use the query format `bg; class1; class2; ...`, with `;` as the separator, and the `bg` background query being optional (as in the third example). This demo highlights the strengths and limitations of unsupervised open-world segmentation methods. Although TCL can handle arbitrary concepts, accurately capturing object boundaries without mask annotation remains a challenge. """ if device.type == "cpu": description_body += f"\nThis demo is running on a free CPU device. Inference times may take around 5-10 seconds." description = description_head + description_body article = """

Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs | Github Repo

""" voc_examples = [ ["examples/voc_59.jpg", "bg; cat; dog"], ["examples/voc_97.jpg", "bg; car"], ["examples/voc_266.jpg", "bg; dog"], ["examples/voc_294.jpg", "bg; bird"], ["examples/voc_864.jpg", "bg; cat"], ["examples/voc_1029.jpg", "bg; bus"], ] examples = [ [ "examples/dogs.jpg", "bg; corgi; shepherd", ], [ "examples/dogs.jpg", "bg; dog", ], [ "examples/dogs.jpg", "corgi; shepherd; lawn, trees, and fallen leaves", ], [ "examples/banana.jpg", "bg; banana", ], [ "examples/banana.jpg", "bg; red banana; green banana; yellow banana", ], [ "examples/frodo_sam_gollum.jpg", "bg; frodo; gollum; samwise", ], [ "examples/frodo_sam_gollum.jpg", "bg; rocks; monster; boys with cape" ], [ "examples/mb_mj.jpg", "bg; marlon brando; michael jackson", ], ] examples = examples + voc_examples def inference(img, query): query = query.split(";") query = [v.strip() for v in query] with ExitStack() as stack: stack.enter_context(inference_context(model)) stack.enter_context(torch.no_grad()) if device.type == "cuda": stack.enter_context(torch.autocast("cuda")) visualized_output = model.forward_vis(img, query) return visualized_output theme = gr.themes.Soft(text_size=gr.themes.sizes.text_md, primary_hue="teal") with gr.Blocks(title=title, theme=theme) as demo: gr.Markdown("

" + title + "

") # gr.Markdown("

" + title2 + "

") gr.Markdown(description) input_components = [] output_components = [] with gr.Row(): with gr.Column(scale=4, variant="panel"): output_image_gr = gr.outputs.Image(label="Segmentation", type="pil").style(height=300) output_components.append(output_image_gr) with gr.Row(): input_gr = gr.inputs.Image(type="pil") query_gr = gr.inputs.Textbox(default="", label="Query") input_components.extend([input_gr, query_gr]) with gr.Row(): clear_btn = gr.Button("Clear") submit_btn = gr.Button("Submit", variant="primary") inputs = [c for c in input_components if not isinstance(c, gr.State)] outputs = [c for c in output_components if not isinstance(c, gr.State)] with gr.Column(scale=2): examples_handler = gr.Examples( examples=examples, inputs=inputs, outputs=outputs, fn=inference, # cache_examples=True, examples_per_page=7, ) gr.Markdown(article) submit_btn.click( inference, input_components, output_components, scroll_to_output=True, ) clear_btn.click( None, [], (input_components + output_components), _js=f"""() => {json.dumps( [component.cleared_value if hasattr(component, "cleared_value") else None for component in input_components + output_components] + ( [gr.Column.update(visible=True)] ) + ([gr.Column.update(visible=False)]) )} """, ) demo.launch() # demo.launch(server_name="0.0.0.0", server_port=9718)