File size: 5,902 Bytes
8ae7071
0639dc1
8ae7071
 
 
6825ad3
 
8ae7071
 
 
 
 
 
 
 
 
0639dc1
8ae7071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6825ad3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ae7071
6825ad3
 
 
 
 
 
 
 
 
8ae7071
 
6825ad3
 
 
 
 
 
 
 
 
8ae7071
 
6825ad3
8ae7071
 
 
 
 
 
6825ad3
8ae7071
 
 
 
 
 
 
 
 
 
 
242850e
8ae7071
 
6825ad3
 
 
 
 
 
 
 
8ae7071
6825ad3
 
 
0639dc1
6825ad3
0639dc1
6825ad3
311bc12
6825ad3
 
 
8ae7071
6825ad3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ae7071
6825ad3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
aerial-segmentation
Proof of concept showing effectiveness of a fine tuned instance segmentation model for detecting trees.
"""
import os
import gradio as gr

import cv2
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
from transformers import DetrFeatureExtractor, DetrForSegmentation
from PIL import Image
import gradio as gr
import numpy as np
import torch
import torchvision
import detectron2
import json

# import some common detectron2 utilities
import itertools
import seaborn as sns
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.visualizer import ColorMode
from detectron2.structures import Instances


def list_pth_files_in_directory(directory, version="v1"):
    files = os.listdir(directory)
    version = version.split("v")[1]
    # return files that contains substring version and end with .pth
    pth_files = [f for f in files if version in f and f.endswith(".pth")]
    return pth_files

def get_version_cfg_yml(path):
    directory = path.split("/")[0]
    version = path.split("/")[1]
    files = os.listdir(directory)
    cfg_file = [f for f in files if (f.endswith(".yml") or f.endswith(".yaml")) and version in f]
    return directory + "/" + cfg_file[0]

def update_row_visibility(mode):
    visibility = {
        "tree": mode in ["Trees", "Both"],
        "building": mode in ["Buildings", "Both"]
    }
    tree_row, building_row = gr.Row(visible=visibility["tree"]), gr.Row(visible=visibility["building"])
    
    return tree_row, building_row

def update_path_options(version):
    if "tree" in version:
        directory = "tree_model_weights"
    else:
        directory = "building_model_weight"
    return gr.Dropdown(choices=list_pth_files_in_directory(directory, version), label=f"Select a {version.split('v')[0]} model file", visible=True, interactive=True)

# Model for trees
def tree_model(tree_version_dropdown, tree_pth_dropdown, tree_threshold, device="cpu"):
    tree_cfg = get_cfg()
    tree_cfg.merge_from_file(get_version_cfg_yml(f"tree_model_weights/{tree_version_dropdown}"))
    tree_cfg.MODEL.DEVICE=device
    tree_cfg.MODEL.WEIGHTS = f"tree_model_weights/{tree_pth_dropdown}"
    tree_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2  # TODO change this
    tree_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = tree_threshold
    tree_predictor = DefaultPredictor(tree_cfg)
    return tree_predictor

# Model for buildings
def building_model(building_version_dropdown, building_pth_dropdown, building_threshold, device="cpu"):
    building_cfg = get_cfg()
    building_cfg.merge_from_file(get_version_cfg_yml(f"building_model_weight/{building_version_dropdown}"))
    building_cfg.MODEL.DEVICE=device
    building_cfg.MODEL.WEIGHTS = f"building_model_weight/{building_pth_dropdown}"
    building_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 8  # TODO change this
    building_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = building_threshold
    building_predictor = DefaultPredictor(building_cfg)
    return building_predictor

# A function that runs the buildings model on an given image and confidence threshold
def segment_building(im, building_predictor):
    outputs = building_predictor(im)
    building_instances = outputs["instances"].to("cpu")

    return building_instances

# A function that runs the trees model on an given image and confidence threshold
def segment_tree(im, tree_predictor):
    outputs = tree_predictor(im)
    tree_instances = outputs["instances"].to("cpu")

    return tree_instances

# Function to map strings to color mode
def map_color_mode(color_mode):
    if color_mode == "Black/white":
        return ColorMode.IMAGE_BW
    elif color_mode == "Random":
        return ColorMode.IMAGE
    elif color_mode == "Segmentation" or color_mode == None:
        return ColorMode.SEGMENTATION

def load_predictor(model, version, pth, threshold):
    return model(version, pth, threshold)

def load_instances(image, predictor, segment_function):
    return segment_function(image, predictor)

def combine_instances(tree_instances, building_instances):
    return Instances.cat([tree_instances, building_instances])

def get_metadata(dataset_name, coco_file):
    metadata = MetadataCatalog.get(dataset_name)
    with open(coco_file, "r") as f:
        coco = json.load(f)
    categories = coco["categories"]
    metadata.thing_classes = [c["name"] for c in categories]
    return metadata

def visualize_image(im, mode, tree_threshold, building_threshold, color_mode, tree_version, tree_pth, building_version, building_pth):
    im = np.array(im)
    color_mode = map_color_mode(color_mode)
    
    instances = None

    if mode in {"Trees", "Both"}:
        tree_predictor = load_predictor(tree_model, tree_version, tree_pth, tree_threshold)
        tree_instances = load_instances(im, tree_predictor, segment_tree)
        instances = tree_instances

    if mode in {"Buildings", "Both"}:
        building_predictor = load_predictor(building_model, building_version, building_pth, building_threshold)
        building_instances = load_instances(im, building_predictor, segment_building)
        instances = building_instances if mode == "Buildings" else combine_instances(instances, building_instances)

    # Assuming 'urban-small_train' is intended for both Trees and Buildings
    metadata = get_metadata("urban-small_train", "building_model_weight/_annotations.coco.json")
    visualizer = Visualizer(im[:, :, ::-1], metadata=metadata, scale=0.5, instance_mode=color_mode)

    output_image = visualizer.draw_instance_predictions(instances)

    return Image.fromarray(output_image.get_image()[:, :, ::-1])