Spaces:
Runtime error
Runtime error
File size: 6,955 Bytes
8ae7071 0639dc1 8ae7071 6825ad3 8ae7071 0639dc1 8ae7071 6825ad3 8ae7071 6825ad3 8ae7071 6825ad3 8ae7071 977aa5f 27aaed4 977aa5f 27aaed4 977aa5f 8ae7071 6825ad3 8ae7071 6825ad3 8ae7071 977aa5f 8ae7071 242850e 8ae7071 6825ad3 8ae7071 6825ad3 0639dc1 6825ad3 0639dc1 6825ad3 311bc12 977aa5f 6825ad3 8ae7071 6825ad3 977aa5f 6825ad3 8ae7071 6825ad3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
"""
aerial-segmentation
Proof of concept showing effectiveness of a fine tuned instance segmentation model for detecting trees.
"""
import os
import gradio as gr
import cv2
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
from transformers import DetrFeatureExtractor, DetrForSegmentation
from PIL import Image
import gradio as gr
import numpy as np
import torch
import torchvision
import detectron2
import json
# import some common detectron2 utilities
import itertools
import seaborn as sns
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.visualizer import ColorMode
from detectron2.structures import Instances
def list_pth_files_in_directory(directory, version="v1"):
files = os.listdir(directory)
version = version.split("v")[1]
# return files that contains substring version and end with .pth
pth_files = [f for f in files if version in f and f.endswith(".pth")]
return pth_files
def get_version_cfg_yml(path):
directory = path.split("/")[0]
version = path.split("/")[1]
files = os.listdir(directory)
cfg_file = [f for f in files if (f.endswith(".yml") or f.endswith(".yaml")) and version in f]
return directory + "/" + cfg_file[0]
def update_row_visibility(mode):
visibility = {
"tree": mode in ["Trees", "Both"],
"building": mode in ["Buildings", "Both"]
}
tree_row, building_row = gr.Row(visible=visibility["tree"]), gr.Row(visible=visibility["building"])
return tree_row, building_row
def update_path_options(version):
if "tree" in version:
directory = "tree_model_weights"
else:
directory = "building_model_weight"
return gr.Dropdown(choices=list_pth_files_in_directory(directory, version), label=f"Select a {version.split('v')[0]} model file", visible=True, interactive=True)
# Model for trees
def tree_model(tree_version_dropdown, tree_pth_dropdown, tree_threshold, device="cpu"):
tree_cfg = get_cfg()
tree_cfg.merge_from_file(get_version_cfg_yml(f"tree_model_weights/{tree_version_dropdown}"))
tree_cfg.MODEL.DEVICE=device
tree_cfg.MODEL.WEIGHTS = f"tree_model_weights/{tree_pth_dropdown}"
tree_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # TODO change this
tree_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = tree_threshold
tree_predictor = DefaultPredictor(tree_cfg)
return tree_predictor
# Model for buildings
def building_model(building_version_dropdown, building_pth_dropdown, building_threshold, device="cpu"):
building_cfg = get_cfg()
building_cfg.merge_from_file(get_version_cfg_yml(f"building_model_weight/{building_version_dropdown}"))
building_cfg.MODEL.DEVICE=device
building_cfg.MODEL.WEIGHTS = f"building_model_weight/{building_pth_dropdown}"
building_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 8 # TODO change this
building_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = building_threshold
building_predictor = DefaultPredictor(building_cfg)
return building_predictor
# Model for LCZs
def lcz_model(lcz_version_dropdown, lcz_pth_dropdown, lcz_threshold, device="cpu"):
lcz_cfg = get_cfg()
lcz_cfg.merge_from_file(get_version_cfg_yml("lcz_model_weights/lczs_cfg.yaml"))
lcz_cfg.MODEL.DEVICE=device
lcz_cfg.MODEL.WEIGHTS = f"tree_model_weights/{lcz_pth_dropdown}"
lcz_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 14 # TODO change this
lcz_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = lcz_threshold
lcz_predictor = DefaultPredictor(lcz_cfg)
return lcz_predictor
# A function that runs the buildings model on an given image and confidence threshold
def segment_building(im, building_predictor):
outputs = building_predictor(im)
building_instances = outputs["instances"].to("cpu")
return building_instances
# A function that runs the trees model on an given image and confidence threshold
def segment_tree(im, tree_predictor):
outputs = tree_predictor(im)
tree_instances = outputs["instances"].to("cpu")
return tree_instances
# A function that runs the trees model on an given image and confidence threshold
def segment_lcz(im, lcz_predictor):
outputs = lcz_predictor(im)
lcz_instances = outputs["instances"].to("cpu")
return lcz_instances
# Function to map strings to color mode
def map_color_mode(color_mode):
if color_mode == "Black/white":
return ColorMode.IMAGE_BW
elif color_mode == "Random":
return ColorMode.IMAGE
elif color_mode == "Segmentation" or color_mode == None:
return ColorMode.SEGMENTATION
def load_predictor(model, version, pth, threshold):
return model(version, pth, threshold)
def load_instances(image, predictor, segment_function):
return segment_function(image, predictor)
def combine_instances(tree_instances, building_instances):
return Instances.cat([tree_instances, building_instances])
def get_metadata(dataset_name, coco_file):
metadata = MetadataCatalog.get(dataset_name)
with open(coco_file, "r") as f:
coco = json.load(f)
categories = coco["categories"]
metadata.thing_classes = [c["name"] for c in categories]
return metadata
def visualize_image(im, mode, tree_threshold, building_threshold, color_mode, tree_version, tree_pth, building_version, building_pth, lcz_version, lcz_pth):
im = np.array(im)
color_mode = map_color_mode(color_mode)
instances = None
if mode in {"Trees", "Both"}:
tree_predictor = load_predictor(tree_model, tree_version, tree_pth, tree_threshold)
tree_instances = load_instances(im, tree_predictor, segment_tree)
instances = tree_instances
if mode in {"Buildings", "Both"}:
building_predictor = load_predictor(building_model, building_version, building_pth, building_threshold)
building_instances = load_instances(im, building_predictor, segment_building)
instances = building_instances if mode == "Buildings" else combine_instances(instances, building_instances)
if mode in {"LCZ", "Both"}:
lcz_predictor = load_predictor(lcz_model, lcz_version, lcz_pth, lcz_threshold)
lcz_instances = load_instances(im, lcz_predictor, segment_lcz)
instances = lcz_instances if mode == "LCZ" else combine_instances(instances, LCZ_instances)
# Assuming 'urban-small_train' is intended for both Trees and Buildings
metadata = get_metadata("urban-small_train", "building_model_weight/_annotations.coco.json")
visualizer = Visualizer(im[:, :, ::-1], metadata=metadata, scale=0.5, instance_mode=color_mode)
output_image = visualizer.draw_instance_predictions(instances)
return Image.fromarray(output_image.get_image()[:, :, ::-1]) |