File size: 3,996 Bytes
b6f51cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

from tkinter import W
import gradio as gr
from matplotlib import cm
import torch
import torchvision
from dlclive import DLCLive, Processor
import matplotlib
from PIL import Image, ImageColor, ImageFont, ImageDraw 
import numpy as np
import math


import yaml
import pdb

############################################
# Predict detections with MegaDetector v5a model
def predict_md(im, 
               megadetector_model, #Megadet_Models[mega_model_input]
               size=640):
               
    # resize image
    g = (size / max(im.size))  # multipl factor to make max size of the image equal to input size
    im = im.resize((int(x * g) for x in im.size), 
                    Image.ANTIALIAS)  # resize
    # device
    if torch.cuda.is_available():
        md_device = torch.device('cuda')
    else:
        md_device = torch.device('cpu')

    # megadetector    
    MD_model = torch.hub.load('ultralytics/yolov5', # repo_or_dir
                              'custom', #model
                              megadetector_model, # args for callable model
                              force_reload=True,
                              device=md_device) 
                              
    # send model to gpu if possible
    if (md_device == torch.device('cuda')):
        print('Sending model to GPU')
        MD_model.to(md_device)    

    ## detect objects
    results = MD_model(im)  # inference # vars(results).keys()= dict_keys(['imgs', 'pred', 'names', 'files', 'times', 'xyxy', 'xywh', 'xyxyn', 'xywhn', 'n', 't', 's'])
    
    return results  


##########################################
def crop_animal_detections(img_in,
                           yolo_results, 
                           likelihood_th):

    ## Extract animal crops
    list_labels_as_str = [i for i in yolo_results.names.values()]  # ['animal', 'person', 'vehicle'] 
    list_np_animal_crops = []

    # image to crop (scale as input for megadetector)
    img_in = img_in.resize((yolo_results.ims[0].shape[1],
                            yolo_results.ims[0].shape[0]))
    # for every detection in the img                        
    for det_array in yolo_results.xyxy:

        # for every detection
        for j in range(det_array.shape[0]):

            # compute coords around bbox rounded to the nearest integer (for pasting later)
            xmin_rd = int(math.floor(det_array[j,0])) # int() should suffice?
            ymin_rd = int(math.floor(det_array[j,1]))

            xmax_rd = int(math.ceil(det_array[j,2]))
            ymax_rd = int(math.ceil(det_array[j,3]))

            pred_llk = det_array[j,4] 
            pred_label = det_array[j,5]
            # keep animal crops above threshold
            if (pred_label == list_labels_as_str.index('animal')) and \
                (pred_llk >= likelihood_th):
                area = (xmin_rd, ymin_rd, xmax_rd, ymax_rd)

                #pdb.set_trace()
                crop = img_in.crop(area) #Image.fromarray(img_in).crop(area)
                crop_np = np.asarray(crop)

                # add to list
                list_np_animal_crops.append(crop_np)

    return list_np_animal_crops

##########################################
def predict_dlc(list_np_crops,
                kpts_likelihood_th,
                DLCmodel,
                dlc_proc):
    
    # run dlc thru list of crops
    dlc_live = DLCLive(DLCmodel, processor=dlc_proc)
    dlc_live.init_inference(list_np_crops[0])

    list_kpts_per_crop = []
    all_kypts  = []
    np_aux = np.empty((1,3)) # can I avoid hardcoding here?
    for crop in list_np_crops:
        # scale crop here?
        keypts_xyp = dlc_live.get_pose(crop) # third column is llk!
        # set kpts below threhsold to nan
        
        #pdb.set_trace()
        keypts_xyp[keypts_xyp[:,-1] < kpts_likelihood_th,:] = np_aux.fill(np.nan)
        # add kpts of this crop to list 
        list_kpts_per_crop.append(keypts_xyp)
        all_kypts.append(keypts_xyp)
    
    return list_kpts_per_crop