Spaces:
Runtime error
Runtime error
""" | |
inference on single image for MaskRCNN (FROM DETECTRON) + DLC | |
two step, pretrained MaskRCNN, then DLC | |
""" | |
import cv2 | |
import torch | |
import sys | |
sys.path.append("Repositories/DeepLabCut-live") | |
import deeplabcut as dlc | |
from dlclive import DLCLive, Processor | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from tqdm import tqdm | |
import os | |
import shutil | |
import torchvision | |
from torchvision.transforms import transforms as transforms | |
import pickle | |
import detectron2 | |
# import some common detectron2 utilities | |
from detectron2 import model_zoo | |
from detectron2.engine import DefaultPredictor | |
from detectron2.config import get_cfg | |
from detectron2.utils.visualizer import Visualizer | |
from detectron2.data import MetadataCatalog, DatasetCatalog | |
import cv2 | |
COCO_INSTANCE_CATEGORY_NAMES = [ | |
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', | |
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', | |
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', | |
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', | |
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', | |
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | |
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', | |
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' | |
] | |
def Process_Crop(Crop, CropSize): | |
"""Crop image and pad, if too big, will scale down """ | |
# import ipdb;ipdb.set_trace() | |
if Crop.shape[0] > CropSize[0] or Crop.shape[1] > CropSize[1]: #Crop is bigger, scale down | |
ScaleProportion = min(CropSize[0]/Crop.shape[0],CropSize[1]/Crop.shape[1]) | |
width_scaled = int(Crop.shape[1] * ScaleProportion) | |
height_scaled = int(Crop.shape[0] * ScaleProportion) | |
Crop = cv2.resize(Crop, (width_scaled,height_scaled), interpolation=cv2.INTER_LINEAR) # resize image | |
# Points2D = {k:[v[0]*ScaleProportion,v[1]*ScaleProportion] for k,v in Points2D.items()} | |
else: | |
ScaleProportion = 1 | |
if Crop.shape[0] %2 ==0: | |
#Shape is even number | |
YPadTop = int((CropSize[1] - Crop.shape[0])/2) | |
YPadBot = int((CropSize[1] - Crop.shape[0])/2) | |
else: | |
YPadTop = int( ((CropSize[1] - Crop.shape[0])/2)-0.5) | |
YPadBot = int(((CropSize[1] - Crop.shape[0])/2)+0.5) | |
##Padding: | |
if Crop.shape[1] %2 ==0: | |
#Shape is even number | |
XPadLeft = int((CropSize[0] - Crop.shape[1])/2) | |
XPadRight= int((CropSize[0] - Crop.shape[1])/2) | |
else: | |
XPadLeft = int(((CropSize[0] - Crop.shape[1])/2)-0.5) | |
XPadRight= int(((CropSize[0] - Crop.shape[1])/2)+0.5) | |
OutImage = cv2.copyMakeBorder(Crop, YPadTop,YPadBot,XPadLeft,XPadRight,cv2.BORDER_CONSTANT,value=[0,0,0]) | |
return OutImage,ScaleProportion, YPadTop,XPadLeft | |
def DLCInference(Crop,dlc_liveObj,CropSize): | |
"""Inference for DLC""" | |
###Scale crop if image bigger than cropsize | |
# import ipdb;ipdb.set_trace() | |
if Crop.shape[0] > CropSize[0] or Crop.shape[1] > CropSize[1]: #Image bigger than crop size, scale down | |
ScaleRatio = min([CropSize[0]/Crop.shape[0], CropSize[1]/Crop.shape[1]]) | |
ScaleWidth = round(Crop.shape[1] * ScaleRatio) | |
ScaleHeight = round(Crop.shape[0]*ScaleRatio) | |
resizedCrop = cv2.resize(Crop, (ScaleWidth,ScaleHeight), interpolation=cv2.INTER_LINEAR) # resize image | |
ScaleUpRatio = 1/ScaleRatio #ratio to scale keypoints back up to original | |
# import ipdb;ipdb.set_trace() | |
else: | |
resizedCrop = Crop | |
ScaleUpRatio = 1 | |
# cv2.imwrite(filename="tempresize.jpg", img=resizedCrop) | |
# cv2.imwrite(filename="temp.jpg", img=Crop) | |
if dlc_liveObj.sess == None: #if first time, init | |
DLCPredict2D = dlc_liveObj.init_inference(resizedCrop) | |
DLCPredict2D= dlc_liveObj.get_pose(resizedCrop) | |
DLCPredict2D[:,0] = DLCPredict2D[:,0]*ScaleUpRatio | |
DLCPredict2D[:,1] = DLCPredict2D[:,1]*ScaleUpRatio | |
return DLCPredict2D | |
def VisualizeAll(frame, box, DLCPredict2D,ScaleBBox, imsize): | |
"""Visualize all stuff""" | |
colourList = [(0,255,255),(255,0 ,255),(128,0,128),(255,192,203),(255, 255, 0),(0, 0 , 255 ),(205,133,63),(0,255,0),(255,0,0)] | |
##Order: Lshoulder, Rshoulder, topKeel,botKeel,Tail,Beak,Nose,Leye,Reye | |
##Points: | |
PlotPoints = [] | |
for x,point in enumerate(DLCPredict2D): | |
roundPoint = [round(point[0]+box[0]),round(point[1]+box[1])] | |
cv2.circle(frame,roundPoint,1,colourList[x], 5) | |
PlotPoints.append(roundPoint) | |
cv2.rectangle(frame,(round(box[0]),round(box[1])),(round(box[2]),round(box[3])),[0,0,255],3) | |
return frame, PlotPoints | |
def Inference(frame,predictor,dlc_liveObj,ScaleBBox=1,Dilate=5,DLCThreshold=0.3): | |
"""Loop through video for SAM, save framewise info""" | |
InferFrame = frame.copy() | |
outputs = predictor(InferFrame)["instances"].to("cpu") | |
CropSize = (320,320) | |
# import ipdb;ipdb.set_trace() | |
imsize = [frame.shape[1],frame.shape[0]] | |
BirdIndex = np.where(outputs.pred_classes.numpy() == 14)[0] #14 is ID for bird | |
BirdBBox = outputs.pred_boxes[BirdIndex].tensor.numpy() | |
# import ipdb;ipdb.set_trace() | |
BirdMasks = (outputs.pred_masks>0.7).numpy()[BirdIndex] | |
for x in range(BirdBBox.shape[0]): | |
# import ipdb;ipdb.set_trace() | |
bbox = list(BirdBBox[x]) | |
Mask = BirdMasks[x]>0 | |
Mask = np.array(Mask,dtype=np.uint8) | |
# show_anns(frame, Mask) | |
if Dilate > 0: | |
DilateKernel = np.ones((Dilate,Dilate),np.uint8) | |
Mask = cv2.dilate(Mask,DilateKernel,iterations = 3) | |
# import ipdb;ipdb.set_trace() | |
Mask = np.array(Mask,dtype=np.uint8) | |
Mask = Mask.reshape(imsize[1],imsize[0],1) | |
Crop = cv2.bitwise_and(InferFrame, InferFrame, mask=Mask) | |
# cv2.imwrite(filename="temp.jpg", img = Crop) | |
##change box to XYWH to scale up | |
bbox = [bbox[0],bbox[1],bbox[2]-bbox[0],bbox[3]-bbox[1]] | |
ScaleWidth = ((ScaleBBox * bbox[2])/2)-(bbox[2]/2) | |
ScaleHeight = ((ScaleBBox * bbox[3])/2)-(bbox[3]/2) | |
# import ipdb;ipdb.set_trace() | |
# BirdCrop = frame[round(bbox[1]):round(bbox[3]),round(bbox[0]):round(bbox[2])] #bbox is XYWH | |
x1 = round(bbox[0]-ScaleWidth) if round(bbox[0]-ScaleWidth)>0 else 0 | |
y1 = round(bbox[1]-ScaleHeight)if round(bbox[1]-ScaleHeight)>0 else 0 | |
x2 = round(bbox[0]+bbox[2]+ScaleWidth) if round(bbox[0]+bbox[2]+ScaleWidth) < imsize[0] else imsize[0] | |
y2 = round(bbox[1]+bbox[3]+ScaleHeight)if round(bbox[1]+bbox[3]+ScaleHeight) < imsize[1] else imsize[1] | |
bbox = [x1,y1,x2,y2] | |
BirdCrop = Crop[y1:y2,x1:x2] #bbox is XYWH | |
DLCPredict2D= DLCInference(BirdCrop,dlc_liveObj,CropSize) | |
frame, PlotPoints = VisualizeAll(frame, bbox, DLCPredict2D,ScaleBBox,imsize) | |
if BirdBBox.shape[0] == 0: | |
DLCPredict2D= DLCInference(InferFrame,dlc_liveObj,CropSize) | |
bbox = [0,0,0,0] | |
frame, PlotPoints = VisualizeAll(frame, bbox, DLCPredict2D,ScaleBBox,imsize) | |
return frame | |