|
import os |
|
import sys |
|
import cv2 |
|
import json |
|
import argparse |
|
import numpy as np |
|
from tqdm import tqdm |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch.utils import data |
|
import torchvision.transforms as transform |
|
|
|
|
|
sys.path.insert(0, "third_party/CLIP/") |
|
os.system(f"pip3 install -Ue third_party/CLIP/") |
|
|
|
os.system(f"pip install git+https://github.com/facebookresearch/detectron2.git") |
|
|
|
|
|
|
|
from detectron2.config import CfgNode as CN |
|
from detectron2.engine import DefaultPredictor |
|
from detectron2.config import get_cfg |
|
from detectron2.utils.visualizer import Visualizer, ColorMode |
|
from detectron2.data import MetadataCatalog |
|
from detectron2.utils.file_io import PathManager |
|
from detectron2.utils.logger import setup_logger |
|
from detectron2.projects.deeplab import add_deeplab_config |
|
from detectron2.structures import Boxes, ImageList, Instances, BitMasks |
|
|
|
|
|
from mask2former import add_mask_former_config |
|
|
|
setup_logger() |
|
logger = setup_logger(name="freeseg") |
|
|
|
|
|
class Predictor(DefaultPredictor): |
|
|
|
def forward(self, original_image, labels=None): |
|
with torch.no_grad(): |
|
|
|
if self.input_format == "RGB": |
|
|
|
original_image = original_image[:, :, ::-1] |
|
height, width = original_image.shape[:2] |
|
image = self.aug.get_transform(original_image).apply_image(original_image) |
|
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) |
|
|
|
inputs = {"image": image, "height": height, "width": width} |
|
predictions = self.model([inputs], labels)[0] |
|
return predictions |
|
|
|
|
|
def create_predictor(task_names): |
|
cfg = get_cfg() |
|
add_deeplab_config(cfg) |
|
add_mask_former_config(cfg) |
|
cfg.merge_from_file("configs/coco-stuff-164k-156/mask2former_R101c_alltask_bs32_60k.yaml") |
|
cfg.MODEL.WEIGHTS = 'checkpoints/model_demo.pth' |
|
cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = True |
|
cfg.MODEL.DEVICE = "cpu" |
|
cfg.INPUT.TASK_NAME = [task.lower() for task in task_names] |
|
predictor = Predictor(cfg) |
|
|
|
return predictor |
|
|
|
|
|
|
|
""" |
|
# FreeSeg Demo |
|
""" |
|
|
|
title = "FreeSeg" |
|
description = """ |
|
<p style='text-align: center'> <a href='https://freeseg.github.io/' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/2303.17225' target='_blank'>Paper</a> | <a href='https://github.com/bytedance/FreeSeg' target='_blank'>Code</a> </p> |
|
|
|
Gradio demo for FreeSeg: Unified, Universal and Open-Vocabulary Image Segmentation. \n |
|
You may click on of the examples or upload your own image. \n |
|
|
|
""" |
|
|
|
article = """ |
|
<p style='text-align: center'><a href='https://arxiv.org/abs/2303.17225' target='_blank'>FreeSeg: Unified, Universal and Open-Vocabulary Image Segmentation</a> | <a href='https://github.com/bytedance/FreeSeg' target='_blank'>Github Repo</a></p> |
|
""" |
|
|
|
examples = [ |
|
[ |
|
"examples/cat.jpg", |
|
"cat, grass, stone, other", |
|
["Semantic segmentation."], |
|
], |
|
[ |
|
"examples/bus.jpg", |
|
"bus, person, road, building, tree, sky, other", |
|
["Semantic segmentation.", "Instance segmentation.", "Panoptic segmentation."], |
|
] |
|
] |
|
|
|
|
|
|
|
def inference(image_path, labels, task_list): |
|
|
|
labels = [lbl.strip() for lbl in labels.split(",")] |
|
predictor = create_predictor(task_list) |
|
|
|
coco_metadata = MetadataCatalog.get("coco_2017_test_full_task") |
|
coco_metadata.stuff_classes[:len(labels)] = labels |
|
|
|
image = Image.open(image_path) |
|
|
|
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
|
outputs = predictor.forward(image, labels) |
|
|
|
results = {"sem_seg": None, "ins_seg": None, "pan_seg": None} |
|
if "Semantic segmentation." in task_list: |
|
sem_seg_out = outputs["sem_seg"].argmax(0).to("cpu") |
|
|
|
image_back = np.zeros_like(image) |
|
v = Visualizer(image[:, :, ::-1], coco_metadata, scale=0.6, instance_mode=ColorMode.IMAGE) |
|
semantic_result = v.draw_sem_seg(sem_seg_out, alpha=0.6).get_image() |
|
|
|
semantic_result = Image.fromarray(semantic_result) |
|
results["sem_seg"] = semantic_result |
|
|
|
if "Panoptic segmentation." in task_list: |
|
coco_metadata.thing_classes[:len(labels)] = labels |
|
|
|
panvis = Visualizer( |
|
image[:, :, ::-1], |
|
coco_metadata, |
|
scale=0.6, |
|
instance_mode=ColorMode.IMAGE |
|
) |
|
panoptic_seg, segments_info = outputs["panoptic_seg"] |
|
|
|
panvis_output = panvis.draw_panoptic_seg_predictions( |
|
panoptic_seg.cpu(), segments_info, alpha=0.6 |
|
) |
|
|
|
panvis_output = Image.fromarray(panvis_output.get_image()) |
|
results["pan_seg"] = panvis_output |
|
|
|
if "Instance segmentation." in task_list: |
|
insvis = Visualizer( |
|
image[:, :, ::-1], |
|
coco_metadata, |
|
scale=0.6, |
|
instance_mode=ColorMode.SEGMENTATION |
|
) |
|
instances = outputs["instances"].to(torch.device("cpu")) |
|
|
|
insvis_output = insvis.draw_instance_predictions(predictions=instances) |
|
|
|
insvis_output = Image.fromarray(insvis_output.get_image()) |
|
results["ins_seg"] = insvis_output |
|
|
|
return results["sem_seg"], results["ins_seg"], results["pan_seg"] |
|
|
|
|
|
|
|
with gr.Blocks(title=title) as demo: |
|
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>") |
|
gr.Markdown(description) |
|
input_components = [] |
|
output_components = [] |
|
|
|
|
|
with gr.Row().style(equal_height=True, mobile_collapse=True): |
|
with gr.Column(scale=3, variant="panel") as input_component_column: |
|
input_image_gr = gr.inputs.Image(type="filepath") |
|
labels_gr = gr.inputs.Textbox(default="", label="Class labels") |
|
task_list_gr = gr.inputs.CheckboxGroup( |
|
choices=["Semantic segmentation.", "Instance segmentation.", "Panoptic segmentation."], |
|
default=["Semantic segmentation."], |
|
label="Task names", |
|
) |
|
input_components.extend([input_image_gr, labels_gr, task_list_gr]) |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("Submit", variant="primary") |
|
clear_btn = gr.Button("Clear") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Row(scale=3, variant="panel") as output_component_row: |
|
output_image_sem_gr = gr.outputs.Image(label="Semantic segmentation", type="pil") |
|
output_components.append(output_image_sem_gr) |
|
|
|
output_image_ins_gr = gr.outputs.Image(label="Instance segmentation", type="pil") |
|
output_components.append(output_image_ins_gr) |
|
|
|
output_image_pan_gr = gr.outputs.Image(label="Panoptic segmentation", type="pil") |
|
output_components.append(output_image_pan_gr) |
|
|
|
with gr.Column(scale=2): |
|
examples_handler = gr.Examples( |
|
examples=examples, |
|
inputs=[c for c in input_components if not isinstance(c, gr.State)], |
|
outputs=[c for c in output_components if not isinstance(c, gr.State)], |
|
fn=inference, |
|
cache_examples=torch.cuda.is_available(), |
|
examples_per_page=5, |
|
) |
|
|
|
gr.Markdown(article) |
|
|
|
submit_btn.click( |
|
inference, |
|
input_components, |
|
output_components, |
|
api_name="predict", |
|
scroll_to_output=True, |
|
) |
|
|
|
clear_btn.click( |
|
None, |
|
[], |
|
(input_components + output_components + [input_component_column]), |
|
_js=f"""() => {json.dumps( |
|
[component.cleared_value if hasattr(component, "cleared_value") else None |
|
for component in input_components + output_components] + ( |
|
[gr.Column.update(visible=True)] |
|
) |
|
+ ([gr.Column.update(visible=False)]) |
|
)} |
|
""", |
|
) |
|
|
|
demo.launch() |
|
|
|
|