File size: 2,310 Bytes
24910f2
c2f4d83
 
24910f2
e3dc417
 
b17bb63
24910f2
 
 
 
 
 
a4c8bb5
24910f2
b17bb63
24910f2
 
b17bb63
 
 
 
6955470
b17bb63
4378f9c
24910f2
a4c8bb5
82d85d1
b17bb63
9964a9f
bca0412
24910f2
de5ee51
 
 
1263ab0
ef7cf07
a4c8bb5
dea2364
 
8fc220a
dea2364
c2f4d83
d67669b
88d8a4f
 
24910f2
59bd0c5
de5ee51
a4c8bb5
 
de5ee51
24910f2
ef7cf07
 
b17bb63
c2f4d83
 
24910f2
 
c2f4d83
de5ee51
ef7cf07
742d503
ef7cf07
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
tree-segmentation
Proof of concept showing effectiveness of a fine tuned instance segmentation model for detecting trees.
"""
import os
import cv2
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
from transformers import DetrFeatureExtractor, DetrForSegmentation
from PIL import Image
import gradio as gr
import numpy as np
import torch
import torchvision
import detectron2

# import some common detectron2 utilities
import itertools
import seaborn as sns
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.checkpoint import DetectionCheckpointer

cfg = get_cfg()
cfg.merge_from_file("model_weights/treev1_cfg.yaml")
cfg.MODEL.DEVICE='cpu'
cfg.MODEL.WEIGHTS = "model_weights/treev1_best.pth"  
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2

def segment_image(im, confidence_threshold):
    # cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.25
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
    predictor = DefaultPredictor(cfg)
    im = np.array(im)
    outputs = predictor(im)
    v = Visualizer(im[:, :, ::-1],
                   scale=0.5,
                   instance_mode=ColorMode.SEGMENTATION
    )
    print(len(outputs["instances"])," trees detected.")
    out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    
    return Image.fromarray(out.get_image()[:, :, ::-1])

# gradio components 

gr_slider_confidence = gr.inputs.Slider(0,1,.1,.7,
                                        label='Set confidence threshold % for masks')

# gradio outputs
inputs = gr.inputs.Image(type="pil", label="Input Image")
outputs = gr.outputs.Image(type="pil", label="Output Image")

title = "Tree Segmentation"
description = "An instance segmentation demo for identifying trees in aerial images using DETR (End-to-End Object Detection) model with MaskRCNN-101 backbone"

# Create user interface and launch
gr.Interface(segment_image, 
                inputs = [inputs, gr_slider_confidence],
                outputs = outputs,
                 title = title,
                description = description).launch(debug=True)