File size: 6,996 Bytes
4ba6fde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from detectors.retinanet.utils_knn import read_LatLotAlt,get_GSD,filter_slice
from detectors.retinanet.encoder_knn import DataEncoder,DataEncoder_fusion
import torch
import json
from detectors.retinanet.tools import py_cpu_nms,get_sub_image

import cv2




model_conf_threshold = {'Bird_A':0.2,
                        'Bird_B':0.2,
                        'Bird_C':0.2,
                        'Bird_D':0.2,
                        'Bird_E':0.2,
                        'Bird_drone':0.2}
model_extension = {
        'Bird_drone':{40:('_alt_30',30),
                        75:('_alt_60',60),
                        90:('_alt_90',90)},
        'Bird_drone_KNN':{20:('_alt_15',15),
                        40:('_alt_30',30),
                        75:('_alt_60',60),
                        90:('_alt_90',90)}
                    }

def get_model_conf_threshold (model_type):
    if (model_type in model_conf_threshold):
        return model_conf_threshold[model_type]
    else:
        return 0.3
        
def get_model_extension(model_type,model_dir,altitude):
    if(model_type in model_extension):
        model_ext = model_extension[model_type]
        for altitude_thresh in model_ext:
            if (altitude_thresh>=altitude):
                ref_altitude = model_ext[altitude_thresh][1]
                # model_dir = model_dir.replace('.pkl',model_ext[altitude_thresh][0]+'.pkl')
                return model_dir,ref_altitude
        # model_dir = model_dir.replace('.pkl',model_ext[max(model_ext.keys())][0]+'.pkl')
        return model_dir,model_ext[max(model_ext.keys())][1]
    else:
        return model_dir,altitude

class Retinanet_instance():
    def __init__(self,input_transform,model_type,model_dir,device =torch.device('cuda'),load_w_config = True,altitude=15):
        self.transform = input_transform
        self.model_type = model_type
        self.load_w_config = load_w_config
        self.altitude = altitude
        self.model_dir,self.ref_altitude = get_model_extension(model_type,model_dir,altitude)
        self.device = device
        self.conf_threshold = get_model_conf_threshold(model_type)
        self.model = None
        self.encoder = None
        self.load_model()
    
    def load_model(self):
        if (self.load_w_config):
            config_dir = self.model_dir.replace('.pkl','.json')
            with open(config_dir,'r') as f:
                cfg = json.load(f)
            from detectors.retinanet.retinanet_fusion import RetinaNet
            self.model = RetinaNet(num_classes=1,num_anchors=len(cfg['KNN_anchors']))
            self.encoder = DataEncoder_fusion(anchor_wh=cfg['KNN_anchors'],device = self.device)
            #self.model.load_state_dict(torch.load(self.model_dir))
        else:
            from detectors.retinanet.retinanet import RetinaNet
            self.model = RetinaNet(num_classes=1)
            self.encoder = DataEncoder(self.device)
        self.model = torch.load(self.model_dir,map_location=self.device)
        self.model = self.model.module.to(self.device)
        self.model.eval()
        print('check net mode',next(self.model.parameters()).device)

    def inference(self,image_dir,slice_overlap,read_GPS = False,debug = True):
        mega_image = cv2.imread(image_dir)
        mega_image = cv2.cvtColor(mega_image, cv2.COLOR_BGR2RGB)
        if (read_GPS):
            try:
                altitude = read_LatLotAlt(image_dir)['altitude']
                # print ('Reading altitude from Meta data of {}'.format(altitude))
            except:
                altitude = self.altitude
                # print ('Meta data not available, use default altitude {}'.format(altitude))
        else:
            altitude = self.altitude
            # print ('Using default altitude {}'.format(altitude))
        GSD,ref_GSD = get_GSD(altitude,camera_type='Pro2', ref_altitude=self.ref_altitude)
        ratio = 1.0*ref_GSD/GSD
        # print('Image processing altitude: {} \t Processing scale {}'.format(altitude,ratio))
        sub_image_list, coor_list = get_sub_image(
            mega_image, overlap=slice_overlap, ratio=ratio)
        
        bbox_list = []
        for index, sub_image in enumerate(sub_image_list):
            sub_bbox_list = []
            with torch.no_grad():
                inputs = self.transform(cv2.resize(
                    sub_image, (512, 512), interpolation=cv2.INTER_AREA))
                inputs = inputs.unsqueeze(0).to(self.device)
                loc_preds, cls_preds = self.model(inputs)
                boxes, labels, scores = self.encoder.decode(
                    loc_preds.data.squeeze(), cls_preds.data.squeeze(), 512, CLS_THRESH = self.conf_threshold,NMS_THRESH = 0.25)
            if (len(boxes.shape) != 1):
                for idx in range(boxes.shape[0]):
                    x1, y1, x2, y2 = list(
                        boxes[idx].cpu().numpy())  # (x1,y1, x2,y2)
                    score = scores.cpu().numpy()[idx]
                    sub_bbox_list.append([x1,y1,x2,y2,score])
                #filter boxes that has overlapped region on sliced images

                sub_bbox_list = filter_slice(sub_bbox_list,coor_list[index],sub_image.shape[0],mega_image.shape[:2],dis = int(slice_overlap/2*512))
                
                for sub_box in sub_bbox_list:
                    x1,y1,x2,y2,score = sub_box
                    bbox_list.append([coor_list[index][1]+ratio*x1, coor_list[index][0]+ratio*y1,
                                     coor_list[index][1]+ratio*x2, coor_list[index][0]+ratio*y2, score])

        box_idx = py_cpu_nms(bbox_list, 0.25)
        bbox_list = [bbox_list[i] for i in box_idx]
        if (debug):
            w = sub_image_list[0].shape[0]
            for i,coor in enumerate(coor_list):
                cv2.rectangle(mega_image,(coor[1],coor[0]),(coor[1]+w,coor[0]+w),(i, 255-i, 0), 2)


        for box in bbox_list:
            cv2.putText(mega_image, str(round(box[4], 2)), (int(box[0]), int(
                            box[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
            cv2.rectangle(mega_image, (int(box[0]), int(
                            box[1])), (int(box[2]), int(box[3])), (255, 0, 0), 2)
        return mega_image,bbox_list

if __name__=='__main__':
    import torchvision.transforms as transforms
    transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
    model = Retinanet_instance(input_transform = transform,model_type = 'Bird_drone_KNN',
                            model_dir = '/home/robert/Models/Retinanet_inference_example/checkpoint/Bird_drone_KNN/final_model.pkl',
                            device =torch.device('cpu'),load_w_config = True,altitude=15)
    image_dir = '/home/robert/Data/drone_collection/Cloud_HarvestedCrop_15m_DJI_0251.jpg'
    re = model.inference(image_dir=image_dir,slice_overlap= 0.2)
    import matplotlib.pyplot as plt
    plt.imshow(re[0])
    plt.show()