Spaces:
Runtime error
Runtime error
Bhaskar Saranga
commited on
Commit
·
e215925
1
Parent(s):
5f086ec
Added tracker
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +14 -3
- requirements.txt +48 -2
- runs/track/test.txt +0 -0
- track.py +397 -0
- trackers/__init__.py +0 -0
- trackers/bytetrack/basetrack.py +52 -0
- trackers/bytetrack/byte_tracker.py +353 -0
- trackers/bytetrack/configs/bytetrack.yaml +7 -0
- trackers/bytetrack/kalman_filter.py +270 -0
- trackers/bytetrack/matching.py +219 -0
- trackers/multi_tracker_zoo.py +52 -0
- trackers/ocsort/association.py +377 -0
- trackers/ocsort/configs/ocsort.yaml +12 -0
- trackers/ocsort/kalmanfilter.py +1581 -0
- trackers/ocsort/ocsort.py +328 -0
- trackers/reid_export.py +313 -0
- trackers/strongsort/.gitignore +13 -0
- trackers/strongsort/__init__.py +0 -0
- trackers/strongsort/configs/strongsort.yaml +11 -0
- trackers/strongsort/deep/checkpoint/.gitkeep +0 -0
- trackers/strongsort/deep/checkpoint/osnet_x0_25_market1501.pth +3 -0
- trackers/strongsort/deep/checkpoint/osnet_x0_25_msmt17.pth +3 -0
- trackers/strongsort/deep/checkpoint/osnet_x1_0_msmt17.pth +3 -0
- trackers/strongsort/deep/models/__init__.py +122 -0
- trackers/strongsort/deep/models/densenet.py +380 -0
- trackers/strongsort/deep/models/hacnn.py +414 -0
- trackers/strongsort/deep/models/inceptionresnetv2.py +361 -0
- trackers/strongsort/deep/models/inceptionv4.py +381 -0
- trackers/strongsort/deep/models/mlfn.py +269 -0
- trackers/strongsort/deep/models/mobilenetv2.py +274 -0
- trackers/strongsort/deep/models/mudeep.py +206 -0
- trackers/strongsort/deep/models/nasnet.py +1131 -0
- trackers/strongsort/deep/models/osnet.py +598 -0
- trackers/strongsort/deep/models/osnet_ain.py +609 -0
- trackers/strongsort/deep/models/pcb.py +314 -0
- trackers/strongsort/deep/models/resnet.py +530 -0
- trackers/strongsort/deep/models/resnet_ibn_a.py +289 -0
- trackers/strongsort/deep/models/resnet_ibn_b.py +274 -0
- trackers/strongsort/deep/models/resnetmid.py +307 -0
- trackers/strongsort/deep/models/senet.py +688 -0
- trackers/strongsort/deep/models/shufflenet.py +198 -0
- trackers/strongsort/deep/models/shufflenetv2.py +262 -0
- trackers/strongsort/deep/models/squeezenet.py +236 -0
- trackers/strongsort/deep/models/xception.py +344 -0
- trackers/strongsort/deep/reid_model_factory.py +215 -0
- trackers/strongsort/reid_multibackend.py +237 -0
- trackers/strongsort/sort/__init__.py +0 -0
- trackers/strongsort/sort/detection.py +58 -0
- trackers/strongsort/sort/iou_matching.py +82 -0
- trackers/strongsort/sort/kalman_filter.py +214 -0
app.py
CHANGED
@@ -11,6 +11,7 @@ from utils.plots import plot_one_box
|
|
11 |
from utils.torch_utils import time_synchronized
|
12 |
import time
|
13 |
from ultralytics import YOLO
|
|
|
14 |
|
15 |
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32):
|
16 |
# Resize and pad image while meeting stride-multiple constraints
|
@@ -178,6 +179,10 @@ def inference_comp(image,iou_threshold,confidence_threshold):
|
|
178 |
v7_out, v7_fps = inference(image, "yolov7",iou_threshold,confidence_threshold)
|
179 |
return v7_out,v8_out,v7_fps,v8_fps
|
180 |
|
|
|
|
|
|
|
|
|
181 |
examples_images = ['data/images/1.jpg',
|
182 |
'data/images/2.jpg',
|
183 |
'data/images/bus.jpg',
|
@@ -185,6 +190,7 @@ examples_images = ['data/images/1.jpg',
|
|
185 |
examples_videos = ['data/video/1.mp4','data/video/2.mp4']
|
186 |
|
187 |
models = ['yolov8m','yolov7','yolov7t']
|
|
|
188 |
|
189 |
with gr.Blocks() as demo:
|
190 |
gr.Markdown("## IDD Inference on Yolo V7 and V8 ")
|
@@ -205,11 +211,14 @@ with gr.Blocks() as demo:
|
|
205 |
video_input = gr.Video(type='pil', label="Input Video", source="upload")
|
206 |
video_output = gr.Video(type="pil", label="Output Video",format="mp4")
|
207 |
fps_video = gr.Number(0,label='FPS')
|
208 |
-
video_drop = gr.Dropdown(choices=models,value=models[0])
|
|
|
209 |
video_iou_threshold = gr.Slider(label="IOU Threshold",interactive=True, minimum=0.0, maximum=1.0, value=0.45)
|
210 |
video_conf_threshold = gr.Slider(label="Confidence Threshold",interactive=True, minimum=0.0, maximum=1.0, value=0.25)
|
211 |
gr.Examples(examples=examples_videos,inputs=video_input,outputs=video_output)
|
212 |
-
|
|
|
|
|
213 |
|
214 |
with gr.Tab("Compare Models"):
|
215 |
gr.Markdown("## YOLOv7 vs YOLOv8 Object detection comparision")
|
@@ -231,12 +240,14 @@ with gr.Blocks() as demo:
|
|
231 |
text_button.click(inference, inputs=[image_input,image_drop,
|
232 |
image_iou_threshold,image_conf_threshold],
|
233 |
outputs=[image_output,fps_image])
|
234 |
-
|
235 |
video_iou_threshold,video_conf_threshold],
|
236 |
outputs=[video_output,fps_video])
|
237 |
text_comp_button.click(inference_comp,inputs=[image_comp_input,
|
238 |
image_comp_iou_threshold,
|
239 |
image_comp_conf_threshold],
|
240 |
outputs=[image_comp_output_v7,image_comp_output_v8,v7_fps_image,v8_fps_image])
|
|
|
|
|
241 |
|
242 |
demo.launch(debug=True,enable_queue=True)
|
|
|
11 |
from utils.torch_utils import time_synchronized
|
12 |
import time
|
13 |
from ultralytics import YOLO
|
14 |
+
from track import MOT
|
15 |
|
16 |
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32):
|
17 |
# Resize and pad image while meeting stride-multiple constraints
|
|
|
179 |
v7_out, v7_fps = inference(image, "yolov7",iou_threshold,confidence_threshold)
|
180 |
return v7_out,v8_out,v7_fps,v8_fps
|
181 |
|
182 |
+
def MODT(sourceVideo, model_link, trackingmethod):
|
183 |
+
model_path = 'weights/'+str(model_link)+'.pt'
|
184 |
+
return MOT(model_path, trackingmethod, sourceVideo), 30
|
185 |
+
|
186 |
examples_images = ['data/images/1.jpg',
|
187 |
'data/images/2.jpg',
|
188 |
'data/images/bus.jpg',
|
|
|
190 |
examples_videos = ['data/video/1.mp4','data/video/2.mp4']
|
191 |
|
192 |
models = ['yolov8m','yolov7','yolov7t']
|
193 |
+
trackers = ['strongsort', 'bytetrack', 'ocsort']
|
194 |
|
195 |
with gr.Blocks() as demo:
|
196 |
gr.Markdown("## IDD Inference on Yolo V7 and V8 ")
|
|
|
211 |
video_input = gr.Video(type='pil', label="Input Video", source="upload")
|
212 |
video_output = gr.Video(type="pil", label="Output Video",format="mp4")
|
213 |
fps_video = gr.Number(0,label='FPS')
|
214 |
+
video_drop = gr.Dropdown(label="Model", choices=models,value=models[0])
|
215 |
+
tracking_drop = gr.Dropdown(label="Tracker", choices=trackers,value=trackers[0])
|
216 |
video_iou_threshold = gr.Slider(label="IOU Threshold",interactive=True, minimum=0.0, maximum=1.0, value=0.45)
|
217 |
video_conf_threshold = gr.Slider(label="Confidence Threshold",interactive=True, minimum=0.0, maximum=1.0, value=0.25)
|
218 |
gr.Examples(examples=examples_videos,inputs=video_input,outputs=video_output)
|
219 |
+
with gr.Row():
|
220 |
+
video_button_detect = gr.Button("Detect")
|
221 |
+
video_button_track = gr.Button("Track")
|
222 |
|
223 |
with gr.Tab("Compare Models"):
|
224 |
gr.Markdown("## YOLOv7 vs YOLOv8 Object detection comparision")
|
|
|
240 |
text_button.click(inference, inputs=[image_input,image_drop,
|
241 |
image_iou_threshold,image_conf_threshold],
|
242 |
outputs=[image_output,fps_image])
|
243 |
+
video_button_detect.click(inference2, inputs=[video_input,video_drop,
|
244 |
video_iou_threshold,video_conf_threshold],
|
245 |
outputs=[video_output,fps_video])
|
246 |
text_comp_button.click(inference_comp,inputs=[image_comp_input,
|
247 |
image_comp_iou_threshold,
|
248 |
image_comp_conf_threshold],
|
249 |
outputs=[image_comp_output_v7,image_comp_output_v8,v7_fps_image,v8_fps_image])
|
250 |
+
video_button_track.click(MODT,inputs=[video_input,video_drop, tracking_drop],
|
251 |
+
outputs=[video_output, fps_video])
|
252 |
|
253 |
demo.launch(debug=True,enable_queue=True)
|
requirements.txt
CHANGED
@@ -1,6 +1,18 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
opencv-python>=4.1.1
|
3 |
torch>=1.7.0,!=1.12.0
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
torchvision>=0.8.1,!=0.13.0
|
5 |
gradio>=3.9.1
|
6 |
tqdm>=4.64.0
|
@@ -8,4 +20,38 @@ seaborn>=0.11.0
|
|
8 |
scipy>=1.4.1
|
9 |
Pillow>=7.1.2
|
10 |
huggingface-hub >= 0.11.0
|
11 |
-
ultralytics >=8.0.34
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pip install -r requirements.txt
|
2 |
+
|
3 |
+
# Base ----------------------------------------
|
4 |
+
gitpython
|
5 |
+
ipython # interactive notebook
|
6 |
+
matplotlib>=3.2.2
|
7 |
+
numpy==1.23.1
|
8 |
opencv-python>=4.1.1
|
9 |
torch>=1.7.0,!=1.12.0
|
10 |
+
Pillow>=7.1.2
|
11 |
+
psutil # system resources
|
12 |
+
PyYAML>=5.3.1
|
13 |
+
requests>=2.23.0
|
14 |
+
scipy>=1.4.1
|
15 |
+
thop>=0.1.1 # FLOPs computation
|
16 |
torchvision>=0.8.1,!=0.13.0
|
17 |
gradio>=3.9.1
|
18 |
tqdm>=4.64.0
|
|
|
20 |
scipy>=1.4.1
|
21 |
Pillow>=7.1.2
|
22 |
huggingface-hub >= 0.11.0
|
23 |
+
ultralytics >=8.0.34
|
24 |
+
|
25 |
+
# Logging ---------------------------------------------------------------------
|
26 |
+
tensorboard>=2.4.1
|
27 |
+
# clearml>=1.2.0
|
28 |
+
# comet
|
29 |
+
|
30 |
+
# Plotting --------------------------------------------------------------------
|
31 |
+
pandas>=1.1.4
|
32 |
+
seaborn>=0.11.0
|
33 |
+
|
34 |
+
# StrongSORT ------------------------------------------------------------------
|
35 |
+
easydict
|
36 |
+
|
37 |
+
# torchreid -------------------------------------------------------------------
|
38 |
+
gdown
|
39 |
+
|
40 |
+
# ByteTrack -------------------------------------------------------------------
|
41 |
+
lap
|
42 |
+
|
43 |
+
# OCSORT ----------------------------------------------------------------------
|
44 |
+
filterpy
|
45 |
+
|
46 |
+
# Export ----------------------------------------------------------------------
|
47 |
+
# onnx>=1.9.0 # ONNX export
|
48 |
+
# onnx-simplifier>=0.4.1 # ONNX simplifier
|
49 |
+
# nvidia-pyindex # TensorRT export
|
50 |
+
# nvidia-tensorrt # TensorRT export
|
51 |
+
# openvino-dev # OpenVINO export
|
52 |
+
|
53 |
+
# Hyperparam search -----------------------------------------------------------
|
54 |
+
# optuna
|
55 |
+
# plotly # for hp importance and pareto front plots
|
56 |
+
# kaleido
|
57 |
+
# joblib
|
runs/track/test.txt
ADDED
File without changes
|
track.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
# limit the number of cpus used by high performance libraries
|
5 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
6 |
+
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
7 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
8 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
|
9 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
10 |
+
|
11 |
+
import sys
|
12 |
+
import platform
|
13 |
+
import numpy as np
|
14 |
+
from pathlib import Path
|
15 |
+
import torch
|
16 |
+
import torch.backends.cudnn as cudnn
|
17 |
+
|
18 |
+
FILE = Path(__file__).resolve()
|
19 |
+
ROOT = FILE.parents[0] # yolov5 strongsort root directory
|
20 |
+
WEIGHTS = ROOT / 'weights'
|
21 |
+
|
22 |
+
if str(ROOT) not in sys.path:
|
23 |
+
sys.path.append(str(ROOT)) # add ROOT to PATH
|
24 |
+
if str(ROOT / 'yolov8') not in sys.path:
|
25 |
+
sys.path.append(str(ROOT / 'yolov8')) # add yolov5 ROOT to PATH
|
26 |
+
if str(ROOT / 'trackers' / 'strongsort') not in sys.path:
|
27 |
+
sys.path.append(str(ROOT / 'trackers' / 'strongsort')) # add strong_sort ROOT to PATH
|
28 |
+
|
29 |
+
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
30 |
+
|
31 |
+
import logging
|
32 |
+
#from yolov8.ultralytics.nn.autobackend import AutoBackend
|
33 |
+
from ultralytics.nn.autobackend import AutoBackend
|
34 |
+
#from yolov8.ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadStreams
|
35 |
+
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadStreams
|
36 |
+
#from yolov8.ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
37 |
+
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
38 |
+
#from yolov8.ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
39 |
+
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
40 |
+
|
41 |
+
#from yolov8.ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow, print_args, check_requirements
|
42 |
+
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow, print_args, check_requirements
|
43 |
+
from ultralytics.yolo.utils.files import increment_path
|
44 |
+
from ultralytics.yolo.utils.torch_utils import select_device
|
45 |
+
from ultralytics.yolo.utils.ops import Profile, non_max_suppression, scale_boxes, process_mask, process_mask_native
|
46 |
+
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
47 |
+
|
48 |
+
from trackers.multi_tracker_zoo import create_tracker
|
49 |
+
|
50 |
+
|
51 |
+
@torch.no_grad()
|
52 |
+
def run(
|
53 |
+
source='0',
|
54 |
+
yolo_weights=WEIGHTS / 'yolov5m.pt', # model.pt path(s),
|
55 |
+
reid_weights=WEIGHTS / 'osnet_x0_25_msmt17.pt', # model.pt path,
|
56 |
+
tracking_method='strongsort',
|
57 |
+
tracking_config=None,
|
58 |
+
imgsz=(640, 640), # inference size (height, width)
|
59 |
+
conf_thres=0.25, # confidence threshold
|
60 |
+
iou_thres=0.45, # NMS IOU threshold
|
61 |
+
max_det=1000, # maximum detections per image
|
62 |
+
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
63 |
+
show_vid=False, # show results
|
64 |
+
save_txt=False, # save results to *.txt
|
65 |
+
save_conf=False, # save confidences in --save-txt labels
|
66 |
+
save_crop=False, # save cropped prediction boxes
|
67 |
+
save_trajectories=False, # save trajectories for each track
|
68 |
+
save_vid=True, # save confidences in --save-txt labels
|
69 |
+
nosave=False, # do not save images/videos
|
70 |
+
classes=None, # filter by class: --class 0, or --class 0 2 3
|
71 |
+
agnostic_nms=False, # class-agnostic NMS
|
72 |
+
augment=False, # augmented inference
|
73 |
+
visualize=False, # visualize features
|
74 |
+
update=False, # update all models
|
75 |
+
#project=ROOT / 'runs' / 'track', # save results to project/name
|
76 |
+
project=ROOT ,# save results to project/name
|
77 |
+
name='exp', # save results to project/name
|
78 |
+
exist_ok=True, # existing project/name ok, do not increment
|
79 |
+
line_thickness=2, # bounding box thickness (pixels)
|
80 |
+
hide_labels=False, # hide labels
|
81 |
+
hide_conf=False, # hide confidences
|
82 |
+
hide_class=False, # hide IDs
|
83 |
+
half=False, # use FP16 half-precision inference
|
84 |
+
dnn=False, # use OpenCV DNN for ONNX inference
|
85 |
+
vid_stride=1, # video frame-rate stride
|
86 |
+
retina_masks=False,
|
87 |
+
):
|
88 |
+
|
89 |
+
source = str(source)
|
90 |
+
save_img = not nosave and not source.endswith('.txt') # save inference images
|
91 |
+
is_file = Path(source).suffix[1:] in (VID_FORMATS)
|
92 |
+
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
|
93 |
+
webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
|
94 |
+
if is_url and is_file:
|
95 |
+
source = check_file(source) # download
|
96 |
+
|
97 |
+
# Directories
|
98 |
+
if not isinstance(yolo_weights, list): # single yolo model
|
99 |
+
exp_name = yolo_weights.stem
|
100 |
+
elif type(yolo_weights) is list and len(yolo_weights) == 1: # single models after --yolo_weights
|
101 |
+
exp_name = Path(yolo_weights[0]).stem
|
102 |
+
else: # multiple models after --yolo_weights
|
103 |
+
exp_name = 'ensemble'
|
104 |
+
exp_name = name if name else exp_name + "_" + reid_weights.stem
|
105 |
+
save_dir = increment_path(Path(project) / exp_name, exist_ok=exist_ok) # increment run
|
106 |
+
(save_dir / 'tracks' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
|
107 |
+
|
108 |
+
# Load model
|
109 |
+
device = select_device(device)
|
110 |
+
is_seg = '-seg' in str(yolo_weights)
|
111 |
+
model = AutoBackend(yolo_weights, device=device, dnn=dnn, fp16=half)
|
112 |
+
stride, names, pt = model.stride, model.names, model.pt
|
113 |
+
imgsz = check_imgsz(imgsz, stride=stride) # check image size
|
114 |
+
|
115 |
+
# Dataloader
|
116 |
+
bs = 1
|
117 |
+
if webcam:
|
118 |
+
show_vid = check_imshow(warn=True)
|
119 |
+
dataset = LoadStreams(
|
120 |
+
source,
|
121 |
+
imgsz=imgsz,
|
122 |
+
stride=stride,
|
123 |
+
auto=pt,
|
124 |
+
transforms=getattr(model.model, 'transforms', None),
|
125 |
+
vid_stride=vid_stride
|
126 |
+
)
|
127 |
+
bs = len(dataset)
|
128 |
+
else:
|
129 |
+
dataset = LoadImages(
|
130 |
+
source,
|
131 |
+
imgsz=imgsz,
|
132 |
+
stride=stride,
|
133 |
+
auto=pt,
|
134 |
+
transforms=getattr(model.model, 'transforms', None),
|
135 |
+
vid_stride=vid_stride
|
136 |
+
)
|
137 |
+
vid_path, vid_writer, txt_path = [None] * bs, [None] * bs, [None] * bs
|
138 |
+
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
|
139 |
+
|
140 |
+
# Create as many strong sort instances as there are video sources
|
141 |
+
tracker_list = []
|
142 |
+
for i in range(bs):
|
143 |
+
tracker = create_tracker(tracking_method, tracking_config, reid_weights, device, half)
|
144 |
+
tracker_list.append(tracker, )
|
145 |
+
if hasattr(tracker_list[i], 'model'):
|
146 |
+
if hasattr(tracker_list[i].model, 'warmup'):
|
147 |
+
tracker_list[i].model.warmup()
|
148 |
+
outputs = [None] * bs
|
149 |
+
|
150 |
+
# Run tracking
|
151 |
+
#model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
|
152 |
+
seen, windows, dt = 0, [], (Profile(), Profile(), Profile(), Profile())
|
153 |
+
curr_frames, prev_frames = [None] * bs, [None] * bs
|
154 |
+
for frame_idx, batch in enumerate(dataset):
|
155 |
+
path, im, im0s, vid_cap, s = batch
|
156 |
+
visualize = increment_path(save_dir / Path(path[0]).stem, mkdir=True) if visualize else False
|
157 |
+
with dt[0]:
|
158 |
+
im = torch.from_numpy(im).to(device)
|
159 |
+
im = im.half() if half else im.float() # uint8 to fp16/32
|
160 |
+
im /= 255.0 # 0 - 255 to 0.0 - 1.0
|
161 |
+
if len(im.shape) == 3:
|
162 |
+
im = im[None] # expand for batch dim
|
163 |
+
|
164 |
+
# Inference
|
165 |
+
with dt[1]:
|
166 |
+
preds = model(im, augment=augment, visualize=visualize)
|
167 |
+
|
168 |
+
# Apply NMS
|
169 |
+
with dt[2]:
|
170 |
+
if is_seg:
|
171 |
+
masks = []
|
172 |
+
p = non_max_suppression(preds[0], conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det, nm=32)
|
173 |
+
proto = preds[1][-1]
|
174 |
+
else:
|
175 |
+
p = non_max_suppression(preds, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
|
176 |
+
|
177 |
+
# Process detections
|
178 |
+
filename = 'out.mp4'
|
179 |
+
for i, det in enumerate(p): # detections per image
|
180 |
+
seen += 1
|
181 |
+
if webcam: # bs >= 1
|
182 |
+
p, im0, _ = path[i], im0s[i].copy(), dataset.count
|
183 |
+
p = Path(p) # to Path
|
184 |
+
s += f'{i}: '
|
185 |
+
txt_file_name = p.name
|
186 |
+
save_path = str(save_dir / filename) # im.jpg, vid.mp4, ...
|
187 |
+
|
188 |
+
else:
|
189 |
+
p, im0, _ = path, im0s.copy(), getattr(dataset, 'frame', 0)
|
190 |
+
p = Path(p) # to Path
|
191 |
+
# video file
|
192 |
+
if source.endswith(VID_FORMATS):
|
193 |
+
txt_file_name = p.stem
|
194 |
+
save_path = str(save_dir / filename) # im.jpg, vid.mp4, ...
|
195 |
+
LOGGER.info(f"p.name is {p.name}, save_path value is {save_path}")
|
196 |
+
# folder with imgs
|
197 |
+
else:
|
198 |
+
txt_file_name = p.parent.name # get folder name containing current img
|
199 |
+
save_path = str(save_dir / p.parent.name) # im.jpg, vid.mp4, ...
|
200 |
+
curr_frames[i] = im0
|
201 |
+
|
202 |
+
txt_path = str(save_dir / 'tracks' / txt_file_name) # im.txt
|
203 |
+
s += '%gx%g ' % im.shape[2:] # print string
|
204 |
+
imc = im0.copy() if save_crop else im0 # for save_crop
|
205 |
+
|
206 |
+
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
|
207 |
+
|
208 |
+
if hasattr(tracker_list[i], 'tracker') and hasattr(tracker_list[i].tracker, 'camera_update'):
|
209 |
+
if prev_frames[i] is not None and curr_frames[i] is not None: # camera motion compensation
|
210 |
+
tracker_list[i].tracker.camera_update(prev_frames[i], curr_frames[i])
|
211 |
+
|
212 |
+
if det is not None and len(det):
|
213 |
+
if is_seg:
|
214 |
+
shape = im0.shape
|
215 |
+
# scale bbox first the crop masks
|
216 |
+
if retina_masks:
|
217 |
+
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], shape).round() # rescale boxes to im0 size
|
218 |
+
masks.append(process_mask_native(proto[i], det[:, 6:], det[:, :4], im0.shape[:2])) # HWC
|
219 |
+
else:
|
220 |
+
masks.append(process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True)) # HWC
|
221 |
+
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], shape).round() # rescale boxes to im0 size
|
222 |
+
else:
|
223 |
+
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size
|
224 |
+
|
225 |
+
# Print results
|
226 |
+
for c in det[:, 5].unique():
|
227 |
+
n = (det[:, 5] == c).sum() # detections per class
|
228 |
+
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
|
229 |
+
|
230 |
+
# pass detections to strongsort
|
231 |
+
with dt[3]:
|
232 |
+
outputs[i] = tracker_list[i].update(det.cpu(), im0)
|
233 |
+
|
234 |
+
# draw boxes for visualization
|
235 |
+
if len(outputs[i]) > 0:
|
236 |
+
|
237 |
+
if is_seg:
|
238 |
+
# Mask plotting
|
239 |
+
annotator.masks(
|
240 |
+
masks[i],
|
241 |
+
colors=[colors(x, True) for x in det[:, 5]],
|
242 |
+
im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(device).permute(2, 0, 1).flip(0).contiguous() /
|
243 |
+
255 if retina_masks else im[i]
|
244 |
+
)
|
245 |
+
|
246 |
+
for j, (output) in enumerate(outputs[i]):
|
247 |
+
|
248 |
+
bbox = output[0:4]
|
249 |
+
id = output[4]
|
250 |
+
cls = output[5]
|
251 |
+
conf = output[6]
|
252 |
+
|
253 |
+
if save_txt:
|
254 |
+
# to MOT format
|
255 |
+
bbox_left = output[0]
|
256 |
+
bbox_top = output[1]
|
257 |
+
bbox_w = output[2] - output[0]
|
258 |
+
bbox_h = output[3] - output[1]
|
259 |
+
# Write MOT compliant results to file
|
260 |
+
with open(txt_path + '.txt', 'a') as f:
|
261 |
+
f.write(('%g ' * 10 + '\n') % (frame_idx + 1, id, bbox_left, # MOT format
|
262 |
+
bbox_top, bbox_w, bbox_h, -1, -1, -1, i))
|
263 |
+
|
264 |
+
if save_vid or save_crop or show_vid: # Add bbox/seg to image
|
265 |
+
c = int(cls) # integer class
|
266 |
+
id = int(id) # integer id
|
267 |
+
label = None if hide_labels else (f'{id} {names[c]}' if hide_conf else \
|
268 |
+
(f'{id} {conf:.2f}' if hide_class else f'{id} {names[c]} {conf:.2f}'))
|
269 |
+
color = colors(c, True)
|
270 |
+
annotator.box_label(bbox, label, color=color)
|
271 |
+
|
272 |
+
if save_trajectories and tracking_method == 'strongsort':
|
273 |
+
q = output[7]
|
274 |
+
tracker_list[i].trajectory(im0, q, color=color)
|
275 |
+
if save_crop:
|
276 |
+
txt_file_name = txt_file_name if (isinstance(path, list) and len(path) > 1) else ''
|
277 |
+
save_one_box(np.array(bbox, dtype=np.int16), imc, file=save_dir / 'crops' / txt_file_name / names[c] / f'{id}' / f'{p.stem}.jpg', BGR=True)
|
278 |
+
|
279 |
+
else:
|
280 |
+
pass
|
281 |
+
#tracker_list[i].tracker.pred_n_update_all_tracks()
|
282 |
+
|
283 |
+
# Stream results
|
284 |
+
im0 = annotator.result()
|
285 |
+
if show_vid:
|
286 |
+
if platform.system() == 'Linux' and p not in windows:
|
287 |
+
windows.append(p)
|
288 |
+
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
289 |
+
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
|
290 |
+
cv2.imshow(str(p), im0)
|
291 |
+
if cv2.waitKey(1) == ord('q'): # 1 millisecond
|
292 |
+
exit()
|
293 |
+
|
294 |
+
# Save results (image with detections)
|
295 |
+
if save_vid:
|
296 |
+
LOGGER.info(f"vid_path, save_path {vid_path[i]}{save_path}")
|
297 |
+
if vid_path[i] != save_path: # new video
|
298 |
+
vid_path[i] = save_path
|
299 |
+
if isinstance(vid_writer[i], cv2.VideoWriter):
|
300 |
+
vid_writer[i].release() # release previous video writer
|
301 |
+
if vid_cap: # video
|
302 |
+
fps = vid_cap.get(cv2.CAP_PROP_FPS)
|
303 |
+
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
304 |
+
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
305 |
+
else: # stream
|
306 |
+
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
307 |
+
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
|
308 |
+
LOGGER.info(f"test Results saved to {colorstr('bold', save_path)}")
|
309 |
+
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
310 |
+
vid_writer[i].write(im0)
|
311 |
+
|
312 |
+
prev_frames[i] = curr_frames[i]
|
313 |
+
|
314 |
+
# Print total time (preprocessing + inference + NMS + tracking)
|
315 |
+
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{sum([dt.dt for dt in dt if hasattr(dt, 'dt')]) * 1E3:.1f}ms")
|
316 |
+
|
317 |
+
# Print results
|
318 |
+
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
|
319 |
+
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS, %.1fms {tracking_method} update per image at shape {(1, 3, *imgsz)}' % t)
|
320 |
+
if save_txt or save_vid:
|
321 |
+
s = f"\n{len(list((save_dir / 'tracks').glob('*.txt')))} tracks saved to {save_dir / 'tracks'}" if save_txt else ''
|
322 |
+
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
|
323 |
+
if update:
|
324 |
+
strip_optimizer(yolo_weights) # update model (to fix SourceChangeWarning)
|
325 |
+
|
326 |
+
|
327 |
+
def parse_opt():
|
328 |
+
parser = argparse.ArgumentParser()
|
329 |
+
#parser.add_argument('--yolo-weights', nargs='+', type=Path, default=WEIGHTS / 'yolov8s-seg.pt', help='model.pt path(s)')
|
330 |
+
parser.add_argument('--reid-weights', type=Path, default=WEIGHTS / 'osnet_x0_25_msmt17.pt')
|
331 |
+
#parser.add_argument('--tracking-method', type=str, default='bytetrack', help='strongsort, ocsort, bytetrack')
|
332 |
+
parser.add_argument('--tracking-config', type=Path, default=None)
|
333 |
+
#parser.add_argument('--source', type=str, default='0', help='file/dir/URL/glob, 0 for webcam')
|
334 |
+
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
|
335 |
+
parser.add_argument('--conf-thres', type=float, default=0.5, help='confidence threshold')
|
336 |
+
parser.add_argument('--iou-thres', type=float, default=0.5, help='NMS IoU threshold')
|
337 |
+
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
|
338 |
+
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
339 |
+
parser.add_argument('--show-vid', action='store_true', help='display tracking video results')
|
340 |
+
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
|
341 |
+
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
|
342 |
+
parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
|
343 |
+
parser.add_argument('--save-trajectories', action='store_true', help='save trajectories for each track')
|
344 |
+
parser.add_argument('--save-vid', action='store_true',default=True, help='save video tracking results')
|
345 |
+
parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
|
346 |
+
# class 0 is person, 1 is bycicle, 2 is car... 79 is oven
|
347 |
+
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
|
348 |
+
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
|
349 |
+
parser.add_argument('--augment', action='store_true', help='augmented inference')
|
350 |
+
parser.add_argument('--visualize', action='store_true', help='visualize features')
|
351 |
+
parser.add_argument('--update', action='store_true', help='update all models')
|
352 |
+
parser.add_argument('--project', default=ROOT , help='save results to project/name')
|
353 |
+
parser.add_argument('--name', default='exp', help='save results to ROOT')
|
354 |
+
parser.add_argument('--exist-ok', default='True', action='store_true', help='existing project/name ok, do not increment')
|
355 |
+
parser.add_argument('--line-thickness', default=2, type=int, help='bounding box thickness (pixels)')
|
356 |
+
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
|
357 |
+
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
|
358 |
+
parser.add_argument('--hide-class', default=False, action='store_true', help='hide IDs')
|
359 |
+
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
|
360 |
+
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
|
361 |
+
parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
|
362 |
+
parser.add_argument('--retina-masks', action='store_true', help='whether to plot masks in native resolution')
|
363 |
+
#opt = parser.parse_args()
|
364 |
+
#opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
|
365 |
+
#opt.tracking_config = ROOT / 'trackers' / opt.tracking_method / 'configs' / (opt.tracking_method + '.yaml')
|
366 |
+
#print_args(vars(opt))
|
367 |
+
#return opt
|
368 |
+
return parser
|
369 |
+
|
370 |
+
|
371 |
+
def main(opt):
|
372 |
+
check_requirements(requirements=ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
|
373 |
+
run(**vars(opt))
|
374 |
+
|
375 |
+
|
376 |
+
#if __name__ == "__main__":
|
377 |
+
# opt = parse_opt()
|
378 |
+
# main(opt)
|
379 |
+
|
380 |
+
def MOT(yoloweights, trackingmethod, sourceVideo):
|
381 |
+
parser = parse_opt()
|
382 |
+
parser.add_argument('--yolo-weights', nargs='+', type=Path, default= yoloweights, help='model.pt path(s)')
|
383 |
+
parser.add_argument('--tracking-method', type=str, default= trackingmethod, help='strongsort, ocsort, bytetrack')
|
384 |
+
parser.add_argument('--source', type=str, default=sourceVideo, help='file/dir/URL/glob, 0 for webcam')
|
385 |
+
opt = parser.parse_args()
|
386 |
+
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
|
387 |
+
opt.tracking_config = ROOT / 'trackers' / opt.tracking_method / 'configs' / (opt.tracking_method + '.yaml')
|
388 |
+
print_args(vars(opt))
|
389 |
+
main(opt)
|
390 |
+
save_dir = increment_path('exp', exist_ok=True)
|
391 |
+
input = os.path.join(save_dir,'out.mp4')
|
392 |
+
outpath = 'output.mp4' #'output/'+ 'output.mp4'
|
393 |
+
command = f"ffmpeg -i {input} -vf fps=30 -vcodec libx264 {outpath}"
|
394 |
+
print(command)
|
395 |
+
os.system(command)
|
396 |
+
#!ffmpeg -i $input -vf fps=30 -vcodec libx264 $outpath tbd
|
397 |
+
return outpath
|
trackers/__init__.py
ADDED
File without changes
|
trackers/bytetrack/basetrack.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
|
5 |
+
class TrackState(object):
|
6 |
+
New = 0
|
7 |
+
Tracked = 1
|
8 |
+
Lost = 2
|
9 |
+
Removed = 3
|
10 |
+
|
11 |
+
|
12 |
+
class BaseTrack(object):
|
13 |
+
_count = 0
|
14 |
+
|
15 |
+
track_id = 0
|
16 |
+
is_activated = False
|
17 |
+
state = TrackState.New
|
18 |
+
|
19 |
+
history = OrderedDict()
|
20 |
+
features = []
|
21 |
+
curr_feature = None
|
22 |
+
score = 0
|
23 |
+
start_frame = 0
|
24 |
+
frame_id = 0
|
25 |
+
time_since_update = 0
|
26 |
+
|
27 |
+
# multi-camera
|
28 |
+
location = (np.inf, np.inf)
|
29 |
+
|
30 |
+
@property
|
31 |
+
def end_frame(self):
|
32 |
+
return self.frame_id
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def next_id():
|
36 |
+
BaseTrack._count += 1
|
37 |
+
return BaseTrack._count
|
38 |
+
|
39 |
+
def activate(self, *args):
|
40 |
+
raise NotImplementedError
|
41 |
+
|
42 |
+
def predict(self):
|
43 |
+
raise NotImplementedError
|
44 |
+
|
45 |
+
def update(self, *args, **kwargs):
|
46 |
+
raise NotImplementedError
|
47 |
+
|
48 |
+
def mark_lost(self):
|
49 |
+
self.state = TrackState.Lost
|
50 |
+
|
51 |
+
def mark_removed(self):
|
52 |
+
self.state = TrackState.Removed
|
trackers/bytetrack/byte_tracker.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from collections import deque
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import copy
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from ultralytics.yolo.utils.ops import xywh2xyxy, xyxy2xywh
|
10 |
+
|
11 |
+
|
12 |
+
from trackers.bytetrack.kalman_filter import KalmanFilter
|
13 |
+
from trackers.bytetrack import matching
|
14 |
+
from trackers.bytetrack.basetrack import BaseTrack, TrackState
|
15 |
+
|
16 |
+
class STrack(BaseTrack):
|
17 |
+
shared_kalman = KalmanFilter()
|
18 |
+
def __init__(self, tlwh, score, cls):
|
19 |
+
|
20 |
+
# wait activate
|
21 |
+
self._tlwh = np.asarray(tlwh, dtype=np.float32)
|
22 |
+
self.kalman_filter = None
|
23 |
+
self.mean, self.covariance = None, None
|
24 |
+
self.is_activated = False
|
25 |
+
|
26 |
+
self.score = score
|
27 |
+
self.tracklet_len = 0
|
28 |
+
self.cls = cls
|
29 |
+
|
30 |
+
def predict(self):
|
31 |
+
mean_state = self.mean.copy()
|
32 |
+
if self.state != TrackState.Tracked:
|
33 |
+
mean_state[7] = 0
|
34 |
+
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def multi_predict(stracks):
|
38 |
+
if len(stracks) > 0:
|
39 |
+
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
40 |
+
multi_covariance = np.asarray([st.covariance for st in stracks])
|
41 |
+
for i, st in enumerate(stracks):
|
42 |
+
if st.state != TrackState.Tracked:
|
43 |
+
multi_mean[i][7] = 0
|
44 |
+
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
45 |
+
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
46 |
+
stracks[i].mean = mean
|
47 |
+
stracks[i].covariance = cov
|
48 |
+
|
49 |
+
def activate(self, kalman_filter, frame_id):
|
50 |
+
"""Start a new tracklet"""
|
51 |
+
self.kalman_filter = kalman_filter
|
52 |
+
self.track_id = self.next_id()
|
53 |
+
self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
|
54 |
+
|
55 |
+
self.tracklet_len = 0
|
56 |
+
self.state = TrackState.Tracked
|
57 |
+
if frame_id == 1:
|
58 |
+
self.is_activated = True
|
59 |
+
# self.is_activated = True
|
60 |
+
self.frame_id = frame_id
|
61 |
+
self.start_frame = frame_id
|
62 |
+
|
63 |
+
def re_activate(self, new_track, frame_id, new_id=False):
|
64 |
+
self.mean, self.covariance = self.kalman_filter.update(
|
65 |
+
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
|
66 |
+
)
|
67 |
+
self.tracklet_len = 0
|
68 |
+
self.state = TrackState.Tracked
|
69 |
+
self.is_activated = True
|
70 |
+
self.frame_id = frame_id
|
71 |
+
if new_id:
|
72 |
+
self.track_id = self.next_id()
|
73 |
+
self.score = new_track.score
|
74 |
+
self.cls = new_track.cls
|
75 |
+
|
76 |
+
def update(self, new_track, frame_id):
|
77 |
+
"""
|
78 |
+
Update a matched track
|
79 |
+
:type new_track: STrack
|
80 |
+
:type frame_id: int
|
81 |
+
:type update_feature: bool
|
82 |
+
:return:
|
83 |
+
"""
|
84 |
+
self.frame_id = frame_id
|
85 |
+
self.tracklet_len += 1
|
86 |
+
# self.cls = cls
|
87 |
+
|
88 |
+
new_tlwh = new_track.tlwh
|
89 |
+
self.mean, self.covariance = self.kalman_filter.update(
|
90 |
+
self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
|
91 |
+
self.state = TrackState.Tracked
|
92 |
+
self.is_activated = True
|
93 |
+
|
94 |
+
self.score = new_track.score
|
95 |
+
|
96 |
+
@property
|
97 |
+
# @jit(nopython=True)
|
98 |
+
def tlwh(self):
|
99 |
+
"""Get current position in bounding box format `(top left x, top left y,
|
100 |
+
width, height)`.
|
101 |
+
"""
|
102 |
+
if self.mean is None:
|
103 |
+
return self._tlwh.copy()
|
104 |
+
ret = self.mean[:4].copy()
|
105 |
+
ret[2] *= ret[3]
|
106 |
+
ret[:2] -= ret[2:] / 2
|
107 |
+
return ret
|
108 |
+
|
109 |
+
@property
|
110 |
+
# @jit(nopython=True)
|
111 |
+
def tlbr(self):
|
112 |
+
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
113 |
+
`(top left, bottom right)`.
|
114 |
+
"""
|
115 |
+
ret = self.tlwh.copy()
|
116 |
+
ret[2:] += ret[:2]
|
117 |
+
return ret
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
# @jit(nopython=True)
|
121 |
+
def tlwh_to_xyah(tlwh):
|
122 |
+
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
123 |
+
height)`, where the aspect ratio is `width / height`.
|
124 |
+
"""
|
125 |
+
ret = np.asarray(tlwh).copy()
|
126 |
+
ret[:2] += ret[2:] / 2
|
127 |
+
ret[2] /= ret[3]
|
128 |
+
return ret
|
129 |
+
|
130 |
+
def to_xyah(self):
|
131 |
+
return self.tlwh_to_xyah(self.tlwh)
|
132 |
+
|
133 |
+
@staticmethod
|
134 |
+
# @jit(nopython=True)
|
135 |
+
def tlbr_to_tlwh(tlbr):
|
136 |
+
ret = np.asarray(tlbr).copy()
|
137 |
+
ret[2:] -= ret[:2]
|
138 |
+
return ret
|
139 |
+
|
140 |
+
@staticmethod
|
141 |
+
# @jit(nopython=True)
|
142 |
+
def tlwh_to_tlbr(tlwh):
|
143 |
+
ret = np.asarray(tlwh).copy()
|
144 |
+
ret[2:] += ret[:2]
|
145 |
+
return ret
|
146 |
+
|
147 |
+
def __repr__(self):
|
148 |
+
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
149 |
+
|
150 |
+
|
151 |
+
class BYTETracker(object):
|
152 |
+
def __init__(self, track_thresh=0.45, match_thresh=0.8, track_buffer=25, frame_rate=30):
|
153 |
+
self.tracked_stracks = [] # type: list[STrack]
|
154 |
+
self.lost_stracks = [] # type: list[STrack]
|
155 |
+
self.removed_stracks = [] # type: list[STrack]
|
156 |
+
|
157 |
+
self.frame_id = 0
|
158 |
+
self.track_buffer=track_buffer
|
159 |
+
|
160 |
+
self.track_thresh = track_thresh
|
161 |
+
self.match_thresh = match_thresh
|
162 |
+
self.det_thresh = track_thresh + 0.1
|
163 |
+
self.buffer_size = int(frame_rate / 30.0 * track_buffer)
|
164 |
+
self.max_time_lost = self.buffer_size
|
165 |
+
self.kalman_filter = KalmanFilter()
|
166 |
+
|
167 |
+
def update(self, dets, _):
|
168 |
+
self.frame_id += 1
|
169 |
+
activated_starcks = []
|
170 |
+
refind_stracks = []
|
171 |
+
lost_stracks = []
|
172 |
+
removed_stracks = []
|
173 |
+
|
174 |
+
xyxys = dets[:, 0:4]
|
175 |
+
xywh = xyxy2xywh(xyxys)
|
176 |
+
confs = dets[:, 4]
|
177 |
+
clss = dets[:, 5]
|
178 |
+
|
179 |
+
classes = clss.numpy()
|
180 |
+
xyxys = xyxys.numpy()
|
181 |
+
confs = confs.numpy()
|
182 |
+
|
183 |
+
remain_inds = confs > self.track_thresh
|
184 |
+
inds_low = confs > 0.1
|
185 |
+
inds_high = confs < self.track_thresh
|
186 |
+
|
187 |
+
inds_second = np.logical_and(inds_low, inds_high)
|
188 |
+
|
189 |
+
dets_second = xywh[inds_second]
|
190 |
+
dets = xywh[remain_inds]
|
191 |
+
|
192 |
+
scores_keep = confs[remain_inds]
|
193 |
+
scores_second = confs[inds_second]
|
194 |
+
|
195 |
+
clss_keep = classes[remain_inds]
|
196 |
+
clss_second = classes[inds_second]
|
197 |
+
|
198 |
+
|
199 |
+
if len(dets) > 0:
|
200 |
+
'''Detections'''
|
201 |
+
detections = [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores_keep, clss_keep)]
|
202 |
+
else:
|
203 |
+
detections = []
|
204 |
+
|
205 |
+
''' Add newly detected tracklets to tracked_stracks'''
|
206 |
+
unconfirmed = []
|
207 |
+
tracked_stracks = [] # type: list[STrack]
|
208 |
+
for track in self.tracked_stracks:
|
209 |
+
if not track.is_activated:
|
210 |
+
unconfirmed.append(track)
|
211 |
+
else:
|
212 |
+
tracked_stracks.append(track)
|
213 |
+
|
214 |
+
''' Step 2: First association, with high score detection boxes'''
|
215 |
+
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
216 |
+
# Predict the current location with KF
|
217 |
+
STrack.multi_predict(strack_pool)
|
218 |
+
dists = matching.iou_distance(strack_pool, detections)
|
219 |
+
#if not self.args.mot20:
|
220 |
+
dists = matching.fuse_score(dists, detections)
|
221 |
+
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.match_thresh)
|
222 |
+
|
223 |
+
for itracked, idet in matches:
|
224 |
+
track = strack_pool[itracked]
|
225 |
+
det = detections[idet]
|
226 |
+
if track.state == TrackState.Tracked:
|
227 |
+
track.update(detections[idet], self.frame_id)
|
228 |
+
activated_starcks.append(track)
|
229 |
+
else:
|
230 |
+
track.re_activate(det, self.frame_id, new_id=False)
|
231 |
+
refind_stracks.append(track)
|
232 |
+
|
233 |
+
''' Step 3: Second association, with low score detection boxes'''
|
234 |
+
# association the untrack to the low score detections
|
235 |
+
if len(dets_second) > 0:
|
236 |
+
'''Detections'''
|
237 |
+
detections_second = [STrack(xywh, s, c) for (xywh, s, c) in zip(dets_second, scores_second, clss_second)]
|
238 |
+
else:
|
239 |
+
detections_second = []
|
240 |
+
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
|
241 |
+
dists = matching.iou_distance(r_tracked_stracks, detections_second)
|
242 |
+
matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
|
243 |
+
for itracked, idet in matches:
|
244 |
+
track = r_tracked_stracks[itracked]
|
245 |
+
det = detections_second[idet]
|
246 |
+
if track.state == TrackState.Tracked:
|
247 |
+
track.update(det, self.frame_id)
|
248 |
+
activated_starcks.append(track)
|
249 |
+
else:
|
250 |
+
track.re_activate(det, self.frame_id, new_id=False)
|
251 |
+
refind_stracks.append(track)
|
252 |
+
|
253 |
+
for it in u_track:
|
254 |
+
track = r_tracked_stracks[it]
|
255 |
+
if not track.state == TrackState.Lost:
|
256 |
+
track.mark_lost()
|
257 |
+
lost_stracks.append(track)
|
258 |
+
|
259 |
+
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
|
260 |
+
detections = [detections[i] for i in u_detection]
|
261 |
+
dists = matching.iou_distance(unconfirmed, detections)
|
262 |
+
#if not self.args.mot20:
|
263 |
+
dists = matching.fuse_score(dists, detections)
|
264 |
+
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
265 |
+
for itracked, idet in matches:
|
266 |
+
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
267 |
+
activated_starcks.append(unconfirmed[itracked])
|
268 |
+
for it in u_unconfirmed:
|
269 |
+
track = unconfirmed[it]
|
270 |
+
track.mark_removed()
|
271 |
+
removed_stracks.append(track)
|
272 |
+
|
273 |
+
""" Step 4: Init new stracks"""
|
274 |
+
for inew in u_detection:
|
275 |
+
track = detections[inew]
|
276 |
+
if track.score < self.det_thresh:
|
277 |
+
continue
|
278 |
+
track.activate(self.kalman_filter, self.frame_id)
|
279 |
+
activated_starcks.append(track)
|
280 |
+
""" Step 5: Update state"""
|
281 |
+
for track in self.lost_stracks:
|
282 |
+
if self.frame_id - track.end_frame > self.max_time_lost:
|
283 |
+
track.mark_removed()
|
284 |
+
removed_stracks.append(track)
|
285 |
+
|
286 |
+
# print('Ramained match {} s'.format(t4-t3))
|
287 |
+
|
288 |
+
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
289 |
+
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
|
290 |
+
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
|
291 |
+
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
|
292 |
+
self.lost_stracks.extend(lost_stracks)
|
293 |
+
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
|
294 |
+
self.removed_stracks.extend(removed_stracks)
|
295 |
+
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
|
296 |
+
# get scores of lost tracks
|
297 |
+
output_stracks = [track for track in self.tracked_stracks if track.is_activated]
|
298 |
+
outputs = []
|
299 |
+
for t in output_stracks:
|
300 |
+
output= []
|
301 |
+
tlwh = t.tlwh
|
302 |
+
tid = t.track_id
|
303 |
+
tlwh = np.expand_dims(tlwh, axis=0)
|
304 |
+
xyxy = xywh2xyxy(tlwh)
|
305 |
+
xyxy = np.squeeze(xyxy, axis=0)
|
306 |
+
output.extend(xyxy)
|
307 |
+
output.append(tid)
|
308 |
+
output.append(t.cls)
|
309 |
+
output.append(t.score)
|
310 |
+
outputs.append(output)
|
311 |
+
|
312 |
+
return outputs
|
313 |
+
#track_id, class_id, conf
|
314 |
+
|
315 |
+
def joint_stracks(tlista, tlistb):
|
316 |
+
exists = {}
|
317 |
+
res = []
|
318 |
+
for t in tlista:
|
319 |
+
exists[t.track_id] = 1
|
320 |
+
res.append(t)
|
321 |
+
for t in tlistb:
|
322 |
+
tid = t.track_id
|
323 |
+
if not exists.get(tid, 0):
|
324 |
+
exists[tid] = 1
|
325 |
+
res.append(t)
|
326 |
+
return res
|
327 |
+
|
328 |
+
|
329 |
+
def sub_stracks(tlista, tlistb):
|
330 |
+
stracks = {}
|
331 |
+
for t in tlista:
|
332 |
+
stracks[t.track_id] = t
|
333 |
+
for t in tlistb:
|
334 |
+
tid = t.track_id
|
335 |
+
if stracks.get(tid, 0):
|
336 |
+
del stracks[tid]
|
337 |
+
return list(stracks.values())
|
338 |
+
|
339 |
+
|
340 |
+
def remove_duplicate_stracks(stracksa, stracksb):
|
341 |
+
pdist = matching.iou_distance(stracksa, stracksb)
|
342 |
+
pairs = np.where(pdist < 0.15)
|
343 |
+
dupa, dupb = list(), list()
|
344 |
+
for p, q in zip(*pairs):
|
345 |
+
timep = stracksa[p].frame_id - stracksa[p].start_frame
|
346 |
+
timeq = stracksb[q].frame_id - stracksb[q].start_frame
|
347 |
+
if timep > timeq:
|
348 |
+
dupb.append(q)
|
349 |
+
else:
|
350 |
+
dupa.append(p)
|
351 |
+
resa = [t for i, t in enumerate(stracksa) if not i in dupa]
|
352 |
+
resb = [t for i, t in enumerate(stracksb) if not i in dupb]
|
353 |
+
return resa, resb
|
trackers/bytetrack/configs/bytetrack.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bytetrack:
|
2 |
+
track_thresh: 0.6 # tracking confidence threshold
|
3 |
+
track_buffer: 30 # the frames for keep lost tracks
|
4 |
+
match_thresh: 0.8 # matching threshold for tracking
|
5 |
+
frame_rate: 30 # FPS
|
6 |
+
conf_thres: 0.5122620708221085
|
7 |
+
|
trackers/bytetrack/kalman_filter.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vim: expandtab:ts=4:sw=4
|
2 |
+
import numpy as np
|
3 |
+
import scipy.linalg
|
4 |
+
|
5 |
+
|
6 |
+
"""
|
7 |
+
Table for the 0.95 quantile of the chi-square distribution with N degrees of
|
8 |
+
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
|
9 |
+
function and used as Mahalanobis gating threshold.
|
10 |
+
"""
|
11 |
+
chi2inv95 = {
|
12 |
+
1: 3.8415,
|
13 |
+
2: 5.9915,
|
14 |
+
3: 7.8147,
|
15 |
+
4: 9.4877,
|
16 |
+
5: 11.070,
|
17 |
+
6: 12.592,
|
18 |
+
7: 14.067,
|
19 |
+
8: 15.507,
|
20 |
+
9: 16.919}
|
21 |
+
|
22 |
+
|
23 |
+
class KalmanFilter(object):
|
24 |
+
"""
|
25 |
+
A simple Kalman filter for tracking bounding boxes in image space.
|
26 |
+
|
27 |
+
The 8-dimensional state space
|
28 |
+
|
29 |
+
x, y, a, h, vx, vy, va, vh
|
30 |
+
|
31 |
+
contains the bounding box center position (x, y), aspect ratio a, height h,
|
32 |
+
and their respective velocities.
|
33 |
+
|
34 |
+
Object motion follows a constant velocity model. The bounding box location
|
35 |
+
(x, y, a, h) is taken as direct observation of the state space (linear
|
36 |
+
observation model).
|
37 |
+
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self):
|
41 |
+
ndim, dt = 4, 1.
|
42 |
+
|
43 |
+
# Create Kalman filter model matrices.
|
44 |
+
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
45 |
+
for i in range(ndim):
|
46 |
+
self._motion_mat[i, ndim + i] = dt
|
47 |
+
self._update_mat = np.eye(ndim, 2 * ndim)
|
48 |
+
|
49 |
+
# Motion and observation uncertainty are chosen relative to the current
|
50 |
+
# state estimate. These weights control the amount of uncertainty in
|
51 |
+
# the model. This is a bit hacky.
|
52 |
+
self._std_weight_position = 1. / 20
|
53 |
+
self._std_weight_velocity = 1. / 160
|
54 |
+
|
55 |
+
def initiate(self, measurement):
|
56 |
+
"""Create track from unassociated measurement.
|
57 |
+
|
58 |
+
Parameters
|
59 |
+
----------
|
60 |
+
measurement : ndarray
|
61 |
+
Bounding box coordinates (x, y, a, h) with center position (x, y),
|
62 |
+
aspect ratio a, and height h.
|
63 |
+
|
64 |
+
Returns
|
65 |
+
-------
|
66 |
+
(ndarray, ndarray)
|
67 |
+
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
68 |
+
dimensional) of the new track. Unobserved velocities are initialized
|
69 |
+
to 0 mean.
|
70 |
+
|
71 |
+
"""
|
72 |
+
mean_pos = measurement
|
73 |
+
mean_vel = np.zeros_like(mean_pos)
|
74 |
+
mean = np.r_[mean_pos, mean_vel]
|
75 |
+
|
76 |
+
std = [
|
77 |
+
2 * self._std_weight_position * measurement[3],
|
78 |
+
2 * self._std_weight_position * measurement[3],
|
79 |
+
1e-2,
|
80 |
+
2 * self._std_weight_position * measurement[3],
|
81 |
+
10 * self._std_weight_velocity * measurement[3],
|
82 |
+
10 * self._std_weight_velocity * measurement[3],
|
83 |
+
1e-5,
|
84 |
+
10 * self._std_weight_velocity * measurement[3]]
|
85 |
+
covariance = np.diag(np.square(std))
|
86 |
+
return mean, covariance
|
87 |
+
|
88 |
+
def predict(self, mean, covariance):
|
89 |
+
"""Run Kalman filter prediction step.
|
90 |
+
|
91 |
+
Parameters
|
92 |
+
----------
|
93 |
+
mean : ndarray
|
94 |
+
The 8 dimensional mean vector of the object state at the previous
|
95 |
+
time step.
|
96 |
+
covariance : ndarray
|
97 |
+
The 8x8 dimensional covariance matrix of the object state at the
|
98 |
+
previous time step.
|
99 |
+
|
100 |
+
Returns
|
101 |
+
-------
|
102 |
+
(ndarray, ndarray)
|
103 |
+
Returns the mean vector and covariance matrix of the predicted
|
104 |
+
state. Unobserved velocities are initialized to 0 mean.
|
105 |
+
|
106 |
+
"""
|
107 |
+
std_pos = [
|
108 |
+
self._std_weight_position * mean[3],
|
109 |
+
self._std_weight_position * mean[3],
|
110 |
+
1e-2,
|
111 |
+
self._std_weight_position * mean[3]]
|
112 |
+
std_vel = [
|
113 |
+
self._std_weight_velocity * mean[3],
|
114 |
+
self._std_weight_velocity * mean[3],
|
115 |
+
1e-5,
|
116 |
+
self._std_weight_velocity * mean[3]]
|
117 |
+
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
118 |
+
|
119 |
+
#mean = np.dot(self._motion_mat, mean)
|
120 |
+
mean = np.dot(mean, self._motion_mat.T)
|
121 |
+
covariance = np.linalg.multi_dot((
|
122 |
+
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
123 |
+
|
124 |
+
return mean, covariance
|
125 |
+
|
126 |
+
def project(self, mean, covariance):
|
127 |
+
"""Project state distribution to measurement space.
|
128 |
+
|
129 |
+
Parameters
|
130 |
+
----------
|
131 |
+
mean : ndarray
|
132 |
+
The state's mean vector (8 dimensional array).
|
133 |
+
covariance : ndarray
|
134 |
+
The state's covariance matrix (8x8 dimensional).
|
135 |
+
|
136 |
+
Returns
|
137 |
+
-------
|
138 |
+
(ndarray, ndarray)
|
139 |
+
Returns the projected mean and covariance matrix of the given state
|
140 |
+
estimate.
|
141 |
+
|
142 |
+
"""
|
143 |
+
std = [
|
144 |
+
self._std_weight_position * mean[3],
|
145 |
+
self._std_weight_position * mean[3],
|
146 |
+
1e-1,
|
147 |
+
self._std_weight_position * mean[3]]
|
148 |
+
innovation_cov = np.diag(np.square(std))
|
149 |
+
|
150 |
+
mean = np.dot(self._update_mat, mean)
|
151 |
+
covariance = np.linalg.multi_dot((
|
152 |
+
self._update_mat, covariance, self._update_mat.T))
|
153 |
+
return mean, covariance + innovation_cov
|
154 |
+
|
155 |
+
def multi_predict(self, mean, covariance):
|
156 |
+
"""Run Kalman filter prediction step (Vectorized version).
|
157 |
+
Parameters
|
158 |
+
----------
|
159 |
+
mean : ndarray
|
160 |
+
The Nx8 dimensional mean matrix of the object states at the previous
|
161 |
+
time step.
|
162 |
+
covariance : ndarray
|
163 |
+
The Nx8x8 dimensional covariance matrics of the object states at the
|
164 |
+
previous time step.
|
165 |
+
Returns
|
166 |
+
-------
|
167 |
+
(ndarray, ndarray)
|
168 |
+
Returns the mean vector and covariance matrix of the predicted
|
169 |
+
state. Unobserved velocities are initialized to 0 mean.
|
170 |
+
"""
|
171 |
+
std_pos = [
|
172 |
+
self._std_weight_position * mean[:, 3],
|
173 |
+
self._std_weight_position * mean[:, 3],
|
174 |
+
1e-2 * np.ones_like(mean[:, 3]),
|
175 |
+
self._std_weight_position * mean[:, 3]]
|
176 |
+
std_vel = [
|
177 |
+
self._std_weight_velocity * mean[:, 3],
|
178 |
+
self._std_weight_velocity * mean[:, 3],
|
179 |
+
1e-5 * np.ones_like(mean[:, 3]),
|
180 |
+
self._std_weight_velocity * mean[:, 3]]
|
181 |
+
sqr = np.square(np.r_[std_pos, std_vel]).T
|
182 |
+
|
183 |
+
motion_cov = []
|
184 |
+
for i in range(len(mean)):
|
185 |
+
motion_cov.append(np.diag(sqr[i]))
|
186 |
+
motion_cov = np.asarray(motion_cov)
|
187 |
+
|
188 |
+
mean = np.dot(mean, self._motion_mat.T)
|
189 |
+
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
|
190 |
+
covariance = np.dot(left, self._motion_mat.T) + motion_cov
|
191 |
+
|
192 |
+
return mean, covariance
|
193 |
+
|
194 |
+
def update(self, mean, covariance, measurement):
|
195 |
+
"""Run Kalman filter correction step.
|
196 |
+
|
197 |
+
Parameters
|
198 |
+
----------
|
199 |
+
mean : ndarray
|
200 |
+
The predicted state's mean vector (8 dimensional).
|
201 |
+
covariance : ndarray
|
202 |
+
The state's covariance matrix (8x8 dimensional).
|
203 |
+
measurement : ndarray
|
204 |
+
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
205 |
+
is the center position, a the aspect ratio, and h the height of the
|
206 |
+
bounding box.
|
207 |
+
|
208 |
+
Returns
|
209 |
+
-------
|
210 |
+
(ndarray, ndarray)
|
211 |
+
Returns the measurement-corrected state distribution.
|
212 |
+
|
213 |
+
"""
|
214 |
+
projected_mean, projected_cov = self.project(mean, covariance)
|
215 |
+
|
216 |
+
chol_factor, lower = scipy.linalg.cho_factor(
|
217 |
+
projected_cov, lower=True, check_finite=False)
|
218 |
+
kalman_gain = scipy.linalg.cho_solve(
|
219 |
+
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
|
220 |
+
check_finite=False).T
|
221 |
+
innovation = measurement - projected_mean
|
222 |
+
|
223 |
+
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
224 |
+
new_covariance = covariance - np.linalg.multi_dot((
|
225 |
+
kalman_gain, projected_cov, kalman_gain.T))
|
226 |
+
return new_mean, new_covariance
|
227 |
+
|
228 |
+
def gating_distance(self, mean, covariance, measurements,
|
229 |
+
only_position=False, metric='maha'):
|
230 |
+
"""Compute gating distance between state distribution and measurements.
|
231 |
+
A suitable distance threshold can be obtained from `chi2inv95`. If
|
232 |
+
`only_position` is False, the chi-square distribution has 4 degrees of
|
233 |
+
freedom, otherwise 2.
|
234 |
+
Parameters
|
235 |
+
----------
|
236 |
+
mean : ndarray
|
237 |
+
Mean vector over the state distribution (8 dimensional).
|
238 |
+
covariance : ndarray
|
239 |
+
Covariance of the state distribution (8x8 dimensional).
|
240 |
+
measurements : ndarray
|
241 |
+
An Nx4 dimensional matrix of N measurements, each in
|
242 |
+
format (x, y, a, h) where (x, y) is the bounding box center
|
243 |
+
position, a the aspect ratio, and h the height.
|
244 |
+
only_position : Optional[bool]
|
245 |
+
If True, distance computation is done with respect to the bounding
|
246 |
+
box center position only.
|
247 |
+
Returns
|
248 |
+
-------
|
249 |
+
ndarray
|
250 |
+
Returns an array of length N, where the i-th element contains the
|
251 |
+
squared Mahalanobis distance between (mean, covariance) and
|
252 |
+
`measurements[i]`.
|
253 |
+
"""
|
254 |
+
mean, covariance = self.project(mean, covariance)
|
255 |
+
if only_position:
|
256 |
+
mean, covariance = mean[:2], covariance[:2, :2]
|
257 |
+
measurements = measurements[:, :2]
|
258 |
+
|
259 |
+
d = measurements - mean
|
260 |
+
if metric == 'gaussian':
|
261 |
+
return np.sum(d * d, axis=1)
|
262 |
+
elif metric == 'maha':
|
263 |
+
cholesky_factor = np.linalg.cholesky(covariance)
|
264 |
+
z = scipy.linalg.solve_triangular(
|
265 |
+
cholesky_factor, d.T, lower=True, check_finite=False,
|
266 |
+
overwrite_b=True)
|
267 |
+
squared_maha = np.sum(z * z, axis=0)
|
268 |
+
return squared_maha
|
269 |
+
else:
|
270 |
+
raise ValueError('invalid distance metric')
|
trackers/bytetrack/matching.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import scipy
|
4 |
+
import lap
|
5 |
+
from scipy.spatial.distance import cdist
|
6 |
+
|
7 |
+
from trackers.bytetrack import kalman_filter
|
8 |
+
import time
|
9 |
+
|
10 |
+
def merge_matches(m1, m2, shape):
|
11 |
+
O,P,Q = shape
|
12 |
+
m1 = np.asarray(m1)
|
13 |
+
m2 = np.asarray(m2)
|
14 |
+
|
15 |
+
M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
|
16 |
+
M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
|
17 |
+
|
18 |
+
mask = M1*M2
|
19 |
+
match = mask.nonzero()
|
20 |
+
match = list(zip(match[0], match[1]))
|
21 |
+
unmatched_O = tuple(set(range(O)) - set([i for i, j in match]))
|
22 |
+
unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match]))
|
23 |
+
|
24 |
+
return match, unmatched_O, unmatched_Q
|
25 |
+
|
26 |
+
|
27 |
+
def _indices_to_matches(cost_matrix, indices, thresh):
|
28 |
+
matched_cost = cost_matrix[tuple(zip(*indices))]
|
29 |
+
matched_mask = (matched_cost <= thresh)
|
30 |
+
|
31 |
+
matches = indices[matched_mask]
|
32 |
+
unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
|
33 |
+
unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
|
34 |
+
|
35 |
+
return matches, unmatched_a, unmatched_b
|
36 |
+
|
37 |
+
|
38 |
+
def linear_assignment(cost_matrix, thresh):
|
39 |
+
if cost_matrix.size == 0:
|
40 |
+
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
41 |
+
matches, unmatched_a, unmatched_b = [], [], []
|
42 |
+
cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
43 |
+
for ix, mx in enumerate(x):
|
44 |
+
if mx >= 0:
|
45 |
+
matches.append([ix, mx])
|
46 |
+
unmatched_a = np.where(x < 0)[0]
|
47 |
+
unmatched_b = np.where(y < 0)[0]
|
48 |
+
matches = np.asarray(matches)
|
49 |
+
return matches, unmatched_a, unmatched_b
|
50 |
+
|
51 |
+
|
52 |
+
def ious(atlbrs, btlbrs):
|
53 |
+
"""
|
54 |
+
Compute cost based on IoU
|
55 |
+
:type atlbrs: list[tlbr] | np.ndarray
|
56 |
+
:type atlbrs: list[tlbr] | np.ndarray
|
57 |
+
|
58 |
+
:rtype ious np.ndarray
|
59 |
+
"""
|
60 |
+
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
|
61 |
+
if ious.size == 0:
|
62 |
+
return ious
|
63 |
+
|
64 |
+
ious = bbox_ious(
|
65 |
+
np.ascontiguousarray(atlbrs, dtype=np.float32),
|
66 |
+
np.ascontiguousarray(btlbrs, dtype=np.float32)
|
67 |
+
)
|
68 |
+
|
69 |
+
return ious
|
70 |
+
|
71 |
+
|
72 |
+
def iou_distance(atracks, btracks):
|
73 |
+
"""
|
74 |
+
Compute cost based on IoU
|
75 |
+
:type atracks: list[STrack]
|
76 |
+
:type btracks: list[STrack]
|
77 |
+
|
78 |
+
:rtype cost_matrix np.ndarray
|
79 |
+
"""
|
80 |
+
|
81 |
+
if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
82 |
+
atlbrs = atracks
|
83 |
+
btlbrs = btracks
|
84 |
+
else:
|
85 |
+
atlbrs = [track.tlbr for track in atracks]
|
86 |
+
btlbrs = [track.tlbr for track in btracks]
|
87 |
+
_ious = ious(atlbrs, btlbrs)
|
88 |
+
cost_matrix = 1 - _ious
|
89 |
+
|
90 |
+
return cost_matrix
|
91 |
+
|
92 |
+
def v_iou_distance(atracks, btracks):
|
93 |
+
"""
|
94 |
+
Compute cost based on IoU
|
95 |
+
:type atracks: list[STrack]
|
96 |
+
:type btracks: list[STrack]
|
97 |
+
|
98 |
+
:rtype cost_matrix np.ndarray
|
99 |
+
"""
|
100 |
+
|
101 |
+
if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
102 |
+
atlbrs = atracks
|
103 |
+
btlbrs = btracks
|
104 |
+
else:
|
105 |
+
atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks]
|
106 |
+
btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks]
|
107 |
+
_ious = ious(atlbrs, btlbrs)
|
108 |
+
cost_matrix = 1 - _ious
|
109 |
+
|
110 |
+
return cost_matrix
|
111 |
+
|
112 |
+
def embedding_distance(tracks, detections, metric='cosine'):
|
113 |
+
"""
|
114 |
+
:param tracks: list[STrack]
|
115 |
+
:param detections: list[BaseTrack]
|
116 |
+
:param metric:
|
117 |
+
:return: cost_matrix np.ndarray
|
118 |
+
"""
|
119 |
+
|
120 |
+
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
|
121 |
+
if cost_matrix.size == 0:
|
122 |
+
return cost_matrix
|
123 |
+
det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32)
|
124 |
+
#for i, track in enumerate(tracks):
|
125 |
+
#cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))
|
126 |
+
track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32)
|
127 |
+
cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Nomalized features
|
128 |
+
return cost_matrix
|
129 |
+
|
130 |
+
|
131 |
+
def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
|
132 |
+
if cost_matrix.size == 0:
|
133 |
+
return cost_matrix
|
134 |
+
gating_dim = 2 if only_position else 4
|
135 |
+
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
136 |
+
measurements = np.asarray([det.to_xyah() for det in detections])
|
137 |
+
for row, track in enumerate(tracks):
|
138 |
+
gating_distance = kf.gating_distance(
|
139 |
+
track.mean, track.covariance, measurements, only_position)
|
140 |
+
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
141 |
+
return cost_matrix
|
142 |
+
|
143 |
+
|
144 |
+
def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
|
145 |
+
if cost_matrix.size == 0:
|
146 |
+
return cost_matrix
|
147 |
+
gating_dim = 2 if only_position else 4
|
148 |
+
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
149 |
+
measurements = np.asarray([det.to_xyah() for det in detections])
|
150 |
+
for row, track in enumerate(tracks):
|
151 |
+
gating_distance = kf.gating_distance(
|
152 |
+
track.mean, track.covariance, measurements, only_position, metric='maha')
|
153 |
+
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
154 |
+
cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance
|
155 |
+
return cost_matrix
|
156 |
+
|
157 |
+
|
158 |
+
def fuse_iou(cost_matrix, tracks, detections):
|
159 |
+
if cost_matrix.size == 0:
|
160 |
+
return cost_matrix
|
161 |
+
reid_sim = 1 - cost_matrix
|
162 |
+
iou_dist = iou_distance(tracks, detections)
|
163 |
+
iou_sim = 1 - iou_dist
|
164 |
+
fuse_sim = reid_sim * (1 + iou_sim) / 2
|
165 |
+
det_scores = np.array([det.score for det in detections])
|
166 |
+
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
|
167 |
+
#fuse_sim = fuse_sim * (1 + det_scores) / 2
|
168 |
+
fuse_cost = 1 - fuse_sim
|
169 |
+
return fuse_cost
|
170 |
+
|
171 |
+
|
172 |
+
def fuse_score(cost_matrix, detections):
|
173 |
+
if cost_matrix.size == 0:
|
174 |
+
return cost_matrix
|
175 |
+
iou_sim = 1 - cost_matrix
|
176 |
+
det_scores = np.array([det.score for det in detections])
|
177 |
+
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
|
178 |
+
fuse_sim = iou_sim * det_scores
|
179 |
+
fuse_cost = 1 - fuse_sim
|
180 |
+
return fuse_cost
|
181 |
+
|
182 |
+
|
183 |
+
def bbox_ious(boxes, query_boxes):
|
184 |
+
"""
|
185 |
+
Parameters
|
186 |
+
----------
|
187 |
+
boxes: (N, 4) ndarray of float
|
188 |
+
query_boxes: (K, 4) ndarray of float
|
189 |
+
Returns
|
190 |
+
-------
|
191 |
+
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
|
192 |
+
"""
|
193 |
+
N = boxes.shape[0]
|
194 |
+
K = query_boxes.shape[0]
|
195 |
+
overlaps = np.zeros((N, K), dtype=np.float32)
|
196 |
+
|
197 |
+
for k in range(K):
|
198 |
+
box_area = (
|
199 |
+
(query_boxes[k, 2] - query_boxes[k, 0] + 1) *
|
200 |
+
(query_boxes[k, 3] - query_boxes[k, 1] + 1)
|
201 |
+
)
|
202 |
+
for n in range(N):
|
203 |
+
iw = (
|
204 |
+
min(boxes[n, 2], query_boxes[k, 2]) -
|
205 |
+
max(boxes[n, 0], query_boxes[k, 0]) + 1
|
206 |
+
)
|
207 |
+
if iw > 0:
|
208 |
+
ih = (
|
209 |
+
min(boxes[n, 3], query_boxes[k, 3]) -
|
210 |
+
max(boxes[n, 1], query_boxes[k, 1]) + 1
|
211 |
+
)
|
212 |
+
if ih > 0:
|
213 |
+
ua = float(
|
214 |
+
(boxes[n, 2] - boxes[n, 0] + 1) *
|
215 |
+
(boxes[n, 3] - boxes[n, 1] + 1) +
|
216 |
+
box_area - iw * ih
|
217 |
+
)
|
218 |
+
overlaps[n, k] = iw * ih / ua
|
219 |
+
return overlaps
|
trackers/multi_tracker_zoo.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from trackers.strongsort.utils.parser import get_config
|
2 |
+
|
3 |
+
|
4 |
+
def create_tracker(tracker_type, tracker_config, reid_weights, device, half):
|
5 |
+
|
6 |
+
cfg = get_config()
|
7 |
+
cfg.merge_from_file(tracker_config)
|
8 |
+
|
9 |
+
if tracker_type == 'strongsort':
|
10 |
+
from trackers.strongsort.strong_sort import StrongSORT
|
11 |
+
strongsort = StrongSORT(
|
12 |
+
reid_weights,
|
13 |
+
device,
|
14 |
+
half,
|
15 |
+
max_dist=cfg.strongsort.max_dist,
|
16 |
+
max_iou_dist=cfg.strongsort.max_iou_dist,
|
17 |
+
max_age=cfg.strongsort.max_age,
|
18 |
+
max_unmatched_preds=cfg.strongsort.max_unmatched_preds,
|
19 |
+
n_init=cfg.strongsort.n_init,
|
20 |
+
nn_budget=cfg.strongsort.nn_budget,
|
21 |
+
mc_lambda=cfg.strongsort.mc_lambda,
|
22 |
+
ema_alpha=cfg.strongsort.ema_alpha,
|
23 |
+
|
24 |
+
)
|
25 |
+
return strongsort
|
26 |
+
|
27 |
+
elif tracker_type == 'ocsort':
|
28 |
+
from trackers.ocsort.ocsort import OCSort
|
29 |
+
ocsort = OCSort(
|
30 |
+
det_thresh=cfg.ocsort.det_thresh,
|
31 |
+
max_age=cfg.ocsort.max_age,
|
32 |
+
min_hits=cfg.ocsort.min_hits,
|
33 |
+
iou_threshold=cfg.ocsort.iou_thresh,
|
34 |
+
delta_t=cfg.ocsort.delta_t,
|
35 |
+
asso_func=cfg.ocsort.asso_func,
|
36 |
+
inertia=cfg.ocsort.inertia,
|
37 |
+
use_byte=cfg.ocsort.use_byte,
|
38 |
+
)
|
39 |
+
return ocsort
|
40 |
+
|
41 |
+
elif tracker_type == 'bytetrack':
|
42 |
+
from trackers.bytetrack.byte_tracker import BYTETracker
|
43 |
+
bytetracker = BYTETracker(
|
44 |
+
track_thresh=cfg.bytetrack.track_thresh,
|
45 |
+
match_thresh=cfg.bytetrack.match_thresh,
|
46 |
+
track_buffer=cfg.bytetrack.track_buffer,
|
47 |
+
frame_rate=cfg.bytetrack.frame_rate
|
48 |
+
)
|
49 |
+
return bytetracker
|
50 |
+
else:
|
51 |
+
print('No such tracker')
|
52 |
+
exit()
|
trackers/ocsort/association.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def iou_batch(bboxes1, bboxes2):
|
6 |
+
"""
|
7 |
+
From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
|
8 |
+
"""
|
9 |
+
bboxes2 = np.expand_dims(bboxes2, 0)
|
10 |
+
bboxes1 = np.expand_dims(bboxes1, 1)
|
11 |
+
|
12 |
+
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
13 |
+
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
14 |
+
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
15 |
+
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
16 |
+
w = np.maximum(0., xx2 - xx1)
|
17 |
+
h = np.maximum(0., yy2 - yy1)
|
18 |
+
wh = w * h
|
19 |
+
o = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
20 |
+
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh)
|
21 |
+
return(o)
|
22 |
+
|
23 |
+
|
24 |
+
def giou_batch(bboxes1, bboxes2):
|
25 |
+
"""
|
26 |
+
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
|
27 |
+
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
|
28 |
+
:return:
|
29 |
+
"""
|
30 |
+
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
|
31 |
+
# ensure predict's bbox form
|
32 |
+
bboxes2 = np.expand_dims(bboxes2, 0)
|
33 |
+
bboxes1 = np.expand_dims(bboxes1, 1)
|
34 |
+
|
35 |
+
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
36 |
+
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
37 |
+
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
38 |
+
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
39 |
+
w = np.maximum(0., xx2 - xx1)
|
40 |
+
h = np.maximum(0., yy2 - yy1)
|
41 |
+
wh = w * h
|
42 |
+
iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
43 |
+
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh)
|
44 |
+
|
45 |
+
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
|
46 |
+
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
|
47 |
+
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
|
48 |
+
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
|
49 |
+
wc = xxc2 - xxc1
|
50 |
+
hc = yyc2 - yyc1
|
51 |
+
assert((wc > 0).all() and (hc > 0).all())
|
52 |
+
area_enclose = wc * hc
|
53 |
+
giou = iou - (area_enclose - wh) / area_enclose
|
54 |
+
giou = (giou + 1.)/2.0 # resize from (-1,1) to (0,1)
|
55 |
+
return giou
|
56 |
+
|
57 |
+
|
58 |
+
def diou_batch(bboxes1, bboxes2):
|
59 |
+
"""
|
60 |
+
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
|
61 |
+
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
|
62 |
+
:return:
|
63 |
+
"""
|
64 |
+
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
|
65 |
+
# ensure predict's bbox form
|
66 |
+
bboxes2 = np.expand_dims(bboxes2, 0)
|
67 |
+
bboxes1 = np.expand_dims(bboxes1, 1)
|
68 |
+
|
69 |
+
# calculate the intersection box
|
70 |
+
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
71 |
+
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
72 |
+
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
73 |
+
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
74 |
+
w = np.maximum(0., xx2 - xx1)
|
75 |
+
h = np.maximum(0., yy2 - yy1)
|
76 |
+
wh = w * h
|
77 |
+
iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
78 |
+
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh)
|
79 |
+
|
80 |
+
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
|
81 |
+
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
|
82 |
+
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
|
83 |
+
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
|
84 |
+
|
85 |
+
inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
|
86 |
+
|
87 |
+
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
|
88 |
+
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
|
89 |
+
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
|
90 |
+
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
|
91 |
+
|
92 |
+
outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2
|
93 |
+
diou = iou - inner_diag / outer_diag
|
94 |
+
|
95 |
+
return (diou + 1) / 2.0 # resize from (-1,1) to (0,1)
|
96 |
+
|
97 |
+
def ciou_batch(bboxes1, bboxes2):
|
98 |
+
"""
|
99 |
+
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
|
100 |
+
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
|
101 |
+
:return:
|
102 |
+
"""
|
103 |
+
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
|
104 |
+
# ensure predict's bbox form
|
105 |
+
bboxes2 = np.expand_dims(bboxes2, 0)
|
106 |
+
bboxes1 = np.expand_dims(bboxes1, 1)
|
107 |
+
|
108 |
+
# calculate the intersection box
|
109 |
+
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
110 |
+
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
111 |
+
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
112 |
+
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
113 |
+
w = np.maximum(0., xx2 - xx1)
|
114 |
+
h = np.maximum(0., yy2 - yy1)
|
115 |
+
wh = w * h
|
116 |
+
iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
117 |
+
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh)
|
118 |
+
|
119 |
+
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
|
120 |
+
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
|
121 |
+
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
|
122 |
+
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
|
123 |
+
|
124 |
+
inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
|
125 |
+
|
126 |
+
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
|
127 |
+
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
|
128 |
+
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
|
129 |
+
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
|
130 |
+
|
131 |
+
outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2
|
132 |
+
|
133 |
+
w1 = bboxes1[..., 2] - bboxes1[..., 0]
|
134 |
+
h1 = bboxes1[..., 3] - bboxes1[..., 1]
|
135 |
+
w2 = bboxes2[..., 2] - bboxes2[..., 0]
|
136 |
+
h2 = bboxes2[..., 3] - bboxes2[..., 1]
|
137 |
+
|
138 |
+
# prevent dividing over zero. add one pixel shift
|
139 |
+
h2 = h2 + 1.
|
140 |
+
h1 = h1 + 1.
|
141 |
+
arctan = np.arctan(w2/h2) - np.arctan(w1/h1)
|
142 |
+
v = (4 / (np.pi ** 2)) * (arctan ** 2)
|
143 |
+
S = 1 - iou
|
144 |
+
alpha = v / (S+v)
|
145 |
+
ciou = iou - inner_diag / outer_diag - alpha * v
|
146 |
+
|
147 |
+
return (ciou + 1) / 2.0 # resize from (-1,1) to (0,1)
|
148 |
+
|
149 |
+
|
150 |
+
def ct_dist(bboxes1, bboxes2):
|
151 |
+
"""
|
152 |
+
Measure the center distance between two sets of bounding boxes,
|
153 |
+
this is a coarse implementation, we don't recommend using it only
|
154 |
+
for association, which can be unstable and sensitive to frame rate
|
155 |
+
and object speed.
|
156 |
+
"""
|
157 |
+
bboxes2 = np.expand_dims(bboxes2, 0)
|
158 |
+
bboxes1 = np.expand_dims(bboxes1, 1)
|
159 |
+
|
160 |
+
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
|
161 |
+
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
|
162 |
+
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
|
163 |
+
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
|
164 |
+
|
165 |
+
ct_dist2 = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
|
166 |
+
|
167 |
+
ct_dist = np.sqrt(ct_dist2)
|
168 |
+
|
169 |
+
# The linear rescaling is a naive version and needs more study
|
170 |
+
ct_dist = ct_dist / ct_dist.max()
|
171 |
+
return ct_dist.max() - ct_dist # resize to (0,1)
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
def speed_direction_batch(dets, tracks):
|
176 |
+
tracks = tracks[..., np.newaxis]
|
177 |
+
CX1, CY1 = (dets[:,0] + dets[:,2])/2.0, (dets[:,1]+dets[:,3])/2.0
|
178 |
+
CX2, CY2 = (tracks[:,0] + tracks[:,2]) /2.0, (tracks[:,1]+tracks[:,3])/2.0
|
179 |
+
dx = CX1 - CX2
|
180 |
+
dy = CY1 - CY2
|
181 |
+
norm = np.sqrt(dx**2 + dy**2) + 1e-6
|
182 |
+
dx = dx / norm
|
183 |
+
dy = dy / norm
|
184 |
+
return dy, dx # size: num_track x num_det
|
185 |
+
|
186 |
+
|
187 |
+
def linear_assignment(cost_matrix):
|
188 |
+
try:
|
189 |
+
import lap
|
190 |
+
_, x, y = lap.lapjv(cost_matrix, extend_cost=True)
|
191 |
+
return np.array([[y[i],i] for i in x if i >= 0]) #
|
192 |
+
except ImportError:
|
193 |
+
from scipy.optimize import linear_sum_assignment
|
194 |
+
x, y = linear_sum_assignment(cost_matrix)
|
195 |
+
return np.array(list(zip(x, y)))
|
196 |
+
|
197 |
+
|
198 |
+
def associate_detections_to_trackers(detections,trackers, iou_threshold = 0.3):
|
199 |
+
"""
|
200 |
+
Assigns detections to tracked object (both represented as bounding boxes)
|
201 |
+
Returns 3 lists of matches, unmatched_detections and unmatched_trackers
|
202 |
+
"""
|
203 |
+
if(len(trackers)==0):
|
204 |
+
return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)
|
205 |
+
|
206 |
+
iou_matrix = iou_batch(detections, trackers)
|
207 |
+
|
208 |
+
if min(iou_matrix.shape) > 0:
|
209 |
+
a = (iou_matrix > iou_threshold).astype(np.int32)
|
210 |
+
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
211 |
+
matched_indices = np.stack(np.where(a), axis=1)
|
212 |
+
else:
|
213 |
+
matched_indices = linear_assignment(-iou_matrix)
|
214 |
+
else:
|
215 |
+
matched_indices = np.empty(shape=(0,2))
|
216 |
+
|
217 |
+
unmatched_detections = []
|
218 |
+
for d, det in enumerate(detections):
|
219 |
+
if(d not in matched_indices[:,0]):
|
220 |
+
unmatched_detections.append(d)
|
221 |
+
unmatched_trackers = []
|
222 |
+
for t, trk in enumerate(trackers):
|
223 |
+
if(t not in matched_indices[:,1]):
|
224 |
+
unmatched_trackers.append(t)
|
225 |
+
|
226 |
+
#filter out matched with low IOU
|
227 |
+
matches = []
|
228 |
+
for m in matched_indices:
|
229 |
+
if(iou_matrix[m[0], m[1]]<iou_threshold):
|
230 |
+
unmatched_detections.append(m[0])
|
231 |
+
unmatched_trackers.append(m[1])
|
232 |
+
else:
|
233 |
+
matches.append(m.reshape(1,2))
|
234 |
+
if(len(matches)==0):
|
235 |
+
matches = np.empty((0,2),dtype=int)
|
236 |
+
else:
|
237 |
+
matches = np.concatenate(matches,axis=0)
|
238 |
+
|
239 |
+
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
240 |
+
|
241 |
+
|
242 |
+
def associate(detections, trackers, iou_threshold, velocities, previous_obs, vdc_weight):
|
243 |
+
if(len(trackers)==0):
|
244 |
+
return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)
|
245 |
+
|
246 |
+
Y, X = speed_direction_batch(detections, previous_obs)
|
247 |
+
inertia_Y, inertia_X = velocities[:,0], velocities[:,1]
|
248 |
+
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
|
249 |
+
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
|
250 |
+
diff_angle_cos = inertia_X * X + inertia_Y * Y
|
251 |
+
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
|
252 |
+
diff_angle = np.arccos(diff_angle_cos)
|
253 |
+
diff_angle = (np.pi /2.0 - np.abs(diff_angle)) / np.pi
|
254 |
+
|
255 |
+
valid_mask = np.ones(previous_obs.shape[0])
|
256 |
+
valid_mask[np.where(previous_obs[:,4]<0)] = 0
|
257 |
+
|
258 |
+
iou_matrix = iou_batch(detections, trackers)
|
259 |
+
scores = np.repeat(detections[:,-1][:, np.newaxis], trackers.shape[0], axis=1)
|
260 |
+
# iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this
|
261 |
+
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
|
262 |
+
|
263 |
+
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
|
264 |
+
angle_diff_cost = angle_diff_cost.T
|
265 |
+
angle_diff_cost = angle_diff_cost * scores
|
266 |
+
|
267 |
+
if min(iou_matrix.shape) > 0:
|
268 |
+
a = (iou_matrix > iou_threshold).astype(np.int32)
|
269 |
+
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
270 |
+
matched_indices = np.stack(np.where(a), axis=1)
|
271 |
+
else:
|
272 |
+
matched_indices = linear_assignment(-(iou_matrix+angle_diff_cost))
|
273 |
+
else:
|
274 |
+
matched_indices = np.empty(shape=(0,2))
|
275 |
+
|
276 |
+
unmatched_detections = []
|
277 |
+
for d, det in enumerate(detections):
|
278 |
+
if(d not in matched_indices[:,0]):
|
279 |
+
unmatched_detections.append(d)
|
280 |
+
unmatched_trackers = []
|
281 |
+
for t, trk in enumerate(trackers):
|
282 |
+
if(t not in matched_indices[:,1]):
|
283 |
+
unmatched_trackers.append(t)
|
284 |
+
|
285 |
+
# filter out matched with low IOU
|
286 |
+
matches = []
|
287 |
+
for m in matched_indices:
|
288 |
+
if(iou_matrix[m[0], m[1]]<iou_threshold):
|
289 |
+
unmatched_detections.append(m[0])
|
290 |
+
unmatched_trackers.append(m[1])
|
291 |
+
else:
|
292 |
+
matches.append(m.reshape(1,2))
|
293 |
+
if(len(matches)==0):
|
294 |
+
matches = np.empty((0,2),dtype=int)
|
295 |
+
else:
|
296 |
+
matches = np.concatenate(matches,axis=0)
|
297 |
+
|
298 |
+
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
299 |
+
|
300 |
+
|
301 |
+
def associate_kitti(detections, trackers, det_cates, iou_threshold,
|
302 |
+
velocities, previous_obs, vdc_weight):
|
303 |
+
if(len(trackers)==0):
|
304 |
+
return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)
|
305 |
+
|
306 |
+
"""
|
307 |
+
Cost from the velocity direction consistency
|
308 |
+
"""
|
309 |
+
Y, X = speed_direction_batch(detections, previous_obs)
|
310 |
+
inertia_Y, inertia_X = velocities[:,0], velocities[:,1]
|
311 |
+
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
|
312 |
+
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
|
313 |
+
diff_angle_cos = inertia_X * X + inertia_Y * Y
|
314 |
+
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
|
315 |
+
diff_angle = np.arccos(diff_angle_cos)
|
316 |
+
diff_angle = (np.pi /2.0 - np.abs(diff_angle)) / np.pi
|
317 |
+
|
318 |
+
valid_mask = np.ones(previous_obs.shape[0])
|
319 |
+
valid_mask[np.where(previous_obs[:,4]<0)]=0
|
320 |
+
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
|
321 |
+
|
322 |
+
scores = np.repeat(detections[:,-1][:, np.newaxis], trackers.shape[0], axis=1)
|
323 |
+
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
|
324 |
+
angle_diff_cost = angle_diff_cost.T
|
325 |
+
angle_diff_cost = angle_diff_cost * scores
|
326 |
+
|
327 |
+
"""
|
328 |
+
Cost from IoU
|
329 |
+
"""
|
330 |
+
iou_matrix = iou_batch(detections, trackers)
|
331 |
+
|
332 |
+
|
333 |
+
"""
|
334 |
+
With multiple categories, generate the cost for catgory mismatch
|
335 |
+
"""
|
336 |
+
num_dets = detections.shape[0]
|
337 |
+
num_trk = trackers.shape[0]
|
338 |
+
cate_matrix = np.zeros((num_dets, num_trk))
|
339 |
+
for i in range(num_dets):
|
340 |
+
for j in range(num_trk):
|
341 |
+
if det_cates[i] != trackers[j, 4]:
|
342 |
+
cate_matrix[i][j] = -1e6
|
343 |
+
|
344 |
+
cost_matrix = - iou_matrix -angle_diff_cost - cate_matrix
|
345 |
+
|
346 |
+
if min(iou_matrix.shape) > 0:
|
347 |
+
a = (iou_matrix > iou_threshold).astype(np.int32)
|
348 |
+
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
349 |
+
matched_indices = np.stack(np.where(a), axis=1)
|
350 |
+
else:
|
351 |
+
matched_indices = linear_assignment(cost_matrix)
|
352 |
+
else:
|
353 |
+
matched_indices = np.empty(shape=(0,2))
|
354 |
+
|
355 |
+
unmatched_detections = []
|
356 |
+
for d, det in enumerate(detections):
|
357 |
+
if(d not in matched_indices[:,0]):
|
358 |
+
unmatched_detections.append(d)
|
359 |
+
unmatched_trackers = []
|
360 |
+
for t, trk in enumerate(trackers):
|
361 |
+
if(t not in matched_indices[:,1]):
|
362 |
+
unmatched_trackers.append(t)
|
363 |
+
|
364 |
+
#filter out matched with low IOU
|
365 |
+
matches = []
|
366 |
+
for m in matched_indices:
|
367 |
+
if(iou_matrix[m[0], m[1]]<iou_threshold):
|
368 |
+
unmatched_detections.append(m[0])
|
369 |
+
unmatched_trackers.append(m[1])
|
370 |
+
else:
|
371 |
+
matches.append(m.reshape(1,2))
|
372 |
+
if(len(matches)==0):
|
373 |
+
matches = np.empty((0,2),dtype=int)
|
374 |
+
else:
|
375 |
+
matches = np.concatenate(matches,axis=0)
|
376 |
+
|
377 |
+
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
trackers/ocsort/configs/ocsort.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Trial number: 137
|
2 |
+
# HOTA, MOTA, IDF1: [55.567]
|
3 |
+
ocsort:
|
4 |
+
asso_func: giou
|
5 |
+
conf_thres: 0.5122620708221085
|
6 |
+
delta_t: 1
|
7 |
+
det_thresh: 0
|
8 |
+
inertia: 0.3941737016672115
|
9 |
+
iou_thresh: 0.22136877277096445
|
10 |
+
max_age: 50
|
11 |
+
min_hits: 1
|
12 |
+
use_byte: false
|
trackers/ocsort/kalmanfilter.py
ADDED
@@ -0,0 +1,1581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# pylint: disable=invalid-name, too-many-arguments, too-many-branches,
|
3 |
+
# pylint: disable=too-many-locals, too-many-instance-attributes, too-many-lines
|
4 |
+
|
5 |
+
"""
|
6 |
+
This module implements the linear Kalman filter in both an object
|
7 |
+
oriented and procedural form. The KalmanFilter class implements
|
8 |
+
the filter by storing the various matrices in instance variables,
|
9 |
+
minimizing the amount of bookkeeping you have to do.
|
10 |
+
All Kalman filters operate with a predict->update cycle. The
|
11 |
+
predict step, implemented with the method or function predict(),
|
12 |
+
uses the state transition matrix F to predict the state in the next
|
13 |
+
time period (epoch). The state is stored as a gaussian (x, P), where
|
14 |
+
x is the state (column) vector, and P is its covariance. Covariance
|
15 |
+
matrix Q specifies the process covariance. In Bayesian terms, this
|
16 |
+
prediction is called the *prior*, which you can think of colloquially
|
17 |
+
as the estimate prior to incorporating the measurement.
|
18 |
+
The update step, implemented with the method or function `update()`,
|
19 |
+
incorporates the measurement z with covariance R, into the state
|
20 |
+
estimate (x, P). The class stores the system uncertainty in S,
|
21 |
+
the innovation (residual between prediction and measurement in
|
22 |
+
measurement space) in y, and the Kalman gain in k. The procedural
|
23 |
+
form returns these variables to you. In Bayesian terms this computes
|
24 |
+
the *posterior* - the estimate after the information from the
|
25 |
+
measurement is incorporated.
|
26 |
+
Whether you use the OO form or procedural form is up to you. If
|
27 |
+
matrices such as H, R, and F are changing each epoch, you'll probably
|
28 |
+
opt to use the procedural form. If they are unchanging, the OO
|
29 |
+
form is perhaps easier to use since you won't need to keep track
|
30 |
+
of these matrices. This is especially useful if you are implementing
|
31 |
+
banks of filters or comparing various KF designs for performance;
|
32 |
+
a trivial coding bug could lead to using the wrong sets of matrices.
|
33 |
+
This module also offers an implementation of the RTS smoother, and
|
34 |
+
other helper functions, such as log likelihood computations.
|
35 |
+
The Saver class allows you to easily save the state of the
|
36 |
+
KalmanFilter class after every update
|
37 |
+
This module expects NumPy arrays for all values that expect
|
38 |
+
arrays, although in a few cases, particularly method parameters,
|
39 |
+
it will accept types that convert to NumPy arrays, such as lists
|
40 |
+
of lists. These exceptions are documented in the method or function.
|
41 |
+
Examples
|
42 |
+
--------
|
43 |
+
The following example constructs a constant velocity kinematic
|
44 |
+
filter, filters noisy data, and plots the results. It also demonstrates
|
45 |
+
using the Saver class to save the state of the filter at each epoch.
|
46 |
+
.. code-block:: Python
|
47 |
+
import matplotlib.pyplot as plt
|
48 |
+
import numpy as np
|
49 |
+
from filterpy.kalman import KalmanFilter
|
50 |
+
from filterpy.common import Q_discrete_white_noise, Saver
|
51 |
+
r_std, q_std = 2., 0.003
|
52 |
+
cv = KalmanFilter(dim_x=2, dim_z=1)
|
53 |
+
cv.x = np.array([[0., 1.]]) # position, velocity
|
54 |
+
cv.F = np.array([[1, dt],[ [0, 1]])
|
55 |
+
cv.R = np.array([[r_std^^2]])
|
56 |
+
f.H = np.array([[1., 0.]])
|
57 |
+
f.P = np.diag([.1^^2, .03^^2)
|
58 |
+
f.Q = Q_discrete_white_noise(2, dt, q_std**2)
|
59 |
+
saver = Saver(cv)
|
60 |
+
for z in range(100):
|
61 |
+
cv.predict()
|
62 |
+
cv.update([z + randn() * r_std])
|
63 |
+
saver.save() # save the filter's state
|
64 |
+
saver.to_array()
|
65 |
+
plt.plot(saver.x[:, 0])
|
66 |
+
# plot all of the priors
|
67 |
+
plt.plot(saver.x_prior[:, 0])
|
68 |
+
# plot mahalanobis distance
|
69 |
+
plt.figure()
|
70 |
+
plt.plot(saver.mahalanobis)
|
71 |
+
This code implements the same filter using the procedural form
|
72 |
+
x = np.array([[0., 1.]]) # position, velocity
|
73 |
+
F = np.array([[1, dt],[ [0, 1]])
|
74 |
+
R = np.array([[r_std^^2]])
|
75 |
+
H = np.array([[1., 0.]])
|
76 |
+
P = np.diag([.1^^2, .03^^2)
|
77 |
+
Q = Q_discrete_white_noise(2, dt, q_std**2)
|
78 |
+
for z in range(100):
|
79 |
+
x, P = predict(x, P, F=F, Q=Q)
|
80 |
+
x, P = update(x, P, z=[z + randn() * r_std], R=R, H=H)
|
81 |
+
xs.append(x[0, 0])
|
82 |
+
plt.plot(xs)
|
83 |
+
For more examples see the test subdirectory, or refer to the
|
84 |
+
book cited below. In it I both teach Kalman filtering from basic
|
85 |
+
principles, and teach the use of this library in great detail.
|
86 |
+
FilterPy library.
|
87 |
+
http://github.com/rlabbe/filterpy
|
88 |
+
Documentation at:
|
89 |
+
https://filterpy.readthedocs.org
|
90 |
+
Supporting book at:
|
91 |
+
https://github.com/rlabbe/Kalman-and-Bayesian-Filters-in-Python
|
92 |
+
This is licensed under an MIT license. See the readme.MD file
|
93 |
+
for more information.
|
94 |
+
Copyright 2014-2018 Roger R Labbe Jr.
|
95 |
+
"""
|
96 |
+
|
97 |
+
from __future__ import absolute_import, division
|
98 |
+
|
99 |
+
from copy import deepcopy
|
100 |
+
from math import log, exp, sqrt
|
101 |
+
import sys
|
102 |
+
import numpy as np
|
103 |
+
from numpy import dot, zeros, eye, isscalar, shape
|
104 |
+
import numpy.linalg as linalg
|
105 |
+
from filterpy.stats import logpdf
|
106 |
+
from filterpy.common import pretty_str, reshape_z
|
107 |
+
|
108 |
+
|
109 |
+
class KalmanFilterNew(object):
|
110 |
+
""" Implements a Kalman filter. You are responsible for setting the
|
111 |
+
various state variables to reasonable values; the defaults will
|
112 |
+
not give you a functional filter.
|
113 |
+
For now the best documentation is my free book Kalman and Bayesian
|
114 |
+
Filters in Python [2]_. The test files in this directory also give you a
|
115 |
+
basic idea of use, albeit without much description.
|
116 |
+
In brief, you will first construct this object, specifying the size of
|
117 |
+
the state vector with dim_x and the size of the measurement vector that
|
118 |
+
you will be using with dim_z. These are mostly used to perform size checks
|
119 |
+
when you assign values to the various matrices. For example, if you
|
120 |
+
specified dim_z=2 and then try to assign a 3x3 matrix to R (the
|
121 |
+
measurement noise matrix you will get an assert exception because R
|
122 |
+
should be 2x2. (If for whatever reason you need to alter the size of
|
123 |
+
things midstream just use the underscore version of the matrices to
|
124 |
+
assign directly: your_filter._R = a_3x3_matrix.)
|
125 |
+
After construction the filter will have default matrices created for you,
|
126 |
+
but you must specify the values for each. It’s usually easiest to just
|
127 |
+
overwrite them rather than assign to each element yourself. This will be
|
128 |
+
clearer in the example below. All are of type numpy.array.
|
129 |
+
Examples
|
130 |
+
--------
|
131 |
+
Here is a filter that tracks position and velocity using a sensor that only
|
132 |
+
reads position.
|
133 |
+
First construct the object with the required dimensionality. Here the state
|
134 |
+
(`dim_x`) has 2 coefficients (position and velocity), and the measurement
|
135 |
+
(`dim_z`) has one. In FilterPy `x` is the state, `z` is the measurement.
|
136 |
+
.. code::
|
137 |
+
from filterpy.kalman import KalmanFilter
|
138 |
+
f = KalmanFilter (dim_x=2, dim_z=1)
|
139 |
+
Assign the initial value for the state (position and velocity). You can do this
|
140 |
+
with a two dimensional array like so:
|
141 |
+
.. code::
|
142 |
+
f.x = np.array([[2.], # position
|
143 |
+
[0.]]) # velocity
|
144 |
+
or just use a one dimensional array, which I prefer doing.
|
145 |
+
.. code::
|
146 |
+
f.x = np.array([2., 0.])
|
147 |
+
Define the state transition matrix:
|
148 |
+
.. code::
|
149 |
+
f.F = np.array([[1.,1.],
|
150 |
+
[0.,1.]])
|
151 |
+
Define the measurement function. Here we need to convert a position-velocity
|
152 |
+
vector into just a position vector, so we use:
|
153 |
+
.. code::
|
154 |
+
f.H = np.array([[1., 0.]])
|
155 |
+
Define the state's covariance matrix P.
|
156 |
+
.. code::
|
157 |
+
f.P = np.array([[1000., 0.],
|
158 |
+
[ 0., 1000.] ])
|
159 |
+
Now assign the measurement noise. Here the dimension is 1x1, so I can
|
160 |
+
use a scalar
|
161 |
+
.. code::
|
162 |
+
f.R = 5
|
163 |
+
I could have done this instead:
|
164 |
+
.. code::
|
165 |
+
f.R = np.array([[5.]])
|
166 |
+
Note that this must be a 2 dimensional array.
|
167 |
+
Finally, I will assign the process noise. Here I will take advantage of
|
168 |
+
another FilterPy library function:
|
169 |
+
.. code::
|
170 |
+
from filterpy.common import Q_discrete_white_noise
|
171 |
+
f.Q = Q_discrete_white_noise(dim=2, dt=0.1, var=0.13)
|
172 |
+
Now just perform the standard predict/update loop:
|
173 |
+
.. code::
|
174 |
+
while some_condition_is_true:
|
175 |
+
z = get_sensor_reading()
|
176 |
+
f.predict()
|
177 |
+
f.update(z)
|
178 |
+
do_something_with_estimate (f.x)
|
179 |
+
**Procedural Form**
|
180 |
+
This module also contains stand alone functions to perform Kalman filtering.
|
181 |
+
Use these if you are not a fan of objects.
|
182 |
+
**Example**
|
183 |
+
.. code::
|
184 |
+
while True:
|
185 |
+
z, R = read_sensor()
|
186 |
+
x, P = predict(x, P, F, Q)
|
187 |
+
x, P = update(x, P, z, R, H)
|
188 |
+
See my book Kalman and Bayesian Filters in Python [2]_.
|
189 |
+
You will have to set the following attributes after constructing this
|
190 |
+
object for the filter to perform properly. Please note that there are
|
191 |
+
various checks in place to ensure that you have made everything the
|
192 |
+
'correct' size. However, it is possible to provide incorrectly sized
|
193 |
+
arrays such that the linear algebra can not perform an operation.
|
194 |
+
It can also fail silently - you can end up with matrices of a size that
|
195 |
+
allows the linear algebra to work, but are the wrong shape for the problem
|
196 |
+
you are trying to solve.
|
197 |
+
Parameters
|
198 |
+
----------
|
199 |
+
dim_x : int
|
200 |
+
Number of state variables for the Kalman filter. For example, if
|
201 |
+
you are tracking the position and velocity of an object in two
|
202 |
+
dimensions, dim_x would be 4.
|
203 |
+
This is used to set the default size of P, Q, and u
|
204 |
+
dim_z : int
|
205 |
+
Number of of measurement inputs. For example, if the sensor
|
206 |
+
provides you with position in (x,y), dim_z would be 2.
|
207 |
+
dim_u : int (optional)
|
208 |
+
size of the control input, if it is being used.
|
209 |
+
Default value of 0 indicates it is not used.
|
210 |
+
compute_log_likelihood : bool (default = True)
|
211 |
+
Computes log likelihood by default, but this can be a slow
|
212 |
+
computation, so if you never use it you can turn this computation
|
213 |
+
off.
|
214 |
+
Attributes
|
215 |
+
----------
|
216 |
+
x : numpy.array(dim_x, 1)
|
217 |
+
Current state estimate. Any call to update() or predict() updates
|
218 |
+
this variable.
|
219 |
+
P : numpy.array(dim_x, dim_x)
|
220 |
+
Current state covariance matrix. Any call to update() or predict()
|
221 |
+
updates this variable.
|
222 |
+
x_prior : numpy.array(dim_x, 1)
|
223 |
+
Prior (predicted) state estimate. The *_prior and *_post attributes
|
224 |
+
are for convenience; they store the prior and posterior of the
|
225 |
+
current epoch. Read Only.
|
226 |
+
P_prior : numpy.array(dim_x, dim_x)
|
227 |
+
Prior (predicted) state covariance matrix. Read Only.
|
228 |
+
x_post : numpy.array(dim_x, 1)
|
229 |
+
Posterior (updated) state estimate. Read Only.
|
230 |
+
P_post : numpy.array(dim_x, dim_x)
|
231 |
+
Posterior (updated) state covariance matrix. Read Only.
|
232 |
+
z : numpy.array
|
233 |
+
Last measurement used in update(). Read only.
|
234 |
+
R : numpy.array(dim_z, dim_z)
|
235 |
+
Measurement noise covariance matrix. Also known as the
|
236 |
+
observation covariance.
|
237 |
+
Q : numpy.array(dim_x, dim_x)
|
238 |
+
Process noise covariance matrix. Also known as the transition
|
239 |
+
covariance.
|
240 |
+
F : numpy.array()
|
241 |
+
State Transition matrix. Also known as `A` in some formulation.
|
242 |
+
H : numpy.array(dim_z, dim_x)
|
243 |
+
Measurement function. Also known as the observation matrix, or as `C`.
|
244 |
+
y : numpy.array
|
245 |
+
Residual of the update step. Read only.
|
246 |
+
K : numpy.array(dim_x, dim_z)
|
247 |
+
Kalman gain of the update step. Read only.
|
248 |
+
S : numpy.array
|
249 |
+
System uncertainty (P projected to measurement space). Read only.
|
250 |
+
SI : numpy.array
|
251 |
+
Inverse system uncertainty. Read only.
|
252 |
+
log_likelihood : float
|
253 |
+
log-likelihood of the last measurement. Read only.
|
254 |
+
likelihood : float
|
255 |
+
likelihood of last measurement. Read only.
|
256 |
+
Computed from the log-likelihood. The log-likelihood can be very
|
257 |
+
small, meaning a large negative value such as -28000. Taking the
|
258 |
+
exp() of that results in 0.0, which can break typical algorithms
|
259 |
+
which multiply by this value, so by default we always return a
|
260 |
+
number >= sys.float_info.min.
|
261 |
+
mahalanobis : float
|
262 |
+
mahalanobis distance of the innovation. Read only.
|
263 |
+
inv : function, default numpy.linalg.inv
|
264 |
+
If you prefer another inverse function, such as the Moore-Penrose
|
265 |
+
pseudo inverse, set it to that instead: kf.inv = np.linalg.pinv
|
266 |
+
This is only used to invert self.S. If you know it is diagonal, you
|
267 |
+
might choose to set it to filterpy.common.inv_diagonal, which is
|
268 |
+
several times faster than numpy.linalg.inv for diagonal matrices.
|
269 |
+
alpha : float
|
270 |
+
Fading memory setting. 1.0 gives the normal Kalman filter, and
|
271 |
+
values slightly larger than 1.0 (such as 1.02) give a fading
|
272 |
+
memory effect - previous measurements have less influence on the
|
273 |
+
filter's estimates. This formulation of the Fading memory filter
|
274 |
+
(there are many) is due to Dan Simon [1]_.
|
275 |
+
References
|
276 |
+
----------
|
277 |
+
.. [1] Dan Simon. "Optimal State Estimation." John Wiley & Sons.
|
278 |
+
p. 208-212. (2006)
|
279 |
+
.. [2] Roger Labbe. "Kalman and Bayesian Filters in Python"
|
280 |
+
https://github.com/rlabbe/Kalman-and-Bayesian-Filters-in-Python
|
281 |
+
"""
|
282 |
+
|
283 |
+
def __init__(self, dim_x, dim_z, dim_u=0):
|
284 |
+
if dim_x < 1:
|
285 |
+
raise ValueError('dim_x must be 1 or greater')
|
286 |
+
if dim_z < 1:
|
287 |
+
raise ValueError('dim_z must be 1 or greater')
|
288 |
+
if dim_u < 0:
|
289 |
+
raise ValueError('dim_u must be 0 or greater')
|
290 |
+
|
291 |
+
self.dim_x = dim_x
|
292 |
+
self.dim_z = dim_z
|
293 |
+
self.dim_u = dim_u
|
294 |
+
|
295 |
+
self.x = zeros((dim_x, 1)) # state
|
296 |
+
self.P = eye(dim_x) # uncertainty covariance
|
297 |
+
self.Q = eye(dim_x) # process uncertainty
|
298 |
+
self.B = None # control transition matrix
|
299 |
+
self.F = eye(dim_x) # state transition matrix
|
300 |
+
self.H = zeros((dim_z, dim_x)) # measurement function
|
301 |
+
self.R = eye(dim_z) # measurement uncertainty
|
302 |
+
self._alpha_sq = 1. # fading memory control
|
303 |
+
self.M = np.zeros((dim_x, dim_z)) # process-measurement cross correlation
|
304 |
+
self.z = np.array([[None]*self.dim_z]).T
|
305 |
+
|
306 |
+
# gain and residual are computed during the innovation step. We
|
307 |
+
# save them so that in case you want to inspect them for various
|
308 |
+
# purposes
|
309 |
+
self.K = np.zeros((dim_x, dim_z)) # kalman gain
|
310 |
+
self.y = zeros((dim_z, 1))
|
311 |
+
self.S = np.zeros((dim_z, dim_z)) # system uncertainty
|
312 |
+
self.SI = np.zeros((dim_z, dim_z)) # inverse system uncertainty
|
313 |
+
|
314 |
+
# identity matrix. Do not alter this.
|
315 |
+
self._I = np.eye(dim_x)
|
316 |
+
|
317 |
+
# these will always be a copy of x,P after predict() is called
|
318 |
+
self.x_prior = self.x.copy()
|
319 |
+
self.P_prior = self.P.copy()
|
320 |
+
|
321 |
+
# these will always be a copy of x,P after update() is called
|
322 |
+
self.x_post = self.x.copy()
|
323 |
+
self.P_post = self.P.copy()
|
324 |
+
|
325 |
+
# Only computed only if requested via property
|
326 |
+
self._log_likelihood = log(sys.float_info.min)
|
327 |
+
self._likelihood = sys.float_info.min
|
328 |
+
self._mahalanobis = None
|
329 |
+
|
330 |
+
# keep all observations
|
331 |
+
self.history_obs = []
|
332 |
+
|
333 |
+
self.inv = np.linalg.inv
|
334 |
+
|
335 |
+
self.attr_saved = None
|
336 |
+
self.observed = False
|
337 |
+
|
338 |
+
|
339 |
+
def predict(self, u=None, B=None, F=None, Q=None):
|
340 |
+
"""
|
341 |
+
Predict next state (prior) using the Kalman filter state propagation
|
342 |
+
equations.
|
343 |
+
Parameters
|
344 |
+
----------
|
345 |
+
u : np.array, default 0
|
346 |
+
Optional control vector.
|
347 |
+
B : np.array(dim_x, dim_u), or None
|
348 |
+
Optional control transition matrix; a value of None
|
349 |
+
will cause the filter to use `self.B`.
|
350 |
+
F : np.array(dim_x, dim_x), or None
|
351 |
+
Optional state transition matrix; a value of None
|
352 |
+
will cause the filter to use `self.F`.
|
353 |
+
Q : np.array(dim_x, dim_x), scalar, or None
|
354 |
+
Optional process noise matrix; a value of None will cause the
|
355 |
+
filter to use `self.Q`.
|
356 |
+
"""
|
357 |
+
|
358 |
+
if B is None:
|
359 |
+
B = self.B
|
360 |
+
if F is None:
|
361 |
+
F = self.F
|
362 |
+
if Q is None:
|
363 |
+
Q = self.Q
|
364 |
+
elif isscalar(Q):
|
365 |
+
Q = eye(self.dim_x) * Q
|
366 |
+
|
367 |
+
|
368 |
+
# x = Fx + Bu
|
369 |
+
if B is not None and u is not None:
|
370 |
+
self.x = dot(F, self.x) + dot(B, u)
|
371 |
+
else:
|
372 |
+
self.x = dot(F, self.x)
|
373 |
+
|
374 |
+
# P = FPF' + Q
|
375 |
+
self.P = self._alpha_sq * dot(dot(F, self.P), F.T) + Q
|
376 |
+
|
377 |
+
# save prior
|
378 |
+
self.x_prior = self.x.copy()
|
379 |
+
self.P_prior = self.P.copy()
|
380 |
+
|
381 |
+
|
382 |
+
|
383 |
+
def freeze(self):
|
384 |
+
"""
|
385 |
+
Save the parameters before non-observation forward
|
386 |
+
"""
|
387 |
+
self.attr_saved = deepcopy(self.__dict__)
|
388 |
+
|
389 |
+
|
390 |
+
def unfreeze(self):
|
391 |
+
if self.attr_saved is not None:
|
392 |
+
new_history = deepcopy(self.history_obs)
|
393 |
+
self.__dict__ = self.attr_saved
|
394 |
+
# self.history_obs = new_history
|
395 |
+
self.history_obs = self.history_obs[:-1]
|
396 |
+
occur = [int(d is None) for d in new_history]
|
397 |
+
indices = np.where(np.array(occur)==0)[0]
|
398 |
+
index1 = indices[-2]
|
399 |
+
index2 = indices[-1]
|
400 |
+
box1 = new_history[index1]
|
401 |
+
x1, y1, s1, r1 = box1
|
402 |
+
w1 = np.sqrt(s1 * r1)
|
403 |
+
h1 = np.sqrt(s1 / r1)
|
404 |
+
box2 = new_history[index2]
|
405 |
+
x2, y2, s2, r2 = box2
|
406 |
+
w2 = np.sqrt(s2 * r2)
|
407 |
+
h2 = np.sqrt(s2 / r2)
|
408 |
+
time_gap = index2 - index1
|
409 |
+
dx = (x2-x1)/time_gap
|
410 |
+
dy = (y2-y1)/time_gap
|
411 |
+
dw = (w2-w1)/time_gap
|
412 |
+
dh = (h2-h1)/time_gap
|
413 |
+
for i in range(index2 - index1):
|
414 |
+
"""
|
415 |
+
The default virtual trajectory generation is by linear
|
416 |
+
motion (constant speed hypothesis), you could modify this
|
417 |
+
part to implement your own.
|
418 |
+
"""
|
419 |
+
x = x1 + (i+1) * dx
|
420 |
+
y = y1 + (i+1) * dy
|
421 |
+
w = w1 + (i+1) * dw
|
422 |
+
h = h1 + (i+1) * dh
|
423 |
+
s = w * h
|
424 |
+
r = w / float(h)
|
425 |
+
new_box = np.array([x, y, s, r]).reshape((4, 1))
|
426 |
+
"""
|
427 |
+
I still use predict-update loop here to refresh the parameters,
|
428 |
+
but this can be faster by directly modifying the internal parameters
|
429 |
+
as suggested in the paper. I keep this naive but slow way for
|
430 |
+
easy read and understanding
|
431 |
+
"""
|
432 |
+
self.update(new_box)
|
433 |
+
if not i == (index2-index1-1):
|
434 |
+
self.predict()
|
435 |
+
|
436 |
+
|
437 |
+
def update(self, z, R=None, H=None):
|
438 |
+
"""
|
439 |
+
Add a new measurement (z) to the Kalman filter.
|
440 |
+
If z is None, nothing is computed. However, x_post and P_post are
|
441 |
+
updated with the prior (x_prior, P_prior), and self.z is set to None.
|
442 |
+
Parameters
|
443 |
+
----------
|
444 |
+
z : (dim_z, 1): array_like
|
445 |
+
measurement for this update. z can be a scalar if dim_z is 1,
|
446 |
+
otherwise it must be convertible to a column vector.
|
447 |
+
If you pass in a value of H, z must be a column vector the
|
448 |
+
of the correct size.
|
449 |
+
R : np.array, scalar, or None
|
450 |
+
Optionally provide R to override the measurement noise for this
|
451 |
+
one call, otherwise self.R will be used.
|
452 |
+
H : np.array, or None
|
453 |
+
Optionally provide H to override the measurement function for this
|
454 |
+
one call, otherwise self.H will be used.
|
455 |
+
"""
|
456 |
+
|
457 |
+
# set to None to force recompute
|
458 |
+
self._log_likelihood = None
|
459 |
+
self._likelihood = None
|
460 |
+
self._mahalanobis = None
|
461 |
+
|
462 |
+
# append the observation
|
463 |
+
self.history_obs.append(z)
|
464 |
+
|
465 |
+
if z is None:
|
466 |
+
if self.observed:
|
467 |
+
"""
|
468 |
+
Got no observation so freeze the current parameters for future
|
469 |
+
potential online smoothing.
|
470 |
+
"""
|
471 |
+
self.freeze()
|
472 |
+
self.observed = False
|
473 |
+
self.z = np.array([[None]*self.dim_z]).T
|
474 |
+
self.x_post = self.x.copy()
|
475 |
+
self.P_post = self.P.copy()
|
476 |
+
self.y = zeros((self.dim_z, 1))
|
477 |
+
return
|
478 |
+
|
479 |
+
# self.observed = True
|
480 |
+
if not self.observed:
|
481 |
+
"""
|
482 |
+
Get observation, use online smoothing to re-update parameters
|
483 |
+
"""
|
484 |
+
self.unfreeze()
|
485 |
+
self.observed = True
|
486 |
+
|
487 |
+
if R is None:
|
488 |
+
R = self.R
|
489 |
+
elif isscalar(R):
|
490 |
+
R = eye(self.dim_z) * R
|
491 |
+
|
492 |
+
if H is None:
|
493 |
+
z = reshape_z(z, self.dim_z, self.x.ndim)
|
494 |
+
H = self.H
|
495 |
+
|
496 |
+
# y = z - Hx
|
497 |
+
# error (residual) between measurement and prediction
|
498 |
+
self.y = z - dot(H, self.x)
|
499 |
+
|
500 |
+
# common subexpression for speed
|
501 |
+
PHT = dot(self.P, H.T)
|
502 |
+
|
503 |
+
# S = HPH' + R
|
504 |
+
# project system uncertainty into measurement space
|
505 |
+
self.S = dot(H, PHT) + R
|
506 |
+
self.SI = self.inv(self.S)
|
507 |
+
# K = PH'inv(S)
|
508 |
+
# map system uncertainty into kalman gain
|
509 |
+
self.K = dot(PHT, self.SI)
|
510 |
+
|
511 |
+
# x = x + Ky
|
512 |
+
# predict new x with residual scaled by the kalman gain
|
513 |
+
self.x = self.x + dot(self.K, self.y)
|
514 |
+
|
515 |
+
# P = (I-KH)P(I-KH)' + KRK'
|
516 |
+
# This is more numerically stable
|
517 |
+
# and works for non-optimal K vs the equation
|
518 |
+
# P = (I-KH)P usually seen in the literature.
|
519 |
+
|
520 |
+
I_KH = self._I - dot(self.K, H)
|
521 |
+
self.P = dot(dot(I_KH, self.P), I_KH.T) + dot(dot(self.K, R), self.K.T)
|
522 |
+
|
523 |
+
# save measurement and posterior state
|
524 |
+
self.z = deepcopy(z)
|
525 |
+
self.x_post = self.x.copy()
|
526 |
+
self.P_post = self.P.copy()
|
527 |
+
|
528 |
+
def predict_steadystate(self, u=0, B=None):
|
529 |
+
"""
|
530 |
+
Predict state (prior) using the Kalman filter state propagation
|
531 |
+
equations. Only x is updated, P is left unchanged. See
|
532 |
+
update_steadstate() for a longer explanation of when to use this
|
533 |
+
method.
|
534 |
+
Parameters
|
535 |
+
----------
|
536 |
+
u : np.array
|
537 |
+
Optional control vector. If non-zero, it is multiplied by B
|
538 |
+
to create the control input into the system.
|
539 |
+
B : np.array(dim_x, dim_u), or None
|
540 |
+
Optional control transition matrix; a value of None
|
541 |
+
will cause the filter to use `self.B`.
|
542 |
+
"""
|
543 |
+
|
544 |
+
if B is None:
|
545 |
+
B = self.B
|
546 |
+
|
547 |
+
# x = Fx + Bu
|
548 |
+
if B is not None:
|
549 |
+
self.x = dot(self.F, self.x) + dot(B, u)
|
550 |
+
else:
|
551 |
+
self.x = dot(self.F, self.x)
|
552 |
+
|
553 |
+
# save prior
|
554 |
+
self.x_prior = self.x.copy()
|
555 |
+
self.P_prior = self.P.copy()
|
556 |
+
|
557 |
+
def update_steadystate(self, z):
|
558 |
+
"""
|
559 |
+
Add a new measurement (z) to the Kalman filter without recomputing
|
560 |
+
the Kalman gain K, the state covariance P, or the system
|
561 |
+
uncertainty S.
|
562 |
+
You can use this for LTI systems since the Kalman gain and covariance
|
563 |
+
converge to a fixed value. Precompute these and assign them explicitly,
|
564 |
+
or run the Kalman filter using the normal predict()/update(0 cycle
|
565 |
+
until they converge.
|
566 |
+
The main advantage of this call is speed. We do significantly less
|
567 |
+
computation, notably avoiding a costly matrix inversion.
|
568 |
+
Use in conjunction with predict_steadystate(), otherwise P will grow
|
569 |
+
without bound.
|
570 |
+
Parameters
|
571 |
+
----------
|
572 |
+
z : (dim_z, 1): array_like
|
573 |
+
measurement for this update. z can be a scalar if dim_z is 1,
|
574 |
+
otherwise it must be convertible to a column vector.
|
575 |
+
Examples
|
576 |
+
--------
|
577 |
+
>>> cv = kinematic_kf(dim=3, order=2) # 3D const velocity filter
|
578 |
+
>>> # let filter converge on representative data, then save k and P
|
579 |
+
>>> for i in range(100):
|
580 |
+
>>> cv.predict()
|
581 |
+
>>> cv.update([i, i, i])
|
582 |
+
>>> saved_k = np.copy(cv.K)
|
583 |
+
>>> saved_P = np.copy(cv.P)
|
584 |
+
later on:
|
585 |
+
>>> cv = kinematic_kf(dim=3, order=2) # 3D const velocity filter
|
586 |
+
>>> cv.K = np.copy(saved_K)
|
587 |
+
>>> cv.P = np.copy(saved_P)
|
588 |
+
>>> for i in range(100):
|
589 |
+
>>> cv.predict_steadystate()
|
590 |
+
>>> cv.update_steadystate([i, i, i])
|
591 |
+
"""
|
592 |
+
|
593 |
+
# set to None to force recompute
|
594 |
+
self._log_likelihood = None
|
595 |
+
self._likelihood = None
|
596 |
+
self._mahalanobis = None
|
597 |
+
|
598 |
+
if z is None:
|
599 |
+
self.z = np.array([[None]*self.dim_z]).T
|
600 |
+
self.x_post = self.x.copy()
|
601 |
+
self.P_post = self.P.copy()
|
602 |
+
self.y = zeros((self.dim_z, 1))
|
603 |
+
return
|
604 |
+
|
605 |
+
z = reshape_z(z, self.dim_z, self.x.ndim)
|
606 |
+
|
607 |
+
# y = z - Hx
|
608 |
+
# error (residual) between measurement and prediction
|
609 |
+
self.y = z - dot(self.H, self.x)
|
610 |
+
|
611 |
+
# x = x + Ky
|
612 |
+
# predict new x with residual scaled by the kalman gain
|
613 |
+
self.x = self.x + dot(self.K, self.y)
|
614 |
+
|
615 |
+
self.z = deepcopy(z)
|
616 |
+
self.x_post = self.x.copy()
|
617 |
+
self.P_post = self.P.copy()
|
618 |
+
|
619 |
+
# set to None to force recompute
|
620 |
+
self._log_likelihood = None
|
621 |
+
self._likelihood = None
|
622 |
+
self._mahalanobis = None
|
623 |
+
|
624 |
+
def update_correlated(self, z, R=None, H=None):
|
625 |
+
""" Add a new measurement (z) to the Kalman filter assuming that
|
626 |
+
process noise and measurement noise are correlated as defined in
|
627 |
+
the `self.M` matrix.
|
628 |
+
A partial derivation can be found in [1]
|
629 |
+
If z is None, nothing is changed.
|
630 |
+
Parameters
|
631 |
+
----------
|
632 |
+
z : (dim_z, 1): array_like
|
633 |
+
measurement for this update. z can be a scalar if dim_z is 1,
|
634 |
+
otherwise it must be convertible to a column vector.
|
635 |
+
R : np.array, scalar, or None
|
636 |
+
Optionally provide R to override the measurement noise for this
|
637 |
+
one call, otherwise self.R will be used.
|
638 |
+
H : np.array, or None
|
639 |
+
Optionally provide H to override the measurement function for this
|
640 |
+
one call, otherwise self.H will be used.
|
641 |
+
References
|
642 |
+
----------
|
643 |
+
.. [1] Bulut, Y. (2011). Applied Kalman filter theory (Doctoral dissertation, Northeastern University).
|
644 |
+
http://people.duke.edu/~hpgavin/SystemID/References/Balut-KalmanFilter-PhD-NEU-2011.pdf
|
645 |
+
"""
|
646 |
+
|
647 |
+
# set to None to force recompute
|
648 |
+
self._log_likelihood = None
|
649 |
+
self._likelihood = None
|
650 |
+
self._mahalanobis = None
|
651 |
+
|
652 |
+
if z is None:
|
653 |
+
self.z = np.array([[None]*self.dim_z]).T
|
654 |
+
self.x_post = self.x.copy()
|
655 |
+
self.P_post = self.P.copy()
|
656 |
+
self.y = zeros((self.dim_z, 1))
|
657 |
+
return
|
658 |
+
|
659 |
+
if R is None:
|
660 |
+
R = self.R
|
661 |
+
elif isscalar(R):
|
662 |
+
R = eye(self.dim_z) * R
|
663 |
+
|
664 |
+
# rename for readability and a tiny extra bit of speed
|
665 |
+
if H is None:
|
666 |
+
z = reshape_z(z, self.dim_z, self.x.ndim)
|
667 |
+
H = self.H
|
668 |
+
|
669 |
+
# handle special case: if z is in form [[z]] but x is not a column
|
670 |
+
# vector dimensions will not match
|
671 |
+
if self.x.ndim == 1 and shape(z) == (1, 1):
|
672 |
+
z = z[0]
|
673 |
+
|
674 |
+
if shape(z) == (): # is it scalar, e.g. z=3 or z=np.array(3)
|
675 |
+
z = np.asarray([z])
|
676 |
+
|
677 |
+
# y = z - Hx
|
678 |
+
# error (residual) between measurement and prediction
|
679 |
+
self.y = z - dot(H, self.x)
|
680 |
+
|
681 |
+
# common subexpression for speed
|
682 |
+
PHT = dot(self.P, H.T)
|
683 |
+
|
684 |
+
# project system uncertainty into measurement space
|
685 |
+
self.S = dot(H, PHT) + dot(H, self.M) + dot(self.M.T, H.T) + R
|
686 |
+
self.SI = self.inv(self.S)
|
687 |
+
|
688 |
+
# K = PH'inv(S)
|
689 |
+
# map system uncertainty into kalman gain
|
690 |
+
self.K = dot(PHT + self.M, self.SI)
|
691 |
+
|
692 |
+
# x = x + Ky
|
693 |
+
# predict new x with residual scaled by the kalman gain
|
694 |
+
self.x = self.x + dot(self.K, self.y)
|
695 |
+
self.P = self.P - dot(self.K, dot(H, self.P) + self.M.T)
|
696 |
+
|
697 |
+
self.z = deepcopy(z)
|
698 |
+
self.x_post = self.x.copy()
|
699 |
+
self.P_post = self.P.copy()
|
700 |
+
|
701 |
+
def batch_filter(self, zs, Fs=None, Qs=None, Hs=None,
|
702 |
+
Rs=None, Bs=None, us=None, update_first=False,
|
703 |
+
saver=None):
|
704 |
+
""" Batch processes a sequences of measurements.
|
705 |
+
Parameters
|
706 |
+
----------
|
707 |
+
zs : list-like
|
708 |
+
list of measurements at each time step `self.dt`. Missing
|
709 |
+
measurements must be represented by `None`.
|
710 |
+
Fs : None, list-like, default=None
|
711 |
+
optional value or list of values to use for the state transition
|
712 |
+
matrix F.
|
713 |
+
If Fs is None then self.F is used for all epochs.
|
714 |
+
Otherwise it must contain a list-like list of F's, one for
|
715 |
+
each epoch. This allows you to have varying F per epoch.
|
716 |
+
Qs : None, np.array or list-like, default=None
|
717 |
+
optional value or list of values to use for the process error
|
718 |
+
covariance Q.
|
719 |
+
If Qs is None then self.Q is used for all epochs.
|
720 |
+
Otherwise it must contain a list-like list of Q's, one for
|
721 |
+
each epoch. This allows you to have varying Q per epoch.
|
722 |
+
Hs : None, np.array or list-like, default=None
|
723 |
+
optional list of values to use for the measurement matrix H.
|
724 |
+
If Hs is None then self.H is used for all epochs.
|
725 |
+
If Hs contains a single matrix, then it is used as H for all
|
726 |
+
epochs.
|
727 |
+
Otherwise it must contain a list-like list of H's, one for
|
728 |
+
each epoch. This allows you to have varying H per epoch.
|
729 |
+
Rs : None, np.array or list-like, default=None
|
730 |
+
optional list of values to use for the measurement error
|
731 |
+
covariance R.
|
732 |
+
If Rs is None then self.R is used for all epochs.
|
733 |
+
Otherwise it must contain a list-like list of R's, one for
|
734 |
+
each epoch. This allows you to have varying R per epoch.
|
735 |
+
Bs : None, np.array or list-like, default=None
|
736 |
+
optional list of values to use for the control transition matrix B.
|
737 |
+
If Bs is None then self.B is used for all epochs.
|
738 |
+
Otherwise it must contain a list-like list of B's, one for
|
739 |
+
each epoch. This allows you to have varying B per epoch.
|
740 |
+
us : None, np.array or list-like, default=None
|
741 |
+
optional list of values to use for the control input vector;
|
742 |
+
If us is None then None is used for all epochs (equivalent to 0,
|
743 |
+
or no control input).
|
744 |
+
Otherwise it must contain a list-like list of u's, one for
|
745 |
+
each epoch.
|
746 |
+
update_first : bool, optional, default=False
|
747 |
+
controls whether the order of operations is update followed by
|
748 |
+
predict, or predict followed by update. Default is predict->update.
|
749 |
+
saver : filterpy.common.Saver, optional
|
750 |
+
filterpy.common.Saver object. If provided, saver.save() will be
|
751 |
+
called after every epoch
|
752 |
+
Returns
|
753 |
+
-------
|
754 |
+
means : np.array((n,dim_x,1))
|
755 |
+
array of the state for each time step after the update. Each entry
|
756 |
+
is an np.array. In other words `means[k,:]` is the state at step
|
757 |
+
`k`.
|
758 |
+
covariance : np.array((n,dim_x,dim_x))
|
759 |
+
array of the covariances for each time step after the update.
|
760 |
+
In other words `covariance[k,:,:]` is the covariance at step `k`.
|
761 |
+
means_predictions : np.array((n,dim_x,1))
|
762 |
+
array of the state for each time step after the predictions. Each
|
763 |
+
entry is an np.array. In other words `means[k,:]` is the state at
|
764 |
+
step `k`.
|
765 |
+
covariance_predictions : np.array((n,dim_x,dim_x))
|
766 |
+
array of the covariances for each time step after the prediction.
|
767 |
+
In other words `covariance[k,:,:]` is the covariance at step `k`.
|
768 |
+
Examples
|
769 |
+
--------
|
770 |
+
.. code-block:: Python
|
771 |
+
# this example demonstrates tracking a measurement where the time
|
772 |
+
# between measurement varies, as stored in dts. This requires
|
773 |
+
# that F be recomputed for each epoch. The output is then smoothed
|
774 |
+
# with an RTS smoother.
|
775 |
+
zs = [t + random.randn()*4 for t in range (40)]
|
776 |
+
Fs = [np.array([[1., dt], [0, 1]] for dt in dts]
|
777 |
+
(mu, cov, _, _) = kf.batch_filter(zs, Fs=Fs)
|
778 |
+
(xs, Ps, Ks, Pps) = kf.rts_smoother(mu, cov, Fs=Fs)
|
779 |
+
"""
|
780 |
+
|
781 |
+
#pylint: disable=too-many-statements
|
782 |
+
n = np.size(zs, 0)
|
783 |
+
if Fs is None:
|
784 |
+
Fs = [self.F] * n
|
785 |
+
if Qs is None:
|
786 |
+
Qs = [self.Q] * n
|
787 |
+
if Hs is None:
|
788 |
+
Hs = [self.H] * n
|
789 |
+
if Rs is None:
|
790 |
+
Rs = [self.R] * n
|
791 |
+
if Bs is None:
|
792 |
+
Bs = [self.B] * n
|
793 |
+
if us is None:
|
794 |
+
us = [0] * n
|
795 |
+
|
796 |
+
# mean estimates from Kalman Filter
|
797 |
+
if self.x.ndim == 1:
|
798 |
+
means = zeros((n, self.dim_x))
|
799 |
+
means_p = zeros((n, self.dim_x))
|
800 |
+
else:
|
801 |
+
means = zeros((n, self.dim_x, 1))
|
802 |
+
means_p = zeros((n, self.dim_x, 1))
|
803 |
+
|
804 |
+
# state covariances from Kalman Filter
|
805 |
+
covariances = zeros((n, self.dim_x, self.dim_x))
|
806 |
+
covariances_p = zeros((n, self.dim_x, self.dim_x))
|
807 |
+
|
808 |
+
if update_first:
|
809 |
+
for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)):
|
810 |
+
|
811 |
+
self.update(z, R=R, H=H)
|
812 |
+
means[i, :] = self.x
|
813 |
+
covariances[i, :, :] = self.P
|
814 |
+
|
815 |
+
self.predict(u=u, B=B, F=F, Q=Q)
|
816 |
+
means_p[i, :] = self.x
|
817 |
+
covariances_p[i, :, :] = self.P
|
818 |
+
|
819 |
+
if saver is not None:
|
820 |
+
saver.save()
|
821 |
+
else:
|
822 |
+
for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)):
|
823 |
+
|
824 |
+
self.predict(u=u, B=B, F=F, Q=Q)
|
825 |
+
means_p[i, :] = self.x
|
826 |
+
covariances_p[i, :, :] = self.P
|
827 |
+
|
828 |
+
self.update(z, R=R, H=H)
|
829 |
+
means[i, :] = self.x
|
830 |
+
covariances[i, :, :] = self.P
|
831 |
+
|
832 |
+
if saver is not None:
|
833 |
+
saver.save()
|
834 |
+
|
835 |
+
return (means, covariances, means_p, covariances_p)
|
836 |
+
|
837 |
+
def rts_smoother(self, Xs, Ps, Fs=None, Qs=None, inv=np.linalg.inv):
|
838 |
+
"""
|
839 |
+
Runs the Rauch-Tung-Striebel Kalman smoother on a set of
|
840 |
+
means and covariances computed by a Kalman filter. The usual input
|
841 |
+
would come from the output of `KalmanFilter.batch_filter()`.
|
842 |
+
Parameters
|
843 |
+
----------
|
844 |
+
Xs : numpy.array
|
845 |
+
array of the means (state variable x) of the output of a Kalman
|
846 |
+
filter.
|
847 |
+
Ps : numpy.array
|
848 |
+
array of the covariances of the output of a kalman filter.
|
849 |
+
Fs : list-like collection of numpy.array, optional
|
850 |
+
State transition matrix of the Kalman filter at each time step.
|
851 |
+
Optional, if not provided the filter's self.F will be used
|
852 |
+
Qs : list-like collection of numpy.array, optional
|
853 |
+
Process noise of the Kalman filter at each time step. Optional,
|
854 |
+
if not provided the filter's self.Q will be used
|
855 |
+
inv : function, default numpy.linalg.inv
|
856 |
+
If you prefer another inverse function, such as the Moore-Penrose
|
857 |
+
pseudo inverse, set it to that instead: kf.inv = np.linalg.pinv
|
858 |
+
Returns
|
859 |
+
-------
|
860 |
+
x : numpy.ndarray
|
861 |
+
smoothed means
|
862 |
+
P : numpy.ndarray
|
863 |
+
smoothed state covariances
|
864 |
+
K : numpy.ndarray
|
865 |
+
smoother gain at each step
|
866 |
+
Pp : numpy.ndarray
|
867 |
+
Predicted state covariances
|
868 |
+
Examples
|
869 |
+
--------
|
870 |
+
.. code-block:: Python
|
871 |
+
zs = [t + random.randn()*4 for t in range (40)]
|
872 |
+
(mu, cov, _, _) = kalman.batch_filter(zs)
|
873 |
+
(x, P, K, Pp) = rts_smoother(mu, cov, kf.F, kf.Q)
|
874 |
+
"""
|
875 |
+
|
876 |
+
if len(Xs) != len(Ps):
|
877 |
+
raise ValueError('length of Xs and Ps must be the same')
|
878 |
+
|
879 |
+
n = Xs.shape[0]
|
880 |
+
dim_x = Xs.shape[1]
|
881 |
+
|
882 |
+
if Fs is None:
|
883 |
+
Fs = [self.F] * n
|
884 |
+
if Qs is None:
|
885 |
+
Qs = [self.Q] * n
|
886 |
+
|
887 |
+
# smoother gain
|
888 |
+
K = zeros((n, dim_x, dim_x))
|
889 |
+
|
890 |
+
x, P, Pp = Xs.copy(), Ps.copy(), Ps.copy()
|
891 |
+
for k in range(n-2, -1, -1):
|
892 |
+
Pp[k] = dot(dot(Fs[k+1], P[k]), Fs[k+1].T) + Qs[k+1]
|
893 |
+
|
894 |
+
#pylint: disable=bad-whitespace
|
895 |
+
K[k] = dot(dot(P[k], Fs[k+1].T), inv(Pp[k]))
|
896 |
+
x[k] += dot(K[k], x[k+1] - dot(Fs[k+1], x[k]))
|
897 |
+
P[k] += dot(dot(K[k], P[k+1] - Pp[k]), K[k].T)
|
898 |
+
|
899 |
+
return (x, P, K, Pp)
|
900 |
+
|
901 |
+
def get_prediction(self, u=None, B=None, F=None, Q=None):
|
902 |
+
"""
|
903 |
+
Predict next state (prior) using the Kalman filter state propagation
|
904 |
+
equations and returns it without modifying the object.
|
905 |
+
Parameters
|
906 |
+
----------
|
907 |
+
u : np.array, default 0
|
908 |
+
Optional control vector.
|
909 |
+
B : np.array(dim_x, dim_u), or None
|
910 |
+
Optional control transition matrix; a value of None
|
911 |
+
will cause the filter to use `self.B`.
|
912 |
+
F : np.array(dim_x, dim_x), or None
|
913 |
+
Optional state transition matrix; a value of None
|
914 |
+
will cause the filter to use `self.F`.
|
915 |
+
Q : np.array(dim_x, dim_x), scalar, or None
|
916 |
+
Optional process noise matrix; a value of None will cause the
|
917 |
+
filter to use `self.Q`.
|
918 |
+
Returns
|
919 |
+
-------
|
920 |
+
(x, P) : tuple
|
921 |
+
State vector and covariance array of the prediction.
|
922 |
+
"""
|
923 |
+
|
924 |
+
if B is None:
|
925 |
+
B = self.B
|
926 |
+
if F is None:
|
927 |
+
F = self.F
|
928 |
+
if Q is None:
|
929 |
+
Q = self.Q
|
930 |
+
elif isscalar(Q):
|
931 |
+
Q = eye(self.dim_x) * Q
|
932 |
+
|
933 |
+
# x = Fx + Bu
|
934 |
+
if B is not None and u is not None:
|
935 |
+
x = dot(F, self.x) + dot(B, u)
|
936 |
+
else:
|
937 |
+
x = dot(F, self.x)
|
938 |
+
|
939 |
+
# P = FPF' + Q
|
940 |
+
P = self._alpha_sq * dot(dot(F, self.P), F.T) + Q
|
941 |
+
|
942 |
+
return x, P
|
943 |
+
|
944 |
+
def get_update(self, z=None):
|
945 |
+
"""
|
946 |
+
Computes the new estimate based on measurement `z` and returns it
|
947 |
+
without altering the state of the filter.
|
948 |
+
Parameters
|
949 |
+
----------
|
950 |
+
z : (dim_z, 1): array_like
|
951 |
+
measurement for this update. z can be a scalar if dim_z is 1,
|
952 |
+
otherwise it must be convertible to a column vector.
|
953 |
+
Returns
|
954 |
+
-------
|
955 |
+
(x, P) : tuple
|
956 |
+
State vector and covariance array of the update.
|
957 |
+
"""
|
958 |
+
|
959 |
+
if z is None:
|
960 |
+
return self.x, self.P
|
961 |
+
z = reshape_z(z, self.dim_z, self.x.ndim)
|
962 |
+
|
963 |
+
R = self.R
|
964 |
+
H = self.H
|
965 |
+
P = self.P
|
966 |
+
x = self.x
|
967 |
+
|
968 |
+
# error (residual) between measurement and prediction
|
969 |
+
y = z - dot(H, x)
|
970 |
+
|
971 |
+
# common subexpression for speed
|
972 |
+
PHT = dot(P, H.T)
|
973 |
+
|
974 |
+
# project system uncertainty into measurement space
|
975 |
+
S = dot(H, PHT) + R
|
976 |
+
|
977 |
+
# map system uncertainty into kalman gain
|
978 |
+
K = dot(PHT, self.inv(S))
|
979 |
+
|
980 |
+
# predict new x with residual scaled by the kalman gain
|
981 |
+
x = x + dot(K, y)
|
982 |
+
|
983 |
+
# P = (I-KH)P(I-KH)' + KRK'
|
984 |
+
I_KH = self._I - dot(K, H)
|
985 |
+
P = dot(dot(I_KH, P), I_KH.T) + dot(dot(K, R), K.T)
|
986 |
+
|
987 |
+
return x, P
|
988 |
+
|
989 |
+
def residual_of(self, z):
|
990 |
+
"""
|
991 |
+
Returns the residual for the given measurement (z). Does not alter
|
992 |
+
the state of the filter.
|
993 |
+
"""
|
994 |
+
z = reshape_z(z, self.dim_z, self.x.ndim)
|
995 |
+
return z - dot(self.H, self.x_prior)
|
996 |
+
|
997 |
+
def measurement_of_state(self, x):
|
998 |
+
"""
|
999 |
+
Helper function that converts a state into a measurement.
|
1000 |
+
Parameters
|
1001 |
+
----------
|
1002 |
+
x : np.array
|
1003 |
+
kalman state vector
|
1004 |
+
Returns
|
1005 |
+
-------
|
1006 |
+
z : (dim_z, 1): array_like
|
1007 |
+
measurement for this update. z can be a scalar if dim_z is 1,
|
1008 |
+
otherwise it must be convertible to a column vector.
|
1009 |
+
"""
|
1010 |
+
|
1011 |
+
return dot(self.H, x)
|
1012 |
+
|
1013 |
+
@property
|
1014 |
+
def log_likelihood(self):
|
1015 |
+
"""
|
1016 |
+
log-likelihood of the last measurement.
|
1017 |
+
"""
|
1018 |
+
if self._log_likelihood is None:
|
1019 |
+
self._log_likelihood = logpdf(x=self.y, cov=self.S)
|
1020 |
+
return self._log_likelihood
|
1021 |
+
|
1022 |
+
@property
|
1023 |
+
def likelihood(self):
|
1024 |
+
"""
|
1025 |
+
Computed from the log-likelihood. The log-likelihood can be very
|
1026 |
+
small, meaning a large negative value such as -28000. Taking the
|
1027 |
+
exp() of that results in 0.0, which can break typical algorithms
|
1028 |
+
which multiply by this value, so by default we always return a
|
1029 |
+
number >= sys.float_info.min.
|
1030 |
+
"""
|
1031 |
+
if self._likelihood is None:
|
1032 |
+
self._likelihood = exp(self.log_likelihood)
|
1033 |
+
if self._likelihood == 0:
|
1034 |
+
self._likelihood = sys.float_info.min
|
1035 |
+
return self._likelihood
|
1036 |
+
|
1037 |
+
@property
|
1038 |
+
def mahalanobis(self):
|
1039 |
+
""""
|
1040 |
+
Mahalanobis distance of measurement. E.g. 3 means measurement
|
1041 |
+
was 3 standard deviations away from the predicted value.
|
1042 |
+
Returns
|
1043 |
+
-------
|
1044 |
+
mahalanobis : float
|
1045 |
+
"""
|
1046 |
+
if self._mahalanobis is None:
|
1047 |
+
self._mahalanobis = sqrt(float(dot(dot(self.y.T, self.SI), self.y)))
|
1048 |
+
return self._mahalanobis
|
1049 |
+
|
1050 |
+
@property
|
1051 |
+
def alpha(self):
|
1052 |
+
"""
|
1053 |
+
Fading memory setting. 1.0 gives the normal Kalman filter, and
|
1054 |
+
values slightly larger than 1.0 (such as 1.02) give a fading
|
1055 |
+
memory effect - previous measurements have less influence on the
|
1056 |
+
filter's estimates. This formulation of the Fading memory filter
|
1057 |
+
(there are many) is due to Dan Simon [1]_.
|
1058 |
+
"""
|
1059 |
+
return self._alpha_sq**.5
|
1060 |
+
|
1061 |
+
def log_likelihood_of(self, z):
|
1062 |
+
"""
|
1063 |
+
log likelihood of the measurement `z`. This should only be called
|
1064 |
+
after a call to update(). Calling after predict() will yield an
|
1065 |
+
incorrect result."""
|
1066 |
+
|
1067 |
+
if z is None:
|
1068 |
+
return log(sys.float_info.min)
|
1069 |
+
return logpdf(z, dot(self.H, self.x), self.S)
|
1070 |
+
|
1071 |
+
@alpha.setter
|
1072 |
+
def alpha(self, value):
|
1073 |
+
if not np.isscalar(value) or value < 1:
|
1074 |
+
raise ValueError('alpha must be a float greater than 1')
|
1075 |
+
|
1076 |
+
self._alpha_sq = value**2
|
1077 |
+
|
1078 |
+
def __repr__(self):
|
1079 |
+
return '\n'.join([
|
1080 |
+
'KalmanFilter object',
|
1081 |
+
pretty_str('dim_x', self.dim_x),
|
1082 |
+
pretty_str('dim_z', self.dim_z),
|
1083 |
+
pretty_str('dim_u', self.dim_u),
|
1084 |
+
pretty_str('x', self.x),
|
1085 |
+
pretty_str('P', self.P),
|
1086 |
+
pretty_str('x_prior', self.x_prior),
|
1087 |
+
pretty_str('P_prior', self.P_prior),
|
1088 |
+
pretty_str('x_post', self.x_post),
|
1089 |
+
pretty_str('P_post', self.P_post),
|
1090 |
+
pretty_str('F', self.F),
|
1091 |
+
pretty_str('Q', self.Q),
|
1092 |
+
pretty_str('R', self.R),
|
1093 |
+
pretty_str('H', self.H),
|
1094 |
+
pretty_str('K', self.K),
|
1095 |
+
pretty_str('y', self.y),
|
1096 |
+
pretty_str('S', self.S),
|
1097 |
+
pretty_str('SI', self.SI),
|
1098 |
+
pretty_str('M', self.M),
|
1099 |
+
pretty_str('B', self.B),
|
1100 |
+
pretty_str('z', self.z),
|
1101 |
+
pretty_str('log-likelihood', self.log_likelihood),
|
1102 |
+
pretty_str('likelihood', self.likelihood),
|
1103 |
+
pretty_str('mahalanobis', self.mahalanobis),
|
1104 |
+
pretty_str('alpha', self.alpha),
|
1105 |
+
pretty_str('inv', self.inv)
|
1106 |
+
])
|
1107 |
+
|
1108 |
+
def test_matrix_dimensions(self, z=None, H=None, R=None, F=None, Q=None):
|
1109 |
+
"""
|
1110 |
+
Performs a series of asserts to check that the size of everything
|
1111 |
+
is what it should be. This can help you debug problems in your design.
|
1112 |
+
If you pass in H, R, F, Q those will be used instead of this object's
|
1113 |
+
value for those matrices.
|
1114 |
+
Testing `z` (the measurement) is problamatic. x is a vector, and can be
|
1115 |
+
implemented as either a 1D array or as a nx1 column vector. Thus Hx
|
1116 |
+
can be of different shapes. Then, if Hx is a single value, it can
|
1117 |
+
be either a 1D array or 2D vector. If either is true, z can reasonably
|
1118 |
+
be a scalar (either '3' or np.array('3') are scalars under this
|
1119 |
+
definition), a 1D, 1 element array, or a 2D, 1 element array. You are
|
1120 |
+
allowed to pass in any combination that works.
|
1121 |
+
"""
|
1122 |
+
|
1123 |
+
if H is None:
|
1124 |
+
H = self.H
|
1125 |
+
if R is None:
|
1126 |
+
R = self.R
|
1127 |
+
if F is None:
|
1128 |
+
F = self.F
|
1129 |
+
if Q is None:
|
1130 |
+
Q = self.Q
|
1131 |
+
x = self.x
|
1132 |
+
P = self.P
|
1133 |
+
|
1134 |
+
assert x.ndim == 1 or x.ndim == 2, \
|
1135 |
+
"x must have one or two dimensions, but has {}".format(x.ndim)
|
1136 |
+
|
1137 |
+
if x.ndim == 1:
|
1138 |
+
assert x.shape[0] == self.dim_x, \
|
1139 |
+
"Shape of x must be ({},{}), but is {}".format(
|
1140 |
+
self.dim_x, 1, x.shape)
|
1141 |
+
else:
|
1142 |
+
assert x.shape == (self.dim_x, 1), \
|
1143 |
+
"Shape of x must be ({},{}), but is {}".format(
|
1144 |
+
self.dim_x, 1, x.shape)
|
1145 |
+
|
1146 |
+
assert P.shape == (self.dim_x, self.dim_x), \
|
1147 |
+
"Shape of P must be ({},{}), but is {}".format(
|
1148 |
+
self.dim_x, self.dim_x, P.shape)
|
1149 |
+
|
1150 |
+
assert Q.shape == (self.dim_x, self.dim_x), \
|
1151 |
+
"Shape of Q must be ({},{}), but is {}".format(
|
1152 |
+
self.dim_x, self.dim_x, P.shape)
|
1153 |
+
|
1154 |
+
assert F.shape == (self.dim_x, self.dim_x), \
|
1155 |
+
"Shape of F must be ({},{}), but is {}".format(
|
1156 |
+
self.dim_x, self.dim_x, F.shape)
|
1157 |
+
|
1158 |
+
assert np.ndim(H) == 2, \
|
1159 |
+
"Shape of H must be (dim_z, {}), but is {}".format(
|
1160 |
+
P.shape[0], shape(H))
|
1161 |
+
|
1162 |
+
assert H.shape[1] == P.shape[0], \
|
1163 |
+
"Shape of H must be (dim_z, {}), but is {}".format(
|
1164 |
+
P.shape[0], H.shape)
|
1165 |
+
|
1166 |
+
# shape of R must be the same as HPH'
|
1167 |
+
hph_shape = (H.shape[0], H.shape[0])
|
1168 |
+
r_shape = shape(R)
|
1169 |
+
|
1170 |
+
if H.shape[0] == 1:
|
1171 |
+
# r can be scalar, 1D, or 2D in this case
|
1172 |
+
assert r_shape in [(), (1,), (1, 1)], \
|
1173 |
+
"R must be scalar or one element array, but is shaped {}".format(
|
1174 |
+
r_shape)
|
1175 |
+
else:
|
1176 |
+
assert r_shape == hph_shape, \
|
1177 |
+
"shape of R should be {} but it is {}".format(hph_shape, r_shape)
|
1178 |
+
|
1179 |
+
|
1180 |
+
if z is not None:
|
1181 |
+
z_shape = shape(z)
|
1182 |
+
else:
|
1183 |
+
z_shape = (self.dim_z, 1)
|
1184 |
+
|
1185 |
+
# H@x must have shape of z
|
1186 |
+
Hx = dot(H, x)
|
1187 |
+
|
1188 |
+
if z_shape == (): # scalar or np.array(scalar)
|
1189 |
+
assert Hx.ndim == 1 or shape(Hx) == (1, 1), \
|
1190 |
+
"shape of z should be {}, not {} for the given H".format(
|
1191 |
+
shape(Hx), z_shape)
|
1192 |
+
|
1193 |
+
elif shape(Hx) == (1,):
|
1194 |
+
assert z_shape[0] == 1, 'Shape of z must be {} for the given H'.format(shape(Hx))
|
1195 |
+
|
1196 |
+
else:
|
1197 |
+
assert (z_shape == shape(Hx) or
|
1198 |
+
(len(z_shape) == 1 and shape(Hx) == (z_shape[0], 1))), \
|
1199 |
+
"shape of z should be {}, not {} for the given H".format(
|
1200 |
+
shape(Hx), z_shape)
|
1201 |
+
|
1202 |
+
if np.ndim(Hx) > 1 and shape(Hx) != (1, 1):
|
1203 |
+
assert shape(Hx) == z_shape, \
|
1204 |
+
'shape of z should be {} for the given H, but it is {}'.format(
|
1205 |
+
shape(Hx), z_shape)
|
1206 |
+
|
1207 |
+
|
1208 |
+
def update(x, P, z, R, H=None, return_all=False):
|
1209 |
+
"""
|
1210 |
+
Add a new measurement (z) to the Kalman filter. If z is None, nothing
|
1211 |
+
is changed.
|
1212 |
+
This can handle either the multidimensional or unidimensional case. If
|
1213 |
+
all parameters are floats instead of arrays the filter will still work,
|
1214 |
+
and return floats for x, P as the result.
|
1215 |
+
update(1, 2, 1, 1, 1) # univariate
|
1216 |
+
update(x, P, 1
|
1217 |
+
Parameters
|
1218 |
+
----------
|
1219 |
+
x : numpy.array(dim_x, 1), or float
|
1220 |
+
State estimate vector
|
1221 |
+
P : numpy.array(dim_x, dim_x), or float
|
1222 |
+
Covariance matrix
|
1223 |
+
z : (dim_z, 1): array_like
|
1224 |
+
measurement for this update. z can be a scalar if dim_z is 1,
|
1225 |
+
otherwise it must be convertible to a column vector.
|
1226 |
+
R : numpy.array(dim_z, dim_z), or float
|
1227 |
+
Measurement noise matrix
|
1228 |
+
H : numpy.array(dim_x, dim_x), or float, optional
|
1229 |
+
Measurement function. If not provided, a value of 1 is assumed.
|
1230 |
+
return_all : bool, default False
|
1231 |
+
If true, y, K, S, and log_likelihood are returned, otherwise
|
1232 |
+
only x and P are returned.
|
1233 |
+
Returns
|
1234 |
+
-------
|
1235 |
+
x : numpy.array
|
1236 |
+
Posterior state estimate vector
|
1237 |
+
P : numpy.array
|
1238 |
+
Posterior covariance matrix
|
1239 |
+
y : numpy.array or scalar
|
1240 |
+
Residua. Difference between measurement and state in measurement space
|
1241 |
+
K : numpy.array
|
1242 |
+
Kalman gain
|
1243 |
+
S : numpy.array
|
1244 |
+
System uncertainty in measurement space
|
1245 |
+
log_likelihood : float
|
1246 |
+
log likelihood of the measurement
|
1247 |
+
"""
|
1248 |
+
|
1249 |
+
#pylint: disable=bare-except
|
1250 |
+
|
1251 |
+
if z is None:
|
1252 |
+
if return_all:
|
1253 |
+
return x, P, None, None, None, None
|
1254 |
+
return x, P
|
1255 |
+
|
1256 |
+
if H is None:
|
1257 |
+
H = np.array([1])
|
1258 |
+
|
1259 |
+
if np.isscalar(H):
|
1260 |
+
H = np.array([H])
|
1261 |
+
|
1262 |
+
Hx = np.atleast_1d(dot(H, x))
|
1263 |
+
z = reshape_z(z, Hx.shape[0], x.ndim)
|
1264 |
+
|
1265 |
+
# error (residual) between measurement and prediction
|
1266 |
+
y = z - Hx
|
1267 |
+
|
1268 |
+
# project system uncertainty into measurement space
|
1269 |
+
S = dot(dot(H, P), H.T) + R
|
1270 |
+
|
1271 |
+
|
1272 |
+
# map system uncertainty into kalman gain
|
1273 |
+
try:
|
1274 |
+
K = dot(dot(P, H.T), linalg.inv(S))
|
1275 |
+
except:
|
1276 |
+
# can't invert a 1D array, annoyingly
|
1277 |
+
K = dot(dot(P, H.T), 1./S)
|
1278 |
+
|
1279 |
+
|
1280 |
+
# predict new x with residual scaled by the kalman gain
|
1281 |
+
x = x + dot(K, y)
|
1282 |
+
|
1283 |
+
# P = (I-KH)P(I-KH)' + KRK'
|
1284 |
+
KH = dot(K, H)
|
1285 |
+
|
1286 |
+
try:
|
1287 |
+
I_KH = np.eye(KH.shape[0]) - KH
|
1288 |
+
except:
|
1289 |
+
I_KH = np.array([1 - KH])
|
1290 |
+
P = dot(dot(I_KH, P), I_KH.T) + dot(dot(K, R), K.T)
|
1291 |
+
|
1292 |
+
|
1293 |
+
if return_all:
|
1294 |
+
# compute log likelihood
|
1295 |
+
log_likelihood = logpdf(z, dot(H, x), S)
|
1296 |
+
return x, P, y, K, S, log_likelihood
|
1297 |
+
return x, P
|
1298 |
+
|
1299 |
+
|
1300 |
+
def update_steadystate(x, z, K, H=None):
|
1301 |
+
"""
|
1302 |
+
Add a new measurement (z) to the Kalman filter. If z is None, nothing
|
1303 |
+
is changed.
|
1304 |
+
Parameters
|
1305 |
+
----------
|
1306 |
+
x : numpy.array(dim_x, 1), or float
|
1307 |
+
State estimate vector
|
1308 |
+
z : (dim_z, 1): array_like
|
1309 |
+
measurement for this update. z can be a scalar if dim_z is 1,
|
1310 |
+
otherwise it must be convertible to a column vector.
|
1311 |
+
K : numpy.array, or float
|
1312 |
+
Kalman gain matrix
|
1313 |
+
H : numpy.array(dim_x, dim_x), or float, optional
|
1314 |
+
Measurement function. If not provided, a value of 1 is assumed.
|
1315 |
+
Returns
|
1316 |
+
-------
|
1317 |
+
x : numpy.array
|
1318 |
+
Posterior state estimate vector
|
1319 |
+
Examples
|
1320 |
+
--------
|
1321 |
+
This can handle either the multidimensional or unidimensional case. If
|
1322 |
+
all parameters are floats instead of arrays the filter will still work,
|
1323 |
+
and return floats for x, P as the result.
|
1324 |
+
>>> update_steadystate(1, 2, 1) # univariate
|
1325 |
+
>>> update_steadystate(x, P, z, H)
|
1326 |
+
"""
|
1327 |
+
|
1328 |
+
|
1329 |
+
if z is None:
|
1330 |
+
return x
|
1331 |
+
|
1332 |
+
if H is None:
|
1333 |
+
H = np.array([1])
|
1334 |
+
|
1335 |
+
if np.isscalar(H):
|
1336 |
+
H = np.array([H])
|
1337 |
+
|
1338 |
+
Hx = np.atleast_1d(dot(H, x))
|
1339 |
+
z = reshape_z(z, Hx.shape[0], x.ndim)
|
1340 |
+
|
1341 |
+
# error (residual) between measurement and prediction
|
1342 |
+
y = z - Hx
|
1343 |
+
|
1344 |
+
# estimate new x with residual scaled by the kalman gain
|
1345 |
+
return x + dot(K, y)
|
1346 |
+
|
1347 |
+
|
1348 |
+
def predict(x, P, F=1, Q=0, u=0, B=1, alpha=1.):
|
1349 |
+
"""
|
1350 |
+
Predict next state (prior) using the Kalman filter state propagation
|
1351 |
+
equations.
|
1352 |
+
Parameters
|
1353 |
+
----------
|
1354 |
+
x : numpy.array
|
1355 |
+
State estimate vector
|
1356 |
+
P : numpy.array
|
1357 |
+
Covariance matrix
|
1358 |
+
F : numpy.array()
|
1359 |
+
State Transition matrix
|
1360 |
+
Q : numpy.array, Optional
|
1361 |
+
Process noise matrix
|
1362 |
+
u : numpy.array, Optional, default 0.
|
1363 |
+
Control vector. If non-zero, it is multiplied by B
|
1364 |
+
to create the control input into the system.
|
1365 |
+
B : numpy.array, optional, default 0.
|
1366 |
+
Control transition matrix.
|
1367 |
+
alpha : float, Optional, default=1.0
|
1368 |
+
Fading memory setting. 1.0 gives the normal Kalman filter, and
|
1369 |
+
values slightly larger than 1.0 (such as 1.02) give a fading
|
1370 |
+
memory effect - previous measurements have less influence on the
|
1371 |
+
filter's estimates. This formulation of the Fading memory filter
|
1372 |
+
(there are many) is due to Dan Simon
|
1373 |
+
Returns
|
1374 |
+
-------
|
1375 |
+
x : numpy.array
|
1376 |
+
Prior state estimate vector
|
1377 |
+
P : numpy.array
|
1378 |
+
Prior covariance matrix
|
1379 |
+
"""
|
1380 |
+
|
1381 |
+
if np.isscalar(F):
|
1382 |
+
F = np.array(F)
|
1383 |
+
x = dot(F, x) + dot(B, u)
|
1384 |
+
P = (alpha * alpha) * dot(dot(F, P), F.T) + Q
|
1385 |
+
|
1386 |
+
return x, P
|
1387 |
+
|
1388 |
+
|
1389 |
+
def predict_steadystate(x, F=1, u=0, B=1):
|
1390 |
+
"""
|
1391 |
+
Predict next state (prior) using the Kalman filter state propagation
|
1392 |
+
equations. This steady state form only computes x, assuming that the
|
1393 |
+
covariance is constant.
|
1394 |
+
Parameters
|
1395 |
+
----------
|
1396 |
+
x : numpy.array
|
1397 |
+
State estimate vector
|
1398 |
+
P : numpy.array
|
1399 |
+
Covariance matrix
|
1400 |
+
F : numpy.array()
|
1401 |
+
State Transition matrix
|
1402 |
+
u : numpy.array, Optional, default 0.
|
1403 |
+
Control vector. If non-zero, it is multiplied by B
|
1404 |
+
to create the control input into the system.
|
1405 |
+
B : numpy.array, optional, default 0.
|
1406 |
+
Control transition matrix.
|
1407 |
+
Returns
|
1408 |
+
-------
|
1409 |
+
x : numpy.array
|
1410 |
+
Prior state estimate vector
|
1411 |
+
"""
|
1412 |
+
|
1413 |
+
if np.isscalar(F):
|
1414 |
+
F = np.array(F)
|
1415 |
+
x = dot(F, x) + dot(B, u)
|
1416 |
+
|
1417 |
+
return x
|
1418 |
+
|
1419 |
+
|
1420 |
+
|
1421 |
+
def batch_filter(x, P, zs, Fs, Qs, Hs, Rs, Bs=None, us=None,
|
1422 |
+
update_first=False, saver=None):
|
1423 |
+
"""
|
1424 |
+
Batch processes a sequences of measurements.
|
1425 |
+
Parameters
|
1426 |
+
----------
|
1427 |
+
zs : list-like
|
1428 |
+
list of measurements at each time step. Missing measurements must be
|
1429 |
+
represented by None.
|
1430 |
+
Fs : list-like
|
1431 |
+
list of values to use for the state transition matrix matrix.
|
1432 |
+
Qs : list-like
|
1433 |
+
list of values to use for the process error
|
1434 |
+
covariance.
|
1435 |
+
Hs : list-like
|
1436 |
+
list of values to use for the measurement matrix.
|
1437 |
+
Rs : list-like
|
1438 |
+
list of values to use for the measurement error
|
1439 |
+
covariance.
|
1440 |
+
Bs : list-like, optional
|
1441 |
+
list of values to use for the control transition matrix;
|
1442 |
+
a value of None in any position will cause the filter
|
1443 |
+
to use `self.B` for that time step.
|
1444 |
+
us : list-like, optional
|
1445 |
+
list of values to use for the control input vector;
|
1446 |
+
a value of None in any position will cause the filter to use
|
1447 |
+
0 for that time step.
|
1448 |
+
update_first : bool, optional
|
1449 |
+
controls whether the order of operations is update followed by
|
1450 |
+
predict, or predict followed by update. Default is predict->update.
|
1451 |
+
saver : filterpy.common.Saver, optional
|
1452 |
+
filterpy.common.Saver object. If provided, saver.save() will be
|
1453 |
+
called after every epoch
|
1454 |
+
Returns
|
1455 |
+
-------
|
1456 |
+
means : np.array((n,dim_x,1))
|
1457 |
+
array of the state for each time step after the update. Each entry
|
1458 |
+
is an np.array. In other words `means[k,:]` is the state at step
|
1459 |
+
`k`.
|
1460 |
+
covariance : np.array((n,dim_x,dim_x))
|
1461 |
+
array of the covariances for each time step after the update.
|
1462 |
+
In other words `covariance[k,:,:]` is the covariance at step `k`.
|
1463 |
+
means_predictions : np.array((n,dim_x,1))
|
1464 |
+
array of the state for each time step after the predictions. Each
|
1465 |
+
entry is an np.array. In other words `means[k,:]` is the state at
|
1466 |
+
step `k`.
|
1467 |
+
covariance_predictions : np.array((n,dim_x,dim_x))
|
1468 |
+
array of the covariances for each time step after the prediction.
|
1469 |
+
In other words `covariance[k,:,:]` is the covariance at step `k`.
|
1470 |
+
Examples
|
1471 |
+
--------
|
1472 |
+
.. code-block:: Python
|
1473 |
+
zs = [t + random.randn()*4 for t in range (40)]
|
1474 |
+
Fs = [kf.F for t in range (40)]
|
1475 |
+
Hs = [kf.H for t in range (40)]
|
1476 |
+
(mu, cov, _, _) = kf.batch_filter(zs, Rs=R_list, Fs=Fs, Hs=Hs, Qs=None,
|
1477 |
+
Bs=None, us=None, update_first=False)
|
1478 |
+
(xs, Ps, Ks, Pps) = kf.rts_smoother(mu, cov, Fs=Fs, Qs=None)
|
1479 |
+
"""
|
1480 |
+
|
1481 |
+
n = np.size(zs, 0)
|
1482 |
+
dim_x = x.shape[0]
|
1483 |
+
|
1484 |
+
# mean estimates from Kalman Filter
|
1485 |
+
if x.ndim == 1:
|
1486 |
+
means = zeros((n, dim_x))
|
1487 |
+
means_p = zeros((n, dim_x))
|
1488 |
+
else:
|
1489 |
+
means = zeros((n, dim_x, 1))
|
1490 |
+
means_p = zeros((n, dim_x, 1))
|
1491 |
+
|
1492 |
+
# state covariances from Kalman Filter
|
1493 |
+
covariances = zeros((n, dim_x, dim_x))
|
1494 |
+
covariances_p = zeros((n, dim_x, dim_x))
|
1495 |
+
|
1496 |
+
if us is None:
|
1497 |
+
us = [0.] * n
|
1498 |
+
Bs = [0.] * n
|
1499 |
+
|
1500 |
+
if update_first:
|
1501 |
+
for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)):
|
1502 |
+
|
1503 |
+
x, P = update(x, P, z, R=R, H=H)
|
1504 |
+
means[i, :] = x
|
1505 |
+
covariances[i, :, :] = P
|
1506 |
+
|
1507 |
+
x, P = predict(x, P, u=u, B=B, F=F, Q=Q)
|
1508 |
+
means_p[i, :] = x
|
1509 |
+
covariances_p[i, :, :] = P
|
1510 |
+
if saver is not None:
|
1511 |
+
saver.save()
|
1512 |
+
else:
|
1513 |
+
for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)):
|
1514 |
+
|
1515 |
+
x, P = predict(x, P, u=u, B=B, F=F, Q=Q)
|
1516 |
+
means_p[i, :] = x
|
1517 |
+
covariances_p[i, :, :] = P
|
1518 |
+
|
1519 |
+
x, P = update(x, P, z, R=R, H=H)
|
1520 |
+
means[i, :] = x
|
1521 |
+
covariances[i, :, :] = P
|
1522 |
+
if saver is not None:
|
1523 |
+
saver.save()
|
1524 |
+
|
1525 |
+
return (means, covariances, means_p, covariances_p)
|
1526 |
+
|
1527 |
+
|
1528 |
+
|
1529 |
+
def rts_smoother(Xs, Ps, Fs, Qs):
|
1530 |
+
"""
|
1531 |
+
Runs the Rauch-Tung-Striebel Kalman smoother on a set of
|
1532 |
+
means and covariances computed by a Kalman filter. The usual input
|
1533 |
+
would come from the output of `KalmanFilter.batch_filter()`.
|
1534 |
+
Parameters
|
1535 |
+
----------
|
1536 |
+
Xs : numpy.array
|
1537 |
+
array of the means (state variable x) of the output of a Kalman
|
1538 |
+
filter.
|
1539 |
+
Ps : numpy.array
|
1540 |
+
array of the covariances of the output of a kalman filter.
|
1541 |
+
Fs : list-like collection of numpy.array
|
1542 |
+
State transition matrix of the Kalman filter at each time step.
|
1543 |
+
Qs : list-like collection of numpy.array, optional
|
1544 |
+
Process noise of the Kalman filter at each time step.
|
1545 |
+
Returns
|
1546 |
+
-------
|
1547 |
+
x : numpy.ndarray
|
1548 |
+
smoothed means
|
1549 |
+
P : numpy.ndarray
|
1550 |
+
smoothed state covariances
|
1551 |
+
K : numpy.ndarray
|
1552 |
+
smoother gain at each step
|
1553 |
+
pP : numpy.ndarray
|
1554 |
+
predicted state covariances
|
1555 |
+
Examples
|
1556 |
+
--------
|
1557 |
+
.. code-block:: Python
|
1558 |
+
zs = [t + random.randn()*4 for t in range (40)]
|
1559 |
+
(mu, cov, _, _) = kalman.batch_filter(zs)
|
1560 |
+
(x, P, K, pP) = rts_smoother(mu, cov, kf.F, kf.Q)
|
1561 |
+
"""
|
1562 |
+
|
1563 |
+
if len(Xs) != len(Ps):
|
1564 |
+
raise ValueError('length of Xs and Ps must be the same')
|
1565 |
+
|
1566 |
+
n = Xs.shape[0]
|
1567 |
+
dim_x = Xs.shape[1]
|
1568 |
+
|
1569 |
+
# smoother gain
|
1570 |
+
K = zeros((n, dim_x, dim_x))
|
1571 |
+
x, P, pP = Xs.copy(), Ps.copy(), Ps.copy()
|
1572 |
+
|
1573 |
+
for k in range(n-2, -1, -1):
|
1574 |
+
pP[k] = dot(dot(Fs[k], P[k]), Fs[k].T) + Qs[k]
|
1575 |
+
|
1576 |
+
#pylint: disable=bad-whitespace
|
1577 |
+
K[k] = dot(dot(P[k], Fs[k].T), linalg.inv(pP[k]))
|
1578 |
+
x[k] += dot(K[k], x[k+1] - dot(Fs[k], x[k]))
|
1579 |
+
P[k] += dot(dot(K[k], P[k+1] - pP[k]), K[k].T)
|
1580 |
+
|
1581 |
+
return (x, P, K, pP)
|
trackers/ocsort/ocsort.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script is adopted from the SORT script by Alex Bewley alex@bewley.ai
|
3 |
+
"""
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from .association import *
|
8 |
+
from ultralytics.yolo.utils.ops import xywh2xyxy
|
9 |
+
|
10 |
+
|
11 |
+
def k_previous_obs(observations, cur_age, k):
|
12 |
+
if len(observations) == 0:
|
13 |
+
return [-1, -1, -1, -1, -1]
|
14 |
+
for i in range(k):
|
15 |
+
dt = k - i
|
16 |
+
if cur_age - dt in observations:
|
17 |
+
return observations[cur_age-dt]
|
18 |
+
max_age = max(observations.keys())
|
19 |
+
return observations[max_age]
|
20 |
+
|
21 |
+
|
22 |
+
def convert_bbox_to_z(bbox):
|
23 |
+
"""
|
24 |
+
Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
|
25 |
+
[x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
|
26 |
+
the aspect ratio
|
27 |
+
"""
|
28 |
+
w = bbox[2] - bbox[0]
|
29 |
+
h = bbox[3] - bbox[1]
|
30 |
+
x = bbox[0] + w/2.
|
31 |
+
y = bbox[1] + h/2.
|
32 |
+
s = w * h # scale is just area
|
33 |
+
r = w / float(h+1e-6)
|
34 |
+
return np.array([x, y, s, r]).reshape((4, 1))
|
35 |
+
|
36 |
+
|
37 |
+
def convert_x_to_bbox(x, score=None):
|
38 |
+
"""
|
39 |
+
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
|
40 |
+
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
|
41 |
+
"""
|
42 |
+
w = np.sqrt(x[2] * x[3])
|
43 |
+
h = x[2] / w
|
44 |
+
if(score == None):
|
45 |
+
return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2.]).reshape((1, 4))
|
46 |
+
else:
|
47 |
+
return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2., score]).reshape((1, 5))
|
48 |
+
|
49 |
+
|
50 |
+
def speed_direction(bbox1, bbox2):
|
51 |
+
cx1, cy1 = (bbox1[0]+bbox1[2]) / 2.0, (bbox1[1]+bbox1[3])/2.0
|
52 |
+
cx2, cy2 = (bbox2[0]+bbox2[2]) / 2.0, (bbox2[1]+bbox2[3])/2.0
|
53 |
+
speed = np.array([cy2-cy1, cx2-cx1])
|
54 |
+
norm = np.sqrt((cy2-cy1)**2 + (cx2-cx1)**2) + 1e-6
|
55 |
+
return speed / norm
|
56 |
+
|
57 |
+
|
58 |
+
class KalmanBoxTracker(object):
|
59 |
+
"""
|
60 |
+
This class represents the internal state of individual tracked objects observed as bbox.
|
61 |
+
"""
|
62 |
+
count = 0
|
63 |
+
|
64 |
+
def __init__(self, bbox, cls, delta_t=3, orig=False):
|
65 |
+
"""
|
66 |
+
Initialises a tracker using initial bounding box.
|
67 |
+
|
68 |
+
"""
|
69 |
+
# define constant velocity model
|
70 |
+
if not orig:
|
71 |
+
from .kalmanfilter import KalmanFilterNew as KalmanFilter
|
72 |
+
self.kf = KalmanFilter(dim_x=7, dim_z=4)
|
73 |
+
else:
|
74 |
+
from filterpy.kalman import KalmanFilter
|
75 |
+
self.kf = KalmanFilter(dim_x=7, dim_z=4)
|
76 |
+
self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0], [0, 0, 1, 0, 0, 0, 1], [
|
77 |
+
0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 1]])
|
78 |
+
self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0],
|
79 |
+
[0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]])
|
80 |
+
|
81 |
+
self.kf.R[2:, 2:] *= 10.
|
82 |
+
self.kf.P[4:, 4:] *= 1000. # give high uncertainty to the unobservable initial velocities
|
83 |
+
self.kf.P *= 10.
|
84 |
+
self.kf.Q[-1, -1] *= 0.01
|
85 |
+
self.kf.Q[4:, 4:] *= 0.01
|
86 |
+
|
87 |
+
self.kf.x[:4] = convert_bbox_to_z(bbox)
|
88 |
+
self.time_since_update = 0
|
89 |
+
self.id = KalmanBoxTracker.count
|
90 |
+
KalmanBoxTracker.count += 1
|
91 |
+
self.history = []
|
92 |
+
self.hits = 0
|
93 |
+
self.hit_streak = 0
|
94 |
+
self.age = 0
|
95 |
+
self.conf = bbox[-1]
|
96 |
+
self.cls = cls
|
97 |
+
"""
|
98 |
+
NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of
|
99 |
+
function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a
|
100 |
+
fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now.
|
101 |
+
"""
|
102 |
+
self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder
|
103 |
+
self.observations = dict()
|
104 |
+
self.history_observations = []
|
105 |
+
self.velocity = None
|
106 |
+
self.delta_t = delta_t
|
107 |
+
|
108 |
+
def update(self, bbox, cls):
|
109 |
+
"""
|
110 |
+
Updates the state vector with observed bbox.
|
111 |
+
"""
|
112 |
+
|
113 |
+
if bbox is not None:
|
114 |
+
self.conf = bbox[-1]
|
115 |
+
self.cls = cls
|
116 |
+
if self.last_observation.sum() >= 0: # no previous observation
|
117 |
+
previous_box = None
|
118 |
+
for i in range(self.delta_t):
|
119 |
+
dt = self.delta_t - i
|
120 |
+
if self.age - dt in self.observations:
|
121 |
+
previous_box = self.observations[self.age-dt]
|
122 |
+
break
|
123 |
+
if previous_box is None:
|
124 |
+
previous_box = self.last_observation
|
125 |
+
"""
|
126 |
+
Estimate the track speed direction with observations \Delta t steps away
|
127 |
+
"""
|
128 |
+
self.velocity = speed_direction(previous_box, bbox)
|
129 |
+
|
130 |
+
"""
|
131 |
+
Insert new observations. This is a ugly way to maintain both self.observations
|
132 |
+
and self.history_observations. Bear it for the moment.
|
133 |
+
"""
|
134 |
+
self.last_observation = bbox
|
135 |
+
self.observations[self.age] = bbox
|
136 |
+
self.history_observations.append(bbox)
|
137 |
+
|
138 |
+
self.time_since_update = 0
|
139 |
+
self.history = []
|
140 |
+
self.hits += 1
|
141 |
+
self.hit_streak += 1
|
142 |
+
self.kf.update(convert_bbox_to_z(bbox))
|
143 |
+
else:
|
144 |
+
self.kf.update(bbox)
|
145 |
+
|
146 |
+
def predict(self):
|
147 |
+
"""
|
148 |
+
Advances the state vector and returns the predicted bounding box estimate.
|
149 |
+
"""
|
150 |
+
if((self.kf.x[6]+self.kf.x[2]) <= 0):
|
151 |
+
self.kf.x[6] *= 0.0
|
152 |
+
|
153 |
+
self.kf.predict()
|
154 |
+
self.age += 1
|
155 |
+
if(self.time_since_update > 0):
|
156 |
+
self.hit_streak = 0
|
157 |
+
self.time_since_update += 1
|
158 |
+
self.history.append(convert_x_to_bbox(self.kf.x))
|
159 |
+
return self.history[-1]
|
160 |
+
|
161 |
+
def get_state(self):
|
162 |
+
"""
|
163 |
+
Returns the current bounding box estimate.
|
164 |
+
"""
|
165 |
+
return convert_x_to_bbox(self.kf.x)
|
166 |
+
|
167 |
+
|
168 |
+
"""
|
169 |
+
We support multiple ways for association cost calculation, by default
|
170 |
+
we use IoU. GIoU may have better performance in some situations. We note
|
171 |
+
that we hardly normalize the cost by all methods to (0,1) which may not be
|
172 |
+
the best practice.
|
173 |
+
"""
|
174 |
+
ASSO_FUNCS = { "iou": iou_batch,
|
175 |
+
"giou": giou_batch,
|
176 |
+
"ciou": ciou_batch,
|
177 |
+
"diou": diou_batch,
|
178 |
+
"ct_dist": ct_dist}
|
179 |
+
|
180 |
+
|
181 |
+
class OCSort(object):
|
182 |
+
def __init__(self, det_thresh, max_age=30, min_hits=3,
|
183 |
+
iou_threshold=0.3, delta_t=3, asso_func="iou", inertia=0.2, use_byte=False):
|
184 |
+
"""
|
185 |
+
Sets key parameters for SORT
|
186 |
+
"""
|
187 |
+
self.max_age = max_age
|
188 |
+
self.min_hits = min_hits
|
189 |
+
self.iou_threshold = iou_threshold
|
190 |
+
self.trackers = []
|
191 |
+
self.frame_count = 0
|
192 |
+
self.det_thresh = det_thresh
|
193 |
+
self.delta_t = delta_t
|
194 |
+
self.asso_func = ASSO_FUNCS[asso_func]
|
195 |
+
self.inertia = inertia
|
196 |
+
self.use_byte = use_byte
|
197 |
+
KalmanBoxTracker.count = 0
|
198 |
+
|
199 |
+
def update(self, dets, _):
|
200 |
+
"""
|
201 |
+
Params:
|
202 |
+
dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
|
203 |
+
Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections).
|
204 |
+
Returns the a similar array, where the last column is the object ID.
|
205 |
+
NOTE: The number of objects returned may differ from the number of detections provided.
|
206 |
+
"""
|
207 |
+
|
208 |
+
self.frame_count += 1
|
209 |
+
|
210 |
+
xyxys = dets[:, 0:4]
|
211 |
+
confs = dets[:, 4]
|
212 |
+
clss = dets[:, 5]
|
213 |
+
|
214 |
+
classes = clss.numpy()
|
215 |
+
xyxys = xyxys.numpy()
|
216 |
+
confs = confs.numpy()
|
217 |
+
|
218 |
+
output_results = np.column_stack((xyxys, confs, classes))
|
219 |
+
|
220 |
+
inds_low = confs > 0.1
|
221 |
+
inds_high = confs < self.det_thresh
|
222 |
+
inds_second = np.logical_and(inds_low, inds_high) # self.det_thresh > score > 0.1, for second matching
|
223 |
+
dets_second = output_results[inds_second] # detections for second matching
|
224 |
+
remain_inds = confs > self.det_thresh
|
225 |
+
dets = output_results[remain_inds]
|
226 |
+
|
227 |
+
# get predicted locations from existing trackers.
|
228 |
+
trks = np.zeros((len(self.trackers), 5))
|
229 |
+
to_del = []
|
230 |
+
ret = []
|
231 |
+
for t, trk in enumerate(trks):
|
232 |
+
pos = self.trackers[t].predict()[0]
|
233 |
+
trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
|
234 |
+
if np.any(np.isnan(pos)):
|
235 |
+
to_del.append(t)
|
236 |
+
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
|
237 |
+
for t in reversed(to_del):
|
238 |
+
self.trackers.pop(t)
|
239 |
+
|
240 |
+
velocities = np.array(
|
241 |
+
[trk.velocity if trk.velocity is not None else np.array((0, 0)) for trk in self.trackers])
|
242 |
+
last_boxes = np.array([trk.last_observation for trk in self.trackers])
|
243 |
+
k_observations = np.array(
|
244 |
+
[k_previous_obs(trk.observations, trk.age, self.delta_t) for trk in self.trackers])
|
245 |
+
|
246 |
+
"""
|
247 |
+
First round of association
|
248 |
+
"""
|
249 |
+
matched, unmatched_dets, unmatched_trks = associate(
|
250 |
+
dets, trks, self.iou_threshold, velocities, k_observations, self.inertia)
|
251 |
+
for m in matched:
|
252 |
+
self.trackers[m[1]].update(dets[m[0], :5], dets[m[0], 5])
|
253 |
+
|
254 |
+
"""
|
255 |
+
Second round of associaton by OCR
|
256 |
+
"""
|
257 |
+
# BYTE association
|
258 |
+
if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[0] > 0:
|
259 |
+
u_trks = trks[unmatched_trks]
|
260 |
+
iou_left = self.asso_func(dets_second, u_trks) # iou between low score detections and unmatched tracks
|
261 |
+
iou_left = np.array(iou_left)
|
262 |
+
if iou_left.max() > self.iou_threshold:
|
263 |
+
"""
|
264 |
+
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
|
265 |
+
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
|
266 |
+
uniform here for simplicity
|
267 |
+
"""
|
268 |
+
matched_indices = linear_assignment(-iou_left)
|
269 |
+
to_remove_trk_indices = []
|
270 |
+
for m in matched_indices:
|
271 |
+
det_ind, trk_ind = m[0], unmatched_trks[m[1]]
|
272 |
+
if iou_left[m[0], m[1]] < self.iou_threshold:
|
273 |
+
continue
|
274 |
+
self.trackers[trk_ind].update(dets_second[det_ind, :5], dets_second[det_ind, 5])
|
275 |
+
to_remove_trk_indices.append(trk_ind)
|
276 |
+
unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))
|
277 |
+
|
278 |
+
if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
|
279 |
+
left_dets = dets[unmatched_dets]
|
280 |
+
left_trks = last_boxes[unmatched_trks]
|
281 |
+
iou_left = self.asso_func(left_dets, left_trks)
|
282 |
+
iou_left = np.array(iou_left)
|
283 |
+
if iou_left.max() > self.iou_threshold:
|
284 |
+
"""
|
285 |
+
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
|
286 |
+
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
|
287 |
+
uniform here for simplicity
|
288 |
+
"""
|
289 |
+
rematched_indices = linear_assignment(-iou_left)
|
290 |
+
to_remove_det_indices = []
|
291 |
+
to_remove_trk_indices = []
|
292 |
+
for m in rematched_indices:
|
293 |
+
det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]]
|
294 |
+
if iou_left[m[0], m[1]] < self.iou_threshold:
|
295 |
+
continue
|
296 |
+
self.trackers[trk_ind].update(dets[det_ind, :5], dets[det_ind, 5])
|
297 |
+
to_remove_det_indices.append(det_ind)
|
298 |
+
to_remove_trk_indices.append(trk_ind)
|
299 |
+
unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices))
|
300 |
+
unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))
|
301 |
+
|
302 |
+
for m in unmatched_trks:
|
303 |
+
self.trackers[m].update(None, None)
|
304 |
+
|
305 |
+
# create and initialise new trackers for unmatched detections
|
306 |
+
for i in unmatched_dets:
|
307 |
+
trk = KalmanBoxTracker(dets[i, :5], dets[i, 5], delta_t=self.delta_t)
|
308 |
+
self.trackers.append(trk)
|
309 |
+
i = len(self.trackers)
|
310 |
+
for trk in reversed(self.trackers):
|
311 |
+
if trk.last_observation.sum() < 0:
|
312 |
+
d = trk.get_state()[0]
|
313 |
+
else:
|
314 |
+
"""
|
315 |
+
this is optional to use the recent observation or the kalman filter prediction,
|
316 |
+
we didn't notice significant difference here
|
317 |
+
"""
|
318 |
+
d = trk.last_observation[:4]
|
319 |
+
if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
|
320 |
+
# +1 as MOT benchmark requires positive
|
321 |
+
ret.append(np.concatenate((d, [trk.id+1], [trk.cls], [trk.conf])).reshape(1, -1))
|
322 |
+
i -= 1
|
323 |
+
# remove dead tracklet
|
324 |
+
if(trk.time_since_update > self.max_age):
|
325 |
+
self.trackers.pop(i)
|
326 |
+
if(len(ret) > 0):
|
327 |
+
return np.concatenate(ret)
|
328 |
+
return np.empty((0, 5))
|
trackers/reid_export.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import os
|
4 |
+
# limit the number of cpus used by high performance libraries
|
5 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
6 |
+
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
7 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
8 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
|
9 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
10 |
+
|
11 |
+
import sys
|
12 |
+
import numpy as np
|
13 |
+
from pathlib import Path
|
14 |
+
import torch
|
15 |
+
import time
|
16 |
+
import platform
|
17 |
+
import pandas as pd
|
18 |
+
import subprocess
|
19 |
+
import torch.backends.cudnn as cudnn
|
20 |
+
from torch.utils.mobile_optimizer import optimize_for_mobile
|
21 |
+
|
22 |
+
FILE = Path(__file__).resolve()
|
23 |
+
ROOT = FILE.parents[0].parents[0] # yolov5 strongsort root directory
|
24 |
+
WEIGHTS = ROOT / 'weights'
|
25 |
+
|
26 |
+
|
27 |
+
if str(ROOT) not in sys.path:
|
28 |
+
sys.path.append(str(ROOT)) # add ROOT to PATH
|
29 |
+
if str(ROOT / 'yolov5') not in sys.path:
|
30 |
+
sys.path.append(str(ROOT / 'yolov5')) # add yolov5 ROOT to PATH
|
31 |
+
|
32 |
+
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
33 |
+
|
34 |
+
import logging
|
35 |
+
from ultralytics.yolo.utils.torch_utils import select_device
|
36 |
+
from ultralytics.yolo.utils import LOGGER, colorstr, ops
|
37 |
+
from ultralytics.yolo.utils.checks import check_requirements, check_version
|
38 |
+
from trackers.strongsort.deep.models import build_model
|
39 |
+
from trackers.strongsort.deep.reid_model_factory import get_model_name, load_pretrained_weights
|
40 |
+
|
41 |
+
|
42 |
+
def file_size(path):
|
43 |
+
# Return file/dir size (MB)
|
44 |
+
path = Path(path)
|
45 |
+
if path.is_file():
|
46 |
+
return path.stat().st_size / 1E6
|
47 |
+
elif path.is_dir():
|
48 |
+
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
|
49 |
+
else:
|
50 |
+
return 0.0
|
51 |
+
|
52 |
+
|
53 |
+
def export_formats():
|
54 |
+
# YOLOv5 export formats
|
55 |
+
x = [
|
56 |
+
['PyTorch', '-', '.pt', True, True],
|
57 |
+
['TorchScript', 'torchscript', '.torchscript', True, True],
|
58 |
+
['ONNX', 'onnx', '.onnx', True, True],
|
59 |
+
['OpenVINO', 'openvino', '_openvino_model', True, False],
|
60 |
+
['TensorRT', 'engine', '.engine', False, True],
|
61 |
+
['TensorFlow Lite', 'tflite', '.tflite', True, False],
|
62 |
+
]
|
63 |
+
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
64 |
+
|
65 |
+
|
66 |
+
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
|
67 |
+
# YOLOv5 TorchScript model export
|
68 |
+
try:
|
69 |
+
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
|
70 |
+
f = file.with_suffix('.torchscript')
|
71 |
+
|
72 |
+
ts = torch.jit.trace(model, im, strict=False)
|
73 |
+
if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
74 |
+
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f))
|
75 |
+
else:
|
76 |
+
ts.save(str(f))
|
77 |
+
|
78 |
+
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
79 |
+
return f
|
80 |
+
except Exception as e:
|
81 |
+
LOGGER.info(f'{prefix} export failure: {e}')
|
82 |
+
|
83 |
+
|
84 |
+
def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):
|
85 |
+
# ONNX export
|
86 |
+
try:
|
87 |
+
check_requirements(('onnx',))
|
88 |
+
import onnx
|
89 |
+
|
90 |
+
f = file.with_suffix('.onnx')
|
91 |
+
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
|
92 |
+
|
93 |
+
if dynamic:
|
94 |
+
dynamic = {'images': {0: 'batch'}} # shape(1,3,640,640)
|
95 |
+
dynamic['output'] = {0: 'batch'} # shape(1,25200,85)
|
96 |
+
|
97 |
+
torch.onnx.export(
|
98 |
+
model.cpu() if dynamic else model, # --dynamic only compatible with cpu
|
99 |
+
im.cpu() if dynamic else im,
|
100 |
+
f,
|
101 |
+
verbose=False,
|
102 |
+
opset_version=opset,
|
103 |
+
do_constant_folding=True,
|
104 |
+
input_names=['images'],
|
105 |
+
output_names=['output'],
|
106 |
+
dynamic_axes=dynamic or None
|
107 |
+
)
|
108 |
+
# Checks
|
109 |
+
model_onnx = onnx.load(f) # load onnx model
|
110 |
+
onnx.checker.check_model(model_onnx) # check onnx model
|
111 |
+
onnx.save(model_onnx, f)
|
112 |
+
|
113 |
+
# Simplify
|
114 |
+
if simplify:
|
115 |
+
try:
|
116 |
+
cuda = torch.cuda.is_available()
|
117 |
+
check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
|
118 |
+
import onnxsim
|
119 |
+
|
120 |
+
LOGGER.info(f'simplifying with onnx-simplifier {onnxsim.__version__}...')
|
121 |
+
model_onnx, check = onnxsim.simplify(model_onnx)
|
122 |
+
assert check, 'assert check failed'
|
123 |
+
onnx.save(model_onnx, f)
|
124 |
+
except Exception as e:
|
125 |
+
LOGGER.info(f'simplifier failure: {e}')
|
126 |
+
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
127 |
+
return f
|
128 |
+
except Exception as e:
|
129 |
+
LOGGER.info(f'export failure: {e}')
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
def export_openvino(file, half, prefix=colorstr('OpenVINO:')):
|
134 |
+
# YOLOv5 OpenVINO export
|
135 |
+
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
136 |
+
import openvino.inference_engine as ie
|
137 |
+
try:
|
138 |
+
LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
|
139 |
+
f = str(file).replace('.pt', f'_openvino_model{os.sep}')
|
140 |
+
|
141 |
+
cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
|
142 |
+
subprocess.check_output(cmd.split()) # export
|
143 |
+
except Exception as e:
|
144 |
+
LOGGER.info(f'export failure: {e}')
|
145 |
+
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
146 |
+
return f
|
147 |
+
|
148 |
+
|
149 |
+
def export_tflite(file, half, prefix=colorstr('TFLite:')):
|
150 |
+
# YOLOv5 OpenVINO export
|
151 |
+
try:
|
152 |
+
check_requirements(('openvino2tensorflow', 'tensorflow', 'tensorflow_datasets')) # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
153 |
+
import openvino.inference_engine as ie
|
154 |
+
LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
|
155 |
+
output = Path(str(file).replace(f'_openvino_model{os.sep}', f'_tflite_model{os.sep}'))
|
156 |
+
modelxml = list(Path(file).glob('*.xml'))[0]
|
157 |
+
cmd = f"openvino2tensorflow \
|
158 |
+
--model_path {modelxml} \
|
159 |
+
--model_output_path {output} \
|
160 |
+
--output_pb \
|
161 |
+
--output_saved_model \
|
162 |
+
--output_no_quant_float32_tflite \
|
163 |
+
--output_dynamic_range_quant_tflite"
|
164 |
+
subprocess.check_output(cmd.split()) # export
|
165 |
+
|
166 |
+
LOGGER.info(f'{prefix} export success, results saved in {output} ({file_size(f):.1f} MB)')
|
167 |
+
return f
|
168 |
+
except Exception as e:
|
169 |
+
LOGGER.info(f'\n{prefix} export failure: {e}')
|
170 |
+
|
171 |
+
|
172 |
+
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
|
173 |
+
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
|
174 |
+
try:
|
175 |
+
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
|
176 |
+
try:
|
177 |
+
import tensorrt as trt
|
178 |
+
except Exception:
|
179 |
+
if platform.system() == 'Linux':
|
180 |
+
check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
|
181 |
+
import tensorrt as trt
|
182 |
+
|
183 |
+
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
184 |
+
grid = model.model[-1].anchor_grid
|
185 |
+
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
|
186 |
+
export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
|
187 |
+
model.model[-1].anchor_grid = grid
|
188 |
+
else: # TensorRT >= 8
|
189 |
+
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
|
190 |
+
export_onnx(model, im, file, 12, dynamic, simplify) # opset 13
|
191 |
+
onnx = file.with_suffix('.onnx')
|
192 |
+
|
193 |
+
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
|
194 |
+
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
|
195 |
+
f = file.with_suffix('.engine') # TensorRT engine file
|
196 |
+
logger = trt.Logger(trt.Logger.INFO)
|
197 |
+
if verbose:
|
198 |
+
logger.min_severity = trt.Logger.Severity.VERBOSE
|
199 |
+
|
200 |
+
builder = trt.Builder(logger)
|
201 |
+
config = builder.create_builder_config()
|
202 |
+
config.max_workspace_size = workspace * 1 << 30
|
203 |
+
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
|
204 |
+
|
205 |
+
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
206 |
+
network = builder.create_network(flag)
|
207 |
+
parser = trt.OnnxParser(network, logger)
|
208 |
+
if not parser.parse_from_file(str(onnx)):
|
209 |
+
raise RuntimeError(f'failed to load ONNX file: {onnx}')
|
210 |
+
|
211 |
+
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
212 |
+
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
213 |
+
LOGGER.info(f'{prefix} Network Description:')
|
214 |
+
for inp in inputs:
|
215 |
+
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
|
216 |
+
for out in outputs:
|
217 |
+
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
|
218 |
+
|
219 |
+
if dynamic:
|
220 |
+
if im.shape[0] <= 1:
|
221 |
+
LOGGER.warning(f"{prefix}WARNING: --dynamic model requires maximum --batch-size argument")
|
222 |
+
profile = builder.create_optimization_profile()
|
223 |
+
for inp in inputs:
|
224 |
+
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
|
225 |
+
config.add_optimization_profile(profile)
|
226 |
+
|
227 |
+
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
|
228 |
+
if builder.platform_has_fast_fp16 and half:
|
229 |
+
config.set_flag(trt.BuilderFlag.FP16)
|
230 |
+
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
|
231 |
+
t.write(engine.serialize())
|
232 |
+
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
233 |
+
return f
|
234 |
+
except Exception as e:
|
235 |
+
LOGGER.info(f'\n{prefix} export failure: {e}')
|
236 |
+
|
237 |
+
|
238 |
+
if __name__ == "__main__":
|
239 |
+
|
240 |
+
parser = argparse.ArgumentParser(description="ReID export")
|
241 |
+
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
242 |
+
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[256, 128], help='image (h, w)')
|
243 |
+
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
244 |
+
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
|
245 |
+
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
|
246 |
+
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
|
247 |
+
parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
|
248 |
+
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
|
249 |
+
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
|
250 |
+
parser.add_argument('--weights', nargs='+', type=str, default=WEIGHTS / 'osnet_x0_25_msmt17.pt', help='model.pt path(s)')
|
251 |
+
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
|
252 |
+
parser.add_argument('--include',
|
253 |
+
nargs='+',
|
254 |
+
default=['torchscript'],
|
255 |
+
help='torchscript, onnx, openvino, engine')
|
256 |
+
args = parser.parse_args()
|
257 |
+
|
258 |
+
t = time.time()
|
259 |
+
|
260 |
+
include = [x.lower() for x in args.include] # to lowercase
|
261 |
+
fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
|
262 |
+
flags = [x in include for x in fmts]
|
263 |
+
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
|
264 |
+
jit, onnx, openvino, engine, tflite = flags # export booleans
|
265 |
+
|
266 |
+
args.device = select_device(args.device)
|
267 |
+
if args.half:
|
268 |
+
assert args.device.type != 'cpu', '--half only compatible with GPU export, i.e. use --device 0'
|
269 |
+
assert not args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
|
270 |
+
|
271 |
+
if type(args.weights) is list:
|
272 |
+
args.weights = Path(args.weights[0])
|
273 |
+
|
274 |
+
model = build_model(
|
275 |
+
get_model_name(args.weights),
|
276 |
+
num_classes=1,
|
277 |
+
pretrained=not (args.weights and args.weights.is_file() and args.weights.suffix == '.pt'),
|
278 |
+
use_gpu=args.device
|
279 |
+
).to(args.device)
|
280 |
+
load_pretrained_weights(model, args.weights)
|
281 |
+
model.eval()
|
282 |
+
|
283 |
+
if args.optimize:
|
284 |
+
assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
|
285 |
+
|
286 |
+
im = torch.zeros(args.batch_size, 3, args.imgsz[0], args.imgsz[1]).to(args.device) # image size(1,3,640,480) BCHW iDetection
|
287 |
+
for _ in range(2):
|
288 |
+
y = model(im) # dry runs
|
289 |
+
if args.half:
|
290 |
+
im, model = im.half(), model.half() # to FP16
|
291 |
+
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
|
292 |
+
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {args.weights} with output shape {shape} ({file_size(args.weights):.1f} MB)")
|
293 |
+
|
294 |
+
# Exports
|
295 |
+
f = [''] * len(fmts) # exported filenames
|
296 |
+
if jit:
|
297 |
+
f[0] = export_torchscript(model, im, args.weights, args.optimize) # opset 12
|
298 |
+
if engine: # TensorRT required before ONNX
|
299 |
+
f[1] = export_engine(model, im, args.weights, args.half, args.dynamic, args.simplify, args.workspace, args.verbose)
|
300 |
+
if onnx: # OpenVINO requires ONNX
|
301 |
+
f[2] = export_onnx(model, im, args.weights, args.opset, args.dynamic, args.simplify) # opset 12
|
302 |
+
if openvino:
|
303 |
+
f[3] = export_openvino(args.weights, args.half)
|
304 |
+
if tflite:
|
305 |
+
export_tflite(f, False)
|
306 |
+
|
307 |
+
# Finish
|
308 |
+
f = [str(x) for x in f if x] # filter out '' and None
|
309 |
+
if any(f):
|
310 |
+
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
|
311 |
+
f"\nResults saved to {colorstr('bold', args.weights.parent.resolve())}"
|
312 |
+
f"\nVisualize: https://netron.app")
|
313 |
+
|
trackers/strongsort/.gitignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Folders
|
2 |
+
__pycache__/
|
3 |
+
build/
|
4 |
+
*.egg-info
|
5 |
+
|
6 |
+
|
7 |
+
# Files
|
8 |
+
*.weights
|
9 |
+
*.t7
|
10 |
+
*.mp4
|
11 |
+
*.avi
|
12 |
+
*.so
|
13 |
+
*.txt
|
trackers/strongsort/__init__.py
ADDED
File without changes
|
trackers/strongsort/configs/strongsort.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
strongsort:
|
2 |
+
ecc: true
|
3 |
+
ema_alpha: 0.8962157769329083
|
4 |
+
max_age: 40
|
5 |
+
max_dist: 0.1594374041012136
|
6 |
+
max_iou_dist: 0.5431835667667874
|
7 |
+
max_unmatched_preds: 0
|
8 |
+
mc_lambda: 0.995
|
9 |
+
n_init: 3
|
10 |
+
nn_budget: 100
|
11 |
+
conf_thres: 0.5122620708221085
|
trackers/strongsort/deep/checkpoint/.gitkeep
ADDED
File without changes
|
trackers/strongsort/deep/checkpoint/osnet_x0_25_market1501.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c0ff09177a21417a19bc73bbf5af4eb5d2b097c2074ed35e67819ee6cd93612c
|
3 |
+
size 2462783
|
trackers/strongsort/deep/checkpoint/osnet_x0_25_msmt17.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6f57607fed9f502b9efed546108132ee715df5a5b6e6932c6269bacb47f59f99
|
3 |
+
size 3057863
|
trackers/strongsort/deep/checkpoint/osnet_x1_0_msmt17.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b7d73dc67c016fd044e4027ff856019496392a7aca8fa0ed56d862a1632c1cf2
|
3 |
+
size 10994685
|
trackers/strongsort/deep/models/__init__.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from .pcb import *
|
5 |
+
from .mlfn import *
|
6 |
+
from .hacnn import *
|
7 |
+
from .osnet import *
|
8 |
+
from .senet import *
|
9 |
+
from .mudeep import *
|
10 |
+
from .nasnet import *
|
11 |
+
from .resnet import *
|
12 |
+
from .densenet import *
|
13 |
+
from .xception import *
|
14 |
+
from .osnet_ain import *
|
15 |
+
from .resnetmid import *
|
16 |
+
from .shufflenet import *
|
17 |
+
from .squeezenet import *
|
18 |
+
from .inceptionv4 import *
|
19 |
+
from .mobilenetv2 import *
|
20 |
+
from .resnet_ibn_a import *
|
21 |
+
from .resnet_ibn_b import *
|
22 |
+
from .shufflenetv2 import *
|
23 |
+
from .inceptionresnetv2 import *
|
24 |
+
|
25 |
+
__model_factory = {
|
26 |
+
# image classification models
|
27 |
+
'resnet18': resnet18,
|
28 |
+
'resnet34': resnet34,
|
29 |
+
'resnet50': resnet50,
|
30 |
+
'resnet101': resnet101,
|
31 |
+
'resnet152': resnet152,
|
32 |
+
'resnext50_32x4d': resnext50_32x4d,
|
33 |
+
'resnext101_32x8d': resnext101_32x8d,
|
34 |
+
'resnet50_fc512': resnet50_fc512,
|
35 |
+
'se_resnet50': se_resnet50,
|
36 |
+
'se_resnet50_fc512': se_resnet50_fc512,
|
37 |
+
'se_resnet101': se_resnet101,
|
38 |
+
'se_resnext50_32x4d': se_resnext50_32x4d,
|
39 |
+
'se_resnext101_32x4d': se_resnext101_32x4d,
|
40 |
+
'densenet121': densenet121,
|
41 |
+
'densenet169': densenet169,
|
42 |
+
'densenet201': densenet201,
|
43 |
+
'densenet161': densenet161,
|
44 |
+
'densenet121_fc512': densenet121_fc512,
|
45 |
+
'inceptionresnetv2': inceptionresnetv2,
|
46 |
+
'inceptionv4': inceptionv4,
|
47 |
+
'xception': xception,
|
48 |
+
'resnet50_ibn_a': resnet50_ibn_a,
|
49 |
+
'resnet50_ibn_b': resnet50_ibn_b,
|
50 |
+
# lightweight models
|
51 |
+
'nasnsetmobile': nasnetamobile,
|
52 |
+
'mobilenetv2_x1_0': mobilenetv2_x1_0,
|
53 |
+
'mobilenetv2_x1_4': mobilenetv2_x1_4,
|
54 |
+
'shufflenet': shufflenet,
|
55 |
+
'squeezenet1_0': squeezenet1_0,
|
56 |
+
'squeezenet1_0_fc512': squeezenet1_0_fc512,
|
57 |
+
'squeezenet1_1': squeezenet1_1,
|
58 |
+
'shufflenet_v2_x0_5': shufflenet_v2_x0_5,
|
59 |
+
'shufflenet_v2_x1_0': shufflenet_v2_x1_0,
|
60 |
+
'shufflenet_v2_x1_5': shufflenet_v2_x1_5,
|
61 |
+
'shufflenet_v2_x2_0': shufflenet_v2_x2_0,
|
62 |
+
# reid-specific models
|
63 |
+
'mudeep': MuDeep,
|
64 |
+
'resnet50mid': resnet50mid,
|
65 |
+
'hacnn': HACNN,
|
66 |
+
'pcb_p6': pcb_p6,
|
67 |
+
'pcb_p4': pcb_p4,
|
68 |
+
'mlfn': mlfn,
|
69 |
+
'osnet_x1_0': osnet_x1_0,
|
70 |
+
'osnet_x0_75': osnet_x0_75,
|
71 |
+
'osnet_x0_5': osnet_x0_5,
|
72 |
+
'osnet_x0_25': osnet_x0_25,
|
73 |
+
'osnet_ibn_x1_0': osnet_ibn_x1_0,
|
74 |
+
'osnet_ain_x1_0': osnet_ain_x1_0,
|
75 |
+
'osnet_ain_x0_75': osnet_ain_x0_75,
|
76 |
+
'osnet_ain_x0_5': osnet_ain_x0_5,
|
77 |
+
'osnet_ain_x0_25': osnet_ain_x0_25
|
78 |
+
}
|
79 |
+
|
80 |
+
|
81 |
+
def show_avai_models():
|
82 |
+
"""Displays available models.
|
83 |
+
|
84 |
+
Examples::
|
85 |
+
>>> from torchreid import models
|
86 |
+
>>> models.show_avai_models()
|
87 |
+
"""
|
88 |
+
print(list(__model_factory.keys()))
|
89 |
+
|
90 |
+
|
91 |
+
def build_model(
|
92 |
+
name, num_classes, loss='softmax', pretrained=True, use_gpu=True
|
93 |
+
):
|
94 |
+
"""A function wrapper for building a model.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
name (str): model name.
|
98 |
+
num_classes (int): number of training identities.
|
99 |
+
loss (str, optional): loss function to optimize the model. Currently
|
100 |
+
supports "softmax" and "triplet". Default is "softmax".
|
101 |
+
pretrained (bool, optional): whether to load ImageNet-pretrained weights.
|
102 |
+
Default is True.
|
103 |
+
use_gpu (bool, optional): whether to use gpu. Default is True.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
nn.Module
|
107 |
+
|
108 |
+
Examples::
|
109 |
+
>>> from torchreid import models
|
110 |
+
>>> model = models.build_model('resnet50', 751, loss='softmax')
|
111 |
+
"""
|
112 |
+
avai_models = list(__model_factory.keys())
|
113 |
+
if name not in avai_models:
|
114 |
+
raise KeyError(
|
115 |
+
'Unknown model: {}. Must be one of {}'.format(name, avai_models)
|
116 |
+
)
|
117 |
+
return __model_factory[name](
|
118 |
+
num_classes=num_classes,
|
119 |
+
loss=loss,
|
120 |
+
pretrained=pretrained,
|
121 |
+
use_gpu=use_gpu
|
122 |
+
)
|
trackers/strongsort/deep/models/densenet.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code source: https://github.com/pytorch/vision
|
3 |
+
"""
|
4 |
+
from __future__ import division, absolute_import
|
5 |
+
import re
|
6 |
+
from collections import OrderedDict
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.utils import model_zoo
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
'densenet121', 'densenet169', 'densenet201', 'densenet161',
|
14 |
+
'densenet121_fc512'
|
15 |
+
]
|
16 |
+
|
17 |
+
model_urls = {
|
18 |
+
'densenet121':
|
19 |
+
'https://download.pytorch.org/models/densenet121-a639ec97.pth',
|
20 |
+
'densenet169':
|
21 |
+
'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
|
22 |
+
'densenet201':
|
23 |
+
'https://download.pytorch.org/models/densenet201-c1103571.pth',
|
24 |
+
'densenet161':
|
25 |
+
'https://download.pytorch.org/models/densenet161-8d451a50.pth',
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
class _DenseLayer(nn.Sequential):
|
30 |
+
|
31 |
+
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
|
32 |
+
super(_DenseLayer, self).__init__()
|
33 |
+
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
|
34 |
+
self.add_module('relu1', nn.ReLU(inplace=True)),
|
35 |
+
self.add_module(
|
36 |
+
'conv1',
|
37 |
+
nn.Conv2d(
|
38 |
+
num_input_features,
|
39 |
+
bn_size * growth_rate,
|
40 |
+
kernel_size=1,
|
41 |
+
stride=1,
|
42 |
+
bias=False
|
43 |
+
)
|
44 |
+
),
|
45 |
+
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
|
46 |
+
self.add_module('relu2', nn.ReLU(inplace=True)),
|
47 |
+
self.add_module(
|
48 |
+
'conv2',
|
49 |
+
nn.Conv2d(
|
50 |
+
bn_size * growth_rate,
|
51 |
+
growth_rate,
|
52 |
+
kernel_size=3,
|
53 |
+
stride=1,
|
54 |
+
padding=1,
|
55 |
+
bias=False
|
56 |
+
)
|
57 |
+
),
|
58 |
+
self.drop_rate = drop_rate
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
new_features = super(_DenseLayer, self).forward(x)
|
62 |
+
if self.drop_rate > 0:
|
63 |
+
new_features = F.dropout(
|
64 |
+
new_features, p=self.drop_rate, training=self.training
|
65 |
+
)
|
66 |
+
return torch.cat([x, new_features], 1)
|
67 |
+
|
68 |
+
|
69 |
+
class _DenseBlock(nn.Sequential):
|
70 |
+
|
71 |
+
def __init__(
|
72 |
+
self, num_layers, num_input_features, bn_size, growth_rate, drop_rate
|
73 |
+
):
|
74 |
+
super(_DenseBlock, self).__init__()
|
75 |
+
for i in range(num_layers):
|
76 |
+
layer = _DenseLayer(
|
77 |
+
num_input_features + i*growth_rate, growth_rate, bn_size,
|
78 |
+
drop_rate
|
79 |
+
)
|
80 |
+
self.add_module('denselayer%d' % (i+1), layer)
|
81 |
+
|
82 |
+
|
83 |
+
class _Transition(nn.Sequential):
|
84 |
+
|
85 |
+
def __init__(self, num_input_features, num_output_features):
|
86 |
+
super(_Transition, self).__init__()
|
87 |
+
self.add_module('norm', nn.BatchNorm2d(num_input_features))
|
88 |
+
self.add_module('relu', nn.ReLU(inplace=True))
|
89 |
+
self.add_module(
|
90 |
+
'conv',
|
91 |
+
nn.Conv2d(
|
92 |
+
num_input_features,
|
93 |
+
num_output_features,
|
94 |
+
kernel_size=1,
|
95 |
+
stride=1,
|
96 |
+
bias=False
|
97 |
+
)
|
98 |
+
)
|
99 |
+
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
|
100 |
+
|
101 |
+
|
102 |
+
class DenseNet(nn.Module):
|
103 |
+
"""Densely connected network.
|
104 |
+
|
105 |
+
Reference:
|
106 |
+
Huang et al. Densely Connected Convolutional Networks. CVPR 2017.
|
107 |
+
|
108 |
+
Public keys:
|
109 |
+
- ``densenet121``: DenseNet121.
|
110 |
+
- ``densenet169``: DenseNet169.
|
111 |
+
- ``densenet201``: DenseNet201.
|
112 |
+
- ``densenet161``: DenseNet161.
|
113 |
+
- ``densenet121_fc512``: DenseNet121 + FC.
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
num_classes,
|
119 |
+
loss,
|
120 |
+
growth_rate=32,
|
121 |
+
block_config=(6, 12, 24, 16),
|
122 |
+
num_init_features=64,
|
123 |
+
bn_size=4,
|
124 |
+
drop_rate=0,
|
125 |
+
fc_dims=None,
|
126 |
+
dropout_p=None,
|
127 |
+
**kwargs
|
128 |
+
):
|
129 |
+
|
130 |
+
super(DenseNet, self).__init__()
|
131 |
+
self.loss = loss
|
132 |
+
|
133 |
+
# First convolution
|
134 |
+
self.features = nn.Sequential(
|
135 |
+
OrderedDict(
|
136 |
+
[
|
137 |
+
(
|
138 |
+
'conv0',
|
139 |
+
nn.Conv2d(
|
140 |
+
3,
|
141 |
+
num_init_features,
|
142 |
+
kernel_size=7,
|
143 |
+
stride=2,
|
144 |
+
padding=3,
|
145 |
+
bias=False
|
146 |
+
)
|
147 |
+
),
|
148 |
+
('norm0', nn.BatchNorm2d(num_init_features)),
|
149 |
+
('relu0', nn.ReLU(inplace=True)),
|
150 |
+
(
|
151 |
+
'pool0',
|
152 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
153 |
+
),
|
154 |
+
]
|
155 |
+
)
|
156 |
+
)
|
157 |
+
|
158 |
+
# Each denseblock
|
159 |
+
num_features = num_init_features
|
160 |
+
for i, num_layers in enumerate(block_config):
|
161 |
+
block = _DenseBlock(
|
162 |
+
num_layers=num_layers,
|
163 |
+
num_input_features=num_features,
|
164 |
+
bn_size=bn_size,
|
165 |
+
growth_rate=growth_rate,
|
166 |
+
drop_rate=drop_rate
|
167 |
+
)
|
168 |
+
self.features.add_module('denseblock%d' % (i+1), block)
|
169 |
+
num_features = num_features + num_layers*growth_rate
|
170 |
+
if i != len(block_config) - 1:
|
171 |
+
trans = _Transition(
|
172 |
+
num_input_features=num_features,
|
173 |
+
num_output_features=num_features // 2
|
174 |
+
)
|
175 |
+
self.features.add_module('transition%d' % (i+1), trans)
|
176 |
+
num_features = num_features // 2
|
177 |
+
|
178 |
+
# Final batch norm
|
179 |
+
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
|
180 |
+
|
181 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
182 |
+
self.feature_dim = num_features
|
183 |
+
self.fc = self._construct_fc_layer(fc_dims, num_features, dropout_p)
|
184 |
+
|
185 |
+
# Linear layer
|
186 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
187 |
+
|
188 |
+
self._init_params()
|
189 |
+
|
190 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
191 |
+
"""Constructs fully connected layer.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
195 |
+
input_dim (int): input dimension
|
196 |
+
dropout_p (float): dropout probability, if None, dropout is unused
|
197 |
+
"""
|
198 |
+
if fc_dims is None:
|
199 |
+
self.feature_dim = input_dim
|
200 |
+
return None
|
201 |
+
|
202 |
+
assert isinstance(
|
203 |
+
fc_dims, (list, tuple)
|
204 |
+
), 'fc_dims must be either list or tuple, but got {}'.format(
|
205 |
+
type(fc_dims)
|
206 |
+
)
|
207 |
+
|
208 |
+
layers = []
|
209 |
+
for dim in fc_dims:
|
210 |
+
layers.append(nn.Linear(input_dim, dim))
|
211 |
+
layers.append(nn.BatchNorm1d(dim))
|
212 |
+
layers.append(nn.ReLU(inplace=True))
|
213 |
+
if dropout_p is not None:
|
214 |
+
layers.append(nn.Dropout(p=dropout_p))
|
215 |
+
input_dim = dim
|
216 |
+
|
217 |
+
self.feature_dim = fc_dims[-1]
|
218 |
+
|
219 |
+
return nn.Sequential(*layers)
|
220 |
+
|
221 |
+
def _init_params(self):
|
222 |
+
for m in self.modules():
|
223 |
+
if isinstance(m, nn.Conv2d):
|
224 |
+
nn.init.kaiming_normal_(
|
225 |
+
m.weight, mode='fan_out', nonlinearity='relu'
|
226 |
+
)
|
227 |
+
if m.bias is not None:
|
228 |
+
nn.init.constant_(m.bias, 0)
|
229 |
+
elif isinstance(m, nn.BatchNorm2d):
|
230 |
+
nn.init.constant_(m.weight, 1)
|
231 |
+
nn.init.constant_(m.bias, 0)
|
232 |
+
elif isinstance(m, nn.BatchNorm1d):
|
233 |
+
nn.init.constant_(m.weight, 1)
|
234 |
+
nn.init.constant_(m.bias, 0)
|
235 |
+
elif isinstance(m, nn.Linear):
|
236 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
237 |
+
if m.bias is not None:
|
238 |
+
nn.init.constant_(m.bias, 0)
|
239 |
+
|
240 |
+
def forward(self, x):
|
241 |
+
f = self.features(x)
|
242 |
+
f = F.relu(f, inplace=True)
|
243 |
+
v = self.global_avgpool(f)
|
244 |
+
v = v.view(v.size(0), -1)
|
245 |
+
|
246 |
+
if self.fc is not None:
|
247 |
+
v = self.fc(v)
|
248 |
+
|
249 |
+
if not self.training:
|
250 |
+
return v
|
251 |
+
|
252 |
+
y = self.classifier(v)
|
253 |
+
|
254 |
+
if self.loss == 'softmax':
|
255 |
+
return y
|
256 |
+
elif self.loss == 'triplet':
|
257 |
+
return y, v
|
258 |
+
else:
|
259 |
+
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
260 |
+
|
261 |
+
|
262 |
+
def init_pretrained_weights(model, model_url):
|
263 |
+
"""Initializes model with pretrained weights.
|
264 |
+
|
265 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
266 |
+
"""
|
267 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
268 |
+
|
269 |
+
# '.'s are no longer allowed in module names, but pervious _DenseLayer
|
270 |
+
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
|
271 |
+
# They are also in the checkpoints in model_urls. This pattern is used
|
272 |
+
# to find such keys.
|
273 |
+
pattern = re.compile(
|
274 |
+
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
|
275 |
+
)
|
276 |
+
for key in list(pretrain_dict.keys()):
|
277 |
+
res = pattern.match(key)
|
278 |
+
if res:
|
279 |
+
new_key = res.group(1) + res.group(2)
|
280 |
+
pretrain_dict[new_key] = pretrain_dict[key]
|
281 |
+
del pretrain_dict[key]
|
282 |
+
|
283 |
+
model_dict = model.state_dict()
|
284 |
+
pretrain_dict = {
|
285 |
+
k: v
|
286 |
+
for k, v in pretrain_dict.items()
|
287 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
288 |
+
}
|
289 |
+
model_dict.update(pretrain_dict)
|
290 |
+
model.load_state_dict(model_dict)
|
291 |
+
|
292 |
+
|
293 |
+
"""
|
294 |
+
Dense network configurations:
|
295 |
+
--
|
296 |
+
densenet121: num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)
|
297 |
+
densenet169: num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32)
|
298 |
+
densenet201: num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32)
|
299 |
+
densenet161: num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24)
|
300 |
+
"""
|
301 |
+
|
302 |
+
|
303 |
+
def densenet121(num_classes, loss='softmax', pretrained=True, **kwargs):
|
304 |
+
model = DenseNet(
|
305 |
+
num_classes=num_classes,
|
306 |
+
loss=loss,
|
307 |
+
num_init_features=64,
|
308 |
+
growth_rate=32,
|
309 |
+
block_config=(6, 12, 24, 16),
|
310 |
+
fc_dims=None,
|
311 |
+
dropout_p=None,
|
312 |
+
**kwargs
|
313 |
+
)
|
314 |
+
if pretrained:
|
315 |
+
init_pretrained_weights(model, model_urls['densenet121'])
|
316 |
+
return model
|
317 |
+
|
318 |
+
|
319 |
+
def densenet169(num_classes, loss='softmax', pretrained=True, **kwargs):
|
320 |
+
model = DenseNet(
|
321 |
+
num_classes=num_classes,
|
322 |
+
loss=loss,
|
323 |
+
num_init_features=64,
|
324 |
+
growth_rate=32,
|
325 |
+
block_config=(6, 12, 32, 32),
|
326 |
+
fc_dims=None,
|
327 |
+
dropout_p=None,
|
328 |
+
**kwargs
|
329 |
+
)
|
330 |
+
if pretrained:
|
331 |
+
init_pretrained_weights(model, model_urls['densenet169'])
|
332 |
+
return model
|
333 |
+
|
334 |
+
|
335 |
+
def densenet201(num_classes, loss='softmax', pretrained=True, **kwargs):
|
336 |
+
model = DenseNet(
|
337 |
+
num_classes=num_classes,
|
338 |
+
loss=loss,
|
339 |
+
num_init_features=64,
|
340 |
+
growth_rate=32,
|
341 |
+
block_config=(6, 12, 48, 32),
|
342 |
+
fc_dims=None,
|
343 |
+
dropout_p=None,
|
344 |
+
**kwargs
|
345 |
+
)
|
346 |
+
if pretrained:
|
347 |
+
init_pretrained_weights(model, model_urls['densenet201'])
|
348 |
+
return model
|
349 |
+
|
350 |
+
|
351 |
+
def densenet161(num_classes, loss='softmax', pretrained=True, **kwargs):
|
352 |
+
model = DenseNet(
|
353 |
+
num_classes=num_classes,
|
354 |
+
loss=loss,
|
355 |
+
num_init_features=96,
|
356 |
+
growth_rate=48,
|
357 |
+
block_config=(6, 12, 36, 24),
|
358 |
+
fc_dims=None,
|
359 |
+
dropout_p=None,
|
360 |
+
**kwargs
|
361 |
+
)
|
362 |
+
if pretrained:
|
363 |
+
init_pretrained_weights(model, model_urls['densenet161'])
|
364 |
+
return model
|
365 |
+
|
366 |
+
|
367 |
+
def densenet121_fc512(num_classes, loss='softmax', pretrained=True, **kwargs):
|
368 |
+
model = DenseNet(
|
369 |
+
num_classes=num_classes,
|
370 |
+
loss=loss,
|
371 |
+
num_init_features=64,
|
372 |
+
growth_rate=32,
|
373 |
+
block_config=(6, 12, 24, 16),
|
374 |
+
fc_dims=[512],
|
375 |
+
dropout_p=None,
|
376 |
+
**kwargs
|
377 |
+
)
|
378 |
+
if pretrained:
|
379 |
+
init_pretrained_weights(model, model_urls['densenet121'])
|
380 |
+
return model
|
trackers/strongsort/deep/models/hacnn.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, absolute_import
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
__all__ = ['HACNN']
|
7 |
+
|
8 |
+
|
9 |
+
class ConvBlock(nn.Module):
|
10 |
+
"""Basic convolutional block.
|
11 |
+
|
12 |
+
convolution + batch normalization + relu.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
in_c (int): number of input channels.
|
16 |
+
out_c (int): number of output channels.
|
17 |
+
k (int or tuple): kernel size.
|
18 |
+
s (int or tuple): stride.
|
19 |
+
p (int or tuple): padding.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, in_c, out_c, k, s=1, p=0):
|
23 |
+
super(ConvBlock, self).__init__()
|
24 |
+
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
|
25 |
+
self.bn = nn.BatchNorm2d(out_c)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return F.relu(self.bn(self.conv(x)))
|
29 |
+
|
30 |
+
|
31 |
+
class InceptionA(nn.Module):
|
32 |
+
|
33 |
+
def __init__(self, in_channels, out_channels):
|
34 |
+
super(InceptionA, self).__init__()
|
35 |
+
mid_channels = out_channels // 4
|
36 |
+
|
37 |
+
self.stream1 = nn.Sequential(
|
38 |
+
ConvBlock(in_channels, mid_channels, 1),
|
39 |
+
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
40 |
+
)
|
41 |
+
self.stream2 = nn.Sequential(
|
42 |
+
ConvBlock(in_channels, mid_channels, 1),
|
43 |
+
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
44 |
+
)
|
45 |
+
self.stream3 = nn.Sequential(
|
46 |
+
ConvBlock(in_channels, mid_channels, 1),
|
47 |
+
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
48 |
+
)
|
49 |
+
self.stream4 = nn.Sequential(
|
50 |
+
nn.AvgPool2d(3, stride=1, padding=1),
|
51 |
+
ConvBlock(in_channels, mid_channels, 1),
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
s1 = self.stream1(x)
|
56 |
+
s2 = self.stream2(x)
|
57 |
+
s3 = self.stream3(x)
|
58 |
+
s4 = self.stream4(x)
|
59 |
+
y = torch.cat([s1, s2, s3, s4], dim=1)
|
60 |
+
return y
|
61 |
+
|
62 |
+
|
63 |
+
class InceptionB(nn.Module):
|
64 |
+
|
65 |
+
def __init__(self, in_channels, out_channels):
|
66 |
+
super(InceptionB, self).__init__()
|
67 |
+
mid_channels = out_channels // 4
|
68 |
+
|
69 |
+
self.stream1 = nn.Sequential(
|
70 |
+
ConvBlock(in_channels, mid_channels, 1),
|
71 |
+
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
|
72 |
+
)
|
73 |
+
self.stream2 = nn.Sequential(
|
74 |
+
ConvBlock(in_channels, mid_channels, 1),
|
75 |
+
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
76 |
+
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
|
77 |
+
)
|
78 |
+
self.stream3 = nn.Sequential(
|
79 |
+
nn.MaxPool2d(3, stride=2, padding=1),
|
80 |
+
ConvBlock(in_channels, mid_channels * 2, 1),
|
81 |
+
)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
s1 = self.stream1(x)
|
85 |
+
s2 = self.stream2(x)
|
86 |
+
s3 = self.stream3(x)
|
87 |
+
y = torch.cat([s1, s2, s3], dim=1)
|
88 |
+
return y
|
89 |
+
|
90 |
+
|
91 |
+
class SpatialAttn(nn.Module):
|
92 |
+
"""Spatial Attention (Sec. 3.1.I.1)"""
|
93 |
+
|
94 |
+
def __init__(self):
|
95 |
+
super(SpatialAttn, self).__init__()
|
96 |
+
self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
|
97 |
+
self.conv2 = ConvBlock(1, 1, 1)
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
# global cross-channel averaging
|
101 |
+
x = x.mean(1, keepdim=True)
|
102 |
+
# 3-by-3 conv
|
103 |
+
x = self.conv1(x)
|
104 |
+
# bilinear resizing
|
105 |
+
x = F.upsample(
|
106 |
+
x, (x.size(2) * 2, x.size(3) * 2),
|
107 |
+
mode='bilinear',
|
108 |
+
align_corners=True
|
109 |
+
)
|
110 |
+
# scaling conv
|
111 |
+
x = self.conv2(x)
|
112 |
+
return x
|
113 |
+
|
114 |
+
|
115 |
+
class ChannelAttn(nn.Module):
|
116 |
+
"""Channel Attention (Sec. 3.1.I.2)"""
|
117 |
+
|
118 |
+
def __init__(self, in_channels, reduction_rate=16):
|
119 |
+
super(ChannelAttn, self).__init__()
|
120 |
+
assert in_channels % reduction_rate == 0
|
121 |
+
self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1)
|
122 |
+
self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
# squeeze operation (global average pooling)
|
126 |
+
x = F.avg_pool2d(x, x.size()[2:])
|
127 |
+
# excitation operation (2 conv layers)
|
128 |
+
x = self.conv1(x)
|
129 |
+
x = self.conv2(x)
|
130 |
+
return x
|
131 |
+
|
132 |
+
|
133 |
+
class SoftAttn(nn.Module):
|
134 |
+
"""Soft Attention (Sec. 3.1.I)
|
135 |
+
|
136 |
+
Aim: Spatial Attention + Channel Attention
|
137 |
+
|
138 |
+
Output: attention maps with shape identical to input.
|
139 |
+
"""
|
140 |
+
|
141 |
+
def __init__(self, in_channels):
|
142 |
+
super(SoftAttn, self).__init__()
|
143 |
+
self.spatial_attn = SpatialAttn()
|
144 |
+
self.channel_attn = ChannelAttn(in_channels)
|
145 |
+
self.conv = ConvBlock(in_channels, in_channels, 1)
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
y_spatial = self.spatial_attn(x)
|
149 |
+
y_channel = self.channel_attn(x)
|
150 |
+
y = y_spatial * y_channel
|
151 |
+
y = torch.sigmoid(self.conv(y))
|
152 |
+
return y
|
153 |
+
|
154 |
+
|
155 |
+
class HardAttn(nn.Module):
|
156 |
+
"""Hard Attention (Sec. 3.1.II)"""
|
157 |
+
|
158 |
+
def __init__(self, in_channels):
|
159 |
+
super(HardAttn, self).__init__()
|
160 |
+
self.fc = nn.Linear(in_channels, 4 * 2)
|
161 |
+
self.init_params()
|
162 |
+
|
163 |
+
def init_params(self):
|
164 |
+
self.fc.weight.data.zero_()
|
165 |
+
self.fc.bias.data.copy_(
|
166 |
+
torch.tensor(
|
167 |
+
[0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float
|
168 |
+
)
|
169 |
+
)
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
# squeeze operation (global average pooling)
|
173 |
+
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
|
174 |
+
# predict transformation parameters
|
175 |
+
theta = torch.tanh(self.fc(x))
|
176 |
+
theta = theta.view(-1, 4, 2)
|
177 |
+
return theta
|
178 |
+
|
179 |
+
|
180 |
+
class HarmAttn(nn.Module):
|
181 |
+
"""Harmonious Attention (Sec. 3.1)"""
|
182 |
+
|
183 |
+
def __init__(self, in_channels):
|
184 |
+
super(HarmAttn, self).__init__()
|
185 |
+
self.soft_attn = SoftAttn(in_channels)
|
186 |
+
self.hard_attn = HardAttn(in_channels)
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
y_soft_attn = self.soft_attn(x)
|
190 |
+
theta = self.hard_attn(x)
|
191 |
+
return y_soft_attn, theta
|
192 |
+
|
193 |
+
|
194 |
+
class HACNN(nn.Module):
|
195 |
+
"""Harmonious Attention Convolutional Neural Network.
|
196 |
+
|
197 |
+
Reference:
|
198 |
+
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
|
199 |
+
|
200 |
+
Public keys:
|
201 |
+
- ``hacnn``: HACNN.
|
202 |
+
"""
|
203 |
+
|
204 |
+
# Args:
|
205 |
+
# num_classes (int): number of classes to predict
|
206 |
+
# nchannels (list): number of channels AFTER concatenation
|
207 |
+
# feat_dim (int): feature dimension for a single stream
|
208 |
+
# learn_region (bool): whether to learn region features (i.e. local branch)
|
209 |
+
|
210 |
+
def __init__(
|
211 |
+
self,
|
212 |
+
num_classes,
|
213 |
+
loss='softmax',
|
214 |
+
nchannels=[128, 256, 384],
|
215 |
+
feat_dim=512,
|
216 |
+
learn_region=True,
|
217 |
+
use_gpu=True,
|
218 |
+
**kwargs
|
219 |
+
):
|
220 |
+
super(HACNN, self).__init__()
|
221 |
+
self.loss = loss
|
222 |
+
self.learn_region = learn_region
|
223 |
+
self.use_gpu = use_gpu
|
224 |
+
|
225 |
+
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
|
226 |
+
|
227 |
+
# Construct Inception + HarmAttn blocks
|
228 |
+
# ============== Block 1 ==============
|
229 |
+
self.inception1 = nn.Sequential(
|
230 |
+
InceptionA(32, nchannels[0]),
|
231 |
+
InceptionB(nchannels[0], nchannels[0]),
|
232 |
+
)
|
233 |
+
self.ha1 = HarmAttn(nchannels[0])
|
234 |
+
|
235 |
+
# ============== Block 2 ==============
|
236 |
+
self.inception2 = nn.Sequential(
|
237 |
+
InceptionA(nchannels[0], nchannels[1]),
|
238 |
+
InceptionB(nchannels[1], nchannels[1]),
|
239 |
+
)
|
240 |
+
self.ha2 = HarmAttn(nchannels[1])
|
241 |
+
|
242 |
+
# ============== Block 3 ==============
|
243 |
+
self.inception3 = nn.Sequential(
|
244 |
+
InceptionA(nchannels[1], nchannels[2]),
|
245 |
+
InceptionB(nchannels[2], nchannels[2]),
|
246 |
+
)
|
247 |
+
self.ha3 = HarmAttn(nchannels[2])
|
248 |
+
|
249 |
+
self.fc_global = nn.Sequential(
|
250 |
+
nn.Linear(nchannels[2], feat_dim),
|
251 |
+
nn.BatchNorm1d(feat_dim),
|
252 |
+
nn.ReLU(),
|
253 |
+
)
|
254 |
+
self.classifier_global = nn.Linear(feat_dim, num_classes)
|
255 |
+
|
256 |
+
if self.learn_region:
|
257 |
+
self.init_scale_factors()
|
258 |
+
self.local_conv1 = InceptionB(32, nchannels[0])
|
259 |
+
self.local_conv2 = InceptionB(nchannels[0], nchannels[1])
|
260 |
+
self.local_conv3 = InceptionB(nchannels[1], nchannels[2])
|
261 |
+
self.fc_local = nn.Sequential(
|
262 |
+
nn.Linear(nchannels[2] * 4, feat_dim),
|
263 |
+
nn.BatchNorm1d(feat_dim),
|
264 |
+
nn.ReLU(),
|
265 |
+
)
|
266 |
+
self.classifier_local = nn.Linear(feat_dim, num_classes)
|
267 |
+
self.feat_dim = feat_dim * 2
|
268 |
+
else:
|
269 |
+
self.feat_dim = feat_dim
|
270 |
+
|
271 |
+
def init_scale_factors(self):
|
272 |
+
# initialize scale factors (s_w, s_h) for four regions
|
273 |
+
self.scale_factors = []
|
274 |
+
self.scale_factors.append(
|
275 |
+
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
|
276 |
+
)
|
277 |
+
self.scale_factors.append(
|
278 |
+
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
|
279 |
+
)
|
280 |
+
self.scale_factors.append(
|
281 |
+
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
|
282 |
+
)
|
283 |
+
self.scale_factors.append(
|
284 |
+
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
|
285 |
+
)
|
286 |
+
|
287 |
+
def stn(self, x, theta):
|
288 |
+
"""Performs spatial transform
|
289 |
+
|
290 |
+
x: (batch, channel, height, width)
|
291 |
+
theta: (batch, 2, 3)
|
292 |
+
"""
|
293 |
+
grid = F.affine_grid(theta, x.size())
|
294 |
+
x = F.grid_sample(x, grid)
|
295 |
+
return x
|
296 |
+
|
297 |
+
def transform_theta(self, theta_i, region_idx):
|
298 |
+
"""Transforms theta to include (s_w, s_h), resulting in (batch, 2, 3)"""
|
299 |
+
scale_factors = self.scale_factors[region_idx]
|
300 |
+
theta = torch.zeros(theta_i.size(0), 2, 3)
|
301 |
+
theta[:, :, :2] = scale_factors
|
302 |
+
theta[:, :, -1] = theta_i
|
303 |
+
if self.use_gpu:
|
304 |
+
theta = theta.cuda()
|
305 |
+
return theta
|
306 |
+
|
307 |
+
def forward(self, x):
|
308 |
+
assert x.size(2) == 160 and x.size(3) == 64, \
|
309 |
+
'Input size does not match, expected (160, 64) but got ({}, {})'.format(x.size(2), x.size(3))
|
310 |
+
x = self.conv(x)
|
311 |
+
|
312 |
+
# ============== Block 1 ==============
|
313 |
+
# global branch
|
314 |
+
x1 = self.inception1(x)
|
315 |
+
x1_attn, x1_theta = self.ha1(x1)
|
316 |
+
x1_out = x1 * x1_attn
|
317 |
+
# local branch
|
318 |
+
if self.learn_region:
|
319 |
+
x1_local_list = []
|
320 |
+
for region_idx in range(4):
|
321 |
+
x1_theta_i = x1_theta[:, region_idx, :]
|
322 |
+
x1_theta_i = self.transform_theta(x1_theta_i, region_idx)
|
323 |
+
x1_trans_i = self.stn(x, x1_theta_i)
|
324 |
+
x1_trans_i = F.upsample(
|
325 |
+
x1_trans_i, (24, 28), mode='bilinear', align_corners=True
|
326 |
+
)
|
327 |
+
x1_local_i = self.local_conv1(x1_trans_i)
|
328 |
+
x1_local_list.append(x1_local_i)
|
329 |
+
|
330 |
+
# ============== Block 2 ==============
|
331 |
+
# Block 2
|
332 |
+
# global branch
|
333 |
+
x2 = self.inception2(x1_out)
|
334 |
+
x2_attn, x2_theta = self.ha2(x2)
|
335 |
+
x2_out = x2 * x2_attn
|
336 |
+
# local branch
|
337 |
+
if self.learn_region:
|
338 |
+
x2_local_list = []
|
339 |
+
for region_idx in range(4):
|
340 |
+
x2_theta_i = x2_theta[:, region_idx, :]
|
341 |
+
x2_theta_i = self.transform_theta(x2_theta_i, region_idx)
|
342 |
+
x2_trans_i = self.stn(x1_out, x2_theta_i)
|
343 |
+
x2_trans_i = F.upsample(
|
344 |
+
x2_trans_i, (12, 14), mode='bilinear', align_corners=True
|
345 |
+
)
|
346 |
+
x2_local_i = x2_trans_i + x1_local_list[region_idx]
|
347 |
+
x2_local_i = self.local_conv2(x2_local_i)
|
348 |
+
x2_local_list.append(x2_local_i)
|
349 |
+
|
350 |
+
# ============== Block 3 ==============
|
351 |
+
# Block 3
|
352 |
+
# global branch
|
353 |
+
x3 = self.inception3(x2_out)
|
354 |
+
x3_attn, x3_theta = self.ha3(x3)
|
355 |
+
x3_out = x3 * x3_attn
|
356 |
+
# local branch
|
357 |
+
if self.learn_region:
|
358 |
+
x3_local_list = []
|
359 |
+
for region_idx in range(4):
|
360 |
+
x3_theta_i = x3_theta[:, region_idx, :]
|
361 |
+
x3_theta_i = self.transform_theta(x3_theta_i, region_idx)
|
362 |
+
x3_trans_i = self.stn(x2_out, x3_theta_i)
|
363 |
+
x3_trans_i = F.upsample(
|
364 |
+
x3_trans_i, (6, 7), mode='bilinear', align_corners=True
|
365 |
+
)
|
366 |
+
x3_local_i = x3_trans_i + x2_local_list[region_idx]
|
367 |
+
x3_local_i = self.local_conv3(x3_local_i)
|
368 |
+
x3_local_list.append(x3_local_i)
|
369 |
+
|
370 |
+
# ============== Feature generation ==============
|
371 |
+
# global branch
|
372 |
+
x_global = F.avg_pool2d(x3_out,
|
373 |
+
x3_out.size()[2:]
|
374 |
+
).view(x3_out.size(0), x3_out.size(1))
|
375 |
+
x_global = self.fc_global(x_global)
|
376 |
+
# local branch
|
377 |
+
if self.learn_region:
|
378 |
+
x_local_list = []
|
379 |
+
for region_idx in range(4):
|
380 |
+
x_local_i = x3_local_list[region_idx]
|
381 |
+
x_local_i = F.avg_pool2d(x_local_i,
|
382 |
+
x_local_i.size()[2:]
|
383 |
+
).view(x_local_i.size(0), -1)
|
384 |
+
x_local_list.append(x_local_i)
|
385 |
+
x_local = torch.cat(x_local_list, 1)
|
386 |
+
x_local = self.fc_local(x_local)
|
387 |
+
|
388 |
+
if not self.training:
|
389 |
+
# l2 normalization before concatenation
|
390 |
+
if self.learn_region:
|
391 |
+
x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True)
|
392 |
+
x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True)
|
393 |
+
return torch.cat([x_global, x_local], 1)
|
394 |
+
else:
|
395 |
+
return x_global
|
396 |
+
|
397 |
+
prelogits_global = self.classifier_global(x_global)
|
398 |
+
if self.learn_region:
|
399 |
+
prelogits_local = self.classifier_local(x_local)
|
400 |
+
|
401 |
+
if self.loss == 'softmax':
|
402 |
+
if self.learn_region:
|
403 |
+
return (prelogits_global, prelogits_local)
|
404 |
+
else:
|
405 |
+
return prelogits_global
|
406 |
+
|
407 |
+
elif self.loss == 'triplet':
|
408 |
+
if self.learn_region:
|
409 |
+
return (prelogits_global, prelogits_local), (x_global, x_local)
|
410 |
+
else:
|
411 |
+
return prelogits_global, x_global
|
412 |
+
|
413 |
+
else:
|
414 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
trackers/strongsort/deep/models/inceptionresnetv2.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code imported from https://github.com/Cadene/pretrained-models.pytorch
|
3 |
+
"""
|
4 |
+
from __future__ import division, absolute_import
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.utils.model_zoo as model_zoo
|
8 |
+
|
9 |
+
__all__ = ['inceptionresnetv2']
|
10 |
+
|
11 |
+
pretrained_settings = {
|
12 |
+
'inceptionresnetv2': {
|
13 |
+
'imagenet': {
|
14 |
+
'url':
|
15 |
+
'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
|
16 |
+
'input_space': 'RGB',
|
17 |
+
'input_size': [3, 299, 299],
|
18 |
+
'input_range': [0, 1],
|
19 |
+
'mean': [0.5, 0.5, 0.5],
|
20 |
+
'std': [0.5, 0.5, 0.5],
|
21 |
+
'num_classes': 1000
|
22 |
+
},
|
23 |
+
'imagenet+background': {
|
24 |
+
'url':
|
25 |
+
'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
|
26 |
+
'input_space': 'RGB',
|
27 |
+
'input_size': [3, 299, 299],
|
28 |
+
'input_range': [0, 1],
|
29 |
+
'mean': [0.5, 0.5, 0.5],
|
30 |
+
'std': [0.5, 0.5, 0.5],
|
31 |
+
'num_classes': 1001
|
32 |
+
}
|
33 |
+
}
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
class BasicConv2d(nn.Module):
|
38 |
+
|
39 |
+
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
|
40 |
+
super(BasicConv2d, self).__init__()
|
41 |
+
self.conv = nn.Conv2d(
|
42 |
+
in_planes,
|
43 |
+
out_planes,
|
44 |
+
kernel_size=kernel_size,
|
45 |
+
stride=stride,
|
46 |
+
padding=padding,
|
47 |
+
bias=False
|
48 |
+
) # verify bias false
|
49 |
+
self.bn = nn.BatchNorm2d(
|
50 |
+
out_planes,
|
51 |
+
eps=0.001, # value found in tensorflow
|
52 |
+
momentum=0.1, # default pytorch value
|
53 |
+
affine=True
|
54 |
+
)
|
55 |
+
self.relu = nn.ReLU(inplace=False)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
x = self.conv(x)
|
59 |
+
x = self.bn(x)
|
60 |
+
x = self.relu(x)
|
61 |
+
return x
|
62 |
+
|
63 |
+
|
64 |
+
class Mixed_5b(nn.Module):
|
65 |
+
|
66 |
+
def __init__(self):
|
67 |
+
super(Mixed_5b, self).__init__()
|
68 |
+
|
69 |
+
self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
|
70 |
+
|
71 |
+
self.branch1 = nn.Sequential(
|
72 |
+
BasicConv2d(192, 48, kernel_size=1, stride=1),
|
73 |
+
BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
|
74 |
+
)
|
75 |
+
|
76 |
+
self.branch2 = nn.Sequential(
|
77 |
+
BasicConv2d(192, 64, kernel_size=1, stride=1),
|
78 |
+
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
|
79 |
+
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
80 |
+
)
|
81 |
+
|
82 |
+
self.branch3 = nn.Sequential(
|
83 |
+
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
84 |
+
BasicConv2d(192, 64, kernel_size=1, stride=1)
|
85 |
+
)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
x0 = self.branch0(x)
|
89 |
+
x1 = self.branch1(x)
|
90 |
+
x2 = self.branch2(x)
|
91 |
+
x3 = self.branch3(x)
|
92 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
93 |
+
return out
|
94 |
+
|
95 |
+
|
96 |
+
class Block35(nn.Module):
|
97 |
+
|
98 |
+
def __init__(self, scale=1.0):
|
99 |
+
super(Block35, self).__init__()
|
100 |
+
|
101 |
+
self.scale = scale
|
102 |
+
|
103 |
+
self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
|
104 |
+
|
105 |
+
self.branch1 = nn.Sequential(
|
106 |
+
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
107 |
+
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
|
108 |
+
)
|
109 |
+
|
110 |
+
self.branch2 = nn.Sequential(
|
111 |
+
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
112 |
+
BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
|
113 |
+
BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
|
114 |
+
)
|
115 |
+
|
116 |
+
self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
|
117 |
+
self.relu = nn.ReLU(inplace=False)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x0 = self.branch0(x)
|
121 |
+
x1 = self.branch1(x)
|
122 |
+
x2 = self.branch2(x)
|
123 |
+
out = torch.cat((x0, x1, x2), 1)
|
124 |
+
out = self.conv2d(out)
|
125 |
+
out = out * self.scale + x
|
126 |
+
out = self.relu(out)
|
127 |
+
return out
|
128 |
+
|
129 |
+
|
130 |
+
class Mixed_6a(nn.Module):
|
131 |
+
|
132 |
+
def __init__(self):
|
133 |
+
super(Mixed_6a, self).__init__()
|
134 |
+
|
135 |
+
self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
|
136 |
+
|
137 |
+
self.branch1 = nn.Sequential(
|
138 |
+
BasicConv2d(320, 256, kernel_size=1, stride=1),
|
139 |
+
BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
140 |
+
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
141 |
+
)
|
142 |
+
|
143 |
+
self.branch2 = nn.MaxPool2d(3, stride=2)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
x0 = self.branch0(x)
|
147 |
+
x1 = self.branch1(x)
|
148 |
+
x2 = self.branch2(x)
|
149 |
+
out = torch.cat((x0, x1, x2), 1)
|
150 |
+
return out
|
151 |
+
|
152 |
+
|
153 |
+
class Block17(nn.Module):
|
154 |
+
|
155 |
+
def __init__(self, scale=1.0):
|
156 |
+
super(Block17, self).__init__()
|
157 |
+
|
158 |
+
self.scale = scale
|
159 |
+
|
160 |
+
self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
|
161 |
+
|
162 |
+
self.branch1 = nn.Sequential(
|
163 |
+
BasicConv2d(1088, 128, kernel_size=1, stride=1),
|
164 |
+
BasicConv2d(
|
165 |
+
128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)
|
166 |
+
),
|
167 |
+
BasicConv2d(
|
168 |
+
160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)
|
169 |
+
)
|
170 |
+
)
|
171 |
+
|
172 |
+
self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
|
173 |
+
self.relu = nn.ReLU(inplace=False)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
x0 = self.branch0(x)
|
177 |
+
x1 = self.branch1(x)
|
178 |
+
out = torch.cat((x0, x1), 1)
|
179 |
+
out = self.conv2d(out)
|
180 |
+
out = out * self.scale + x
|
181 |
+
out = self.relu(out)
|
182 |
+
return out
|
183 |
+
|
184 |
+
|
185 |
+
class Mixed_7a(nn.Module):
|
186 |
+
|
187 |
+
def __init__(self):
|
188 |
+
super(Mixed_7a, self).__init__()
|
189 |
+
|
190 |
+
self.branch0 = nn.Sequential(
|
191 |
+
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
192 |
+
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
193 |
+
)
|
194 |
+
|
195 |
+
self.branch1 = nn.Sequential(
|
196 |
+
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
197 |
+
BasicConv2d(256, 288, kernel_size=3, stride=2)
|
198 |
+
)
|
199 |
+
|
200 |
+
self.branch2 = nn.Sequential(
|
201 |
+
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
202 |
+
BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
|
203 |
+
BasicConv2d(288, 320, kernel_size=3, stride=2)
|
204 |
+
)
|
205 |
+
|
206 |
+
self.branch3 = nn.MaxPool2d(3, stride=2)
|
207 |
+
|
208 |
+
def forward(self, x):
|
209 |
+
x0 = self.branch0(x)
|
210 |
+
x1 = self.branch1(x)
|
211 |
+
x2 = self.branch2(x)
|
212 |
+
x3 = self.branch3(x)
|
213 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
214 |
+
return out
|
215 |
+
|
216 |
+
|
217 |
+
class Block8(nn.Module):
|
218 |
+
|
219 |
+
def __init__(self, scale=1.0, noReLU=False):
|
220 |
+
super(Block8, self).__init__()
|
221 |
+
|
222 |
+
self.scale = scale
|
223 |
+
self.noReLU = noReLU
|
224 |
+
|
225 |
+
self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
|
226 |
+
|
227 |
+
self.branch1 = nn.Sequential(
|
228 |
+
BasicConv2d(2080, 192, kernel_size=1, stride=1),
|
229 |
+
BasicConv2d(
|
230 |
+
192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)
|
231 |
+
),
|
232 |
+
BasicConv2d(
|
233 |
+
224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)
|
234 |
+
)
|
235 |
+
)
|
236 |
+
|
237 |
+
self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
|
238 |
+
if not self.noReLU:
|
239 |
+
self.relu = nn.ReLU(inplace=False)
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
x0 = self.branch0(x)
|
243 |
+
x1 = self.branch1(x)
|
244 |
+
out = torch.cat((x0, x1), 1)
|
245 |
+
out = self.conv2d(out)
|
246 |
+
out = out * self.scale + x
|
247 |
+
if not self.noReLU:
|
248 |
+
out = self.relu(out)
|
249 |
+
return out
|
250 |
+
|
251 |
+
|
252 |
+
# ----------------
|
253 |
+
# Model Definition
|
254 |
+
# ----------------
|
255 |
+
class InceptionResNetV2(nn.Module):
|
256 |
+
"""Inception-ResNet-V2.
|
257 |
+
|
258 |
+
Reference:
|
259 |
+
Szegedy et al. Inception-v4, Inception-ResNet and the Impact of Residual
|
260 |
+
Connections on Learning. AAAI 2017.
|
261 |
+
|
262 |
+
Public keys:
|
263 |
+
- ``inceptionresnetv2``: Inception-ResNet-V2.
|
264 |
+
"""
|
265 |
+
|
266 |
+
def __init__(self, num_classes, loss='softmax', **kwargs):
|
267 |
+
super(InceptionResNetV2, self).__init__()
|
268 |
+
self.loss = loss
|
269 |
+
|
270 |
+
# Modules
|
271 |
+
self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
|
272 |
+
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
|
273 |
+
self.conv2d_2b = BasicConv2d(
|
274 |
+
32, 64, kernel_size=3, stride=1, padding=1
|
275 |
+
)
|
276 |
+
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
|
277 |
+
self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
|
278 |
+
self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
|
279 |
+
self.maxpool_5a = nn.MaxPool2d(3, stride=2)
|
280 |
+
self.mixed_5b = Mixed_5b()
|
281 |
+
self.repeat = nn.Sequential(
|
282 |
+
Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17),
|
283 |
+
Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17),
|
284 |
+
Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17),
|
285 |
+
Block35(scale=0.17)
|
286 |
+
)
|
287 |
+
self.mixed_6a = Mixed_6a()
|
288 |
+
self.repeat_1 = nn.Sequential(
|
289 |
+
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
|
290 |
+
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
|
291 |
+
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
|
292 |
+
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
|
293 |
+
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
|
294 |
+
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
|
295 |
+
Block17(scale=0.10), Block17(scale=0.10)
|
296 |
+
)
|
297 |
+
self.mixed_7a = Mixed_7a()
|
298 |
+
self.repeat_2 = nn.Sequential(
|
299 |
+
Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20),
|
300 |
+
Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20),
|
301 |
+
Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20)
|
302 |
+
)
|
303 |
+
|
304 |
+
self.block8 = Block8(noReLU=True)
|
305 |
+
self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
|
306 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
307 |
+
self.classifier = nn.Linear(1536, num_classes)
|
308 |
+
|
309 |
+
def load_imagenet_weights(self):
|
310 |
+
settings = pretrained_settings['inceptionresnetv2']['imagenet']
|
311 |
+
pretrain_dict = model_zoo.load_url(settings['url'])
|
312 |
+
model_dict = self.state_dict()
|
313 |
+
pretrain_dict = {
|
314 |
+
k: v
|
315 |
+
for k, v in pretrain_dict.items()
|
316 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
317 |
+
}
|
318 |
+
model_dict.update(pretrain_dict)
|
319 |
+
self.load_state_dict(model_dict)
|
320 |
+
|
321 |
+
def featuremaps(self, x):
|
322 |
+
x = self.conv2d_1a(x)
|
323 |
+
x = self.conv2d_2a(x)
|
324 |
+
x = self.conv2d_2b(x)
|
325 |
+
x = self.maxpool_3a(x)
|
326 |
+
x = self.conv2d_3b(x)
|
327 |
+
x = self.conv2d_4a(x)
|
328 |
+
x = self.maxpool_5a(x)
|
329 |
+
x = self.mixed_5b(x)
|
330 |
+
x = self.repeat(x)
|
331 |
+
x = self.mixed_6a(x)
|
332 |
+
x = self.repeat_1(x)
|
333 |
+
x = self.mixed_7a(x)
|
334 |
+
x = self.repeat_2(x)
|
335 |
+
x = self.block8(x)
|
336 |
+
x = self.conv2d_7b(x)
|
337 |
+
return x
|
338 |
+
|
339 |
+
def forward(self, x):
|
340 |
+
f = self.featuremaps(x)
|
341 |
+
v = self.global_avgpool(f)
|
342 |
+
v = v.view(v.size(0), -1)
|
343 |
+
|
344 |
+
if not self.training:
|
345 |
+
return v
|
346 |
+
|
347 |
+
y = self.classifier(v)
|
348 |
+
|
349 |
+
if self.loss == 'softmax':
|
350 |
+
return y
|
351 |
+
elif self.loss == 'triplet':
|
352 |
+
return y, v
|
353 |
+
else:
|
354 |
+
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
355 |
+
|
356 |
+
|
357 |
+
def inceptionresnetv2(num_classes, loss='softmax', pretrained=True, **kwargs):
|
358 |
+
model = InceptionResNetV2(num_classes=num_classes, loss=loss, **kwargs)
|
359 |
+
if pretrained:
|
360 |
+
model.load_imagenet_weights()
|
361 |
+
return model
|
trackers/strongsort/deep/models/inceptionv4.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, absolute_import
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.utils.model_zoo as model_zoo
|
5 |
+
|
6 |
+
__all__ = ['inceptionv4']
|
7 |
+
"""
|
8 |
+
Code imported from https://github.com/Cadene/pretrained-models.pytorch
|
9 |
+
"""
|
10 |
+
|
11 |
+
pretrained_settings = {
|
12 |
+
'inceptionv4': {
|
13 |
+
'imagenet': {
|
14 |
+
'url':
|
15 |
+
'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
|
16 |
+
'input_space': 'RGB',
|
17 |
+
'input_size': [3, 299, 299],
|
18 |
+
'input_range': [0, 1],
|
19 |
+
'mean': [0.5, 0.5, 0.5],
|
20 |
+
'std': [0.5, 0.5, 0.5],
|
21 |
+
'num_classes': 1000
|
22 |
+
},
|
23 |
+
'imagenet+background': {
|
24 |
+
'url':
|
25 |
+
'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
|
26 |
+
'input_space': 'RGB',
|
27 |
+
'input_size': [3, 299, 299],
|
28 |
+
'input_range': [0, 1],
|
29 |
+
'mean': [0.5, 0.5, 0.5],
|
30 |
+
'std': [0.5, 0.5, 0.5],
|
31 |
+
'num_classes': 1001
|
32 |
+
}
|
33 |
+
}
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
class BasicConv2d(nn.Module):
|
38 |
+
|
39 |
+
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
|
40 |
+
super(BasicConv2d, self).__init__()
|
41 |
+
self.conv = nn.Conv2d(
|
42 |
+
in_planes,
|
43 |
+
out_planes,
|
44 |
+
kernel_size=kernel_size,
|
45 |
+
stride=stride,
|
46 |
+
padding=padding,
|
47 |
+
bias=False
|
48 |
+
) # verify bias false
|
49 |
+
self.bn = nn.BatchNorm2d(
|
50 |
+
out_planes,
|
51 |
+
eps=0.001, # value found in tensorflow
|
52 |
+
momentum=0.1, # default pytorch value
|
53 |
+
affine=True
|
54 |
+
)
|
55 |
+
self.relu = nn.ReLU(inplace=True)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
x = self.conv(x)
|
59 |
+
x = self.bn(x)
|
60 |
+
x = self.relu(x)
|
61 |
+
return x
|
62 |
+
|
63 |
+
|
64 |
+
class Mixed_3a(nn.Module):
|
65 |
+
|
66 |
+
def __init__(self):
|
67 |
+
super(Mixed_3a, self).__init__()
|
68 |
+
self.maxpool = nn.MaxPool2d(3, stride=2)
|
69 |
+
self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x0 = self.maxpool(x)
|
73 |
+
x1 = self.conv(x)
|
74 |
+
out = torch.cat((x0, x1), 1)
|
75 |
+
return out
|
76 |
+
|
77 |
+
|
78 |
+
class Mixed_4a(nn.Module):
|
79 |
+
|
80 |
+
def __init__(self):
|
81 |
+
super(Mixed_4a, self).__init__()
|
82 |
+
|
83 |
+
self.branch0 = nn.Sequential(
|
84 |
+
BasicConv2d(160, 64, kernel_size=1, stride=1),
|
85 |
+
BasicConv2d(64, 96, kernel_size=3, stride=1)
|
86 |
+
)
|
87 |
+
|
88 |
+
self.branch1 = nn.Sequential(
|
89 |
+
BasicConv2d(160, 64, kernel_size=1, stride=1),
|
90 |
+
BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
91 |
+
BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
92 |
+
BasicConv2d(64, 96, kernel_size=(3, 3), stride=1)
|
93 |
+
)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
x0 = self.branch0(x)
|
97 |
+
x1 = self.branch1(x)
|
98 |
+
out = torch.cat((x0, x1), 1)
|
99 |
+
return out
|
100 |
+
|
101 |
+
|
102 |
+
class Mixed_5a(nn.Module):
|
103 |
+
|
104 |
+
def __init__(self):
|
105 |
+
super(Mixed_5a, self).__init__()
|
106 |
+
self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2)
|
107 |
+
self.maxpool = nn.MaxPool2d(3, stride=2)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
x0 = self.conv(x)
|
111 |
+
x1 = self.maxpool(x)
|
112 |
+
out = torch.cat((x0, x1), 1)
|
113 |
+
return out
|
114 |
+
|
115 |
+
|
116 |
+
class Inception_A(nn.Module):
|
117 |
+
|
118 |
+
def __init__(self):
|
119 |
+
super(Inception_A, self).__init__()
|
120 |
+
self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1)
|
121 |
+
|
122 |
+
self.branch1 = nn.Sequential(
|
123 |
+
BasicConv2d(384, 64, kernel_size=1, stride=1),
|
124 |
+
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1)
|
125 |
+
)
|
126 |
+
|
127 |
+
self.branch2 = nn.Sequential(
|
128 |
+
BasicConv2d(384, 64, kernel_size=1, stride=1),
|
129 |
+
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
|
130 |
+
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
131 |
+
)
|
132 |
+
|
133 |
+
self.branch3 = nn.Sequential(
|
134 |
+
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
135 |
+
BasicConv2d(384, 96, kernel_size=1, stride=1)
|
136 |
+
)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
x0 = self.branch0(x)
|
140 |
+
x1 = self.branch1(x)
|
141 |
+
x2 = self.branch2(x)
|
142 |
+
x3 = self.branch3(x)
|
143 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
144 |
+
return out
|
145 |
+
|
146 |
+
|
147 |
+
class Reduction_A(nn.Module):
|
148 |
+
|
149 |
+
def __init__(self):
|
150 |
+
super(Reduction_A, self).__init__()
|
151 |
+
self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2)
|
152 |
+
|
153 |
+
self.branch1 = nn.Sequential(
|
154 |
+
BasicConv2d(384, 192, kernel_size=1, stride=1),
|
155 |
+
BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1),
|
156 |
+
BasicConv2d(224, 256, kernel_size=3, stride=2)
|
157 |
+
)
|
158 |
+
|
159 |
+
self.branch2 = nn.MaxPool2d(3, stride=2)
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
x0 = self.branch0(x)
|
163 |
+
x1 = self.branch1(x)
|
164 |
+
x2 = self.branch2(x)
|
165 |
+
out = torch.cat((x0, x1, x2), 1)
|
166 |
+
return out
|
167 |
+
|
168 |
+
|
169 |
+
class Inception_B(nn.Module):
|
170 |
+
|
171 |
+
def __init__(self):
|
172 |
+
super(Inception_B, self).__init__()
|
173 |
+
self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1)
|
174 |
+
|
175 |
+
self.branch1 = nn.Sequential(
|
176 |
+
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
177 |
+
BasicConv2d(
|
178 |
+
192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)
|
179 |
+
),
|
180 |
+
BasicConv2d(
|
181 |
+
224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)
|
182 |
+
)
|
183 |
+
)
|
184 |
+
|
185 |
+
self.branch2 = nn.Sequential(
|
186 |
+
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
187 |
+
BasicConv2d(
|
188 |
+
192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)
|
189 |
+
),
|
190 |
+
BasicConv2d(
|
191 |
+
192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)
|
192 |
+
),
|
193 |
+
BasicConv2d(
|
194 |
+
224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)
|
195 |
+
),
|
196 |
+
BasicConv2d(
|
197 |
+
224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)
|
198 |
+
)
|
199 |
+
)
|
200 |
+
|
201 |
+
self.branch3 = nn.Sequential(
|
202 |
+
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
203 |
+
BasicConv2d(1024, 128, kernel_size=1, stride=1)
|
204 |
+
)
|
205 |
+
|
206 |
+
def forward(self, x):
|
207 |
+
x0 = self.branch0(x)
|
208 |
+
x1 = self.branch1(x)
|
209 |
+
x2 = self.branch2(x)
|
210 |
+
x3 = self.branch3(x)
|
211 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
212 |
+
return out
|
213 |
+
|
214 |
+
|
215 |
+
class Reduction_B(nn.Module):
|
216 |
+
|
217 |
+
def __init__(self):
|
218 |
+
super(Reduction_B, self).__init__()
|
219 |
+
|
220 |
+
self.branch0 = nn.Sequential(
|
221 |
+
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
222 |
+
BasicConv2d(192, 192, kernel_size=3, stride=2)
|
223 |
+
)
|
224 |
+
|
225 |
+
self.branch1 = nn.Sequential(
|
226 |
+
BasicConv2d(1024, 256, kernel_size=1, stride=1),
|
227 |
+
BasicConv2d(
|
228 |
+
256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)
|
229 |
+
),
|
230 |
+
BasicConv2d(
|
231 |
+
256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)
|
232 |
+
), BasicConv2d(320, 320, kernel_size=3, stride=2)
|
233 |
+
)
|
234 |
+
|
235 |
+
self.branch2 = nn.MaxPool2d(3, stride=2)
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
x0 = self.branch0(x)
|
239 |
+
x1 = self.branch1(x)
|
240 |
+
x2 = self.branch2(x)
|
241 |
+
out = torch.cat((x0, x1, x2), 1)
|
242 |
+
return out
|
243 |
+
|
244 |
+
|
245 |
+
class Inception_C(nn.Module):
|
246 |
+
|
247 |
+
def __init__(self):
|
248 |
+
super(Inception_C, self).__init__()
|
249 |
+
|
250 |
+
self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1)
|
251 |
+
|
252 |
+
self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
|
253 |
+
self.branch1_1a = BasicConv2d(
|
254 |
+
384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)
|
255 |
+
)
|
256 |
+
self.branch1_1b = BasicConv2d(
|
257 |
+
384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)
|
258 |
+
)
|
259 |
+
|
260 |
+
self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
|
261 |
+
self.branch2_1 = BasicConv2d(
|
262 |
+
384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0)
|
263 |
+
)
|
264 |
+
self.branch2_2 = BasicConv2d(
|
265 |
+
448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1)
|
266 |
+
)
|
267 |
+
self.branch2_3a = BasicConv2d(
|
268 |
+
512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)
|
269 |
+
)
|
270 |
+
self.branch2_3b = BasicConv2d(
|
271 |
+
512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)
|
272 |
+
)
|
273 |
+
|
274 |
+
self.branch3 = nn.Sequential(
|
275 |
+
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
276 |
+
BasicConv2d(1536, 256, kernel_size=1, stride=1)
|
277 |
+
)
|
278 |
+
|
279 |
+
def forward(self, x):
|
280 |
+
x0 = self.branch0(x)
|
281 |
+
|
282 |
+
x1_0 = self.branch1_0(x)
|
283 |
+
x1_1a = self.branch1_1a(x1_0)
|
284 |
+
x1_1b = self.branch1_1b(x1_0)
|
285 |
+
x1 = torch.cat((x1_1a, x1_1b), 1)
|
286 |
+
|
287 |
+
x2_0 = self.branch2_0(x)
|
288 |
+
x2_1 = self.branch2_1(x2_0)
|
289 |
+
x2_2 = self.branch2_2(x2_1)
|
290 |
+
x2_3a = self.branch2_3a(x2_2)
|
291 |
+
x2_3b = self.branch2_3b(x2_2)
|
292 |
+
x2 = torch.cat((x2_3a, x2_3b), 1)
|
293 |
+
|
294 |
+
x3 = self.branch3(x)
|
295 |
+
|
296 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
297 |
+
return out
|
298 |
+
|
299 |
+
|
300 |
+
class InceptionV4(nn.Module):
|
301 |
+
"""Inception-v4.
|
302 |
+
|
303 |
+
Reference:
|
304 |
+
Szegedy et al. Inception-v4, Inception-ResNet and the Impact of Residual
|
305 |
+
Connections on Learning. AAAI 2017.
|
306 |
+
|
307 |
+
Public keys:
|
308 |
+
- ``inceptionv4``: InceptionV4.
|
309 |
+
"""
|
310 |
+
|
311 |
+
def __init__(self, num_classes, loss, **kwargs):
|
312 |
+
super(InceptionV4, self).__init__()
|
313 |
+
self.loss = loss
|
314 |
+
|
315 |
+
self.features = nn.Sequential(
|
316 |
+
BasicConv2d(3, 32, kernel_size=3, stride=2),
|
317 |
+
BasicConv2d(32, 32, kernel_size=3, stride=1),
|
318 |
+
BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
319 |
+
Mixed_3a(),
|
320 |
+
Mixed_4a(),
|
321 |
+
Mixed_5a(),
|
322 |
+
Inception_A(),
|
323 |
+
Inception_A(),
|
324 |
+
Inception_A(),
|
325 |
+
Inception_A(),
|
326 |
+
Reduction_A(), # Mixed_6a
|
327 |
+
Inception_B(),
|
328 |
+
Inception_B(),
|
329 |
+
Inception_B(),
|
330 |
+
Inception_B(),
|
331 |
+
Inception_B(),
|
332 |
+
Inception_B(),
|
333 |
+
Inception_B(),
|
334 |
+
Reduction_B(), # Mixed_7a
|
335 |
+
Inception_C(),
|
336 |
+
Inception_C(),
|
337 |
+
Inception_C()
|
338 |
+
)
|
339 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
340 |
+
self.classifier = nn.Linear(1536, num_classes)
|
341 |
+
|
342 |
+
def forward(self, x):
|
343 |
+
f = self.features(x)
|
344 |
+
v = self.global_avgpool(f)
|
345 |
+
v = v.view(v.size(0), -1)
|
346 |
+
|
347 |
+
if not self.training:
|
348 |
+
return v
|
349 |
+
|
350 |
+
y = self.classifier(v)
|
351 |
+
|
352 |
+
if self.loss == 'softmax':
|
353 |
+
return y
|
354 |
+
elif self.loss == 'triplet':
|
355 |
+
return y, v
|
356 |
+
else:
|
357 |
+
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
358 |
+
|
359 |
+
|
360 |
+
def init_pretrained_weights(model, model_url):
|
361 |
+
"""Initializes model with pretrained weights.
|
362 |
+
|
363 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
364 |
+
"""
|
365 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
366 |
+
model_dict = model.state_dict()
|
367 |
+
pretrain_dict = {
|
368 |
+
k: v
|
369 |
+
for k, v in pretrain_dict.items()
|
370 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
371 |
+
}
|
372 |
+
model_dict.update(pretrain_dict)
|
373 |
+
model.load_state_dict(model_dict)
|
374 |
+
|
375 |
+
|
376 |
+
def inceptionv4(num_classes, loss='softmax', pretrained=True, **kwargs):
|
377 |
+
model = InceptionV4(num_classes, loss, **kwargs)
|
378 |
+
if pretrained:
|
379 |
+
model_url = pretrained_settings['inceptionv4']['imagenet']['url']
|
380 |
+
init_pretrained_weights(model, model_url)
|
381 |
+
return model
|
trackers/strongsort/deep/models/mlfn.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, absolute_import
|
2 |
+
import torch
|
3 |
+
import torch.utils.model_zoo as model_zoo
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
__all__ = ['mlfn']
|
8 |
+
|
9 |
+
model_urls = {
|
10 |
+
# training epoch = 5, top1 = 51.6
|
11 |
+
'imagenet':
|
12 |
+
'https://mega.nz/#!YHxAhaxC!yu9E6zWl0x5zscSouTdbZu8gdFFytDdl-RAdD2DEfpk',
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class MLFNBlock(nn.Module):
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self, in_channels, out_channels, stride, fsm_channels, groups=32
|
20 |
+
):
|
21 |
+
super(MLFNBlock, self).__init__()
|
22 |
+
self.groups = groups
|
23 |
+
mid_channels = out_channels // 2
|
24 |
+
|
25 |
+
# Factor Modules
|
26 |
+
self.fm_conv1 = nn.Conv2d(in_channels, mid_channels, 1, bias=False)
|
27 |
+
self.fm_bn1 = nn.BatchNorm2d(mid_channels)
|
28 |
+
self.fm_conv2 = nn.Conv2d(
|
29 |
+
mid_channels,
|
30 |
+
mid_channels,
|
31 |
+
3,
|
32 |
+
stride=stride,
|
33 |
+
padding=1,
|
34 |
+
bias=False,
|
35 |
+
groups=self.groups
|
36 |
+
)
|
37 |
+
self.fm_bn2 = nn.BatchNorm2d(mid_channels)
|
38 |
+
self.fm_conv3 = nn.Conv2d(mid_channels, out_channels, 1, bias=False)
|
39 |
+
self.fm_bn3 = nn.BatchNorm2d(out_channels)
|
40 |
+
|
41 |
+
# Factor Selection Module
|
42 |
+
self.fsm = nn.Sequential(
|
43 |
+
nn.AdaptiveAvgPool2d(1),
|
44 |
+
nn.Conv2d(in_channels, fsm_channels[0], 1),
|
45 |
+
nn.BatchNorm2d(fsm_channels[0]),
|
46 |
+
nn.ReLU(inplace=True),
|
47 |
+
nn.Conv2d(fsm_channels[0], fsm_channels[1], 1),
|
48 |
+
nn.BatchNorm2d(fsm_channels[1]),
|
49 |
+
nn.ReLU(inplace=True),
|
50 |
+
nn.Conv2d(fsm_channels[1], self.groups, 1),
|
51 |
+
nn.BatchNorm2d(self.groups),
|
52 |
+
nn.Sigmoid(),
|
53 |
+
)
|
54 |
+
|
55 |
+
self.downsample = None
|
56 |
+
if in_channels != out_channels or stride > 1:
|
57 |
+
self.downsample = nn.Sequential(
|
58 |
+
nn.Conv2d(
|
59 |
+
in_channels, out_channels, 1, stride=stride, bias=False
|
60 |
+
),
|
61 |
+
nn.BatchNorm2d(out_channels),
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
residual = x
|
66 |
+
s = self.fsm(x)
|
67 |
+
|
68 |
+
# reduce dimension
|
69 |
+
x = self.fm_conv1(x)
|
70 |
+
x = self.fm_bn1(x)
|
71 |
+
x = F.relu(x, inplace=True)
|
72 |
+
|
73 |
+
# group convolution
|
74 |
+
x = self.fm_conv2(x)
|
75 |
+
x = self.fm_bn2(x)
|
76 |
+
x = F.relu(x, inplace=True)
|
77 |
+
|
78 |
+
# factor selection
|
79 |
+
b, c = x.size(0), x.size(1)
|
80 |
+
n = c // self.groups
|
81 |
+
ss = s.repeat(1, n, 1, 1) # from (b, g, 1, 1) to (b, g*n=c, 1, 1)
|
82 |
+
ss = ss.view(b, n, self.groups, 1, 1)
|
83 |
+
ss = ss.permute(0, 2, 1, 3, 4).contiguous()
|
84 |
+
ss = ss.view(b, c, 1, 1)
|
85 |
+
x = ss * x
|
86 |
+
|
87 |
+
# recover dimension
|
88 |
+
x = self.fm_conv3(x)
|
89 |
+
x = self.fm_bn3(x)
|
90 |
+
x = F.relu(x, inplace=True)
|
91 |
+
|
92 |
+
if self.downsample is not None:
|
93 |
+
residual = self.downsample(residual)
|
94 |
+
|
95 |
+
return F.relu(residual + x, inplace=True), s
|
96 |
+
|
97 |
+
|
98 |
+
class MLFN(nn.Module):
|
99 |
+
"""Multi-Level Factorisation Net.
|
100 |
+
|
101 |
+
Reference:
|
102 |
+
Chang et al. Multi-Level Factorisation Net for
|
103 |
+
Person Re-Identification. CVPR 2018.
|
104 |
+
|
105 |
+
Public keys:
|
106 |
+
- ``mlfn``: MLFN (Multi-Level Factorisation Net).
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
num_classes,
|
112 |
+
loss='softmax',
|
113 |
+
groups=32,
|
114 |
+
channels=[64, 256, 512, 1024, 2048],
|
115 |
+
embed_dim=1024,
|
116 |
+
**kwargs
|
117 |
+
):
|
118 |
+
super(MLFN, self).__init__()
|
119 |
+
self.loss = loss
|
120 |
+
self.groups = groups
|
121 |
+
|
122 |
+
# first convolutional layer
|
123 |
+
self.conv1 = nn.Conv2d(3, channels[0], 7, stride=2, padding=3)
|
124 |
+
self.bn1 = nn.BatchNorm2d(channels[0])
|
125 |
+
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
126 |
+
|
127 |
+
# main body
|
128 |
+
self.feature = nn.ModuleList(
|
129 |
+
[
|
130 |
+
# layer 1-3
|
131 |
+
MLFNBlock(channels[0], channels[1], 1, [128, 64], self.groups),
|
132 |
+
MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups),
|
133 |
+
MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups),
|
134 |
+
# layer 4-7
|
135 |
+
MLFNBlock(
|
136 |
+
channels[1], channels[2], 2, [256, 128], self.groups
|
137 |
+
),
|
138 |
+
MLFNBlock(
|
139 |
+
channels[2], channels[2], 1, [256, 128], self.groups
|
140 |
+
),
|
141 |
+
MLFNBlock(
|
142 |
+
channels[2], channels[2], 1, [256, 128], self.groups
|
143 |
+
),
|
144 |
+
MLFNBlock(
|
145 |
+
channels[2], channels[2], 1, [256, 128], self.groups
|
146 |
+
),
|
147 |
+
# layer 8-13
|
148 |
+
MLFNBlock(
|
149 |
+
channels[2], channels[3], 2, [512, 128], self.groups
|
150 |
+
),
|
151 |
+
MLFNBlock(
|
152 |
+
channels[3], channels[3], 1, [512, 128], self.groups
|
153 |
+
),
|
154 |
+
MLFNBlock(
|
155 |
+
channels[3], channels[3], 1, [512, 128], self.groups
|
156 |
+
),
|
157 |
+
MLFNBlock(
|
158 |
+
channels[3], channels[3], 1, [512, 128], self.groups
|
159 |
+
),
|
160 |
+
MLFNBlock(
|
161 |
+
channels[3], channels[3], 1, [512, 128], self.groups
|
162 |
+
),
|
163 |
+
MLFNBlock(
|
164 |
+
channels[3], channels[3], 1, [512, 128], self.groups
|
165 |
+
),
|
166 |
+
# layer 14-16
|
167 |
+
MLFNBlock(
|
168 |
+
channels[3], channels[4], 2, [512, 128], self.groups
|
169 |
+
),
|
170 |
+
MLFNBlock(
|
171 |
+
channels[4], channels[4], 1, [512, 128], self.groups
|
172 |
+
),
|
173 |
+
MLFNBlock(
|
174 |
+
channels[4], channels[4], 1, [512, 128], self.groups
|
175 |
+
),
|
176 |
+
]
|
177 |
+
)
|
178 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
179 |
+
|
180 |
+
# projection functions
|
181 |
+
self.fc_x = nn.Sequential(
|
182 |
+
nn.Conv2d(channels[4], embed_dim, 1, bias=False),
|
183 |
+
nn.BatchNorm2d(embed_dim),
|
184 |
+
nn.ReLU(inplace=True),
|
185 |
+
)
|
186 |
+
self.fc_s = nn.Sequential(
|
187 |
+
nn.Conv2d(self.groups * 16, embed_dim, 1, bias=False),
|
188 |
+
nn.BatchNorm2d(embed_dim),
|
189 |
+
nn.ReLU(inplace=True),
|
190 |
+
)
|
191 |
+
|
192 |
+
self.classifier = nn.Linear(embed_dim, num_classes)
|
193 |
+
|
194 |
+
self.init_params()
|
195 |
+
|
196 |
+
def init_params(self):
|
197 |
+
for m in self.modules():
|
198 |
+
if isinstance(m, nn.Conv2d):
|
199 |
+
nn.init.kaiming_normal_(
|
200 |
+
m.weight, mode='fan_out', nonlinearity='relu'
|
201 |
+
)
|
202 |
+
if m.bias is not None:
|
203 |
+
nn.init.constant_(m.bias, 0)
|
204 |
+
elif isinstance(m, nn.BatchNorm2d):
|
205 |
+
nn.init.constant_(m.weight, 1)
|
206 |
+
nn.init.constant_(m.bias, 0)
|
207 |
+
elif isinstance(m, nn.Linear):
|
208 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
209 |
+
if m.bias is not None:
|
210 |
+
nn.init.constant_(m.bias, 0)
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
x = self.conv1(x)
|
214 |
+
x = self.bn1(x)
|
215 |
+
x = F.relu(x, inplace=True)
|
216 |
+
x = self.maxpool(x)
|
217 |
+
|
218 |
+
s_hat = []
|
219 |
+
for block in self.feature:
|
220 |
+
x, s = block(x)
|
221 |
+
s_hat.append(s)
|
222 |
+
s_hat = torch.cat(s_hat, 1)
|
223 |
+
|
224 |
+
x = self.global_avgpool(x)
|
225 |
+
x = self.fc_x(x)
|
226 |
+
s_hat = self.fc_s(s_hat)
|
227 |
+
|
228 |
+
v = (x+s_hat) * 0.5
|
229 |
+
v = v.view(v.size(0), -1)
|
230 |
+
|
231 |
+
if not self.training:
|
232 |
+
return v
|
233 |
+
|
234 |
+
y = self.classifier(v)
|
235 |
+
|
236 |
+
if self.loss == 'softmax':
|
237 |
+
return y
|
238 |
+
elif self.loss == 'triplet':
|
239 |
+
return y, v
|
240 |
+
else:
|
241 |
+
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
242 |
+
|
243 |
+
|
244 |
+
def init_pretrained_weights(model, model_url):
|
245 |
+
"""Initializes model with pretrained weights.
|
246 |
+
|
247 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
248 |
+
"""
|
249 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
250 |
+
model_dict = model.state_dict()
|
251 |
+
pretrain_dict = {
|
252 |
+
k: v
|
253 |
+
for k, v in pretrain_dict.items()
|
254 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
255 |
+
}
|
256 |
+
model_dict.update(pretrain_dict)
|
257 |
+
model.load_state_dict(model_dict)
|
258 |
+
|
259 |
+
|
260 |
+
def mlfn(num_classes, loss='softmax', pretrained=True, **kwargs):
|
261 |
+
model = MLFN(num_classes, loss, **kwargs)
|
262 |
+
if pretrained:
|
263 |
+
# init_pretrained_weights(model, model_urls['imagenet'])
|
264 |
+
import warnings
|
265 |
+
warnings.warn(
|
266 |
+
'The imagenet pretrained weights need to be manually downloaded from {}'
|
267 |
+
.format(model_urls['imagenet'])
|
268 |
+
)
|
269 |
+
return model
|
trackers/strongsort/deep/models/mobilenetv2.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, absolute_import
|
2 |
+
import torch.utils.model_zoo as model_zoo
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
__all__ = ['mobilenetv2_x1_0', 'mobilenetv2_x1_4']
|
7 |
+
|
8 |
+
model_urls = {
|
9 |
+
# 1.0: top-1 71.3
|
10 |
+
'mobilenetv2_x1_0':
|
11 |
+
'https://mega.nz/#!NKp2wAIA!1NH1pbNzY_M2hVk_hdsxNM1NUOWvvGPHhaNr-fASF6c',
|
12 |
+
# 1.4: top-1 73.9
|
13 |
+
'mobilenetv2_x1_4':
|
14 |
+
'https://mega.nz/#!RGhgEIwS!xN2s2ZdyqI6vQ3EwgmRXLEW3khr9tpXg96G9SUJugGk',
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
class ConvBlock(nn.Module):
|
19 |
+
"""Basic convolutional block.
|
20 |
+
|
21 |
+
convolution (bias discarded) + batch normalization + relu6.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
in_c (int): number of input channels.
|
25 |
+
out_c (int): number of output channels.
|
26 |
+
k (int or tuple): kernel size.
|
27 |
+
s (int or tuple): stride.
|
28 |
+
p (int or tuple): padding.
|
29 |
+
g (int): number of blocked connections from input channels
|
30 |
+
to output channels (default: 1).
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, in_c, out_c, k, s=1, p=0, g=1):
|
34 |
+
super(ConvBlock, self).__init__()
|
35 |
+
self.conv = nn.Conv2d(
|
36 |
+
in_c, out_c, k, stride=s, padding=p, bias=False, groups=g
|
37 |
+
)
|
38 |
+
self.bn = nn.BatchNorm2d(out_c)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return F.relu6(self.bn(self.conv(x)))
|
42 |
+
|
43 |
+
|
44 |
+
class Bottleneck(nn.Module):
|
45 |
+
|
46 |
+
def __init__(self, in_channels, out_channels, expansion_factor, stride=1):
|
47 |
+
super(Bottleneck, self).__init__()
|
48 |
+
mid_channels = in_channels * expansion_factor
|
49 |
+
self.use_residual = stride == 1 and in_channels == out_channels
|
50 |
+
self.conv1 = ConvBlock(in_channels, mid_channels, 1)
|
51 |
+
self.dwconv2 = ConvBlock(
|
52 |
+
mid_channels, mid_channels, 3, stride, 1, g=mid_channels
|
53 |
+
)
|
54 |
+
self.conv3 = nn.Sequential(
|
55 |
+
nn.Conv2d(mid_channels, out_channels, 1, bias=False),
|
56 |
+
nn.BatchNorm2d(out_channels),
|
57 |
+
)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
m = self.conv1(x)
|
61 |
+
m = self.dwconv2(m)
|
62 |
+
m = self.conv3(m)
|
63 |
+
if self.use_residual:
|
64 |
+
return x + m
|
65 |
+
else:
|
66 |
+
return m
|
67 |
+
|
68 |
+
|
69 |
+
class MobileNetV2(nn.Module):
|
70 |
+
"""MobileNetV2.
|
71 |
+
|
72 |
+
Reference:
|
73 |
+
Sandler et al. MobileNetV2: Inverted Residuals and
|
74 |
+
Linear Bottlenecks. CVPR 2018.
|
75 |
+
|
76 |
+
Public keys:
|
77 |
+
- ``mobilenetv2_x1_0``: MobileNetV2 x1.0.
|
78 |
+
- ``mobilenetv2_x1_4``: MobileNetV2 x1.4.
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
num_classes,
|
84 |
+
width_mult=1,
|
85 |
+
loss='softmax',
|
86 |
+
fc_dims=None,
|
87 |
+
dropout_p=None,
|
88 |
+
**kwargs
|
89 |
+
):
|
90 |
+
super(MobileNetV2, self).__init__()
|
91 |
+
self.loss = loss
|
92 |
+
self.in_channels = int(32 * width_mult)
|
93 |
+
self.feature_dim = int(1280 * width_mult) if width_mult > 1 else 1280
|
94 |
+
|
95 |
+
# construct layers
|
96 |
+
self.conv1 = ConvBlock(3, self.in_channels, 3, s=2, p=1)
|
97 |
+
self.conv2 = self._make_layer(
|
98 |
+
Bottleneck, 1, int(16 * width_mult), 1, 1
|
99 |
+
)
|
100 |
+
self.conv3 = self._make_layer(
|
101 |
+
Bottleneck, 6, int(24 * width_mult), 2, 2
|
102 |
+
)
|
103 |
+
self.conv4 = self._make_layer(
|
104 |
+
Bottleneck, 6, int(32 * width_mult), 3, 2
|
105 |
+
)
|
106 |
+
self.conv5 = self._make_layer(
|
107 |
+
Bottleneck, 6, int(64 * width_mult), 4, 2
|
108 |
+
)
|
109 |
+
self.conv6 = self._make_layer(
|
110 |
+
Bottleneck, 6, int(96 * width_mult), 3, 1
|
111 |
+
)
|
112 |
+
self.conv7 = self._make_layer(
|
113 |
+
Bottleneck, 6, int(160 * width_mult), 3, 2
|
114 |
+
)
|
115 |
+
self.conv8 = self._make_layer(
|
116 |
+
Bottleneck, 6, int(320 * width_mult), 1, 1
|
117 |
+
)
|
118 |
+
self.conv9 = ConvBlock(self.in_channels, self.feature_dim, 1)
|
119 |
+
|
120 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
121 |
+
self.fc = self._construct_fc_layer(
|
122 |
+
fc_dims, self.feature_dim, dropout_p
|
123 |
+
)
|
124 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
125 |
+
|
126 |
+
self._init_params()
|
127 |
+
|
128 |
+
def _make_layer(self, block, t, c, n, s):
|
129 |
+
# t: expansion factor
|
130 |
+
# c: output channels
|
131 |
+
# n: number of blocks
|
132 |
+
# s: stride for first layer
|
133 |
+
layers = []
|
134 |
+
layers.append(block(self.in_channels, c, t, s))
|
135 |
+
self.in_channels = c
|
136 |
+
for i in range(1, n):
|
137 |
+
layers.append(block(self.in_channels, c, t))
|
138 |
+
return nn.Sequential(*layers)
|
139 |
+
|
140 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
141 |
+
"""Constructs fully connected layer.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
145 |
+
input_dim (int): input dimension
|
146 |
+
dropout_p (float): dropout probability, if None, dropout is unused
|
147 |
+
"""
|
148 |
+
if fc_dims is None:
|
149 |
+
self.feature_dim = input_dim
|
150 |
+
return None
|
151 |
+
|
152 |
+
assert isinstance(
|
153 |
+
fc_dims, (list, tuple)
|
154 |
+
), 'fc_dims must be either list or tuple, but got {}'.format(
|
155 |
+
type(fc_dims)
|
156 |
+
)
|
157 |
+
|
158 |
+
layers = []
|
159 |
+
for dim in fc_dims:
|
160 |
+
layers.append(nn.Linear(input_dim, dim))
|
161 |
+
layers.append(nn.BatchNorm1d(dim))
|
162 |
+
layers.append(nn.ReLU(inplace=True))
|
163 |
+
if dropout_p is not None:
|
164 |
+
layers.append(nn.Dropout(p=dropout_p))
|
165 |
+
input_dim = dim
|
166 |
+
|
167 |
+
self.feature_dim = fc_dims[-1]
|
168 |
+
|
169 |
+
return nn.Sequential(*layers)
|
170 |
+
|
171 |
+
def _init_params(self):
|
172 |
+
for m in self.modules():
|
173 |
+
if isinstance(m, nn.Conv2d):
|
174 |
+
nn.init.kaiming_normal_(
|
175 |
+
m.weight, mode='fan_out', nonlinearity='relu'
|
176 |
+
)
|
177 |
+
if m.bias is not None:
|
178 |
+
nn.init.constant_(m.bias, 0)
|
179 |
+
elif isinstance(m, nn.BatchNorm2d):
|
180 |
+
nn.init.constant_(m.weight, 1)
|
181 |
+
nn.init.constant_(m.bias, 0)
|
182 |
+
elif isinstance(m, nn.BatchNorm1d):
|
183 |
+
nn.init.constant_(m.weight, 1)
|
184 |
+
nn.init.constant_(m.bias, 0)
|
185 |
+
elif isinstance(m, nn.Linear):
|
186 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
187 |
+
if m.bias is not None:
|
188 |
+
nn.init.constant_(m.bias, 0)
|
189 |
+
|
190 |
+
def featuremaps(self, x):
|
191 |
+
x = self.conv1(x)
|
192 |
+
x = self.conv2(x)
|
193 |
+
x = self.conv3(x)
|
194 |
+
x = self.conv4(x)
|
195 |
+
x = self.conv5(x)
|
196 |
+
x = self.conv6(x)
|
197 |
+
x = self.conv7(x)
|
198 |
+
x = self.conv8(x)
|
199 |
+
x = self.conv9(x)
|
200 |
+
return x
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
f = self.featuremaps(x)
|
204 |
+
v = self.global_avgpool(f)
|
205 |
+
v = v.view(v.size(0), -1)
|
206 |
+
|
207 |
+
if self.fc is not None:
|
208 |
+
v = self.fc(v)
|
209 |
+
|
210 |
+
if not self.training:
|
211 |
+
return v
|
212 |
+
|
213 |
+
y = self.classifier(v)
|
214 |
+
|
215 |
+
if self.loss == 'softmax':
|
216 |
+
return y
|
217 |
+
elif self.loss == 'triplet':
|
218 |
+
return y, v
|
219 |
+
else:
|
220 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
221 |
+
|
222 |
+
|
223 |
+
def init_pretrained_weights(model, model_url):
|
224 |
+
"""Initializes model with pretrained weights.
|
225 |
+
|
226 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
227 |
+
"""
|
228 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
229 |
+
model_dict = model.state_dict()
|
230 |
+
pretrain_dict = {
|
231 |
+
k: v
|
232 |
+
for k, v in pretrain_dict.items()
|
233 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
234 |
+
}
|
235 |
+
model_dict.update(pretrain_dict)
|
236 |
+
model.load_state_dict(model_dict)
|
237 |
+
|
238 |
+
|
239 |
+
def mobilenetv2_x1_0(num_classes, loss, pretrained=True, **kwargs):
|
240 |
+
model = MobileNetV2(
|
241 |
+
num_classes,
|
242 |
+
loss=loss,
|
243 |
+
width_mult=1,
|
244 |
+
fc_dims=None,
|
245 |
+
dropout_p=None,
|
246 |
+
**kwargs
|
247 |
+
)
|
248 |
+
if pretrained:
|
249 |
+
# init_pretrained_weights(model, model_urls['mobilenetv2_x1_0'])
|
250 |
+
import warnings
|
251 |
+
warnings.warn(
|
252 |
+
'The imagenet pretrained weights need to be manually downloaded from {}'
|
253 |
+
.format(model_urls['mobilenetv2_x1_0'])
|
254 |
+
)
|
255 |
+
return model
|
256 |
+
|
257 |
+
|
258 |
+
def mobilenetv2_x1_4(num_classes, loss, pretrained=True, **kwargs):
|
259 |
+
model = MobileNetV2(
|
260 |
+
num_classes,
|
261 |
+
loss=loss,
|
262 |
+
width_mult=1.4,
|
263 |
+
fc_dims=None,
|
264 |
+
dropout_p=None,
|
265 |
+
**kwargs
|
266 |
+
)
|
267 |
+
if pretrained:
|
268 |
+
# init_pretrained_weights(model, model_urls['mobilenetv2_x1_4'])
|
269 |
+
import warnings
|
270 |
+
warnings.warn(
|
271 |
+
'The imagenet pretrained weights need to be manually downloaded from {}'
|
272 |
+
.format(model_urls['mobilenetv2_x1_4'])
|
273 |
+
)
|
274 |
+
return model
|
trackers/strongsort/deep/models/mudeep.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, absolute_import
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
__all__ = ['MuDeep']
|
7 |
+
|
8 |
+
|
9 |
+
class ConvBlock(nn.Module):
|
10 |
+
"""Basic convolutional block.
|
11 |
+
|
12 |
+
convolution + batch normalization + relu.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
in_c (int): number of input channels.
|
16 |
+
out_c (int): number of output channels.
|
17 |
+
k (int or tuple): kernel size.
|
18 |
+
s (int or tuple): stride.
|
19 |
+
p (int or tuple): padding.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, in_c, out_c, k, s, p):
|
23 |
+
super(ConvBlock, self).__init__()
|
24 |
+
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
|
25 |
+
self.bn = nn.BatchNorm2d(out_c)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return F.relu(self.bn(self.conv(x)))
|
29 |
+
|
30 |
+
|
31 |
+
class ConvLayers(nn.Module):
|
32 |
+
"""Preprocessing layers."""
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
super(ConvLayers, self).__init__()
|
36 |
+
self.conv1 = ConvBlock(3, 48, k=3, s=1, p=1)
|
37 |
+
self.conv2 = ConvBlock(48, 96, k=3, s=1, p=1)
|
38 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
x = self.conv1(x)
|
42 |
+
x = self.conv2(x)
|
43 |
+
x = self.maxpool(x)
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
class MultiScaleA(nn.Module):
|
48 |
+
"""Multi-scale stream layer A (Sec.3.1)"""
|
49 |
+
|
50 |
+
def __init__(self):
|
51 |
+
super(MultiScaleA, self).__init__()
|
52 |
+
self.stream1 = nn.Sequential(
|
53 |
+
ConvBlock(96, 96, k=1, s=1, p=0),
|
54 |
+
ConvBlock(96, 24, k=3, s=1, p=1),
|
55 |
+
)
|
56 |
+
self.stream2 = nn.Sequential(
|
57 |
+
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
58 |
+
ConvBlock(96, 24, k=1, s=1, p=0),
|
59 |
+
)
|
60 |
+
self.stream3 = ConvBlock(96, 24, k=1, s=1, p=0)
|
61 |
+
self.stream4 = nn.Sequential(
|
62 |
+
ConvBlock(96, 16, k=1, s=1, p=0),
|
63 |
+
ConvBlock(16, 24, k=3, s=1, p=1),
|
64 |
+
ConvBlock(24, 24, k=3, s=1, p=1),
|
65 |
+
)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
s1 = self.stream1(x)
|
69 |
+
s2 = self.stream2(x)
|
70 |
+
s3 = self.stream3(x)
|
71 |
+
s4 = self.stream4(x)
|
72 |
+
y = torch.cat([s1, s2, s3, s4], dim=1)
|
73 |
+
return y
|
74 |
+
|
75 |
+
|
76 |
+
class Reduction(nn.Module):
|
77 |
+
"""Reduction layer (Sec.3.1)"""
|
78 |
+
|
79 |
+
def __init__(self):
|
80 |
+
super(Reduction, self).__init__()
|
81 |
+
self.stream1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
82 |
+
self.stream2 = ConvBlock(96, 96, k=3, s=2, p=1)
|
83 |
+
self.stream3 = nn.Sequential(
|
84 |
+
ConvBlock(96, 48, k=1, s=1, p=0),
|
85 |
+
ConvBlock(48, 56, k=3, s=1, p=1),
|
86 |
+
ConvBlock(56, 64, k=3, s=2, p=1),
|
87 |
+
)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
s1 = self.stream1(x)
|
91 |
+
s2 = self.stream2(x)
|
92 |
+
s3 = self.stream3(x)
|
93 |
+
y = torch.cat([s1, s2, s3], dim=1)
|
94 |
+
return y
|
95 |
+
|
96 |
+
|
97 |
+
class MultiScaleB(nn.Module):
|
98 |
+
"""Multi-scale stream layer B (Sec.3.1)"""
|
99 |
+
|
100 |
+
def __init__(self):
|
101 |
+
super(MultiScaleB, self).__init__()
|
102 |
+
self.stream1 = nn.Sequential(
|
103 |
+
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
104 |
+
ConvBlock(256, 256, k=1, s=1, p=0),
|
105 |
+
)
|
106 |
+
self.stream2 = nn.Sequential(
|
107 |
+
ConvBlock(256, 64, k=1, s=1, p=0),
|
108 |
+
ConvBlock(64, 128, k=(1, 3), s=1, p=(0, 1)),
|
109 |
+
ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)),
|
110 |
+
)
|
111 |
+
self.stream3 = ConvBlock(256, 256, k=1, s=1, p=0)
|
112 |
+
self.stream4 = nn.Sequential(
|
113 |
+
ConvBlock(256, 64, k=1, s=1, p=0),
|
114 |
+
ConvBlock(64, 64, k=(1, 3), s=1, p=(0, 1)),
|
115 |
+
ConvBlock(64, 128, k=(3, 1), s=1, p=(1, 0)),
|
116 |
+
ConvBlock(128, 128, k=(1, 3), s=1, p=(0, 1)),
|
117 |
+
ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)),
|
118 |
+
)
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
s1 = self.stream1(x)
|
122 |
+
s2 = self.stream2(x)
|
123 |
+
s3 = self.stream3(x)
|
124 |
+
s4 = self.stream4(x)
|
125 |
+
return s1, s2, s3, s4
|
126 |
+
|
127 |
+
|
128 |
+
class Fusion(nn.Module):
|
129 |
+
"""Saliency-based learning fusion layer (Sec.3.2)"""
|
130 |
+
|
131 |
+
def __init__(self):
|
132 |
+
super(Fusion, self).__init__()
|
133 |
+
self.a1 = nn.Parameter(torch.rand(1, 256, 1, 1))
|
134 |
+
self.a2 = nn.Parameter(torch.rand(1, 256, 1, 1))
|
135 |
+
self.a3 = nn.Parameter(torch.rand(1, 256, 1, 1))
|
136 |
+
self.a4 = nn.Parameter(torch.rand(1, 256, 1, 1))
|
137 |
+
|
138 |
+
# We add an average pooling layer to reduce the spatial dimension
|
139 |
+
# of feature maps, which differs from the original paper.
|
140 |
+
self.avgpool = nn.AvgPool2d(kernel_size=4, stride=4, padding=0)
|
141 |
+
|
142 |
+
def forward(self, x1, x2, x3, x4):
|
143 |
+
s1 = self.a1.expand_as(x1) * x1
|
144 |
+
s2 = self.a2.expand_as(x2) * x2
|
145 |
+
s3 = self.a3.expand_as(x3) * x3
|
146 |
+
s4 = self.a4.expand_as(x4) * x4
|
147 |
+
y = self.avgpool(s1 + s2 + s3 + s4)
|
148 |
+
return y
|
149 |
+
|
150 |
+
|
151 |
+
class MuDeep(nn.Module):
|
152 |
+
"""Multiscale deep neural network.
|
153 |
+
|
154 |
+
Reference:
|
155 |
+
Qian et al. Multi-scale Deep Learning Architectures
|
156 |
+
for Person Re-identification. ICCV 2017.
|
157 |
+
|
158 |
+
Public keys:
|
159 |
+
- ``mudeep``: Multiscale deep neural network.
|
160 |
+
"""
|
161 |
+
|
162 |
+
def __init__(self, num_classes, loss='softmax', **kwargs):
|
163 |
+
super(MuDeep, self).__init__()
|
164 |
+
self.loss = loss
|
165 |
+
|
166 |
+
self.block1 = ConvLayers()
|
167 |
+
self.block2 = MultiScaleA()
|
168 |
+
self.block3 = Reduction()
|
169 |
+
self.block4 = MultiScaleB()
|
170 |
+
self.block5 = Fusion()
|
171 |
+
|
172 |
+
# Due to this fully connected layer, input image has to be fixed
|
173 |
+
# in shape, i.e. (3, 256, 128), such that the last convolutional feature
|
174 |
+
# maps are of shape (256, 16, 8). If input shape is changed,
|
175 |
+
# the input dimension of this layer has to be changed accordingly.
|
176 |
+
self.fc = nn.Sequential(
|
177 |
+
nn.Linear(256 * 16 * 8, 4096),
|
178 |
+
nn.BatchNorm1d(4096),
|
179 |
+
nn.ReLU(),
|
180 |
+
)
|
181 |
+
self.classifier = nn.Linear(4096, num_classes)
|
182 |
+
self.feat_dim = 4096
|
183 |
+
|
184 |
+
def featuremaps(self, x):
|
185 |
+
x = self.block1(x)
|
186 |
+
x = self.block2(x)
|
187 |
+
x = self.block3(x)
|
188 |
+
x = self.block4(x)
|
189 |
+
x = self.block5(*x)
|
190 |
+
return x
|
191 |
+
|
192 |
+
def forward(self, x):
|
193 |
+
x = self.featuremaps(x)
|
194 |
+
x = x.view(x.size(0), -1)
|
195 |
+
x = self.fc(x)
|
196 |
+
y = self.classifier(x)
|
197 |
+
|
198 |
+
if not self.training:
|
199 |
+
return x
|
200 |
+
|
201 |
+
if self.loss == 'softmax':
|
202 |
+
return y
|
203 |
+
elif self.loss == 'triplet':
|
204 |
+
return y, x
|
205 |
+
else:
|
206 |
+
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
trackers/strongsort/deep/models/nasnet.py
ADDED
@@ -0,0 +1,1131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, absolute_import
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
|
7 |
+
__all__ = ['nasnetamobile']
|
8 |
+
"""
|
9 |
+
NASNet Mobile
|
10 |
+
Thanks to Anastasiia (https://github.com/DagnyT) for the great help, support and motivation!
|
11 |
+
|
12 |
+
|
13 |
+
------------------------------------------------------------------------------------
|
14 |
+
Architecture | Top-1 Acc | Top-5 Acc | Multiply-Adds | Params (M)
|
15 |
+
------------------------------------------------------------------------------------
|
16 |
+
| NASNet-A (4 @ 1056) | 74.08% | 91.74% | 564 M | 5.3 |
|
17 |
+
------------------------------------------------------------------------------------
|
18 |
+
# References:
|
19 |
+
- [Learning Transferable Architectures for Scalable Image Recognition]
|
20 |
+
(https://arxiv.org/abs/1707.07012)
|
21 |
+
"""
|
22 |
+
"""
|
23 |
+
Code imported from https://github.com/Cadene/pretrained-models.pytorch
|
24 |
+
"""
|
25 |
+
|
26 |
+
pretrained_settings = {
|
27 |
+
'nasnetamobile': {
|
28 |
+
'imagenet': {
|
29 |
+
# 'url': 'https://github.com/veronikayurchuk/pretrained-models.pytorch/releases/download/v1.0/nasnetmobile-7e03cead.pth.tar',
|
30 |
+
'url':
|
31 |
+
'http://data.lip6.fr/cadene/pretrainedmodels/nasnetamobile-7e03cead.pth',
|
32 |
+
'input_space': 'RGB',
|
33 |
+
'input_size': [3, 224, 224], # resize 256
|
34 |
+
'input_range': [0, 1],
|
35 |
+
'mean': [0.5, 0.5, 0.5],
|
36 |
+
'std': [0.5, 0.5, 0.5],
|
37 |
+
'num_classes': 1000
|
38 |
+
},
|
39 |
+
# 'imagenet+background': {
|
40 |
+
# # 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth',
|
41 |
+
# 'input_space': 'RGB',
|
42 |
+
# 'input_size': [3, 224, 224], # resize 256
|
43 |
+
# 'input_range': [0, 1],
|
44 |
+
# 'mean': [0.5, 0.5, 0.5],
|
45 |
+
# 'std': [0.5, 0.5, 0.5],
|
46 |
+
# 'num_classes': 1001
|
47 |
+
# }
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
class MaxPoolPad(nn.Module):
|
53 |
+
|
54 |
+
def __init__(self):
|
55 |
+
super(MaxPoolPad, self).__init__()
|
56 |
+
self.pad = nn.ZeroPad2d((1, 0, 1, 0))
|
57 |
+
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x = self.pad(x)
|
61 |
+
x = self.pool(x)
|
62 |
+
x = x[:, :, 1:, 1:].contiguous()
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class AvgPoolPad(nn.Module):
|
67 |
+
|
68 |
+
def __init__(self, stride=2, padding=1):
|
69 |
+
super(AvgPoolPad, self).__init__()
|
70 |
+
self.pad = nn.ZeroPad2d((1, 0, 1, 0))
|
71 |
+
self.pool = nn.AvgPool2d(
|
72 |
+
3, stride=stride, padding=padding, count_include_pad=False
|
73 |
+
)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
x = self.pad(x)
|
77 |
+
x = self.pool(x)
|
78 |
+
x = x[:, :, 1:, 1:].contiguous()
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
class SeparableConv2d(nn.Module):
|
83 |
+
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
in_channels,
|
87 |
+
out_channels,
|
88 |
+
dw_kernel,
|
89 |
+
dw_stride,
|
90 |
+
dw_padding,
|
91 |
+
bias=False
|
92 |
+
):
|
93 |
+
super(SeparableConv2d, self).__init__()
|
94 |
+
self.depthwise_conv2d = nn.Conv2d(
|
95 |
+
in_channels,
|
96 |
+
in_channels,
|
97 |
+
dw_kernel,
|
98 |
+
stride=dw_stride,
|
99 |
+
padding=dw_padding,
|
100 |
+
bias=bias,
|
101 |
+
groups=in_channels
|
102 |
+
)
|
103 |
+
self.pointwise_conv2d = nn.Conv2d(
|
104 |
+
in_channels, out_channels, 1, stride=1, bias=bias
|
105 |
+
)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
x = self.depthwise_conv2d(x)
|
109 |
+
x = self.pointwise_conv2d(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class BranchSeparables(nn.Module):
|
114 |
+
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
in_channels,
|
118 |
+
out_channels,
|
119 |
+
kernel_size,
|
120 |
+
stride,
|
121 |
+
padding,
|
122 |
+
name=None,
|
123 |
+
bias=False
|
124 |
+
):
|
125 |
+
super(BranchSeparables, self).__init__()
|
126 |
+
self.relu = nn.ReLU()
|
127 |
+
self.separable_1 = SeparableConv2d(
|
128 |
+
in_channels, in_channels, kernel_size, stride, padding, bias=bias
|
129 |
+
)
|
130 |
+
self.bn_sep_1 = nn.BatchNorm2d(
|
131 |
+
in_channels, eps=0.001, momentum=0.1, affine=True
|
132 |
+
)
|
133 |
+
self.relu1 = nn.ReLU()
|
134 |
+
self.separable_2 = SeparableConv2d(
|
135 |
+
in_channels, out_channels, kernel_size, 1, padding, bias=bias
|
136 |
+
)
|
137 |
+
self.bn_sep_2 = nn.BatchNorm2d(
|
138 |
+
out_channels, eps=0.001, momentum=0.1, affine=True
|
139 |
+
)
|
140 |
+
self.name = name
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
x = self.relu(x)
|
144 |
+
if self.name == 'specific':
|
145 |
+
x = nn.ZeroPad2d((1, 0, 1, 0))(x)
|
146 |
+
x = self.separable_1(x)
|
147 |
+
if self.name == 'specific':
|
148 |
+
x = x[:, :, 1:, 1:].contiguous()
|
149 |
+
|
150 |
+
x = self.bn_sep_1(x)
|
151 |
+
x = self.relu1(x)
|
152 |
+
x = self.separable_2(x)
|
153 |
+
x = self.bn_sep_2(x)
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class BranchSeparablesStem(nn.Module):
|
158 |
+
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
in_channels,
|
162 |
+
out_channels,
|
163 |
+
kernel_size,
|
164 |
+
stride,
|
165 |
+
padding,
|
166 |
+
bias=False
|
167 |
+
):
|
168 |
+
super(BranchSeparablesStem, self).__init__()
|
169 |
+
self.relu = nn.ReLU()
|
170 |
+
self.separable_1 = SeparableConv2d(
|
171 |
+
in_channels, out_channels, kernel_size, stride, padding, bias=bias
|
172 |
+
)
|
173 |
+
self.bn_sep_1 = nn.BatchNorm2d(
|
174 |
+
out_channels, eps=0.001, momentum=0.1, affine=True
|
175 |
+
)
|
176 |
+
self.relu1 = nn.ReLU()
|
177 |
+
self.separable_2 = SeparableConv2d(
|
178 |
+
out_channels, out_channels, kernel_size, 1, padding, bias=bias
|
179 |
+
)
|
180 |
+
self.bn_sep_2 = nn.BatchNorm2d(
|
181 |
+
out_channels, eps=0.001, momentum=0.1, affine=True
|
182 |
+
)
|
183 |
+
|
184 |
+
def forward(self, x):
|
185 |
+
x = self.relu(x)
|
186 |
+
x = self.separable_1(x)
|
187 |
+
x = self.bn_sep_1(x)
|
188 |
+
x = self.relu1(x)
|
189 |
+
x = self.separable_2(x)
|
190 |
+
x = self.bn_sep_2(x)
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
class BranchSeparablesReduction(BranchSeparables):
|
195 |
+
|
196 |
+
def __init__(
|
197 |
+
self,
|
198 |
+
in_channels,
|
199 |
+
out_channels,
|
200 |
+
kernel_size,
|
201 |
+
stride,
|
202 |
+
padding,
|
203 |
+
z_padding=1,
|
204 |
+
bias=False
|
205 |
+
):
|
206 |
+
BranchSeparables.__init__(
|
207 |
+
self, in_channels, out_channels, kernel_size, stride, padding, bias
|
208 |
+
)
|
209 |
+
self.padding = nn.ZeroPad2d((z_padding, 0, z_padding, 0))
|
210 |
+
|
211 |
+
def forward(self, x):
|
212 |
+
x = self.relu(x)
|
213 |
+
x = self.padding(x)
|
214 |
+
x = self.separable_1(x)
|
215 |
+
x = x[:, :, 1:, 1:].contiguous()
|
216 |
+
x = self.bn_sep_1(x)
|
217 |
+
x = self.relu1(x)
|
218 |
+
x = self.separable_2(x)
|
219 |
+
x = self.bn_sep_2(x)
|
220 |
+
return x
|
221 |
+
|
222 |
+
|
223 |
+
class CellStem0(nn.Module):
|
224 |
+
|
225 |
+
def __init__(self, stem_filters, num_filters=42):
|
226 |
+
super(CellStem0, self).__init__()
|
227 |
+
self.num_filters = num_filters
|
228 |
+
self.stem_filters = stem_filters
|
229 |
+
self.conv_1x1 = nn.Sequential()
|
230 |
+
self.conv_1x1.add_module('relu', nn.ReLU())
|
231 |
+
self.conv_1x1.add_module(
|
232 |
+
'conv',
|
233 |
+
nn.Conv2d(
|
234 |
+
self.stem_filters, self.num_filters, 1, stride=1, bias=False
|
235 |
+
)
|
236 |
+
)
|
237 |
+
self.conv_1x1.add_module(
|
238 |
+
'bn',
|
239 |
+
nn.BatchNorm2d(
|
240 |
+
self.num_filters, eps=0.001, momentum=0.1, affine=True
|
241 |
+
)
|
242 |
+
)
|
243 |
+
|
244 |
+
self.comb_iter_0_left = BranchSeparables(
|
245 |
+
self.num_filters, self.num_filters, 5, 2, 2
|
246 |
+
)
|
247 |
+
self.comb_iter_0_right = BranchSeparablesStem(
|
248 |
+
self.stem_filters, self.num_filters, 7, 2, 3, bias=False
|
249 |
+
)
|
250 |
+
|
251 |
+
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
|
252 |
+
self.comb_iter_1_right = BranchSeparablesStem(
|
253 |
+
self.stem_filters, self.num_filters, 7, 2, 3, bias=False
|
254 |
+
)
|
255 |
+
|
256 |
+
self.comb_iter_2_left = nn.AvgPool2d(
|
257 |
+
3, stride=2, padding=1, count_include_pad=False
|
258 |
+
)
|
259 |
+
self.comb_iter_2_right = BranchSeparablesStem(
|
260 |
+
self.stem_filters, self.num_filters, 5, 2, 2, bias=False
|
261 |
+
)
|
262 |
+
|
263 |
+
self.comb_iter_3_right = nn.AvgPool2d(
|
264 |
+
3, stride=1, padding=1, count_include_pad=False
|
265 |
+
)
|
266 |
+
|
267 |
+
self.comb_iter_4_left = BranchSeparables(
|
268 |
+
self.num_filters, self.num_filters, 3, 1, 1, bias=False
|
269 |
+
)
|
270 |
+
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
|
271 |
+
|
272 |
+
def forward(self, x):
|
273 |
+
x1 = self.conv_1x1(x)
|
274 |
+
|
275 |
+
x_comb_iter_0_left = self.comb_iter_0_left(x1)
|
276 |
+
x_comb_iter_0_right = self.comb_iter_0_right(x)
|
277 |
+
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
278 |
+
|
279 |
+
x_comb_iter_1_left = self.comb_iter_1_left(x1)
|
280 |
+
x_comb_iter_1_right = self.comb_iter_1_right(x)
|
281 |
+
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
282 |
+
|
283 |
+
x_comb_iter_2_left = self.comb_iter_2_left(x1)
|
284 |
+
x_comb_iter_2_right = self.comb_iter_2_right(x)
|
285 |
+
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
286 |
+
|
287 |
+
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
288 |
+
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
289 |
+
|
290 |
+
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
291 |
+
x_comb_iter_4_right = self.comb_iter_4_right(x1)
|
292 |
+
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
293 |
+
|
294 |
+
x_out = torch.cat(
|
295 |
+
[x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1
|
296 |
+
)
|
297 |
+
return x_out
|
298 |
+
|
299 |
+
|
300 |
+
class CellStem1(nn.Module):
|
301 |
+
|
302 |
+
def __init__(self, stem_filters, num_filters):
|
303 |
+
super(CellStem1, self).__init__()
|
304 |
+
self.num_filters = num_filters
|
305 |
+
self.stem_filters = stem_filters
|
306 |
+
self.conv_1x1 = nn.Sequential()
|
307 |
+
self.conv_1x1.add_module('relu', nn.ReLU())
|
308 |
+
self.conv_1x1.add_module(
|
309 |
+
'conv',
|
310 |
+
nn.Conv2d(
|
311 |
+
2 * self.num_filters,
|
312 |
+
self.num_filters,
|
313 |
+
1,
|
314 |
+
stride=1,
|
315 |
+
bias=False
|
316 |
+
)
|
317 |
+
)
|
318 |
+
self.conv_1x1.add_module(
|
319 |
+
'bn',
|
320 |
+
nn.BatchNorm2d(
|
321 |
+
self.num_filters, eps=0.001, momentum=0.1, affine=True
|
322 |
+
)
|
323 |
+
)
|
324 |
+
|
325 |
+
self.relu = nn.ReLU()
|
326 |
+
self.path_1 = nn.Sequential()
|
327 |
+
self.path_1.add_module(
|
328 |
+
'avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)
|
329 |
+
)
|
330 |
+
self.path_1.add_module(
|
331 |
+
'conv',
|
332 |
+
nn.Conv2d(
|
333 |
+
self.stem_filters,
|
334 |
+
self.num_filters // 2,
|
335 |
+
1,
|
336 |
+
stride=1,
|
337 |
+
bias=False
|
338 |
+
)
|
339 |
+
)
|
340 |
+
self.path_2 = nn.ModuleList()
|
341 |
+
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
|
342 |
+
self.path_2.add_module(
|
343 |
+
'avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)
|
344 |
+
)
|
345 |
+
self.path_2.add_module(
|
346 |
+
'conv',
|
347 |
+
nn.Conv2d(
|
348 |
+
self.stem_filters,
|
349 |
+
self.num_filters // 2,
|
350 |
+
1,
|
351 |
+
stride=1,
|
352 |
+
bias=False
|
353 |
+
)
|
354 |
+
)
|
355 |
+
|
356 |
+
self.final_path_bn = nn.BatchNorm2d(
|
357 |
+
self.num_filters, eps=0.001, momentum=0.1, affine=True
|
358 |
+
)
|
359 |
+
|
360 |
+
self.comb_iter_0_left = BranchSeparables(
|
361 |
+
self.num_filters,
|
362 |
+
self.num_filters,
|
363 |
+
5,
|
364 |
+
2,
|
365 |
+
2,
|
366 |
+
name='specific',
|
367 |
+
bias=False
|
368 |
+
)
|
369 |
+
self.comb_iter_0_right = BranchSeparables(
|
370 |
+
self.num_filters,
|
371 |
+
self.num_filters,
|
372 |
+
7,
|
373 |
+
2,
|
374 |
+
3,
|
375 |
+
name='specific',
|
376 |
+
bias=False
|
377 |
+
)
|
378 |
+
|
379 |
+
# self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
|
380 |
+
self.comb_iter_1_left = MaxPoolPad()
|
381 |
+
self.comb_iter_1_right = BranchSeparables(
|
382 |
+
self.num_filters,
|
383 |
+
self.num_filters,
|
384 |
+
7,
|
385 |
+
2,
|
386 |
+
3,
|
387 |
+
name='specific',
|
388 |
+
bias=False
|
389 |
+
)
|
390 |
+
|
391 |
+
# self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
|
392 |
+
self.comb_iter_2_left = AvgPoolPad()
|
393 |
+
self.comb_iter_2_right = BranchSeparables(
|
394 |
+
self.num_filters,
|
395 |
+
self.num_filters,
|
396 |
+
5,
|
397 |
+
2,
|
398 |
+
2,
|
399 |
+
name='specific',
|
400 |
+
bias=False
|
401 |
+
)
|
402 |
+
|
403 |
+
self.comb_iter_3_right = nn.AvgPool2d(
|
404 |
+
3, stride=1, padding=1, count_include_pad=False
|
405 |
+
)
|
406 |
+
|
407 |
+
self.comb_iter_4_left = BranchSeparables(
|
408 |
+
self.num_filters,
|
409 |
+
self.num_filters,
|
410 |
+
3,
|
411 |
+
1,
|
412 |
+
1,
|
413 |
+
name='specific',
|
414 |
+
bias=False
|
415 |
+
)
|
416 |
+
# self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
|
417 |
+
self.comb_iter_4_right = MaxPoolPad()
|
418 |
+
|
419 |
+
def forward(self, x_conv0, x_stem_0):
|
420 |
+
x_left = self.conv_1x1(x_stem_0)
|
421 |
+
|
422 |
+
x_relu = self.relu(x_conv0)
|
423 |
+
# path 1
|
424 |
+
x_path1 = self.path_1(x_relu)
|
425 |
+
# path 2
|
426 |
+
x_path2 = self.path_2.pad(x_relu)
|
427 |
+
x_path2 = x_path2[:, :, 1:, 1:]
|
428 |
+
x_path2 = self.path_2.avgpool(x_path2)
|
429 |
+
x_path2 = self.path_2.conv(x_path2)
|
430 |
+
# final path
|
431 |
+
x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
432 |
+
|
433 |
+
x_comb_iter_0_left = self.comb_iter_0_left(x_left)
|
434 |
+
x_comb_iter_0_right = self.comb_iter_0_right(x_right)
|
435 |
+
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
436 |
+
|
437 |
+
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
|
438 |
+
x_comb_iter_1_right = self.comb_iter_1_right(x_right)
|
439 |
+
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
440 |
+
|
441 |
+
x_comb_iter_2_left = self.comb_iter_2_left(x_left)
|
442 |
+
x_comb_iter_2_right = self.comb_iter_2_right(x_right)
|
443 |
+
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
444 |
+
|
445 |
+
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
446 |
+
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
447 |
+
|
448 |
+
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
449 |
+
x_comb_iter_4_right = self.comb_iter_4_right(x_left)
|
450 |
+
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
451 |
+
|
452 |
+
x_out = torch.cat(
|
453 |
+
[x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1
|
454 |
+
)
|
455 |
+
return x_out
|
456 |
+
|
457 |
+
|
458 |
+
class FirstCell(nn.Module):
|
459 |
+
|
460 |
+
def __init__(
|
461 |
+
self, in_channels_left, out_channels_left, in_channels_right,
|
462 |
+
out_channels_right
|
463 |
+
):
|
464 |
+
super(FirstCell, self).__init__()
|
465 |
+
self.conv_1x1 = nn.Sequential()
|
466 |
+
self.conv_1x1.add_module('relu', nn.ReLU())
|
467 |
+
self.conv_1x1.add_module(
|
468 |
+
'conv',
|
469 |
+
nn.Conv2d(
|
470 |
+
in_channels_right, out_channels_right, 1, stride=1, bias=False
|
471 |
+
)
|
472 |
+
)
|
473 |
+
self.conv_1x1.add_module(
|
474 |
+
'bn',
|
475 |
+
nn.BatchNorm2d(
|
476 |
+
out_channels_right, eps=0.001, momentum=0.1, affine=True
|
477 |
+
)
|
478 |
+
)
|
479 |
+
|
480 |
+
self.relu = nn.ReLU()
|
481 |
+
self.path_1 = nn.Sequential()
|
482 |
+
self.path_1.add_module(
|
483 |
+
'avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)
|
484 |
+
)
|
485 |
+
self.path_1.add_module(
|
486 |
+
'conv',
|
487 |
+
nn.Conv2d(
|
488 |
+
in_channels_left, out_channels_left, 1, stride=1, bias=False
|
489 |
+
)
|
490 |
+
)
|
491 |
+
self.path_2 = nn.ModuleList()
|
492 |
+
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
|
493 |
+
self.path_2.add_module(
|
494 |
+
'avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)
|
495 |
+
)
|
496 |
+
self.path_2.add_module(
|
497 |
+
'conv',
|
498 |
+
nn.Conv2d(
|
499 |
+
in_channels_left, out_channels_left, 1, stride=1, bias=False
|
500 |
+
)
|
501 |
+
)
|
502 |
+
|
503 |
+
self.final_path_bn = nn.BatchNorm2d(
|
504 |
+
out_channels_left * 2, eps=0.001, momentum=0.1, affine=True
|
505 |
+
)
|
506 |
+
|
507 |
+
self.comb_iter_0_left = BranchSeparables(
|
508 |
+
out_channels_right, out_channels_right, 5, 1, 2, bias=False
|
509 |
+
)
|
510 |
+
self.comb_iter_0_right = BranchSeparables(
|
511 |
+
out_channels_right, out_channels_right, 3, 1, 1, bias=False
|
512 |
+
)
|
513 |
+
|
514 |
+
self.comb_iter_1_left = BranchSeparables(
|
515 |
+
out_channels_right, out_channels_right, 5, 1, 2, bias=False
|
516 |
+
)
|
517 |
+
self.comb_iter_1_right = BranchSeparables(
|
518 |
+
out_channels_right, out_channels_right, 3, 1, 1, bias=False
|
519 |
+
)
|
520 |
+
|
521 |
+
self.comb_iter_2_left = nn.AvgPool2d(
|
522 |
+
3, stride=1, padding=1, count_include_pad=False
|
523 |
+
)
|
524 |
+
|
525 |
+
self.comb_iter_3_left = nn.AvgPool2d(
|
526 |
+
3, stride=1, padding=1, count_include_pad=False
|
527 |
+
)
|
528 |
+
self.comb_iter_3_right = nn.AvgPool2d(
|
529 |
+
3, stride=1, padding=1, count_include_pad=False
|
530 |
+
)
|
531 |
+
|
532 |
+
self.comb_iter_4_left = BranchSeparables(
|
533 |
+
out_channels_right, out_channels_right, 3, 1, 1, bias=False
|
534 |
+
)
|
535 |
+
|
536 |
+
def forward(self, x, x_prev):
|
537 |
+
x_relu = self.relu(x_prev)
|
538 |
+
# path 1
|
539 |
+
x_path1 = self.path_1(x_relu)
|
540 |
+
# path 2
|
541 |
+
x_path2 = self.path_2.pad(x_relu)
|
542 |
+
x_path2 = x_path2[:, :, 1:, 1:]
|
543 |
+
x_path2 = self.path_2.avgpool(x_path2)
|
544 |
+
x_path2 = self.path_2.conv(x_path2)
|
545 |
+
# final path
|
546 |
+
x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
547 |
+
|
548 |
+
x_right = self.conv_1x1(x)
|
549 |
+
|
550 |
+
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
551 |
+
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
552 |
+
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
553 |
+
|
554 |
+
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
|
555 |
+
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
556 |
+
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
557 |
+
|
558 |
+
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
559 |
+
x_comb_iter_2 = x_comb_iter_2_left + x_left
|
560 |
+
|
561 |
+
x_comb_iter_3_left = self.comb_iter_3_left(x_left)
|
562 |
+
x_comb_iter_3_right = self.comb_iter_3_right(x_left)
|
563 |
+
x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
|
564 |
+
|
565 |
+
x_comb_iter_4_left = self.comb_iter_4_left(x_right)
|
566 |
+
x_comb_iter_4 = x_comb_iter_4_left + x_right
|
567 |
+
|
568 |
+
x_out = torch.cat(
|
569 |
+
[
|
570 |
+
x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2,
|
571 |
+
x_comb_iter_3, x_comb_iter_4
|
572 |
+
], 1
|
573 |
+
)
|
574 |
+
return x_out
|
575 |
+
|
576 |
+
|
577 |
+
class NormalCell(nn.Module):
|
578 |
+
|
579 |
+
def __init__(
|
580 |
+
self, in_channels_left, out_channels_left, in_channels_right,
|
581 |
+
out_channels_right
|
582 |
+
):
|
583 |
+
super(NormalCell, self).__init__()
|
584 |
+
self.conv_prev_1x1 = nn.Sequential()
|
585 |
+
self.conv_prev_1x1.add_module('relu', nn.ReLU())
|
586 |
+
self.conv_prev_1x1.add_module(
|
587 |
+
'conv',
|
588 |
+
nn.Conv2d(
|
589 |
+
in_channels_left, out_channels_left, 1, stride=1, bias=False
|
590 |
+
)
|
591 |
+
)
|
592 |
+
self.conv_prev_1x1.add_module(
|
593 |
+
'bn',
|
594 |
+
nn.BatchNorm2d(
|
595 |
+
out_channels_left, eps=0.001, momentum=0.1, affine=True
|
596 |
+
)
|
597 |
+
)
|
598 |
+
|
599 |
+
self.conv_1x1 = nn.Sequential()
|
600 |
+
self.conv_1x1.add_module('relu', nn.ReLU())
|
601 |
+
self.conv_1x1.add_module(
|
602 |
+
'conv',
|
603 |
+
nn.Conv2d(
|
604 |
+
in_channels_right, out_channels_right, 1, stride=1, bias=False
|
605 |
+
)
|
606 |
+
)
|
607 |
+
self.conv_1x1.add_module(
|
608 |
+
'bn',
|
609 |
+
nn.BatchNorm2d(
|
610 |
+
out_channels_right, eps=0.001, momentum=0.1, affine=True
|
611 |
+
)
|
612 |
+
)
|
613 |
+
|
614 |
+
self.comb_iter_0_left = BranchSeparables(
|
615 |
+
out_channels_right, out_channels_right, 5, 1, 2, bias=False
|
616 |
+
)
|
617 |
+
self.comb_iter_0_right = BranchSeparables(
|
618 |
+
out_channels_left, out_channels_left, 3, 1, 1, bias=False
|
619 |
+
)
|
620 |
+
|
621 |
+
self.comb_iter_1_left = BranchSeparables(
|
622 |
+
out_channels_left, out_channels_left, 5, 1, 2, bias=False
|
623 |
+
)
|
624 |
+
self.comb_iter_1_right = BranchSeparables(
|
625 |
+
out_channels_left, out_channels_left, 3, 1, 1, bias=False
|
626 |
+
)
|
627 |
+
|
628 |
+
self.comb_iter_2_left = nn.AvgPool2d(
|
629 |
+
3, stride=1, padding=1, count_include_pad=False
|
630 |
+
)
|
631 |
+
|
632 |
+
self.comb_iter_3_left = nn.AvgPool2d(
|
633 |
+
3, stride=1, padding=1, count_include_pad=False
|
634 |
+
)
|
635 |
+
self.comb_iter_3_right = nn.AvgPool2d(
|
636 |
+
3, stride=1, padding=1, count_include_pad=False
|
637 |
+
)
|
638 |
+
|
639 |
+
self.comb_iter_4_left = BranchSeparables(
|
640 |
+
out_channels_right, out_channels_right, 3, 1, 1, bias=False
|
641 |
+
)
|
642 |
+
|
643 |
+
def forward(self, x, x_prev):
|
644 |
+
x_left = self.conv_prev_1x1(x_prev)
|
645 |
+
x_right = self.conv_1x1(x)
|
646 |
+
|
647 |
+
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
648 |
+
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
649 |
+
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
650 |
+
|
651 |
+
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
|
652 |
+
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
653 |
+
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
654 |
+
|
655 |
+
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
656 |
+
x_comb_iter_2 = x_comb_iter_2_left + x_left
|
657 |
+
|
658 |
+
x_comb_iter_3_left = self.comb_iter_3_left(x_left)
|
659 |
+
x_comb_iter_3_right = self.comb_iter_3_right(x_left)
|
660 |
+
x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
|
661 |
+
|
662 |
+
x_comb_iter_4_left = self.comb_iter_4_left(x_right)
|
663 |
+
x_comb_iter_4 = x_comb_iter_4_left + x_right
|
664 |
+
|
665 |
+
x_out = torch.cat(
|
666 |
+
[
|
667 |
+
x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2,
|
668 |
+
x_comb_iter_3, x_comb_iter_4
|
669 |
+
], 1
|
670 |
+
)
|
671 |
+
return x_out
|
672 |
+
|
673 |
+
|
674 |
+
class ReductionCell0(nn.Module):
|
675 |
+
|
676 |
+
def __init__(
|
677 |
+
self, in_channels_left, out_channels_left, in_channels_right,
|
678 |
+
out_channels_right
|
679 |
+
):
|
680 |
+
super(ReductionCell0, self).__init__()
|
681 |
+
self.conv_prev_1x1 = nn.Sequential()
|
682 |
+
self.conv_prev_1x1.add_module('relu', nn.ReLU())
|
683 |
+
self.conv_prev_1x1.add_module(
|
684 |
+
'conv',
|
685 |
+
nn.Conv2d(
|
686 |
+
in_channels_left, out_channels_left, 1, stride=1, bias=False
|
687 |
+
)
|
688 |
+
)
|
689 |
+
self.conv_prev_1x1.add_module(
|
690 |
+
'bn',
|
691 |
+
nn.BatchNorm2d(
|
692 |
+
out_channels_left, eps=0.001, momentum=0.1, affine=True
|
693 |
+
)
|
694 |
+
)
|
695 |
+
|
696 |
+
self.conv_1x1 = nn.Sequential()
|
697 |
+
self.conv_1x1.add_module('relu', nn.ReLU())
|
698 |
+
self.conv_1x1.add_module(
|
699 |
+
'conv',
|
700 |
+
nn.Conv2d(
|
701 |
+
in_channels_right, out_channels_right, 1, stride=1, bias=False
|
702 |
+
)
|
703 |
+
)
|
704 |
+
self.conv_1x1.add_module(
|
705 |
+
'bn',
|
706 |
+
nn.BatchNorm2d(
|
707 |
+
out_channels_right, eps=0.001, momentum=0.1, affine=True
|
708 |
+
)
|
709 |
+
)
|
710 |
+
|
711 |
+
self.comb_iter_0_left = BranchSeparablesReduction(
|
712 |
+
out_channels_right, out_channels_right, 5, 2, 2, bias=False
|
713 |
+
)
|
714 |
+
self.comb_iter_0_right = BranchSeparablesReduction(
|
715 |
+
out_channels_right, out_channels_right, 7, 2, 3, bias=False
|
716 |
+
)
|
717 |
+
|
718 |
+
self.comb_iter_1_left = MaxPoolPad()
|
719 |
+
self.comb_iter_1_right = BranchSeparablesReduction(
|
720 |
+
out_channels_right, out_channels_right, 7, 2, 3, bias=False
|
721 |
+
)
|
722 |
+
|
723 |
+
self.comb_iter_2_left = AvgPoolPad()
|
724 |
+
self.comb_iter_2_right = BranchSeparablesReduction(
|
725 |
+
out_channels_right, out_channels_right, 5, 2, 2, bias=False
|
726 |
+
)
|
727 |
+
|
728 |
+
self.comb_iter_3_right = nn.AvgPool2d(
|
729 |
+
3, stride=1, padding=1, count_include_pad=False
|
730 |
+
)
|
731 |
+
|
732 |
+
self.comb_iter_4_left = BranchSeparablesReduction(
|
733 |
+
out_channels_right, out_channels_right, 3, 1, 1, bias=False
|
734 |
+
)
|
735 |
+
self.comb_iter_4_right = MaxPoolPad()
|
736 |
+
|
737 |
+
def forward(self, x, x_prev):
|
738 |
+
x_left = self.conv_prev_1x1(x_prev)
|
739 |
+
x_right = self.conv_1x1(x)
|
740 |
+
|
741 |
+
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
742 |
+
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
743 |
+
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
744 |
+
|
745 |
+
x_comb_iter_1_left = self.comb_iter_1_left(x_right)
|
746 |
+
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
747 |
+
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
748 |
+
|
749 |
+
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
750 |
+
x_comb_iter_2_right = self.comb_iter_2_right(x_left)
|
751 |
+
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
752 |
+
|
753 |
+
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
754 |
+
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
755 |
+
|
756 |
+
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
757 |
+
x_comb_iter_4_right = self.comb_iter_4_right(x_right)
|
758 |
+
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
759 |
+
|
760 |
+
x_out = torch.cat(
|
761 |
+
[x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1
|
762 |
+
)
|
763 |
+
return x_out
|
764 |
+
|
765 |
+
|
766 |
+
class ReductionCell1(nn.Module):
|
767 |
+
|
768 |
+
def __init__(
|
769 |
+
self, in_channels_left, out_channels_left, in_channels_right,
|
770 |
+
out_channels_right
|
771 |
+
):
|
772 |
+
super(ReductionCell1, self).__init__()
|
773 |
+
self.conv_prev_1x1 = nn.Sequential()
|
774 |
+
self.conv_prev_1x1.add_module('relu', nn.ReLU())
|
775 |
+
self.conv_prev_1x1.add_module(
|
776 |
+
'conv',
|
777 |
+
nn.Conv2d(
|
778 |
+
in_channels_left, out_channels_left, 1, stride=1, bias=False
|
779 |
+
)
|
780 |
+
)
|
781 |
+
self.conv_prev_1x1.add_module(
|
782 |
+
'bn',
|
783 |
+
nn.BatchNorm2d(
|
784 |
+
out_channels_left, eps=0.001, momentum=0.1, affine=True
|
785 |
+
)
|
786 |
+
)
|
787 |
+
|
788 |
+
self.conv_1x1 = nn.Sequential()
|
789 |
+
self.conv_1x1.add_module('relu', nn.ReLU())
|
790 |
+
self.conv_1x1.add_module(
|
791 |
+
'conv',
|
792 |
+
nn.Conv2d(
|
793 |
+
in_channels_right, out_channels_right, 1, stride=1, bias=False
|
794 |
+
)
|
795 |
+
)
|
796 |
+
self.conv_1x1.add_module(
|
797 |
+
'bn',
|
798 |
+
nn.BatchNorm2d(
|
799 |
+
out_channels_right, eps=0.001, momentum=0.1, affine=True
|
800 |
+
)
|
801 |
+
)
|
802 |
+
|
803 |
+
self.comb_iter_0_left = BranchSeparables(
|
804 |
+
out_channels_right,
|
805 |
+
out_channels_right,
|
806 |
+
5,
|
807 |
+
2,
|
808 |
+
2,
|
809 |
+
name='specific',
|
810 |
+
bias=False
|
811 |
+
)
|
812 |
+
self.comb_iter_0_right = BranchSeparables(
|
813 |
+
out_channels_right,
|
814 |
+
out_channels_right,
|
815 |
+
7,
|
816 |
+
2,
|
817 |
+
3,
|
818 |
+
name='specific',
|
819 |
+
bias=False
|
820 |
+
)
|
821 |
+
|
822 |
+
# self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
|
823 |
+
self.comb_iter_1_left = MaxPoolPad()
|
824 |
+
self.comb_iter_1_right = BranchSeparables(
|
825 |
+
out_channels_right,
|
826 |
+
out_channels_right,
|
827 |
+
7,
|
828 |
+
2,
|
829 |
+
3,
|
830 |
+
name='specific',
|
831 |
+
bias=False
|
832 |
+
)
|
833 |
+
|
834 |
+
# self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
|
835 |
+
self.comb_iter_2_left = AvgPoolPad()
|
836 |
+
self.comb_iter_2_right = BranchSeparables(
|
837 |
+
out_channels_right,
|
838 |
+
out_channels_right,
|
839 |
+
5,
|
840 |
+
2,
|
841 |
+
2,
|
842 |
+
name='specific',
|
843 |
+
bias=False
|
844 |
+
)
|
845 |
+
|
846 |
+
self.comb_iter_3_right = nn.AvgPool2d(
|
847 |
+
3, stride=1, padding=1, count_include_pad=False
|
848 |
+
)
|
849 |
+
|
850 |
+
self.comb_iter_4_left = BranchSeparables(
|
851 |
+
out_channels_right,
|
852 |
+
out_channels_right,
|
853 |
+
3,
|
854 |
+
1,
|
855 |
+
1,
|
856 |
+
name='specific',
|
857 |
+
bias=False
|
858 |
+
)
|
859 |
+
# self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
|
860 |
+
self.comb_iter_4_right = MaxPoolPad()
|
861 |
+
|
862 |
+
def forward(self, x, x_prev):
|
863 |
+
x_left = self.conv_prev_1x1(x_prev)
|
864 |
+
x_right = self.conv_1x1(x)
|
865 |
+
|
866 |
+
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
867 |
+
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
868 |
+
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
869 |
+
|
870 |
+
x_comb_iter_1_left = self.comb_iter_1_left(x_right)
|
871 |
+
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
872 |
+
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
873 |
+
|
874 |
+
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
875 |
+
x_comb_iter_2_right = self.comb_iter_2_right(x_left)
|
876 |
+
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
877 |
+
|
878 |
+
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
879 |
+
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
880 |
+
|
881 |
+
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
882 |
+
x_comb_iter_4_right = self.comb_iter_4_right(x_right)
|
883 |
+
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
884 |
+
|
885 |
+
x_out = torch.cat(
|
886 |
+
[x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1
|
887 |
+
)
|
888 |
+
return x_out
|
889 |
+
|
890 |
+
|
891 |
+
class NASNetAMobile(nn.Module):
|
892 |
+
"""Neural Architecture Search (NAS).
|
893 |
+
|
894 |
+
Reference:
|
895 |
+
Zoph et al. Learning Transferable Architectures
|
896 |
+
for Scalable Image Recognition. CVPR 2018.
|
897 |
+
|
898 |
+
Public keys:
|
899 |
+
- ``nasnetamobile``: NASNet-A Mobile.
|
900 |
+
"""
|
901 |
+
|
902 |
+
def __init__(
|
903 |
+
self,
|
904 |
+
num_classes,
|
905 |
+
loss,
|
906 |
+
stem_filters=32,
|
907 |
+
penultimate_filters=1056,
|
908 |
+
filters_multiplier=2,
|
909 |
+
**kwargs
|
910 |
+
):
|
911 |
+
super(NASNetAMobile, self).__init__()
|
912 |
+
self.stem_filters = stem_filters
|
913 |
+
self.penultimate_filters = penultimate_filters
|
914 |
+
self.filters_multiplier = filters_multiplier
|
915 |
+
self.loss = loss
|
916 |
+
|
917 |
+
filters = self.penultimate_filters // 24
|
918 |
+
# 24 is default value for the architecture
|
919 |
+
|
920 |
+
self.conv0 = nn.Sequential()
|
921 |
+
self.conv0.add_module(
|
922 |
+
'conv',
|
923 |
+
nn.Conv2d(
|
924 |
+
in_channels=3,
|
925 |
+
out_channels=self.stem_filters,
|
926 |
+
kernel_size=3,
|
927 |
+
padding=0,
|
928 |
+
stride=2,
|
929 |
+
bias=False
|
930 |
+
)
|
931 |
+
)
|
932 |
+
self.conv0.add_module(
|
933 |
+
'bn',
|
934 |
+
nn.BatchNorm2d(
|
935 |
+
self.stem_filters, eps=0.001, momentum=0.1, affine=True
|
936 |
+
)
|
937 |
+
)
|
938 |
+
|
939 |
+
self.cell_stem_0 = CellStem0(
|
940 |
+
self.stem_filters, num_filters=filters // (filters_multiplier**2)
|
941 |
+
)
|
942 |
+
self.cell_stem_1 = CellStem1(
|
943 |
+
self.stem_filters, num_filters=filters // filters_multiplier
|
944 |
+
)
|
945 |
+
|
946 |
+
self.cell_0 = FirstCell(
|
947 |
+
in_channels_left=filters,
|
948 |
+
out_channels_left=filters // 2, # 1, 0.5
|
949 |
+
in_channels_right=2 * filters,
|
950 |
+
out_channels_right=filters
|
951 |
+
) # 2, 1
|
952 |
+
self.cell_1 = NormalCell(
|
953 |
+
in_channels_left=2 * filters,
|
954 |
+
out_channels_left=filters, # 2, 1
|
955 |
+
in_channels_right=6 * filters,
|
956 |
+
out_channels_right=filters
|
957 |
+
) # 6, 1
|
958 |
+
self.cell_2 = NormalCell(
|
959 |
+
in_channels_left=6 * filters,
|
960 |
+
out_channels_left=filters, # 6, 1
|
961 |
+
in_channels_right=6 * filters,
|
962 |
+
out_channels_right=filters
|
963 |
+
) # 6, 1
|
964 |
+
self.cell_3 = NormalCell(
|
965 |
+
in_channels_left=6 * filters,
|
966 |
+
out_channels_left=filters, # 6, 1
|
967 |
+
in_channels_right=6 * filters,
|
968 |
+
out_channels_right=filters
|
969 |
+
) # 6, 1
|
970 |
+
|
971 |
+
self.reduction_cell_0 = ReductionCell0(
|
972 |
+
in_channels_left=6 * filters,
|
973 |
+
out_channels_left=2 * filters, # 6, 2
|
974 |
+
in_channels_right=6 * filters,
|
975 |
+
out_channels_right=2 * filters
|
976 |
+
) # 6, 2
|
977 |
+
|
978 |
+
self.cell_6 = FirstCell(
|
979 |
+
in_channels_left=6 * filters,
|
980 |
+
out_channels_left=filters, # 6, 1
|
981 |
+
in_channels_right=8 * filters,
|
982 |
+
out_channels_right=2 * filters
|
983 |
+
) # 8, 2
|
984 |
+
self.cell_7 = NormalCell(
|
985 |
+
in_channels_left=8 * filters,
|
986 |
+
out_channels_left=2 * filters, # 8, 2
|
987 |
+
in_channels_right=12 * filters,
|
988 |
+
out_channels_right=2 * filters
|
989 |
+
) # 12, 2
|
990 |
+
self.cell_8 = NormalCell(
|
991 |
+
in_channels_left=12 * filters,
|
992 |
+
out_channels_left=2 * filters, # 12, 2
|
993 |
+
in_channels_right=12 * filters,
|
994 |
+
out_channels_right=2 * filters
|
995 |
+
) # 12, 2
|
996 |
+
self.cell_9 = NormalCell(
|
997 |
+
in_channels_left=12 * filters,
|
998 |
+
out_channels_left=2 * filters, # 12, 2
|
999 |
+
in_channels_right=12 * filters,
|
1000 |
+
out_channels_right=2 * filters
|
1001 |
+
) # 12, 2
|
1002 |
+
|
1003 |
+
self.reduction_cell_1 = ReductionCell1(
|
1004 |
+
in_channels_left=12 * filters,
|
1005 |
+
out_channels_left=4 * filters, # 12, 4
|
1006 |
+
in_channels_right=12 * filters,
|
1007 |
+
out_channels_right=4 * filters
|
1008 |
+
) # 12, 4
|
1009 |
+
|
1010 |
+
self.cell_12 = FirstCell(
|
1011 |
+
in_channels_left=12 * filters,
|
1012 |
+
out_channels_left=2 * filters, # 12, 2
|
1013 |
+
in_channels_right=16 * filters,
|
1014 |
+
out_channels_right=4 * filters
|
1015 |
+
) # 16, 4
|
1016 |
+
self.cell_13 = NormalCell(
|
1017 |
+
in_channels_left=16 * filters,
|
1018 |
+
out_channels_left=4 * filters, # 16, 4
|
1019 |
+
in_channels_right=24 * filters,
|
1020 |
+
out_channels_right=4 * filters
|
1021 |
+
) # 24, 4
|
1022 |
+
self.cell_14 = NormalCell(
|
1023 |
+
in_channels_left=24 * filters,
|
1024 |
+
out_channels_left=4 * filters, # 24, 4
|
1025 |
+
in_channels_right=24 * filters,
|
1026 |
+
out_channels_right=4 * filters
|
1027 |
+
) # 24, 4
|
1028 |
+
self.cell_15 = NormalCell(
|
1029 |
+
in_channels_left=24 * filters,
|
1030 |
+
out_channels_left=4 * filters, # 24, 4
|
1031 |
+
in_channels_right=24 * filters,
|
1032 |
+
out_channels_right=4 * filters
|
1033 |
+
) # 24, 4
|
1034 |
+
|
1035 |
+
self.relu = nn.ReLU()
|
1036 |
+
self.dropout = nn.Dropout()
|
1037 |
+
self.classifier = nn.Linear(24 * filters, num_classes)
|
1038 |
+
|
1039 |
+
self._init_params()
|
1040 |
+
|
1041 |
+
def _init_params(self):
|
1042 |
+
for m in self.modules():
|
1043 |
+
if isinstance(m, nn.Conv2d):
|
1044 |
+
nn.init.kaiming_normal_(
|
1045 |
+
m.weight, mode='fan_out', nonlinearity='relu'
|
1046 |
+
)
|
1047 |
+
if m.bias is not None:
|
1048 |
+
nn.init.constant_(m.bias, 0)
|
1049 |
+
elif isinstance(m, nn.BatchNorm2d):
|
1050 |
+
nn.init.constant_(m.weight, 1)
|
1051 |
+
nn.init.constant_(m.bias, 0)
|
1052 |
+
elif isinstance(m, nn.BatchNorm1d):
|
1053 |
+
nn.init.constant_(m.weight, 1)
|
1054 |
+
nn.init.constant_(m.bias, 0)
|
1055 |
+
elif isinstance(m, nn.Linear):
|
1056 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
1057 |
+
if m.bias is not None:
|
1058 |
+
nn.init.constant_(m.bias, 0)
|
1059 |
+
|
1060 |
+
def features(self, input):
|
1061 |
+
x_conv0 = self.conv0(input)
|
1062 |
+
x_stem_0 = self.cell_stem_0(x_conv0)
|
1063 |
+
x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
|
1064 |
+
|
1065 |
+
x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
|
1066 |
+
x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
|
1067 |
+
x_cell_2 = self.cell_2(x_cell_1, x_cell_0)
|
1068 |
+
x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
|
1069 |
+
|
1070 |
+
x_reduction_cell_0 = self.reduction_cell_0(x_cell_3, x_cell_2)
|
1071 |
+
|
1072 |
+
x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_3)
|
1073 |
+
x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0)
|
1074 |
+
x_cell_8 = self.cell_8(x_cell_7, x_cell_6)
|
1075 |
+
x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
|
1076 |
+
|
1077 |
+
x_reduction_cell_1 = self.reduction_cell_1(x_cell_9, x_cell_8)
|
1078 |
+
|
1079 |
+
x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_9)
|
1080 |
+
x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1)
|
1081 |
+
x_cell_14 = self.cell_14(x_cell_13, x_cell_12)
|
1082 |
+
x_cell_15 = self.cell_15(x_cell_14, x_cell_13)
|
1083 |
+
|
1084 |
+
x_cell_15 = self.relu(x_cell_15)
|
1085 |
+
x_cell_15 = F.avg_pool2d(
|
1086 |
+
x_cell_15,
|
1087 |
+
x_cell_15.size()[2:]
|
1088 |
+
) # global average pool
|
1089 |
+
x_cell_15 = x_cell_15.view(x_cell_15.size(0), -1)
|
1090 |
+
x_cell_15 = self.dropout(x_cell_15)
|
1091 |
+
|
1092 |
+
return x_cell_15
|
1093 |
+
|
1094 |
+
def forward(self, input):
|
1095 |
+
v = self.features(input)
|
1096 |
+
|
1097 |
+
if not self.training:
|
1098 |
+
return v
|
1099 |
+
|
1100 |
+
y = self.classifier(v)
|
1101 |
+
|
1102 |
+
if self.loss == 'softmax':
|
1103 |
+
return y
|
1104 |
+
elif self.loss == 'triplet':
|
1105 |
+
return y, v
|
1106 |
+
else:
|
1107 |
+
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
1108 |
+
|
1109 |
+
|
1110 |
+
def init_pretrained_weights(model, model_url):
|
1111 |
+
"""Initializes model with pretrained weights.
|
1112 |
+
|
1113 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
1114 |
+
"""
|
1115 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
1116 |
+
model_dict = model.state_dict()
|
1117 |
+
pretrain_dict = {
|
1118 |
+
k: v
|
1119 |
+
for k, v in pretrain_dict.items()
|
1120 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
1121 |
+
}
|
1122 |
+
model_dict.update(pretrain_dict)
|
1123 |
+
model.load_state_dict(model_dict)
|
1124 |
+
|
1125 |
+
|
1126 |
+
def nasnetamobile(num_classes, loss='softmax', pretrained=True, **kwargs):
|
1127 |
+
model = NASNetAMobile(num_classes, loss, **kwargs)
|
1128 |
+
if pretrained:
|
1129 |
+
model_url = pretrained_settings['nasnetamobile']['imagenet']['url']
|
1130 |
+
init_pretrained_weights(model, model_url)
|
1131 |
+
return model
|
trackers/strongsort/deep/models/osnet.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, absolute_import
|
2 |
+
import warnings
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'osnet_x1_0', 'osnet_x0_75', 'osnet_x0_5', 'osnet_x0_25', 'osnet_ibn_x1_0'
|
9 |
+
]
|
10 |
+
|
11 |
+
pretrained_urls = {
|
12 |
+
'osnet_x1_0':
|
13 |
+
'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY',
|
14 |
+
'osnet_x0_75':
|
15 |
+
'https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq',
|
16 |
+
'osnet_x0_5':
|
17 |
+
'https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i',
|
18 |
+
'osnet_x0_25':
|
19 |
+
'https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs',
|
20 |
+
'osnet_ibn_x1_0':
|
21 |
+
'https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l'
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
##########
|
26 |
+
# Basic layers
|
27 |
+
##########
|
28 |
+
class ConvLayer(nn.Module):
|
29 |
+
"""Convolution layer (conv + bn + relu)."""
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
in_channels,
|
34 |
+
out_channels,
|
35 |
+
kernel_size,
|
36 |
+
stride=1,
|
37 |
+
padding=0,
|
38 |
+
groups=1,
|
39 |
+
IN=False
|
40 |
+
):
|
41 |
+
super(ConvLayer, self).__init__()
|
42 |
+
self.conv = nn.Conv2d(
|
43 |
+
in_channels,
|
44 |
+
out_channels,
|
45 |
+
kernel_size,
|
46 |
+
stride=stride,
|
47 |
+
padding=padding,
|
48 |
+
bias=False,
|
49 |
+
groups=groups
|
50 |
+
)
|
51 |
+
if IN:
|
52 |
+
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
|
53 |
+
else:
|
54 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
55 |
+
self.relu = nn.ReLU(inplace=True)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
x = self.conv(x)
|
59 |
+
x = self.bn(x)
|
60 |
+
x = self.relu(x)
|
61 |
+
return x
|
62 |
+
|
63 |
+
|
64 |
+
class Conv1x1(nn.Module):
|
65 |
+
"""1x1 convolution + bn + relu."""
|
66 |
+
|
67 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
68 |
+
super(Conv1x1, self).__init__()
|
69 |
+
self.conv = nn.Conv2d(
|
70 |
+
in_channels,
|
71 |
+
out_channels,
|
72 |
+
1,
|
73 |
+
stride=stride,
|
74 |
+
padding=0,
|
75 |
+
bias=False,
|
76 |
+
groups=groups
|
77 |
+
)
|
78 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
79 |
+
self.relu = nn.ReLU(inplace=True)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
x = self.conv(x)
|
83 |
+
x = self.bn(x)
|
84 |
+
x = self.relu(x)
|
85 |
+
return x
|
86 |
+
|
87 |
+
|
88 |
+
class Conv1x1Linear(nn.Module):
|
89 |
+
"""1x1 convolution + bn (w/o non-linearity)."""
|
90 |
+
|
91 |
+
def __init__(self, in_channels, out_channels, stride=1):
|
92 |
+
super(Conv1x1Linear, self).__init__()
|
93 |
+
self.conv = nn.Conv2d(
|
94 |
+
in_channels, out_channels, 1, stride=stride, padding=0, bias=False
|
95 |
+
)
|
96 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
x = self.conv(x)
|
100 |
+
x = self.bn(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
class Conv3x3(nn.Module):
|
105 |
+
"""3x3 convolution + bn + relu."""
|
106 |
+
|
107 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
108 |
+
super(Conv3x3, self).__init__()
|
109 |
+
self.conv = nn.Conv2d(
|
110 |
+
in_channels,
|
111 |
+
out_channels,
|
112 |
+
3,
|
113 |
+
stride=stride,
|
114 |
+
padding=1,
|
115 |
+
bias=False,
|
116 |
+
groups=groups
|
117 |
+
)
|
118 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
119 |
+
self.relu = nn.ReLU(inplace=True)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
x = self.conv(x)
|
123 |
+
x = self.bn(x)
|
124 |
+
x = self.relu(x)
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
class LightConv3x3(nn.Module):
|
129 |
+
"""Lightweight 3x3 convolution.
|
130 |
+
|
131 |
+
1x1 (linear) + dw 3x3 (nonlinear).
|
132 |
+
"""
|
133 |
+
|
134 |
+
def __init__(self, in_channels, out_channels):
|
135 |
+
super(LightConv3x3, self).__init__()
|
136 |
+
self.conv1 = nn.Conv2d(
|
137 |
+
in_channels, out_channels, 1, stride=1, padding=0, bias=False
|
138 |
+
)
|
139 |
+
self.conv2 = nn.Conv2d(
|
140 |
+
out_channels,
|
141 |
+
out_channels,
|
142 |
+
3,
|
143 |
+
stride=1,
|
144 |
+
padding=1,
|
145 |
+
bias=False,
|
146 |
+
groups=out_channels
|
147 |
+
)
|
148 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
149 |
+
self.relu = nn.ReLU(inplace=True)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
x = self.conv1(x)
|
153 |
+
x = self.conv2(x)
|
154 |
+
x = self.bn(x)
|
155 |
+
x = self.relu(x)
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
##########
|
160 |
+
# Building blocks for omni-scale feature learning
|
161 |
+
##########
|
162 |
+
class ChannelGate(nn.Module):
|
163 |
+
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
|
164 |
+
|
165 |
+
def __init__(
|
166 |
+
self,
|
167 |
+
in_channels,
|
168 |
+
num_gates=None,
|
169 |
+
return_gates=False,
|
170 |
+
gate_activation='sigmoid',
|
171 |
+
reduction=16,
|
172 |
+
layer_norm=False
|
173 |
+
):
|
174 |
+
super(ChannelGate, self).__init__()
|
175 |
+
if num_gates is None:
|
176 |
+
num_gates = in_channels
|
177 |
+
self.return_gates = return_gates
|
178 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
179 |
+
self.fc1 = nn.Conv2d(
|
180 |
+
in_channels,
|
181 |
+
in_channels // reduction,
|
182 |
+
kernel_size=1,
|
183 |
+
bias=True,
|
184 |
+
padding=0
|
185 |
+
)
|
186 |
+
self.norm1 = None
|
187 |
+
if layer_norm:
|
188 |
+
self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
|
189 |
+
self.relu = nn.ReLU(inplace=True)
|
190 |
+
self.fc2 = nn.Conv2d(
|
191 |
+
in_channels // reduction,
|
192 |
+
num_gates,
|
193 |
+
kernel_size=1,
|
194 |
+
bias=True,
|
195 |
+
padding=0
|
196 |
+
)
|
197 |
+
if gate_activation == 'sigmoid':
|
198 |
+
self.gate_activation = nn.Sigmoid()
|
199 |
+
elif gate_activation == 'relu':
|
200 |
+
self.gate_activation = nn.ReLU(inplace=True)
|
201 |
+
elif gate_activation == 'linear':
|
202 |
+
self.gate_activation = None
|
203 |
+
else:
|
204 |
+
raise RuntimeError(
|
205 |
+
"Unknown gate activation: {}".format(gate_activation)
|
206 |
+
)
|
207 |
+
|
208 |
+
def forward(self, x):
|
209 |
+
input = x
|
210 |
+
x = self.global_avgpool(x)
|
211 |
+
x = self.fc1(x)
|
212 |
+
if self.norm1 is not None:
|
213 |
+
x = self.norm1(x)
|
214 |
+
x = self.relu(x)
|
215 |
+
x = self.fc2(x)
|
216 |
+
if self.gate_activation is not None:
|
217 |
+
x = self.gate_activation(x)
|
218 |
+
if self.return_gates:
|
219 |
+
return x
|
220 |
+
return input * x
|
221 |
+
|
222 |
+
|
223 |
+
class OSBlock(nn.Module):
|
224 |
+
"""Omni-scale feature learning block."""
|
225 |
+
|
226 |
+
def __init__(
|
227 |
+
self,
|
228 |
+
in_channels,
|
229 |
+
out_channels,
|
230 |
+
IN=False,
|
231 |
+
bottleneck_reduction=4,
|
232 |
+
**kwargs
|
233 |
+
):
|
234 |
+
super(OSBlock, self).__init__()
|
235 |
+
mid_channels = out_channels // bottleneck_reduction
|
236 |
+
self.conv1 = Conv1x1(in_channels, mid_channels)
|
237 |
+
self.conv2a = LightConv3x3(mid_channels, mid_channels)
|
238 |
+
self.conv2b = nn.Sequential(
|
239 |
+
LightConv3x3(mid_channels, mid_channels),
|
240 |
+
LightConv3x3(mid_channels, mid_channels),
|
241 |
+
)
|
242 |
+
self.conv2c = nn.Sequential(
|
243 |
+
LightConv3x3(mid_channels, mid_channels),
|
244 |
+
LightConv3x3(mid_channels, mid_channels),
|
245 |
+
LightConv3x3(mid_channels, mid_channels),
|
246 |
+
)
|
247 |
+
self.conv2d = nn.Sequential(
|
248 |
+
LightConv3x3(mid_channels, mid_channels),
|
249 |
+
LightConv3x3(mid_channels, mid_channels),
|
250 |
+
LightConv3x3(mid_channels, mid_channels),
|
251 |
+
LightConv3x3(mid_channels, mid_channels),
|
252 |
+
)
|
253 |
+
self.gate = ChannelGate(mid_channels)
|
254 |
+
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
|
255 |
+
self.downsample = None
|
256 |
+
if in_channels != out_channels:
|
257 |
+
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
258 |
+
self.IN = None
|
259 |
+
if IN:
|
260 |
+
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
|
261 |
+
|
262 |
+
def forward(self, x):
|
263 |
+
identity = x
|
264 |
+
x1 = self.conv1(x)
|
265 |
+
x2a = self.conv2a(x1)
|
266 |
+
x2b = self.conv2b(x1)
|
267 |
+
x2c = self.conv2c(x1)
|
268 |
+
x2d = self.conv2d(x1)
|
269 |
+
x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
|
270 |
+
x3 = self.conv3(x2)
|
271 |
+
if self.downsample is not None:
|
272 |
+
identity = self.downsample(identity)
|
273 |
+
out = x3 + identity
|
274 |
+
if self.IN is not None:
|
275 |
+
out = self.IN(out)
|
276 |
+
return F.relu(out)
|
277 |
+
|
278 |
+
|
279 |
+
##########
|
280 |
+
# Network architecture
|
281 |
+
##########
|
282 |
+
class OSNet(nn.Module):
|
283 |
+
"""Omni-Scale Network.
|
284 |
+
|
285 |
+
Reference:
|
286 |
+
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
|
287 |
+
- Zhou et al. Learning Generalisable Omni-Scale Representations
|
288 |
+
for Person Re-Identification. TPAMI, 2021.
|
289 |
+
"""
|
290 |
+
|
291 |
+
def __init__(
|
292 |
+
self,
|
293 |
+
num_classes,
|
294 |
+
blocks,
|
295 |
+
layers,
|
296 |
+
channels,
|
297 |
+
feature_dim=512,
|
298 |
+
loss='softmax',
|
299 |
+
IN=False,
|
300 |
+
**kwargs
|
301 |
+
):
|
302 |
+
super(OSNet, self).__init__()
|
303 |
+
num_blocks = len(blocks)
|
304 |
+
assert num_blocks == len(layers)
|
305 |
+
assert num_blocks == len(channels) - 1
|
306 |
+
self.loss = loss
|
307 |
+
self.feature_dim = feature_dim
|
308 |
+
|
309 |
+
# convolutional backbone
|
310 |
+
self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
|
311 |
+
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
312 |
+
self.conv2 = self._make_layer(
|
313 |
+
blocks[0],
|
314 |
+
layers[0],
|
315 |
+
channels[0],
|
316 |
+
channels[1],
|
317 |
+
reduce_spatial_size=True,
|
318 |
+
IN=IN
|
319 |
+
)
|
320 |
+
self.conv3 = self._make_layer(
|
321 |
+
blocks[1],
|
322 |
+
layers[1],
|
323 |
+
channels[1],
|
324 |
+
channels[2],
|
325 |
+
reduce_spatial_size=True
|
326 |
+
)
|
327 |
+
self.conv4 = self._make_layer(
|
328 |
+
blocks[2],
|
329 |
+
layers[2],
|
330 |
+
channels[2],
|
331 |
+
channels[3],
|
332 |
+
reduce_spatial_size=False
|
333 |
+
)
|
334 |
+
self.conv5 = Conv1x1(channels[3], channels[3])
|
335 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
336 |
+
# fully connected layer
|
337 |
+
self.fc = self._construct_fc_layer(
|
338 |
+
self.feature_dim, channels[3], dropout_p=None
|
339 |
+
)
|
340 |
+
# identity classification layer
|
341 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
342 |
+
|
343 |
+
self._init_params()
|
344 |
+
|
345 |
+
def _make_layer(
|
346 |
+
self,
|
347 |
+
block,
|
348 |
+
layer,
|
349 |
+
in_channels,
|
350 |
+
out_channels,
|
351 |
+
reduce_spatial_size,
|
352 |
+
IN=False
|
353 |
+
):
|
354 |
+
layers = []
|
355 |
+
|
356 |
+
layers.append(block(in_channels, out_channels, IN=IN))
|
357 |
+
for i in range(1, layer):
|
358 |
+
layers.append(block(out_channels, out_channels, IN=IN))
|
359 |
+
|
360 |
+
if reduce_spatial_size:
|
361 |
+
layers.append(
|
362 |
+
nn.Sequential(
|
363 |
+
Conv1x1(out_channels, out_channels),
|
364 |
+
nn.AvgPool2d(2, stride=2)
|
365 |
+
)
|
366 |
+
)
|
367 |
+
|
368 |
+
return nn.Sequential(*layers)
|
369 |
+
|
370 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
371 |
+
if fc_dims is None or fc_dims < 0:
|
372 |
+
self.feature_dim = input_dim
|
373 |
+
return None
|
374 |
+
|
375 |
+
if isinstance(fc_dims, int):
|
376 |
+
fc_dims = [fc_dims]
|
377 |
+
|
378 |
+
layers = []
|
379 |
+
for dim in fc_dims:
|
380 |
+
layers.append(nn.Linear(input_dim, dim))
|
381 |
+
layers.append(nn.BatchNorm1d(dim))
|
382 |
+
layers.append(nn.ReLU(inplace=True))
|
383 |
+
if dropout_p is not None:
|
384 |
+
layers.append(nn.Dropout(p=dropout_p))
|
385 |
+
input_dim = dim
|
386 |
+
|
387 |
+
self.feature_dim = fc_dims[-1]
|
388 |
+
|
389 |
+
return nn.Sequential(*layers)
|
390 |
+
|
391 |
+
def _init_params(self):
|
392 |
+
for m in self.modules():
|
393 |
+
if isinstance(m, nn.Conv2d):
|
394 |
+
nn.init.kaiming_normal_(
|
395 |
+
m.weight, mode='fan_out', nonlinearity='relu'
|
396 |
+
)
|
397 |
+
if m.bias is not None:
|
398 |
+
nn.init.constant_(m.bias, 0)
|
399 |
+
|
400 |
+
elif isinstance(m, nn.BatchNorm2d):
|
401 |
+
nn.init.constant_(m.weight, 1)
|
402 |
+
nn.init.constant_(m.bias, 0)
|
403 |
+
|
404 |
+
elif isinstance(m, nn.BatchNorm1d):
|
405 |
+
nn.init.constant_(m.weight, 1)
|
406 |
+
nn.init.constant_(m.bias, 0)
|
407 |
+
|
408 |
+
elif isinstance(m, nn.Linear):
|
409 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
410 |
+
if m.bias is not None:
|
411 |
+
nn.init.constant_(m.bias, 0)
|
412 |
+
|
413 |
+
def featuremaps(self, x):
|
414 |
+
x = self.conv1(x)
|
415 |
+
x = self.maxpool(x)
|
416 |
+
x = self.conv2(x)
|
417 |
+
x = self.conv3(x)
|
418 |
+
x = self.conv4(x)
|
419 |
+
x = self.conv5(x)
|
420 |
+
return x
|
421 |
+
|
422 |
+
def forward(self, x, return_featuremaps=False):
|
423 |
+
x = self.featuremaps(x)
|
424 |
+
if return_featuremaps:
|
425 |
+
return x
|
426 |
+
v = self.global_avgpool(x)
|
427 |
+
v = v.view(v.size(0), -1)
|
428 |
+
if self.fc is not None:
|
429 |
+
v = self.fc(v)
|
430 |
+
if not self.training:
|
431 |
+
return v
|
432 |
+
y = self.classifier(v)
|
433 |
+
if self.loss == 'softmax':
|
434 |
+
return y
|
435 |
+
elif self.loss == 'triplet':
|
436 |
+
return y, v
|
437 |
+
else:
|
438 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
439 |
+
|
440 |
+
|
441 |
+
def init_pretrained_weights(model, key=''):
|
442 |
+
"""Initializes model with pretrained weights.
|
443 |
+
|
444 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
445 |
+
"""
|
446 |
+
import os
|
447 |
+
import errno
|
448 |
+
import gdown
|
449 |
+
from collections import OrderedDict
|
450 |
+
|
451 |
+
def _get_torch_home():
|
452 |
+
ENV_TORCH_HOME = 'TORCH_HOME'
|
453 |
+
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
454 |
+
DEFAULT_CACHE_DIR = '~/.cache'
|
455 |
+
torch_home = os.path.expanduser(
|
456 |
+
os.getenv(
|
457 |
+
ENV_TORCH_HOME,
|
458 |
+
os.path.join(
|
459 |
+
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
|
460 |
+
)
|
461 |
+
)
|
462 |
+
)
|
463 |
+
return torch_home
|
464 |
+
|
465 |
+
torch_home = _get_torch_home()
|
466 |
+
model_dir = os.path.join(torch_home, 'checkpoints')
|
467 |
+
try:
|
468 |
+
os.makedirs(model_dir)
|
469 |
+
except OSError as e:
|
470 |
+
if e.errno == errno.EEXIST:
|
471 |
+
# Directory already exists, ignore.
|
472 |
+
pass
|
473 |
+
else:
|
474 |
+
# Unexpected OSError, re-raise.
|
475 |
+
raise
|
476 |
+
filename = key + '_imagenet.pth'
|
477 |
+
cached_file = os.path.join(model_dir, filename)
|
478 |
+
|
479 |
+
if not os.path.exists(cached_file):
|
480 |
+
gdown.download(pretrained_urls[key], cached_file, quiet=False)
|
481 |
+
|
482 |
+
state_dict = torch.load(cached_file)
|
483 |
+
model_dict = model.state_dict()
|
484 |
+
new_state_dict = OrderedDict()
|
485 |
+
matched_layers, discarded_layers = [], []
|
486 |
+
|
487 |
+
for k, v in state_dict.items():
|
488 |
+
if k.startswith('module.'):
|
489 |
+
k = k[7:] # discard module.
|
490 |
+
|
491 |
+
if k in model_dict and model_dict[k].size() == v.size():
|
492 |
+
new_state_dict[k] = v
|
493 |
+
matched_layers.append(k)
|
494 |
+
else:
|
495 |
+
discarded_layers.append(k)
|
496 |
+
|
497 |
+
model_dict.update(new_state_dict)
|
498 |
+
model.load_state_dict(model_dict)
|
499 |
+
|
500 |
+
if len(matched_layers) == 0:
|
501 |
+
warnings.warn(
|
502 |
+
'The pretrained weights from "{}" cannot be loaded, '
|
503 |
+
'please check the key names manually '
|
504 |
+
'(** ignored and continue **)'.format(cached_file)
|
505 |
+
)
|
506 |
+
else:
|
507 |
+
print(
|
508 |
+
'Successfully loaded imagenet pretrained weights from "{}"'.
|
509 |
+
format(cached_file)
|
510 |
+
)
|
511 |
+
if len(discarded_layers) > 0:
|
512 |
+
print(
|
513 |
+
'** The following layers are discarded '
|
514 |
+
'due to unmatched keys or layer size: {}'.
|
515 |
+
format(discarded_layers)
|
516 |
+
)
|
517 |
+
|
518 |
+
|
519 |
+
##########
|
520 |
+
# Instantiation
|
521 |
+
##########
|
522 |
+
def osnet_x1_0(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
523 |
+
# standard size (width x1.0)
|
524 |
+
model = OSNet(
|
525 |
+
num_classes,
|
526 |
+
blocks=[OSBlock, OSBlock, OSBlock],
|
527 |
+
layers=[2, 2, 2],
|
528 |
+
channels=[64, 256, 384, 512],
|
529 |
+
loss=loss,
|
530 |
+
**kwargs
|
531 |
+
)
|
532 |
+
if pretrained:
|
533 |
+
init_pretrained_weights(model, key='osnet_x1_0')
|
534 |
+
return model
|
535 |
+
|
536 |
+
|
537 |
+
def osnet_x0_75(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
538 |
+
# medium size (width x0.75)
|
539 |
+
model = OSNet(
|
540 |
+
num_classes,
|
541 |
+
blocks=[OSBlock, OSBlock, OSBlock],
|
542 |
+
layers=[2, 2, 2],
|
543 |
+
channels=[48, 192, 288, 384],
|
544 |
+
loss=loss,
|
545 |
+
**kwargs
|
546 |
+
)
|
547 |
+
if pretrained:
|
548 |
+
init_pretrained_weights(model, key='osnet_x0_75')
|
549 |
+
return model
|
550 |
+
|
551 |
+
|
552 |
+
def osnet_x0_5(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
553 |
+
# tiny size (width x0.5)
|
554 |
+
model = OSNet(
|
555 |
+
num_classes,
|
556 |
+
blocks=[OSBlock, OSBlock, OSBlock],
|
557 |
+
layers=[2, 2, 2],
|
558 |
+
channels=[32, 128, 192, 256],
|
559 |
+
loss=loss,
|
560 |
+
**kwargs
|
561 |
+
)
|
562 |
+
if pretrained:
|
563 |
+
init_pretrained_weights(model, key='osnet_x0_5')
|
564 |
+
return model
|
565 |
+
|
566 |
+
|
567 |
+
def osnet_x0_25(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
568 |
+
# very tiny size (width x0.25)
|
569 |
+
model = OSNet(
|
570 |
+
num_classes,
|
571 |
+
blocks=[OSBlock, OSBlock, OSBlock],
|
572 |
+
layers=[2, 2, 2],
|
573 |
+
channels=[16, 64, 96, 128],
|
574 |
+
loss=loss,
|
575 |
+
**kwargs
|
576 |
+
)
|
577 |
+
if pretrained:
|
578 |
+
init_pretrained_weights(model, key='osnet_x0_25')
|
579 |
+
return model
|
580 |
+
|
581 |
+
|
582 |
+
def osnet_ibn_x1_0(
|
583 |
+
num_classes=1000, pretrained=True, loss='softmax', **kwargs
|
584 |
+
):
|
585 |
+
# standard size (width x1.0) + IBN layer
|
586 |
+
# Ref: Pan et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net. ECCV, 2018.
|
587 |
+
model = OSNet(
|
588 |
+
num_classes,
|
589 |
+
blocks=[OSBlock, OSBlock, OSBlock],
|
590 |
+
layers=[2, 2, 2],
|
591 |
+
channels=[64, 256, 384, 512],
|
592 |
+
loss=loss,
|
593 |
+
IN=True,
|
594 |
+
**kwargs
|
595 |
+
)
|
596 |
+
if pretrained:
|
597 |
+
init_pretrained_weights(model, key='osnet_ibn_x1_0')
|
598 |
+
return model
|
trackers/strongsort/deep/models/osnet_ain.py
ADDED
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, absolute_import
|
2 |
+
import warnings
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'osnet_ain_x1_0', 'osnet_ain_x0_75', 'osnet_ain_x0_5', 'osnet_ain_x0_25'
|
9 |
+
]
|
10 |
+
|
11 |
+
pretrained_urls = {
|
12 |
+
'osnet_ain_x1_0':
|
13 |
+
'https://drive.google.com/uc?id=1-CaioD9NaqbHK_kzSMW8VE4_3KcsRjEo',
|
14 |
+
'osnet_ain_x0_75':
|
15 |
+
'https://drive.google.com/uc?id=1apy0hpsMypqstfencdH-jKIUEFOW4xoM',
|
16 |
+
'osnet_ain_x0_5':
|
17 |
+
'https://drive.google.com/uc?id=1KusKvEYyKGDTUBVRxRiz55G31wkihB6l',
|
18 |
+
'osnet_ain_x0_25':
|
19 |
+
'https://drive.google.com/uc?id=1SxQt2AvmEcgWNhaRb2xC4rP6ZwVDP0Wt'
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
##########
|
24 |
+
# Basic layers
|
25 |
+
##########
|
26 |
+
class ConvLayer(nn.Module):
|
27 |
+
"""Convolution layer (conv + bn + relu)."""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
in_channels,
|
32 |
+
out_channels,
|
33 |
+
kernel_size,
|
34 |
+
stride=1,
|
35 |
+
padding=0,
|
36 |
+
groups=1,
|
37 |
+
IN=False
|
38 |
+
):
|
39 |
+
super(ConvLayer, self).__init__()
|
40 |
+
self.conv = nn.Conv2d(
|
41 |
+
in_channels,
|
42 |
+
out_channels,
|
43 |
+
kernel_size,
|
44 |
+
stride=stride,
|
45 |
+
padding=padding,
|
46 |
+
bias=False,
|
47 |
+
groups=groups
|
48 |
+
)
|
49 |
+
if IN:
|
50 |
+
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
|
51 |
+
else:
|
52 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
53 |
+
self.relu = nn.ReLU()
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
x = self.conv(x)
|
57 |
+
x = self.bn(x)
|
58 |
+
return self.relu(x)
|
59 |
+
|
60 |
+
|
61 |
+
class Conv1x1(nn.Module):
|
62 |
+
"""1x1 convolution + bn + relu."""
|
63 |
+
|
64 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
65 |
+
super(Conv1x1, self).__init__()
|
66 |
+
self.conv = nn.Conv2d(
|
67 |
+
in_channels,
|
68 |
+
out_channels,
|
69 |
+
1,
|
70 |
+
stride=stride,
|
71 |
+
padding=0,
|
72 |
+
bias=False,
|
73 |
+
groups=groups
|
74 |
+
)
|
75 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
76 |
+
self.relu = nn.ReLU()
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
x = self.conv(x)
|
80 |
+
x = self.bn(x)
|
81 |
+
return self.relu(x)
|
82 |
+
|
83 |
+
|
84 |
+
class Conv1x1Linear(nn.Module):
|
85 |
+
"""1x1 convolution + bn (w/o non-linearity)."""
|
86 |
+
|
87 |
+
def __init__(self, in_channels, out_channels, stride=1, bn=True):
|
88 |
+
super(Conv1x1Linear, self).__init__()
|
89 |
+
self.conv = nn.Conv2d(
|
90 |
+
in_channels, out_channels, 1, stride=stride, padding=0, bias=False
|
91 |
+
)
|
92 |
+
self.bn = None
|
93 |
+
if bn:
|
94 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
x = self.conv(x)
|
98 |
+
if self.bn is not None:
|
99 |
+
x = self.bn(x)
|
100 |
+
return x
|
101 |
+
|
102 |
+
|
103 |
+
class Conv3x3(nn.Module):
|
104 |
+
"""3x3 convolution + bn + relu."""
|
105 |
+
|
106 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
107 |
+
super(Conv3x3, self).__init__()
|
108 |
+
self.conv = nn.Conv2d(
|
109 |
+
in_channels,
|
110 |
+
out_channels,
|
111 |
+
3,
|
112 |
+
stride=stride,
|
113 |
+
padding=1,
|
114 |
+
bias=False,
|
115 |
+
groups=groups
|
116 |
+
)
|
117 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
118 |
+
self.relu = nn.ReLU()
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
x = self.conv(x)
|
122 |
+
x = self.bn(x)
|
123 |
+
return self.relu(x)
|
124 |
+
|
125 |
+
|
126 |
+
class LightConv3x3(nn.Module):
|
127 |
+
"""Lightweight 3x3 convolution.
|
128 |
+
|
129 |
+
1x1 (linear) + dw 3x3 (nonlinear).
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self, in_channels, out_channels):
|
133 |
+
super(LightConv3x3, self).__init__()
|
134 |
+
self.conv1 = nn.Conv2d(
|
135 |
+
in_channels, out_channels, 1, stride=1, padding=0, bias=False
|
136 |
+
)
|
137 |
+
self.conv2 = nn.Conv2d(
|
138 |
+
out_channels,
|
139 |
+
out_channels,
|
140 |
+
3,
|
141 |
+
stride=1,
|
142 |
+
padding=1,
|
143 |
+
bias=False,
|
144 |
+
groups=out_channels
|
145 |
+
)
|
146 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
147 |
+
self.relu = nn.ReLU()
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
x = self.conv1(x)
|
151 |
+
x = self.conv2(x)
|
152 |
+
x = self.bn(x)
|
153 |
+
return self.relu(x)
|
154 |
+
|
155 |
+
|
156 |
+
class LightConvStream(nn.Module):
|
157 |
+
"""Lightweight convolution stream."""
|
158 |
+
|
159 |
+
def __init__(self, in_channels, out_channels, depth):
|
160 |
+
super(LightConvStream, self).__init__()
|
161 |
+
assert depth >= 1, 'depth must be equal to or larger than 1, but got {}'.format(
|
162 |
+
depth
|
163 |
+
)
|
164 |
+
layers = []
|
165 |
+
layers += [LightConv3x3(in_channels, out_channels)]
|
166 |
+
for i in range(depth - 1):
|
167 |
+
layers += [LightConv3x3(out_channels, out_channels)]
|
168 |
+
self.layers = nn.Sequential(*layers)
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
return self.layers(x)
|
172 |
+
|
173 |
+
|
174 |
+
##########
|
175 |
+
# Building blocks for omni-scale feature learning
|
176 |
+
##########
|
177 |
+
class ChannelGate(nn.Module):
|
178 |
+
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
|
179 |
+
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
in_channels,
|
183 |
+
num_gates=None,
|
184 |
+
return_gates=False,
|
185 |
+
gate_activation='sigmoid',
|
186 |
+
reduction=16,
|
187 |
+
layer_norm=False
|
188 |
+
):
|
189 |
+
super(ChannelGate, self).__init__()
|
190 |
+
if num_gates is None:
|
191 |
+
num_gates = in_channels
|
192 |
+
self.return_gates = return_gates
|
193 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
194 |
+
self.fc1 = nn.Conv2d(
|
195 |
+
in_channels,
|
196 |
+
in_channels // reduction,
|
197 |
+
kernel_size=1,
|
198 |
+
bias=True,
|
199 |
+
padding=0
|
200 |
+
)
|
201 |
+
self.norm1 = None
|
202 |
+
if layer_norm:
|
203 |
+
self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
|
204 |
+
self.relu = nn.ReLU()
|
205 |
+
self.fc2 = nn.Conv2d(
|
206 |
+
in_channels // reduction,
|
207 |
+
num_gates,
|
208 |
+
kernel_size=1,
|
209 |
+
bias=True,
|
210 |
+
padding=0
|
211 |
+
)
|
212 |
+
if gate_activation == 'sigmoid':
|
213 |
+
self.gate_activation = nn.Sigmoid()
|
214 |
+
elif gate_activation == 'relu':
|
215 |
+
self.gate_activation = nn.ReLU()
|
216 |
+
elif gate_activation == 'linear':
|
217 |
+
self.gate_activation = None
|
218 |
+
else:
|
219 |
+
raise RuntimeError(
|
220 |
+
"Unknown gate activation: {}".format(gate_activation)
|
221 |
+
)
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
input = x
|
225 |
+
x = self.global_avgpool(x)
|
226 |
+
x = self.fc1(x)
|
227 |
+
if self.norm1 is not None:
|
228 |
+
x = self.norm1(x)
|
229 |
+
x = self.relu(x)
|
230 |
+
x = self.fc2(x)
|
231 |
+
if self.gate_activation is not None:
|
232 |
+
x = self.gate_activation(x)
|
233 |
+
if self.return_gates:
|
234 |
+
return x
|
235 |
+
return input * x
|
236 |
+
|
237 |
+
|
238 |
+
class OSBlock(nn.Module):
|
239 |
+
"""Omni-scale feature learning block."""
|
240 |
+
|
241 |
+
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
|
242 |
+
super(OSBlock, self).__init__()
|
243 |
+
assert T >= 1
|
244 |
+
assert out_channels >= reduction and out_channels % reduction == 0
|
245 |
+
mid_channels = out_channels // reduction
|
246 |
+
|
247 |
+
self.conv1 = Conv1x1(in_channels, mid_channels)
|
248 |
+
self.conv2 = nn.ModuleList()
|
249 |
+
for t in range(1, T + 1):
|
250 |
+
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
|
251 |
+
self.gate = ChannelGate(mid_channels)
|
252 |
+
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
|
253 |
+
self.downsample = None
|
254 |
+
if in_channels != out_channels:
|
255 |
+
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
256 |
+
|
257 |
+
def forward(self, x):
|
258 |
+
identity = x
|
259 |
+
x1 = self.conv1(x)
|
260 |
+
x2 = 0
|
261 |
+
for conv2_t in self.conv2:
|
262 |
+
x2_t = conv2_t(x1)
|
263 |
+
x2 = x2 + self.gate(x2_t)
|
264 |
+
x3 = self.conv3(x2)
|
265 |
+
if self.downsample is not None:
|
266 |
+
identity = self.downsample(identity)
|
267 |
+
out = x3 + identity
|
268 |
+
return F.relu(out)
|
269 |
+
|
270 |
+
|
271 |
+
class OSBlockINin(nn.Module):
|
272 |
+
"""Omni-scale feature learning block with instance normalization."""
|
273 |
+
|
274 |
+
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
|
275 |
+
super(OSBlockINin, self).__init__()
|
276 |
+
assert T >= 1
|
277 |
+
assert out_channels >= reduction and out_channels % reduction == 0
|
278 |
+
mid_channels = out_channels // reduction
|
279 |
+
|
280 |
+
self.conv1 = Conv1x1(in_channels, mid_channels)
|
281 |
+
self.conv2 = nn.ModuleList()
|
282 |
+
for t in range(1, T + 1):
|
283 |
+
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
|
284 |
+
self.gate = ChannelGate(mid_channels)
|
285 |
+
self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False)
|
286 |
+
self.downsample = None
|
287 |
+
if in_channels != out_channels:
|
288 |
+
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
289 |
+
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
|
290 |
+
|
291 |
+
def forward(self, x):
|
292 |
+
identity = x
|
293 |
+
x1 = self.conv1(x)
|
294 |
+
x2 = 0
|
295 |
+
for conv2_t in self.conv2:
|
296 |
+
x2_t = conv2_t(x1)
|
297 |
+
x2 = x2 + self.gate(x2_t)
|
298 |
+
x3 = self.conv3(x2)
|
299 |
+
x3 = self.IN(x3) # IN inside residual
|
300 |
+
if self.downsample is not None:
|
301 |
+
identity = self.downsample(identity)
|
302 |
+
out = x3 + identity
|
303 |
+
return F.relu(out)
|
304 |
+
|
305 |
+
|
306 |
+
##########
|
307 |
+
# Network architecture
|
308 |
+
##########
|
309 |
+
class OSNet(nn.Module):
|
310 |
+
"""Omni-Scale Network.
|
311 |
+
|
312 |
+
Reference:
|
313 |
+
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
|
314 |
+
- Zhou et al. Learning Generalisable Omni-Scale Representations
|
315 |
+
for Person Re-Identification. TPAMI, 2021.
|
316 |
+
"""
|
317 |
+
|
318 |
+
def __init__(
|
319 |
+
self,
|
320 |
+
num_classes,
|
321 |
+
blocks,
|
322 |
+
layers,
|
323 |
+
channels,
|
324 |
+
feature_dim=512,
|
325 |
+
loss='softmax',
|
326 |
+
conv1_IN=False,
|
327 |
+
**kwargs
|
328 |
+
):
|
329 |
+
super(OSNet, self).__init__()
|
330 |
+
num_blocks = len(blocks)
|
331 |
+
assert num_blocks == len(layers)
|
332 |
+
assert num_blocks == len(channels) - 1
|
333 |
+
self.loss = loss
|
334 |
+
self.feature_dim = feature_dim
|
335 |
+
|
336 |
+
# convolutional backbone
|
337 |
+
self.conv1 = ConvLayer(
|
338 |
+
3, channels[0], 7, stride=2, padding=3, IN=conv1_IN
|
339 |
+
)
|
340 |
+
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
341 |
+
self.conv2 = self._make_layer(
|
342 |
+
blocks[0], layers[0], channels[0], channels[1]
|
343 |
+
)
|
344 |
+
self.pool2 = nn.Sequential(
|
345 |
+
Conv1x1(channels[1], channels[1]), nn.AvgPool2d(2, stride=2)
|
346 |
+
)
|
347 |
+
self.conv3 = self._make_layer(
|
348 |
+
blocks[1], layers[1], channels[1], channels[2]
|
349 |
+
)
|
350 |
+
self.pool3 = nn.Sequential(
|
351 |
+
Conv1x1(channels[2], channels[2]), nn.AvgPool2d(2, stride=2)
|
352 |
+
)
|
353 |
+
self.conv4 = self._make_layer(
|
354 |
+
blocks[2], layers[2], channels[2], channels[3]
|
355 |
+
)
|
356 |
+
self.conv5 = Conv1x1(channels[3], channels[3])
|
357 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
358 |
+
# fully connected layer
|
359 |
+
self.fc = self._construct_fc_layer(
|
360 |
+
self.feature_dim, channels[3], dropout_p=None
|
361 |
+
)
|
362 |
+
# identity classification layer
|
363 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
364 |
+
|
365 |
+
self._init_params()
|
366 |
+
|
367 |
+
def _make_layer(self, blocks, layer, in_channels, out_channels):
|
368 |
+
layers = []
|
369 |
+
layers += [blocks[0](in_channels, out_channels)]
|
370 |
+
for i in range(1, len(blocks)):
|
371 |
+
layers += [blocks[i](out_channels, out_channels)]
|
372 |
+
return nn.Sequential(*layers)
|
373 |
+
|
374 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
375 |
+
if fc_dims is None or fc_dims < 0:
|
376 |
+
self.feature_dim = input_dim
|
377 |
+
return None
|
378 |
+
|
379 |
+
if isinstance(fc_dims, int):
|
380 |
+
fc_dims = [fc_dims]
|
381 |
+
|
382 |
+
layers = []
|
383 |
+
for dim in fc_dims:
|
384 |
+
layers.append(nn.Linear(input_dim, dim))
|
385 |
+
layers.append(nn.BatchNorm1d(dim))
|
386 |
+
layers.append(nn.ReLU())
|
387 |
+
if dropout_p is not None:
|
388 |
+
layers.append(nn.Dropout(p=dropout_p))
|
389 |
+
input_dim = dim
|
390 |
+
|
391 |
+
self.feature_dim = fc_dims[-1]
|
392 |
+
|
393 |
+
return nn.Sequential(*layers)
|
394 |
+
|
395 |
+
def _init_params(self):
|
396 |
+
for m in self.modules():
|
397 |
+
if isinstance(m, nn.Conv2d):
|
398 |
+
nn.init.kaiming_normal_(
|
399 |
+
m.weight, mode='fan_out', nonlinearity='relu'
|
400 |
+
)
|
401 |
+
if m.bias is not None:
|
402 |
+
nn.init.constant_(m.bias, 0)
|
403 |
+
|
404 |
+
elif isinstance(m, nn.BatchNorm2d):
|
405 |
+
nn.init.constant_(m.weight, 1)
|
406 |
+
nn.init.constant_(m.bias, 0)
|
407 |
+
|
408 |
+
elif isinstance(m, nn.BatchNorm1d):
|
409 |
+
nn.init.constant_(m.weight, 1)
|
410 |
+
nn.init.constant_(m.bias, 0)
|
411 |
+
|
412 |
+
elif isinstance(m, nn.InstanceNorm2d):
|
413 |
+
nn.init.constant_(m.weight, 1)
|
414 |
+
nn.init.constant_(m.bias, 0)
|
415 |
+
|
416 |
+
elif isinstance(m, nn.Linear):
|
417 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
418 |
+
if m.bias is not None:
|
419 |
+
nn.init.constant_(m.bias, 0)
|
420 |
+
|
421 |
+
def featuremaps(self, x):
|
422 |
+
x = self.conv1(x)
|
423 |
+
x = self.maxpool(x)
|
424 |
+
x = self.conv2(x)
|
425 |
+
x = self.pool2(x)
|
426 |
+
x = self.conv3(x)
|
427 |
+
x = self.pool3(x)
|
428 |
+
x = self.conv4(x)
|
429 |
+
x = self.conv5(x)
|
430 |
+
return x
|
431 |
+
|
432 |
+
def forward(self, x, return_featuremaps=False):
|
433 |
+
x = self.featuremaps(x)
|
434 |
+
if return_featuremaps:
|
435 |
+
return x
|
436 |
+
v = self.global_avgpool(x)
|
437 |
+
v = v.view(v.size(0), -1)
|
438 |
+
if self.fc is not None:
|
439 |
+
v = self.fc(v)
|
440 |
+
if not self.training:
|
441 |
+
return v
|
442 |
+
y = self.classifier(v)
|
443 |
+
if self.loss == 'softmax':
|
444 |
+
return y
|
445 |
+
elif self.loss == 'triplet':
|
446 |
+
return y, v
|
447 |
+
else:
|
448 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
449 |
+
|
450 |
+
|
451 |
+
def init_pretrained_weights(model, key=''):
|
452 |
+
"""Initializes model with pretrained weights.
|
453 |
+
|
454 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
455 |
+
"""
|
456 |
+
import os
|
457 |
+
import errno
|
458 |
+
import gdown
|
459 |
+
from collections import OrderedDict
|
460 |
+
|
461 |
+
def _get_torch_home():
|
462 |
+
ENV_TORCH_HOME = 'TORCH_HOME'
|
463 |
+
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
464 |
+
DEFAULT_CACHE_DIR = '~/.cache'
|
465 |
+
torch_home = os.path.expanduser(
|
466 |
+
os.getenv(
|
467 |
+
ENV_TORCH_HOME,
|
468 |
+
os.path.join(
|
469 |
+
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
|
470 |
+
)
|
471 |
+
)
|
472 |
+
)
|
473 |
+
return torch_home
|
474 |
+
|
475 |
+
torch_home = _get_torch_home()
|
476 |
+
model_dir = os.path.join(torch_home, 'checkpoints')
|
477 |
+
try:
|
478 |
+
os.makedirs(model_dir)
|
479 |
+
except OSError as e:
|
480 |
+
if e.errno == errno.EEXIST:
|
481 |
+
# Directory already exists, ignore.
|
482 |
+
pass
|
483 |
+
else:
|
484 |
+
# Unexpected OSError, re-raise.
|
485 |
+
raise
|
486 |
+
filename = key + '_imagenet.pth'
|
487 |
+
cached_file = os.path.join(model_dir, filename)
|
488 |
+
|
489 |
+
if not os.path.exists(cached_file):
|
490 |
+
gdown.download(pretrained_urls[key], cached_file, quiet=False)
|
491 |
+
|
492 |
+
state_dict = torch.load(cached_file)
|
493 |
+
model_dict = model.state_dict()
|
494 |
+
new_state_dict = OrderedDict()
|
495 |
+
matched_layers, discarded_layers = [], []
|
496 |
+
|
497 |
+
for k, v in state_dict.items():
|
498 |
+
if k.startswith('module.'):
|
499 |
+
k = k[7:] # discard module.
|
500 |
+
|
501 |
+
if k in model_dict and model_dict[k].size() == v.size():
|
502 |
+
new_state_dict[k] = v
|
503 |
+
matched_layers.append(k)
|
504 |
+
else:
|
505 |
+
discarded_layers.append(k)
|
506 |
+
|
507 |
+
model_dict.update(new_state_dict)
|
508 |
+
model.load_state_dict(model_dict)
|
509 |
+
|
510 |
+
if len(matched_layers) == 0:
|
511 |
+
warnings.warn(
|
512 |
+
'The pretrained weights from "{}" cannot be loaded, '
|
513 |
+
'please check the key names manually '
|
514 |
+
'(** ignored and continue **)'.format(cached_file)
|
515 |
+
)
|
516 |
+
else:
|
517 |
+
print(
|
518 |
+
'Successfully loaded imagenet pretrained weights from "{}"'.
|
519 |
+
format(cached_file)
|
520 |
+
)
|
521 |
+
if len(discarded_layers) > 0:
|
522 |
+
print(
|
523 |
+
'** The following layers are discarded '
|
524 |
+
'due to unmatched keys or layer size: {}'.
|
525 |
+
format(discarded_layers)
|
526 |
+
)
|
527 |
+
|
528 |
+
|
529 |
+
##########
|
530 |
+
# Instantiation
|
531 |
+
##########
|
532 |
+
def osnet_ain_x1_0(
|
533 |
+
num_classes=1000, pretrained=True, loss='softmax', **kwargs
|
534 |
+
):
|
535 |
+
model = OSNet(
|
536 |
+
num_classes,
|
537 |
+
blocks=[
|
538 |
+
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
|
539 |
+
[OSBlockINin, OSBlock]
|
540 |
+
],
|
541 |
+
layers=[2, 2, 2],
|
542 |
+
channels=[64, 256, 384, 512],
|
543 |
+
loss=loss,
|
544 |
+
conv1_IN=True,
|
545 |
+
**kwargs
|
546 |
+
)
|
547 |
+
if pretrained:
|
548 |
+
init_pretrained_weights(model, key='osnet_ain_x1_0')
|
549 |
+
return model
|
550 |
+
|
551 |
+
|
552 |
+
def osnet_ain_x0_75(
|
553 |
+
num_classes=1000, pretrained=True, loss='softmax', **kwargs
|
554 |
+
):
|
555 |
+
model = OSNet(
|
556 |
+
num_classes,
|
557 |
+
blocks=[
|
558 |
+
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
|
559 |
+
[OSBlockINin, OSBlock]
|
560 |
+
],
|
561 |
+
layers=[2, 2, 2],
|
562 |
+
channels=[48, 192, 288, 384],
|
563 |
+
loss=loss,
|
564 |
+
conv1_IN=True,
|
565 |
+
**kwargs
|
566 |
+
)
|
567 |
+
if pretrained:
|
568 |
+
init_pretrained_weights(model, key='osnet_ain_x0_75')
|
569 |
+
return model
|
570 |
+
|
571 |
+
|
572 |
+
def osnet_ain_x0_5(
|
573 |
+
num_classes=1000, pretrained=True, loss='softmax', **kwargs
|
574 |
+
):
|
575 |
+
model = OSNet(
|
576 |
+
num_classes,
|
577 |
+
blocks=[
|
578 |
+
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
|
579 |
+
[OSBlockINin, OSBlock]
|
580 |
+
],
|
581 |
+
layers=[2, 2, 2],
|
582 |
+
channels=[32, 128, 192, 256],
|
583 |
+
loss=loss,
|
584 |
+
conv1_IN=True,
|
585 |
+
**kwargs
|
586 |
+
)
|
587 |
+
if pretrained:
|
588 |
+
init_pretrained_weights(model, key='osnet_ain_x0_5')
|
589 |
+
return model
|
590 |
+
|
591 |
+
|
592 |
+
def osnet_ain_x0_25(
|
593 |
+
num_classes=1000, pretrained=True, loss='softmax', **kwargs
|
594 |
+
):
|
595 |
+
model = OSNet(
|
596 |
+
num_classes,
|
597 |
+
blocks=[
|
598 |
+
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
|
599 |
+
[OSBlockINin, OSBlock]
|
600 |
+
],
|
601 |
+
layers=[2, 2, 2],
|
602 |
+
channels=[16, 64, 96, 128],
|
603 |
+
loss=loss,
|
604 |
+
conv1_IN=True,
|
605 |
+
**kwargs
|
606 |
+
)
|
607 |
+
if pretrained:
|
608 |
+
init_pretrained_weights(model, key='osnet_ain_x0_25')
|
609 |
+
return model
|
trackers/strongsort/deep/models/pcb.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, absolute_import
|
2 |
+
import torch.utils.model_zoo as model_zoo
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
__all__ = ['pcb_p6', 'pcb_p4']
|
7 |
+
|
8 |
+
model_urls = {
|
9 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
10 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
11 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
12 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
13 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
18 |
+
"""3x3 convolution with padding"""
|
19 |
+
return nn.Conv2d(
|
20 |
+
in_planes,
|
21 |
+
out_planes,
|
22 |
+
kernel_size=3,
|
23 |
+
stride=stride,
|
24 |
+
padding=1,
|
25 |
+
bias=False
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class BasicBlock(nn.Module):
|
30 |
+
expansion = 1
|
31 |
+
|
32 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
33 |
+
super(BasicBlock, self).__init__()
|
34 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
35 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
36 |
+
self.relu = nn.ReLU(inplace=True)
|
37 |
+
self.conv2 = conv3x3(planes, planes)
|
38 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
39 |
+
self.downsample = downsample
|
40 |
+
self.stride = stride
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
residual = x
|
44 |
+
|
45 |
+
out = self.conv1(x)
|
46 |
+
out = self.bn1(out)
|
47 |
+
out = self.relu(out)
|
48 |
+
|
49 |
+
out = self.conv2(out)
|
50 |
+
out = self.bn2(out)
|
51 |
+
|
52 |
+
if self.downsample is not None:
|
53 |
+
residual = self.downsample(x)
|
54 |
+
|
55 |
+
out += residual
|
56 |
+
out = self.relu(out)
|
57 |
+
|
58 |
+
return out
|
59 |
+
|
60 |
+
|
61 |
+
class Bottleneck(nn.Module):
|
62 |
+
expansion = 4
|
63 |
+
|
64 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
65 |
+
super(Bottleneck, self).__init__()
|
66 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
67 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
68 |
+
self.conv2 = nn.Conv2d(
|
69 |
+
planes,
|
70 |
+
planes,
|
71 |
+
kernel_size=3,
|
72 |
+
stride=stride,
|
73 |
+
padding=1,
|
74 |
+
bias=False
|
75 |
+
)
|
76 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
77 |
+
self.conv3 = nn.Conv2d(
|
78 |
+
planes, planes * self.expansion, kernel_size=1, bias=False
|
79 |
+
)
|
80 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
81 |
+
self.relu = nn.ReLU(inplace=True)
|
82 |
+
self.downsample = downsample
|
83 |
+
self.stride = stride
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
residual = x
|
87 |
+
|
88 |
+
out = self.conv1(x)
|
89 |
+
out = self.bn1(out)
|
90 |
+
out = self.relu(out)
|
91 |
+
|
92 |
+
out = self.conv2(out)
|
93 |
+
out = self.bn2(out)
|
94 |
+
out = self.relu(out)
|
95 |
+
|
96 |
+
out = self.conv3(out)
|
97 |
+
out = self.bn3(out)
|
98 |
+
|
99 |
+
if self.downsample is not None:
|
100 |
+
residual = self.downsample(x)
|
101 |
+
|
102 |
+
out += residual
|
103 |
+
out = self.relu(out)
|
104 |
+
|
105 |
+
return out
|
106 |
+
|
107 |
+
|
108 |
+
class DimReduceLayer(nn.Module):
|
109 |
+
|
110 |
+
def __init__(self, in_channels, out_channels, nonlinear):
|
111 |
+
super(DimReduceLayer, self).__init__()
|
112 |
+
layers = []
|
113 |
+
layers.append(
|
114 |
+
nn.Conv2d(
|
115 |
+
in_channels, out_channels, 1, stride=1, padding=0, bias=False
|
116 |
+
)
|
117 |
+
)
|
118 |
+
layers.append(nn.BatchNorm2d(out_channels))
|
119 |
+
|
120 |
+
if nonlinear == 'relu':
|
121 |
+
layers.append(nn.ReLU(inplace=True))
|
122 |
+
elif nonlinear == 'leakyrelu':
|
123 |
+
layers.append(nn.LeakyReLU(0.1))
|
124 |
+
|
125 |
+
self.layers = nn.Sequential(*layers)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
return self.layers(x)
|
129 |
+
|
130 |
+
|
131 |
+
class PCB(nn.Module):
|
132 |
+
"""Part-based Convolutional Baseline.
|
133 |
+
|
134 |
+
Reference:
|
135 |
+
Sun et al. Beyond Part Models: Person Retrieval with Refined
|
136 |
+
Part Pooling (and A Strong Convolutional Baseline). ECCV 2018.
|
137 |
+
|
138 |
+
Public keys:
|
139 |
+
- ``pcb_p4``: PCB with 4-part strips.
|
140 |
+
- ``pcb_p6``: PCB with 6-part strips.
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(
|
144 |
+
self,
|
145 |
+
num_classes,
|
146 |
+
loss,
|
147 |
+
block,
|
148 |
+
layers,
|
149 |
+
parts=6,
|
150 |
+
reduced_dim=256,
|
151 |
+
nonlinear='relu',
|
152 |
+
**kwargs
|
153 |
+
):
|
154 |
+
self.inplanes = 64
|
155 |
+
super(PCB, self).__init__()
|
156 |
+
self.loss = loss
|
157 |
+
self.parts = parts
|
158 |
+
self.feature_dim = 512 * block.expansion
|
159 |
+
|
160 |
+
# backbone network
|
161 |
+
self.conv1 = nn.Conv2d(
|
162 |
+
3, 64, kernel_size=7, stride=2, padding=3, bias=False
|
163 |
+
)
|
164 |
+
self.bn1 = nn.BatchNorm2d(64)
|
165 |
+
self.relu = nn.ReLU(inplace=True)
|
166 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
167 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
168 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
169 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
170 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1)
|
171 |
+
|
172 |
+
# pcb layers
|
173 |
+
self.parts_avgpool = nn.AdaptiveAvgPool2d((self.parts, 1))
|
174 |
+
self.dropout = nn.Dropout(p=0.5)
|
175 |
+
self.conv5 = DimReduceLayer(
|
176 |
+
512 * block.expansion, reduced_dim, nonlinear=nonlinear
|
177 |
+
)
|
178 |
+
self.feature_dim = reduced_dim
|
179 |
+
self.classifier = nn.ModuleList(
|
180 |
+
[
|
181 |
+
nn.Linear(self.feature_dim, num_classes)
|
182 |
+
for _ in range(self.parts)
|
183 |
+
]
|
184 |
+
)
|
185 |
+
|
186 |
+
self._init_params()
|
187 |
+
|
188 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
189 |
+
downsample = None
|
190 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
191 |
+
downsample = nn.Sequential(
|
192 |
+
nn.Conv2d(
|
193 |
+
self.inplanes,
|
194 |
+
planes * block.expansion,
|
195 |
+
kernel_size=1,
|
196 |
+
stride=stride,
|
197 |
+
bias=False
|
198 |
+
),
|
199 |
+
nn.BatchNorm2d(planes * block.expansion),
|
200 |
+
)
|
201 |
+
|
202 |
+
layers = []
|
203 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
204 |
+
self.inplanes = planes * block.expansion
|
205 |
+
for i in range(1, blocks):
|
206 |
+
layers.append(block(self.inplanes, planes))
|
207 |
+
|
208 |
+
return nn.Sequential(*layers)
|
209 |
+
|
210 |
+
def _init_params(self):
|
211 |
+
for m in self.modules():
|
212 |
+
if isinstance(m, nn.Conv2d):
|
213 |
+
nn.init.kaiming_normal_(
|
214 |
+
m.weight, mode='fan_out', nonlinearity='relu'
|
215 |
+
)
|
216 |
+
if m.bias is not None:
|
217 |
+
nn.init.constant_(m.bias, 0)
|
218 |
+
elif isinstance(m, nn.BatchNorm2d):
|
219 |
+
nn.init.constant_(m.weight, 1)
|
220 |
+
nn.init.constant_(m.bias, 0)
|
221 |
+
elif isinstance(m, nn.BatchNorm1d):
|
222 |
+
nn.init.constant_(m.weight, 1)
|
223 |
+
nn.init.constant_(m.bias, 0)
|
224 |
+
elif isinstance(m, nn.Linear):
|
225 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
226 |
+
if m.bias is not None:
|
227 |
+
nn.init.constant_(m.bias, 0)
|
228 |
+
|
229 |
+
def featuremaps(self, x):
|
230 |
+
x = self.conv1(x)
|
231 |
+
x = self.bn1(x)
|
232 |
+
x = self.relu(x)
|
233 |
+
x = self.maxpool(x)
|
234 |
+
x = self.layer1(x)
|
235 |
+
x = self.layer2(x)
|
236 |
+
x = self.layer3(x)
|
237 |
+
x = self.layer4(x)
|
238 |
+
return x
|
239 |
+
|
240 |
+
def forward(self, x):
|
241 |
+
f = self.featuremaps(x)
|
242 |
+
v_g = self.parts_avgpool(f)
|
243 |
+
|
244 |
+
if not self.training:
|
245 |
+
v_g = F.normalize(v_g, p=2, dim=1)
|
246 |
+
return v_g.view(v_g.size(0), -1)
|
247 |
+
|
248 |
+
v_g = self.dropout(v_g)
|
249 |
+
v_h = self.conv5(v_g)
|
250 |
+
|
251 |
+
y = []
|
252 |
+
for i in range(self.parts):
|
253 |
+
v_h_i = v_h[:, :, i, :]
|
254 |
+
v_h_i = v_h_i.view(v_h_i.size(0), -1)
|
255 |
+
y_i = self.classifier[i](v_h_i)
|
256 |
+
y.append(y_i)
|
257 |
+
|
258 |
+
if self.loss == 'softmax':
|
259 |
+
return y
|
260 |
+
elif self.loss == 'triplet':
|
261 |
+
v_g = F.normalize(v_g, p=2, dim=1)
|
262 |
+
return y, v_g.view(v_g.size(0), -1)
|
263 |
+
else:
|
264 |
+
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
265 |
+
|
266 |
+
|
267 |
+
def init_pretrained_weights(model, model_url):
|
268 |
+
"""Initializes model with pretrained weights.
|
269 |
+
|
270 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
271 |
+
"""
|
272 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
273 |
+
model_dict = model.state_dict()
|
274 |
+
pretrain_dict = {
|
275 |
+
k: v
|
276 |
+
for k, v in pretrain_dict.items()
|
277 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
278 |
+
}
|
279 |
+
model_dict.update(pretrain_dict)
|
280 |
+
model.load_state_dict(model_dict)
|
281 |
+
|
282 |
+
|
283 |
+
def pcb_p6(num_classes, loss='softmax', pretrained=True, **kwargs):
|
284 |
+
model = PCB(
|
285 |
+
num_classes=num_classes,
|
286 |
+
loss=loss,
|
287 |
+
block=Bottleneck,
|
288 |
+
layers=[3, 4, 6, 3],
|
289 |
+
last_stride=1,
|
290 |
+
parts=6,
|
291 |
+
reduced_dim=256,
|
292 |
+
nonlinear='relu',
|
293 |
+
**kwargs
|
294 |
+
)
|
295 |
+
if pretrained:
|
296 |
+
init_pretrained_weights(model, model_urls['resnet50'])
|
297 |
+
return model
|
298 |
+
|
299 |
+
|
300 |
+
def pcb_p4(num_classes, loss='softmax', pretrained=True, **kwargs):
|
301 |
+
model = PCB(
|
302 |
+
num_classes=num_classes,
|
303 |
+
loss=loss,
|
304 |
+
block=Bottleneck,
|
305 |
+
layers=[3, 4, 6, 3],
|
306 |
+
last_stride=1,
|
307 |
+
parts=4,
|
308 |
+
reduced_dim=256,
|
309 |
+
nonlinear='relu',
|
310 |
+
**kwargs
|
311 |
+
)
|
312 |
+
if pretrained:
|
313 |
+
init_pretrained_weights(model, model_urls['resnet50'])
|
314 |
+
return model
|
trackers/strongsort/deep/models/resnet.py
ADDED
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code source: https://github.com/pytorch/vision
|
3 |
+
"""
|
4 |
+
from __future__ import division, absolute_import
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
|
10 |
+
'resnext50_32x4d', 'resnext101_32x8d', 'resnet50_fc512'
|
11 |
+
]
|
12 |
+
|
13 |
+
model_urls = {
|
14 |
+
'resnet18':
|
15 |
+
'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
16 |
+
'resnet34':
|
17 |
+
'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
18 |
+
'resnet50':
|
19 |
+
'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
20 |
+
'resnet101':
|
21 |
+
'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
22 |
+
'resnet152':
|
23 |
+
'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
24 |
+
'resnext50_32x4d':
|
25 |
+
'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
26 |
+
'resnext101_32x8d':
|
27 |
+
'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
32 |
+
"""3x3 convolution with padding"""
|
33 |
+
return nn.Conv2d(
|
34 |
+
in_planes,
|
35 |
+
out_planes,
|
36 |
+
kernel_size=3,
|
37 |
+
stride=stride,
|
38 |
+
padding=dilation,
|
39 |
+
groups=groups,
|
40 |
+
bias=False,
|
41 |
+
dilation=dilation
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
46 |
+
"""1x1 convolution"""
|
47 |
+
return nn.Conv2d(
|
48 |
+
in_planes, out_planes, kernel_size=1, stride=stride, bias=False
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
class BasicBlock(nn.Module):
|
53 |
+
expansion = 1
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
inplanes,
|
58 |
+
planes,
|
59 |
+
stride=1,
|
60 |
+
downsample=None,
|
61 |
+
groups=1,
|
62 |
+
base_width=64,
|
63 |
+
dilation=1,
|
64 |
+
norm_layer=None
|
65 |
+
):
|
66 |
+
super(BasicBlock, self).__init__()
|
67 |
+
if norm_layer is None:
|
68 |
+
norm_layer = nn.BatchNorm2d
|
69 |
+
if groups != 1 or base_width != 64:
|
70 |
+
raise ValueError(
|
71 |
+
'BasicBlock only supports groups=1 and base_width=64'
|
72 |
+
)
|
73 |
+
if dilation > 1:
|
74 |
+
raise NotImplementedError(
|
75 |
+
"Dilation > 1 not supported in BasicBlock"
|
76 |
+
)
|
77 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
78 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
79 |
+
self.bn1 = norm_layer(planes)
|
80 |
+
self.relu = nn.ReLU(inplace=True)
|
81 |
+
self.conv2 = conv3x3(planes, planes)
|
82 |
+
self.bn2 = norm_layer(planes)
|
83 |
+
self.downsample = downsample
|
84 |
+
self.stride = stride
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
identity = x
|
88 |
+
|
89 |
+
out = self.conv1(x)
|
90 |
+
out = self.bn1(out)
|
91 |
+
out = self.relu(out)
|
92 |
+
|
93 |
+
out = self.conv2(out)
|
94 |
+
out = self.bn2(out)
|
95 |
+
|
96 |
+
if self.downsample is not None:
|
97 |
+
identity = self.downsample(x)
|
98 |
+
|
99 |
+
out += identity
|
100 |
+
out = self.relu(out)
|
101 |
+
|
102 |
+
return out
|
103 |
+
|
104 |
+
|
105 |
+
class Bottleneck(nn.Module):
|
106 |
+
expansion = 4
|
107 |
+
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
inplanes,
|
111 |
+
planes,
|
112 |
+
stride=1,
|
113 |
+
downsample=None,
|
114 |
+
groups=1,
|
115 |
+
base_width=64,
|
116 |
+
dilation=1,
|
117 |
+
norm_layer=None
|
118 |
+
):
|
119 |
+
super(Bottleneck, self).__init__()
|
120 |
+
if norm_layer is None:
|
121 |
+
norm_layer = nn.BatchNorm2d
|
122 |
+
width = int(planes * (base_width/64.)) * groups
|
123 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
124 |
+
self.conv1 = conv1x1(inplanes, width)
|
125 |
+
self.bn1 = norm_layer(width)
|
126 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
127 |
+
self.bn2 = norm_layer(width)
|
128 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
129 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
130 |
+
self.relu = nn.ReLU(inplace=True)
|
131 |
+
self.downsample = downsample
|
132 |
+
self.stride = stride
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
identity = x
|
136 |
+
|
137 |
+
out = self.conv1(x)
|
138 |
+
out = self.bn1(out)
|
139 |
+
out = self.relu(out)
|
140 |
+
|
141 |
+
out = self.conv2(out)
|
142 |
+
out = self.bn2(out)
|
143 |
+
out = self.relu(out)
|
144 |
+
|
145 |
+
out = self.conv3(out)
|
146 |
+
out = self.bn3(out)
|
147 |
+
|
148 |
+
if self.downsample is not None:
|
149 |
+
identity = self.downsample(x)
|
150 |
+
|
151 |
+
out += identity
|
152 |
+
out = self.relu(out)
|
153 |
+
|
154 |
+
return out
|
155 |
+
|
156 |
+
|
157 |
+
class ResNet(nn.Module):
|
158 |
+
"""Residual network.
|
159 |
+
|
160 |
+
Reference:
|
161 |
+
- He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
|
162 |
+
- Xie et al. Aggregated Residual Transformations for Deep Neural Networks. CVPR 2017.
|
163 |
+
|
164 |
+
Public keys:
|
165 |
+
- ``resnet18``: ResNet18.
|
166 |
+
- ``resnet34``: ResNet34.
|
167 |
+
- ``resnet50``: ResNet50.
|
168 |
+
- ``resnet101``: ResNet101.
|
169 |
+
- ``resnet152``: ResNet152.
|
170 |
+
- ``resnext50_32x4d``: ResNeXt50.
|
171 |
+
- ``resnext101_32x8d``: ResNeXt101.
|
172 |
+
- ``resnet50_fc512``: ResNet50 + FC.
|
173 |
+
"""
|
174 |
+
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
num_classes,
|
178 |
+
loss,
|
179 |
+
block,
|
180 |
+
layers,
|
181 |
+
zero_init_residual=False,
|
182 |
+
groups=1,
|
183 |
+
width_per_group=64,
|
184 |
+
replace_stride_with_dilation=None,
|
185 |
+
norm_layer=None,
|
186 |
+
last_stride=2,
|
187 |
+
fc_dims=None,
|
188 |
+
dropout_p=None,
|
189 |
+
**kwargs
|
190 |
+
):
|
191 |
+
super(ResNet, self).__init__()
|
192 |
+
if norm_layer is None:
|
193 |
+
norm_layer = nn.BatchNorm2d
|
194 |
+
self._norm_layer = norm_layer
|
195 |
+
self.loss = loss
|
196 |
+
self.feature_dim = 512 * block.expansion
|
197 |
+
self.inplanes = 64
|
198 |
+
self.dilation = 1
|
199 |
+
if replace_stride_with_dilation is None:
|
200 |
+
# each element in the tuple indicates if we should replace
|
201 |
+
# the 2x2 stride with a dilated convolution instead
|
202 |
+
replace_stride_with_dilation = [False, False, False]
|
203 |
+
if len(replace_stride_with_dilation) != 3:
|
204 |
+
raise ValueError(
|
205 |
+
"replace_stride_with_dilation should be None "
|
206 |
+
"or a 3-element tuple, got {}".
|
207 |
+
format(replace_stride_with_dilation)
|
208 |
+
)
|
209 |
+
self.groups = groups
|
210 |
+
self.base_width = width_per_group
|
211 |
+
self.conv1 = nn.Conv2d(
|
212 |
+
3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
|
213 |
+
)
|
214 |
+
self.bn1 = norm_layer(self.inplanes)
|
215 |
+
self.relu = nn.ReLU(inplace=True)
|
216 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
217 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
218 |
+
self.layer2 = self._make_layer(
|
219 |
+
block,
|
220 |
+
128,
|
221 |
+
layers[1],
|
222 |
+
stride=2,
|
223 |
+
dilate=replace_stride_with_dilation[0]
|
224 |
+
)
|
225 |
+
self.layer3 = self._make_layer(
|
226 |
+
block,
|
227 |
+
256,
|
228 |
+
layers[2],
|
229 |
+
stride=2,
|
230 |
+
dilate=replace_stride_with_dilation[1]
|
231 |
+
)
|
232 |
+
self.layer4 = self._make_layer(
|
233 |
+
block,
|
234 |
+
512,
|
235 |
+
layers[3],
|
236 |
+
stride=last_stride,
|
237 |
+
dilate=replace_stride_with_dilation[2]
|
238 |
+
)
|
239 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
240 |
+
self.fc = self._construct_fc_layer(
|
241 |
+
fc_dims, 512 * block.expansion, dropout_p
|
242 |
+
)
|
243 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
244 |
+
|
245 |
+
self._init_params()
|
246 |
+
|
247 |
+
# Zero-initialize the last BN in each residual branch,
|
248 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
249 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
250 |
+
if zero_init_residual:
|
251 |
+
for m in self.modules():
|
252 |
+
if isinstance(m, Bottleneck):
|
253 |
+
nn.init.constant_(m.bn3.weight, 0)
|
254 |
+
elif isinstance(m, BasicBlock):
|
255 |
+
nn.init.constant_(m.bn2.weight, 0)
|
256 |
+
|
257 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
258 |
+
norm_layer = self._norm_layer
|
259 |
+
downsample = None
|
260 |
+
previous_dilation = self.dilation
|
261 |
+
if dilate:
|
262 |
+
self.dilation *= stride
|
263 |
+
stride = 1
|
264 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
265 |
+
downsample = nn.Sequential(
|
266 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
267 |
+
norm_layer(planes * block.expansion),
|
268 |
+
)
|
269 |
+
|
270 |
+
layers = []
|
271 |
+
layers.append(
|
272 |
+
block(
|
273 |
+
self.inplanes, planes, stride, downsample, self.groups,
|
274 |
+
self.base_width, previous_dilation, norm_layer
|
275 |
+
)
|
276 |
+
)
|
277 |
+
self.inplanes = planes * block.expansion
|
278 |
+
for _ in range(1, blocks):
|
279 |
+
layers.append(
|
280 |
+
block(
|
281 |
+
self.inplanes,
|
282 |
+
planes,
|
283 |
+
groups=self.groups,
|
284 |
+
base_width=self.base_width,
|
285 |
+
dilation=self.dilation,
|
286 |
+
norm_layer=norm_layer
|
287 |
+
)
|
288 |
+
)
|
289 |
+
|
290 |
+
return nn.Sequential(*layers)
|
291 |
+
|
292 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
293 |
+
"""Constructs fully connected layer
|
294 |
+
|
295 |
+
Args:
|
296 |
+
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
297 |
+
input_dim (int): input dimension
|
298 |
+
dropout_p (float): dropout probability, if None, dropout is unused
|
299 |
+
"""
|
300 |
+
if fc_dims is None:
|
301 |
+
self.feature_dim = input_dim
|
302 |
+
return None
|
303 |
+
|
304 |
+
assert isinstance(
|
305 |
+
fc_dims, (list, tuple)
|
306 |
+
), 'fc_dims must be either list or tuple, but got {}'.format(
|
307 |
+
type(fc_dims)
|
308 |
+
)
|
309 |
+
|
310 |
+
layers = []
|
311 |
+
for dim in fc_dims:
|
312 |
+
layers.append(nn.Linear(input_dim, dim))
|
313 |
+
layers.append(nn.BatchNorm1d(dim))
|
314 |
+
layers.append(nn.ReLU(inplace=True))
|
315 |
+
if dropout_p is not None:
|
316 |
+
layers.append(nn.Dropout(p=dropout_p))
|
317 |
+
input_dim = dim
|
318 |
+
|
319 |
+
self.feature_dim = fc_dims[-1]
|
320 |
+
|
321 |
+
return nn.Sequential(*layers)
|
322 |
+
|
323 |
+
def _init_params(self):
|
324 |
+
for m in self.modules():
|
325 |
+
if isinstance(m, nn.Conv2d):
|
326 |
+
nn.init.kaiming_normal_(
|
327 |
+
m.weight, mode='fan_out', nonlinearity='relu'
|
328 |
+
)
|
329 |
+
if m.bias is not None:
|
330 |
+
nn.init.constant_(m.bias, 0)
|
331 |
+
elif isinstance(m, nn.BatchNorm2d):
|
332 |
+
nn.init.constant_(m.weight, 1)
|
333 |
+
nn.init.constant_(m.bias, 0)
|
334 |
+
elif isinstance(m, nn.BatchNorm1d):
|
335 |
+
nn.init.constant_(m.weight, 1)
|
336 |
+
nn.init.constant_(m.bias, 0)
|
337 |
+
elif isinstance(m, nn.Linear):
|
338 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
339 |
+
if m.bias is not None:
|
340 |
+
nn.init.constant_(m.bias, 0)
|
341 |
+
|
342 |
+
def featuremaps(self, x):
|
343 |
+
x = self.conv1(x)
|
344 |
+
x = self.bn1(x)
|
345 |
+
x = self.relu(x)
|
346 |
+
x = self.maxpool(x)
|
347 |
+
x = self.layer1(x)
|
348 |
+
x = self.layer2(x)
|
349 |
+
x = self.layer3(x)
|
350 |
+
x = self.layer4(x)
|
351 |
+
return x
|
352 |
+
|
353 |
+
def forward(self, x):
|
354 |
+
f = self.featuremaps(x)
|
355 |
+
v = self.global_avgpool(f)
|
356 |
+
v = v.view(v.size(0), -1)
|
357 |
+
|
358 |
+
if self.fc is not None:
|
359 |
+
v = self.fc(v)
|
360 |
+
|
361 |
+
if not self.training:
|
362 |
+
return v
|
363 |
+
|
364 |
+
y = self.classifier(v)
|
365 |
+
|
366 |
+
if self.loss == 'softmax':
|
367 |
+
return y
|
368 |
+
elif self.loss == 'triplet':
|
369 |
+
return y, v
|
370 |
+
else:
|
371 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
372 |
+
|
373 |
+
|
374 |
+
def init_pretrained_weights(model, model_url):
|
375 |
+
"""Initializes model with pretrained weights.
|
376 |
+
|
377 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
378 |
+
"""
|
379 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
380 |
+
model_dict = model.state_dict()
|
381 |
+
pretrain_dict = {
|
382 |
+
k: v
|
383 |
+
for k, v in pretrain_dict.items()
|
384 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
385 |
+
}
|
386 |
+
model_dict.update(pretrain_dict)
|
387 |
+
model.load_state_dict(model_dict)
|
388 |
+
|
389 |
+
|
390 |
+
"""ResNet"""
|
391 |
+
|
392 |
+
|
393 |
+
def resnet18(num_classes, loss='softmax', pretrained=True, **kwargs):
|
394 |
+
model = ResNet(
|
395 |
+
num_classes=num_classes,
|
396 |
+
loss=loss,
|
397 |
+
block=BasicBlock,
|
398 |
+
layers=[2, 2, 2, 2],
|
399 |
+
last_stride=2,
|
400 |
+
fc_dims=None,
|
401 |
+
dropout_p=None,
|
402 |
+
**kwargs
|
403 |
+
)
|
404 |
+
if pretrained:
|
405 |
+
init_pretrained_weights(model, model_urls['resnet18'])
|
406 |
+
return model
|
407 |
+
|
408 |
+
|
409 |
+
def resnet34(num_classes, loss='softmax', pretrained=True, **kwargs):
|
410 |
+
model = ResNet(
|
411 |
+
num_classes=num_classes,
|
412 |
+
loss=loss,
|
413 |
+
block=BasicBlock,
|
414 |
+
layers=[3, 4, 6, 3],
|
415 |
+
last_stride=2,
|
416 |
+
fc_dims=None,
|
417 |
+
dropout_p=None,
|
418 |
+
**kwargs
|
419 |
+
)
|
420 |
+
if pretrained:
|
421 |
+
init_pretrained_weights(model, model_urls['resnet34'])
|
422 |
+
return model
|
423 |
+
|
424 |
+
|
425 |
+
def resnet50(num_classes, loss='softmax', pretrained=True, **kwargs):
|
426 |
+
model = ResNet(
|
427 |
+
num_classes=num_classes,
|
428 |
+
loss=loss,
|
429 |
+
block=Bottleneck,
|
430 |
+
layers=[3, 4, 6, 3],
|
431 |
+
last_stride=2,
|
432 |
+
fc_dims=None,
|
433 |
+
dropout_p=None,
|
434 |
+
**kwargs
|
435 |
+
)
|
436 |
+
if pretrained:
|
437 |
+
init_pretrained_weights(model, model_urls['resnet50'])
|
438 |
+
return model
|
439 |
+
|
440 |
+
|
441 |
+
def resnet101(num_classes, loss='softmax', pretrained=True, **kwargs):
|
442 |
+
model = ResNet(
|
443 |
+
num_classes=num_classes,
|
444 |
+
loss=loss,
|
445 |
+
block=Bottleneck,
|
446 |
+
layers=[3, 4, 23, 3],
|
447 |
+
last_stride=2,
|
448 |
+
fc_dims=None,
|
449 |
+
dropout_p=None,
|
450 |
+
**kwargs
|
451 |
+
)
|
452 |
+
if pretrained:
|
453 |
+
init_pretrained_weights(model, model_urls['resnet101'])
|
454 |
+
return model
|
455 |
+
|
456 |
+
|
457 |
+
def resnet152(num_classes, loss='softmax', pretrained=True, **kwargs):
|
458 |
+
model = ResNet(
|
459 |
+
num_classes=num_classes,
|
460 |
+
loss=loss,
|
461 |
+
block=Bottleneck,
|
462 |
+
layers=[3, 8, 36, 3],
|
463 |
+
last_stride=2,
|
464 |
+
fc_dims=None,
|
465 |
+
dropout_p=None,
|
466 |
+
**kwargs
|
467 |
+
)
|
468 |
+
if pretrained:
|
469 |
+
init_pretrained_weights(model, model_urls['resnet152'])
|
470 |
+
return model
|
471 |
+
|
472 |
+
|
473 |
+
"""ResNeXt"""
|
474 |
+
|
475 |
+
|
476 |
+
def resnext50_32x4d(num_classes, loss='softmax', pretrained=True, **kwargs):
|
477 |
+
model = ResNet(
|
478 |
+
num_classes=num_classes,
|
479 |
+
loss=loss,
|
480 |
+
block=Bottleneck,
|
481 |
+
layers=[3, 4, 6, 3],
|
482 |
+
last_stride=2,
|
483 |
+
fc_dims=None,
|
484 |
+
dropout_p=None,
|
485 |
+
groups=32,
|
486 |
+
width_per_group=4,
|
487 |
+
**kwargs
|
488 |
+
)
|
489 |
+
if pretrained:
|
490 |
+
init_pretrained_weights(model, model_urls['resnext50_32x4d'])
|
491 |
+
return model
|
492 |
+
|
493 |
+
|
494 |
+
def resnext101_32x8d(num_classes, loss='softmax', pretrained=True, **kwargs):
|
495 |
+
model = ResNet(
|
496 |
+
num_classes=num_classes,
|
497 |
+
loss=loss,
|
498 |
+
block=Bottleneck,
|
499 |
+
layers=[3, 4, 23, 3],
|
500 |
+
last_stride=2,
|
501 |
+
fc_dims=None,
|
502 |
+
dropout_p=None,
|
503 |
+
groups=32,
|
504 |
+
width_per_group=8,
|
505 |
+
**kwargs
|
506 |
+
)
|
507 |
+
if pretrained:
|
508 |
+
init_pretrained_weights(model, model_urls['resnext101_32x8d'])
|
509 |
+
return model
|
510 |
+
|
511 |
+
|
512 |
+
"""
|
513 |
+
ResNet + FC
|
514 |
+
"""
|
515 |
+
|
516 |
+
|
517 |
+
def resnet50_fc512(num_classes, loss='softmax', pretrained=True, **kwargs):
|
518 |
+
model = ResNet(
|
519 |
+
num_classes=num_classes,
|
520 |
+
loss=loss,
|
521 |
+
block=Bottleneck,
|
522 |
+
layers=[3, 4, 6, 3],
|
523 |
+
last_stride=1,
|
524 |
+
fc_dims=[512],
|
525 |
+
dropout_p=None,
|
526 |
+
**kwargs
|
527 |
+
)
|
528 |
+
if pretrained:
|
529 |
+
init_pretrained_weights(model, model_urls['resnet50'])
|
530 |
+
return model
|
trackers/strongsort/deep/models/resnet_ibn_a.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Credit to https://github.com/XingangPan/IBN-Net.
|
3 |
+
"""
|
4 |
+
from __future__ import division, absolute_import
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.utils.model_zoo as model_zoo
|
9 |
+
|
10 |
+
__all__ = ['resnet50_ibn_a']
|
11 |
+
|
12 |
+
model_urls = {
|
13 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
14 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
15 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
20 |
+
"3x3 convolution with padding"
|
21 |
+
return nn.Conv2d(
|
22 |
+
in_planes,
|
23 |
+
out_planes,
|
24 |
+
kernel_size=3,
|
25 |
+
stride=stride,
|
26 |
+
padding=1,
|
27 |
+
bias=False
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
class BasicBlock(nn.Module):
|
32 |
+
expansion = 1
|
33 |
+
|
34 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
35 |
+
super(BasicBlock, self).__init__()
|
36 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
37 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
38 |
+
self.relu = nn.ReLU(inplace=True)
|
39 |
+
self.conv2 = conv3x3(planes, planes)
|
40 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
41 |
+
self.downsample = downsample
|
42 |
+
self.stride = stride
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
residual = x
|
46 |
+
|
47 |
+
out = self.conv1(x)
|
48 |
+
out = self.bn1(out)
|
49 |
+
out = self.relu(out)
|
50 |
+
|
51 |
+
out = self.conv2(out)
|
52 |
+
out = self.bn2(out)
|
53 |
+
|
54 |
+
if self.downsample is not None:
|
55 |
+
residual = self.downsample(x)
|
56 |
+
|
57 |
+
out += residual
|
58 |
+
out = self.relu(out)
|
59 |
+
|
60 |
+
return out
|
61 |
+
|
62 |
+
|
63 |
+
class IBN(nn.Module):
|
64 |
+
|
65 |
+
def __init__(self, planes):
|
66 |
+
super(IBN, self).__init__()
|
67 |
+
half1 = int(planes / 2)
|
68 |
+
self.half = half1
|
69 |
+
half2 = planes - half1
|
70 |
+
self.IN = nn.InstanceNorm2d(half1, affine=True)
|
71 |
+
self.BN = nn.BatchNorm2d(half2)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
split = torch.split(x, self.half, 1)
|
75 |
+
out1 = self.IN(split[0].contiguous())
|
76 |
+
out2 = self.BN(split[1].contiguous())
|
77 |
+
out = torch.cat((out1, out2), 1)
|
78 |
+
return out
|
79 |
+
|
80 |
+
|
81 |
+
class Bottleneck(nn.Module):
|
82 |
+
expansion = 4
|
83 |
+
|
84 |
+
def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None):
|
85 |
+
super(Bottleneck, self).__init__()
|
86 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
87 |
+
if ibn:
|
88 |
+
self.bn1 = IBN(planes)
|
89 |
+
else:
|
90 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
91 |
+
self.conv2 = nn.Conv2d(
|
92 |
+
planes,
|
93 |
+
planes,
|
94 |
+
kernel_size=3,
|
95 |
+
stride=stride,
|
96 |
+
padding=1,
|
97 |
+
bias=False
|
98 |
+
)
|
99 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
100 |
+
self.conv3 = nn.Conv2d(
|
101 |
+
planes, planes * self.expansion, kernel_size=1, bias=False
|
102 |
+
)
|
103 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
104 |
+
self.relu = nn.ReLU(inplace=True)
|
105 |
+
self.downsample = downsample
|
106 |
+
self.stride = stride
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
residual = x
|
110 |
+
|
111 |
+
out = self.conv1(x)
|
112 |
+
out = self.bn1(out)
|
113 |
+
out = self.relu(out)
|
114 |
+
|
115 |
+
out = self.conv2(out)
|
116 |
+
out = self.bn2(out)
|
117 |
+
out = self.relu(out)
|
118 |
+
|
119 |
+
out = self.conv3(out)
|
120 |
+
out = self.bn3(out)
|
121 |
+
|
122 |
+
if self.downsample is not None:
|
123 |
+
residual = self.downsample(x)
|
124 |
+
|
125 |
+
out += residual
|
126 |
+
out = self.relu(out)
|
127 |
+
|
128 |
+
return out
|
129 |
+
|
130 |
+
|
131 |
+
class ResNet(nn.Module):
|
132 |
+
"""Residual network + IBN layer.
|
133 |
+
|
134 |
+
Reference:
|
135 |
+
- He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
|
136 |
+
- Pan et al. Two at Once: Enhancing Learning and Generalization
|
137 |
+
Capacities via IBN-Net. ECCV 2018.
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
block,
|
143 |
+
layers,
|
144 |
+
num_classes=1000,
|
145 |
+
loss='softmax',
|
146 |
+
fc_dims=None,
|
147 |
+
dropout_p=None,
|
148 |
+
**kwargs
|
149 |
+
):
|
150 |
+
scale = 64
|
151 |
+
self.inplanes = scale
|
152 |
+
super(ResNet, self).__init__()
|
153 |
+
self.loss = loss
|
154 |
+
self.feature_dim = scale * 8 * block.expansion
|
155 |
+
|
156 |
+
self.conv1 = nn.Conv2d(
|
157 |
+
3, scale, kernel_size=7, stride=2, padding=3, bias=False
|
158 |
+
)
|
159 |
+
self.bn1 = nn.BatchNorm2d(scale)
|
160 |
+
self.relu = nn.ReLU(inplace=True)
|
161 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
162 |
+
self.layer1 = self._make_layer(block, scale, layers[0])
|
163 |
+
self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2)
|
164 |
+
self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2)
|
165 |
+
self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=2)
|
166 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
167 |
+
self.fc = self._construct_fc_layer(
|
168 |
+
fc_dims, scale * 8 * block.expansion, dropout_p
|
169 |
+
)
|
170 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
171 |
+
|
172 |
+
for m in self.modules():
|
173 |
+
if isinstance(m, nn.Conv2d):
|
174 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
175 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
176 |
+
elif isinstance(m, nn.BatchNorm2d):
|
177 |
+
m.weight.data.fill_(1)
|
178 |
+
m.bias.data.zero_()
|
179 |
+
elif isinstance(m, nn.InstanceNorm2d):
|
180 |
+
m.weight.data.fill_(1)
|
181 |
+
m.bias.data.zero_()
|
182 |
+
|
183 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
184 |
+
downsample = None
|
185 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
186 |
+
downsample = nn.Sequential(
|
187 |
+
nn.Conv2d(
|
188 |
+
self.inplanes,
|
189 |
+
planes * block.expansion,
|
190 |
+
kernel_size=1,
|
191 |
+
stride=stride,
|
192 |
+
bias=False
|
193 |
+
),
|
194 |
+
nn.BatchNorm2d(planes * block.expansion),
|
195 |
+
)
|
196 |
+
|
197 |
+
layers = []
|
198 |
+
ibn = True
|
199 |
+
if planes == 512:
|
200 |
+
ibn = False
|
201 |
+
layers.append(block(self.inplanes, planes, ibn, stride, downsample))
|
202 |
+
self.inplanes = planes * block.expansion
|
203 |
+
for i in range(1, blocks):
|
204 |
+
layers.append(block(self.inplanes, planes, ibn))
|
205 |
+
|
206 |
+
return nn.Sequential(*layers)
|
207 |
+
|
208 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
209 |
+
"""Constructs fully connected layer
|
210 |
+
|
211 |
+
Args:
|
212 |
+
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
213 |
+
input_dim (int): input dimension
|
214 |
+
dropout_p (float): dropout probability, if None, dropout is unused
|
215 |
+
"""
|
216 |
+
if fc_dims is None:
|
217 |
+
self.feature_dim = input_dim
|
218 |
+
return None
|
219 |
+
|
220 |
+
assert isinstance(
|
221 |
+
fc_dims, (list, tuple)
|
222 |
+
), 'fc_dims must be either list or tuple, but got {}'.format(
|
223 |
+
type(fc_dims)
|
224 |
+
)
|
225 |
+
|
226 |
+
layers = []
|
227 |
+
for dim in fc_dims:
|
228 |
+
layers.append(nn.Linear(input_dim, dim))
|
229 |
+
layers.append(nn.BatchNorm1d(dim))
|
230 |
+
layers.append(nn.ReLU(inplace=True))
|
231 |
+
if dropout_p is not None:
|
232 |
+
layers.append(nn.Dropout(p=dropout_p))
|
233 |
+
input_dim = dim
|
234 |
+
|
235 |
+
self.feature_dim = fc_dims[-1]
|
236 |
+
|
237 |
+
return nn.Sequential(*layers)
|
238 |
+
|
239 |
+
def featuremaps(self, x):
|
240 |
+
x = self.conv1(x)
|
241 |
+
x = self.bn1(x)
|
242 |
+
x = self.relu(x)
|
243 |
+
x = self.maxpool(x)
|
244 |
+
x = self.layer1(x)
|
245 |
+
x = self.layer2(x)
|
246 |
+
x = self.layer3(x)
|
247 |
+
x = self.layer4(x)
|
248 |
+
return x
|
249 |
+
|
250 |
+
def forward(self, x):
|
251 |
+
f = self.featuremaps(x)
|
252 |
+
v = self.avgpool(f)
|
253 |
+
v = v.view(v.size(0), -1)
|
254 |
+
if self.fc is not None:
|
255 |
+
v = self.fc(v)
|
256 |
+
if not self.training:
|
257 |
+
return v
|
258 |
+
y = self.classifier(v)
|
259 |
+
if self.loss == 'softmax':
|
260 |
+
return y
|
261 |
+
elif self.loss == 'triplet':
|
262 |
+
return y, v
|
263 |
+
else:
|
264 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
265 |
+
|
266 |
+
|
267 |
+
def init_pretrained_weights(model, model_url):
|
268 |
+
"""Initializes model with pretrained weights.
|
269 |
+
|
270 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
271 |
+
"""
|
272 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
273 |
+
model_dict = model.state_dict()
|
274 |
+
pretrain_dict = {
|
275 |
+
k: v
|
276 |
+
for k, v in pretrain_dict.items()
|
277 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
278 |
+
}
|
279 |
+
model_dict.update(pretrain_dict)
|
280 |
+
model.load_state_dict(model_dict)
|
281 |
+
|
282 |
+
|
283 |
+
def resnet50_ibn_a(num_classes, loss='softmax', pretrained=False, **kwargs):
|
284 |
+
model = ResNet(
|
285 |
+
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, loss=loss, **kwargs
|
286 |
+
)
|
287 |
+
if pretrained:
|
288 |
+
init_pretrained_weights(model, model_urls['resnet50'])
|
289 |
+
return model
|
trackers/strongsort/deep/models/resnet_ibn_b.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Credit to https://github.com/XingangPan/IBN-Net.
|
3 |
+
"""
|
4 |
+
from __future__ import division, absolute_import
|
5 |
+
import math
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.utils.model_zoo as model_zoo
|
8 |
+
|
9 |
+
__all__ = ['resnet50_ibn_b']
|
10 |
+
|
11 |
+
model_urls = {
|
12 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
13 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
14 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
19 |
+
"3x3 convolution with padding"
|
20 |
+
return nn.Conv2d(
|
21 |
+
in_planes,
|
22 |
+
out_planes,
|
23 |
+
kernel_size=3,
|
24 |
+
stride=stride,
|
25 |
+
padding=1,
|
26 |
+
bias=False
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
class BasicBlock(nn.Module):
|
31 |
+
expansion = 1
|
32 |
+
|
33 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
34 |
+
super(BasicBlock, self).__init__()
|
35 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
36 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
37 |
+
self.relu = nn.ReLU(inplace=True)
|
38 |
+
self.conv2 = conv3x3(planes, planes)
|
39 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
40 |
+
self.downsample = downsample
|
41 |
+
self.stride = stride
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
residual = x
|
45 |
+
|
46 |
+
out = self.conv1(x)
|
47 |
+
out = self.bn1(out)
|
48 |
+
out = self.relu(out)
|
49 |
+
|
50 |
+
out = self.conv2(out)
|
51 |
+
out = self.bn2(out)
|
52 |
+
|
53 |
+
if self.downsample is not None:
|
54 |
+
residual = self.downsample(x)
|
55 |
+
|
56 |
+
out += residual
|
57 |
+
out = self.relu(out)
|
58 |
+
|
59 |
+
return out
|
60 |
+
|
61 |
+
|
62 |
+
class Bottleneck(nn.Module):
|
63 |
+
expansion = 4
|
64 |
+
|
65 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, IN=False):
|
66 |
+
super(Bottleneck, self).__init__()
|
67 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
68 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
69 |
+
self.conv2 = nn.Conv2d(
|
70 |
+
planes,
|
71 |
+
planes,
|
72 |
+
kernel_size=3,
|
73 |
+
stride=stride,
|
74 |
+
padding=1,
|
75 |
+
bias=False
|
76 |
+
)
|
77 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
78 |
+
self.conv3 = nn.Conv2d(
|
79 |
+
planes, planes * self.expansion, kernel_size=1, bias=False
|
80 |
+
)
|
81 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
82 |
+
self.IN = None
|
83 |
+
if IN:
|
84 |
+
self.IN = nn.InstanceNorm2d(planes * 4, affine=True)
|
85 |
+
self.relu = nn.ReLU(inplace=True)
|
86 |
+
self.downsample = downsample
|
87 |
+
self.stride = stride
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
residual = x
|
91 |
+
|
92 |
+
out = self.conv1(x)
|
93 |
+
out = self.bn1(out)
|
94 |
+
out = self.relu(out)
|
95 |
+
|
96 |
+
out = self.conv2(out)
|
97 |
+
out = self.bn2(out)
|
98 |
+
out = self.relu(out)
|
99 |
+
|
100 |
+
out = self.conv3(out)
|
101 |
+
out = self.bn3(out)
|
102 |
+
|
103 |
+
if self.downsample is not None:
|
104 |
+
residual = self.downsample(x)
|
105 |
+
|
106 |
+
out += residual
|
107 |
+
if self.IN is not None:
|
108 |
+
out = self.IN(out)
|
109 |
+
out = self.relu(out)
|
110 |
+
|
111 |
+
return out
|
112 |
+
|
113 |
+
|
114 |
+
class ResNet(nn.Module):
|
115 |
+
"""Residual network + IBN layer.
|
116 |
+
|
117 |
+
Reference:
|
118 |
+
- He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
|
119 |
+
- Pan et al. Two at Once: Enhancing Learning and Generalization
|
120 |
+
Capacities via IBN-Net. ECCV 2018.
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
block,
|
126 |
+
layers,
|
127 |
+
num_classes=1000,
|
128 |
+
loss='softmax',
|
129 |
+
fc_dims=None,
|
130 |
+
dropout_p=None,
|
131 |
+
**kwargs
|
132 |
+
):
|
133 |
+
scale = 64
|
134 |
+
self.inplanes = scale
|
135 |
+
super(ResNet, self).__init__()
|
136 |
+
self.loss = loss
|
137 |
+
self.feature_dim = scale * 8 * block.expansion
|
138 |
+
|
139 |
+
self.conv1 = nn.Conv2d(
|
140 |
+
3, scale, kernel_size=7, stride=2, padding=3, bias=False
|
141 |
+
)
|
142 |
+
self.bn1 = nn.InstanceNorm2d(scale, affine=True)
|
143 |
+
self.relu = nn.ReLU(inplace=True)
|
144 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
145 |
+
self.layer1 = self._make_layer(
|
146 |
+
block, scale, layers[0], stride=1, IN=True
|
147 |
+
)
|
148 |
+
self.layer2 = self._make_layer(
|
149 |
+
block, scale * 2, layers[1], stride=2, IN=True
|
150 |
+
)
|
151 |
+
self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2)
|
152 |
+
self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=2)
|
153 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
154 |
+
self.fc = self._construct_fc_layer(
|
155 |
+
fc_dims, scale * 8 * block.expansion, dropout_p
|
156 |
+
)
|
157 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
158 |
+
|
159 |
+
for m in self.modules():
|
160 |
+
if isinstance(m, nn.Conv2d):
|
161 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
162 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
163 |
+
elif isinstance(m, nn.BatchNorm2d):
|
164 |
+
m.weight.data.fill_(1)
|
165 |
+
m.bias.data.zero_()
|
166 |
+
elif isinstance(m, nn.InstanceNorm2d):
|
167 |
+
m.weight.data.fill_(1)
|
168 |
+
m.bias.data.zero_()
|
169 |
+
|
170 |
+
def _make_layer(self, block, planes, blocks, stride=1, IN=False):
|
171 |
+
downsample = None
|
172 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
173 |
+
downsample = nn.Sequential(
|
174 |
+
nn.Conv2d(
|
175 |
+
self.inplanes,
|
176 |
+
planes * block.expansion,
|
177 |
+
kernel_size=1,
|
178 |
+
stride=stride,
|
179 |
+
bias=False
|
180 |
+
),
|
181 |
+
nn.BatchNorm2d(planes * block.expansion),
|
182 |
+
)
|
183 |
+
|
184 |
+
layers = []
|
185 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
186 |
+
self.inplanes = planes * block.expansion
|
187 |
+
for i in range(1, blocks - 1):
|
188 |
+
layers.append(block(self.inplanes, planes))
|
189 |
+
layers.append(block(self.inplanes, planes, IN=IN))
|
190 |
+
|
191 |
+
return nn.Sequential(*layers)
|
192 |
+
|
193 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
194 |
+
"""Constructs fully connected layer
|
195 |
+
|
196 |
+
Args:
|
197 |
+
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
198 |
+
input_dim (int): input dimension
|
199 |
+
dropout_p (float): dropout probability, if None, dropout is unused
|
200 |
+
"""
|
201 |
+
if fc_dims is None:
|
202 |
+
self.feature_dim = input_dim
|
203 |
+
return None
|
204 |
+
|
205 |
+
assert isinstance(
|
206 |
+
fc_dims, (list, tuple)
|
207 |
+
), 'fc_dims must be either list or tuple, but got {}'.format(
|
208 |
+
type(fc_dims)
|
209 |
+
)
|
210 |
+
|
211 |
+
layers = []
|
212 |
+
for dim in fc_dims:
|
213 |
+
layers.append(nn.Linear(input_dim, dim))
|
214 |
+
layers.append(nn.BatchNorm1d(dim))
|
215 |
+
layers.append(nn.ReLU(inplace=True))
|
216 |
+
if dropout_p is not None:
|
217 |
+
layers.append(nn.Dropout(p=dropout_p))
|
218 |
+
input_dim = dim
|
219 |
+
|
220 |
+
self.feature_dim = fc_dims[-1]
|
221 |
+
|
222 |
+
return nn.Sequential(*layers)
|
223 |
+
|
224 |
+
def featuremaps(self, x):
|
225 |
+
x = self.conv1(x)
|
226 |
+
x = self.bn1(x)
|
227 |
+
x = self.relu(x)
|
228 |
+
x = self.maxpool(x)
|
229 |
+
x = self.layer1(x)
|
230 |
+
x = self.layer2(x)
|
231 |
+
x = self.layer3(x)
|
232 |
+
x = self.layer4(x)
|
233 |
+
return x
|
234 |
+
|
235 |
+
def forward(self, x):
|
236 |
+
f = self.featuremaps(x)
|
237 |
+
v = self.avgpool(f)
|
238 |
+
v = v.view(v.size(0), -1)
|
239 |
+
if self.fc is not None:
|
240 |
+
v = self.fc(v)
|
241 |
+
if not self.training:
|
242 |
+
return v
|
243 |
+
y = self.classifier(v)
|
244 |
+
if self.loss == 'softmax':
|
245 |
+
return y
|
246 |
+
elif self.loss == 'triplet':
|
247 |
+
return y, v
|
248 |
+
else:
|
249 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
250 |
+
|
251 |
+
|
252 |
+
def init_pretrained_weights(model, model_url):
|
253 |
+
"""Initializes model with pretrained weights.
|
254 |
+
|
255 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
256 |
+
"""
|
257 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
258 |
+
model_dict = model.state_dict()
|
259 |
+
pretrain_dict = {
|
260 |
+
k: v
|
261 |
+
for k, v in pretrain_dict.items()
|
262 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
263 |
+
}
|
264 |
+
model_dict.update(pretrain_dict)
|
265 |
+
model.load_state_dict(model_dict)
|
266 |
+
|
267 |
+
|
268 |
+
def resnet50_ibn_b(num_classes, loss='softmax', pretrained=False, **kwargs):
|
269 |
+
model = ResNet(
|
270 |
+
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, loss=loss, **kwargs
|
271 |
+
)
|
272 |
+
if pretrained:
|
273 |
+
init_pretrained_weights(model, model_urls['resnet50'])
|
274 |
+
return model
|
trackers/strongsort/deep/models/resnetmid.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, absolute_import
|
2 |
+
import torch
|
3 |
+
import torch.utils.model_zoo as model_zoo
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
__all__ = ['resnet50mid']
|
7 |
+
|
8 |
+
model_urls = {
|
9 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
10 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
11 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
12 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
13 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
18 |
+
"""3x3 convolution with padding"""
|
19 |
+
return nn.Conv2d(
|
20 |
+
in_planes,
|
21 |
+
out_planes,
|
22 |
+
kernel_size=3,
|
23 |
+
stride=stride,
|
24 |
+
padding=1,
|
25 |
+
bias=False
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class BasicBlock(nn.Module):
|
30 |
+
expansion = 1
|
31 |
+
|
32 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
33 |
+
super(BasicBlock, self).__init__()
|
34 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
35 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
36 |
+
self.relu = nn.ReLU(inplace=True)
|
37 |
+
self.conv2 = conv3x3(planes, planes)
|
38 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
39 |
+
self.downsample = downsample
|
40 |
+
self.stride = stride
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
residual = x
|
44 |
+
|
45 |
+
out = self.conv1(x)
|
46 |
+
out = self.bn1(out)
|
47 |
+
out = self.relu(out)
|
48 |
+
|
49 |
+
out = self.conv2(out)
|
50 |
+
out = self.bn2(out)
|
51 |
+
|
52 |
+
if self.downsample is not None:
|
53 |
+
residual = self.downsample(x)
|
54 |
+
|
55 |
+
out += residual
|
56 |
+
out = self.relu(out)
|
57 |
+
|
58 |
+
return out
|
59 |
+
|
60 |
+
|
61 |
+
class Bottleneck(nn.Module):
|
62 |
+
expansion = 4
|
63 |
+
|
64 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
65 |
+
super(Bottleneck, self).__init__()
|
66 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
67 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
68 |
+
self.conv2 = nn.Conv2d(
|
69 |
+
planes,
|
70 |
+
planes,
|
71 |
+
kernel_size=3,
|
72 |
+
stride=stride,
|
73 |
+
padding=1,
|
74 |
+
bias=False
|
75 |
+
)
|
76 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
77 |
+
self.conv3 = nn.Conv2d(
|
78 |
+
planes, planes * self.expansion, kernel_size=1, bias=False
|
79 |
+
)
|
80 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
81 |
+
self.relu = nn.ReLU(inplace=True)
|
82 |
+
self.downsample = downsample
|
83 |
+
self.stride = stride
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
residual = x
|
87 |
+
|
88 |
+
out = self.conv1(x)
|
89 |
+
out = self.bn1(out)
|
90 |
+
out = self.relu(out)
|
91 |
+
|
92 |
+
out = self.conv2(out)
|
93 |
+
out = self.bn2(out)
|
94 |
+
out = self.relu(out)
|
95 |
+
|
96 |
+
out = self.conv3(out)
|
97 |
+
out = self.bn3(out)
|
98 |
+
|
99 |
+
if self.downsample is not None:
|
100 |
+
residual = self.downsample(x)
|
101 |
+
|
102 |
+
out += residual
|
103 |
+
out = self.relu(out)
|
104 |
+
|
105 |
+
return out
|
106 |
+
|
107 |
+
|
108 |
+
class ResNetMid(nn.Module):
|
109 |
+
"""Residual network + mid-level features.
|
110 |
+
|
111 |
+
Reference:
|
112 |
+
Yu et al. The Devil is in the Middle: Exploiting Mid-level Representations for
|
113 |
+
Cross-Domain Instance Matching. arXiv:1711.08106.
|
114 |
+
|
115 |
+
Public keys:
|
116 |
+
- ``resnet50mid``: ResNet50 + mid-level feature fusion.
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
num_classes,
|
122 |
+
loss,
|
123 |
+
block,
|
124 |
+
layers,
|
125 |
+
last_stride=2,
|
126 |
+
fc_dims=None,
|
127 |
+
**kwargs
|
128 |
+
):
|
129 |
+
self.inplanes = 64
|
130 |
+
super(ResNetMid, self).__init__()
|
131 |
+
self.loss = loss
|
132 |
+
self.feature_dim = 512 * block.expansion
|
133 |
+
|
134 |
+
# backbone network
|
135 |
+
self.conv1 = nn.Conv2d(
|
136 |
+
3, 64, kernel_size=7, stride=2, padding=3, bias=False
|
137 |
+
)
|
138 |
+
self.bn1 = nn.BatchNorm2d(64)
|
139 |
+
self.relu = nn.ReLU(inplace=True)
|
140 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
141 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
142 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
143 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
144 |
+
self.layer4 = self._make_layer(
|
145 |
+
block, 512, layers[3], stride=last_stride
|
146 |
+
)
|
147 |
+
|
148 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
149 |
+
assert fc_dims is not None
|
150 |
+
self.fc_fusion = self._construct_fc_layer(
|
151 |
+
fc_dims, 512 * block.expansion * 2
|
152 |
+
)
|
153 |
+
self.feature_dim += 512 * block.expansion
|
154 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
155 |
+
|
156 |
+
self._init_params()
|
157 |
+
|
158 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
159 |
+
downsample = None
|
160 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
161 |
+
downsample = nn.Sequential(
|
162 |
+
nn.Conv2d(
|
163 |
+
self.inplanes,
|
164 |
+
planes * block.expansion,
|
165 |
+
kernel_size=1,
|
166 |
+
stride=stride,
|
167 |
+
bias=False
|
168 |
+
),
|
169 |
+
nn.BatchNorm2d(planes * block.expansion),
|
170 |
+
)
|
171 |
+
|
172 |
+
layers = []
|
173 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
174 |
+
self.inplanes = planes * block.expansion
|
175 |
+
for i in range(1, blocks):
|
176 |
+
layers.append(block(self.inplanes, planes))
|
177 |
+
|
178 |
+
return nn.Sequential(*layers)
|
179 |
+
|
180 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
181 |
+
"""Constructs fully connected layer
|
182 |
+
|
183 |
+
Args:
|
184 |
+
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
185 |
+
input_dim (int): input dimension
|
186 |
+
dropout_p (float): dropout probability, if None, dropout is unused
|
187 |
+
"""
|
188 |
+
if fc_dims is None:
|
189 |
+
self.feature_dim = input_dim
|
190 |
+
return None
|
191 |
+
|
192 |
+
assert isinstance(
|
193 |
+
fc_dims, (list, tuple)
|
194 |
+
), 'fc_dims must be either list or tuple, but got {}'.format(
|
195 |
+
type(fc_dims)
|
196 |
+
)
|
197 |
+
|
198 |
+
layers = []
|
199 |
+
for dim in fc_dims:
|
200 |
+
layers.append(nn.Linear(input_dim, dim))
|
201 |
+
layers.append(nn.BatchNorm1d(dim))
|
202 |
+
layers.append(nn.ReLU(inplace=True))
|
203 |
+
if dropout_p is not None:
|
204 |
+
layers.append(nn.Dropout(p=dropout_p))
|
205 |
+
input_dim = dim
|
206 |
+
|
207 |
+
self.feature_dim = fc_dims[-1]
|
208 |
+
|
209 |
+
return nn.Sequential(*layers)
|
210 |
+
|
211 |
+
def _init_params(self):
|
212 |
+
for m in self.modules():
|
213 |
+
if isinstance(m, nn.Conv2d):
|
214 |
+
nn.init.kaiming_normal_(
|
215 |
+
m.weight, mode='fan_out', nonlinearity='relu'
|
216 |
+
)
|
217 |
+
if m.bias is not None:
|
218 |
+
nn.init.constant_(m.bias, 0)
|
219 |
+
elif isinstance(m, nn.BatchNorm2d):
|
220 |
+
nn.init.constant_(m.weight, 1)
|
221 |
+
nn.init.constant_(m.bias, 0)
|
222 |
+
elif isinstance(m, nn.BatchNorm1d):
|
223 |
+
nn.init.constant_(m.weight, 1)
|
224 |
+
nn.init.constant_(m.bias, 0)
|
225 |
+
elif isinstance(m, nn.Linear):
|
226 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
227 |
+
if m.bias is not None:
|
228 |
+
nn.init.constant_(m.bias, 0)
|
229 |
+
|
230 |
+
def featuremaps(self, x):
|
231 |
+
x = self.conv1(x)
|
232 |
+
x = self.bn1(x)
|
233 |
+
x = self.relu(x)
|
234 |
+
x = self.maxpool(x)
|
235 |
+
x = self.layer1(x)
|
236 |
+
x = self.layer2(x)
|
237 |
+
x = self.layer3(x)
|
238 |
+
x4a = self.layer4[0](x)
|
239 |
+
x4b = self.layer4[1](x4a)
|
240 |
+
x4c = self.layer4[2](x4b)
|
241 |
+
return x4a, x4b, x4c
|
242 |
+
|
243 |
+
def forward(self, x):
|
244 |
+
x4a, x4b, x4c = self.featuremaps(x)
|
245 |
+
|
246 |
+
v4a = self.global_avgpool(x4a)
|
247 |
+
v4b = self.global_avgpool(x4b)
|
248 |
+
v4c = self.global_avgpool(x4c)
|
249 |
+
v4ab = torch.cat([v4a, v4b], 1)
|
250 |
+
v4ab = v4ab.view(v4ab.size(0), -1)
|
251 |
+
v4ab = self.fc_fusion(v4ab)
|
252 |
+
v4c = v4c.view(v4c.size(0), -1)
|
253 |
+
v = torch.cat([v4ab, v4c], 1)
|
254 |
+
|
255 |
+
if not self.training:
|
256 |
+
return v
|
257 |
+
|
258 |
+
y = self.classifier(v)
|
259 |
+
|
260 |
+
if self.loss == 'softmax':
|
261 |
+
return y
|
262 |
+
elif self.loss == 'triplet':
|
263 |
+
return y, v
|
264 |
+
else:
|
265 |
+
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
266 |
+
|
267 |
+
|
268 |
+
def init_pretrained_weights(model, model_url):
|
269 |
+
"""Initializes model with pretrained weights.
|
270 |
+
|
271 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
272 |
+
"""
|
273 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
274 |
+
model_dict = model.state_dict()
|
275 |
+
pretrain_dict = {
|
276 |
+
k: v
|
277 |
+
for k, v in pretrain_dict.items()
|
278 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
279 |
+
}
|
280 |
+
model_dict.update(pretrain_dict)
|
281 |
+
model.load_state_dict(model_dict)
|
282 |
+
|
283 |
+
|
284 |
+
"""
|
285 |
+
Residual network configurations:
|
286 |
+
--
|
287 |
+
resnet18: block=BasicBlock, layers=[2, 2, 2, 2]
|
288 |
+
resnet34: block=BasicBlock, layers=[3, 4, 6, 3]
|
289 |
+
resnet50: block=Bottleneck, layers=[3, 4, 6, 3]
|
290 |
+
resnet101: block=Bottleneck, layers=[3, 4, 23, 3]
|
291 |
+
resnet152: block=Bottleneck, layers=[3, 8, 36, 3]
|
292 |
+
"""
|
293 |
+
|
294 |
+
|
295 |
+
def resnet50mid(num_classes, loss='softmax', pretrained=True, **kwargs):
|
296 |
+
model = ResNetMid(
|
297 |
+
num_classes=num_classes,
|
298 |
+
loss=loss,
|
299 |
+
block=Bottleneck,
|
300 |
+
layers=[3, 4, 6, 3],
|
301 |
+
last_stride=2,
|
302 |
+
fc_dims=[1024],
|
303 |
+
**kwargs
|
304 |
+
)
|
305 |
+
if pretrained:
|
306 |
+
init_pretrained_weights(model, model_urls['resnet50'])
|
307 |
+
return model
|
trackers/strongsort/deep/models/senet.py
ADDED
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, absolute_import
|
2 |
+
import math
|
3 |
+
from collections import OrderedDict
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.utils import model_zoo
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152',
|
9 |
+
'se_resnext50_32x4d', 'se_resnext101_32x4d', 'se_resnet50_fc512'
|
10 |
+
]
|
11 |
+
"""
|
12 |
+
Code imported from https://github.com/Cadene/pretrained-models.pytorch
|
13 |
+
"""
|
14 |
+
|
15 |
+
pretrained_settings = {
|
16 |
+
'senet154': {
|
17 |
+
'imagenet': {
|
18 |
+
'url':
|
19 |
+
'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
|
20 |
+
'input_space': 'RGB',
|
21 |
+
'input_size': [3, 224, 224],
|
22 |
+
'input_range': [0, 1],
|
23 |
+
'mean': [0.485, 0.456, 0.406],
|
24 |
+
'std': [0.229, 0.224, 0.225],
|
25 |
+
'num_classes': 1000
|
26 |
+
}
|
27 |
+
},
|
28 |
+
'se_resnet50': {
|
29 |
+
'imagenet': {
|
30 |
+
'url':
|
31 |
+
'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
|
32 |
+
'input_space': 'RGB',
|
33 |
+
'input_size': [3, 224, 224],
|
34 |
+
'input_range': [0, 1],
|
35 |
+
'mean': [0.485, 0.456, 0.406],
|
36 |
+
'std': [0.229, 0.224, 0.225],
|
37 |
+
'num_classes': 1000
|
38 |
+
}
|
39 |
+
},
|
40 |
+
'se_resnet101': {
|
41 |
+
'imagenet': {
|
42 |
+
'url':
|
43 |
+
'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
|
44 |
+
'input_space': 'RGB',
|
45 |
+
'input_size': [3, 224, 224],
|
46 |
+
'input_range': [0, 1],
|
47 |
+
'mean': [0.485, 0.456, 0.406],
|
48 |
+
'std': [0.229, 0.224, 0.225],
|
49 |
+
'num_classes': 1000
|
50 |
+
}
|
51 |
+
},
|
52 |
+
'se_resnet152': {
|
53 |
+
'imagenet': {
|
54 |
+
'url':
|
55 |
+
'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
|
56 |
+
'input_space': 'RGB',
|
57 |
+
'input_size': [3, 224, 224],
|
58 |
+
'input_range': [0, 1],
|
59 |
+
'mean': [0.485, 0.456, 0.406],
|
60 |
+
'std': [0.229, 0.224, 0.225],
|
61 |
+
'num_classes': 1000
|
62 |
+
}
|
63 |
+
},
|
64 |
+
'se_resnext50_32x4d': {
|
65 |
+
'imagenet': {
|
66 |
+
'url':
|
67 |
+
'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
|
68 |
+
'input_space': 'RGB',
|
69 |
+
'input_size': [3, 224, 224],
|
70 |
+
'input_range': [0, 1],
|
71 |
+
'mean': [0.485, 0.456, 0.406],
|
72 |
+
'std': [0.229, 0.224, 0.225],
|
73 |
+
'num_classes': 1000
|
74 |
+
}
|
75 |
+
},
|
76 |
+
'se_resnext101_32x4d': {
|
77 |
+
'imagenet': {
|
78 |
+
'url':
|
79 |
+
'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
|
80 |
+
'input_space': 'RGB',
|
81 |
+
'input_size': [3, 224, 224],
|
82 |
+
'input_range': [0, 1],
|
83 |
+
'mean': [0.485, 0.456, 0.406],
|
84 |
+
'std': [0.229, 0.224, 0.225],
|
85 |
+
'num_classes': 1000
|
86 |
+
}
|
87 |
+
},
|
88 |
+
}
|
89 |
+
|
90 |
+
|
91 |
+
class SEModule(nn.Module):
|
92 |
+
|
93 |
+
def __init__(self, channels, reduction):
|
94 |
+
super(SEModule, self).__init__()
|
95 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
96 |
+
self.fc1 = nn.Conv2d(
|
97 |
+
channels, channels // reduction, kernel_size=1, padding=0
|
98 |
+
)
|
99 |
+
self.relu = nn.ReLU(inplace=True)
|
100 |
+
self.fc2 = nn.Conv2d(
|
101 |
+
channels // reduction, channels, kernel_size=1, padding=0
|
102 |
+
)
|
103 |
+
self.sigmoid = nn.Sigmoid()
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
module_input = x
|
107 |
+
x = self.avg_pool(x)
|
108 |
+
x = self.fc1(x)
|
109 |
+
x = self.relu(x)
|
110 |
+
x = self.fc2(x)
|
111 |
+
x = self.sigmoid(x)
|
112 |
+
return module_input * x
|
113 |
+
|
114 |
+
|
115 |
+
class Bottleneck(nn.Module):
|
116 |
+
"""
|
117 |
+
Base class for bottlenecks that implements `forward()` method.
|
118 |
+
"""
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
residual = x
|
122 |
+
|
123 |
+
out = self.conv1(x)
|
124 |
+
out = self.bn1(out)
|
125 |
+
out = self.relu(out)
|
126 |
+
|
127 |
+
out = self.conv2(out)
|
128 |
+
out = self.bn2(out)
|
129 |
+
out = self.relu(out)
|
130 |
+
|
131 |
+
out = self.conv3(out)
|
132 |
+
out = self.bn3(out)
|
133 |
+
|
134 |
+
if self.downsample is not None:
|
135 |
+
residual = self.downsample(x)
|
136 |
+
|
137 |
+
out = self.se_module(out) + residual
|
138 |
+
out = self.relu(out)
|
139 |
+
|
140 |
+
return out
|
141 |
+
|
142 |
+
|
143 |
+
class SEBottleneck(Bottleneck):
|
144 |
+
"""
|
145 |
+
Bottleneck for SENet154.
|
146 |
+
"""
|
147 |
+
expansion = 4
|
148 |
+
|
149 |
+
def __init__(
|
150 |
+
self, inplanes, planes, groups, reduction, stride=1, downsample=None
|
151 |
+
):
|
152 |
+
super(SEBottleneck, self).__init__()
|
153 |
+
self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
|
154 |
+
self.bn1 = nn.BatchNorm2d(planes * 2)
|
155 |
+
self.conv2 = nn.Conv2d(
|
156 |
+
planes * 2,
|
157 |
+
planes * 4,
|
158 |
+
kernel_size=3,
|
159 |
+
stride=stride,
|
160 |
+
padding=1,
|
161 |
+
groups=groups,
|
162 |
+
bias=False
|
163 |
+
)
|
164 |
+
self.bn2 = nn.BatchNorm2d(planes * 4)
|
165 |
+
self.conv3 = nn.Conv2d(
|
166 |
+
planes * 4, planes * 4, kernel_size=1, bias=False
|
167 |
+
)
|
168 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
169 |
+
self.relu = nn.ReLU(inplace=True)
|
170 |
+
self.se_module = SEModule(planes * 4, reduction=reduction)
|
171 |
+
self.downsample = downsample
|
172 |
+
self.stride = stride
|
173 |
+
|
174 |
+
|
175 |
+
class SEResNetBottleneck(Bottleneck):
|
176 |
+
"""
|
177 |
+
ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
|
178 |
+
implementation and uses `stride=stride` in `conv1` and not in `conv2`
|
179 |
+
(the latter is used in the torchvision implementation of ResNet).
|
180 |
+
"""
|
181 |
+
expansion = 4
|
182 |
+
|
183 |
+
def __init__(
|
184 |
+
self, inplanes, planes, groups, reduction, stride=1, downsample=None
|
185 |
+
):
|
186 |
+
super(SEResNetBottleneck, self).__init__()
|
187 |
+
self.conv1 = nn.Conv2d(
|
188 |
+
inplanes, planes, kernel_size=1, bias=False, stride=stride
|
189 |
+
)
|
190 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
191 |
+
self.conv2 = nn.Conv2d(
|
192 |
+
planes,
|
193 |
+
planes,
|
194 |
+
kernel_size=3,
|
195 |
+
padding=1,
|
196 |
+
groups=groups,
|
197 |
+
bias=False
|
198 |
+
)
|
199 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
200 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
201 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
202 |
+
self.relu = nn.ReLU(inplace=True)
|
203 |
+
self.se_module = SEModule(planes * 4, reduction=reduction)
|
204 |
+
self.downsample = downsample
|
205 |
+
self.stride = stride
|
206 |
+
|
207 |
+
|
208 |
+
class SEResNeXtBottleneck(Bottleneck):
|
209 |
+
"""ResNeXt bottleneck type C with a Squeeze-and-Excitation module"""
|
210 |
+
expansion = 4
|
211 |
+
|
212 |
+
def __init__(
|
213 |
+
self,
|
214 |
+
inplanes,
|
215 |
+
planes,
|
216 |
+
groups,
|
217 |
+
reduction,
|
218 |
+
stride=1,
|
219 |
+
downsample=None,
|
220 |
+
base_width=4
|
221 |
+
):
|
222 |
+
super(SEResNeXtBottleneck, self).__init__()
|
223 |
+
width = int(math.floor(planes * (base_width/64.)) * groups)
|
224 |
+
self.conv1 = nn.Conv2d(
|
225 |
+
inplanes, width, kernel_size=1, bias=False, stride=1
|
226 |
+
)
|
227 |
+
self.bn1 = nn.BatchNorm2d(width)
|
228 |
+
self.conv2 = nn.Conv2d(
|
229 |
+
width,
|
230 |
+
width,
|
231 |
+
kernel_size=3,
|
232 |
+
stride=stride,
|
233 |
+
padding=1,
|
234 |
+
groups=groups,
|
235 |
+
bias=False
|
236 |
+
)
|
237 |
+
self.bn2 = nn.BatchNorm2d(width)
|
238 |
+
self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
|
239 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
240 |
+
self.relu = nn.ReLU(inplace=True)
|
241 |
+
self.se_module = SEModule(planes * 4, reduction=reduction)
|
242 |
+
self.downsample = downsample
|
243 |
+
self.stride = stride
|
244 |
+
|
245 |
+
|
246 |
+
class SENet(nn.Module):
|
247 |
+
"""Squeeze-and-excitation network.
|
248 |
+
|
249 |
+
Reference:
|
250 |
+
Hu et al. Squeeze-and-Excitation Networks. CVPR 2018.
|
251 |
+
|
252 |
+
Public keys:
|
253 |
+
- ``senet154``: SENet154.
|
254 |
+
- ``se_resnet50``: ResNet50 + SE.
|
255 |
+
- ``se_resnet101``: ResNet101 + SE.
|
256 |
+
- ``se_resnet152``: ResNet152 + SE.
|
257 |
+
- ``se_resnext50_32x4d``: ResNeXt50 (groups=32, width=4) + SE.
|
258 |
+
- ``se_resnext101_32x4d``: ResNeXt101 (groups=32, width=4) + SE.
|
259 |
+
- ``se_resnet50_fc512``: (ResNet50 + SE) + FC.
|
260 |
+
"""
|
261 |
+
|
262 |
+
def __init__(
|
263 |
+
self,
|
264 |
+
num_classes,
|
265 |
+
loss,
|
266 |
+
block,
|
267 |
+
layers,
|
268 |
+
groups,
|
269 |
+
reduction,
|
270 |
+
dropout_p=0.2,
|
271 |
+
inplanes=128,
|
272 |
+
input_3x3=True,
|
273 |
+
downsample_kernel_size=3,
|
274 |
+
downsample_padding=1,
|
275 |
+
last_stride=2,
|
276 |
+
fc_dims=None,
|
277 |
+
**kwargs
|
278 |
+
):
|
279 |
+
"""
|
280 |
+
Parameters
|
281 |
+
----------
|
282 |
+
block (nn.Module): Bottleneck class.
|
283 |
+
- For SENet154: SEBottleneck
|
284 |
+
- For SE-ResNet models: SEResNetBottleneck
|
285 |
+
- For SE-ResNeXt models: SEResNeXtBottleneck
|
286 |
+
layers (list of ints): Number of residual blocks for 4 layers of the
|
287 |
+
network (layer1...layer4).
|
288 |
+
groups (int): Number of groups for the 3x3 convolution in each
|
289 |
+
bottleneck block.
|
290 |
+
- For SENet154: 64
|
291 |
+
- For SE-ResNet models: 1
|
292 |
+
- For SE-ResNeXt models: 32
|
293 |
+
reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
|
294 |
+
- For all models: 16
|
295 |
+
dropout_p (float or None): Drop probability for the Dropout layer.
|
296 |
+
If `None` the Dropout layer is not used.
|
297 |
+
- For SENet154: 0.2
|
298 |
+
- For SE-ResNet models: None
|
299 |
+
- For SE-ResNeXt models: None
|
300 |
+
inplanes (int): Number of input channels for layer1.
|
301 |
+
- For SENet154: 128
|
302 |
+
- For SE-ResNet models: 64
|
303 |
+
- For SE-ResNeXt models: 64
|
304 |
+
input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
|
305 |
+
a single 7x7 convolution in layer0.
|
306 |
+
- For SENet154: True
|
307 |
+
- For SE-ResNet models: False
|
308 |
+
- For SE-ResNeXt models: False
|
309 |
+
downsample_kernel_size (int): Kernel size for downsampling convolutions
|
310 |
+
in layer2, layer3 and layer4.
|
311 |
+
- For SENet154: 3
|
312 |
+
- For SE-ResNet models: 1
|
313 |
+
- For SE-ResNeXt models: 1
|
314 |
+
downsample_padding (int): Padding for downsampling convolutions in
|
315 |
+
layer2, layer3 and layer4.
|
316 |
+
- For SENet154: 1
|
317 |
+
- For SE-ResNet models: 0
|
318 |
+
- For SE-ResNeXt models: 0
|
319 |
+
num_classes (int): Number of outputs in `classifier` layer.
|
320 |
+
"""
|
321 |
+
super(SENet, self).__init__()
|
322 |
+
self.inplanes = inplanes
|
323 |
+
self.loss = loss
|
324 |
+
|
325 |
+
if input_3x3:
|
326 |
+
layer0_modules = [
|
327 |
+
(
|
328 |
+
'conv1',
|
329 |
+
nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False)
|
330 |
+
),
|
331 |
+
('bn1', nn.BatchNorm2d(64)),
|
332 |
+
('relu1', nn.ReLU(inplace=True)),
|
333 |
+
(
|
334 |
+
'conv2',
|
335 |
+
nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)
|
336 |
+
),
|
337 |
+
('bn2', nn.BatchNorm2d(64)),
|
338 |
+
('relu2', nn.ReLU(inplace=True)),
|
339 |
+
(
|
340 |
+
'conv3',
|
341 |
+
nn.Conv2d(
|
342 |
+
64, inplanes, 3, stride=1, padding=1, bias=False
|
343 |
+
)
|
344 |
+
),
|
345 |
+
('bn3', nn.BatchNorm2d(inplanes)),
|
346 |
+
('relu3', nn.ReLU(inplace=True)),
|
347 |
+
]
|
348 |
+
else:
|
349 |
+
layer0_modules = [
|
350 |
+
(
|
351 |
+
'conv1',
|
352 |
+
nn.Conv2d(
|
353 |
+
3,
|
354 |
+
inplanes,
|
355 |
+
kernel_size=7,
|
356 |
+
stride=2,
|
357 |
+
padding=3,
|
358 |
+
bias=False
|
359 |
+
)
|
360 |
+
),
|
361 |
+
('bn1', nn.BatchNorm2d(inplanes)),
|
362 |
+
('relu1', nn.ReLU(inplace=True)),
|
363 |
+
]
|
364 |
+
# To preserve compatibility with Caffe weights `ceil_mode=True`
|
365 |
+
# is used instead of `padding=1`.
|
366 |
+
layer0_modules.append(
|
367 |
+
('pool', nn.MaxPool2d(3, stride=2, ceil_mode=True))
|
368 |
+
)
|
369 |
+
self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
|
370 |
+
self.layer1 = self._make_layer(
|
371 |
+
block,
|
372 |
+
planes=64,
|
373 |
+
blocks=layers[0],
|
374 |
+
groups=groups,
|
375 |
+
reduction=reduction,
|
376 |
+
downsample_kernel_size=1,
|
377 |
+
downsample_padding=0
|
378 |
+
)
|
379 |
+
self.layer2 = self._make_layer(
|
380 |
+
block,
|
381 |
+
planes=128,
|
382 |
+
blocks=layers[1],
|
383 |
+
stride=2,
|
384 |
+
groups=groups,
|
385 |
+
reduction=reduction,
|
386 |
+
downsample_kernel_size=downsample_kernel_size,
|
387 |
+
downsample_padding=downsample_padding
|
388 |
+
)
|
389 |
+
self.layer3 = self._make_layer(
|
390 |
+
block,
|
391 |
+
planes=256,
|
392 |
+
blocks=layers[2],
|
393 |
+
stride=2,
|
394 |
+
groups=groups,
|
395 |
+
reduction=reduction,
|
396 |
+
downsample_kernel_size=downsample_kernel_size,
|
397 |
+
downsample_padding=downsample_padding
|
398 |
+
)
|
399 |
+
self.layer4 = self._make_layer(
|
400 |
+
block,
|
401 |
+
planes=512,
|
402 |
+
blocks=layers[3],
|
403 |
+
stride=last_stride,
|
404 |
+
groups=groups,
|
405 |
+
reduction=reduction,
|
406 |
+
downsample_kernel_size=downsample_kernel_size,
|
407 |
+
downsample_padding=downsample_padding
|
408 |
+
)
|
409 |
+
|
410 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
411 |
+
self.fc = self._construct_fc_layer(
|
412 |
+
fc_dims, 512 * block.expansion, dropout_p
|
413 |
+
)
|
414 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
415 |
+
|
416 |
+
def _make_layer(
|
417 |
+
self,
|
418 |
+
block,
|
419 |
+
planes,
|
420 |
+
blocks,
|
421 |
+
groups,
|
422 |
+
reduction,
|
423 |
+
stride=1,
|
424 |
+
downsample_kernel_size=1,
|
425 |
+
downsample_padding=0
|
426 |
+
):
|
427 |
+
downsample = None
|
428 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
429 |
+
downsample = nn.Sequential(
|
430 |
+
nn.Conv2d(
|
431 |
+
self.inplanes,
|
432 |
+
planes * block.expansion,
|
433 |
+
kernel_size=downsample_kernel_size,
|
434 |
+
stride=stride,
|
435 |
+
padding=downsample_padding,
|
436 |
+
bias=False
|
437 |
+
),
|
438 |
+
nn.BatchNorm2d(planes * block.expansion),
|
439 |
+
)
|
440 |
+
|
441 |
+
layers = []
|
442 |
+
layers.append(
|
443 |
+
block(
|
444 |
+
self.inplanes, planes, groups, reduction, stride, downsample
|
445 |
+
)
|
446 |
+
)
|
447 |
+
self.inplanes = planes * block.expansion
|
448 |
+
for i in range(1, blocks):
|
449 |
+
layers.append(block(self.inplanes, planes, groups, reduction))
|
450 |
+
|
451 |
+
return nn.Sequential(*layers)
|
452 |
+
|
453 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
454 |
+
"""
|
455 |
+
Construct fully connected layer
|
456 |
+
|
457 |
+
- fc_dims (list or tuple): dimensions of fc layers, if None,
|
458 |
+
no fc layers are constructed
|
459 |
+
- input_dim (int): input dimension
|
460 |
+
- dropout_p (float): dropout probability, if None, dropout is unused
|
461 |
+
"""
|
462 |
+
if fc_dims is None:
|
463 |
+
self.feature_dim = input_dim
|
464 |
+
return None
|
465 |
+
|
466 |
+
assert isinstance(
|
467 |
+
fc_dims, (list, tuple)
|
468 |
+
), 'fc_dims must be either list or tuple, but got {}'.format(
|
469 |
+
type(fc_dims)
|
470 |
+
)
|
471 |
+
|
472 |
+
layers = []
|
473 |
+
for dim in fc_dims:
|
474 |
+
layers.append(nn.Linear(input_dim, dim))
|
475 |
+
layers.append(nn.BatchNorm1d(dim))
|
476 |
+
layers.append(nn.ReLU(inplace=True))
|
477 |
+
if dropout_p is not None:
|
478 |
+
layers.append(nn.Dropout(p=dropout_p))
|
479 |
+
input_dim = dim
|
480 |
+
|
481 |
+
self.feature_dim = fc_dims[-1]
|
482 |
+
|
483 |
+
return nn.Sequential(*layers)
|
484 |
+
|
485 |
+
def featuremaps(self, x):
|
486 |
+
x = self.layer0(x)
|
487 |
+
x = self.layer1(x)
|
488 |
+
x = self.layer2(x)
|
489 |
+
x = self.layer3(x)
|
490 |
+
x = self.layer4(x)
|
491 |
+
return x
|
492 |
+
|
493 |
+
def forward(self, x):
|
494 |
+
f = self.featuremaps(x)
|
495 |
+
v = self.global_avgpool(f)
|
496 |
+
v = v.view(v.size(0), -1)
|
497 |
+
|
498 |
+
if self.fc is not None:
|
499 |
+
v = self.fc(v)
|
500 |
+
|
501 |
+
if not self.training:
|
502 |
+
return v
|
503 |
+
|
504 |
+
y = self.classifier(v)
|
505 |
+
|
506 |
+
if self.loss == 'softmax':
|
507 |
+
return y
|
508 |
+
elif self.loss == 'triplet':
|
509 |
+
return y, v
|
510 |
+
else:
|
511 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
512 |
+
|
513 |
+
|
514 |
+
def init_pretrained_weights(model, model_url):
|
515 |
+
"""Initializes model with pretrained weights.
|
516 |
+
|
517 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
518 |
+
"""
|
519 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
520 |
+
model_dict = model.state_dict()
|
521 |
+
pretrain_dict = {
|
522 |
+
k: v
|
523 |
+
for k, v in pretrain_dict.items()
|
524 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
525 |
+
}
|
526 |
+
model_dict.update(pretrain_dict)
|
527 |
+
model.load_state_dict(model_dict)
|
528 |
+
|
529 |
+
|
530 |
+
def senet154(num_classes, loss='softmax', pretrained=True, **kwargs):
|
531 |
+
model = SENet(
|
532 |
+
num_classes=num_classes,
|
533 |
+
loss=loss,
|
534 |
+
block=SEBottleneck,
|
535 |
+
layers=[3, 8, 36, 3],
|
536 |
+
groups=64,
|
537 |
+
reduction=16,
|
538 |
+
dropout_p=0.2,
|
539 |
+
last_stride=2,
|
540 |
+
fc_dims=None,
|
541 |
+
**kwargs
|
542 |
+
)
|
543 |
+
if pretrained:
|
544 |
+
model_url = pretrained_settings['senet154']['imagenet']['url']
|
545 |
+
init_pretrained_weights(model, model_url)
|
546 |
+
return model
|
547 |
+
|
548 |
+
|
549 |
+
def se_resnet50(num_classes, loss='softmax', pretrained=True, **kwargs):
|
550 |
+
model = SENet(
|
551 |
+
num_classes=num_classes,
|
552 |
+
loss=loss,
|
553 |
+
block=SEResNetBottleneck,
|
554 |
+
layers=[3, 4, 6, 3],
|
555 |
+
groups=1,
|
556 |
+
reduction=16,
|
557 |
+
dropout_p=None,
|
558 |
+
inplanes=64,
|
559 |
+
input_3x3=False,
|
560 |
+
downsample_kernel_size=1,
|
561 |
+
downsample_padding=0,
|
562 |
+
last_stride=2,
|
563 |
+
fc_dims=None,
|
564 |
+
**kwargs
|
565 |
+
)
|
566 |
+
if pretrained:
|
567 |
+
model_url = pretrained_settings['se_resnet50']['imagenet']['url']
|
568 |
+
init_pretrained_weights(model, model_url)
|
569 |
+
return model
|
570 |
+
|
571 |
+
|
572 |
+
def se_resnet50_fc512(num_classes, loss='softmax', pretrained=True, **kwargs):
|
573 |
+
model = SENet(
|
574 |
+
num_classes=num_classes,
|
575 |
+
loss=loss,
|
576 |
+
block=SEResNetBottleneck,
|
577 |
+
layers=[3, 4, 6, 3],
|
578 |
+
groups=1,
|
579 |
+
reduction=16,
|
580 |
+
dropout_p=None,
|
581 |
+
inplanes=64,
|
582 |
+
input_3x3=False,
|
583 |
+
downsample_kernel_size=1,
|
584 |
+
downsample_padding=0,
|
585 |
+
last_stride=1,
|
586 |
+
fc_dims=[512],
|
587 |
+
**kwargs
|
588 |
+
)
|
589 |
+
if pretrained:
|
590 |
+
model_url = pretrained_settings['se_resnet50']['imagenet']['url']
|
591 |
+
init_pretrained_weights(model, model_url)
|
592 |
+
return model
|
593 |
+
|
594 |
+
|
595 |
+
def se_resnet101(num_classes, loss='softmax', pretrained=True, **kwargs):
|
596 |
+
model = SENet(
|
597 |
+
num_classes=num_classes,
|
598 |
+
loss=loss,
|
599 |
+
block=SEResNetBottleneck,
|
600 |
+
layers=[3, 4, 23, 3],
|
601 |
+
groups=1,
|
602 |
+
reduction=16,
|
603 |
+
dropout_p=None,
|
604 |
+
inplanes=64,
|
605 |
+
input_3x3=False,
|
606 |
+
downsample_kernel_size=1,
|
607 |
+
downsample_padding=0,
|
608 |
+
last_stride=2,
|
609 |
+
fc_dims=None,
|
610 |
+
**kwargs
|
611 |
+
)
|
612 |
+
if pretrained:
|
613 |
+
model_url = pretrained_settings['se_resnet101']['imagenet']['url']
|
614 |
+
init_pretrained_weights(model, model_url)
|
615 |
+
return model
|
616 |
+
|
617 |
+
|
618 |
+
def se_resnet152(num_classes, loss='softmax', pretrained=True, **kwargs):
|
619 |
+
model = SENet(
|
620 |
+
num_classes=num_classes,
|
621 |
+
loss=loss,
|
622 |
+
block=SEResNetBottleneck,
|
623 |
+
layers=[3, 8, 36, 3],
|
624 |
+
groups=1,
|
625 |
+
reduction=16,
|
626 |
+
dropout_p=None,
|
627 |
+
inplanes=64,
|
628 |
+
input_3x3=False,
|
629 |
+
downsample_kernel_size=1,
|
630 |
+
downsample_padding=0,
|
631 |
+
last_stride=2,
|
632 |
+
fc_dims=None,
|
633 |
+
**kwargs
|
634 |
+
)
|
635 |
+
if pretrained:
|
636 |
+
model_url = pretrained_settings['se_resnet152']['imagenet']['url']
|
637 |
+
init_pretrained_weights(model, model_url)
|
638 |
+
return model
|
639 |
+
|
640 |
+
|
641 |
+
def se_resnext50_32x4d(num_classes, loss='softmax', pretrained=True, **kwargs):
|
642 |
+
model = SENet(
|
643 |
+
num_classes=num_classes,
|
644 |
+
loss=loss,
|
645 |
+
block=SEResNeXtBottleneck,
|
646 |
+
layers=[3, 4, 6, 3],
|
647 |
+
groups=32,
|
648 |
+
reduction=16,
|
649 |
+
dropout_p=None,
|
650 |
+
inplanes=64,
|
651 |
+
input_3x3=False,
|
652 |
+
downsample_kernel_size=1,
|
653 |
+
downsample_padding=0,
|
654 |
+
last_stride=2,
|
655 |
+
fc_dims=None,
|
656 |
+
**kwargs
|
657 |
+
)
|
658 |
+
if pretrained:
|
659 |
+
model_url = pretrained_settings['se_resnext50_32x4d']['imagenet']['url'
|
660 |
+
]
|
661 |
+
init_pretrained_weights(model, model_url)
|
662 |
+
return model
|
663 |
+
|
664 |
+
|
665 |
+
def se_resnext101_32x4d(
|
666 |
+
num_classes, loss='softmax', pretrained=True, **kwargs
|
667 |
+
):
|
668 |
+
model = SENet(
|
669 |
+
num_classes=num_classes,
|
670 |
+
loss=loss,
|
671 |
+
block=SEResNeXtBottleneck,
|
672 |
+
layers=[3, 4, 23, 3],
|
673 |
+
groups=32,
|
674 |
+
reduction=16,
|
675 |
+
dropout_p=None,
|
676 |
+
inplanes=64,
|
677 |
+
input_3x3=False,
|
678 |
+
downsample_kernel_size=1,
|
679 |
+
downsample_padding=0,
|
680 |
+
last_stride=2,
|
681 |
+
fc_dims=None,
|
682 |
+
**kwargs
|
683 |
+
)
|
684 |
+
if pretrained:
|
685 |
+
model_url = pretrained_settings['se_resnext101_32x4d']['imagenet'][
|
686 |
+
'url']
|
687 |
+
init_pretrained_weights(model, model_url)
|
688 |
+
return model
|
trackers/strongsort/deep/models/shufflenet.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, absolute_import
|
2 |
+
import torch
|
3 |
+
import torch.utils.model_zoo as model_zoo
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
__all__ = ['shufflenet']
|
8 |
+
|
9 |
+
model_urls = {
|
10 |
+
# training epoch = 90, top1 = 61.8
|
11 |
+
'imagenet':
|
12 |
+
'https://mega.nz/#!RDpUlQCY!tr_5xBEkelzDjveIYBBcGcovNCOrgfiJO9kiidz9fZM',
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class ChannelShuffle(nn.Module):
|
17 |
+
|
18 |
+
def __init__(self, num_groups):
|
19 |
+
super(ChannelShuffle, self).__init__()
|
20 |
+
self.g = num_groups
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
b, c, h, w = x.size()
|
24 |
+
n = c // self.g
|
25 |
+
# reshape
|
26 |
+
x = x.view(b, self.g, n, h, w)
|
27 |
+
# transpose
|
28 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
29 |
+
# flatten
|
30 |
+
x = x.view(b, c, h, w)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class Bottleneck(nn.Module):
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
in_channels,
|
39 |
+
out_channels,
|
40 |
+
stride,
|
41 |
+
num_groups,
|
42 |
+
group_conv1x1=True
|
43 |
+
):
|
44 |
+
super(Bottleneck, self).__init__()
|
45 |
+
assert stride in [1, 2], 'Warning: stride must be either 1 or 2'
|
46 |
+
self.stride = stride
|
47 |
+
mid_channels = out_channels // 4
|
48 |
+
if stride == 2:
|
49 |
+
out_channels -= in_channels
|
50 |
+
# group conv is not applied to first conv1x1 at stage 2
|
51 |
+
num_groups_conv1x1 = num_groups if group_conv1x1 else 1
|
52 |
+
self.conv1 = nn.Conv2d(
|
53 |
+
in_channels,
|
54 |
+
mid_channels,
|
55 |
+
1,
|
56 |
+
groups=num_groups_conv1x1,
|
57 |
+
bias=False
|
58 |
+
)
|
59 |
+
self.bn1 = nn.BatchNorm2d(mid_channels)
|
60 |
+
self.shuffle1 = ChannelShuffle(num_groups)
|
61 |
+
self.conv2 = nn.Conv2d(
|
62 |
+
mid_channels,
|
63 |
+
mid_channels,
|
64 |
+
3,
|
65 |
+
stride=stride,
|
66 |
+
padding=1,
|
67 |
+
groups=mid_channels,
|
68 |
+
bias=False
|
69 |
+
)
|
70 |
+
self.bn2 = nn.BatchNorm2d(mid_channels)
|
71 |
+
self.conv3 = nn.Conv2d(
|
72 |
+
mid_channels, out_channels, 1, groups=num_groups, bias=False
|
73 |
+
)
|
74 |
+
self.bn3 = nn.BatchNorm2d(out_channels)
|
75 |
+
if stride == 2:
|
76 |
+
self.shortcut = nn.AvgPool2d(3, stride=2, padding=1)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
80 |
+
out = self.shuffle1(out)
|
81 |
+
out = self.bn2(self.conv2(out))
|
82 |
+
out = self.bn3(self.conv3(out))
|
83 |
+
if self.stride == 2:
|
84 |
+
res = self.shortcut(x)
|
85 |
+
out = F.relu(torch.cat([res, out], 1))
|
86 |
+
else:
|
87 |
+
out = F.relu(x + out)
|
88 |
+
return out
|
89 |
+
|
90 |
+
|
91 |
+
# configuration of (num_groups: #out_channels) based on Table 1 in the paper
|
92 |
+
cfg = {
|
93 |
+
1: [144, 288, 576],
|
94 |
+
2: [200, 400, 800],
|
95 |
+
3: [240, 480, 960],
|
96 |
+
4: [272, 544, 1088],
|
97 |
+
8: [384, 768, 1536],
|
98 |
+
}
|
99 |
+
|
100 |
+
|
101 |
+
class ShuffleNet(nn.Module):
|
102 |
+
"""ShuffleNet.
|
103 |
+
|
104 |
+
Reference:
|
105 |
+
Zhang et al. ShuffleNet: An Extremely Efficient Convolutional Neural
|
106 |
+
Network for Mobile Devices. CVPR 2018.
|
107 |
+
|
108 |
+
Public keys:
|
109 |
+
- ``shufflenet``: ShuffleNet (groups=3).
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(self, num_classes, loss='softmax', num_groups=3, **kwargs):
|
113 |
+
super(ShuffleNet, self).__init__()
|
114 |
+
self.loss = loss
|
115 |
+
|
116 |
+
self.conv1 = nn.Sequential(
|
117 |
+
nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False),
|
118 |
+
nn.BatchNorm2d(24),
|
119 |
+
nn.ReLU(),
|
120 |
+
nn.MaxPool2d(3, stride=2, padding=1),
|
121 |
+
)
|
122 |
+
|
123 |
+
self.stage2 = nn.Sequential(
|
124 |
+
Bottleneck(
|
125 |
+
24, cfg[num_groups][0], 2, num_groups, group_conv1x1=False
|
126 |
+
),
|
127 |
+
Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups),
|
128 |
+
Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups),
|
129 |
+
Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups),
|
130 |
+
)
|
131 |
+
|
132 |
+
self.stage3 = nn.Sequential(
|
133 |
+
Bottleneck(cfg[num_groups][0], cfg[num_groups][1], 2, num_groups),
|
134 |
+
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
135 |
+
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
136 |
+
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
137 |
+
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
138 |
+
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
139 |
+
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
140 |
+
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
141 |
+
)
|
142 |
+
|
143 |
+
self.stage4 = nn.Sequential(
|
144 |
+
Bottleneck(cfg[num_groups][1], cfg[num_groups][2], 2, num_groups),
|
145 |
+
Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups),
|
146 |
+
Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups),
|
147 |
+
Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups),
|
148 |
+
)
|
149 |
+
|
150 |
+
self.classifier = nn.Linear(cfg[num_groups][2], num_classes)
|
151 |
+
self.feat_dim = cfg[num_groups][2]
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
x = self.conv1(x)
|
155 |
+
x = self.stage2(x)
|
156 |
+
x = self.stage3(x)
|
157 |
+
x = self.stage4(x)
|
158 |
+
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1)
|
159 |
+
|
160 |
+
if not self.training:
|
161 |
+
return x
|
162 |
+
|
163 |
+
y = self.classifier(x)
|
164 |
+
|
165 |
+
if self.loss == 'softmax':
|
166 |
+
return y
|
167 |
+
elif self.loss == 'triplet':
|
168 |
+
return y, x
|
169 |
+
else:
|
170 |
+
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
171 |
+
|
172 |
+
|
173 |
+
def init_pretrained_weights(model, model_url):
|
174 |
+
"""Initializes model with pretrained weights.
|
175 |
+
|
176 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
177 |
+
"""
|
178 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
179 |
+
model_dict = model.state_dict()
|
180 |
+
pretrain_dict = {
|
181 |
+
k: v
|
182 |
+
for k, v in pretrain_dict.items()
|
183 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
184 |
+
}
|
185 |
+
model_dict.update(pretrain_dict)
|
186 |
+
model.load_state_dict(model_dict)
|
187 |
+
|
188 |
+
|
189 |
+
def shufflenet(num_classes, loss='softmax', pretrained=True, **kwargs):
|
190 |
+
model = ShuffleNet(num_classes, loss, **kwargs)
|
191 |
+
if pretrained:
|
192 |
+
# init_pretrained_weights(model, model_urls['imagenet'])
|
193 |
+
import warnings
|
194 |
+
warnings.warn(
|
195 |
+
'The imagenet pretrained weights need to be manually downloaded from {}'
|
196 |
+
.format(model_urls['imagenet'])
|
197 |
+
)
|
198 |
+
return model
|
trackers/strongsort/deep/models/shufflenetv2.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code source: https://github.com/pytorch/vision
|
3 |
+
"""
|
4 |
+
from __future__ import division, absolute_import
|
5 |
+
import torch
|
6 |
+
import torch.utils.model_zoo as model_zoo
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5',
|
11 |
+
'shufflenet_v2_x2_0'
|
12 |
+
]
|
13 |
+
|
14 |
+
model_urls = {
|
15 |
+
'shufflenetv2_x0.5':
|
16 |
+
'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
|
17 |
+
'shufflenetv2_x1.0':
|
18 |
+
'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
|
19 |
+
'shufflenetv2_x1.5': None,
|
20 |
+
'shufflenetv2_x2.0': None,
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
def channel_shuffle(x, groups):
|
25 |
+
batchsize, num_channels, height, width = x.data.size()
|
26 |
+
channels_per_group = num_channels // groups
|
27 |
+
|
28 |
+
# reshape
|
29 |
+
x = x.view(batchsize, groups, channels_per_group, height, width)
|
30 |
+
|
31 |
+
x = torch.transpose(x, 1, 2).contiguous()
|
32 |
+
|
33 |
+
# flatten
|
34 |
+
x = x.view(batchsize, -1, height, width)
|
35 |
+
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
class InvertedResidual(nn.Module):
|
40 |
+
|
41 |
+
def __init__(self, inp, oup, stride):
|
42 |
+
super(InvertedResidual, self).__init__()
|
43 |
+
|
44 |
+
if not (1 <= stride <= 3):
|
45 |
+
raise ValueError('illegal stride value')
|
46 |
+
self.stride = stride
|
47 |
+
|
48 |
+
branch_features = oup // 2
|
49 |
+
assert (self.stride != 1) or (inp == branch_features << 1)
|
50 |
+
|
51 |
+
if self.stride > 1:
|
52 |
+
self.branch1 = nn.Sequential(
|
53 |
+
self.depthwise_conv(
|
54 |
+
inp, inp, kernel_size=3, stride=self.stride, padding=1
|
55 |
+
),
|
56 |
+
nn.BatchNorm2d(inp),
|
57 |
+
nn.Conv2d(
|
58 |
+
inp,
|
59 |
+
branch_features,
|
60 |
+
kernel_size=1,
|
61 |
+
stride=1,
|
62 |
+
padding=0,
|
63 |
+
bias=False
|
64 |
+
),
|
65 |
+
nn.BatchNorm2d(branch_features),
|
66 |
+
nn.ReLU(inplace=True),
|
67 |
+
)
|
68 |
+
|
69 |
+
self.branch2 = nn.Sequential(
|
70 |
+
nn.Conv2d(
|
71 |
+
inp if (self.stride > 1) else branch_features,
|
72 |
+
branch_features,
|
73 |
+
kernel_size=1,
|
74 |
+
stride=1,
|
75 |
+
padding=0,
|
76 |
+
bias=False
|
77 |
+
),
|
78 |
+
nn.BatchNorm2d(branch_features),
|
79 |
+
nn.ReLU(inplace=True),
|
80 |
+
self.depthwise_conv(
|
81 |
+
branch_features,
|
82 |
+
branch_features,
|
83 |
+
kernel_size=3,
|
84 |
+
stride=self.stride,
|
85 |
+
padding=1
|
86 |
+
),
|
87 |
+
nn.BatchNorm2d(branch_features),
|
88 |
+
nn.Conv2d(
|
89 |
+
branch_features,
|
90 |
+
branch_features,
|
91 |
+
kernel_size=1,
|
92 |
+
stride=1,
|
93 |
+
padding=0,
|
94 |
+
bias=False
|
95 |
+
),
|
96 |
+
nn.BatchNorm2d(branch_features),
|
97 |
+
nn.ReLU(inplace=True),
|
98 |
+
)
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
|
102 |
+
return nn.Conv2d(
|
103 |
+
i, o, kernel_size, stride, padding, bias=bias, groups=i
|
104 |
+
)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
if self.stride == 1:
|
108 |
+
x1, x2 = x.chunk(2, dim=1)
|
109 |
+
out = torch.cat((x1, self.branch2(x2)), dim=1)
|
110 |
+
else:
|
111 |
+
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
|
112 |
+
|
113 |
+
out = channel_shuffle(out, 2)
|
114 |
+
|
115 |
+
return out
|
116 |
+
|
117 |
+
|
118 |
+
class ShuffleNetV2(nn.Module):
|
119 |
+
"""ShuffleNetV2.
|
120 |
+
|
121 |
+
Reference:
|
122 |
+
Ma et al. ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design. ECCV 2018.
|
123 |
+
|
124 |
+
Public keys:
|
125 |
+
- ``shufflenet_v2_x0_5``: ShuffleNetV2 x0.5.
|
126 |
+
- ``shufflenet_v2_x1_0``: ShuffleNetV2 x1.0.
|
127 |
+
- ``shufflenet_v2_x1_5``: ShuffleNetV2 x1.5.
|
128 |
+
- ``shufflenet_v2_x2_0``: ShuffleNetV2 x2.0.
|
129 |
+
"""
|
130 |
+
|
131 |
+
def __init__(
|
132 |
+
self, num_classes, loss, stages_repeats, stages_out_channels, **kwargs
|
133 |
+
):
|
134 |
+
super(ShuffleNetV2, self).__init__()
|
135 |
+
self.loss = loss
|
136 |
+
|
137 |
+
if len(stages_repeats) != 3:
|
138 |
+
raise ValueError(
|
139 |
+
'expected stages_repeats as list of 3 positive ints'
|
140 |
+
)
|
141 |
+
if len(stages_out_channels) != 5:
|
142 |
+
raise ValueError(
|
143 |
+
'expected stages_out_channels as list of 5 positive ints'
|
144 |
+
)
|
145 |
+
self._stage_out_channels = stages_out_channels
|
146 |
+
|
147 |
+
input_channels = 3
|
148 |
+
output_channels = self._stage_out_channels[0]
|
149 |
+
self.conv1 = nn.Sequential(
|
150 |
+
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
|
151 |
+
nn.BatchNorm2d(output_channels),
|
152 |
+
nn.ReLU(inplace=True),
|
153 |
+
)
|
154 |
+
input_channels = output_channels
|
155 |
+
|
156 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
157 |
+
|
158 |
+
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
|
159 |
+
for name, repeats, output_channels in zip(
|
160 |
+
stage_names, stages_repeats, self._stage_out_channels[1:]
|
161 |
+
):
|
162 |
+
seq = [InvertedResidual(input_channels, output_channels, 2)]
|
163 |
+
for i in range(repeats - 1):
|
164 |
+
seq.append(
|
165 |
+
InvertedResidual(output_channels, output_channels, 1)
|
166 |
+
)
|
167 |
+
setattr(self, name, nn.Sequential(*seq))
|
168 |
+
input_channels = output_channels
|
169 |
+
|
170 |
+
output_channels = self._stage_out_channels[-1]
|
171 |
+
self.conv5 = nn.Sequential(
|
172 |
+
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
|
173 |
+
nn.BatchNorm2d(output_channels),
|
174 |
+
nn.ReLU(inplace=True),
|
175 |
+
)
|
176 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
177 |
+
|
178 |
+
self.classifier = nn.Linear(output_channels, num_classes)
|
179 |
+
|
180 |
+
def featuremaps(self, x):
|
181 |
+
x = self.conv1(x)
|
182 |
+
x = self.maxpool(x)
|
183 |
+
x = self.stage2(x)
|
184 |
+
x = self.stage3(x)
|
185 |
+
x = self.stage4(x)
|
186 |
+
x = self.conv5(x)
|
187 |
+
return x
|
188 |
+
|
189 |
+
def forward(self, x):
|
190 |
+
f = self.featuremaps(x)
|
191 |
+
v = self.global_avgpool(f)
|
192 |
+
v = v.view(v.size(0), -1)
|
193 |
+
|
194 |
+
if not self.training:
|
195 |
+
return v
|
196 |
+
|
197 |
+
y = self.classifier(v)
|
198 |
+
|
199 |
+
if self.loss == 'softmax':
|
200 |
+
return y
|
201 |
+
elif self.loss == 'triplet':
|
202 |
+
return y, v
|
203 |
+
else:
|
204 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
205 |
+
|
206 |
+
|
207 |
+
def init_pretrained_weights(model, model_url):
|
208 |
+
"""Initializes model with pretrained weights.
|
209 |
+
|
210 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
211 |
+
"""
|
212 |
+
if model_url is None:
|
213 |
+
import warnings
|
214 |
+
warnings.warn(
|
215 |
+
'ImageNet pretrained weights are unavailable for this model'
|
216 |
+
)
|
217 |
+
return
|
218 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
219 |
+
model_dict = model.state_dict()
|
220 |
+
pretrain_dict = {
|
221 |
+
k: v
|
222 |
+
for k, v in pretrain_dict.items()
|
223 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
224 |
+
}
|
225 |
+
model_dict.update(pretrain_dict)
|
226 |
+
model.load_state_dict(model_dict)
|
227 |
+
|
228 |
+
|
229 |
+
def shufflenet_v2_x0_5(num_classes, loss='softmax', pretrained=True, **kwargs):
|
230 |
+
model = ShuffleNetV2(
|
231 |
+
num_classes, loss, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs
|
232 |
+
)
|
233 |
+
if pretrained:
|
234 |
+
init_pretrained_weights(model, model_urls['shufflenetv2_x0.5'])
|
235 |
+
return model
|
236 |
+
|
237 |
+
|
238 |
+
def shufflenet_v2_x1_0(num_classes, loss='softmax', pretrained=True, **kwargs):
|
239 |
+
model = ShuffleNetV2(
|
240 |
+
num_classes, loss, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs
|
241 |
+
)
|
242 |
+
if pretrained:
|
243 |
+
init_pretrained_weights(model, model_urls['shufflenetv2_x1.0'])
|
244 |
+
return model
|
245 |
+
|
246 |
+
|
247 |
+
def shufflenet_v2_x1_5(num_classes, loss='softmax', pretrained=True, **kwargs):
|
248 |
+
model = ShuffleNetV2(
|
249 |
+
num_classes, loss, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs
|
250 |
+
)
|
251 |
+
if pretrained:
|
252 |
+
init_pretrained_weights(model, model_urls['shufflenetv2_x1.5'])
|
253 |
+
return model
|
254 |
+
|
255 |
+
|
256 |
+
def shufflenet_v2_x2_0(num_classes, loss='softmax', pretrained=True, **kwargs):
|
257 |
+
model = ShuffleNetV2(
|
258 |
+
num_classes, loss, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs
|
259 |
+
)
|
260 |
+
if pretrained:
|
261 |
+
init_pretrained_weights(model, model_urls['shufflenetv2_x2.0'])
|
262 |
+
return model
|
trackers/strongsort/deep/models/squeezenet.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code source: https://github.com/pytorch/vision
|
3 |
+
"""
|
4 |
+
from __future__ import division, absolute_import
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.utils.model_zoo as model_zoo
|
8 |
+
|
9 |
+
__all__ = ['squeezenet1_0', 'squeezenet1_1', 'squeezenet1_0_fc512']
|
10 |
+
|
11 |
+
model_urls = {
|
12 |
+
'squeezenet1_0':
|
13 |
+
'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
|
14 |
+
'squeezenet1_1':
|
15 |
+
'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
class Fire(nn.Module):
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes
|
23 |
+
):
|
24 |
+
super(Fire, self).__init__()
|
25 |
+
self.inplanes = inplanes
|
26 |
+
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
|
27 |
+
self.squeeze_activation = nn.ReLU(inplace=True)
|
28 |
+
self.expand1x1 = nn.Conv2d(
|
29 |
+
squeeze_planes, expand1x1_planes, kernel_size=1
|
30 |
+
)
|
31 |
+
self.expand1x1_activation = nn.ReLU(inplace=True)
|
32 |
+
self.expand3x3 = nn.Conv2d(
|
33 |
+
squeeze_planes, expand3x3_planes, kernel_size=3, padding=1
|
34 |
+
)
|
35 |
+
self.expand3x3_activation = nn.ReLU(inplace=True)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
x = self.squeeze_activation(self.squeeze(x))
|
39 |
+
return torch.cat(
|
40 |
+
[
|
41 |
+
self.expand1x1_activation(self.expand1x1(x)),
|
42 |
+
self.expand3x3_activation(self.expand3x3(x))
|
43 |
+
], 1
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
class SqueezeNet(nn.Module):
|
48 |
+
"""SqueezeNet.
|
49 |
+
|
50 |
+
Reference:
|
51 |
+
Iandola et al. SqueezeNet: AlexNet-level accuracy with 50x fewer parameters
|
52 |
+
and< 0.5 MB model size. arXiv:1602.07360.
|
53 |
+
|
54 |
+
Public keys:
|
55 |
+
- ``squeezenet1_0``: SqueezeNet (version=1.0).
|
56 |
+
- ``squeezenet1_1``: SqueezeNet (version=1.1).
|
57 |
+
- ``squeezenet1_0_fc512``: SqueezeNet (version=1.0) + FC.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
num_classes,
|
63 |
+
loss,
|
64 |
+
version=1.0,
|
65 |
+
fc_dims=None,
|
66 |
+
dropout_p=None,
|
67 |
+
**kwargs
|
68 |
+
):
|
69 |
+
super(SqueezeNet, self).__init__()
|
70 |
+
self.loss = loss
|
71 |
+
self.feature_dim = 512
|
72 |
+
|
73 |
+
if version not in [1.0, 1.1]:
|
74 |
+
raise ValueError(
|
75 |
+
'Unsupported SqueezeNet version {version}:'
|
76 |
+
'1.0 or 1.1 expected'.format(version=version)
|
77 |
+
)
|
78 |
+
|
79 |
+
if version == 1.0:
|
80 |
+
self.features = nn.Sequential(
|
81 |
+
nn.Conv2d(3, 96, kernel_size=7, stride=2),
|
82 |
+
nn.ReLU(inplace=True),
|
83 |
+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
|
84 |
+
Fire(96, 16, 64, 64),
|
85 |
+
Fire(128, 16, 64, 64),
|
86 |
+
Fire(128, 32, 128, 128),
|
87 |
+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
|
88 |
+
Fire(256, 32, 128, 128),
|
89 |
+
Fire(256, 48, 192, 192),
|
90 |
+
Fire(384, 48, 192, 192),
|
91 |
+
Fire(384, 64, 256, 256),
|
92 |
+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
|
93 |
+
Fire(512, 64, 256, 256),
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
self.features = nn.Sequential(
|
97 |
+
nn.Conv2d(3, 64, kernel_size=3, stride=2),
|
98 |
+
nn.ReLU(inplace=True),
|
99 |
+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
|
100 |
+
Fire(64, 16, 64, 64),
|
101 |
+
Fire(128, 16, 64, 64),
|
102 |
+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
|
103 |
+
Fire(128, 32, 128, 128),
|
104 |
+
Fire(256, 32, 128, 128),
|
105 |
+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
|
106 |
+
Fire(256, 48, 192, 192),
|
107 |
+
Fire(384, 48, 192, 192),
|
108 |
+
Fire(384, 64, 256, 256),
|
109 |
+
Fire(512, 64, 256, 256),
|
110 |
+
)
|
111 |
+
|
112 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
113 |
+
self.fc = self._construct_fc_layer(fc_dims, 512, dropout_p)
|
114 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
115 |
+
|
116 |
+
self._init_params()
|
117 |
+
|
118 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
119 |
+
"""Constructs fully connected layer
|
120 |
+
|
121 |
+
Args:
|
122 |
+
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
123 |
+
input_dim (int): input dimension
|
124 |
+
dropout_p (float): dropout probability, if None, dropout is unused
|
125 |
+
"""
|
126 |
+
if fc_dims is None:
|
127 |
+
self.feature_dim = input_dim
|
128 |
+
return None
|
129 |
+
|
130 |
+
assert isinstance(
|
131 |
+
fc_dims, (list, tuple)
|
132 |
+
), 'fc_dims must be either list or tuple, but got {}'.format(
|
133 |
+
type(fc_dims)
|
134 |
+
)
|
135 |
+
|
136 |
+
layers = []
|
137 |
+
for dim in fc_dims:
|
138 |
+
layers.append(nn.Linear(input_dim, dim))
|
139 |
+
layers.append(nn.BatchNorm1d(dim))
|
140 |
+
layers.append(nn.ReLU(inplace=True))
|
141 |
+
if dropout_p is not None:
|
142 |
+
layers.append(nn.Dropout(p=dropout_p))
|
143 |
+
input_dim = dim
|
144 |
+
|
145 |
+
self.feature_dim = fc_dims[-1]
|
146 |
+
|
147 |
+
return nn.Sequential(*layers)
|
148 |
+
|
149 |
+
def _init_params(self):
|
150 |
+
for m in self.modules():
|
151 |
+
if isinstance(m, nn.Conv2d):
|
152 |
+
nn.init.kaiming_normal_(
|
153 |
+
m.weight, mode='fan_out', nonlinearity='relu'
|
154 |
+
)
|
155 |
+
if m.bias is not None:
|
156 |
+
nn.init.constant_(m.bias, 0)
|
157 |
+
elif isinstance(m, nn.BatchNorm2d):
|
158 |
+
nn.init.constant_(m.weight, 1)
|
159 |
+
nn.init.constant_(m.bias, 0)
|
160 |
+
elif isinstance(m, nn.BatchNorm1d):
|
161 |
+
nn.init.constant_(m.weight, 1)
|
162 |
+
nn.init.constant_(m.bias, 0)
|
163 |
+
elif isinstance(m, nn.Linear):
|
164 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
165 |
+
if m.bias is not None:
|
166 |
+
nn.init.constant_(m.bias, 0)
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
f = self.features(x)
|
170 |
+
v = self.global_avgpool(f)
|
171 |
+
v = v.view(v.size(0), -1)
|
172 |
+
|
173 |
+
if self.fc is not None:
|
174 |
+
v = self.fc(v)
|
175 |
+
|
176 |
+
if not self.training:
|
177 |
+
return v
|
178 |
+
|
179 |
+
y = self.classifier(v)
|
180 |
+
|
181 |
+
if self.loss == 'softmax':
|
182 |
+
return y
|
183 |
+
elif self.loss == 'triplet':
|
184 |
+
return y, v
|
185 |
+
else:
|
186 |
+
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
187 |
+
|
188 |
+
|
189 |
+
def init_pretrained_weights(model, model_url):
|
190 |
+
"""Initializes model with pretrained weights.
|
191 |
+
|
192 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
193 |
+
"""
|
194 |
+
pretrain_dict = model_zoo.load_url(model_url, map_location=None)
|
195 |
+
model_dict = model.state_dict()
|
196 |
+
pretrain_dict = {
|
197 |
+
k: v
|
198 |
+
for k, v in pretrain_dict.items()
|
199 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
200 |
+
}
|
201 |
+
model_dict.update(pretrain_dict)
|
202 |
+
model.load_state_dict(model_dict)
|
203 |
+
|
204 |
+
|
205 |
+
def squeezenet1_0(num_classes, loss='softmax', pretrained=True, **kwargs):
|
206 |
+
model = SqueezeNet(
|
207 |
+
num_classes, loss, version=1.0, fc_dims=None, dropout_p=None, **kwargs
|
208 |
+
)
|
209 |
+
if pretrained:
|
210 |
+
init_pretrained_weights(model, model_urls['squeezenet1_0'])
|
211 |
+
return model
|
212 |
+
|
213 |
+
|
214 |
+
def squeezenet1_0_fc512(
|
215 |
+
num_classes, loss='softmax', pretrained=True, **kwargs
|
216 |
+
):
|
217 |
+
model = SqueezeNet(
|
218 |
+
num_classes,
|
219 |
+
loss,
|
220 |
+
version=1.0,
|
221 |
+
fc_dims=[512],
|
222 |
+
dropout_p=None,
|
223 |
+
**kwargs
|
224 |
+
)
|
225 |
+
if pretrained:
|
226 |
+
init_pretrained_weights(model, model_urls['squeezenet1_0'])
|
227 |
+
return model
|
228 |
+
|
229 |
+
|
230 |
+
def squeezenet1_1(num_classes, loss='softmax', pretrained=True, **kwargs):
|
231 |
+
model = SqueezeNet(
|
232 |
+
num_classes, loss, version=1.1, fc_dims=None, dropout_p=None, **kwargs
|
233 |
+
)
|
234 |
+
if pretrained:
|
235 |
+
init_pretrained_weights(model, model_urls['squeezenet1_1'])
|
236 |
+
return model
|
trackers/strongsort/deep/models/xception.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, absolute_import
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.utils.model_zoo as model_zoo
|
5 |
+
|
6 |
+
__all__ = ['xception']
|
7 |
+
|
8 |
+
pretrained_settings = {
|
9 |
+
'xception': {
|
10 |
+
'imagenet': {
|
11 |
+
'url':
|
12 |
+
'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
|
13 |
+
'input_space': 'RGB',
|
14 |
+
'input_size': [3, 299, 299],
|
15 |
+
'input_range': [0, 1],
|
16 |
+
'mean': [0.5, 0.5, 0.5],
|
17 |
+
'std': [0.5, 0.5, 0.5],
|
18 |
+
'num_classes': 1000,
|
19 |
+
'scale':
|
20 |
+
0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
21 |
+
}
|
22 |
+
}
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
class SeparableConv2d(nn.Module):
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
in_channels,
|
31 |
+
out_channels,
|
32 |
+
kernel_size=1,
|
33 |
+
stride=1,
|
34 |
+
padding=0,
|
35 |
+
dilation=1,
|
36 |
+
bias=False
|
37 |
+
):
|
38 |
+
super(SeparableConv2d, self).__init__()
|
39 |
+
|
40 |
+
self.conv1 = nn.Conv2d(
|
41 |
+
in_channels,
|
42 |
+
in_channels,
|
43 |
+
kernel_size,
|
44 |
+
stride,
|
45 |
+
padding,
|
46 |
+
dilation,
|
47 |
+
groups=in_channels,
|
48 |
+
bias=bias
|
49 |
+
)
|
50 |
+
self.pointwise = nn.Conv2d(
|
51 |
+
in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
x = self.conv1(x)
|
56 |
+
x = self.pointwise(x)
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class Block(nn.Module):
|
61 |
+
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
in_filters,
|
65 |
+
out_filters,
|
66 |
+
reps,
|
67 |
+
strides=1,
|
68 |
+
start_with_relu=True,
|
69 |
+
grow_first=True
|
70 |
+
):
|
71 |
+
super(Block, self).__init__()
|
72 |
+
|
73 |
+
if out_filters != in_filters or strides != 1:
|
74 |
+
self.skip = nn.Conv2d(
|
75 |
+
in_filters, out_filters, 1, stride=strides, bias=False
|
76 |
+
)
|
77 |
+
self.skipbn = nn.BatchNorm2d(out_filters)
|
78 |
+
else:
|
79 |
+
self.skip = None
|
80 |
+
|
81 |
+
self.relu = nn.ReLU(inplace=True)
|
82 |
+
rep = []
|
83 |
+
|
84 |
+
filters = in_filters
|
85 |
+
if grow_first:
|
86 |
+
rep.append(self.relu)
|
87 |
+
rep.append(
|
88 |
+
SeparableConv2d(
|
89 |
+
in_filters,
|
90 |
+
out_filters,
|
91 |
+
3,
|
92 |
+
stride=1,
|
93 |
+
padding=1,
|
94 |
+
bias=False
|
95 |
+
)
|
96 |
+
)
|
97 |
+
rep.append(nn.BatchNorm2d(out_filters))
|
98 |
+
filters = out_filters
|
99 |
+
|
100 |
+
for i in range(reps - 1):
|
101 |
+
rep.append(self.relu)
|
102 |
+
rep.append(
|
103 |
+
SeparableConv2d(
|
104 |
+
filters, filters, 3, stride=1, padding=1, bias=False
|
105 |
+
)
|
106 |
+
)
|
107 |
+
rep.append(nn.BatchNorm2d(filters))
|
108 |
+
|
109 |
+
if not grow_first:
|
110 |
+
rep.append(self.relu)
|
111 |
+
rep.append(
|
112 |
+
SeparableConv2d(
|
113 |
+
in_filters,
|
114 |
+
out_filters,
|
115 |
+
3,
|
116 |
+
stride=1,
|
117 |
+
padding=1,
|
118 |
+
bias=False
|
119 |
+
)
|
120 |
+
)
|
121 |
+
rep.append(nn.BatchNorm2d(out_filters))
|
122 |
+
|
123 |
+
if not start_with_relu:
|
124 |
+
rep = rep[1:]
|
125 |
+
else:
|
126 |
+
rep[0] = nn.ReLU(inplace=False)
|
127 |
+
|
128 |
+
if strides != 1:
|
129 |
+
rep.append(nn.MaxPool2d(3, strides, 1))
|
130 |
+
self.rep = nn.Sequential(*rep)
|
131 |
+
|
132 |
+
def forward(self, inp):
|
133 |
+
x = self.rep(inp)
|
134 |
+
|
135 |
+
if self.skip is not None:
|
136 |
+
skip = self.skip(inp)
|
137 |
+
skip = self.skipbn(skip)
|
138 |
+
else:
|
139 |
+
skip = inp
|
140 |
+
|
141 |
+
x += skip
|
142 |
+
return x
|
143 |
+
|
144 |
+
|
145 |
+
class Xception(nn.Module):
|
146 |
+
"""Xception.
|
147 |
+
|
148 |
+
Reference:
|
149 |
+
Chollet. Xception: Deep Learning with Depthwise
|
150 |
+
Separable Convolutions. CVPR 2017.
|
151 |
+
|
152 |
+
Public keys:
|
153 |
+
- ``xception``: Xception.
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(
|
157 |
+
self, num_classes, loss, fc_dims=None, dropout_p=None, **kwargs
|
158 |
+
):
|
159 |
+
super(Xception, self).__init__()
|
160 |
+
self.loss = loss
|
161 |
+
|
162 |
+
self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False)
|
163 |
+
self.bn1 = nn.BatchNorm2d(32)
|
164 |
+
|
165 |
+
self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
|
166 |
+
self.bn2 = nn.BatchNorm2d(64)
|
167 |
+
|
168 |
+
self.block1 = Block(
|
169 |
+
64, 128, 2, 2, start_with_relu=False, grow_first=True
|
170 |
+
)
|
171 |
+
self.block2 = Block(
|
172 |
+
128, 256, 2, 2, start_with_relu=True, grow_first=True
|
173 |
+
)
|
174 |
+
self.block3 = Block(
|
175 |
+
256, 728, 2, 2, start_with_relu=True, grow_first=True
|
176 |
+
)
|
177 |
+
|
178 |
+
self.block4 = Block(
|
179 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True
|
180 |
+
)
|
181 |
+
self.block5 = Block(
|
182 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True
|
183 |
+
)
|
184 |
+
self.block6 = Block(
|
185 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True
|
186 |
+
)
|
187 |
+
self.block7 = Block(
|
188 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True
|
189 |
+
)
|
190 |
+
|
191 |
+
self.block8 = Block(
|
192 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True
|
193 |
+
)
|
194 |
+
self.block9 = Block(
|
195 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True
|
196 |
+
)
|
197 |
+
self.block10 = Block(
|
198 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True
|
199 |
+
)
|
200 |
+
self.block11 = Block(
|
201 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True
|
202 |
+
)
|
203 |
+
|
204 |
+
self.block12 = Block(
|
205 |
+
728, 1024, 2, 2, start_with_relu=True, grow_first=False
|
206 |
+
)
|
207 |
+
|
208 |
+
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
|
209 |
+
self.bn3 = nn.BatchNorm2d(1536)
|
210 |
+
|
211 |
+
self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
|
212 |
+
self.bn4 = nn.BatchNorm2d(2048)
|
213 |
+
|
214 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
215 |
+
self.feature_dim = 2048
|
216 |
+
self.fc = self._construct_fc_layer(fc_dims, 2048, dropout_p)
|
217 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
218 |
+
|
219 |
+
self._init_params()
|
220 |
+
|
221 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
222 |
+
"""Constructs fully connected layer.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
226 |
+
input_dim (int): input dimension
|
227 |
+
dropout_p (float): dropout probability, if None, dropout is unused
|
228 |
+
"""
|
229 |
+
if fc_dims is None:
|
230 |
+
self.feature_dim = input_dim
|
231 |
+
return None
|
232 |
+
|
233 |
+
assert isinstance(
|
234 |
+
fc_dims, (list, tuple)
|
235 |
+
), 'fc_dims must be either list or tuple, but got {}'.format(
|
236 |
+
type(fc_dims)
|
237 |
+
)
|
238 |
+
|
239 |
+
layers = []
|
240 |
+
for dim in fc_dims:
|
241 |
+
layers.append(nn.Linear(input_dim, dim))
|
242 |
+
layers.append(nn.BatchNorm1d(dim))
|
243 |
+
layers.append(nn.ReLU(inplace=True))
|
244 |
+
if dropout_p is not None:
|
245 |
+
layers.append(nn.Dropout(p=dropout_p))
|
246 |
+
input_dim = dim
|
247 |
+
|
248 |
+
self.feature_dim = fc_dims[-1]
|
249 |
+
|
250 |
+
return nn.Sequential(*layers)
|
251 |
+
|
252 |
+
def _init_params(self):
|
253 |
+
for m in self.modules():
|
254 |
+
if isinstance(m, nn.Conv2d):
|
255 |
+
nn.init.kaiming_normal_(
|
256 |
+
m.weight, mode='fan_out', nonlinearity='relu'
|
257 |
+
)
|
258 |
+
if m.bias is not None:
|
259 |
+
nn.init.constant_(m.bias, 0)
|
260 |
+
elif isinstance(m, nn.BatchNorm2d):
|
261 |
+
nn.init.constant_(m.weight, 1)
|
262 |
+
nn.init.constant_(m.bias, 0)
|
263 |
+
elif isinstance(m, nn.BatchNorm1d):
|
264 |
+
nn.init.constant_(m.weight, 1)
|
265 |
+
nn.init.constant_(m.bias, 0)
|
266 |
+
elif isinstance(m, nn.Linear):
|
267 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
268 |
+
if m.bias is not None:
|
269 |
+
nn.init.constant_(m.bias, 0)
|
270 |
+
|
271 |
+
def featuremaps(self, input):
|
272 |
+
x = self.conv1(input)
|
273 |
+
x = self.bn1(x)
|
274 |
+
x = F.relu(x, inplace=True)
|
275 |
+
|
276 |
+
x = self.conv2(x)
|
277 |
+
x = self.bn2(x)
|
278 |
+
x = F.relu(x, inplace=True)
|
279 |
+
|
280 |
+
x = self.block1(x)
|
281 |
+
x = self.block2(x)
|
282 |
+
x = self.block3(x)
|
283 |
+
x = self.block4(x)
|
284 |
+
x = self.block5(x)
|
285 |
+
x = self.block6(x)
|
286 |
+
x = self.block7(x)
|
287 |
+
x = self.block8(x)
|
288 |
+
x = self.block9(x)
|
289 |
+
x = self.block10(x)
|
290 |
+
x = self.block11(x)
|
291 |
+
x = self.block12(x)
|
292 |
+
|
293 |
+
x = self.conv3(x)
|
294 |
+
x = self.bn3(x)
|
295 |
+
x = F.relu(x, inplace=True)
|
296 |
+
|
297 |
+
x = self.conv4(x)
|
298 |
+
x = self.bn4(x)
|
299 |
+
x = F.relu(x, inplace=True)
|
300 |
+
return x
|
301 |
+
|
302 |
+
def forward(self, x):
|
303 |
+
f = self.featuremaps(x)
|
304 |
+
v = self.global_avgpool(f)
|
305 |
+
v = v.view(v.size(0), -1)
|
306 |
+
|
307 |
+
if self.fc is not None:
|
308 |
+
v = self.fc(v)
|
309 |
+
|
310 |
+
if not self.training:
|
311 |
+
return v
|
312 |
+
|
313 |
+
y = self.classifier(v)
|
314 |
+
|
315 |
+
if self.loss == 'softmax':
|
316 |
+
return y
|
317 |
+
elif self.loss == 'triplet':
|
318 |
+
return y, v
|
319 |
+
else:
|
320 |
+
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
321 |
+
|
322 |
+
|
323 |
+
def init_pretrained_weights(model, model_url):
|
324 |
+
"""Initialize models with pretrained weights.
|
325 |
+
|
326 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
327 |
+
"""
|
328 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
329 |
+
model_dict = model.state_dict()
|
330 |
+
pretrain_dict = {
|
331 |
+
k: v
|
332 |
+
for k, v in pretrain_dict.items()
|
333 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
334 |
+
}
|
335 |
+
model_dict.update(pretrain_dict)
|
336 |
+
model.load_state_dict(model_dict)
|
337 |
+
|
338 |
+
|
339 |
+
def xception(num_classes, loss='softmax', pretrained=True, **kwargs):
|
340 |
+
model = Xception(num_classes, loss, fc_dims=None, dropout_p=None, **kwargs)
|
341 |
+
if pretrained:
|
342 |
+
model_url = pretrained_settings['xception']['imagenet']['url']
|
343 |
+
init_pretrained_weights(model, model_url)
|
344 |
+
return model
|
trackers/strongsort/deep/reid_model_factory.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
__model_types = [
|
7 |
+
'resnet50', 'mlfn', 'hacnn', 'mobilenetv2_x1_0', 'mobilenetv2_x1_4',
|
8 |
+
'osnet_x1_0', 'osnet_x0_75', 'osnet_x0_5', 'osnet_x0_25',
|
9 |
+
'osnet_ibn_x1_0', 'osnet_ain_x1_0']
|
10 |
+
|
11 |
+
__trained_urls = {
|
12 |
+
|
13 |
+
# market1501 models ########################################################
|
14 |
+
'resnet50_market1501.pt':
|
15 |
+
'https://drive.google.com/uc?id=1dUUZ4rHDWohmsQXCRe2C_HbYkzz94iBV',
|
16 |
+
'resnet50_dukemtmcreid.pt':
|
17 |
+
'https://drive.google.com/uc?id=17ymnLglnc64NRvGOitY3BqMRS9UWd1wg',
|
18 |
+
'resnet50_msmt17.pt':
|
19 |
+
'https://drive.google.com/uc?id=1ep7RypVDOthCRIAqDnn4_N-UhkkFHJsj',
|
20 |
+
|
21 |
+
'resnet50_fc512_market1501.pt':
|
22 |
+
'https://drive.google.com/uc?id=1kv8l5laX_YCdIGVCetjlNdzKIA3NvsSt',
|
23 |
+
'resnet50_fc512_dukemtmcreid.pt':
|
24 |
+
'https://drive.google.com/uc?id=13QN8Mp3XH81GK4BPGXobKHKyTGH50Rtx',
|
25 |
+
'resnet50_fc512_msmt17.pt':
|
26 |
+
'https://drive.google.com/uc?id=1fDJLcz4O5wxNSUvImIIjoaIF9u1Rwaud',
|
27 |
+
|
28 |
+
'mlfn_market1501.pt':
|
29 |
+
'https://drive.google.com/uc?id=1wXcvhA_b1kpDfrt9s2Pma-MHxtj9pmvS',
|
30 |
+
'mlfn_dukemtmcreid.pt':
|
31 |
+
'https://drive.google.com/uc?id=1rExgrTNb0VCIcOnXfMsbwSUW1h2L1Bum',
|
32 |
+
'mlfn_msmt17.pt':
|
33 |
+
'https://drive.google.com/uc?id=18JzsZlJb3Wm7irCbZbZ07TN4IFKvR6p-',
|
34 |
+
|
35 |
+
'hacnn_market1501.pt':
|
36 |
+
'https://drive.google.com/uc?id=1LRKIQduThwGxMDQMiVkTScBwR7WidmYF',
|
37 |
+
'hacnn_dukemtmcreid.pt':
|
38 |
+
'https://drive.google.com/uc?id=1zNm6tP4ozFUCUQ7Sv1Z98EAJWXJEhtYH',
|
39 |
+
'hacnn_msmt17.pt':
|
40 |
+
'https://drive.google.com/uc?id=1MsKRtPM5WJ3_Tk2xC0aGOO7pM3VaFDNZ',
|
41 |
+
|
42 |
+
'mobilenetv2_x1_0_market1501.pt':
|
43 |
+
'https://drive.google.com/uc?id=18DgHC2ZJkjekVoqBWszD8_Xiikz-fewp',
|
44 |
+
'mobilenetv2_x1_0_dukemtmcreid.pt':
|
45 |
+
'https://drive.google.com/uc?id=1q1WU2FETRJ3BXcpVtfJUuqq4z3psetds',
|
46 |
+
'mobilenetv2_x1_0_msmt17.pt':
|
47 |
+
'https://drive.google.com/uc?id=1j50Hv14NOUAg7ZeB3frzfX-WYLi7SrhZ',
|
48 |
+
|
49 |
+
'mobilenetv2_x1_4_market1501.pt':
|
50 |
+
'https://drive.google.com/uc?id=1t6JCqphJG-fwwPVkRLmGGyEBhGOf2GO5',
|
51 |
+
'mobilenetv2_x1_4_dukemtmcreid.pt':
|
52 |
+
'https://drive.google.com/uc?id=12uD5FeVqLg9-AFDju2L7SQxjmPb4zpBN',
|
53 |
+
'mobilenetv2_x1_4_msmt17.pt':
|
54 |
+
'https://drive.google.com/uc?id=1ZY5P2Zgm-3RbDpbXM0kIBMPvspeNIbXz',
|
55 |
+
|
56 |
+
'osnet_x1_0_market1501.pt':
|
57 |
+
'https://drive.google.com/uc?id=1vduhq5DpN2q1g4fYEZfPI17MJeh9qyrA',
|
58 |
+
'osnet_x1_0_dukemtmcreid.pt':
|
59 |
+
'https://drive.google.com/uc?id=1QZO_4sNf4hdOKKKzKc-TZU9WW1v6zQbq',
|
60 |
+
'osnet_x1_0_msmt17.pt':
|
61 |
+
'https://drive.google.com/uc?id=112EMUfBPYeYg70w-syK6V6Mx8-Qb9Q1M',
|
62 |
+
|
63 |
+
'osnet_x0_75_market1501.pt':
|
64 |
+
'https://drive.google.com/uc?id=1ozRaDSQw_EQ8_93OUmjDbvLXw9TnfPer',
|
65 |
+
'osnet_x0_75_dukemtmcreid.pt':
|
66 |
+
'https://drive.google.com/uc?id=1IE3KRaTPp4OUa6PGTFL_d5_KQSJbP0Or',
|
67 |
+
'osnet_x0_75_msmt17.pt':
|
68 |
+
'https://drive.google.com/uc?id=1QEGO6WnJ-BmUzVPd3q9NoaO_GsPNlmWc',
|
69 |
+
|
70 |
+
'osnet_x0_5_market1501.pt':
|
71 |
+
'https://drive.google.com/uc?id=1PLB9rgqrUM7blWrg4QlprCuPT7ILYGKT',
|
72 |
+
'osnet_x0_5_dukemtmcreid.pt':
|
73 |
+
'https://drive.google.com/uc?id=1KoUVqmiST175hnkALg9XuTi1oYpqcyTu',
|
74 |
+
'osnet_x0_5_msmt17.pt':
|
75 |
+
'https://drive.google.com/uc?id=1UT3AxIaDvS2PdxzZmbkLmjtiqq7AIKCv',
|
76 |
+
|
77 |
+
'osnet_x0_25_market1501.pt':
|
78 |
+
'https://drive.google.com/uc?id=1z1UghYvOTtjx7kEoRfmqSMu-z62J6MAj',
|
79 |
+
'osnet_x0_25_dukemtmcreid.pt':
|
80 |
+
'https://drive.google.com/uc?id=1eumrtiXT4NOspjyEV4j8cHmlOaaCGk5l',
|
81 |
+
'osnet_x0_25_msmt17.pt':
|
82 |
+
'https://drive.google.com/uc?id=1sSwXSUlj4_tHZequ_iZ8w_Jh0VaRQMqF',
|
83 |
+
|
84 |
+
####### market1501 models ##################################################
|
85 |
+
'resnet50_msmt17.pt':
|
86 |
+
'https://drive.google.com/uc?id=1yiBteqgIZoOeywE8AhGmEQl7FTVwrQmf',
|
87 |
+
'osnet_x1_0_msmt17.pt':
|
88 |
+
'https://drive.google.com/uc?id=1IosIFlLiulGIjwW3H8uMRmx3MzPwf86x',
|
89 |
+
'osnet_x0_75_msmt17.pt':
|
90 |
+
'https://drive.google.com/uc?id=1fhjSS_7SUGCioIf2SWXaRGPqIY9j7-uw',
|
91 |
+
|
92 |
+
'osnet_x0_5_msmt17.pt':
|
93 |
+
'https://drive.google.com/uc?id=1DHgmb6XV4fwG3n-CnCM0zdL9nMsZ9_RF',
|
94 |
+
'osnet_x0_25_msmt17.pt':
|
95 |
+
'https://drive.google.com/uc?id=1Kkx2zW89jq_NETu4u42CFZTMVD5Hwm6e',
|
96 |
+
'osnet_ibn_x1_0_msmt17.pt':
|
97 |
+
'https://drive.google.com/uc?id=1q3Sj2ii34NlfxA4LvmHdWO_75NDRmECJ',
|
98 |
+
'osnet_ain_x1_0_msmt17.pt':
|
99 |
+
'https://drive.google.com/uc?id=1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal',
|
100 |
+
}
|
101 |
+
|
102 |
+
|
103 |
+
def show_downloadeable_models():
|
104 |
+
print('\nAvailable .pt ReID models for automatic download')
|
105 |
+
print(list(__trained_urls.keys()))
|
106 |
+
|
107 |
+
|
108 |
+
def get_model_url(model):
|
109 |
+
if model.name in __trained_urls:
|
110 |
+
return __trained_urls[model.name]
|
111 |
+
else:
|
112 |
+
None
|
113 |
+
|
114 |
+
|
115 |
+
def is_model_in_model_types(model):
|
116 |
+
if model.name in __model_types:
|
117 |
+
return True
|
118 |
+
else:
|
119 |
+
return False
|
120 |
+
|
121 |
+
|
122 |
+
def get_model_name(model):
|
123 |
+
for x in __model_types:
|
124 |
+
if x in model.name:
|
125 |
+
return x
|
126 |
+
return None
|
127 |
+
|
128 |
+
|
129 |
+
def download_url(url, dst):
|
130 |
+
"""Downloads file from a url to a destination.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
url (str): url to download file.
|
134 |
+
dst (str): destination path.
|
135 |
+
"""
|
136 |
+
from six.moves import urllib
|
137 |
+
print('* url="{}"'.format(url))
|
138 |
+
print('* destination="{}"'.format(dst))
|
139 |
+
|
140 |
+
def _reporthook(count, block_size, total_size):
|
141 |
+
global start_time
|
142 |
+
if count == 0:
|
143 |
+
start_time = time.time()
|
144 |
+
return
|
145 |
+
duration = time.time() - start_time
|
146 |
+
progress_size = int(count * block_size)
|
147 |
+
speed = int(progress_size / (1024*duration))
|
148 |
+
percent = int(count * block_size * 100 / total_size)
|
149 |
+
sys.stdout.write(
|
150 |
+
'\r...%d%%, %d MB, %d KB/s, %d seconds passed' %
|
151 |
+
(percent, progress_size / (1024*1024), speed, duration)
|
152 |
+
)
|
153 |
+
sys.stdout.flush()
|
154 |
+
|
155 |
+
urllib.request.urlretrieve(url, dst, _reporthook)
|
156 |
+
sys.stdout.write('\n')
|
157 |
+
|
158 |
+
|
159 |
+
def load_pretrained_weights(model, weight_path):
|
160 |
+
r"""Loads pretrianed weights to model.
|
161 |
+
|
162 |
+
Features::
|
163 |
+
- Incompatible layers (unmatched in name or size) will be ignored.
|
164 |
+
- Can automatically deal with keys containing "module.".
|
165 |
+
|
166 |
+
Args:
|
167 |
+
model (nn.Module): network model.
|
168 |
+
weight_path (str): path to pretrained weights.
|
169 |
+
|
170 |
+
Examples::
|
171 |
+
>>> from torchreid.utils import load_pretrained_weights
|
172 |
+
>>> weight_path = 'log/my_model/model-best.pth.tar'
|
173 |
+
>>> load_pretrained_weights(model, weight_path)
|
174 |
+
"""
|
175 |
+
checkpoint = torch.load(weight_path)
|
176 |
+
if 'state_dict' in checkpoint:
|
177 |
+
state_dict = checkpoint['state_dict']
|
178 |
+
else:
|
179 |
+
state_dict = checkpoint
|
180 |
+
|
181 |
+
model_dict = model.state_dict()
|
182 |
+
new_state_dict = OrderedDict()
|
183 |
+
matched_layers, discarded_layers = [], []
|
184 |
+
|
185 |
+
for k, v in state_dict.items():
|
186 |
+
if k.startswith('module.'):
|
187 |
+
k = k[7:] # discard module.
|
188 |
+
|
189 |
+
if k in model_dict and model_dict[k].size() == v.size():
|
190 |
+
new_state_dict[k] = v
|
191 |
+
matched_layers.append(k)
|
192 |
+
else:
|
193 |
+
discarded_layers.append(k)
|
194 |
+
|
195 |
+
model_dict.update(new_state_dict)
|
196 |
+
model.load_state_dict(model_dict)
|
197 |
+
|
198 |
+
if len(matched_layers) == 0:
|
199 |
+
warnings.warn(
|
200 |
+
'The pretrained weights "{}" cannot be loaded, '
|
201 |
+
'please check the key names manually '
|
202 |
+
'(** ignored and continue **)'.format(weight_path)
|
203 |
+
)
|
204 |
+
else:
|
205 |
+
print(
|
206 |
+
'Successfully loaded pretrained weights from "{}"'.
|
207 |
+
format(weight_path)
|
208 |
+
)
|
209 |
+
if len(discarded_layers) > 0:
|
210 |
+
print(
|
211 |
+
'** The following layers are discarded '
|
212 |
+
'due to unmatched keys or layer size: {}'.
|
213 |
+
format(discarded_layers)
|
214 |
+
)
|
215 |
+
|
trackers/strongsort/reid_multibackend.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
from pathlib import Path
|
4 |
+
import numpy as np
|
5 |
+
from itertools import islice
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
import cv2
|
8 |
+
import sys
|
9 |
+
import torchvision.transforms as T
|
10 |
+
from collections import OrderedDict, namedtuple
|
11 |
+
import gdown
|
12 |
+
from os.path import exists as file_exists
|
13 |
+
|
14 |
+
|
15 |
+
from ultralytics.yolo.utils.checks import check_requirements, check_version
|
16 |
+
from ultralytics.yolo.utils import LOGGER
|
17 |
+
from trackers.strongsort.deep.reid_model_factory import (show_downloadeable_models, get_model_url, get_model_name,
|
18 |
+
download_url, load_pretrained_weights)
|
19 |
+
from trackers.strongsort.deep.models import build_model
|
20 |
+
|
21 |
+
|
22 |
+
def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
|
23 |
+
# Check file(s) for acceptable suffix
|
24 |
+
if file and suffix:
|
25 |
+
if isinstance(suffix, str):
|
26 |
+
suffix = [suffix]
|
27 |
+
for f in file if isinstance(file, (list, tuple)) else [file]:
|
28 |
+
s = Path(f).suffix.lower() # file suffix
|
29 |
+
if len(s):
|
30 |
+
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
|
31 |
+
|
32 |
+
|
33 |
+
class ReIDDetectMultiBackend(nn.Module):
|
34 |
+
# ReID models MultiBackend class for python inference on various backends
|
35 |
+
def __init__(self, weights='osnet_x0_25_msmt17.pt', device=torch.device('cpu'), fp16=False):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
w = weights[0] if isinstance(weights, list) else weights
|
39 |
+
self.pt, self.jit, self.onnx, self.xml, self.engine, self.tflite = self.model_type(w) # get backend
|
40 |
+
self.fp16 = fp16
|
41 |
+
self.fp16 &= self.pt or self.jit or self.engine # FP16
|
42 |
+
|
43 |
+
# Build transform functions
|
44 |
+
self.device = device
|
45 |
+
self.image_size=(256, 128)
|
46 |
+
self.pixel_mean=[0.485, 0.456, 0.406]
|
47 |
+
self.pixel_std=[0.229, 0.224, 0.225]
|
48 |
+
self.transforms = []
|
49 |
+
self.transforms += [T.Resize(self.image_size)]
|
50 |
+
self.transforms += [T.ToTensor()]
|
51 |
+
self.transforms += [T.Normalize(mean=self.pixel_mean, std=self.pixel_std)]
|
52 |
+
self.preprocess = T.Compose(self.transforms)
|
53 |
+
self.to_pil = T.ToPILImage()
|
54 |
+
|
55 |
+
model_name = get_model_name(w)
|
56 |
+
|
57 |
+
if w.suffix == '.pt':
|
58 |
+
model_url = get_model_url(w)
|
59 |
+
if not file_exists(w) and model_url is not None:
|
60 |
+
gdown.download(model_url, str(w), quiet=False)
|
61 |
+
elif file_exists(w):
|
62 |
+
pass
|
63 |
+
else:
|
64 |
+
print(f'No URL associated to the chosen StrongSORT weights ({w}). Choose between:')
|
65 |
+
show_downloadeable_models()
|
66 |
+
exit()
|
67 |
+
|
68 |
+
# Build model
|
69 |
+
self.model = build_model(
|
70 |
+
model_name,
|
71 |
+
num_classes=1,
|
72 |
+
pretrained=not (w and w.is_file()),
|
73 |
+
use_gpu=device
|
74 |
+
)
|
75 |
+
|
76 |
+
if self.pt: # PyTorch
|
77 |
+
# populate model arch with weights
|
78 |
+
if w and w.is_file() and w.suffix == '.pt':
|
79 |
+
load_pretrained_weights(self.model, w)
|
80 |
+
|
81 |
+
self.model.to(device).eval()
|
82 |
+
self.model.half() if self.fp16 else self.model.float()
|
83 |
+
elif self.jit:
|
84 |
+
LOGGER.info(f'Loading {w} for TorchScript inference...')
|
85 |
+
self.model = torch.jit.load(w)
|
86 |
+
self.model.half() if self.fp16 else self.model.float()
|
87 |
+
elif self.onnx: # ONNX Runtime
|
88 |
+
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
89 |
+
cuda = torch.cuda.is_available() and device.type != 'cpu'
|
90 |
+
#check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
91 |
+
import onnxruntime
|
92 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
93 |
+
self.session = onnxruntime.InferenceSession(str(w), providers=providers)
|
94 |
+
elif self.engine: # TensorRT
|
95 |
+
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
96 |
+
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
97 |
+
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
98 |
+
if device.type == 'cpu':
|
99 |
+
device = torch.device('cuda:0')
|
100 |
+
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
101 |
+
logger = trt.Logger(trt.Logger.INFO)
|
102 |
+
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
103 |
+
self.model_ = runtime.deserialize_cuda_engine(f.read())
|
104 |
+
self.context = self.model_.create_execution_context()
|
105 |
+
self.bindings = OrderedDict()
|
106 |
+
self.fp16 = False # default updated below
|
107 |
+
dynamic = False
|
108 |
+
for index in range(self.model_.num_bindings):
|
109 |
+
name = self.model_.get_binding_name(index)
|
110 |
+
dtype = trt.nptype(self.model_.get_binding_dtype(index))
|
111 |
+
if self.model_.binding_is_input(index):
|
112 |
+
if -1 in tuple(self.model_.get_binding_shape(index)): # dynamic
|
113 |
+
dynamic = True
|
114 |
+
self.context.set_binding_shape(index, tuple(self.model_.get_profile_shape(0, index)[2]))
|
115 |
+
if dtype == np.float16:
|
116 |
+
self.fp16 = True
|
117 |
+
shape = tuple(self.context.get_binding_shape(index))
|
118 |
+
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
119 |
+
self.bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
120 |
+
self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items())
|
121 |
+
batch_size = self.bindings['images'].shape[0] # if dynamic, this is instead max batch size
|
122 |
+
elif self.xml: # OpenVINO
|
123 |
+
LOGGER.info(f'Loading {w} for OpenVINO inference...')
|
124 |
+
check_requirements(('openvino',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
125 |
+
from openvino.runtime import Core, Layout, get_batch
|
126 |
+
ie = Core()
|
127 |
+
if not Path(w).is_file(): # if not *.xml
|
128 |
+
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
|
129 |
+
network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
|
130 |
+
if network.get_parameters()[0].get_layout().empty:
|
131 |
+
network.get_parameters()[0].set_layout(Layout("NCWH"))
|
132 |
+
batch_dim = get_batch(network)
|
133 |
+
if batch_dim.is_static:
|
134 |
+
batch_size = batch_dim.get_length()
|
135 |
+
self.executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
|
136 |
+
self.output_layer = next(iter(self.executable_network.outputs))
|
137 |
+
|
138 |
+
elif self.tflite:
|
139 |
+
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
|
140 |
+
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
141 |
+
from tflite_runtime.interpreter import Interpreter, load_delegate
|
142 |
+
except ImportError:
|
143 |
+
import tensorflow as tf
|
144 |
+
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
|
145 |
+
self.interpreter = tf.lite.Interpreter(model_path=w)
|
146 |
+
self.interpreter.allocate_tensors()
|
147 |
+
# Get input and output tensors.
|
148 |
+
self.input_details = self.interpreter.get_input_details()
|
149 |
+
self.output_details = self.interpreter.get_output_details()
|
150 |
+
|
151 |
+
# Test model on random input data.
|
152 |
+
input_data = np.array(np.random.random_sample((1,256,128,3)), dtype=np.float32)
|
153 |
+
self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
|
154 |
+
|
155 |
+
self.interpreter.invoke()
|
156 |
+
|
157 |
+
# The function `get_tensor()` returns a copy of the tensor data.
|
158 |
+
output_data = self.interpreter.get_tensor(self.output_details[0]['index'])
|
159 |
+
else:
|
160 |
+
print('This model framework is not supported yet!')
|
161 |
+
exit()
|
162 |
+
|
163 |
+
|
164 |
+
@staticmethod
|
165 |
+
def model_type(p='path/to/model.pt'):
|
166 |
+
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
167 |
+
from trackers.reid_export import export_formats
|
168 |
+
sf = list(export_formats().Suffix) # export suffixes
|
169 |
+
check_suffix(p, sf) # checks
|
170 |
+
types = [s in Path(p).name for s in sf]
|
171 |
+
return types
|
172 |
+
|
173 |
+
def _preprocess(self, im_batch):
|
174 |
+
|
175 |
+
images = []
|
176 |
+
for element in im_batch:
|
177 |
+
image = self.to_pil(element)
|
178 |
+
image = self.preprocess(image)
|
179 |
+
images.append(image)
|
180 |
+
|
181 |
+
images = torch.stack(images, dim=0)
|
182 |
+
images = images.to(self.device)
|
183 |
+
|
184 |
+
return images
|
185 |
+
|
186 |
+
|
187 |
+
def forward(self, im_batch):
|
188 |
+
|
189 |
+
# preprocess batch
|
190 |
+
im_batch = self._preprocess(im_batch)
|
191 |
+
|
192 |
+
# batch to half
|
193 |
+
if self.fp16 and im_batch.dtype != torch.float16:
|
194 |
+
im_batch = im_batch.half()
|
195 |
+
|
196 |
+
# batch processing
|
197 |
+
features = []
|
198 |
+
if self.pt:
|
199 |
+
features = self.model(im_batch)
|
200 |
+
elif self.jit: # TorchScript
|
201 |
+
features = self.model(im_batch)
|
202 |
+
elif self.onnx: # ONNX Runtime
|
203 |
+
im_batch = im_batch.cpu().numpy() # torch to numpy
|
204 |
+
features = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im_batch})[0]
|
205 |
+
elif self.engine: # TensorRT
|
206 |
+
if True and im_batch.shape != self.bindings['images'].shape:
|
207 |
+
i_in, i_out = (self.model_.get_binding_index(x) for x in ('images', 'output'))
|
208 |
+
self.context.set_binding_shape(i_in, im_batch.shape) # reshape if dynamic
|
209 |
+
self.bindings['images'] = self.bindings['images']._replace(shape=im_batch.shape)
|
210 |
+
self.bindings['output'].data.resize_(tuple(self.context.get_binding_shape(i_out)))
|
211 |
+
s = self.bindings['images'].shape
|
212 |
+
assert im_batch.shape == s, f"input size {im_batch.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
|
213 |
+
self.binding_addrs['images'] = int(im_batch.data_ptr())
|
214 |
+
self.context.execute_v2(list(self.binding_addrs.values()))
|
215 |
+
features = self.bindings['output'].data
|
216 |
+
elif self.xml: # OpenVINO
|
217 |
+
im_batch = im_batch.cpu().numpy() # FP32
|
218 |
+
features = self.executable_network([im_batch])[self.output_layer]
|
219 |
+
else:
|
220 |
+
print('Framework not supported at the moment, we are working on it...')
|
221 |
+
exit()
|
222 |
+
|
223 |
+
if isinstance(features, (list, tuple)):
|
224 |
+
return self.from_numpy(features[0]) if len(features) == 1 else [self.from_numpy(x) for x in features]
|
225 |
+
else:
|
226 |
+
return self.from_numpy(features)
|
227 |
+
|
228 |
+
def from_numpy(self, x):
|
229 |
+
return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
|
230 |
+
|
231 |
+
def warmup(self, imgsz=[(256, 128, 3)]):
|
232 |
+
# Warmup model by running inference once
|
233 |
+
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.tflite
|
234 |
+
if any(warmup_types) and self.device.type != 'cpu':
|
235 |
+
im = [np.empty(*imgsz).astype(np.uint8)] # input
|
236 |
+
for _ in range(2 if self.jit else 1): #
|
237 |
+
self.forward(im) # warmup
|
trackers/strongsort/sort/__init__.py
ADDED
File without changes
|
trackers/strongsort/sort/detection.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vim: expandtab:ts=4:sw=4
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
class Detection(object):
|
6 |
+
"""
|
7 |
+
This class represents a bounding box detection in a single image.
|
8 |
+
|
9 |
+
Parameters
|
10 |
+
----------
|
11 |
+
tlwh : array_like
|
12 |
+
Bounding box in format `(x, y, w, h)`.
|
13 |
+
confidence : float
|
14 |
+
Detector confidence score.
|
15 |
+
feature : array_like
|
16 |
+
A feature vector that describes the object contained in this image.
|
17 |
+
|
18 |
+
Attributes
|
19 |
+
----------
|
20 |
+
tlwh : ndarray
|
21 |
+
Bounding box in format `(top left x, top left y, width, height)`.
|
22 |
+
confidence : ndarray
|
23 |
+
Detector confidence score.
|
24 |
+
feature : ndarray | NoneType
|
25 |
+
A feature vector that describes the object contained in this image.
|
26 |
+
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, tlwh, confidence, feature):
|
30 |
+
self.tlwh = np.asarray(tlwh, dtype=np.float32)
|
31 |
+
self.confidence = float(confidence)
|
32 |
+
self.feature = np.asarray(feature.cpu(), dtype=np.float32)
|
33 |
+
|
34 |
+
def to_tlbr(self):
|
35 |
+
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
36 |
+
`(top left, bottom right)`.
|
37 |
+
"""
|
38 |
+
ret = self.tlwh.copy()
|
39 |
+
ret[2:] += ret[:2]
|
40 |
+
return ret
|
41 |
+
|
42 |
+
def to_xyah(self):
|
43 |
+
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
44 |
+
height)`, where the aspect ratio is `width / height`.
|
45 |
+
"""
|
46 |
+
ret = self.tlwh.copy()
|
47 |
+
ret[:2] += ret[2:] / 2
|
48 |
+
ret[2] /= ret[3]
|
49 |
+
return ret
|
50 |
+
|
51 |
+
def to_xyah_ext(bbox):
|
52 |
+
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
53 |
+
height)`, where the aspect ratio is `width / height`.
|
54 |
+
"""
|
55 |
+
ret = bbox.copy()
|
56 |
+
ret[:2] += ret[2:] / 2
|
57 |
+
ret[2] /= ret[3]
|
58 |
+
return ret
|
trackers/strongsort/sort/iou_matching.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vim: expandtab:ts=4:sw=4
|
2 |
+
from __future__ import absolute_import
|
3 |
+
import numpy as np
|
4 |
+
from . import linear_assignment
|
5 |
+
|
6 |
+
|
7 |
+
def iou(bbox, candidates):
|
8 |
+
"""Computer intersection over union.
|
9 |
+
|
10 |
+
Parameters
|
11 |
+
----------
|
12 |
+
bbox : ndarray
|
13 |
+
A bounding box in format `(top left x, top left y, width, height)`.
|
14 |
+
candidates : ndarray
|
15 |
+
A matrix of candidate bounding boxes (one per row) in the same format
|
16 |
+
as `bbox`.
|
17 |
+
|
18 |
+
Returns
|
19 |
+
-------
|
20 |
+
ndarray
|
21 |
+
The intersection over union in [0, 1] between the `bbox` and each
|
22 |
+
candidate. A higher score means a larger fraction of the `bbox` is
|
23 |
+
occluded by the candidate.
|
24 |
+
|
25 |
+
"""
|
26 |
+
bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:]
|
27 |
+
candidates_tl = candidates[:, :2]
|
28 |
+
candidates_br = candidates[:, :2] + candidates[:, 2:]
|
29 |
+
|
30 |
+
tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
|
31 |
+
np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
|
32 |
+
br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
|
33 |
+
np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
|
34 |
+
wh = np.maximum(0., br - tl)
|
35 |
+
|
36 |
+
area_intersection = wh.prod(axis=1)
|
37 |
+
area_bbox = bbox[2:].prod()
|
38 |
+
area_candidates = candidates[:, 2:].prod(axis=1)
|
39 |
+
return area_intersection / (area_bbox + area_candidates - area_intersection)
|
40 |
+
|
41 |
+
|
42 |
+
def iou_cost(tracks, detections, track_indices=None,
|
43 |
+
detection_indices=None):
|
44 |
+
"""An intersection over union distance metric.
|
45 |
+
|
46 |
+
Parameters
|
47 |
+
----------
|
48 |
+
tracks : List[deep_sort.track.Track]
|
49 |
+
A list of tracks.
|
50 |
+
detections : List[deep_sort.detection.Detection]
|
51 |
+
A list of detections.
|
52 |
+
track_indices : Optional[List[int]]
|
53 |
+
A list of indices to tracks that should be matched. Defaults to
|
54 |
+
all `tracks`.
|
55 |
+
detection_indices : Optional[List[int]]
|
56 |
+
A list of indices to detections that should be matched. Defaults
|
57 |
+
to all `detections`.
|
58 |
+
|
59 |
+
Returns
|
60 |
+
-------
|
61 |
+
ndarray
|
62 |
+
Returns a cost matrix of shape
|
63 |
+
len(track_indices), len(detection_indices) where entry (i, j) is
|
64 |
+
`1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
|
65 |
+
|
66 |
+
"""
|
67 |
+
if track_indices is None:
|
68 |
+
track_indices = np.arange(len(tracks))
|
69 |
+
if detection_indices is None:
|
70 |
+
detection_indices = np.arange(len(detections))
|
71 |
+
|
72 |
+
cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
|
73 |
+
for row, track_idx in enumerate(track_indices):
|
74 |
+
if tracks[track_idx].time_since_update > 1:
|
75 |
+
cost_matrix[row, :] = linear_assignment.INFTY_COST
|
76 |
+
continue
|
77 |
+
|
78 |
+
bbox = tracks[track_idx].to_tlwh()
|
79 |
+
candidates = np.asarray(
|
80 |
+
[detections[i].tlwh for i in detection_indices])
|
81 |
+
cost_matrix[row, :] = 1. - iou(bbox, candidates)
|
82 |
+
return cost_matrix
|
trackers/strongsort/sort/kalman_filter.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vim: expandtab:ts=4:sw=4
|
2 |
+
import numpy as np
|
3 |
+
import scipy.linalg
|
4 |
+
"""
|
5 |
+
Table for the 0.95 quantile of the chi-square distribution with N degrees of
|
6 |
+
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
|
7 |
+
function and used as Mahalanobis gating threshold.
|
8 |
+
"""
|
9 |
+
chi2inv95 = {
|
10 |
+
1: 3.8415,
|
11 |
+
2: 5.9915,
|
12 |
+
3: 7.8147,
|
13 |
+
4: 9.4877,
|
14 |
+
5: 11.070,
|
15 |
+
6: 12.592,
|
16 |
+
7: 14.067,
|
17 |
+
8: 15.507,
|
18 |
+
9: 16.919}
|
19 |
+
|
20 |
+
|
21 |
+
class KalmanFilter(object):
|
22 |
+
"""
|
23 |
+
A simple Kalman filter for tracking bounding boxes in image space.
|
24 |
+
The 8-dimensional state space
|
25 |
+
x, y, a, h, vx, vy, va, vh
|
26 |
+
contains the bounding box center position (x, y), aspect ratio a, height h,
|
27 |
+
and their respective velocities.
|
28 |
+
Object motion follows a constant velocity model. The bounding box location
|
29 |
+
(x, y, a, h) is taken as direct observation of the state space (linear
|
30 |
+
observation model).
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self):
|
34 |
+
ndim, dt = 4, 1.
|
35 |
+
|
36 |
+
# Create Kalman filter model matrices.
|
37 |
+
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
38 |
+
for i in range(ndim):
|
39 |
+
self._motion_mat[i, ndim + i] = dt
|
40 |
+
|
41 |
+
self._update_mat = np.eye(ndim, 2 * ndim)
|
42 |
+
|
43 |
+
# Motion and observation uncertainty are chosen relative to the current
|
44 |
+
# state estimate. These weights control the amount of uncertainty in
|
45 |
+
# the model. This is a bit hacky.
|
46 |
+
self._std_weight_position = 1. / 20
|
47 |
+
self._std_weight_velocity = 1. / 160
|
48 |
+
|
49 |
+
def initiate(self, measurement):
|
50 |
+
"""Create track from unassociated measurement.
|
51 |
+
Parameters
|
52 |
+
----------
|
53 |
+
measurement : ndarray
|
54 |
+
Bounding box coordinates (x, y, a, h) with center position (x, y),
|
55 |
+
aspect ratio a, and height h.
|
56 |
+
Returns
|
57 |
+
-------
|
58 |
+
(ndarray, ndarray)
|
59 |
+
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
60 |
+
dimensional) of the new track. Unobserved velocities are initialized
|
61 |
+
to 0 mean.
|
62 |
+
"""
|
63 |
+
mean_pos = measurement
|
64 |
+
mean_vel = np.zeros_like(mean_pos)
|
65 |
+
mean = np.r_[mean_pos, mean_vel]
|
66 |
+
|
67 |
+
std = [
|
68 |
+
2 * self._std_weight_position * measurement[0], # the center point x
|
69 |
+
2 * self._std_weight_position * measurement[1], # the center point y
|
70 |
+
1 * measurement[2], # the ratio of width/height
|
71 |
+
2 * self._std_weight_position * measurement[3], # the height
|
72 |
+
10 * self._std_weight_velocity * measurement[0],
|
73 |
+
10 * self._std_weight_velocity * measurement[1],
|
74 |
+
0.1 * measurement[2],
|
75 |
+
10 * self._std_weight_velocity * measurement[3]]
|
76 |
+
covariance = np.diag(np.square(std))
|
77 |
+
return mean, covariance
|
78 |
+
|
79 |
+
def predict(self, mean, covariance):
|
80 |
+
"""Run Kalman filter prediction step.
|
81 |
+
Parameters
|
82 |
+
----------
|
83 |
+
mean : ndarray
|
84 |
+
The 8 dimensional mean vector of the object state at the previous
|
85 |
+
time step.
|
86 |
+
covariance : ndarray
|
87 |
+
The 8x8 dimensional covariance matrix of the object state at the
|
88 |
+
previous time step.
|
89 |
+
Returns
|
90 |
+
-------
|
91 |
+
(ndarray, ndarray)
|
92 |
+
Returns the mean vector and covariance matrix of the predicted
|
93 |
+
state. Unobserved velocities are initialized to 0 mean.
|
94 |
+
"""
|
95 |
+
std_pos = [
|
96 |
+
self._std_weight_position * mean[0],
|
97 |
+
self._std_weight_position * mean[1],
|
98 |
+
1 * mean[2],
|
99 |
+
self._std_weight_position * mean[3]]
|
100 |
+
std_vel = [
|
101 |
+
self._std_weight_velocity * mean[0],
|
102 |
+
self._std_weight_velocity * mean[1],
|
103 |
+
0.1 * mean[2],
|
104 |
+
self._std_weight_velocity * mean[3]]
|
105 |
+
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
106 |
+
|
107 |
+
mean = np.dot(self._motion_mat, mean)
|
108 |
+
covariance = np.linalg.multi_dot((
|
109 |
+
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
110 |
+
|
111 |
+
return mean, covariance
|
112 |
+
|
113 |
+
def project(self, mean, covariance, confidence=.0):
|
114 |
+
"""Project state distribution to measurement space.
|
115 |
+
Parameters
|
116 |
+
----------
|
117 |
+
mean : ndarray
|
118 |
+
The state's mean vector (8 dimensional array).
|
119 |
+
covariance : ndarray
|
120 |
+
The state's covariance matrix (8x8 dimensional).
|
121 |
+
confidence: (dyh) 检测框置信度
|
122 |
+
Returns
|
123 |
+
-------
|
124 |
+
(ndarray, ndarray)
|
125 |
+
Returns the projected mean and covariance matrix of the given state
|
126 |
+
estimate.
|
127 |
+
"""
|
128 |
+
std = [
|
129 |
+
self._std_weight_position * mean[3],
|
130 |
+
self._std_weight_position * mean[3],
|
131 |
+
1e-1,
|
132 |
+
self._std_weight_position * mean[3]]
|
133 |
+
|
134 |
+
|
135 |
+
std = [(1 - confidence) * x for x in std]
|
136 |
+
|
137 |
+
innovation_cov = np.diag(np.square(std))
|
138 |
+
|
139 |
+
mean = np.dot(self._update_mat, mean)
|
140 |
+
covariance = np.linalg.multi_dot((
|
141 |
+
self._update_mat, covariance, self._update_mat.T))
|
142 |
+
return mean, covariance + innovation_cov
|
143 |
+
|
144 |
+
def update(self, mean, covariance, measurement, confidence=.0):
|
145 |
+
"""Run Kalman filter correction step.
|
146 |
+
Parameters
|
147 |
+
----------
|
148 |
+
mean : ndarray
|
149 |
+
The predicted state's mean vector (8 dimensional).
|
150 |
+
covariance : ndarray
|
151 |
+
The state's covariance matrix (8x8 dimensional).
|
152 |
+
measurement : ndarray
|
153 |
+
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
154 |
+
is the center position, a the aspect ratio, and h the height of the
|
155 |
+
bounding box.
|
156 |
+
confidence: (dyh)检测框置信度
|
157 |
+
Returns
|
158 |
+
-------
|
159 |
+
(ndarray, ndarray)
|
160 |
+
Returns the measurement-corrected state distribution.
|
161 |
+
"""
|
162 |
+
projected_mean, projected_cov = self.project(mean, covariance, confidence)
|
163 |
+
|
164 |
+
chol_factor, lower = scipy.linalg.cho_factor(
|
165 |
+
projected_cov, lower=True, check_finite=False)
|
166 |
+
kalman_gain = scipy.linalg.cho_solve(
|
167 |
+
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
|
168 |
+
check_finite=False).T
|
169 |
+
innovation = measurement - projected_mean
|
170 |
+
|
171 |
+
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
172 |
+
new_covariance = covariance - np.linalg.multi_dot((
|
173 |
+
kalman_gain, projected_cov, kalman_gain.T))
|
174 |
+
return new_mean, new_covariance
|
175 |
+
|
176 |
+
def gating_distance(self, mean, covariance, measurements,
|
177 |
+
only_position=False):
|
178 |
+
"""Compute gating distance between state distribution and measurements.
|
179 |
+
A suitable distance threshold can be obtained from `chi2inv95`. If
|
180 |
+
`only_position` is False, the chi-square distribution has 4 degrees of
|
181 |
+
freedom, otherwise 2.
|
182 |
+
Parameters
|
183 |
+
----------
|
184 |
+
mean : ndarray
|
185 |
+
Mean vector over the state distribution (8 dimensional).
|
186 |
+
covariance : ndarray
|
187 |
+
Covariance of the state distribution (8x8 dimensional).
|
188 |
+
measurements : ndarray
|
189 |
+
An Nx4 dimensional matrix of N measurements, each in
|
190 |
+
format (x, y, a, h) where (x, y) is the bounding box center
|
191 |
+
position, a the aspect ratio, and h the height.
|
192 |
+
only_position : Optional[bool]
|
193 |
+
If True, distance computation is done with respect to the bounding
|
194 |
+
box center position only.
|
195 |
+
Returns
|
196 |
+
-------
|
197 |
+
ndarray
|
198 |
+
Returns an array of length N, where the i-th element contains the
|
199 |
+
squared Mahalanobis distance between (mean, covariance) and
|
200 |
+
`measurements[i]`.
|
201 |
+
"""
|
202 |
+
mean, covariance = self.project(mean, covariance)
|
203 |
+
|
204 |
+
if only_position:
|
205 |
+
mean, covariance = mean[:2], covariance[:2, :2]
|
206 |
+
measurements = measurements[:, :2]
|
207 |
+
|
208 |
+
cholesky_factor = np.linalg.cholesky(covariance)
|
209 |
+
d = measurements - mean
|
210 |
+
z = scipy.linalg.solve_triangular(
|
211 |
+
cholesky_factor, d.T, lower=True, check_finite=False,
|
212 |
+
overwrite_b=True)
|
213 |
+
squared_maha = np.sum(z * z, axis=0)
|
214 |
+
return squared_maha
|