Spaces:
Runtime error
Runtime error
""" | |
aerial-segmentation | |
Proof of concept showing effectiveness of a fine tuned instance segmentation model for detecting trees. | |
""" | |
import os | |
import gradio as gr | |
import cv2 | |
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'") | |
from transformers import DetrFeatureExtractor, DetrForSegmentation | |
from PIL import Image | |
import gradio as gr | |
import numpy as np | |
import torch | |
import torchvision | |
import detectron2 | |
import json | |
# import some common detectron2 utilities | |
import itertools | |
import seaborn as sns | |
from detectron2 import model_zoo | |
from detectron2.engine import DefaultPredictor | |
from detectron2.config import get_cfg | |
from detectron2.utils.visualizer import Visualizer | |
from detectron2.utils.visualizer import ColorMode | |
from detectron2.data import MetadataCatalog, DatasetCatalog | |
from detectron2.checkpoint import DetectionCheckpointer | |
from detectron2.utils.visualizer import ColorMode | |
from detectron2.structures import Instances | |
def list_cfg_file_versions(directory): | |
files = os.listdir(directory) | |
# return files that contains substring version and end with .yml | |
cfg_files = [f.split("_")[0] for f in files if (f.endswith(".yml") or f.endswith(".yaml")) and f.startswith(f"{directory.split('_')[0]}v")] | |
return cfg_files | |
def list_pth_files_in_directory(directory, version="v1"): | |
files = os.listdir(directory) | |
version = version.split("v")[1] | |
# return files that contains substring version and end with .pth | |
pth_files = [f for f in files if version in f and f.endswith(".pth")] | |
return pth_files | |
def get_version_cfg_yml(path): | |
directory = path.split("/")[0] | |
version = path.split("/")[1] | |
files = os.listdir(directory) | |
cfg_file = [f for f in files if (f.endswith(".yml") or f.endswith(".yaml")) and version in f] | |
return directory + "/" + cfg_file[0] | |
def update_row_visibility(mode): | |
visibility = { | |
"tree": mode in ["Trees", "Both"], | |
"building": mode in ["Buildings", "Both"] | |
} | |
tree_row, building_row = gr.Row(visible=visibility["tree"]), gr.Row(visible=visibility["building"]) | |
return tree_row, building_row | |
def update_path_options(version): | |
if "tree" in version: | |
directory = "tree_model_weights" | |
else: | |
directory = "building_model_weights" | |
return gr.Dropdown(choices=list_pth_files_in_directory(directory, version), label=f"Select a {version.split('v')[0]} model file", visible=True, interactive=True) | |
# Model for trees | |
def tree_model(tree_version_dropdown, tree_pth_dropdown, tree_threshold, device="cpu"): | |
tree_cfg = get_cfg() | |
tree_cfg.merge_from_file(get_version_cfg_yml(f"tree_model_weights/{tree_version_dropdown}")) | |
tree_cfg.MODEL.DEVICE=device | |
tree_cfg.MODEL.WEIGHTS = f"tree_model_weights/{tree_pth_dropdown}" | |
tree_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # TODO change this | |
tree_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = tree_threshold | |
tree_predictor = DefaultPredictor(tree_cfg) | |
return tree_predictor | |
# Model for buildings | |
def building_model(building_version_dropdown, building_pth_dropdown, building_threshold, device="cpu"): | |
building_cfg = get_cfg() | |
building_cfg.merge_from_file(get_version_cfg_yml(f"building_model_weights/{building_version_dropdown}")) | |
building_cfg.MODEL.DEVICE=device | |
building_cfg.MODEL.WEIGHTS = f"building_model_weights/{building_pth_dropdown}" | |
building_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 8 # TODO change this | |
building_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = building_threshold | |
building_predictor = DefaultPredictor(building_cfg) | |
return building_predictor | |
# Model for LCZs | |
def lcz_model(lcz_version_dropdown, lcz_pth_dropdown, lcz_threshold, device="cpu"): | |
lcz_cfg = get_cfg() | |
lcz_cfg.merge_from_file(get_version_cfg_yml("lcz_model_weights/lczs_cfg.yaml")) | |
lcz_cfg.MODEL.DEVICE=device | |
lcz_cfg.MODEL.WEIGHTS = f"tree_model_weights/{lcz_pth_dropdown}" | |
lcz_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 14 # TODO change this | |
lcz_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = lcz_threshold | |
lcz_predictor = DefaultPredictor(lcz_cfg) | |
return lcz_predictor | |
# A function that runs the buildings model on an given image and confidence threshold | |
def segment_building(im, building_predictor): | |
outputs = building_predictor(im) | |
building_instances = outputs["instances"].to("cpu") | |
return building_instances | |
# A function that runs the trees model on an given image and confidence threshold | |
def segment_tree(im, tree_predictor): | |
outputs = tree_predictor(im) | |
tree_instances = outputs["instances"].to("cpu") | |
return tree_instances | |
# A function that runs the trees model on an given image and confidence threshold | |
def segment_lcz(im, lcz_predictor): | |
outputs = lcz_predictor(im) | |
lcz_instances = outputs["instances"].to("cpu") | |
return lcz_instances | |
# Function to map strings to color mode | |
def map_color_mode(color_mode): | |
if color_mode == "Black/white": | |
return ColorMode.IMAGE_BW | |
elif color_mode == "Random": | |
return ColorMode.IMAGE | |
elif color_mode == "Segmentation" or color_mode == None: | |
return ColorMode.SEGMENTATION | |
def load_predictor(model, version, pth, threshold): | |
return model(version, pth, threshold) | |
def load_instances(image, predictor, segment_function): | |
return segment_function(image, predictor) | |
def combine_instances(tree_instances, building_instances): | |
return Instances.cat([tree_instances, building_instances]) | |
def get_metadata(dataset_name, coco_file): | |
metadata = MetadataCatalog.get(dataset_name) | |
with open(coco_file, "r") as f: | |
coco = json.load(f) | |
categories = coco["categories"] | |
metadata.thing_classes = [c["name"] for c in categories] | |
return metadata | |
def visualize_image(im, mode, tree_threshold, building_threshold, color_mode, tree_version, tree_pth, building_version, building_pth, lcz_version, lcz_pth): | |
im = np.array(im) | |
color_mode = map_color_mode(color_mode) | |
instances = None | |
if mode in {"Trees", "Both"}: | |
tree_predictor = load_predictor(tree_model, tree_version, tree_pth, tree_threshold) | |
tree_instances = load_instances(im, tree_predictor, segment_tree) | |
instances = tree_instances | |
if mode in {"Buildings", "Both"}: | |
building_predictor = load_predictor(building_model, building_version, building_pth, building_threshold) | |
building_instances = load_instances(im, building_predictor, segment_building) | |
instances = building_instances if mode == "Buildings" else combine_instances(instances, building_instances) | |
if mode in {"LCZ", "Both"}: | |
lcz_predictor = load_predictor(lcz_model, lcz_version, lcz_pth, lcz_threshold) | |
lcz_instances = load_instances(im, lcz_predictor, segment_lcz) | |
instances = lcz_instances if mode == "LCZ" else combine_instances(instances, LCZ_instances) | |
# Assuming 'urban-small_train' is intended for both Trees and Buildings | |
metadata = get_metadata("urban-small_train", "building_model_weights/_annotations.coco.json") | |
visualizer = Visualizer(im[:, :, ::-1], metadata=metadata, scale=0.5, instance_mode=color_mode) | |
output_image = visualizer.draw_instance_predictions(instances) | |
return Image.fromarray(output_image.get_image()[:, :, ::-1]) |