File size: 7,483 Bytes
8ae7071
0639dc1
8ae7071
 
 
6825ad3
 
8ae7071
 
 
 
 
 
 
 
 
0639dc1
8ae7071
 
 
 
 
 
 
 
 
 
 
 
 
 
084c97d
 
 
 
 
8ae7071
6825ad3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79a62ec
 
 
6825ad3
79a62ec
 
 
6825ad3
 
 
 
79a62ec
084c97d
79a62ec
 
6825ad3
 
8ae7071
6825ad3
 
 
 
 
 
 
 
 
8ae7071
 
6825ad3
 
084c97d
6825ad3
084c97d
6825ad3
 
 
 
8ae7071
977aa5f
27aaed4
977aa5f
 
 
27aaed4
977aa5f
 
 
 
 
8ae7071
6825ad3
8ae7071
 
 
 
 
 
6825ad3
8ae7071
 
 
 
 
977aa5f
 
 
 
 
 
 
8ae7071
 
 
 
 
 
242850e
8ae7071
 
6825ad3
 
 
 
 
 
 
 
8ae7071
6825ad3
 
 
0639dc1
6825ad3
0639dc1
6825ad3
311bc12
977aa5f
6825ad3
 
8ae7071
6825ad3
 
 
 
 
 
 
 
 
 
 
 
977aa5f
 
 
 
 
6825ad3
084c97d
6825ad3
 
 
8ae7071
6825ad3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""
aerial-segmentation
Proof of concept showing effectiveness of a fine tuned instance segmentation model for detecting trees.
"""
import os
import gradio as gr

import cv2
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
from transformers import DetrFeatureExtractor, DetrForSegmentation
from PIL import Image
import gradio as gr
import numpy as np
import torch
import torchvision
import detectron2
import json

# import some common detectron2 utilities
import itertools
import seaborn as sns
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.visualizer import ColorMode
from detectron2.structures import Instances

def list_cfg_file_versions(directory):
    files = os.listdir(directory)
    # return files that contains substring version and end with .yml
    cfg_files = [f.split("_")[0] for f in files if (f.endswith(".yml") or f.endswith(".yaml")) and f.startswith(f"{directory.split('_')[0]}v")]
    return cfg_files

def list_pth_files_in_directory(directory, version="v1"):
    files = os.listdir(directory)
    version = version.split("v")[1]
    # return files that contains substring version and end with .pth
    pth_files = [f for f in files if version in f and f.endswith(".pth")]
    return pth_files

def get_version_cfg_yml(path):
    directory = path.split("/")[0]
    version = path.split("/")[1]
    files = os.listdir(directory)
    cfg_file = [f for f in files if (f.endswith(".yml") or f.endswith(".yaml")) and version in f]
    return directory + "/" + cfg_file[0]

def update_row_visibility(mode):
    visibility = {
        "tree": mode in ["Trees", "Trees & Buildings"],
        "building": mode in ["Buildings", "Trees & Buildings"],
        "lcz": mode in ["LCZ"]
    }
    tree_row, building_row, lcz_row = gr.Row(visible=visibility["tree"]), gr.Row(visible=visibility["building"]), gr.Row(visible=visibility["lcz"])
    print(visibility)
    return tree_row, building_row, lcz_row

def update_path_options(version):
    if "tree" in version:
        directory = "tree_model_weights"
    elif "building" in version:
        directory = "building_model_weights"
    elif "lcz" in version:
        directory = "lcz_model_weights"
    return gr.Dropdown(choices=list_pth_files_in_directory(directory, version), label=f"Select a {version.split('v')[0]} model file", visible=True, interactive=True)

# Model for trees
def tree_model(tree_version_dropdown, tree_pth_dropdown, tree_threshold, device="cpu"):
    tree_cfg = get_cfg()
    tree_cfg.merge_from_file(get_version_cfg_yml(f"tree_model_weights/{tree_version_dropdown}"))
    tree_cfg.MODEL.DEVICE=device
    tree_cfg.MODEL.WEIGHTS = f"tree_model_weights/{tree_pth_dropdown}"
    tree_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2  # TODO change this
    tree_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = tree_threshold
    tree_predictor = DefaultPredictor(tree_cfg)
    return tree_predictor

# Model for buildings
def building_model(building_version_dropdown, building_pth_dropdown, building_threshold, device="cpu"):
    building_cfg = get_cfg()
    building_cfg.merge_from_file(get_version_cfg_yml(f"building_model_weights/{building_version_dropdown}"))
    building_cfg.MODEL.DEVICE=device
    building_cfg.MODEL.WEIGHTS = f"building_model_weights/{building_pth_dropdown}"
    building_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 8  # TODO change this
    building_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = building_threshold
    building_predictor = DefaultPredictor(building_cfg)
    return building_predictor

# Model for LCZs
def lcz_model(lcz_version_dropdown, lcz_pth_dropdown, lcz_threshold, device="cpu"):
    lcz_cfg = get_cfg()
    lcz_cfg.merge_from_file(get_version_cfg_yml("lcz_model_weights/lczs_cfg.yaml"))
    lcz_cfg.MODEL.DEVICE=device
    lcz_cfg.MODEL.WEIGHTS = f"tree_model_weights/{lcz_pth_dropdown}"
    lcz_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 14  # TODO change this
    lcz_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = lcz_threshold
    lcz_predictor = DefaultPredictor(lcz_cfg)
    return lcz_predictor

# A function that runs the buildings model on an given image and confidence threshold
def segment_building(im, building_predictor):
    outputs = building_predictor(im)
    building_instances = outputs["instances"].to("cpu")

    return building_instances

# A function that runs the trees model on an given image and confidence threshold
def segment_tree(im, tree_predictor):
    outputs = tree_predictor(im)
    tree_instances = outputs["instances"].to("cpu")

    return tree_instances

# A function that runs the trees model on an given image and confidence threshold
def segment_lcz(im, lcz_predictor):
    outputs = lcz_predictor(im)
    lcz_instances = outputs["instances"].to("cpu")

    return lcz_instances

# Function to map strings to color mode
def map_color_mode(color_mode):
    if color_mode == "Black/white":
        return ColorMode.IMAGE_BW
    elif color_mode == "Random":
        return ColorMode.IMAGE
    elif color_mode == "Segmentation" or color_mode == None:
        return ColorMode.SEGMENTATION

def load_predictor(model, version, pth, threshold):
    return model(version, pth, threshold)

def load_instances(image, predictor, segment_function):
    return segment_function(image, predictor)

def combine_instances(tree_instances, building_instances):
    return Instances.cat([tree_instances, building_instances])

def get_metadata(dataset_name, coco_file):
    metadata = MetadataCatalog.get(dataset_name)
    with open(coco_file, "r") as f:
        coco = json.load(f)
    categories = coco["categories"]
    metadata.thing_classes = [c["name"] for c in categories]
    return metadata

def visualize_image(im, mode, tree_threshold, building_threshold, color_mode, tree_version, tree_pth, building_version, building_pth, lcz_version, lcz_pth):
    im = np.array(im)
    color_mode = map_color_mode(color_mode)
    
    instances = None

    if mode in {"Trees", "Both"}:
        tree_predictor = load_predictor(tree_model, tree_version, tree_pth, tree_threshold)
        tree_instances = load_instances(im, tree_predictor, segment_tree)
        instances = tree_instances

    if mode in {"Buildings", "Both"}:
        building_predictor = load_predictor(building_model, building_version, building_pth, building_threshold)
        building_instances = load_instances(im, building_predictor, segment_building)
        instances = building_instances if mode == "Buildings" else combine_instances(instances, building_instances)

    if mode in {"LCZ", "Both"}:
        lcz_predictor = load_predictor(lcz_model, lcz_version, lcz_pth, lcz_threshold)
        lcz_instances = load_instances(im, lcz_predictor, segment_lcz)
        instances = lcz_instances if mode == "LCZ" else combine_instances(instances, LCZ_instances)   

    # Assuming 'urban-small_train' is intended for both Trees and Buildings
    metadata = get_metadata("urban-small_train", "building_model_weights/_annotations.coco.json")
    visualizer = Visualizer(im[:, :, ::-1], metadata=metadata, scale=0.5, instance_mode=color_mode)

    output_image = visualizer.draw_instance_predictions(instances)

    return Image.fromarray(output_image.get_image()[:, :, ::-1])