Spaces:
Runtime error
Runtime error
File size: 2,310 Bytes
24910f2 c2f4d83 24910f2 e3dc417 b17bb63 24910f2 a4c8bb5 24910f2 b17bb63 24910f2 b17bb63 6955470 b17bb63 4378f9c 24910f2 a4c8bb5 82d85d1 b17bb63 9964a9f bca0412 24910f2 de5ee51 1263ab0 ef7cf07 a4c8bb5 dea2364 8fc220a dea2364 c2f4d83 d67669b 88d8a4f 24910f2 59bd0c5 de5ee51 a4c8bb5 de5ee51 24910f2 ef7cf07 b17bb63 c2f4d83 24910f2 c2f4d83 de5ee51 ef7cf07 742d503 ef7cf07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
"""
tree-segmentation
Proof of concept showing effectiveness of a fine tuned instance segmentation model for detecting trees.
"""
import os
import cv2
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
from transformers import DetrFeatureExtractor, DetrForSegmentation
from PIL import Image
import gradio as gr
import numpy as np
import torch
import torchvision
import detectron2
# import some common detectron2 utilities
import itertools
import seaborn as sns
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.checkpoint import DetectionCheckpointer
cfg = get_cfg()
cfg.merge_from_file("model_weights/treev1_cfg.yaml")
cfg.MODEL.DEVICE='cpu'
cfg.MODEL.WEIGHTS = "model_weights/treev1_best.pth"
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
def segment_image(im, confidence_threshold):
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.25
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
predictor = DefaultPredictor(cfg)
im = np.array(im)
outputs = predictor(im)
v = Visualizer(im[:, :, ::-1],
scale=0.5,
instance_mode=ColorMode.SEGMENTATION
)
print(len(outputs["instances"])," trees detected.")
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
return Image.fromarray(out.get_image()[:, :, ::-1])
# gradio components
gr_slider_confidence = gr.inputs.Slider(0,1,.1,.7,
label='Set confidence threshold % for masks')
# gradio outputs
inputs = gr.inputs.Image(type="pil", label="Input Image")
outputs = gr.outputs.Image(type="pil", label="Output Image")
title = "Tree Segmentation"
description = "An instance segmentation demo for identifying trees in aerial images using DETR (End-to-End Object Detection) model with MaskRCNN-101 backbone"
# Create user interface and launch
gr.Interface(segment_image,
inputs = [inputs, gr_slider_confidence],
outputs = outputs,
title = title,
description = description).launch(debug=True) |