AndresHdzC commited on
Commit
b38382a
1 Parent(s): d753a84

app.py added

Browse files
Files changed (3) hide show
  1. README.md +1 -0
  2. app.py +200 -0
  3. requirements.txt +1 -0
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: 4.15.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ python_version: 3.8
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """ Demo using Gradio interface"""
5
+
6
+ #%%
7
+ # Importing basic libraries
8
+ import os
9
+ import time
10
+ from PIL import Image
11
+ import supervision as sv
12
+ import gradio as gr
13
+ from zipfile import ZipFile
14
+ from torch.utils.data import DataLoader
15
+
16
+ #%%
17
+ # Importing the models, dataset, transformations, and utility functions from PytorchWildlife
18
+ from PytorchWildlife.models import detection as pw_detection
19
+ from PytorchWildlife.models import classification as pw_classification
20
+ from PytorchWildlife.data import transforms as pw_trans
21
+ from PytorchWildlife.data import datasets as pw_data
22
+ from PytorchWildlife import utils as pw_utils
23
+
24
+ #%%
25
+ # Setting the device to use for computations ('cuda' indicates GPU)
26
+ DEVICE = "cuda"
27
+ # Initializing a supervision box annotator for visualizing detections
28
+ box_annotator = sv.BoxAnnotator(thickness=4, text_thickness=4, text_scale=2)
29
+
30
+ # Initializing the detection and classification models
31
+ detection_model = None
32
+ classification_model = None
33
+
34
+ # Defining transformations for detection and classification
35
+ trans_det = None
36
+ trans_clf = None
37
+
38
+ #%% Defining functions for different detection scenarios
39
+ def load_models(det, clf):
40
+
41
+ global detection_model, classification_model, trans_det, trans_clf
42
+
43
+ detection_model = pw_detection.__dict__[det](device=DEVICE, pretrained=True)
44
+ if clf != "None":
45
+ classification_model = pw_classification.__dict__[clf](device=DEVICE, pretrained=True)
46
+
47
+ trans_det = pw_trans.MegaDetector_v5_Transform(target_size=detection_model.IMAGE_SIZE,
48
+ stride=detection_model.STRIDE)
49
+ trans_clf = pw_trans.Classification_Inference_Transform(target_size=224)
50
+
51
+ return "Loaded Detector: {}. Loaded Classifier: {}".format(det, clf)
52
+
53
+
54
+ def single_image_detection(input_img, det_conf_thres, clf_conf_thres, img_index=None):
55
+ """Performs detection on a single image and returns an annotated image.
56
+
57
+ Args:
58
+ input_img (np.ndarray): Input image in numpy array format defaulted by Gradio.
59
+ det_conf_thre (float): Confidence threshold for detection.
60
+ clf_conf_thre (float): Confidence threshold for classification.
61
+ img_index: Image index identifier.
62
+ Returns:
63
+ annotated_img (PIL.Image.Image): Annotated image with bounding box instances.
64
+ """
65
+ results_det = detection_model.single_image_detection(trans_det(input_img),
66
+ input_img.shape,
67
+ img_path=img_index,
68
+ conf_thres=det_conf_thres)
69
+ if classification_model is not None:
70
+ labels = []
71
+ for xyxy, det_id in zip(results_det["detections"].xyxy, results_det["detections"].class_id):
72
+ # Only run classifier when detection class is animal
73
+ if det_id == 0:
74
+ cropped_image = sv.crop_image(image=input_img, xyxy=xyxy)
75
+ results_clf = classification_model.single_image_classification(trans_clf(Image.fromarray(cropped_image)))
76
+ labels.append("{} {:.2f}".format(results_clf["prediction"] if results_clf["confidence"] > clf_conf_thres else "Unknown",
77
+ results_clf["confidence"]))
78
+ else:
79
+ labels = results_det["labels"]
80
+ else:
81
+ labels = results_det["labels"]
82
+ annotated_img = box_annotator.annotate(scene=input_img, detections=results_det["detections"], labels=labels)
83
+ return annotated_img
84
+
85
+
86
+ def batch_detection(zip_file, det_conf_thres):
87
+ """Perform detection on a batch of images from a zip file and return path to results JSON.
88
+
89
+ Args:
90
+ zip_file (File): Zip file containing images.
91
+ det_conf_thre (float): Confidence threshold for detection.
92
+ clf_conf_thre (float): Confidence threshold for classification.
93
+
94
+ Returns:
95
+ json_save_path (str): Path to the JSON file containing detection results.
96
+ """
97
+ extract_path = os.path.join("..","temp","zip_upload")
98
+ json_save_path = os.path.join(extract_path, "results.json")
99
+ with ZipFile(zip_file.name) as zfile:
100
+ zfile.extractall(extract_path)
101
+ #tgt_folder_path = os.path.join(extract_path, zip_file.name.rsplit(os.sep, 1)[1].rstrip(".zip"))
102
+ tgt_folder_path = os.path.join(extract_path)
103
+ det_dataset = pw_data.DetectionImageFolder(tgt_folder_path, transform=trans_det)
104
+ det_loader = DataLoader(det_dataset, batch_size=32, shuffle=False,
105
+ pin_memory=True, num_workers=8, drop_last=False)
106
+ det_results = detection_model.batch_image_detection(det_loader, conf_thres=det_conf_thres, id_strip=tgt_folder_path)
107
+
108
+ if classification_model is not None:
109
+ clf_dataset = pw_data.DetectionCrops(
110
+ det_results,
111
+ transform=pw_trans.Classification_Inference_Transform(target_size=224),
112
+ path_head=tgt_folder_path
113
+ )
114
+ clf_loader = DataLoader(clf_dataset, batch_size=32, shuffle=False,
115
+ pin_memory=True, num_workers=8, drop_last=False)
116
+ clf_results = classification_model.batch_image_classification(clf_loader, id_strip=tgt_folder_path)
117
+ pw_utils.save_detection_classification_json(det_results=det_results,
118
+ clf_results=clf_results,
119
+ det_categories=detection_model.CLASS_NAMES,
120
+ clf_categories=classification_model.CLASS_NAMES,
121
+ output_path=json_save_path)
122
+ else:
123
+ pw_utils.save_detection_json(det_results, json_save_path, categories=detection_model.CLASS_NAMES)
124
+
125
+ return json_save_path
126
+
127
+
128
+ def video_detection(video, det_conf_thres, clf_conf_thres, target_fps):
129
+ """Perform detection on a video and return path to processed video.
130
+
131
+ Args:
132
+ video (str): Video source path.
133
+ det_conf_thre (float): Confidence threshold for detection.
134
+ clf_conf_thre (float): Confidence threshold for classification.
135
+
136
+ """
137
+ def callback(frame, index):
138
+ annotated_frame = single_image_detection(frame,
139
+ img_index=index,
140
+ det_conf_thres=det_conf_thres,
141
+ clf_conf_thres=clf_conf_thres)
142
+ return annotated_frame
143
+
144
+ target_path = "../temp/video_detection.mp4"
145
+ pw_utils.process_video(source_path=video, target_path=target_path,
146
+ callback=callback, target_fps=target_fps)
147
+ return target_path
148
+
149
+ #%% Building Gradio UI
150
+
151
+ with gr.Blocks() as demo:
152
+ gr.Markdown("# Pytorch-Wildlife Demo.")
153
+ with gr.Row():
154
+ det_drop = gr.Dropdown(
155
+ ["MegaDetectorV5"],
156
+ label="Detection model",
157
+ info="Will add more detection models!",
158
+ value="MegaDetectorV5"
159
+ )
160
+ clf_drop = gr.Dropdown(
161
+ ["None", "AI4GOpossum", "AI4GAmazonRainforest"],
162
+ label="Classification model",
163
+ info="Will add more classification models!",
164
+ value="None"
165
+ )
166
+ with gr.Column():
167
+ load_but = gr.Button("Load Models!")
168
+ load_out = gr.Text("NO MODEL LOADED!!", label="Loaded models:")
169
+ with gr.Tab("Single Image Process"):
170
+ with gr.Row():
171
+ with gr.Column():
172
+ sgl_in = gr.Image()
173
+ sgl_conf_sl_det = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2)
174
+ sgl_conf_sl_clf = gr.Slider(0, 1, label="Classification Confidence Threshold", value=0.7)
175
+ sgl_out = gr.Image()
176
+ sgl_but = gr.Button("Detect Animals!")
177
+ with gr.Tab("Batch Image Process"):
178
+ with gr.Row():
179
+ with gr.Column():
180
+ bth_in = gr.File(label="Upload zip file.")
181
+ bth_conf_sl = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2)
182
+ bth_out = gr.File(label="Detection Results JSON.", height=200)
183
+ bth_but = gr.Button("Detect Animals!")
184
+ with gr.Tab("Single Video Process"):
185
+ with gr.Row():
186
+ with gr.Column():
187
+ vid_in = gr.Video(label="Upload a video.")
188
+ vid_conf_sl_det = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2)
189
+ vid_conf_sl_clf = gr.Slider(0, 1, label="Classification Confidence Threshold", value=0.7)
190
+ vid_fr = gr.Dropdown([5, 10, 30], label="Output video framerate", value=30)
191
+ vid_out = gr.Video()
192
+ vid_but = gr.Button("Detect Animals!")
193
+
194
+ load_but.click(load_models, inputs=[det_drop, clf_drop], outputs=load_out)
195
+ sgl_but.click(single_image_detection, inputs=[sgl_in, sgl_conf_sl_det, sgl_conf_sl_clf], outputs=sgl_out)
196
+ bth_but.click(batch_detection, inputs=[bth_in, bth_conf_sl], outputs=bth_out)
197
+ vid_but.click(video_detection, inputs=[vid_in, vid_conf_sl_det, vid_conf_sl_clf, vid_fr], outputs=vid_out)
198
+
199
+ if __name__ == "__main__":
200
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ PytorchWildlife