Spaces:
Runtime error
Runtime error
File size: 4,519 Bytes
8ae7071 0639dc1 8ae7071 0639dc1 8ae7071 e793dfa 8ae7071 db42787 8ae7071 d4def87 8ae7071 242850e 8ae7071 242850e cc6b299 8ae7071 242850e 8ae7071 008d1df 45a81a5 8ae7071 45a81a5 8ae7071 05b14d0 008d1df 45a81a5 008d1df 0639dc1 05b14d0 311bc12 8ae7071 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
"""
aerial-segmentation
Proof of concept showing effectiveness of a fine tuned instance segmentation model for detecting trees.
"""
import os
import cv2
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
from transformers import DetrFeatureExtractor, DetrForSegmentation
from PIL import Image
import gradio as gr
import numpy as np
import torch
import torchvision
import detectron2
import json
# import some common detectron2 utilities
import itertools
import seaborn as sns
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.visualizer import ColorMode
from detectron2.structures import Instances
# Model for trees
tree_cfg = get_cfg()
tree_cfg.merge_from_file("tree_model_weights/treev2_cfg.yml")
tree_cfg.MODEL.DEVICE='cpu'
tree_cfg.MODEL.WEIGHTS = "tree_model_weights/treev2_best.pth"
tree_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
tree_predictor = DefaultPredictor(tree_cfg)
# Model for buildings
building_cfg = get_cfg()
building_cfg.merge_from_file("building_model_weight/buildings_poc_cfg.yml")
building_cfg.MODEL.DEVICE='cpu'
building_cfg.MODEL.WEIGHTS = "building_model_weight/model_final.pth"
building_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 8
building_predictor = DefaultPredictor(building_cfg)
# A function that runs the buildings model on an given image and confidence threshold
def segment_building(im, confidence_threshold):
building_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
outputs = building_predictor(im)
building_instances = outputs["instances"].to("cpu")
return building_instances
# A function that runs the trees model on an given image and confidence threshold
def segment_tree(im, confidence_threshold):
tree_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
outputs = tree_predictor(im)
tree_instances = outputs["instances"].to("cpu")
return tree_instances
# Function to map strings to color mode
def map_color_mode(color_mode):
if color_mode == "Black/white":
return ColorMode.IMAGE_BW
elif color_mode == "Random":
return ColorMode.IMAGE
elif color_mode == "Segmentation" or color_mode == None:
return ColorMode.SEGMENTATION
def visualize_image(im, mode, tree_threshold:float, building_threshold:float, color_mode):
im = np.array(im)
color_mode = map_color_mode(color_mode)
if mode == "Trees":
instances = segment_tree(im, tree_threshold)
elif mode == "Buildings":
instances = segment_building(im, building_threshold)
elif mode == "Both" or mode == None:
tree_instances = segment_tree(im, tree_threshold)
building_instances = segment_building(im, building_threshold)
instances = Instances.cat([tree_instances, building_instances])
metadata = MetadataCatalog.get("urban-trees-fdokv_train")
print("metadata", type(metadata), metadata)
print('metadata.get("thing_classes")', type(metadata.get("thing_classes")), metadata.get("thing_classes"))
visualizer = Visualizer(im[:, :, ::-1],
metadata=metadata,
scale=0.5,
instance_mode=color_mode)
dataset_names = MetadataCatalog.list()
print(dataset_names)
metadata = MetadataCatalog.get("urban-small_train")
category_names = metadata.get("thing_classes")
print(category_names)
with open("building_model_weight/_annotations.coco.json", "r") as f:
coco = json.load(f)
categories = coco["categories"]
print("categories", categories)
metadata.thing_classes = [c["name"] for c in categories]
print("metadata.thing_classes", metadata.thing_classes )
# visualizer = Visualizer(im[:, :, ::-1],
# metadata=metadata,
# scale=0.5,
# instance_mode=color_mode)
# # in the visualizer, add category label names to detected instances
# for instance in instances:
# label = category_names[instance["category_id"]]
# visualizer.draw_text(label, instance["bbox"][:2])
output_image = visualizer.draw_instance_predictions(instances)
return Image.fromarray(output_image.get_image()[:, :, ::-1])
|