FreeSeg / app.py
qinjie
update
f74f01a
import os
import sys
import cv2
import json
import argparse
import numpy as np
from tqdm import tqdm
from PIL import Image
import gradio as gr
import torch
import torch.nn.functional as F
from torch.utils import data
import torchvision.transforms as transform
sys.path.insert(0, "third_party/CLIP/")
os.system(f"pip3 install -Ue third_party/CLIP/")
os.system(f"pip install git+https://github.com/facebookresearch/detectron2.git")
# import some common detectron2 utilities
from detectron2.config import CfgNode as CN
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.structures import Boxes, ImageList, Instances, BitMasks
# import Mask2Former project
from mask2former import add_mask_former_config
setup_logger()
logger = setup_logger(name="freeseg")
class Predictor(DefaultPredictor):
def forward(self, original_image, labels=None):
with torch.no_grad():
# Apply pre-processing to image.
if self.input_format == "RGB":
# whether the model expects BGR inputs or RGB
original_image = original_image[:, :, ::-1]
height, width = original_image.shape[:2]
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
inputs = {"image": image, "height": height, "width": width}
predictions = self.model([inputs], labels)[0]
return predictions
def create_predictor(task_names):
cfg = get_cfg()
add_deeplab_config(cfg)
add_mask_former_config(cfg)
cfg.merge_from_file("configs/coco-stuff-164k-156/mask2former_R101c_alltask_bs32_60k.yaml")
cfg.MODEL.WEIGHTS = 'checkpoints/model_demo.pth'
cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = True
cfg.MODEL.DEVICE = "cpu"
cfg.INPUT.TASK_NAME = [task.lower() for task in task_names]
predictor = Predictor(cfg)
return predictor
"""
# FreeSeg Demo
"""
title = "FreeSeg"
description = """
<p style='text-align: center'> <a href='https://freeseg.github.io/' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/2303.17225' target='_blank'>Paper</a> | <a href='https://github.com/bytedance/FreeSeg' target='_blank'>Code</a> </p>
Gradio demo for FreeSeg: Unified, Universal and Open-Vocabulary Image Segmentation. \n
You may click on of the examples or upload your own image. \n
""" # noqa
article = """
<p style='text-align: center'><a href='https://arxiv.org/abs/2303.17225' target='_blank'>FreeSeg: Unified, Universal and Open-Vocabulary Image Segmentation</a> | <a href='https://github.com/bytedance/FreeSeg' target='_blank'>Github Repo</a></p>
""" # noqa
examples = [
[
"examples/cat.jpg",
"cat, grass, stone, other",
["Semantic segmentation."],
],
[
"examples/bus.jpg",
"bus, person, road, building, tree, sky, other",
["Semantic segmentation.", "Instance segmentation.", "Panoptic segmentation."],
]
]
def inference(image_path, labels, task_list):
labels = [lbl.strip() for lbl in labels.split(",")]
predictor = create_predictor(task_list)
coco_metadata = MetadataCatalog.get("coco_2017_test_full_task")
coco_metadata.stuff_classes[:len(labels)] = labels
image = Image.open(image_path)
# image = np.array(image)
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
outputs = predictor.forward(image, labels)
results = {"sem_seg": None, "ins_seg": None, "pan_seg": None}
if "Semantic segmentation." in task_list:
sem_seg_out = outputs["sem_seg"].argmax(0).to("cpu")
image_back = np.zeros_like(image)
v = Visualizer(image[:, :, ::-1], coco_metadata, scale=0.6, instance_mode=ColorMode.IMAGE)
semantic_result = v.draw_sem_seg(sem_seg_out, alpha=0.6).get_image()
semantic_result = Image.fromarray(semantic_result)
results["sem_seg"] = semantic_result
if "Panoptic segmentation." in task_list:
coco_metadata.thing_classes[:len(labels)] = labels
panvis = Visualizer(
image[:, :, ::-1],
coco_metadata,
scale=0.6,
instance_mode=ColorMode.IMAGE
)
panoptic_seg, segments_info = outputs["panoptic_seg"]
panvis_output = panvis.draw_panoptic_seg_predictions(
panoptic_seg.cpu(), segments_info, alpha=0.6
)
panvis_output = Image.fromarray(panvis_output.get_image())
results["pan_seg"] = panvis_output
if "Instance segmentation." in task_list:
insvis = Visualizer(
image[:, :, ::-1],
coco_metadata,
scale=0.6,
instance_mode=ColorMode.SEGMENTATION
)
instances = outputs["instances"].to(torch.device("cpu"))
insvis_output = insvis.draw_instance_predictions(predictions=instances)
insvis_output = Image.fromarray(insvis_output.get_image())
results["ins_seg"] = insvis_output
return results["sem_seg"], results["ins_seg"], results["pan_seg"]
with gr.Blocks(title=title) as demo:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
gr.Markdown(description)
input_components = []
output_components = []
with gr.Row().style(equal_height=True, mobile_collapse=True):
with gr.Column(scale=3, variant="panel") as input_component_column:
input_image_gr = gr.inputs.Image(type="filepath")
labels_gr = gr.inputs.Textbox(default="", label="Class labels")
task_list_gr = gr.inputs.CheckboxGroup(
choices=["Semantic segmentation.", "Instance segmentation.", "Panoptic segmentation."],
default=["Semantic segmentation."],
label="Task names",
)
input_components.extend([input_image_gr, labels_gr, task_list_gr])
with gr.Row():
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Row():
with gr.Row(scale=3, variant="panel") as output_component_row:
output_image_sem_gr = gr.outputs.Image(label="Semantic segmentation", type="pil")
output_components.append(output_image_sem_gr)
output_image_ins_gr = gr.outputs.Image(label="Instance segmentation", type="pil")
output_components.append(output_image_ins_gr)
output_image_pan_gr = gr.outputs.Image(label="Panoptic segmentation", type="pil")
output_components.append(output_image_pan_gr)
with gr.Column(scale=2):
examples_handler = gr.Examples(
examples=examples,
inputs=[c for c in input_components if not isinstance(c, gr.State)],
outputs=[c for c in output_components if not isinstance(c, gr.State)],
fn=inference,
cache_examples=torch.cuda.is_available(),
examples_per_page=5,
)
gr.Markdown(article)
submit_btn.click(
inference,
input_components,
output_components,
api_name="predict",
scroll_to_output=True,
)
clear_btn.click(
None,
[],
(input_components + output_components + [input_component_column]),
_js=f"""() => {json.dumps(
[component.cleared_value if hasattr(component, "cleared_value") else None
for component in input_components + output_components] + (
[gr.Column.update(visible=True)]
)
+ ([gr.Column.update(visible=False)])
)}
""",
)
demo.launch()