Frank Pacini commited on
Commit
6155c0e
β€’
1 Parent(s): e694ec3
CustomFile.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ # from typing import Dict
3
+ # import base64
4
+
5
+ # def encode_file_to_base64(f):
6
+ # with open(f, "rb") as file:
7
+ # encoded_string = base64.b64encode(file.read())
8
+ # base64_str = str(encoded_string, "utf-8")
9
+ # return base64_str
10
+
11
+ class CustomFile(gr.File):
12
+ # def postprocess(self, y: str) -> Dict:
13
+ # res = super().postprocess(y)
14
+ # if res is not None:
15
+ # res['data'] = encode_file_to_base64(res['name'])
16
+ # return res
17
+ def dummy(self):
18
+ return
19
+
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Fal2022 Videoanalysis V2
3
- emoji: πŸ’©
4
  colorFrom: yellow
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.12.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Fall2022 Videoanalysis
3
+ emoji: πŸ“ˆ
4
  colorFrom: yellow
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.12.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from slowfast import slow_fast_train
4
+ from video_object_extraction import video_object_extraction
5
+ from audio_feature_extraction_final import audio_feature_extraction
6
+ from CustomFile import CustomFile
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import pickle
11
+ import torch
12
+
13
+ try:
14
+ import detectron2
15
+ except:
16
+ import os
17
+ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
18
+
19
+
20
+ def predict(video_path, frames):
21
+ # gpu = torch.cuda.is_available()
22
+ # video_1, df1 = slow_fast_train(video_path, gpu)
23
+ # video_2, df2 = video_object_extraction(video_path,frames)
24
+ # audio_path = audio_feature_extraction(video_path, gpu)
25
+ # return ([video_1, video_2,audio_path], df1, df2)
26
+ audio_features = np.random.rand(2,2)
27
+ audio_path = 'audio_embeddings'
28
+ with open(audio_path, 'wb') as f:
29
+ pickle.dump(audio_features, f)
30
+ df = pd.DataFrame()
31
+ return ([video_path, video_path, audio_path], df, df)
32
+
33
+
34
+ iface = gr.Interface(predict, inputs= [gr.Video(),gr.Slider(1, 100, value=15)], outputs=[gr.File(), gr.Dataframe(max_rows = 10),gr.Dataframe(max_rows = 10)])
35
+ iface.launch(show_error=True, debug=True)
audio_feature_extraction_final.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchaudio import load as torchaudio_load
3
+ from moviepy.editor import VideoFileClip
4
+
5
+ from pyannote.audio import Pipeline
6
+ from sklearn.preprocessing import LabelEncoder
7
+ from librosa import load as librosa_load
8
+ import librosa.display
9
+ import math
10
+ import pandas as pd
11
+
12
+ import sys
13
+ from tqdm import tqdm
14
+ import numpy as np
15
+ from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration, pipeline as transformers_pipeline
16
+ import pickle
17
+
18
+
19
+ """"Author: Frank"""
20
+ def extract_s2t_features(gpu):
21
+ model_name="medium"
22
+ processor = Speech2TextProcessor.from_pretrained("facebook/s2t-{}-librispeech-asr".format(model_name))
23
+ model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-{}-librispeech-asr".format(model_name))
24
+ if gpu:
25
+ model = model.cuda()
26
+ model.load_state_dict(torch.load('s2t_model'))
27
+ model.eval()
28
+
29
+ sample_rate = 16000
30
+ embedding_window = 10 # in secs
31
+
32
+ audio, _ = torchaudio_load('temp.wav')
33
+ audio = torch.mean(audio, dim=0)
34
+
35
+ embs = []
36
+ audio_clips = audio.split(embedding_window*sample_rate)
37
+ if len(audio_clips) > 1:
38
+ audio_clips = audio_clips[:-1]
39
+ for clip in tqdm(audio_clips):
40
+ with torch.no_grad():
41
+ inputs = processor(clip, sampling_rate=16000, return_tensors="pt")
42
+ features = inputs["input_features"]
43
+ decoder_input = torch.zeros(features.shape[:2], dtype=torch.int32)
44
+ if gpu:
45
+ features, decoder_input = features.cuda(), decoder_input.cuda()
46
+
47
+ h = model.model(features, decoder_input_ids=decoder_input).last_hidden_state.cpu()
48
+ emb = torch.mean(h,axis=1)
49
+ embs.append(emb)
50
+ return torch.cat(embs).numpy()
51
+
52
+
53
+ """"Author: Sichao"""
54
+ def extract_speaker_features(gpu):
55
+ x , sample_rate = librosa_load('temp.wav')
56
+ print('Input sample rate: {}, Length: {} s'.format(sample_rate, x.size/sample_rate))
57
+
58
+ # speaker diarization
59
+ print('Start speaker diarization...')
60
+ pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1", use_auth_token='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')
61
+ diarization = pipeline('temp.wav')
62
+ speaker_per_sec_dict = {i: 'UNKNOWN' for i in range(0, math.ceil(x.size/sample_rate))}
63
+
64
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
65
+ for clip_start in range(math.ceil(turn.start), math.ceil(turn.end)):
66
+ if speaker_per_sec_dict[clip_start] == 'UNKNOWN':
67
+ speaker_per_sec_dict[clip_start] = speaker
68
+ elif speaker_per_sec_dict[clip_start] != speaker:
69
+ speaker_per_sec_dict[clip_start] = speaker_per_sec_dict[clip_start] + ' ' + speaker
70
+
71
+ speaker_per_clip = []
72
+ for i in range(0, math.ceil(x.size/sample_rate), 10):
73
+ speakers = []
74
+ for j in range(10):
75
+ if i + j in speaker_per_sec_dict and speaker_per_sec_dict[i + j] != 'UNKNOWN':
76
+ speakers.append(speaker_per_sec_dict[i + j])
77
+ if len(speakers) > 0:
78
+ is_single_speaker = all(s == speakers[0] for s in speakers)
79
+ if is_single_speaker:
80
+ speaker_per_clip.append(speakers[0])
81
+ else:
82
+ speaker_per_clip.append('MULTI SPEAKER')
83
+ else:
84
+ speaker_per_clip.append('UNKNOWN')
85
+
86
+ # Adult child classification
87
+ print('Start adult child classification...')
88
+ device = 0 if gpu else -1
89
+ audio_classifier = transformers_pipeline(task="audio-classification", model="bookbot/wav2vec2-adult-child-cls", device=device)
90
+ clip_idxs = [i for i in range(0, math.ceil(x.size/sample_rate), 10)]
91
+ classifications = []
92
+ for clip_start in tqdm(clip_idxs):
93
+ with torch.no_grad():
94
+ preds = audio_classifier(x[clip_start*sample_rate:(clip_start + 10)*sample_rate])
95
+ preds = [{"score": round(pred["score"], 4), "label": pred["label"]} for pred in preds]
96
+ classifications.append(preds[0]['label'])
97
+
98
+ # output
99
+ print('Output...')
100
+ output = {'clip_start':clip_idxs, 'diarization':speaker_per_clip, 'adult_child_classification':classifications}
101
+ output_df = pd.DataFrame(output)
102
+ # Creating a instance of label Encoder.
103
+ le = LabelEncoder()
104
+
105
+ # encoder and return encoded label
106
+ output_df['diarization_numeric'] = le.fit_transform(output_df['diarization'])
107
+ output_df['adult_child_classification_numeric'] = le.fit_transform(output_df['adult_child_classification'])
108
+ return output_df['diarization_numeric'].values, output_df['adult_child_classification_numeric'].values
109
+
110
+ def audio_feature_extraction(input_path, gpu=False):
111
+ output_path = 'audio_embedding'
112
+ audioTrack = VideoFileClip(input_path).audio
113
+ audioTrack.write_audiofile('temp.wav', codec='pcm_s16le', fps=16000)
114
+
115
+ print('Extracting s2t features...')
116
+ s2t_features = extract_s2t_features(gpu)
117
+ print('Extracting speaker features...')
118
+ diarization_features, adult_child_class_features = extract_speaker_features(gpu)
119
+
120
+ if len(diarization_features) > 1:
121
+ diarization_features, adult_child_class_features = diarization_features[:-1], adult_child_class_features[:-1]
122
+ audio_features = np.concatenate((s2t_features, diarization_features[:, None], adult_child_class_features[:, None]), axis=1)
123
+ with open(output_path, 'wb') as f:
124
+ pickle.dump(audio_features, f)
125
+ return output_path
ava_action_list.pbtxt ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ item {
2
+ name: "bend/bow (at the waist)"
3
+ id: 1
4
+ }
5
+ item {
6
+ name: "crouch/kneel"
7
+ id: 3
8
+ }
9
+ item {
10
+ name: "dance"
11
+ id: 4
12
+ }
13
+ item {
14
+ name: "fall down"
15
+ id: 5
16
+ }
17
+ item {
18
+ name: "get up"
19
+ id: 6
20
+ }
21
+ item {
22
+ name: "jump/leap"
23
+ id: 7
24
+ }
25
+ item {
26
+ name: "lie/sleep"
27
+ id: 8
28
+ }
29
+ item {
30
+ name: "martial art"
31
+ id: 9
32
+ }
33
+ item {
34
+ name: "run/jog"
35
+ id: 10
36
+ }
37
+ item {
38
+ name: "sit"
39
+ id: 11
40
+ }
41
+ item {
42
+ name: "stand"
43
+ id: 12
44
+ }
45
+ item {
46
+ name: "swim"
47
+ id: 13
48
+ }
49
+ item {
50
+ name: "walk"
51
+ id: 14
52
+ }
53
+ item {
54
+ name: "answer phone"
55
+ id: 15
56
+ }
57
+ item {
58
+ name: "carry/hold (an object)"
59
+ id: 17
60
+ }
61
+ item {
62
+ name: "climb (e.g., a mountain)"
63
+ id: 20
64
+ }
65
+ item {
66
+ name: "close (e.g., a door, a box)"
67
+ id: 22
68
+ }
69
+ item {
70
+ name: "cut"
71
+ id: 24
72
+ }
73
+ item {
74
+ name: "dress/put on clothing"
75
+ id: 26
76
+ }
77
+ item {
78
+ name: "drink"
79
+ id: 27
80
+ }
81
+ item {
82
+ name: "drive (e.g., a car, a truck)"
83
+ id: 28
84
+ }
85
+ item {
86
+ name: "eat"
87
+ id: 29
88
+ }
89
+ item {
90
+ name: "enter"
91
+ id: 30
92
+ }
93
+ item {
94
+ name: "hit (an object)"
95
+ id: 34
96
+ }
97
+ item {
98
+ name: "lift/pick up"
99
+ id: 36
100
+ }
101
+ item {
102
+ name: "listen (e.g., to music)"
103
+ id: 37
104
+ }
105
+ item {
106
+ name: "open (e.g., a window, a car door)"
107
+ id: 38
108
+ }
109
+ item {
110
+ name: "play musical instrument"
111
+ id: 41
112
+ }
113
+ item {
114
+ name: "point to (an object)"
115
+ id: 43
116
+ }
117
+ item {
118
+ name: "pull (an object)"
119
+ id: 45
120
+ }
121
+ item {
122
+ name: "push (an object)"
123
+ id: 46
124
+ }
125
+ item {
126
+ name: "put down"
127
+ id: 47
128
+ }
129
+ item {
130
+ name: "read"
131
+ id: 48
132
+ }
133
+ item {
134
+ name: "ride (e.g., a bike, a car, a horse)"
135
+ id: 49
136
+ }
137
+ item {
138
+ name: "sail boat"
139
+ id: 51
140
+ }
141
+ item {
142
+ name: "shoot"
143
+ id: 52
144
+ }
145
+ item {
146
+ name: "smoke"
147
+ id: 54
148
+ }
149
+ item {
150
+ name: "take a photo"
151
+ id: 56
152
+ }
153
+ item {
154
+ name: "text on/look at a cellphone"
155
+ id: 57
156
+ }
157
+ item {
158
+ name: "throw"
159
+ id: 58
160
+ }
161
+ item {
162
+ name: "touch (an object)"
163
+ id: 59
164
+ }
165
+ item {
166
+ name: "turn (e.g., a screwdriver)"
167
+ id: 60
168
+ }
169
+ item {
170
+ name: "watch (e.g., TV)"
171
+ id: 61
172
+ }
173
+ item {
174
+ name: "work on a computer"
175
+ id: 62
176
+ }
177
+ item {
178
+ name: "write"
179
+ id: 63
180
+ }
181
+ item {
182
+ name: "fight/hit (a person)"
183
+ id: 64
184
+ }
185
+ item {
186
+ name: "give/serve (an object) to (a person)"
187
+ id: 65
188
+ }
189
+ item {
190
+ name: "grab (a person)"
191
+ id: 66
192
+ }
193
+ item {
194
+ name: "hand clap"
195
+ id: 67
196
+ }
197
+ item {
198
+ name: "hand shake"
199
+ id: 68
200
+ }
201
+ item {
202
+ name: "hand wave"
203
+ id: 69
204
+ }
205
+ item {
206
+ name: "hug (a person)"
207
+ id: 70
208
+ }
209
+ item {
210
+ name: "kiss (a person)"
211
+ id: 72
212
+ }
213
+ item {
214
+ name: "lift (a person)"
215
+ id: 73
216
+ }
217
+ item {
218
+ name: "listen to (a person)"
219
+ id: 74
220
+ }
221
+ item {
222
+ name: "push (another person)"
223
+ id: 76
224
+ }
225
+ item {
226
+ name: "sing to (e.g., self, a person, a group)"
227
+ id: 77
228
+ }
229
+ item {
230
+ name: "take (an object) from (a person)"
231
+ id: 78
232
+ }
233
+ item {
234
+ name: "talk to (e.g., self, a person, a group)"
235
+ id: 79
236
+ }
237
+ item {
238
+ name: "watch (a person)"
239
+ id: 80
240
+ }
coco.names ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ person
2
+ bicycle
3
+ car
4
+ motorbike
5
+ aeroplane
6
+ bus
7
+ train
8
+ truck
9
+ boat
10
+ traffic light
11
+ fire hydrant
12
+ stop sign
13
+ parking meter
14
+ bench
15
+ bird
16
+ cat
17
+ dog
18
+ horse
19
+ sheep
20
+ cow
21
+ elephant
22
+ bear
23
+ zebra
24
+ giraffe
25
+ backpack
26
+ umbrella
27
+ handbag
28
+ tie
29
+ suitcase
30
+ frisbee
31
+ skis
32
+ snowboard
33
+ sports ball
34
+ kite
35
+ baseball bat
36
+ baseball glove
37
+ skateboard
38
+ surfboard
39
+ tennis racket
40
+ bottle
41
+ wine glass
42
+ cup
43
+ fork
44
+ knife
45
+ spoon
46
+ bowl
47
+ banana
48
+ apple
49
+ sandwich
50
+ orange
51
+ broccoli
52
+ carrot
53
+ hot dog
54
+ pizza
55
+ donut
56
+ cake
57
+ chair
58
+ sofa
59
+ pottedplant
60
+ bed
61
+ diningtable
62
+ toilet
63
+ tvmonitor
64
+ laptop
65
+ mouse
66
+ remote
67
+ keyboard
68
+ cell phone
69
+ microwave
70
+ oven
71
+ toaster
72
+ sink
73
+ refrigerator
74
+ book
75
+ clock
76
+ vase
77
+ scissors
78
+ teddy bear
79
+ hair drier
80
+ toothbrush
environment.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ name: env
2
+ dependencies:
3
+ - cudatoolkit
4
+ - pip:
5
+ - -r requirements.txt
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imutils
2
+ matplotlib
3
+ numpy
4
+ pandas
5
+ opencv-python
6
+ ffmpeg-python
7
+ pytorchvideo
8
+
9
+ cython
10
+ scipy
11
+ tqdm
12
+ gdown
13
+ cmake
14
+
15
+ #Torch
16
+ --find-links https://download.pytorch.org/whl/cu111
17
+ torch==1.10.0
18
+ torchvision==0.11.1
19
+
20
+ # Detectron
21
+ --find-links https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/index.html
22
+ detectron2
23
+
24
+ moviepy
25
+ pyannote.audio
26
+ scikit-learn
27
+ librosa
28
+ transformers
slowfast.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import cv2
4
+ import torch
5
+ import warnings
6
+ from detectron2.config import get_cfg
7
+ from detectron2 import model_zoo
8
+ from detectron2.engine import DefaultPredictor
9
+ import ffmpeg
10
+ import pytorchvideo
11
+ from pytorchvideo.transforms.functional import (
12
+ uniform_temporal_subsample,
13
+ short_side_scale_with_boxes,
14
+ clip_boxes_to_image
15
+ )
16
+ from torchvision.transforms._functional_video import normalize
17
+ from pytorchvideo.data.ava import AvaLabeledVideoFramePaths
18
+ from pytorchvideo.models.hub import slowfast_r50_detection # Another option is slow_r50_detection
19
+ from visualization import VideoVisualizer
20
+
21
+
22
+ # This method takes in an image and generates the bounding boxes for people in the image.
23
+ def get_person_bboxes(inp_img, predictor):
24
+ predictions = predictor(inp_img.cpu().detach().numpy())['instances'].to('cpu')
25
+ boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
26
+ scores = predictions.scores if predictions.has("scores") else None
27
+ classes = np.array(predictions.pred_classes.tolist() if predictions.has("pred_classes") else None)
28
+ predicted_boxes = boxes[np.logical_and(classes==0, scores>0.75 )].tensor.cpu() # only person
29
+ return predicted_boxes
30
+
31
+
32
+ def ava_inference_transform(
33
+ clip,
34
+ boxes,
35
+ num_frames = 32, # 4 if using slowfast_r50_detection, change this to 32
36
+ crop_size = 256,
37
+ data_mean = [0.45, 0.45, 0.45],
38
+ data_std = [0.225, 0.225, 0.225],
39
+ slow_fast_alpha = 4, # if using slowfast_r50_detection, change None to 4
40
+ device = 'cpu'):
41
+
42
+ boxes = np.array(boxes)
43
+ ori_boxes = boxes.copy()
44
+
45
+ # Image [0, 255] -> [0, 1].
46
+ clip = uniform_temporal_subsample(clip, num_frames)
47
+ clip = clip.float()
48
+ clip = clip / 255.0
49
+
50
+ height, width = clip.shape[2], clip.shape[3]
51
+ # The format of boxes is [x1, y1, x2, y2]. The input boxes are in the
52
+ # range of [0, width] for x and [0,height] for y
53
+ boxes = clip_boxes_to_image(boxes, height, width)
54
+
55
+ # Resize short side to crop_size. Non-local and STRG uses 256.
56
+ clip, boxes = short_side_scale_with_boxes(clip, size=crop_size, boxes=boxes)
57
+
58
+ # Normalize images by mean and std.
59
+ clip = normalize(clip, np.array(data_mean, dtype=np.float32), np.array(data_std, dtype=np.float32))
60
+
61
+ boxes = clip_boxes_to_image(boxes, clip.shape[2], clip.shape[3])
62
+
63
+ # Incase of slowfast, generate both pathways
64
+ if slow_fast_alpha is not None:
65
+ fast_pathway = clip
66
+ # Perform temporal sampling from the fast pathway.
67
+ slow_pathway = torch.index_select(clip, 1, torch.linspace(
68
+ 0, clip.shape[1] - 1, clip.shape[1] // slow_fast_alpha).long())
69
+ clip = [slow_pathway.unsqueeze(0).to(device), fast_pathway.unsqueeze(0).to(device)]
70
+
71
+ return clip, torch.from_numpy(boxes), ori_boxes
72
+
73
+ # get video info
74
+ def with_opencv(filename):
75
+ video = cv2.VideoCapture(filename)
76
+ frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
77
+ fps = video.get(cv2.CAP_PROP_FPS)
78
+ s = round(frame_count / fps)
79
+ video.release()
80
+ return int(s), fps
81
+
82
+
83
+ def slow_fast_train(file_path, gpu=False):
84
+ device = 'cuda' if gpu else 'cpu'
85
+ top_k = 1
86
+
87
+ video_model = slowfast_r50_detection(True) # Another option is slow_r50_detection(True)
88
+ video_model = video_model.eval().to(device)
89
+ cfg = get_cfg()
90
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
91
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.55 # set threshold for this model
92
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
93
+ cfg.MODEL.DEVICE = device
94
+ predictor = DefaultPredictor(cfg)
95
+ # Create an id to label name mapping
96
+ label_map, allowed_class_ids = AvaLabeledVideoFramePaths.read_label_map('ava_action_list.pbtxt')
97
+ # Create a video visualizer that can plot bounding boxes and visualize actions on bboxes.
98
+ video_visualizer = VideoVisualizer(81, label_map, top_k=top_k, mode="thres",thres=0.5) #get top3 predictions show in each bounding box
99
+
100
+ #preprocess video data
101
+ encoded_vid = pytorchvideo.data.encoded_video.EncodedVideo.from_path(file_path)
102
+
103
+ # Video predictions are generated each frame/second for the wholevideo.
104
+ total_sec, fps = with_opencv(file_path)
105
+ time_stamp_range = range(0, total_sec) # time stamps in video for which clip is sampled
106
+ clip_duration = 1.0 # Duration of clip used for each inference step.
107
+ gif_imgs = []
108
+ xleft, ytop, xright, ybottom = [], [], [], []
109
+ labels = []
110
+ time_frame = []
111
+ scores = []
112
+
113
+ for time_stamp in time_stamp_range:
114
+
115
+ # Generate clip around the designated time stamps
116
+ inp_imgs = encoded_vid.get_clip(
117
+ time_stamp - clip_duration/2.0,
118
+ time_stamp + clip_duration/2.0)
119
+ inp_imgs = inp_imgs['video']
120
+
121
+ #if time_stamp % 15 == 0:
122
+ # Generate people bbox predictions using Detectron2's off the self pre-trained predictor
123
+ # We use the the middle image in each clip to generate the bounding boxes.
124
+ inp_img = inp_imgs[:,inp_imgs.shape[1]//2,:,:]
125
+ inp_img = inp_img.permute(1,2,0)
126
+
127
+ # Predicted boxes are of the form List[(x_1, y_1, x_2, y_2)]
128
+ predicted_boxes = get_person_bboxes(inp_img, predictor)
129
+ if len(predicted_boxes) == 0:
130
+ print("Skipping clip no frames detected at time stamp: ", time_stamp)
131
+ continue
132
+
133
+ # Preprocess clip and bounding boxes for video action recognition.
134
+ inputs, inp_boxes, _ = ava_inference_transform(inp_imgs, predicted_boxes.numpy(), device=device)
135
+ # Prepend data sample id for each bounding box.
136
+ # For more details refere to the RoIAlign in Detectron2
137
+ inp_boxes = torch.cat([torch.zeros(inp_boxes.shape[0],1), inp_boxes], dim=1)
138
+
139
+ # Generate actions predictions for the bounding boxes in the clip.
140
+ # The model here takes in the pre-processed video clip and the detected bounding boxes.
141
+ preds = video_model(inputs, inp_boxes.to(device)) #change inputs to inputs.unsqueeze(0).to(device) if using slow_r50
142
+
143
+ preds = preds.to('cpu')
144
+ # The model is trained on AVA and AVA labels are 1 indexed so, prepend 0 to convert to 0 index.
145
+ preds = torch.cat([torch.zeros(preds.shape[0],1), preds], dim=1)
146
+
147
+ # Plot predictions on the video and save for later visualization.
148
+ inp_imgs = inp_imgs.permute(1,2,3,0)
149
+ inp_imgs = inp_imgs/255.0
150
+ out_img_pred = video_visualizer.draw_clip_range(inp_imgs, preds, predicted_boxes)
151
+ gif_imgs += out_img_pred
152
+
153
+ #format of bboxes(x_left, y_top, x_right, y_bottom)
154
+ predicted_boxes_lst = predicted_boxes.tolist()
155
+ topscores, topclasses = torch.topk(preds, k=1)
156
+ topscores, topclasses = topscores.tolist(), topclasses.tolist()
157
+ topclasses = np.concatenate(topclasses)
158
+ topscores = np.concatenate(topscores)
159
+
160
+ #add top 1 prediction of behaviors in each time step
161
+ for i in range(len(predicted_boxes_lst)):
162
+ xleft.append(predicted_boxes_lst[i][0])
163
+ ytop.append(predicted_boxes_lst[i][1])
164
+ xright.append(predicted_boxes_lst[i][2])
165
+ ybottom.append(predicted_boxes_lst[i][3])
166
+ labels.append(label_map.get(topclasses[i]))
167
+ time_frame.append(time_stamp)
168
+ scores.append(topscores[i])
169
+
170
+ print("Finished generating predictions.")
171
+ # Generate Metadata file
172
+ metadata = pd.DataFrame()
173
+ metadata['frame'] = time_frame
174
+ metadata['x_left'] = xleft
175
+ metadata['y_top'] = ytop
176
+ metadata['x_right'] = xright
177
+ metadata['y_bottom'] = ybottom
178
+ metadata['label'] = labels
179
+ metadata['confidence'] = scores
180
+
181
+ height, width = gif_imgs[0].shape[0], gif_imgs[0].shape[1]
182
+ video_save_path = 'activity_recognition.mp4'
183
+ video = cv2.VideoWriter(video_save_path, cv2.VideoWriter_fourcc(*'mp4v'), int(fps), (width, height))
184
+
185
+ for image in gif_imgs:
186
+ img = (255*image).astype(np.uint8)
187
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
188
+ video.write(img)
189
+ video.release()
190
+
191
+ return video_save_path, metadata
video_object_extraction.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Tue Nov 8 16:18:28 2022
5
+
6
+ @author: ariellee
7
+ """
8
+
9
+ # import argparse
10
+ from pathlib import Path
11
+ import cv2
12
+ import numpy as np
13
+ from imutils.video import FPS
14
+ import pandas as pd
15
+ import os
16
+
17
+
18
+ # def str2bool(v):
19
+ # """
20
+ # Converts string to bool type, enables command line
21
+ # arguments in the format of '--arg1 true --arg2 false'
22
+ # """
23
+ # if isinstance(v, bool):
24
+ # return v
25
+ # if v.lower() in ('yes', 'true', 't', 'y', '1'):
26
+ # return True
27
+ # elif v.lower() in ('no', 'false', 'f', 'n', '0'):
28
+ # return False
29
+ # else:
30
+ # raise argparse.ArgumentTypeError('Boolean value expected (true/false)')
31
+
32
+
33
+ # def get_args_parser():
34
+ # parser = argparse.ArgumentParser('Wheelock evaluation script for classroom object detection',
35
+ # add_help=False)
36
+
37
+ # parser.add_argument('--output_dir', default='', type=str,
38
+ # help='path to save the feature extraction results')
39
+
40
+ # parser.add_argument('--output_name', default='video_out', type=str, help='name of csv \
41
+ # file with object features and annotated video with object tracking \
42
+ # and bounding boxes')
43
+
44
+ # parser.add_argument('--video_path', default='short',
45
+ # type=str, help='path to input video, do not include file extension')
46
+
47
+ # parser.add_argument('--is_mp4', type=str2bool, default=False,
48
+ # help='must be an mp4 file')
49
+
50
+ # parser.add_argument('--save_csv', type=str2bool, default=True,
51
+ # help='if true, a csv file of extracted features will be saved in output_dir')
52
+
53
+ # parser.add_argument('--labels', default='coco.names', type=str,
54
+ # help='labels for classes model can detect')
55
+
56
+ # parser.add_argument('--weights', default='yolov3.weights', type=str,
57
+ # help='weights for pretrained yolo model')
58
+
59
+ # parser.add_argument('--cfg', default='yolov3.cfg', type=str,
60
+ # help='model configuration parameters')
61
+ # return parser
62
+
63
+
64
+ def video_object_extraction(video_path, frames):
65
+ '''
66
+ Object detection and feature extraction with yolov3
67
+ Uses darknet repo by pjreddie
68
+
69
+ Returns: (1) csv file with extracted object features
70
+ columns: frame_number, x_start, y_start, x_end, y_end, label, confidence
71
+ (2) mp4 video with object bounding boxes and tracking
72
+
73
+ '''
74
+ # video_path = args.video_path + '.mp4'
75
+ print('Reading from video {}...'.format(video_path))
76
+ cap = cv2.VideoCapture(video_path)
77
+
78
+ # get total number of frames in the video
79
+ total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
80
+
81
+ # get height and width of video
82
+ H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
83
+ W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
84
+
85
+ fps = FPS().start()
86
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
87
+
88
+ # (cols, rows) format
89
+ # root = os.path.join(args.output_dir, args.output_name)
90
+ wp = 'object_detection.mp4'
91
+ g_fps = int(cap.get(cv2.CAP_PROP_FPS))
92
+ writer = cv2.VideoWriter(wp, fourcc, g_fps, (W, H))
93
+ # labels = open(args.labels).read().strip().split('\n')
94
+ labels = open('coco.names').read().strip().split('\n')
95
+ bbox_colors = np.random.randint(0, 255, size=(len(labels), 3), dtype='uint8')
96
+
97
+ yolo = cv2.dnn.readNetFromDarknet('yolov3.cfg', 'yolov3.weights')
98
+ out_layers = yolo.getLayerNames()
99
+ layers = [out_layers[i - 1] for i in yolo.getUnconnectedOutLayers()]
100
+ count = 0
101
+ stat_list = []
102
+
103
+ while count < total_frames:
104
+
105
+ _, frame = cap.read()
106
+
107
+ if count == 0 or count % frames == 0:
108
+ blob = cv2.dnn.blobFromImage(frame, 1 / 255.0, (416, 416), swapRB=True)
109
+ yolo.setInput(blob)
110
+
111
+ layer_outputs = yolo.forward(layers)
112
+ boxes = []
113
+ confidences = []
114
+ classes = []
115
+
116
+ # loop over layer outputs and objects detected
117
+ for output in layer_outputs:
118
+ for obj in output:
119
+
120
+ # extract class and detection likelihood of current object
121
+ scores = obj[5:]
122
+ obj_class = np.argmax(scores)
123
+ confidence = scores[obj_class]
124
+
125
+ # get rid of bad predictions
126
+ if confidence > 0.4:
127
+
128
+ # scale bbox coordinates relative to frame size
129
+ box = obj[0:4] * np.array([W, H, W, H])
130
+ centerX, centerY, width, height = box.astype('int')
131
+
132
+ # final coordiantes
133
+ x = int(centerX - (width / 2))
134
+ y = int(centerY - (height / 2))
135
+
136
+ # update list of bbox coordinates, confidences, classes
137
+ boxes.append([x, y, int(width), int(height)])
138
+ confidences.append(float(confidence))
139
+ classes.append(obj_class)
140
+
141
+ # non-max suppression for overlapping bounding boxes
142
+ idxs = cv2.dnn.NMSBoxes(boxes, confidences, 0.4, 0.4)
143
+
144
+ for i in idxs.flatten():
145
+
146
+ # extract coordinates
147
+ (x, y) = (boxes[i][0], boxes[i][1])
148
+ (w, h) = (boxes[i][2], boxes[i][3])
149
+
150
+ # set up + add bboxes to frame
151
+ color = [int(c) for c in bbox_colors[classes[i]]]
152
+ cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
153
+ text = "{}: {:.4f}".format(labels[classes[i]], confidences[i])
154
+ (text_width, text_height), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
155
+ cv2.rectangle(frame, (x, y - text_height), (x + text_width, y), color, cv2.FILLED)
156
+ cv2.putText(frame, text, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (38, 38, 38), 2)
157
+
158
+ # format of each csv file is: frame number / x / y / w / h / label / confidence
159
+ stat_list.append([count, x, y, w, h, labels[classes[i]], confidences[i]])
160
+
161
+ writer.write(frame)
162
+ fps.update()
163
+ count += 1
164
+
165
+ df = pd.DataFrame(stat_list, columns=['frame', 'x_left', 'y_top', 'x_right',
166
+ 'y_bottom', 'label', 'confidence'])
167
+ fps.stop()
168
+ print('Time elapsed (seconds): {:.2f}'.format(fps.elapsed()))
169
+ writer.release()
170
+ cap.release()
171
+
172
+ return wp, df
173
+
174
+
175
+ # if __name__ == '__main__':
176
+
177
+ # parser = argparse.ArgumentParser('Wheelock evaluation script for classroom object detection', parents=[get_args_parser()])
178
+ # args = parser.parse_args()
179
+
180
+ # if not args.is_mp4:
181
+ # print('Video must be an mp4 file.')
182
+ # else:
183
+ # if args.output_dir:
184
+ # Path(args.output_dir).mkdir(parents=True, exist_ok=True)
185
+ # main(args)
visualization.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ # Note: This file has been barrowed from facebookresearch/slowfast repo. And it is used to add the bounding boxes and predictions to the frame.
3
+ # TODO: Migrate this into the core PyTorchVideo libarary.
4
+ from __future__ import annotations
5
+
6
+ import itertools
7
+ # import logging
8
+ from types import SimpleNamespace
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import torch
14
+ from detectron2.utils.visualizer import Visualizer
15
+
16
+
17
+ # logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _create_text_labels(
21
+ classes: List[int],
22
+ scores: List[float],
23
+ class_names: List[str],
24
+ ground_truth: bool = False,
25
+ ) -> List[str]:
26
+ """
27
+ Create text labels.
28
+ Args:
29
+ classes (list[int]): a list of class ids for each example.
30
+ scores (list[float] or None): list of scores for each example.
31
+ class_names (list[str]): a list of class names, ordered by their ids.
32
+ ground_truth (bool): whether the labels are ground truth.
33
+ Returns:
34
+ labels (list[str]): formatted text labels.
35
+ """
36
+ try:
37
+ labels = [class_names.get(c, "n/a") for c in classes]
38
+ except IndexError:
39
+ # logger.error("Class indices get out of range: {}".format(classes))
40
+ return None
41
+
42
+ if ground_truth:
43
+ labels = ["[{}] {}".format("GT", label) for label in labels]
44
+ elif scores is not None:
45
+ assert len(classes) == len(scores)
46
+ labels = ["[{:.2f}] {}".format(s, label) for s, label in zip(scores, labels)]
47
+ return labels
48
+
49
+
50
+ class ImgVisualizer(Visualizer):
51
+ def __init__(
52
+ self, img_rgb: torch.Tensor, meta: Optional[SimpleNamespace] = None, **kwargs
53
+ ) -> None:
54
+ """
55
+ See https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py
56
+ for more details.
57
+ Args:
58
+ img_rgb: a tensor or numpy array of shape (H, W, C), where H and W correspond to
59
+ the height and width of the image respectively. C is the number of
60
+ color channels. The image is required to be in RGB format since that
61
+ is a requirement of the Matplotlib library. The image is also expected
62
+ to be in the range [0, 255].
63
+ meta (MetadataCatalog): image metadata.
64
+ See https://github.com/facebookresearch/detectron2/blob/81d5a87763bfc71a492b5be89b74179bd7492f6b/detectron2/data/catalog.py#L90
65
+ """
66
+ super(ImgVisualizer, self).__init__(img_rgb, meta, **kwargs)
67
+
68
+ def draw_text(
69
+ self,
70
+ text: str,
71
+ position: List[int],
72
+ *,
73
+ font_size: Optional[int] = None,
74
+ color: str = "w",
75
+ horizontal_alignment: str = "center",
76
+ vertical_alignment: str = "bottom",
77
+ box_facecolor: str = "black",
78
+ alpha: float = 0.5,
79
+ ) -> None:
80
+ """
81
+ Draw text at the specified position.
82
+ Args:
83
+ text (str): the text to draw on image.
84
+ position (list of 2 ints): the x,y coordinate to place the text.
85
+ font_size (Optional[int]): font of the text. If not provided, a font size
86
+ proportional to the image width is calculated and used.
87
+ color (str): color of the text. Refer to `matplotlib.colors` for full list
88
+ of formats that are accepted.
89
+ horizontal_alignment (str): see `matplotlib.text.Text`.
90
+ vertical_alignment (str): see `matplotlib.text.Text`.
91
+ box_facecolor (str): color of the box wrapped around the text. Refer to
92
+ `matplotlib.colors` for full list of formats that are accepted.
93
+ alpha (float): transparency level of the box.
94
+ """
95
+ if not font_size:
96
+ font_size = self._default_font_size
97
+ x, y = position
98
+ self.output.ax.text(
99
+ x,
100
+ y,
101
+ text,
102
+ size=font_size * self.output.scale,
103
+ family="monospace",
104
+ bbox={
105
+ "facecolor": box_facecolor,
106
+ "alpha": alpha,
107
+ "pad": 0.7,
108
+ "edgecolor": "none",
109
+ },
110
+ verticalalignment=vertical_alignment,
111
+ horizontalalignment=horizontal_alignment,
112
+ color=color,
113
+ zorder=10,
114
+ )
115
+
116
+ def draw_multiple_text(
117
+ self,
118
+ text_ls: List[str],
119
+ box_coordinate: torch.Tensor,
120
+ *,
121
+ top_corner: bool = True,
122
+ font_size: Optional[int] = None,
123
+ color: str = "w",
124
+ box_facecolors: str = "black",
125
+ alpha: float = 0.5,
126
+ ) -> None:
127
+ """
128
+ Draw a list of text labels for some bounding box on the image.
129
+ Args:
130
+ text_ls (list of strings): a list of text labels.
131
+ box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom)
132
+ coordinates of the box.
133
+ top_corner (bool): If True, draw the text labels at (x_left, y_top) of the box.
134
+ Else, draw labels at (x_left, y_bottom).
135
+ font_size (Optional[int]): font of the text. If not provided, a font size
136
+ proportional to the image width is calculated and used.
137
+ color (str): color of the text. Refer to `matplotlib.colors` for full list
138
+ of formats that are accepted.
139
+ box_facecolors (str): colors of the box wrapped around the text. Refer to
140
+ `matplotlib.colors` for full list of formats that are accepted.
141
+ alpha (float): transparency level of the box.
142
+ """
143
+ if not isinstance(box_facecolors, list):
144
+ box_facecolors = [box_facecolors] * len(text_ls)
145
+ assert len(box_facecolors) == len(
146
+ text_ls
147
+ ), "Number of colors provided is not equal to the number of text labels."
148
+ if not font_size:
149
+ font_size = self._default_font_size
150
+ text_box_width = font_size + font_size // 2
151
+ # If the texts does not fit in the assigned location,
152
+ # we split the text and draw it in another place.
153
+ if top_corner:
154
+ num_text_split = self._align_y_top(
155
+ box_coordinate, len(text_ls), text_box_width
156
+ )
157
+ y_corner = 1
158
+ else:
159
+ num_text_split = len(text_ls) - self._align_y_bottom(
160
+ box_coordinate, len(text_ls), text_box_width
161
+ )
162
+ y_corner = 3
163
+
164
+ text_color_sorted = sorted(
165
+ zip(text_ls, box_facecolors), key=lambda x: x[0], reverse=True
166
+ )
167
+ if len(text_color_sorted) != 0:
168
+ text_ls, box_facecolors = zip(*text_color_sorted)
169
+ else:
170
+ text_ls, box_facecolors = [], []
171
+ text_ls, box_facecolors = list(text_ls), list(box_facecolors)
172
+ self.draw_multiple_text_upward(
173
+ text_ls[:num_text_split][::-1],
174
+ box_coordinate,
175
+ y_corner=y_corner,
176
+ font_size=font_size,
177
+ color=color,
178
+ box_facecolors=box_facecolors[:num_text_split][::-1],
179
+ alpha=alpha,
180
+ )
181
+ self.draw_multiple_text_downward(
182
+ text_ls[num_text_split:],
183
+ box_coordinate,
184
+ y_corner=y_corner,
185
+ font_size=font_size,
186
+ color=color,
187
+ box_facecolors=box_facecolors[num_text_split:],
188
+ alpha=alpha,
189
+ )
190
+
191
+ def draw_multiple_text_upward(
192
+ self,
193
+ text_ls: List[str],
194
+ box_coordinate: torch.Tensor,
195
+ *,
196
+ y_corner: int = 1,
197
+ font_size: Optional[int] = None,
198
+ color: str = "w",
199
+ box_facecolors: str = "black",
200
+ alpha: float = 0.5,
201
+ ) -> None:
202
+ """
203
+ Draw a list of text labels for some bounding box on the image in upward direction.
204
+ The next text label will be on top of the previous one.
205
+ Args:
206
+ text_ls (list of strings): a list of text labels.
207
+ box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom)
208
+ coordinates of the box.
209
+ y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of
210
+ the box to draw labels around.
211
+ font_size (Optional[int]): font of the text. If not provided, a font size
212
+ proportional to the image width is calculated and used.
213
+ color (str): color of the text. Refer to `matplotlib.colors` for full list
214
+ of formats that are accepted.
215
+ box_facecolors (str or list of strs): colors of the box wrapped around the
216
+ text. Refer to `matplotlib.colors` for full list of formats that
217
+ are accepted.
218
+ alpha (float): transparency level of the box.
219
+ """
220
+ if not isinstance(box_facecolors, list):
221
+ box_facecolors = [box_facecolors] * len(text_ls)
222
+ assert len(box_facecolors) == len(
223
+ text_ls
224
+ ), "Number of colors provided is not equal to the number of text labels."
225
+
226
+ assert y_corner in [1, 3], "Y_corner must be either 1 or 3"
227
+ if not font_size:
228
+ font_size = self._default_font_size
229
+
230
+ x, horizontal_alignment = self._align_x_coordinate(box_coordinate)
231
+ y = box_coordinate[y_corner].item()
232
+ for i, text in enumerate(text_ls):
233
+ self.draw_text(
234
+ text,
235
+ (x, y),
236
+ font_size=font_size,
237
+ color=color,
238
+ horizontal_alignment=horizontal_alignment,
239
+ vertical_alignment="bottom",
240
+ box_facecolor=box_facecolors[i],
241
+ alpha=alpha,
242
+ )
243
+ y -= font_size + font_size // 2
244
+
245
+ def draw_multiple_text_downward(
246
+ self,
247
+ text_ls: List[str],
248
+ box_coordinate: torch.Tensor,
249
+ *,
250
+ y_corner: int = 1,
251
+ font_size: Optional[int] = None,
252
+ color: str = "w",
253
+ box_facecolors: str = "black",
254
+ alpha: float = 0.5,
255
+ ) -> None:
256
+ """
257
+ Draw a list of text labels for some bounding box on the image in downward direction.
258
+ The next text label will be below the previous one.
259
+ Args:
260
+ text_ls (list of strings): a list of text labels.
261
+ box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom)
262
+ coordinates of the box.
263
+ y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of
264
+ the box to draw labels around.
265
+ font_size (Optional[int]): font of the text. If not provided, a font size
266
+ proportional to the image width is calculated and used.
267
+ color (str): color of the text. Refer to `matplotlib.colors` for full list
268
+ of formats that are accepted.
269
+ box_facecolors (str): colors of the box wrapped around the text. Refer to
270
+ `matplotlib.colors` for full list of formats that are accepted.
271
+ alpha (float): transparency level of the box.
272
+ """
273
+ if not isinstance(box_facecolors, list):
274
+ box_facecolors = [box_facecolors] * len(text_ls)
275
+ assert len(box_facecolors) == len(
276
+ text_ls
277
+ ), "Number of colors provided is not equal to the number of text labels."
278
+
279
+ assert y_corner in [1, 3], "Y_corner must be either 1 or 3"
280
+ if not font_size:
281
+ font_size = self._default_font_size
282
+
283
+ x, horizontal_alignment = self._align_x_coordinate(box_coordinate)
284
+ y = box_coordinate[y_corner].item()
285
+ for i, text in enumerate(text_ls):
286
+ self.draw_text(
287
+ text,
288
+ (x, y),
289
+ font_size=font_size,
290
+ color=color,
291
+ horizontal_alignment=horizontal_alignment,
292
+ vertical_alignment="top",
293
+ box_facecolor=box_facecolors[i],
294
+ alpha=alpha,
295
+ )
296
+ y += font_size + font_size // 2
297
+
298
+ def _align_x_coordinate(self, box_coordinate: torch.Tensor) -> Tuple[float, str]:
299
+ """
300
+ Choose an x-coordinate from the box to make sure the text label
301
+ does not go out of frames. By default, the left x-coordinate is
302
+ chosen and text is aligned left. If the box is too close to the
303
+ right side of the image, then the right x-coordinate is chosen
304
+ instead and the text is aligned right.
305
+ Args:
306
+ box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom)
307
+ coordinates of the box.
308
+ Returns:
309
+ x_coordinate (float): the chosen x-coordinate.
310
+ alignment (str): whether to align left or right.
311
+ """
312
+ # If the x-coordinate is greater than 5/6 of the image width,
313
+ # then we align test to the right of the box. This is
314
+ # chosen by heuristics.
315
+ if box_coordinate[0] > (self.output.width * 5) // 6:
316
+ return box_coordinate[2], "right"
317
+
318
+ return box_coordinate[0], "left"
319
+
320
+ def _align_y_top(
321
+ self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float
322
+ ) -> int:
323
+ """
324
+ Calculate the number of text labels to plot on top of the box
325
+ without going out of frames.
326
+ Args:
327
+ box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom)
328
+ coordinates of the box.
329
+ num_text (int): the number of text labels to plot.
330
+ textbox_width (float): the width of the box wrapped around text label.
331
+ """
332
+ dist_to_top = box_coordinate[1]
333
+ num_text_top = dist_to_top // textbox_width
334
+
335
+ if isinstance(num_text_top, torch.Tensor):
336
+ num_text_top = int(num_text_top.item())
337
+
338
+ return min(num_text, num_text_top)
339
+
340
+ def _align_y_bottom(
341
+ self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float
342
+ ) -> int:
343
+ """
344
+ Calculate the number of text labels to plot at the bottom of the box
345
+ without going out of frames.
346
+ Args:
347
+ box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom)
348
+ coordinates of the box.
349
+ num_text (int): the number of text labels to plot.
350
+ textbox_width (float): the width of the box wrapped around text label.
351
+ """
352
+ dist_to_bottom = self.output.height - box_coordinate[3]
353
+ num_text_bottom = dist_to_bottom // textbox_width
354
+
355
+ if isinstance(num_text_bottom, torch.Tensor):
356
+ num_text_bottom = int(num_text_bottom.item())
357
+
358
+ return min(num_text, num_text_bottom)
359
+
360
+
361
+ class VideoVisualizer:
362
+ def __init__(
363
+ self,
364
+ num_classes: int,
365
+ class_names: Dict,
366
+ top_k: int = 1,
367
+ colormap: str = "rainbow",
368
+ thres: float = 0.7,
369
+ lower_thres: float = 0.3,
370
+ common_class_names: Optional[List[str]] = None,
371
+ mode: str = "top-k",
372
+ ) -> None:
373
+ """
374
+ Args:
375
+ num_classes (int): total number of classes.
376
+ class_names (dict): Dict mapping classID to name.
377
+ top_k (int): number of top predicted classes to plot.
378
+ colormap (str): the colormap to choose color for class labels from.
379
+ See https://matplotlib.org/tutorials/colors/colormaps.html
380
+ thres (float): threshold for picking predicted classes to visualize.
381
+ lower_thres (Optional[float]): If `common_class_names` if given,
382
+ this `lower_thres` will be applied to uncommon classes and
383
+ `thres` will be applied to classes in `common_class_names`.
384
+ common_class_names (Optional[list of str]): list of common class names
385
+ to apply `thres`. Class names not included in `common_class_names` will
386
+ have `lower_thres` as a threshold. If None, all classes will have
387
+ `thres` as a threshold. This is helpful for model trained on
388
+ highly imbalanced dataset.
389
+ mode (str): Supported modes are {"top-k", "thres"}.
390
+ This is used for choosing predictions for visualization.
391
+ """
392
+ assert mode in ["top-k", "thres"], "Mode {} is not supported.".format(mode)
393
+ self.mode = mode
394
+ self.num_classes = num_classes
395
+ self.class_names = class_names
396
+ self.top_k = top_k
397
+ self.thres = thres
398
+ self.lower_thres = lower_thres
399
+
400
+ if mode == "thres":
401
+ self._get_thres_array(common_class_names=common_class_names)
402
+
403
+ self.color_map = plt.get_cmap(colormap)
404
+
405
+ def _get_color(self, class_id: int) -> List[float]:
406
+ """
407
+ Get color for a class id.
408
+ Args:
409
+ class_id (int): class id.
410
+ """
411
+ return self.color_map(class_id / self.num_classes)[:3]
412
+
413
+ def draw_one_frame(
414
+ self,
415
+ frame: Union[torch.Tensor, np.ndarray],
416
+ preds: Union[torch.Tensor, List[float]],
417
+ bboxes: Optional[torch.Tensor] = None,
418
+ alpha: float = 0.5,
419
+ text_alpha: float = 0.7,
420
+ ground_truth: bool = False,
421
+ ) -> np.ndarray:
422
+ """
423
+ Draw labels and bouding boxes for one image. By default, predicted
424
+ labels are drawn in the top left corner of the image or corresponding
425
+ bounding boxes. For ground truth labels (setting True for ground_truth flag),
426
+ labels will be drawn in the bottom left corner.
427
+ Args:
428
+ frame (array-like): a tensor or numpy array of shape (H, W, C),
429
+ where H and W correspond to
430
+ the height and width of the image respectively. C is the number of
431
+ color channels. The image is required to be in RGB format since that
432
+ is a requirement of the Matplotlib library. The image is also expected
433
+ to be in the range [0, 255].
434
+ preds (tensor or list): If ground_truth is False, provide a float tensor of
435
+ shape (num_boxes, num_classes) that contains all of the confidence
436
+ scores of the model. For recognition task, input shape can be (num_classes,).
437
+ To plot true label (ground_truth is True), preds is a list contains int32
438
+ of the shape (num_boxes, true_class_ids) or (true_class_ids,).
439
+ bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates
440
+ of the bounding boxes.
441
+ alpha (Optional[float]): transparency level of the bounding boxes.
442
+ text_alpha (Optional[float]): transparency level of the box wrapped around
443
+ text labels.
444
+ ground_truth (bool): whether the prodived bounding boxes are ground-truth.
445
+ Returns:
446
+ An image with bounding box annotations and corresponding bbox
447
+ labels plotted on it.
448
+ """
449
+ if isinstance(preds, torch.Tensor):
450
+ if preds.ndim == 1:
451
+ preds = preds.unsqueeze(0)
452
+ n_instances = preds.shape[0]
453
+ elif isinstance(preds, list):
454
+ n_instances = len(preds)
455
+ else:
456
+ # logger.error("Unsupported type of prediction input.")
457
+ return
458
+
459
+ if ground_truth:
460
+ top_scores, top_classes = [None] * n_instances, preds
461
+
462
+ elif self.mode == "top-k":
463
+ top_scores, top_classes = torch.topk(preds, k=self.top_k)
464
+ top_scores, top_classes = top_scores.tolist(), top_classes.tolist()
465
+ elif self.mode == "thres":
466
+ top_scores, top_classes = [], []
467
+ for pred in preds:
468
+ mask = pred >= self.thres
469
+ top_scores.append(pred[mask].tolist())
470
+ top_class = torch.squeeze(torch.nonzero(mask), dim=-1).tolist()
471
+ top_classes.append(top_class)
472
+
473
+ # Create labels top k predicted classes with their scores.
474
+ text_labels = []
475
+ for i in range(n_instances):
476
+ text_labels.append(
477
+ _create_text_labels(
478
+ top_classes[i],
479
+ top_scores[i],
480
+ self.class_names,
481
+ ground_truth=ground_truth,
482
+ )
483
+ )
484
+ frame_visualizer = ImgVisualizer(frame, meta=None)
485
+ font_size = min(max(np.sqrt(frame.shape[0] * frame.shape[1]) // 25, 5), 9)
486
+ top_corner = not ground_truth
487
+ if bboxes is not None:
488
+ assert len(preds) == len(
489
+ bboxes
490
+ ), "Encounter {} predictions and {} bounding boxes".format(
491
+ len(preds), len(bboxes)
492
+ )
493
+ for i, box in enumerate(bboxes):
494
+ text = text_labels[i]
495
+ pred_class = top_classes[i]
496
+ colors = [self._get_color(pred) for pred in pred_class]
497
+
498
+ box_color = "r" if ground_truth else "g"
499
+ line_style = "--" if ground_truth else "-."
500
+ frame_visualizer.draw_box(
501
+ box,
502
+ alpha=alpha,
503
+ edge_color=box_color,
504
+ line_style=line_style,
505
+ )
506
+ frame_visualizer.draw_multiple_text(
507
+ text,
508
+ box,
509
+ top_corner=top_corner,
510
+ font_size=font_size,
511
+ box_facecolors=colors,
512
+ alpha=text_alpha,
513
+ )
514
+ else:
515
+ text = text_labels[0]
516
+ pred_class = top_classes[0]
517
+ colors = [self._get_color(pred) for pred in pred_class]
518
+ frame_visualizer.draw_multiple_text(
519
+ text,
520
+ torch.Tensor([0, 5, frame.shape[1], frame.shape[0] - 5]),
521
+ top_corner=top_corner,
522
+ font_size=font_size,
523
+ box_facecolors=colors,
524
+ alpha=text_alpha,
525
+ )
526
+
527
+ return frame_visualizer.output.get_image()
528
+
529
+ def draw_clip_range(
530
+ self,
531
+ frames: Union[torch.Tensor, np.ndarray],
532
+ preds: Union[torch.Tensor, List[float]],
533
+ bboxes: Optional[torch.Tensor] = None,
534
+ text_alpha: float = 0.5,
535
+ ground_truth: bool = False,
536
+ keyframe_idx: Optional[int] = None,
537
+ draw_range: Optional[List[int]] = None,
538
+ repeat_frame: int = 1,
539
+ ) -> List[np.ndarray]:
540
+ """
541
+ Draw predicted labels or ground truth classes to clip.
542
+ Draw bouding boxes to clip if bboxes is provided. Boxes will gradually
543
+ fade in and out the clip, centered around the clip's central frame,
544
+ within the provided `draw_range`.
545
+ Args:
546
+ frames (array-like): video data in the shape (T, H, W, C).
547
+ preds (tensor): a tensor of shape (num_boxes, num_classes) that
548
+ contains all of the confidence scores of the model. For recognition
549
+ task or for ground_truth labels, input shape can be (num_classes,).
550
+ bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates
551
+ of the bounding boxes.
552
+ text_alpha (float): transparency label of the box wrapped around text labels.
553
+ ground_truth (bool): whether the prodived bounding boxes are ground-truth.
554
+ keyframe_idx (int): the index of keyframe in the clip.
555
+ draw_range (Optional[list[ints]): only draw frames in range
556
+ [start_idx, end_idx] inclusively in the clip. If None, draw on
557
+ the entire clip.
558
+ repeat_frame (int): repeat each frame in draw_range for `repeat_frame`
559
+ time for slow-motion effect.
560
+ Returns:
561
+ A list of frames with bounding box annotations and corresponding
562
+ bbox labels ploted on them.
563
+ """
564
+ if draw_range is None:
565
+ draw_range = [0, len(frames) - 1]
566
+ if draw_range is not None:
567
+ draw_range[0] = max(0, draw_range[0])
568
+ left_frames = frames[: draw_range[0]]
569
+ right_frames = frames[draw_range[1] + 1 :]
570
+
571
+ draw_frames = frames[draw_range[0] : draw_range[1] + 1]
572
+ if keyframe_idx is None:
573
+ keyframe_idx = len(frames) // 2
574
+
575
+ img_ls = (
576
+ list(left_frames)
577
+ + self.draw_clip(
578
+ draw_frames,
579
+ preds,
580
+ bboxes=bboxes,
581
+ text_alpha=text_alpha,
582
+ ground_truth=ground_truth,
583
+ keyframe_idx=keyframe_idx - draw_range[0],
584
+ repeat_frame=repeat_frame,
585
+ )
586
+ + list(right_frames)
587
+ )
588
+
589
+ return img_ls
590
+
591
+ def draw_clip(
592
+ self,
593
+ frames: Union[torch.Tensor, np.ndarray],
594
+ preds: Union[torch.Tensor, List[float]],
595
+ bboxes: Optional[torch.Tensor] = None,
596
+ text_alpha: float = 0.5,
597
+ ground_truth: bool = False,
598
+ keyframe_idx: Optional[int] = None,
599
+ repeat_frame: int = 1,
600
+ ) -> List[np.ndarray]:
601
+ """
602
+ Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip
603
+ if bboxes is provided. Boxes will gradually fade in and out the clip, centered
604
+ around the clip's central frame.
605
+ Args:
606
+ frames (array-like): video data in the shape (T, H, W, C).
607
+ preds (tensor): a tensor of shape (num_boxes, num_classes) that contains
608
+ all of the confidence scores of the model. For recognition task or for
609
+ ground_truth labels, input shape can be (num_classes,).
610
+ bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates
611
+ of the bounding boxes.
612
+ text_alpha (float): transparency label of the box wrapped around text labels.
613
+ ground_truth (bool): whether the prodived bounding boxes are ground-truth.
614
+ keyframe_idx (int): the index of keyframe in the clip.
615
+ repeat_frame (int): repeat each frame in draw_range for `repeat_frame`
616
+ time for slow-motion effect.
617
+ Returns:
618
+ A list of frames with bounding box annotations and corresponding
619
+ bbox labels plotted on them.
620
+ """
621
+ assert repeat_frame >= 1, "`repeat_frame` must be a positive integer."
622
+
623
+ repeated_seq = range(0, len(frames))
624
+ repeated_seq = list(
625
+ itertools.chain.from_iterable(
626
+ itertools.repeat(x, repeat_frame) for x in repeated_seq
627
+ )
628
+ )
629
+
630
+ frames, adjusted = self._adjust_frames_type(frames)
631
+ if keyframe_idx is None:
632
+ half_left = len(repeated_seq) // 2
633
+ half_right = (len(repeated_seq) + 1) // 2
634
+ else:
635
+ mid = int((keyframe_idx / len(frames)) * len(repeated_seq))
636
+ half_left = mid
637
+ half_right = len(repeated_seq) - mid
638
+
639
+ alpha_ls = np.concatenate(
640
+ [
641
+ np.linspace(0, 1, num=half_left),
642
+ np.linspace(1, 0, num=half_right),
643
+ ]
644
+ )
645
+ text_alpha = text_alpha
646
+ frames = frames[repeated_seq]
647
+ img_ls = []
648
+ for alpha, frame in zip(alpha_ls, frames):
649
+ draw_img = self.draw_one_frame(
650
+ frame,
651
+ preds,
652
+ bboxes,
653
+ alpha=alpha,
654
+ text_alpha=text_alpha,
655
+ ground_truth=ground_truth,
656
+ )
657
+ if adjusted:
658
+ draw_img = draw_img.astype("float32") / 255
659
+
660
+ img_ls.append(draw_img)
661
+
662
+ return img_ls
663
+
664
+ def _adjust_frames_type(
665
+ self, frames: torch.Tensor
666
+ ) -> Tuple[List[np.ndarray], bool]:
667
+ """
668
+ Modify video data to have dtype of uint8 and values range in [0, 255].
669
+ Args:
670
+ frames (array-like): 4D array of shape (T, H, W, C).
671
+ Returns:
672
+ frames (list of frames): list of frames in range [0, 1].
673
+ adjusted (bool): whether the original frames need adjusted.
674
+ """
675
+ assert (
676
+ frames is not None and len(frames) != 0
677
+ ), "Frames does not contain any values"
678
+ frames = np.array(frames)
679
+ assert np.array(frames).ndim == 4, "Frames must have 4 dimensions"
680
+ adjusted = False
681
+ if frames.dtype in [np.float32, np.float64]:
682
+ frames *= 255
683
+ frames = frames.astype(np.uint8)
684
+ adjusted = True
685
+
686
+ return frames, adjusted
687
+
688
+ def _get_thres_array(self, common_class_names: Optional[List[str]] = None) -> None:
689
+ """
690
+ Compute a thresholds array for all classes based on `self.thes` and `self.lower_thres`.
691
+ Args:
692
+ common_class_names (Optional[list of str]): a list of common class names.
693
+ """
694
+ common_class_ids = []
695
+ if common_class_names is not None:
696
+ common_classes = set(common_class_names)
697
+
698
+ for key, name in self.class_names.items():
699
+ if name in common_classes:
700
+ common_class_ids.append(key)
701
+ else:
702
+ common_class_ids = list(range(self.num_classes))
703
+
704
+ thres_array = np.full(shape=(self.num_classes,), fill_value=self.lower_thres)
705
+ thres_array[common_class_ids] = self.thres
706
+ self.thres = torch.from_numpy(thres_array)
yolov3.cfg ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [net]
2
+ # Testing
3
+ # batch=1
4
+ # subdivisions=1
5
+ # Training
6
+ batch=64
7
+ subdivisions=16
8
+ width=608
9
+ height=608
10
+ channels=3
11
+ momentum=0.9
12
+ decay=0.0005
13
+ angle=0
14
+ saturation = 1.5
15
+ exposure = 1.5
16
+ hue=.1
17
+
18
+ learning_rate=0.001
19
+ burn_in=1000
20
+ max_batches = 500200
21
+ policy=steps
22
+ steps=400000,450000
23
+ scales=.1,.1
24
+
25
+ [convolutional]
26
+ batch_normalize=1
27
+ filters=32
28
+ size=3
29
+ stride=1
30
+ pad=1
31
+ activation=leaky
32
+
33
+ # Downsample
34
+
35
+ [convolutional]
36
+ batch_normalize=1
37
+ filters=64
38
+ size=3
39
+ stride=2
40
+ pad=1
41
+ activation=leaky
42
+
43
+ [convolutional]
44
+ batch_normalize=1
45
+ filters=32
46
+ size=1
47
+ stride=1
48
+ pad=1
49
+ activation=leaky
50
+
51
+ [convolutional]
52
+ batch_normalize=1
53
+ filters=64
54
+ size=3
55
+ stride=1
56
+ pad=1
57
+ activation=leaky
58
+
59
+ [shortcut]
60
+ from=-3
61
+ activation=linear
62
+
63
+ # Downsample
64
+
65
+ [convolutional]
66
+ batch_normalize=1
67
+ filters=128
68
+ size=3
69
+ stride=2
70
+ pad=1
71
+ activation=leaky
72
+
73
+ [convolutional]
74
+ batch_normalize=1
75
+ filters=64
76
+ size=1
77
+ stride=1
78
+ pad=1
79
+ activation=leaky
80
+
81
+ [convolutional]
82
+ batch_normalize=1
83
+ filters=128
84
+ size=3
85
+ stride=1
86
+ pad=1
87
+ activation=leaky
88
+
89
+ [shortcut]
90
+ from=-3
91
+ activation=linear
92
+
93
+ [convolutional]
94
+ batch_normalize=1
95
+ filters=64
96
+ size=1
97
+ stride=1
98
+ pad=1
99
+ activation=leaky
100
+
101
+ [convolutional]
102
+ batch_normalize=1
103
+ filters=128
104
+ size=3
105
+ stride=1
106
+ pad=1
107
+ activation=leaky
108
+
109
+ [shortcut]
110
+ from=-3
111
+ activation=linear
112
+
113
+ # Downsample
114
+
115
+ [convolutional]
116
+ batch_normalize=1
117
+ filters=256
118
+ size=3
119
+ stride=2
120
+ pad=1
121
+ activation=leaky
122
+
123
+ [convolutional]
124
+ batch_normalize=1
125
+ filters=128
126
+ size=1
127
+ stride=1
128
+ pad=1
129
+ activation=leaky
130
+
131
+ [convolutional]
132
+ batch_normalize=1
133
+ filters=256
134
+ size=3
135
+ stride=1
136
+ pad=1
137
+ activation=leaky
138
+
139
+ [shortcut]
140
+ from=-3
141
+ activation=linear
142
+
143
+ [convolutional]
144
+ batch_normalize=1
145
+ filters=128
146
+ size=1
147
+ stride=1
148
+ pad=1
149
+ activation=leaky
150
+
151
+ [convolutional]
152
+ batch_normalize=1
153
+ filters=256
154
+ size=3
155
+ stride=1
156
+ pad=1
157
+ activation=leaky
158
+
159
+ [shortcut]
160
+ from=-3
161
+ activation=linear
162
+
163
+ [convolutional]
164
+ batch_normalize=1
165
+ filters=128
166
+ size=1
167
+ stride=1
168
+ pad=1
169
+ activation=leaky
170
+
171
+ [convolutional]
172
+ batch_normalize=1
173
+ filters=256
174
+ size=3
175
+ stride=1
176
+ pad=1
177
+ activation=leaky
178
+
179
+ [shortcut]
180
+ from=-3
181
+ activation=linear
182
+
183
+ [convolutional]
184
+ batch_normalize=1
185
+ filters=128
186
+ size=1
187
+ stride=1
188
+ pad=1
189
+ activation=leaky
190
+
191
+ [convolutional]
192
+ batch_normalize=1
193
+ filters=256
194
+ size=3
195
+ stride=1
196
+ pad=1
197
+ activation=leaky
198
+
199
+ [shortcut]
200
+ from=-3
201
+ activation=linear
202
+
203
+
204
+ [convolutional]
205
+ batch_normalize=1
206
+ filters=128
207
+ size=1
208
+ stride=1
209
+ pad=1
210
+ activation=leaky
211
+
212
+ [convolutional]
213
+ batch_normalize=1
214
+ filters=256
215
+ size=3
216
+ stride=1
217
+ pad=1
218
+ activation=leaky
219
+
220
+ [shortcut]
221
+ from=-3
222
+ activation=linear
223
+
224
+ [convolutional]
225
+ batch_normalize=1
226
+ filters=128
227
+ size=1
228
+ stride=1
229
+ pad=1
230
+ activation=leaky
231
+
232
+ [convolutional]
233
+ batch_normalize=1
234
+ filters=256
235
+ size=3
236
+ stride=1
237
+ pad=1
238
+ activation=leaky
239
+
240
+ [shortcut]
241
+ from=-3
242
+ activation=linear
243
+
244
+ [convolutional]
245
+ batch_normalize=1
246
+ filters=128
247
+ size=1
248
+ stride=1
249
+ pad=1
250
+ activation=leaky
251
+
252
+ [convolutional]
253
+ batch_normalize=1
254
+ filters=256
255
+ size=3
256
+ stride=1
257
+ pad=1
258
+ activation=leaky
259
+
260
+ [shortcut]
261
+ from=-3
262
+ activation=linear
263
+
264
+ [convolutional]
265
+ batch_normalize=1
266
+ filters=128
267
+ size=1
268
+ stride=1
269
+ pad=1
270
+ activation=leaky
271
+
272
+ [convolutional]
273
+ batch_normalize=1
274
+ filters=256
275
+ size=3
276
+ stride=1
277
+ pad=1
278
+ activation=leaky
279
+
280
+ [shortcut]
281
+ from=-3
282
+ activation=linear
283
+
284
+ # Downsample
285
+
286
+ [convolutional]
287
+ batch_normalize=1
288
+ filters=512
289
+ size=3
290
+ stride=2
291
+ pad=1
292
+ activation=leaky
293
+
294
+ [convolutional]
295
+ batch_normalize=1
296
+ filters=256
297
+ size=1
298
+ stride=1
299
+ pad=1
300
+ activation=leaky
301
+
302
+ [convolutional]
303
+ batch_normalize=1
304
+ filters=512
305
+ size=3
306
+ stride=1
307
+ pad=1
308
+ activation=leaky
309
+
310
+ [shortcut]
311
+ from=-3
312
+ activation=linear
313
+
314
+
315
+ [convolutional]
316
+ batch_normalize=1
317
+ filters=256
318
+ size=1
319
+ stride=1
320
+ pad=1
321
+ activation=leaky
322
+
323
+ [convolutional]
324
+ batch_normalize=1
325
+ filters=512
326
+ size=3
327
+ stride=1
328
+ pad=1
329
+ activation=leaky
330
+
331
+ [shortcut]
332
+ from=-3
333
+ activation=linear
334
+
335
+
336
+ [convolutional]
337
+ batch_normalize=1
338
+ filters=256
339
+ size=1
340
+ stride=1
341
+ pad=1
342
+ activation=leaky
343
+
344
+ [convolutional]
345
+ batch_normalize=1
346
+ filters=512
347
+ size=3
348
+ stride=1
349
+ pad=1
350
+ activation=leaky
351
+
352
+ [shortcut]
353
+ from=-3
354
+ activation=linear
355
+
356
+
357
+ [convolutional]
358
+ batch_normalize=1
359
+ filters=256
360
+ size=1
361
+ stride=1
362
+ pad=1
363
+ activation=leaky
364
+
365
+ [convolutional]
366
+ batch_normalize=1
367
+ filters=512
368
+ size=3
369
+ stride=1
370
+ pad=1
371
+ activation=leaky
372
+
373
+ [shortcut]
374
+ from=-3
375
+ activation=linear
376
+
377
+ [convolutional]
378
+ batch_normalize=1
379
+ filters=256
380
+ size=1
381
+ stride=1
382
+ pad=1
383
+ activation=leaky
384
+
385
+ [convolutional]
386
+ batch_normalize=1
387
+ filters=512
388
+ size=3
389
+ stride=1
390
+ pad=1
391
+ activation=leaky
392
+
393
+ [shortcut]
394
+ from=-3
395
+ activation=linear
396
+
397
+
398
+ [convolutional]
399
+ batch_normalize=1
400
+ filters=256
401
+ size=1
402
+ stride=1
403
+ pad=1
404
+ activation=leaky
405
+
406
+ [convolutional]
407
+ batch_normalize=1
408
+ filters=512
409
+ size=3
410
+ stride=1
411
+ pad=1
412
+ activation=leaky
413
+
414
+ [shortcut]
415
+ from=-3
416
+ activation=linear
417
+
418
+
419
+ [convolutional]
420
+ batch_normalize=1
421
+ filters=256
422
+ size=1
423
+ stride=1
424
+ pad=1
425
+ activation=leaky
426
+
427
+ [convolutional]
428
+ batch_normalize=1
429
+ filters=512
430
+ size=3
431
+ stride=1
432
+ pad=1
433
+ activation=leaky
434
+
435
+ [shortcut]
436
+ from=-3
437
+ activation=linear
438
+
439
+ [convolutional]
440
+ batch_normalize=1
441
+ filters=256
442
+ size=1
443
+ stride=1
444
+ pad=1
445
+ activation=leaky
446
+
447
+ [convolutional]
448
+ batch_normalize=1
449
+ filters=512
450
+ size=3
451
+ stride=1
452
+ pad=1
453
+ activation=leaky
454
+
455
+ [shortcut]
456
+ from=-3
457
+ activation=linear
458
+
459
+ # Downsample
460
+
461
+ [convolutional]
462
+ batch_normalize=1
463
+ filters=1024
464
+ size=3
465
+ stride=2
466
+ pad=1
467
+ activation=leaky
468
+
469
+ [convolutional]
470
+ batch_normalize=1
471
+ filters=512
472
+ size=1
473
+ stride=1
474
+ pad=1
475
+ activation=leaky
476
+
477
+ [convolutional]
478
+ batch_normalize=1
479
+ filters=1024
480
+ size=3
481
+ stride=1
482
+ pad=1
483
+ activation=leaky
484
+
485
+ [shortcut]
486
+ from=-3
487
+ activation=linear
488
+
489
+ [convolutional]
490
+ batch_normalize=1
491
+ filters=512
492
+ size=1
493
+ stride=1
494
+ pad=1
495
+ activation=leaky
496
+
497
+ [convolutional]
498
+ batch_normalize=1
499
+ filters=1024
500
+ size=3
501
+ stride=1
502
+ pad=1
503
+ activation=leaky
504
+
505
+ [shortcut]
506
+ from=-3
507
+ activation=linear
508
+
509
+ [convolutional]
510
+ batch_normalize=1
511
+ filters=512
512
+ size=1
513
+ stride=1
514
+ pad=1
515
+ activation=leaky
516
+
517
+ [convolutional]
518
+ batch_normalize=1
519
+ filters=1024
520
+ size=3
521
+ stride=1
522
+ pad=1
523
+ activation=leaky
524
+
525
+ [shortcut]
526
+ from=-3
527
+ activation=linear
528
+
529
+ [convolutional]
530
+ batch_normalize=1
531
+ filters=512
532
+ size=1
533
+ stride=1
534
+ pad=1
535
+ activation=leaky
536
+
537
+ [convolutional]
538
+ batch_normalize=1
539
+ filters=1024
540
+ size=3
541
+ stride=1
542
+ pad=1
543
+ activation=leaky
544
+
545
+ [shortcut]
546
+ from=-3
547
+ activation=linear
548
+
549
+ ######################
550
+
551
+ [convolutional]
552
+ batch_normalize=1
553
+ filters=512
554
+ size=1
555
+ stride=1
556
+ pad=1
557
+ activation=leaky
558
+
559
+ [convolutional]
560
+ batch_normalize=1
561
+ size=3
562
+ stride=1
563
+ pad=1
564
+ filters=1024
565
+ activation=leaky
566
+
567
+ [convolutional]
568
+ batch_normalize=1
569
+ filters=512
570
+ size=1
571
+ stride=1
572
+ pad=1
573
+ activation=leaky
574
+
575
+ [convolutional]
576
+ batch_normalize=1
577
+ size=3
578
+ stride=1
579
+ pad=1
580
+ filters=1024
581
+ activation=leaky
582
+
583
+ [convolutional]
584
+ batch_normalize=1
585
+ filters=512
586
+ size=1
587
+ stride=1
588
+ pad=1
589
+ activation=leaky
590
+
591
+ [convolutional]
592
+ batch_normalize=1
593
+ size=3
594
+ stride=1
595
+ pad=1
596
+ filters=1024
597
+ activation=leaky
598
+
599
+ [convolutional]
600
+ size=1
601
+ stride=1
602
+ pad=1
603
+ filters=255
604
+ activation=linear
605
+
606
+
607
+ [yolo]
608
+ mask = 6,7,8
609
+ anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
610
+ classes=80
611
+ num=9
612
+ jitter=.3
613
+ ignore_thresh = .7
614
+ truth_thresh = 1
615
+ random=1
616
+
617
+
618
+ [route]
619
+ layers = -4
620
+
621
+ [convolutional]
622
+ batch_normalize=1
623
+ filters=256
624
+ size=1
625
+ stride=1
626
+ pad=1
627
+ activation=leaky
628
+
629
+ [upsample]
630
+ stride=2
631
+
632
+ [route]
633
+ layers = -1, 61
634
+
635
+
636
+
637
+ [convolutional]
638
+ batch_normalize=1
639
+ filters=256
640
+ size=1
641
+ stride=1
642
+ pad=1
643
+ activation=leaky
644
+
645
+ [convolutional]
646
+ batch_normalize=1
647
+ size=3
648
+ stride=1
649
+ pad=1
650
+ filters=512
651
+ activation=leaky
652
+
653
+ [convolutional]
654
+ batch_normalize=1
655
+ filters=256
656
+ size=1
657
+ stride=1
658
+ pad=1
659
+ activation=leaky
660
+
661
+ [convolutional]
662
+ batch_normalize=1
663
+ size=3
664
+ stride=1
665
+ pad=1
666
+ filters=512
667
+ activation=leaky
668
+
669
+ [convolutional]
670
+ batch_normalize=1
671
+ filters=256
672
+ size=1
673
+ stride=1
674
+ pad=1
675
+ activation=leaky
676
+
677
+ [convolutional]
678
+ batch_normalize=1
679
+ size=3
680
+ stride=1
681
+ pad=1
682
+ filters=512
683
+ activation=leaky
684
+
685
+ [convolutional]
686
+ size=1
687
+ stride=1
688
+ pad=1
689
+ filters=255
690
+ activation=linear
691
+
692
+
693
+ [yolo]
694
+ mask = 3,4,5
695
+ anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
696
+ classes=80
697
+ num=9
698
+ jitter=.3
699
+ ignore_thresh = .7
700
+ truth_thresh = 1
701
+ random=1
702
+
703
+
704
+
705
+ [route]
706
+ layers = -4
707
+
708
+ [convolutional]
709
+ batch_normalize=1
710
+ filters=128
711
+ size=1
712
+ stride=1
713
+ pad=1
714
+ activation=leaky
715
+
716
+ [upsample]
717
+ stride=2
718
+
719
+ [route]
720
+ layers = -1, 36
721
+
722
+
723
+
724
+ [convolutional]
725
+ batch_normalize=1
726
+ filters=128
727
+ size=1
728
+ stride=1
729
+ pad=1
730
+ activation=leaky
731
+
732
+ [convolutional]
733
+ batch_normalize=1
734
+ size=3
735
+ stride=1
736
+ pad=1
737
+ filters=256
738
+ activation=leaky
739
+
740
+ [convolutional]
741
+ batch_normalize=1
742
+ filters=128
743
+ size=1
744
+ stride=1
745
+ pad=1
746
+ activation=leaky
747
+
748
+ [convolutional]
749
+ batch_normalize=1
750
+ size=3
751
+ stride=1
752
+ pad=1
753
+ filters=256
754
+ activation=leaky
755
+
756
+ [convolutional]
757
+ batch_normalize=1
758
+ filters=128
759
+ size=1
760
+ stride=1
761
+ pad=1
762
+ activation=leaky
763
+
764
+ [convolutional]
765
+ batch_normalize=1
766
+ size=3
767
+ stride=1
768
+ pad=1
769
+ filters=256
770
+ activation=leaky
771
+
772
+ [convolutional]
773
+ size=1
774
+ stride=1
775
+ pad=1
776
+ filters=255
777
+ activation=linear
778
+
779
+
780
+ [yolo]
781
+ mask = 0,1,2
782
+ anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
783
+ classes=80
784
+ num=9
785
+ jitter=.3
786
+ ignore_thresh = .7
787
+ truth_thresh = 1
788
+ random=1
789
+