SAM-CAT-Seg / app.py
seokju cho
initial commit
f8f62f3
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
import argparse
import glob
import multiprocessing as mp
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = ""
try:
import detectron2
except ModuleNotFoundError:
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
try:
import segment_anything
except ModuleNotFoundError:
os.system('pip install git+https://github.com/facebookresearch/segment-anything.git')
# fmt: off
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
# fmt: on
import tempfile
import time
import warnings
import cv2
import numpy as np
import tqdm
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.utils.logger import setup_logger
from cat_seg import add_cat_seg_config
from demo.predictor import VisualizationDemo
import gradio as gr
import torch
from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
# constants
WINDOW_NAME = "MaskFormer demo"
def setup_cfg(args):
# load config from file and command-line arguments
cfg = get_cfg()
add_deeplab_config(cfg)
add_cat_seg_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
if torch.cuda.is_available():
cfg.MODEL.DEVICE = "cuda"
cfg.freeze()
return cfg
def get_parser():
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
parser.add_argument(
"--config-file",
default="configs/vitl_swinb_384.yaml",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--input",
nargs="+",
help="A list of space separated input images; "
"or a single glob pattern such as 'directory/*.jpg'",
)
parser.add_argument(
"--opts",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=(
[
"MODEL.WEIGHTS", "model_final_cls.pth",
"MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
"MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
"TEST.SLIDING_WINDOW", "True",
"MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]",
"MODEL.PROMPT_ENSEMBLE_TYPE", "single",
"MODEL.DEVICE", "cpu",
]),
nargs=argparse.REMAINDER,
)
return parser
def save_masks(preds, text):
preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() # C H W
for i, t in enumerate(text):
dir = f"mask_{t}.png"
mask = preds == i
cv2.imwrite(dir, mask * 255)
def predict(image, text, model_type):
#import pdb; pdb.set_trace()
#use_sam = True #
use_sam = model_type != "CAT-Seg"
predictions, visualized_output = demo.run_on_image(image, text, use_sam)
#save_masks(predictions, text.split(','))
canvas = fc(visualized_output.fig)
canvas.draw()
out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))
return out[..., ::-1]
if __name__ == "__main__":
args = get_parser().parse_args()
cfg = setup_cfg(args)
global demo
demo = VisualizationDemo(cfg)
iface = gr.Interface(
fn=predict,
inputs=[gr.Image(), gr.Textbox(placeholder='background, cat, person'), ], #gr.Radio(["CAT-Seg", "Segment Anycat"], value="CAT-Seg")],
outputs="image",
description="""## Segment Anything with CAT-Seg!
Welcome to the Segment Anything with CAT-Seg!
In this demo, we combine state-of-the-art open-vocabulary semantic segmentation model, CAT-Seg with SAM(Segment Anything) for semantically labelling mask predictions from SAM.
Please note that this is an optimized version of the full model, and as such, its performance may be limited compared to the full model.
Also, the demo might run on a CPU depending on the demand, so it may take a little time to process your image.
To get started, simply upload an image and a comma-separated list of categories, and let the model work its magic!""")
iface.launch()