File size: 8,127 Bytes
c20ef4d
 
 
86af013
b6f51cf
 
 
959b3a8
b6f51cf
 
a065329
959b3a8
b6f51cf
 
a065329
026ff32
b6f51cf
 
d581ff8
b6f51cf
07db84b
b6f51cf
2b6769d
b6f51cf
 
57ac38a
b6f51cf
052cded
 
b6f51cf
936e066
 
959b3a8
07db84b
b6f51cf
 
 
 
 
a065329
07db84b
a065329
57ac38a
07db84b
645a407
07db84b
2b6769d
a065329
 
 
 
2b6769d
07db84b
b6f51cf
 
 
 
 
 
 
 
 
 
 
 
 
07db84b
b6f51cf
 
 
 
 
 
 
 
 
 
 
a065329
b6f51cf
 
07db84b
 
b6f51cf
07db84b
 
 
b6f51cf
07db84b
 
 
 
 
a065329
07db84b
 
 
959b3a8
07db84b
 
 
645a407
3f4d50f
a065329
 
 
 
026ff32
 
 
 
07db84b
 
 
 
 
 
 
959b3a8
b6f51cf
 
 
959b3a8
b6f51cf
07db84b
 
 
 
b6f51cf
 
07db84b
 
 
645a407
a065329
 
 
 
 
07db84b
b6f51cf
 
 
959b3a8
b6f51cf
959b3a8
 
 
b6f51cf
 
 
 
959b3a8
 
b6f51cf
026ff32
959b3a8
 
07db84b
b6f51cf
 
 
 
 
 
 
 
a065329
b6f51cf
16fb395
a065329
 
 
 
 
b6f51cf
d581ff8
8d64a48
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# Built from https://huggingface.co/spaces/hlydecker/MegaDetector_v5 
# Built from https://huggingface.co/spaces/sofmi/MegaDetector_DLClive/blob/main/app.py
# Built from https://huggingface.co/spaces/Neslihan/megadetector_dlcmodels/blob/main/app.py 

import os
import yaml
import numpy as np
from matplotlib import cm
import gradio as gr

from PIL import Image, ImageColor, ImageFont, ImageDraw 
# check git lfs pull!!
from DLC_models.download_utils import DownloadModel
from dlclive import DLCLive, Processor

from viz_utils import save_results_as_json, draw_keypoints_on_image, draw_bbox_w_text, save_results_only_dlc
from detection_utils import predict_md, crop_animal_detections, predict_dlc
from ui_utils import gradio_inputs_for_MD_DLC, gradio_outputs_for_MD_DLC, gradio_description_and_examples

# import pdb
#########################################
# Input params - Global vars

MD_models_dict = {'md_v5a': "MD_models/md_v5a.0.0.pt", # 
                  'md_v5b': "MD_models/md_v5b.0.0.pt"}

# DLC models target  dirs
DLC_models_dict = {#'full_cat': "DLC_models/DLC_Cat/",
                    #'full_dog': "DLC_models/DLC_Dog/",
                    'full_human': "DLC_models/DLC_human_dancing/",
                    'full_macaque': 'DLC_models/DLC_monkey/',
                    'primate_face': "DLC_models/DLC_FacialLandmarks/"}
                 

# FONTS = {'amiko': "fonts/Amiko-Regular.ttf",
#         'nature': "fonts/LoveNature.otf", 
#         'painter':"fonts/PainterDecorator.otf",
#         'animals': "fonts/UncialAnimals.ttf", 
#         'zen': "fonts/ZEN.TTF"}
#####################################################
def predict_pipeline(img_input,
                     mega_model_input,
                     dlc_model_input_str,
                     flag_dlc_only,
                     flag_show_str_labels,
                     bbox_likelihood_th,
                     kpts_likelihood_th,
                     font_style,
                     font_size,
                     keypt_color,
                     marker_size,
                     ):

    if not flag_dlc_only:
        ############################################################                                               
        # ### Run Megadetector
        md_results = predict_md(img_input, 
                                MD_models_dict[mega_model_input], #mega_model_input,
                                size=640) #Image.fromarray(results.imgs[0])

        ################################################################
        # Obtain animal crops for bboxes with confidence above th
        list_crops = crop_animal_detections(img_input,
                                            md_results,
                                            bbox_likelihood_th)

    ############################################################
    ## Get DLC model and label map  
    
    # If model is found: do not download (previous execution is likely within same day)
    # TODO: can we ask the user whether to reload dlc model if a directory is found?
    if os.path.isdir(DLC_models_dict[dlc_model_input_str]) and \
        len(os.listdir(DLC_models_dict[dlc_model_input_str])) > 0:
        path_to_DLCmodel = DLC_models_dict[dlc_model_input_str]
    else:
        path_to_DLCmodel = DownloadModel(dlc_model_input_str, 
                                         DLC_models_dict[dlc_model_input_str])

    # extract map label ids to strings
    pose_cfg_path = os.path.join(DLC_models_dict[dlc_model_input_str],
                                 'pose_cfg.yaml')
    with open(pose_cfg_path, "r") as stream:
        pose_cfg_dict = yaml.safe_load(stream) 
    map_label_id_to_str = dict([(k,v) for k,v in zip([el[0] for el in pose_cfg_dict['all_joints']],  # pose_cfg_dict['all_joints'] is a list of one-element lists,
                                                     pose_cfg_dict['all_joints_names'])])

    ##############################################################
    # Run DLC and visualise results
    dlc_proc = Processor()

    # if required: ignore MD crops and run DLC on full image [mostly for testing]
    if flag_dlc_only:
        # compute kpts on input img
        list_kpts_per_crop = predict_dlc([np.asarray(img_input)],
                                         kpts_likelihood_th,
                                         path_to_DLCmodel,
                                         dlc_proc)
        # draw kpts on input img #fix!
        draw_keypoints_on_image(img_input,
                                list_kpts_per_crop[0], # a numpy array with shape [num_keypoints, 2].
                                map_label_id_to_str,
                                flag_show_str_labels,
                                use_normalized_coordinates=False,
                                font_style=font_style,
                                font_size=font_size,
                                keypt_color=keypt_color,
                                marker_size=marker_size)

        donw_file = save_results_only_dlc(list_kpts_per_crop[0], map_label_id_to_str,dlc_model_input_str)

        return img_input, donw_file

    else:
        # Compute kpts for each crop
        list_kpts_per_crop = predict_dlc(list_crops,
                                         kpts_likelihood_th,
                                         path_to_DLCmodel,
                                         dlc_proc)
        
        # resize input image to match megadetector output
        img_background = img_input.resize((md_results.ims[0].shape[1],
                                           md_results.ims[0].shape[0]))
        
        # draw keypoints on each crop and paste to background img
        for ic, (np_crop, kpts_crop) in enumerate(zip(list_crops,
                                                      list_kpts_per_crop)):

            img_crop = Image.fromarray(np_crop)

            # Draw keypts on crop
            draw_keypoints_on_image(img_crop,
                                    kpts_crop, # a numpy array with shape [num_keypoints, 2].
                                    map_label_id_to_str,
                                    flag_show_str_labels,
                                    use_normalized_coordinates=False,  # if True, then I should use md_results.xyxyn for list_kpts_crop
                                    font_style=font_style,
                                    font_size=font_size,
                                    keypt_color=keypt_color,
                                    marker_size=marker_size)

            # Paste crop in original image
            img_background.paste(img_crop, 
                                 box = tuple([int(t) for t in md_results.xyxy[0][ic,:2]]))

            # Plot bbox
            bb_per_animal = md_results.xyxy[0].tolist()[ic]
            pred = md_results.xyxy[0].tolist()[ic][4]
            if bbox_likelihood_th < pred:
                draw_bbox_w_text(img_background, 
                                    bb_per_animal,
                                    font_style=font_style,
                                    font_size=font_size)  # TODO: add selectable color for bbox?

            
        # Save detection results as json
        download_file  = save_results_as_json(md_results,list_kpts_per_crop,map_label_id_to_str, bbox_likelihood_th,dlc_model_input_str,mega_model_input)         

        return img_background, download_file

#########################################################
# Define user interface and launch
inputs = gradio_inputs_for_MD_DLC(list(MD_models_dict.keys()),
                                  list(DLC_models_dict.keys()))
outputs = gradio_outputs_for_MD_DLC()                                    
[gr_title, 
 gr_description, 
 examples] = gradio_description_and_examples()

# launch
demo = gr.Interface(predict_pipeline, 
                    inputs=inputs,
                    outputs=outputs, 
                    title=gr_title, 
                    description=gr_description,
                    examples=examples,
                    theme="huggingface")

demo.launch(enable_queue=True, share=True)