File size: 3,670 Bytes
02652ab
7057648
ab5c947
7057648
02652ab
a7f52ee
99f5e4a
330e9ae
2220924
330e9ae
 
2220924
5a7a90e
 
 
 
7057648
2220924
02652ab
 
 
 
802af0c
 
02652ab
 
 
802af0c
 
02652ab
 
 
 
 
802af0c
 
 
 
 
 
 
02652ab
 
 
802af0c
 
 
 
 
 
 
02652ab
802af0c
 
02652ab
802af0c
 
 
 
 
 
 
 
02652ab
 
 
 
 
 
8b17bc6
02652ab
802af0c
02652ab
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os 
import os
os.system("nvcc --version")
print(os.environ.get('CUDA_PATH'))
os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
os.system("git clone https://github.com/facebookresearch/Mask2Former.git")

os.chdir("Mask2Former")
os.system("pwd")
os.system("pip install git+https://github.com/cocodataset/panopticapi.git")
os.chdir("mask2former/modeling/pixel_decoder/ops")
os.system("pwd")

os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0"
os.environ["FORCE_CUDA"] = "1"
os.system("python setup.py build install")
os.chdir("/home/user/app/Mask2Former/")
os.system("pwd")
import gradio as gr
# check pytorch installation: 
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
setup_logger(name="mask2former")

# import some common libraries
import numpy as np
import cv2
import torch

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog
from detectron2.projects.deeplab import add_deeplab_config
coco_metadata = MetadataCatalog.get("coco_2017_val_panoptic")

# import Mask2Former project
from mask2former import add_maskformer2_config

cfg = get_cfg()
cfg.MODEL.DEVICE='cpu'
add_deeplab_config(cfg)
add_maskformer2_config(cfg)
cfg.merge_from_file("configs/coco/panoptic-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_100ep.yaml")
cfg.MODEL.WEIGHTS = 'https://dl.fbaipublicfiles.com/maskformer/mask2former/coco/panoptic/maskformer2_swin_large_IN21k_384_bs16_100ep/model_final_f07440.pkl'
cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = True
cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = True
cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = True
predictor = DefaultPredictor(cfg)
outputs = predictor(im)

def inference(img):
    im = cv2.imread(img)
    v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)
    panoptic_result = v.draw_panoptic_seg(outputs["panoptic_seg"][0].to("cpu"), outputs["panoptic_seg"][1]).get_image()
    v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)
    instance_result = v.draw_instance_predictions(outputs["instances"].to("cpu")).get_image()
    v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)
    semantic_result = v.draw_sem_seg(outputs["sem_seg"].argmax(0).to("cpu")).get_image()    
    return Image.fromarray(np.uint8(panoptic_result)).convert('RGB'),Image.fromarray(np.uint8(instance_result)).convert('RGB'),Image.fromarray(np.uint8(semantic_result)).convert('RGB')
    

title = "Detectron 2"
description = "Gradio demo for Detectron 2: A PyTorch-based modular object detection library. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://ai.facebook.com/blog/-detectron2-a-pytorch-based-modular-object-detection-library-/' target='_blank'>Detectron2: A PyTorch-based modular object detection library</a> | <a href='https://github.com/facebookresearch/detectron2' target='_blank'>Github Repo</a></p>"

examples = [['airplane.png']]

gr.Interface(inference, inputs=gr.inputs.Image(type="filepath"), outputs=[gr.outputs.Image(label="Panoptic segmentation",type="pil"),gr.outputs.Image(label="instance segmentation",type="pil"),gr.outputs.Image(label="semantic segmentation",type="pil")],enable_queue=True, title=title,
    description=description,
    article=article,
    examples=examples).launch()