File size: 4,520 Bytes
8ae7071
0639dc1
8ae7071
 
 
 
 
 
 
 
 
 
 
 
0639dc1
8ae7071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e793dfa
8ae7071
3451b3f
8ae7071
 
 
 
 
d4def87
8ae7071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242850e
8ae7071
 
242850e
cc6b299
8ae7071
 
 
 
 
 
242850e
8ae7071
 
 
008d1df
 
45a81a5
 
8ae7071
 
45a81a5
8ae7071
 
05b14d0
008d1df
 
45a81a5
008d1df
 
 
0639dc1
 
 
 
 
 
05b14d0
 
 
 
 
 
 
 
311bc12
8ae7071
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
aerial-segmentation
Proof of concept showing effectiveness of a fine tuned instance segmentation model for detecting trees.
"""
import os
import cv2
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
from transformers import DetrFeatureExtractor, DetrForSegmentation
from PIL import Image
import gradio as gr
import numpy as np
import torch
import torchvision
import detectron2
import json

# import some common detectron2 utilities
import itertools
import seaborn as sns
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.visualizer import ColorMode
from detectron2.structures import Instances


# Model for trees
tree_cfg = get_cfg()
tree_cfg.merge_from_file("tree_model_weights/treev2_cfg.yml")
tree_cfg.MODEL.DEVICE='cpu'
tree_cfg.MODEL.WEIGHTS = "tree_model_weights/treev2_final.pth"
tree_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
tree_predictor = DefaultPredictor(tree_cfg)

# Model for buildings
building_cfg = get_cfg()
building_cfg.merge_from_file("building_model_weight/buildings_poc_cfg.yml")
building_cfg.MODEL.DEVICE='cpu'
building_cfg.MODEL.WEIGHTS = "building_model_weight/model_final.pth"  
building_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 8
building_predictor = DefaultPredictor(building_cfg)

# A function that runs the buildings model on an given image and confidence threshold
def segment_building(im, confidence_threshold):
    building_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
    outputs = building_predictor(im)
    building_instances = outputs["instances"].to("cpu")

    return building_instances

# A function that runs the trees model on an given image and confidence threshold
def segment_tree(im, confidence_threshold):
    tree_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
    outputs = tree_predictor(im)
    tree_instances = outputs["instances"].to("cpu")

    return tree_instances

# Function to map strings to color mode
def map_color_mode(color_mode):
    if color_mode == "Black/white":
        return ColorMode.IMAGE_BW
    elif color_mode == "Random":
        return ColorMode.IMAGE
    elif color_mode == "Segmentation" or color_mode == None:
        return ColorMode.SEGMENTATION

def visualize_image(im, mode, tree_threshold:float, building_threshold:float, color_mode):
    im = np.array(im)
    color_mode = map_color_mode(color_mode)

    if mode == "Trees":
        instances = segment_tree(im, tree_threshold)
    elif mode == "Buildings":
        instances = segment_building(im, building_threshold)
    elif mode == "Both" or mode == None:
        tree_instances = segment_tree(im, tree_threshold)
        building_instances = segment_building(im, building_threshold)
        instances = Instances.cat([tree_instances, building_instances])

    metadata = MetadataCatalog.get("urban-trees-fdokv_train")
    print("metadata", type(metadata), metadata)
    print('metadata.get("thing_classes")', type(metadata.get("thing_classes")), metadata.get("thing_classes"))

    visualizer = Visualizer(im[:, :, ::-1],
                            metadata=metadata,
                            scale=0.5,
                            instance_mode=color_mode)

    dataset_names = MetadataCatalog.list()
    print(dataset_names)

    metadata = MetadataCatalog.get("urban-small_train")
    category_names = metadata.get("thing_classes")
    print(category_names)
    with open("building_model_weight/_annotations.coco.json", "r") as f:
        coco = json.load(f)
        categories = coco["categories"]
    print("categories", categories)
    metadata.thing_classes = [c["name"] for c in categories]
    print("metadata.thing_classes", metadata.thing_classes )
    # visualizer = Visualizer(im[:, :, ::-1],
    #                         metadata=metadata,
    #                         scale=0.5,
    #                         instance_mode=color_mode)
    # # in the visualizer, add category label names to detected instances
    # for instance in instances:
    #     label = category_names[instance["category_id"]]
    #     visualizer.draw_text(label, instance["bbox"][:2])

    output_image = visualizer.draw_instance_predictions(instances)
    
    return Image.fromarray(output_image.get_image()[:, :, ::-1])