""" aerial-segmentation Proof of concept showing effectiveness of a fine tuned instance segmentation model for detecting trees. """ import os import gradio as gr import cv2 os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'") from transformers import DetrFeatureExtractor, DetrForSegmentation from PIL import Image import gradio as gr import numpy as np import torch import torchvision import detectron2 import json # import some common detectron2 utilities import itertools import seaborn as sns from detectron2 import model_zoo from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.utils.visualizer import Visualizer from detectron2.utils.visualizer import ColorMode from detectron2.data import MetadataCatalog, DatasetCatalog from detectron2.checkpoint import DetectionCheckpointer from detectron2.utils.visualizer import ColorMode from detectron2.structures import Instances def list_cfg_file_versions(directory): files = os.listdir(directory) # return files that contains substring version and end with .yml cfg_files = [f.split("_")[0] for f in files if (f.endswith(".yml") or f.endswith(".yaml")) and f.startswith(f"{directory.split('_')[0]}v")] return cfg_files def list_pth_files_in_directory(directory, version="v1"): files = os.listdir(directory) version = version.split("v")[1] # return files that contains substring version and end with .pth pth_files = [f for f in files if version in f and f.endswith(".pth")] return pth_files def get_version_cfg_yml(path): directory = path.split("/")[0] version = path.split("/")[1] files = os.listdir(directory) cfg_file = [f for f in files if (f.endswith(".yml") or f.endswith(".yaml")) and version in f] return directory + "/" + cfg_file[0] def update_row_visibility(mode): visibility = { "tree": mode in ["Trees", "Trees & Buildings"], "building": mode in ["Buildings", "Trees & Buildings"], "lcz": mode in ["LCZ"] } tree_row, building_row, lcz_row = gr.Row(visible=visibility["tree"]), gr.Row(visible=visibility["building"]), gr.Row(visible=visibility["lcz"]) print(visibility) return tree_row, building_row, lcz_row def update_path_options(version): if "tree" in version: directory = "tree_model_weights" elif "building" in version: directory = "building_model_weights" elif "lcz" in version: directory = "lcz_model_weights" return gr.Dropdown(choices=list_pth_files_in_directory(directory, version), label=f"Select a {version.split('v')[0]} model file", visible=True, interactive=True) # Model for trees def tree_model(tree_version_dropdown, tree_pth_dropdown, tree_threshold, device="cpu"): tree_cfg = get_cfg() tree_cfg.merge_from_file(get_version_cfg_yml(f"tree_model_weights/{tree_version_dropdown}")) tree_cfg.MODEL.DEVICE=device tree_cfg.MODEL.WEIGHTS = f"tree_model_weights/{tree_pth_dropdown}" tree_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # TODO change this tree_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = tree_threshold tree_predictor = DefaultPredictor(tree_cfg) return tree_predictor # Model for buildings def building_model(building_version_dropdown, building_pth_dropdown, building_threshold, device="cpu"): building_cfg = get_cfg() building_cfg.merge_from_file(get_version_cfg_yml(f"building_model_weights/{building_version_dropdown}")) building_cfg.MODEL.DEVICE=device building_cfg.MODEL.WEIGHTS = f"building_model_weights/{building_pth_dropdown}" building_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 8 # TODO change this building_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = building_threshold building_predictor = DefaultPredictor(building_cfg) return building_predictor # Model for LCZs def lcz_model(lcz_version_dropdown, lcz_pth_dropdown, lcz_threshold, device="cpu"): lcz_cfg = get_cfg() lcz_cfg.merge_from_file(get_version_cfg_yml("lcz_model_weights/lczs_cfg.yaml")) lcz_cfg.MODEL.DEVICE=device lcz_cfg.MODEL.WEIGHTS = f"tree_model_weights/{lcz_pth_dropdown}" lcz_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 14 # TODO change this lcz_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = lcz_threshold lcz_predictor = DefaultPredictor(lcz_cfg) return lcz_predictor # A function that runs the buildings model on an given image and confidence threshold def segment_building(im, building_predictor): outputs = building_predictor(im) building_instances = outputs["instances"].to("cpu") return building_instances # A function that runs the trees model on an given image and confidence threshold def segment_tree(im, tree_predictor): outputs = tree_predictor(im) tree_instances = outputs["instances"].to("cpu") return tree_instances # A function that runs the trees model on an given image and confidence threshold def segment_lcz(im, lcz_predictor): outputs = lcz_predictor(im) lcz_instances = outputs["instances"].to("cpu") return lcz_instances # Function to map strings to color mode def map_color_mode(color_mode): if color_mode == "Black/white": return ColorMode.IMAGE_BW elif color_mode == "Random": return ColorMode.IMAGE elif color_mode == "Segmentation" or color_mode == None: return ColorMode.SEGMENTATION def load_predictor(model, version, pth, threshold): return model(version, pth, threshold) def load_instances(image, predictor, segment_function): return segment_function(image, predictor) def combine_instances(tree_instances, building_instances): return Instances.cat([tree_instances, building_instances]) def get_metadata(dataset_name, coco_file): metadata = MetadataCatalog.get(dataset_name) with open(coco_file, "r") as f: coco = json.load(f) categories = coco["categories"] metadata.thing_classes = [c["name"] for c in categories] return metadata def visualize_image(im, mode, tree_threshold, building_threshold, color_mode, tree_version, tree_pth, building_version, building_pth, lcz_version, lcz_pth): im = np.array(im) color_mode = map_color_mode(color_mode) instances = None if mode in {"Trees", "Both"}: tree_predictor = load_predictor(tree_model, tree_version, tree_pth, tree_threshold) tree_instances = load_instances(im, tree_predictor, segment_tree) instances = tree_instances if mode in {"Buildings", "Both"}: building_predictor = load_predictor(building_model, building_version, building_pth, building_threshold) building_instances = load_instances(im, building_predictor, segment_building) instances = building_instances if mode == "Buildings" else combine_instances(instances, building_instances) if mode in {"LCZ", "Both"}: lcz_predictor = load_predictor(lcz_model, lcz_version, lcz_pth, lcz_threshold) lcz_instances = load_instances(im, lcz_predictor, segment_lcz) instances = lcz_instances if mode == "LCZ" else combine_instances(instances, LCZ_instances) # Assuming 'urban-small_train' is intended for both Trees and Buildings metadata = get_metadata("urban-small_train", "building_model_weights/_annotations.coco.json") visualizer = Visualizer(im[:, :, ::-1], metadata=metadata, scale=0.5, instance_mode=color_mode) output_image = visualizer.draw_instance_predictions(instances) return Image.fromarray(output_image.get_image()[:, :, ::-1])