Spaces:
Running
Running
File size: 6,996 Bytes
4ba6fde |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from detectors.retinanet.utils_knn import read_LatLotAlt,get_GSD,filter_slice
from detectors.retinanet.encoder_knn import DataEncoder,DataEncoder_fusion
import torch
import json
from detectors.retinanet.tools import py_cpu_nms,get_sub_image
import cv2
model_conf_threshold = {'Bird_A':0.2,
'Bird_B':0.2,
'Bird_C':0.2,
'Bird_D':0.2,
'Bird_E':0.2,
'Bird_drone':0.2}
model_extension = {
'Bird_drone':{40:('_alt_30',30),
75:('_alt_60',60),
90:('_alt_90',90)},
'Bird_drone_KNN':{20:('_alt_15',15),
40:('_alt_30',30),
75:('_alt_60',60),
90:('_alt_90',90)}
}
def get_model_conf_threshold (model_type):
if (model_type in model_conf_threshold):
return model_conf_threshold[model_type]
else:
return 0.3
def get_model_extension(model_type,model_dir,altitude):
if(model_type in model_extension):
model_ext = model_extension[model_type]
for altitude_thresh in model_ext:
if (altitude_thresh>=altitude):
ref_altitude = model_ext[altitude_thresh][1]
# model_dir = model_dir.replace('.pkl',model_ext[altitude_thresh][0]+'.pkl')
return model_dir,ref_altitude
# model_dir = model_dir.replace('.pkl',model_ext[max(model_ext.keys())][0]+'.pkl')
return model_dir,model_ext[max(model_ext.keys())][1]
else:
return model_dir,altitude
class Retinanet_instance():
def __init__(self,input_transform,model_type,model_dir,device =torch.device('cuda'),load_w_config = True,altitude=15):
self.transform = input_transform
self.model_type = model_type
self.load_w_config = load_w_config
self.altitude = altitude
self.model_dir,self.ref_altitude = get_model_extension(model_type,model_dir,altitude)
self.device = device
self.conf_threshold = get_model_conf_threshold(model_type)
self.model = None
self.encoder = None
self.load_model()
def load_model(self):
if (self.load_w_config):
config_dir = self.model_dir.replace('.pkl','.json')
with open(config_dir,'r') as f:
cfg = json.load(f)
from detectors.retinanet.retinanet_fusion import RetinaNet
self.model = RetinaNet(num_classes=1,num_anchors=len(cfg['KNN_anchors']))
self.encoder = DataEncoder_fusion(anchor_wh=cfg['KNN_anchors'],device = self.device)
#self.model.load_state_dict(torch.load(self.model_dir))
else:
from detectors.retinanet.retinanet import RetinaNet
self.model = RetinaNet(num_classes=1)
self.encoder = DataEncoder(self.device)
self.model = torch.load(self.model_dir,map_location=self.device)
self.model = self.model.module.to(self.device)
self.model.eval()
print('check net mode',next(self.model.parameters()).device)
def inference(self,image_dir,slice_overlap,read_GPS = False,debug = True):
mega_image = cv2.imread(image_dir)
mega_image = cv2.cvtColor(mega_image, cv2.COLOR_BGR2RGB)
if (read_GPS):
try:
altitude = read_LatLotAlt(image_dir)['altitude']
# print ('Reading altitude from Meta data of {}'.format(altitude))
except:
altitude = self.altitude
# print ('Meta data not available, use default altitude {}'.format(altitude))
else:
altitude = self.altitude
# print ('Using default altitude {}'.format(altitude))
GSD,ref_GSD = get_GSD(altitude,camera_type='Pro2', ref_altitude=self.ref_altitude)
ratio = 1.0*ref_GSD/GSD
# print('Image processing altitude: {} \t Processing scale {}'.format(altitude,ratio))
sub_image_list, coor_list = get_sub_image(
mega_image, overlap=slice_overlap, ratio=ratio)
bbox_list = []
for index, sub_image in enumerate(sub_image_list):
sub_bbox_list = []
with torch.no_grad():
inputs = self.transform(cv2.resize(
sub_image, (512, 512), interpolation=cv2.INTER_AREA))
inputs = inputs.unsqueeze(0).to(self.device)
loc_preds, cls_preds = self.model(inputs)
boxes, labels, scores = self.encoder.decode(
loc_preds.data.squeeze(), cls_preds.data.squeeze(), 512, CLS_THRESH = self.conf_threshold,NMS_THRESH = 0.25)
if (len(boxes.shape) != 1):
for idx in range(boxes.shape[0]):
x1, y1, x2, y2 = list(
boxes[idx].cpu().numpy()) # (x1,y1, x2,y2)
score = scores.cpu().numpy()[idx]
sub_bbox_list.append([x1,y1,x2,y2,score])
#filter boxes that has overlapped region on sliced images
sub_bbox_list = filter_slice(sub_bbox_list,coor_list[index],sub_image.shape[0],mega_image.shape[:2],dis = int(slice_overlap/2*512))
for sub_box in sub_bbox_list:
x1,y1,x2,y2,score = sub_box
bbox_list.append([coor_list[index][1]+ratio*x1, coor_list[index][0]+ratio*y1,
coor_list[index][1]+ratio*x2, coor_list[index][0]+ratio*y2, score])
box_idx = py_cpu_nms(bbox_list, 0.25)
bbox_list = [bbox_list[i] for i in box_idx]
if (debug):
w = sub_image_list[0].shape[0]
for i,coor in enumerate(coor_list):
cv2.rectangle(mega_image,(coor[1],coor[0]),(coor[1]+w,coor[0]+w),(i, 255-i, 0), 2)
for box in bbox_list:
cv2.putText(mega_image, str(round(box[4], 2)), (int(box[0]), int(
box[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
cv2.rectangle(mega_image, (int(box[0]), int(
box[1])), (int(box[2]), int(box[3])), (255, 0, 0), 2)
return mega_image,bbox_list
if __name__=='__main__':
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
model = Retinanet_instance(input_transform = transform,model_type = 'Bird_drone_KNN',
model_dir = '/home/robert/Models/Retinanet_inference_example/checkpoint/Bird_drone_KNN/final_model.pkl',
device =torch.device('cpu'),load_w_config = True,altitude=15)
image_dir = '/home/robert/Data/drone_collection/Cloud_HarvestedCrop_15m_DJI_0251.jpg'
re = model.inference(image_dir=image_dir,slice_overlap= 0.2)
import matplotlib.pyplot as plt
plt.imshow(re[0])
plt.show()
|