File size: 3,825 Bytes
8ae7071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4def87
8ae7071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d812008
cc6b299
8ae7071
 
 
 
 
 
 
 
 
 
 
311bc12
 
 
 
 
 
8ae7071
311bc12
8ae7071
 
311bc12
 
 
 
 
8ae7071
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
tree-segmentation
Proof of concept showing effectiveness of a fine tuned instance segmentation model for detecting trees.
"""
import os
import cv2
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
from transformers import DetrFeatureExtractor, DetrForSegmentation
from PIL import Image
import gradio as gr
import numpy as np
import torch
import torchvision
import detectron2

# import some common detectron2 utilities
import itertools
import seaborn as sns
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.visualizer import ColorMode
from detectron2.structures import Instances


# Model for trees
tree_cfg = get_cfg()
tree_cfg.merge_from_file("tree_model_weights/tree_cfg.yml")
tree_cfg.MODEL.DEVICE='cpu'
tree_cfg.MODEL.WEIGHTS = "tree_model_weights/treev1_best.pth"
tree_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
tree_predictor = DefaultPredictor(tree_cfg)

# Model for buildings
building_cfg = get_cfg()
building_cfg.merge_from_file("building_model_weight/buildings_poc_cfg.yml")
building_cfg.MODEL.DEVICE='cpu'
building_cfg.MODEL.WEIGHTS = "building_model_weight/model_final.pth"  
building_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 8
building_predictor = DefaultPredictor(building_cfg)

# A function that runs the buildings model on an given image and confidence threshold
def segment_building(im, confidence_threshold):
    building_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
    outputs = building_predictor(im)
    building_instances = outputs["instances"].to("cpu")

    return building_instances

# A function that runs the trees model on an given image and confidence threshold
def segment_tree(im, confidence_threshold):
    tree_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
    outputs = tree_predictor(im)
    tree_instances = outputs["instances"].to("cpu")

    return tree_instances

# Function to map strings to color mode
def map_color_mode(color_mode):
    if color_mode == "Black/white":
        return ColorMode.IMAGE_BW
    elif color_mode == "Random":
        return ColorMode.IMAGE
    elif color_mode == "Segmentation":
        return ColorMode.SEGMENTATION

def visualize_image(im, mode="BOTH", tree_threshold=0.7, building_threshold=0.7, color_mode=ColorMode.SEGMENTATION):
    im = np.array(im)
    color_mode = map_color_mode(color_mode)

    if mode == "Trees":
        instances = segment_tree(im, tree_threshold)
    elif mode == "Buildings":
        instances = segment_building(im, building_threshold)
    elif mode == "Both":
        tree_instances = segment_tree(im, tree_threshold)
        building_instances = segment_building(im, building_threshold)
        instances = Instances.cat([tree_instances, building_instances])

    # visualizer = Visualizer(im[:, :, ::-1],
    #                         scale=0.5,
    #                         instance_mode=color_mode)
    
    metadata = MetadataCatalog.get("your_model_metadata")
    category_names = metadata.get("thing_classes")
    visualizer = Visualizer(im[:, :, ::-1],
                            metadata=metadata,
                            scale=0.5,
                            instance_mode=color_mode)
    # in the visualizer, add category label names to detected instances
    for instance in instances:
        label = category_names[instance["category_id"]]
        visualizer.draw_text(label, instance["bbox"][:2])

    output_image = visualizer.draw_instance_predictions(instances)
    
    return Image.fromarray(output_image.get_image()[:, :, ::-1])