aigis-lczs / backend.py
hlydecker's picture
Update backend.py
27aaed4 verified
raw
history blame
6.96 kB
"""
aerial-segmentation
Proof of concept showing effectiveness of a fine tuned instance segmentation model for detecting trees.
"""
import os
import gradio as gr
import cv2
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
from transformers import DetrFeatureExtractor, DetrForSegmentation
from PIL import Image
import gradio as gr
import numpy as np
import torch
import torchvision
import detectron2
import json
# import some common detectron2 utilities
import itertools
import seaborn as sns
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.visualizer import ColorMode
from detectron2.structures import Instances
def list_pth_files_in_directory(directory, version="v1"):
files = os.listdir(directory)
version = version.split("v")[1]
# return files that contains substring version and end with .pth
pth_files = [f for f in files if version in f and f.endswith(".pth")]
return pth_files
def get_version_cfg_yml(path):
directory = path.split("/")[0]
version = path.split("/")[1]
files = os.listdir(directory)
cfg_file = [f for f in files if (f.endswith(".yml") or f.endswith(".yaml")) and version in f]
return directory + "/" + cfg_file[0]
def update_row_visibility(mode):
visibility = {
"tree": mode in ["Trees", "Both"],
"building": mode in ["Buildings", "Both"]
}
tree_row, building_row = gr.Row(visible=visibility["tree"]), gr.Row(visible=visibility["building"])
return tree_row, building_row
def update_path_options(version):
if "tree" in version:
directory = "tree_model_weights"
else:
directory = "building_model_weight"
return gr.Dropdown(choices=list_pth_files_in_directory(directory, version), label=f"Select a {version.split('v')[0]} model file", visible=True, interactive=True)
# Model for trees
def tree_model(tree_version_dropdown, tree_pth_dropdown, tree_threshold, device="cpu"):
tree_cfg = get_cfg()
tree_cfg.merge_from_file(get_version_cfg_yml(f"tree_model_weights/{tree_version_dropdown}"))
tree_cfg.MODEL.DEVICE=device
tree_cfg.MODEL.WEIGHTS = f"tree_model_weights/{tree_pth_dropdown}"
tree_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # TODO change this
tree_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = tree_threshold
tree_predictor = DefaultPredictor(tree_cfg)
return tree_predictor
# Model for buildings
def building_model(building_version_dropdown, building_pth_dropdown, building_threshold, device="cpu"):
building_cfg = get_cfg()
building_cfg.merge_from_file(get_version_cfg_yml(f"building_model_weight/{building_version_dropdown}"))
building_cfg.MODEL.DEVICE=device
building_cfg.MODEL.WEIGHTS = f"building_model_weight/{building_pth_dropdown}"
building_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 8 # TODO change this
building_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = building_threshold
building_predictor = DefaultPredictor(building_cfg)
return building_predictor
# Model for LCZs
def lcz_model(lcz_version_dropdown, lcz_pth_dropdown, lcz_threshold, device="cpu"):
lcz_cfg = get_cfg()
lcz_cfg.merge_from_file(get_version_cfg_yml("lcz_model_weights/lczs_cfg.yaml"))
lcz_cfg.MODEL.DEVICE=device
lcz_cfg.MODEL.WEIGHTS = f"tree_model_weights/{lcz_pth_dropdown}"
lcz_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 14 # TODO change this
lcz_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = lcz_threshold
lcz_predictor = DefaultPredictor(lcz_cfg)
return lcz_predictor
# A function that runs the buildings model on an given image and confidence threshold
def segment_building(im, building_predictor):
outputs = building_predictor(im)
building_instances = outputs["instances"].to("cpu")
return building_instances
# A function that runs the trees model on an given image and confidence threshold
def segment_tree(im, tree_predictor):
outputs = tree_predictor(im)
tree_instances = outputs["instances"].to("cpu")
return tree_instances
# A function that runs the trees model on an given image and confidence threshold
def segment_lcz(im, lcz_predictor):
outputs = lcz_predictor(im)
lcz_instances = outputs["instances"].to("cpu")
return lcz_instances
# Function to map strings to color mode
def map_color_mode(color_mode):
if color_mode == "Black/white":
return ColorMode.IMAGE_BW
elif color_mode == "Random":
return ColorMode.IMAGE
elif color_mode == "Segmentation" or color_mode == None:
return ColorMode.SEGMENTATION
def load_predictor(model, version, pth, threshold):
return model(version, pth, threshold)
def load_instances(image, predictor, segment_function):
return segment_function(image, predictor)
def combine_instances(tree_instances, building_instances):
return Instances.cat([tree_instances, building_instances])
def get_metadata(dataset_name, coco_file):
metadata = MetadataCatalog.get(dataset_name)
with open(coco_file, "r") as f:
coco = json.load(f)
categories = coco["categories"]
metadata.thing_classes = [c["name"] for c in categories]
return metadata
def visualize_image(im, mode, tree_threshold, building_threshold, color_mode, tree_version, tree_pth, building_version, building_pth, lcz_version, lcz_pth):
im = np.array(im)
color_mode = map_color_mode(color_mode)
instances = None
if mode in {"Trees", "Both"}:
tree_predictor = load_predictor(tree_model, tree_version, tree_pth, tree_threshold)
tree_instances = load_instances(im, tree_predictor, segment_tree)
instances = tree_instances
if mode in {"Buildings", "Both"}:
building_predictor = load_predictor(building_model, building_version, building_pth, building_threshold)
building_instances = load_instances(im, building_predictor, segment_building)
instances = building_instances if mode == "Buildings" else combine_instances(instances, building_instances)
if mode in {"LCZ", "Both"}:
lcz_predictor = load_predictor(lcz_model, lcz_version, lcz_pth, lcz_threshold)
lcz_instances = load_instances(im, lcz_predictor, segment_lcz)
instances = lcz_instances if mode == "LCZ" else combine_instances(instances, LCZ_instances)
# Assuming 'urban-small_train' is intended for both Trees and Buildings
metadata = get_metadata("urban-small_train", "building_model_weight/_annotations.coco.json")
visualizer = Visualizer(im[:, :, ::-1], metadata=metadata, scale=0.5, instance_mode=color_mode)
output_image = visualizer.draw_instance_predictions(instances)
return Image.fromarray(output_image.get_image()[:, :, ::-1])