|
import cv2 |
|
import numpy as np |
|
import gradio as gr |
|
from detectron2 import model_zoo |
|
from detectron2.config import get_cfg |
|
from detectron2.engine import DefaultPredictor |
|
from detectron2.utils.visualizer import Visualizer |
|
from detectron2.data import MetadataCatalog |
|
|
|
def initialize_model(): |
|
for d in ["train", "test"]: |
|
|
|
MetadataCatalog.get("wheat_" + d).set(thing_classes=["wheat"]) |
|
|
|
wheat_metadata = MetadataCatalog.get("wheat_train") |
|
cfg = get_cfg() |
|
cfg.MODEL.DEVICE = "cpu" |
|
cfg.DATALOADER.NUM_WORKERS = 0 |
|
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml") |
|
cfg.SOLVER.IMS_PER_BATCH = 2 |
|
cfg.SOLVER.BASE_LR = 0.00025 |
|
cfg.SOLVER.STEPS = [] |
|
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 |
|
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 |
|
cfg.MODEL.WEIGHTS = "output/model_final.pth" |
|
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.95 |
|
predictor = DefaultPredictor(cfg) |
|
return predictor |
|
|
|
def process_image(predictor, img): |
|
outputs = predictor(img) |
|
wheat_metadata = MetadataCatalog.get("wheat_train") |
|
v = Visualizer(img[:, :, ::-1], |
|
metadata=wheat_metadata, |
|
scale=1.5, |
|
instance_mode="segmentation") |
|
out = v.draw_instance_predictions(outputs["instances"].to("cpu")) |
|
processed_img = cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB) |
|
return processed_img |
|
|
|
def main(img): |
|
predictor = initialize_model() |
|
processed_img = process_image(predictor, img) |
|
return processed_img |
|
|
|
|
|
iface = gr.Interface( |
|
fn=main, |
|
inputs="image", |
|
outputs="image", |
|
title="Wheat head Detector & Counting Wheat heads", |
|
cache_examples=False, port=7861).launch(share=True) |