""" aerial-segmentation Proof of concept showing effectiveness of a fine tuned instance segmentation model for detecting trees. """ import os import gradio as gr import cv2 os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'") from transformers import DetrFeatureExtractor, DetrForSegmentation from PIL import Image import gradio as gr import numpy as np import torch import torchvision import detectron2 import json # import some common detectron2 utilities import itertools import seaborn as sns from detectron2 import model_zoo from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.utils.visualizer import Visualizer from detectron2.utils.visualizer import ColorMode from detectron2.data import MetadataCatalog, DatasetCatalog from detectron2.checkpoint import DetectionCheckpointer from detectron2.utils.visualizer import ColorMode from detectron2.structures import Instances def list_pth_files_in_directory(directory, version="v1"): files = os.listdir(directory) version = version.split("v")[1] # return files that contains substring version and end with .pth pth_files = [f for f in files if version in f and f.endswith(".pth")] return pth_files def get_version_cfg_yml(path): directory = path.split("/")[0] version = path.split("/")[1] files = os.listdir(directory) cfg_file = [f for f in files if (f.endswith(".yml") or f.endswith(".yaml")) and version in f] return directory + "/" + cfg_file[0] def update_row_visibility(mode): visibility = { "tree": mode in ["Trees", "Both"], "building": mode in ["Buildings", "Both"] } tree_row, building_row = gr.Row(visible=visibility["tree"]), gr.Row(visible=visibility["building"]) return tree_row, building_row def update_path_options(version): if "tree" in version: directory = "tree_model_weights" else: directory = "building_model_weight" return gr.Dropdown(choices=list_pth_files_in_directory(directory, version), label=f"Select a {version.split('v')[0]} model file", visible=True, interactive=True) # Model for trees def tree_model(tree_version_dropdown, tree_pth_dropdown, tree_threshold, device="cpu"): tree_cfg = get_cfg() tree_cfg.merge_from_file(get_version_cfg_yml(f"tree_model_weights/{tree_version_dropdown}")) tree_cfg.MODEL.DEVICE=device tree_cfg.MODEL.WEIGHTS = f"tree_model_weights/{tree_pth_dropdown}" tree_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # TODO change this tree_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = tree_threshold tree_predictor = DefaultPredictor(tree_cfg) return tree_predictor # Model for buildings def building_model(building_version_dropdown, building_pth_dropdown, building_threshold, device="cpu"): building_cfg = get_cfg() building_cfg.merge_from_file(get_version_cfg_yml(f"building_model_weight/{building_version_dropdown}")) building_cfg.MODEL.DEVICE=device building_cfg.MODEL.WEIGHTS = f"building_model_weight/{building_pth_dropdown}" building_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 8 # TODO change this building_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = building_threshold building_predictor = DefaultPredictor(building_cfg) return building_predictor # A function that runs the buildings model on an given image and confidence threshold def segment_building(im, building_predictor): outputs = building_predictor(im) building_instances = outputs["instances"].to("cpu") return building_instances # A function that runs the trees model on an given image and confidence threshold def segment_tree(im, tree_predictor): outputs = tree_predictor(im) tree_instances = outputs["instances"].to("cpu") return tree_instances # Function to map strings to color mode def map_color_mode(color_mode): if color_mode == "Black/white": return ColorMode.IMAGE_BW elif color_mode == "Random": return ColorMode.IMAGE elif color_mode == "Segmentation" or color_mode == None: return ColorMode.SEGMENTATION def load_predictor(model, version, pth, threshold): return model(version, pth, threshold) def load_instances(image, predictor, segment_function): return segment_function(image, predictor) def combine_instances(tree_instances, building_instances): return Instances.cat([tree_instances, building_instances]) def get_metadata(dataset_name, coco_file): metadata = MetadataCatalog.get(dataset_name) with open(coco_file, "r") as f: coco = json.load(f) categories = coco["categories"] metadata.thing_classes = [c["name"] for c in categories] return metadata def visualize_image(im, mode, tree_threshold, building_threshold, color_mode, tree_version, tree_pth, building_version, building_pth): im = np.array(im) color_mode = map_color_mode(color_mode) instances = None if mode in {"Trees", "Both"}: tree_predictor = load_predictor(tree_model, tree_version, tree_pth, tree_threshold) tree_instances = load_instances(im, tree_predictor, segment_tree) instances = tree_instances if mode in {"Buildings", "Both"}: building_predictor = load_predictor(building_model, building_version, building_pth, building_threshold) building_instances = load_instances(im, building_predictor, segment_building) instances = building_instances if mode == "Buildings" else combine_instances(instances, building_instances) # Assuming 'urban-small_train' is intended for both Trees and Buildings metadata = get_metadata("urban-small_train", "building_model_weight/_annotations.coco.json") visualizer = Visualizer(im[:, :, ::-1], metadata=metadata, scale=0.5, instance_mode=color_mode) output_image = visualizer.draw_instance_predictions(instances) return Image.fromarray(output_image.get_image()[:, :, ::-1])