Spaces:
Sleeping
Sleeping
thai thong
commited on
Commit
•
9e56ba5
1
Parent(s):
b881c5d
add file to build StrongSORT
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +71 -29
- detect_strongsort.py +392 -0
- strong_sort/.gitignore +13 -0
- strong_sort/LICENSE +21 -0
- strong_sort/README.md +137 -0
- strong_sort/__init__.py +11 -0
- strong_sort/configs/strong_sort.yaml +10 -0
- strong_sort/deep/__init__.py +0 -0
- strong_sort/deep/checkpoint/.gitkeep +0 -0
- strong_sort/deep/reid/.flake8 +18 -0
- strong_sort/deep/reid/.gitignore +140 -0
- strong_sort/deep/reid/.isort.cfg +10 -0
- strong_sort/deep/reid/.style.yapf +7 -0
- strong_sort/deep/reid/LICENSE +21 -0
- strong_sort/deep/reid/README.rst +317 -0
- strong_sort/deep/reid/configs/im_osnet_ain_x1_0_softmax_256x128_amsgrad_cosine.yaml +35 -0
- strong_sort/deep/reid/configs/im_osnet_ibn_x1_0_softmax_256x128_amsgrad.yaml +36 -0
- strong_sort/deep/reid/configs/im_osnet_x0_25_softmax_256x128_amsgrad.yaml +36 -0
- strong_sort/deep/reid/configs/im_osnet_x0_5_softmax_256x128_amsgrad.yaml +36 -0
- strong_sort/deep/reid/configs/im_osnet_x0_75_softmax_256x128_amsgrad.yaml +36 -0
- strong_sort/deep/reid/configs/im_osnet_x1_0_softmax_256x128_amsgrad.yaml +36 -0
- strong_sort/deep/reid/configs/im_osnet_x1_0_softmax_256x128_amsgrad_cosine.yaml +35 -0
- strong_sort/deep/reid/configs/im_r50_softmax_256x128_amsgrad.yaml +36 -0
- strong_sort/deep/reid/configs/im_r50fc512_softmax_256x128_amsgrad.yaml +36 -0
- strong_sort/deep/reid/docs/AWESOME_REID.md +69 -0
- strong_sort/deep/reid/docs/MODEL_ZOO.md +93 -0
- strong_sort/deep/reid/docs/Makefile +19 -0
- strong_sort/deep/reid/docs/conf.py +181 -0
- strong_sort/deep/reid/docs/datasets.rst +264 -0
- strong_sort/deep/reid/docs/evaluation.rst +21 -0
- strong_sort/deep/reid/docs/figures/actmap.jpg +0 -0
- strong_sort/deep/reid/docs/figures/ranking_results.jpg +0 -0
- strong_sort/deep/reid/docs/index.rst +35 -0
- strong_sort/deep/reid/docs/pkg/data.rst +86 -0
- strong_sort/deep/reid/docs/pkg/engine.rst +31 -0
- strong_sort/deep/reid/docs/pkg/losses.rst +18 -0
- strong_sort/deep/reid/docs/pkg/metrics.rst +25 -0
- strong_sort/deep/reid/docs/pkg/models.rst +43 -0
- strong_sort/deep/reid/docs/pkg/optim.rst +18 -0
- strong_sort/deep/reid/docs/pkg/utils.rst +41 -0
- strong_sort/deep/reid/docs/user_guide.rst +351 -0
- strong_sort/deep/reid/linter.sh +11 -0
- strong_sort/deep/reid/projects/DML/README.md +16 -0
- strong_sort/deep/reid/projects/DML/default_config.py +207 -0
- strong_sort/deep/reid/projects/DML/dml.py +149 -0
- strong_sort/deep/reid/projects/DML/im_osnet_x1_0_dml_256x128_amsgrad_cosine.yaml +42 -0
- strong_sort/deep/reid/projects/DML/main.py +166 -0
- strong_sort/deep/reid/projects/OSNet_AIN/README.md +38 -0
- strong_sort/deep/reid/projects/OSNet_AIN/default_config.py +210 -0
- strong_sort/deep/reid/projects/OSNet_AIN/main.py +145 -0
app.py
CHANGED
@@ -1,38 +1,65 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
def app():
|
25 |
with gr.Blocks():
|
26 |
with gr.Row():
|
27 |
with gr.Column():
|
28 |
-
|
29 |
-
|
|
|
30 |
model_id = gr.Dropdown(
|
31 |
label="Model",
|
32 |
choices=[
|
33 |
"last_best_model.pt",
|
|
|
34 |
],
|
35 |
-
value="./last_best_model.pt"
|
|
|
36 |
)
|
37 |
image_size = gr.Slider(
|
38 |
label="Image Size",
|
@@ -56,21 +83,26 @@ def app():
|
|
56 |
value=0.5,
|
57 |
)
|
58 |
yolov9_infer = gr.Button(value="Inference")
|
59 |
-
|
60 |
with gr.Column():
|
61 |
-
|
|
|
|
|
62 |
output_path = gr.Textbox(label="Output path")
|
|
|
63 |
yolov9_infer.click(
|
64 |
-
fn=
|
65 |
inputs=[
|
66 |
model_id,
|
67 |
image_size,
|
68 |
conf_threshold,
|
69 |
iou_threshold,
|
70 |
-
|
|
|
71 |
],
|
72 |
-
outputs=[
|
73 |
)
|
|
|
74 |
|
75 |
|
76 |
gradio_app = gr.Blocks()
|
@@ -81,8 +113,18 @@ with gradio_app:
|
|
81 |
YOLOv9: Real-time Object Detection
|
82 |
</h1>
|
83 |
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
with gr.Row():
|
85 |
with gr.Column():
|
86 |
app()
|
87 |
|
88 |
gradio_app.launch(debug=True)
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from detect_strongsort import run
|
3 |
+
import os
|
4 |
+
import threading
|
5 |
+
|
6 |
+
should_continue = True
|
7 |
+
|
8 |
+
|
9 |
+
def yolov9_inference(model_id, image_size, conf_threshold, iou_threshold, img_path=None, vid_path=None):
|
10 |
+
global should_continue
|
11 |
+
img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
|
12 |
+
vid_extensions = ['.mp4', '.avi', '.mov', '.mkv'] # Add more video extensions if needed
|
13 |
+
|
14 |
+
input_path = None
|
15 |
+
if img_path is not None:
|
16 |
+
_, img_extension = os.path.splitext(img_path)
|
17 |
+
if img_extension.lower() in img_extensions:
|
18 |
+
input_path = img_path
|
19 |
+
elif vid_path is not None:
|
20 |
+
_, vid_extension = os.path.splitext(vid_path)
|
21 |
+
if vid_extension.lower() in vid_extensions:
|
22 |
+
input_path = vid_path
|
23 |
+
|
24 |
+
output_path = run(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
|
25 |
+
# Assuming output_path is the path to the output file
|
26 |
+
_, output_extension = os.path.splitext(output_path)
|
27 |
+
if output_extension.lower() in img_extensions:
|
28 |
+
output_image = output_path # Load the image file here
|
29 |
+
output_video = None
|
30 |
+
elif output_extension.lower() in vid_extensions:
|
31 |
+
output_image = None
|
32 |
+
output_video = output_path # Load the video file here
|
33 |
+
|
34 |
+
return output_image, output_video, output_path
|
35 |
+
|
36 |
+
def inference(model_id, image_size, conf_threshold, iou_threshold, img_path=None, vid_path=None):
|
37 |
+
global should_continue
|
38 |
+
should_continue = True
|
39 |
+
output_image, output_video, output_path = yolov9_inference(model_id, image_size, conf_threshold, iou_threshold, img_path, vid_path)
|
40 |
+
return output_image, output_video, output_path
|
41 |
+
|
42 |
+
|
43 |
+
def stop_processing():
|
44 |
+
global should_continue
|
45 |
+
should_continue = False
|
46 |
+
return "Stop..."
|
47 |
|
48 |
def app():
|
49 |
with gr.Blocks():
|
50 |
with gr.Row():
|
51 |
with gr.Column():
|
52 |
+
gr.HTML("<h2>Input Parameters</h2>")
|
53 |
+
img_path = gr.File(label="Image")
|
54 |
+
vid_path = gr.File(label="Video")
|
55 |
model_id = gr.Dropdown(
|
56 |
label="Model",
|
57 |
choices=[
|
58 |
"last_best_model.pt",
|
59 |
+
"best_model-converted.pt"
|
60 |
],
|
61 |
+
value="./last_best_model.pt"
|
62 |
+
|
63 |
)
|
64 |
image_size = gr.Slider(
|
65 |
label="Image Size",
|
|
|
83 |
value=0.5,
|
84 |
)
|
85 |
yolov9_infer = gr.Button(value="Inference")
|
86 |
+
stop_button = gr.Button(value="Stop")
|
87 |
with gr.Column():
|
88 |
+
gr.HTML("<h2>Output</h2>")
|
89 |
+
output_image = gr.Image(type="numpy",label="Output Image")
|
90 |
+
output_video = gr.Video(label="Output Video")
|
91 |
output_path = gr.Textbox(label="Output path")
|
92 |
+
|
93 |
yolov9_infer.click(
|
94 |
+
fn=inference,
|
95 |
inputs=[
|
96 |
model_id,
|
97 |
image_size,
|
98 |
conf_threshold,
|
99 |
iou_threshold,
|
100 |
+
img_path,
|
101 |
+
vid_path
|
102 |
],
|
103 |
+
outputs=[output_image, output_video, output_path],
|
104 |
)
|
105 |
+
stop_button.click(stop_processing)
|
106 |
|
107 |
|
108 |
gradio_app = gr.Blocks()
|
|
|
113 |
YOLOv9: Real-time Object Detection
|
114 |
</h1>
|
115 |
""")
|
116 |
+
css = """
|
117 |
+
body {
|
118 |
+
background-color: #f0f0f0;
|
119 |
+
}
|
120 |
+
h1 {
|
121 |
+
color: #4CAF50;
|
122 |
+
}
|
123 |
+
"""
|
124 |
with gr.Row():
|
125 |
with gr.Column():
|
126 |
app()
|
127 |
|
128 |
gradio_app.launch(debug=True)
|
129 |
+
|
130 |
+
|
detect_strongsort.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import os
|
4 |
+
# limit the number of cpus used by high performance libraries
|
5 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
6 |
+
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
7 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
8 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
|
9 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
10 |
+
import platform
|
11 |
+
import sys
|
12 |
+
import numpy as np
|
13 |
+
from pathlib import Path
|
14 |
+
import torch
|
15 |
+
import torch.backends.cudnn as cudnn
|
16 |
+
from numpy import random
|
17 |
+
from time import time
|
18 |
+
|
19 |
+
|
20 |
+
FILE = Path(__file__).resolve()
|
21 |
+
ROOT = FILE.parents[0] # yolov5 strongsort root directory
|
22 |
+
WEIGHTS = ROOT / 'weights'
|
23 |
+
if str(ROOT) not in sys.path:
|
24 |
+
sys.path.append(str(ROOT)) # add ROOT to PATH
|
25 |
+
if str(ROOT / 'yolov9') not in sys.path:
|
26 |
+
sys.path.append(str(ROOT / 'yolov9')) # add yolov5 ROOT to PATH
|
27 |
+
if str(ROOT / 'strong_sort') not in sys.path:
|
28 |
+
sys.path.append(str(ROOT / 'strong_sort')) # add strong_sort ROOT to PATH
|
29 |
+
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
30 |
+
from models.experimental import attempt_load
|
31 |
+
from models.common import DetectMultiBackend
|
32 |
+
from utils.dataloaders import LoadImages, LoadStreams, LoadScreenshots
|
33 |
+
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
34 |
+
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
|
35 |
+
from utils.torch_utils import select_device, time_sync, smart_inference_mode
|
36 |
+
from utils.plots import Annotator, colors, save_one_box
|
37 |
+
from strong_sort.utils.parser import get_config
|
38 |
+
from strong_sort.strong_sort import StrongSORT
|
39 |
+
|
40 |
+
|
41 |
+
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
|
42 |
+
|
43 |
+
|
44 |
+
def plot_one_box(x, img, color=None, label=None, line_thickness=3):
|
45 |
+
# Plots one bounding box on image img
|
46 |
+
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
|
47 |
+
color = color or [random.randint(0, 255) for _ in range(3)]
|
48 |
+
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
|
49 |
+
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
|
50 |
+
if label:
|
51 |
+
tf = max(tl - 1, 1) # font thickness
|
52 |
+
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
|
53 |
+
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
54 |
+
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
|
55 |
+
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
@smart_inference_mode()
|
60 |
+
def run(
|
61 |
+
source='0',
|
62 |
+
data = ROOT / 'data/coco.yaml', # data.yaml path
|
63 |
+
yolo_weights=WEIGHTS / 'yolo.pt', # model.pt path(s),
|
64 |
+
strong_sort_weights=WEIGHTS / 'osnet_x0_25_msmt17.pt', # model.pt path,
|
65 |
+
config_strongsort=ROOT / 'strong_sort/configs/strong_sort.yaml',
|
66 |
+
imgsz=(640, 640), # inference size (height, width)
|
67 |
+
conf_thres=0.25, # confidence threshold
|
68 |
+
iou_thres=0.45, # NMS IOU threshold
|
69 |
+
max_det=1000, # maximum detections per image
|
70 |
+
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
71 |
+
view_img=False, # show results
|
72 |
+
save_txt=False, # save results to *.txt
|
73 |
+
save_conf=False, # save confidences in --save-txt labels
|
74 |
+
save_crop=False, # save cropped prediction boxes
|
75 |
+
nosave=False, # do not save images/videos
|
76 |
+
classes=None, # filter by class: --class 0, or --class 0 2 3
|
77 |
+
agnostic_nms=False, # class-agnostic NMS
|
78 |
+
augment=False, # augmented inference
|
79 |
+
visualize=False, # visualize features
|
80 |
+
update=False, # update all models
|
81 |
+
project=ROOT / 'runs/track', # save results to project/name
|
82 |
+
name='exp', # save results to project/name
|
83 |
+
exist_ok=False, # existing project/name ok, do not increment
|
84 |
+
line_thickness=3, # bounding box thickness (pixels)
|
85 |
+
hide_labels=False, # hide labels
|
86 |
+
hide_conf=False, # hide confidences
|
87 |
+
half=False, # use FP16 half-precision inference
|
88 |
+
dnn=False, # use OpenCV DNN for ONNX inference
|
89 |
+
vid_stride=1, # video frame-rate stride
|
90 |
+
):
|
91 |
+
|
92 |
+
source = str(source)
|
93 |
+
save_img = not nosave and not source.endswith('.txt') # save inference images
|
94 |
+
is_file = Path(source).suffix[1:] in (VID_FORMATS)
|
95 |
+
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
|
96 |
+
webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
|
97 |
+
screenshot = source.lower().startswith('screen')
|
98 |
+
|
99 |
+
if is_url and is_file:
|
100 |
+
source = check_file(source) # download
|
101 |
+
|
102 |
+
# Directories
|
103 |
+
if not isinstance(yolo_weights, list): # single yolo model
|
104 |
+
exp_name = Path(yolo_weights).stem
|
105 |
+
elif type(yolo_weights) is list and len(yolo_weights) == 1: # single models after --yolo_weights
|
106 |
+
exp_name = Path(yolo_weights[0]).stem
|
107 |
+
yolo_weights = Path(yolo_weights[0])
|
108 |
+
else: # multiple models after --yolo_weights
|
109 |
+
exp_name = 'ensemble'
|
110 |
+
exp_name = name if name else exp_name + "_" + Path(strong_sort_weights).stem
|
111 |
+
save_dir = increment_path(Path(project) / exp_name, exist_ok=exist_ok) # increment run
|
112 |
+
save_dir = Path(save_dir)
|
113 |
+
(save_dir / 'tracks' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
|
114 |
+
|
115 |
+
# Load model
|
116 |
+
device = select_device(device)
|
117 |
+
model = DetectMultiBackend(yolo_weights, device=device, dnn=dnn, data=data, fp16=half)
|
118 |
+
stride, names, pt = model.stride, model.names, model.pt
|
119 |
+
imgsz = check_img_size(imgsz, s=stride) # check image size
|
120 |
+
|
121 |
+
# Dataloader
|
122 |
+
|
123 |
+
# Dataloader
|
124 |
+
bs = 1 # batch_size
|
125 |
+
if webcam:
|
126 |
+
view_img = check_imshow(warn=True)
|
127 |
+
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
|
128 |
+
bs = len(dataset)
|
129 |
+
elif screenshot:
|
130 |
+
dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
|
131 |
+
else:
|
132 |
+
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
|
133 |
+
vid_path, vid_writer,txt_path = [None] * bs, [None] * bs, [None] * bs
|
134 |
+
|
135 |
+
|
136 |
+
# initialize StrongSORT
|
137 |
+
cfg = get_config()
|
138 |
+
cfg.merge_from_file(config_strongsort)
|
139 |
+
|
140 |
+
# Create as many strong sort instances as there are video sources
|
141 |
+
strongsort_list = []
|
142 |
+
for i in range(bs):
|
143 |
+
strongsort_list.append(
|
144 |
+
StrongSORT(
|
145 |
+
strong_sort_weights,
|
146 |
+
device,
|
147 |
+
half,
|
148 |
+
#max_dist=cfg.STRONGSORT.MAX_DIST,
|
149 |
+
max_iou_distance=cfg.STRONGSORT.MAX_IOU_DISTANCE,
|
150 |
+
max_age=cfg.STRONGSORT.MAX_AGE,
|
151 |
+
n_init=cfg.STRONGSORT.N_INIT,
|
152 |
+
nn_budget=cfg.STRONGSORT.NN_BUDGET,
|
153 |
+
mc_lambda=cfg.STRONGSORT.MC_LAMBDA,
|
154 |
+
ema_alpha=cfg.STRONGSORT.EMA_ALPHA,
|
155 |
+
|
156 |
+
)
|
157 |
+
)
|
158 |
+
strongsort_list[i].model.warmup()
|
159 |
+
outputs = [None] * bs
|
160 |
+
|
161 |
+
colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
|
162 |
+
|
163 |
+
# Run tracking
|
164 |
+
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
|
165 |
+
seen, windows, dt,sdt = 0, [], (Profile(), Profile(), Profile(), Profile()),[0.0, 0.0, 0.0, 0.0]
|
166 |
+
curr_frames, prev_frames = [None] * bs, [None] * bs
|
167 |
+
for frame_idx, (path, im, im0s, vid_cap, s) in enumerate(dataset):
|
168 |
+
# s = ''
|
169 |
+
t1 = time_sync()
|
170 |
+
with dt[0]:
|
171 |
+
im = torch.from_numpy(im).to(model.device)
|
172 |
+
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
173 |
+
im /= 255 # 0 - 255 to 0.0 - 1.0
|
174 |
+
if len(im.shape) == 3:
|
175 |
+
im = im[None] # expand for batch dim
|
176 |
+
t2 = time_sync()
|
177 |
+
sdt[0] += t2 - t1
|
178 |
+
|
179 |
+
# Inference
|
180 |
+
with dt[1]:
|
181 |
+
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
|
182 |
+
pred = model(im, augment=augment, visualize=visualize)
|
183 |
+
pred = pred[0][1]
|
184 |
+
t3 = time_sync()
|
185 |
+
sdt[1] += t3 - t2
|
186 |
+
|
187 |
+
# Apply NMS
|
188 |
+
with dt[2]:
|
189 |
+
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
|
190 |
+
sdt[2] += time_sync() - t3
|
191 |
+
|
192 |
+
# Second-stage classifier (optional)
|
193 |
+
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
|
194 |
+
|
195 |
+
|
196 |
+
# Process detections
|
197 |
+
for i, det in enumerate(pred): # detections per image
|
198 |
+
seen += 1
|
199 |
+
if webcam: # bs >= 1
|
200 |
+
p, im0, _ = path[i], im0s[i].copy(), dataset.count
|
201 |
+
p = Path(p) # to Path
|
202 |
+
s += f'{i}: '
|
203 |
+
# txt_file_name = p.name
|
204 |
+
txt_file_name = p.stem + f'_{i}' # Unique text file name
|
205 |
+
# save_path = str(save_dir / p.name) + str(i) # im.jpg, vid.mp4, ...
|
206 |
+
save_path = str(save_dir / p.stem) + f'_{i}' # Unique video file name
|
207 |
+
|
208 |
+
else:
|
209 |
+
p, im0, _ = path, im0s.copy(), getattr(dataset, 'frame', 0)
|
210 |
+
|
211 |
+
|
212 |
+
p = Path(p) # to Path
|
213 |
+
# video file
|
214 |
+
if source.endswith(VID_FORMATS):
|
215 |
+
txt_file_name = p.stem
|
216 |
+
save_path = str(save_dir / p.name) # im.jpg, vid.mp4, ...
|
217 |
+
# folder with imgs
|
218 |
+
else:
|
219 |
+
txt_file_name = p.parent.name # get folder name containing current img
|
220 |
+
save_path = str(save_dir / p.parent.name) # im.jpg, vid.mp4, ...
|
221 |
+
|
222 |
+
curr_frames[i] = im0
|
223 |
+
|
224 |
+
txt_path = str(save_dir / 'tracks' / txt_file_name) # im.txt
|
225 |
+
s += '%gx%g ' % im.shape[2:] # print string
|
226 |
+
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
227 |
+
imc = im0.copy() if save_crop else im0 # for save_crop
|
228 |
+
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
|
229 |
+
|
230 |
+
if cfg.STRONGSORT.ECC: # camera motion compensation
|
231 |
+
strongsort_list[i].tracker.camera_update(prev_frames[i], curr_frames[i])
|
232 |
+
|
233 |
+
if det is not None and len(det):
|
234 |
+
# Rescale boxes from img_size to im0 size
|
235 |
+
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
|
236 |
+
|
237 |
+
# Print results
|
238 |
+
for c in det[:, -1].unique():
|
239 |
+
n = (det[:, -1] == c).sum() # detections per class
|
240 |
+
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
|
241 |
+
|
242 |
+
xywhs = xyxy2xywh(det[:, 0:4])
|
243 |
+
confs = det[:, 4]
|
244 |
+
clss = det[:, 5]
|
245 |
+
|
246 |
+
# pass detections to strongsort
|
247 |
+
t4 = time_sync()
|
248 |
+
outputs[i] = strongsort_list[i].update(xywhs.cpu(), confs.cpu(), clss.cpu(), im0)
|
249 |
+
t5 = time_sync()
|
250 |
+
sdt[3] += t5 - t4
|
251 |
+
|
252 |
+
# Write results
|
253 |
+
for j, (output, conf) in enumerate(zip(outputs[i], confs)):
|
254 |
+
xyxy = output[0:4]
|
255 |
+
id = output[4]
|
256 |
+
cls = output[5]
|
257 |
+
# for *xyxy, conf, cls in reversed(det):
|
258 |
+
if save_txt: # Write to file
|
259 |
+
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
260 |
+
# line = (id , cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
|
261 |
+
line = ( int(p.stem), frame_idx, id , cls, *xywh, conf) if save_conf else ( p.stem, frame_idx, cls, *xywh) # label format
|
262 |
+
with open(txt_path + '.txt', 'a') as file:
|
263 |
+
file.write(('%g ' * len(line) + '\n') % line)
|
264 |
+
|
265 |
+
if save_img or save_crop or view_img: # Add bbox to image
|
266 |
+
c = int(cls) # integer class
|
267 |
+
label = None if hide_labels else (names[c] if hide_conf else f' { id } {names[c]} {conf:.2f}')
|
268 |
+
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=2)
|
269 |
+
if save_crop:
|
270 |
+
save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
|
271 |
+
|
272 |
+
|
273 |
+
# # draw boxes for visualization
|
274 |
+
# if len(outputs[i]) > 0:
|
275 |
+
# for j, (output, conf) in enumerate(zip(outputs[i], confs)):
|
276 |
+
|
277 |
+
# bboxes = output[0:4]
|
278 |
+
# id = output[4]
|
279 |
+
# cls = output[5]
|
280 |
+
|
281 |
+
# if save_txt:
|
282 |
+
# # to MOT format
|
283 |
+
# bbox_left = output[0]
|
284 |
+
# bbox_top = output[1]
|
285 |
+
# bbox_w = output[2] - output[0]
|
286 |
+
# bbox_h = output[3] - output[1]
|
287 |
+
# # format video_name frame id xmin ymin width height score class
|
288 |
+
# with open(txt_path + '.txt', 'a') as file:
|
289 |
+
# file.write(f'{p.stem} {frame_idx} {id} {bbox_left} {bbox_top} {bbox_w} {bbox_h} {conf:.2f} {cls}\n')
|
290 |
+
|
291 |
+
# if save_img or save_crop or view_img: # Add bbox to image
|
292 |
+
# c = int(cls) # integer class
|
293 |
+
# id = int(id) # integer id
|
294 |
+
# label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
|
295 |
+
# plot_one_box(bboxes, im0, label=label, color=colors[int(cls)], line_thickness=2)
|
296 |
+
# if save_crop:
|
297 |
+
# txt_file_name = txt_file_name if (isinstance(path, list) and len(path) > 1) else ''
|
298 |
+
# save_one_box(bboxes, imc, file=save_dir / 'crops' / txt_file_name / names[c] / f'{id}' / f'{p.stem}.jpg', BGR=True)
|
299 |
+
|
300 |
+
print(f'{s}Done. YOLO:({t3 - t2:.3f}s), StrongSORT:({t5 - t4:.3f}s)')
|
301 |
+
|
302 |
+
else:
|
303 |
+
strongsort_list[i].increment_ages()
|
304 |
+
print('No detections')
|
305 |
+
|
306 |
+
# Stream results
|
307 |
+
im0 = annotator.result()
|
308 |
+
if view_img:
|
309 |
+
if platform.system() == 'Linux' and p not in windows:
|
310 |
+
windows.append(p)
|
311 |
+
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
312 |
+
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
|
313 |
+
cv2.imshow(str(p), im0)
|
314 |
+
cv2.waitKey(1) # 1 millisecond
|
315 |
+
|
316 |
+
# Save results (image with detections)
|
317 |
+
if save_img:
|
318 |
+
if dataset.mode == 'image':
|
319 |
+
cv2.imwrite(save_path, im0)
|
320 |
+
else: # 'video' or 'stream'
|
321 |
+
if vid_path[i] != save_path: # new video
|
322 |
+
vid_path[i] = save_path
|
323 |
+
if isinstance(vid_writer[i], cv2.VideoWriter):
|
324 |
+
vid_writer[i].release() # release previous video writer
|
325 |
+
if vid_cap: # video
|
326 |
+
fps = vid_cap.get(cv2.CAP_PROP_FPS)
|
327 |
+
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
328 |
+
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
329 |
+
else: # stream
|
330 |
+
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
331 |
+
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
|
332 |
+
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc('m','p','4','v'), fps, (w, h))
|
333 |
+
vid_writer[i].write(im0)
|
334 |
+
|
335 |
+
prev_frames[i] = curr_frames[i]
|
336 |
+
|
337 |
+
# Print time (inference-only)
|
338 |
+
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
|
339 |
+
# Print results
|
340 |
+
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape, %.1fms StrongSORT' % tuple(1E3 * x / seen for x in sdt))
|
341 |
+
if save_txt or save_img:
|
342 |
+
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
|
343 |
+
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
|
344 |
+
if update:
|
345 |
+
strip_optimizer(yolo_weights[0]) # update model (to fix SourceChangeWarning)
|
346 |
+
return save_path
|
347 |
+
def parse_opt():
|
348 |
+
parser = argparse.ArgumentParser()
|
349 |
+
parser.add_argument('--yolo-weights', nargs='+', type=str, default=WEIGHTS / 'yolov9.pt', help='model.pt path(s)')
|
350 |
+
parser.add_argument('--strong-sort-weights', type=str, default=WEIGHTS / 'osnet_x0_25_msmt17.pt')
|
351 |
+
parser.add_argument('--config-strongsort', type=str, default='strong_sort/configs/strong_sort.yaml')
|
352 |
+
parser.add_argument('--source', type=str, default='0', help='file/dir/URL/glob, 0 for webcam')
|
353 |
+
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
|
354 |
+
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
|
355 |
+
parser.add_argument('--conf-thres', type=float, default=0.5, help='confidence threshold')
|
356 |
+
parser.add_argument('--iou-thres', type=float, default=0.5, help='NMS IoU threshold')
|
357 |
+
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
|
358 |
+
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
359 |
+
parser.add_argument('--view-img', action='store_true', help='show results')
|
360 |
+
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
|
361 |
+
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
|
362 |
+
parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
|
363 |
+
parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
|
364 |
+
# class 0 is person, 1 is bycicle, 2 is car... 79 is oven
|
365 |
+
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
|
366 |
+
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
|
367 |
+
parser.add_argument('--augment', action='store_true', help='augmented inference')
|
368 |
+
parser.add_argument('--visualize', action='store_true', help='visualize features')
|
369 |
+
parser.add_argument('--update', action='store_true', help='update all models')
|
370 |
+
parser.add_argument('--project', default=ROOT / 'runs/track', help='save results to project/name')
|
371 |
+
parser.add_argument('--name', default='exp', help='save results to project/name')
|
372 |
+
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
|
373 |
+
parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
|
374 |
+
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
|
375 |
+
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
|
376 |
+
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
|
377 |
+
parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
|
378 |
+
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
|
379 |
+
opt = parser.parse_args()
|
380 |
+
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
|
381 |
+
|
382 |
+
return opt
|
383 |
+
|
384 |
+
|
385 |
+
def main(opt):
|
386 |
+
# check_requirements(requirements=ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
|
387 |
+
run(**vars(opt))
|
388 |
+
|
389 |
+
|
390 |
+
if __name__ == "__main__":
|
391 |
+
opt = parse_opt()
|
392 |
+
main(opt)
|
strong_sort/.gitignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Folders
|
2 |
+
__pycache__/
|
3 |
+
build/
|
4 |
+
*.egg-info
|
5 |
+
|
6 |
+
|
7 |
+
# Files
|
8 |
+
*.weights
|
9 |
+
*.t7
|
10 |
+
*.mp4
|
11 |
+
*.avi
|
12 |
+
*.so
|
13 |
+
*.txt
|
strong_sort/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Ziqiang
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
strong_sort/README.md
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Deep Sort with PyTorch
|
2 |
+
|
3 |
+
![](demo/demo.gif)
|
4 |
+
|
5 |
+
## Update(1-1-2020)
|
6 |
+
Changes
|
7 |
+
- fix bugs
|
8 |
+
- refactor code
|
9 |
+
- accerate detection by adding nms on gpu
|
10 |
+
|
11 |
+
## Latest Update(07-22)
|
12 |
+
Changes
|
13 |
+
- bug fix (Thanks @JieChen91 and @yingsen1 for bug reporting).
|
14 |
+
- using batch for feature extracting for each frame, which lead to a small speed up.
|
15 |
+
- code improvement.
|
16 |
+
|
17 |
+
Futher improvement direction
|
18 |
+
- Train detector on specific dataset rather than the official one.
|
19 |
+
- Retrain REID model on pedestrain dataset for better performance.
|
20 |
+
- Replace YOLOv3 detector with advanced ones.
|
21 |
+
|
22 |
+
**Any contributions to this repository is welcome!**
|
23 |
+
|
24 |
+
|
25 |
+
## Introduction
|
26 |
+
This is an implement of MOT tracking algorithm deep sort. Deep sort is basicly the same with sort but added a CNN model to extract features in image of human part bounded by a detector. This CNN model is indeed a RE-ID model and the detector used in [PAPER](https://arxiv.org/abs/1703.07402) is FasterRCNN , and the original source code is [HERE](https://github.com/nwojke/deep_sort).
|
27 |
+
However in original code, the CNN model is implemented with tensorflow, which I'm not familier with. SO I re-implemented the CNN feature extraction model with PyTorch, and changed the CNN model a little bit. Also, I use **YOLOv3** to generate bboxes instead of FasterRCNN.
|
28 |
+
|
29 |
+
## Dependencies
|
30 |
+
- python 3 (python2 not sure)
|
31 |
+
- numpy
|
32 |
+
- scipy
|
33 |
+
- opencv-python
|
34 |
+
- sklearn
|
35 |
+
- torch >= 0.4
|
36 |
+
- torchvision >= 0.1
|
37 |
+
- pillow
|
38 |
+
- vizer
|
39 |
+
- edict
|
40 |
+
|
41 |
+
## Quick Start
|
42 |
+
0. Check all dependencies installed
|
43 |
+
```bash
|
44 |
+
pip install -r requirements.txt
|
45 |
+
```
|
46 |
+
for user in china, you can specify pypi source to accelerate install like:
|
47 |
+
```bash
|
48 |
+
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
49 |
+
```
|
50 |
+
|
51 |
+
1. Clone this repository
|
52 |
+
```
|
53 |
+
git clone git@github.com:ZQPei/deep_sort_pytorch.git
|
54 |
+
```
|
55 |
+
|
56 |
+
2. Download YOLOv3 parameters
|
57 |
+
```
|
58 |
+
cd detector/YOLOv3/weight/
|
59 |
+
wget https://pjreddie.com/media/files/yolov3.weights
|
60 |
+
wget https://pjreddie.com/media/files/yolov3-tiny.weights
|
61 |
+
cd ../../../
|
62 |
+
```
|
63 |
+
|
64 |
+
3. Download deepsort parameters ckpt.t7
|
65 |
+
```
|
66 |
+
cd deep_sort/deep/checkpoint
|
67 |
+
# download ckpt.t7 from
|
68 |
+
https://drive.google.com/drive/folders/1xhG0kRH1EX5B9_Iz8gQJb7UNnn_riXi6 to this folder
|
69 |
+
cd ../../../
|
70 |
+
```
|
71 |
+
|
72 |
+
4. Compile nms module
|
73 |
+
```bash
|
74 |
+
cd detector/YOLOv3/nms
|
75 |
+
sh build.sh
|
76 |
+
cd ../../..
|
77 |
+
```
|
78 |
+
|
79 |
+
Notice:
|
80 |
+
If compiling failed, the simplist way is to **Upgrade your pytorch >= 1.1 and torchvision >= 0.3" and you can avoid the troublesome compiling problems which are most likely caused by either `gcc version too low` or `libraries missing`.
|
81 |
+
|
82 |
+
5. Run demo
|
83 |
+
```
|
84 |
+
usage: python yolov3_deepsort.py VIDEO_PATH
|
85 |
+
[--help]
|
86 |
+
[--frame_interval FRAME_INTERVAL]
|
87 |
+
[--config_detection CONFIG_DETECTION]
|
88 |
+
[--config_deepsort CONFIG_DEEPSORT]
|
89 |
+
[--display]
|
90 |
+
[--display_width DISPLAY_WIDTH]
|
91 |
+
[--display_height DISPLAY_HEIGHT]
|
92 |
+
[--save_path SAVE_PATH]
|
93 |
+
[--cpu]
|
94 |
+
|
95 |
+
# yolov3 + deepsort
|
96 |
+
python yolov3_deepsort.py [VIDEO_PATH]
|
97 |
+
|
98 |
+
# yolov3_tiny + deepsort
|
99 |
+
python yolov3_deepsort.py [VIDEO_PATH] --config_detection ./configs/yolov3_tiny.yaml
|
100 |
+
|
101 |
+
# yolov3 + deepsort on webcam
|
102 |
+
python3 yolov3_deepsort.py /dev/video0 --camera 0
|
103 |
+
|
104 |
+
# yolov3_tiny + deepsort on webcam
|
105 |
+
python3 yolov3_deepsort.py /dev/video0 --config_detection ./configs/yolov3_tiny.yaml --camera 0
|
106 |
+
```
|
107 |
+
Use `--display` to enable display.
|
108 |
+
Results will be saved to `./output/results.avi` and `./output/results.txt`.
|
109 |
+
|
110 |
+
All files above can also be accessed from BaiduDisk!
|
111 |
+
linker:[BaiduDisk](https://pan.baidu.com/s/1YJ1iPpdFTlUyLFoonYvozg)
|
112 |
+
passwd:fbuw
|
113 |
+
|
114 |
+
## Training the RE-ID model
|
115 |
+
The original model used in paper is in original_model.py, and its parameter here [original_ckpt.t7](https://drive.google.com/drive/folders/1xhG0kRH1EX5B9_Iz8gQJb7UNnn_riXi6).
|
116 |
+
|
117 |
+
To train the model, first you need download [Market1501](http://www.liangzheng.com.cn/Project/project_reid.html) dataset or [Mars](http://www.liangzheng.com.cn/Project/project_mars.html) dataset.
|
118 |
+
|
119 |
+
Then you can try [train.py](deep_sort/deep/train.py) to train your own parameter and evaluate it using [test.py](deep_sort/deep/test.py) and [evaluate.py](deep_sort/deep/evalute.py).
|
120 |
+
![train.jpg](deep_sort/deep/train.jpg)
|
121 |
+
|
122 |
+
## Demo videos and images
|
123 |
+
[demo.avi](https://drive.google.com/drive/folders/1xhG0kRH1EX5B9_Iz8gQJb7UNnn_riXi6)
|
124 |
+
[demo2.avi](https://drive.google.com/drive/folders/1xhG0kRH1EX5B9_Iz8gQJb7UNnn_riXi6)
|
125 |
+
|
126 |
+
![1.jpg](demo/1.jpg)
|
127 |
+
![2.jpg](demo/2.jpg)
|
128 |
+
|
129 |
+
|
130 |
+
## References
|
131 |
+
- paper: [Simple Online and Realtime Tracking with a Deep Association Metric](https://arxiv.org/abs/1703.07402)
|
132 |
+
|
133 |
+
- code: [nwojke/deep_sort](https://github.com/nwojke/deep_sort)
|
134 |
+
|
135 |
+
- paper: [YOLOv3](https://pjreddie.com/media/files/papers/YOLOv3.pdf)
|
136 |
+
|
137 |
+
- code: [Joseph Redmon/yolov3](https://pjreddie.com/darknet/yolo/)
|
strong_sort/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .strong_sort import StrongSORT
|
2 |
+
|
3 |
+
|
4 |
+
__all__ = ['StrongSORT', 'build_tracker']
|
5 |
+
|
6 |
+
|
7 |
+
def build_tracker(cfg, use_cuda):
|
8 |
+
return StrongSORT(cfg.STRONGSORT.REID_CKPT,
|
9 |
+
max_dist=cfg.STRONGSORT.MAX_DIST, min_confidence=cfg.STRONGSORT.MIN_CONFIDENCE,
|
10 |
+
nms_max_overlap=cfg.STRONGSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.STRONGSORT.MAX_IOU_DISTANCE,
|
11 |
+
max_age=cfg.STRONGSORT.MAX_AGE, n_init=cfg.STRONGSORT.N_INIT, nn_budget=cfg.STRONGSORT.NN_BUDGET, use_cuda=use_cuda)
|
strong_sort/configs/strong_sort.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
STRONGSORT:
|
2 |
+
ECC: True # activate camera motion compensation
|
3 |
+
MC_LAMBDA: 0.995 # matching with both appearance (1 - MC_LAMBDA) and motion cost
|
4 |
+
EMA_ALPHA: 0.9 # updates appearance state in an exponential moving average manner
|
5 |
+
MAX_DIST: 0.2 # The matching threshold. Samples with larger distance are considered an invalid match
|
6 |
+
MAX_IOU_DISTANCE: 0.7 # Gating threshold. Associations with cost larger than this value are disregarded.
|
7 |
+
MAX_AGE: 30 # Maximum number of missed misses before a track is deleted
|
8 |
+
N_INIT: 1 # Number of frames that a track remains in initialization phase
|
9 |
+
NN_BUDGET: 100 # Maximum size of the appearance descriptors gallery
|
10 |
+
|
strong_sort/deep/__init__.py
ADDED
File without changes
|
strong_sort/deep/checkpoint/.gitkeep
ADDED
File without changes
|
strong_sort/deep/reid/.flake8
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
ignore =
|
3 |
+
# At least two spaces before inline comment
|
4 |
+
E261,
|
5 |
+
# Line lengths are recommended to be no greater than 79 characters
|
6 |
+
E501,
|
7 |
+
# Missing whitespace around arithmetic operator
|
8 |
+
E226,
|
9 |
+
# Blank line contains whitespace
|
10 |
+
W293,
|
11 |
+
# Do not use bare 'except'
|
12 |
+
E722,
|
13 |
+
# Line break after binary operator
|
14 |
+
W504,
|
15 |
+
# isort found an import in the wrong position
|
16 |
+
I001
|
17 |
+
max-line-length = 79
|
18 |
+
exclude = __init__.py, build, torchreid/metrics/rank_cylib/
|
strong_sort/deep/reid/.gitignore
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
|
53 |
+
# Translations
|
54 |
+
*.mo
|
55 |
+
*.pot
|
56 |
+
|
57 |
+
# Django stuff:
|
58 |
+
*.log
|
59 |
+
local_settings.py
|
60 |
+
db.sqlite3
|
61 |
+
|
62 |
+
# Flask stuff:
|
63 |
+
instance/
|
64 |
+
.webassets-cache
|
65 |
+
|
66 |
+
# Scrapy stuff:
|
67 |
+
.scrapy
|
68 |
+
|
69 |
+
# Sphinx documentation
|
70 |
+
docs/_build/
|
71 |
+
|
72 |
+
# PyBuilder
|
73 |
+
target/
|
74 |
+
|
75 |
+
# Jupyter Notebook
|
76 |
+
.ipynb_checkpoints
|
77 |
+
|
78 |
+
# IPython
|
79 |
+
profile_default/
|
80 |
+
ipython_config.py
|
81 |
+
|
82 |
+
# pyenv
|
83 |
+
.python-version
|
84 |
+
|
85 |
+
# pipenv
|
86 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
87 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
88 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
89 |
+
# install all needed dependencies.
|
90 |
+
#Pipfile.lock
|
91 |
+
|
92 |
+
# celery beat schedule file
|
93 |
+
celerybeat-schedule
|
94 |
+
|
95 |
+
# SageMath parsed files
|
96 |
+
*.sage.py
|
97 |
+
|
98 |
+
# Environments
|
99 |
+
.env
|
100 |
+
.venv
|
101 |
+
env/
|
102 |
+
venv/
|
103 |
+
ENV/
|
104 |
+
env.bak/
|
105 |
+
venv.bak/
|
106 |
+
|
107 |
+
# Spyder project settings
|
108 |
+
.spyderproject
|
109 |
+
.spyproject
|
110 |
+
|
111 |
+
# Rope project settings
|
112 |
+
.ropeproject
|
113 |
+
|
114 |
+
# mkdocs documentation
|
115 |
+
/site
|
116 |
+
|
117 |
+
# mypy
|
118 |
+
.mypy_cache/
|
119 |
+
.dmypy.json
|
120 |
+
dmypy.json
|
121 |
+
|
122 |
+
# Pyre type checker
|
123 |
+
.pyre/
|
124 |
+
|
125 |
+
# Cython eval code
|
126 |
+
*.c
|
127 |
+
*.html
|
128 |
+
|
129 |
+
# OS X
|
130 |
+
.DS_Store
|
131 |
+
.Spotlight-V100
|
132 |
+
.Trashes
|
133 |
+
._*
|
134 |
+
|
135 |
+
# ReID
|
136 |
+
reid-data/
|
137 |
+
log/
|
138 |
+
saved-models/
|
139 |
+
model-zoo/
|
140 |
+
debug*
|
strong_sort/deep/reid/.isort.cfg
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[isort]
|
2 |
+
line_length=79
|
3 |
+
multi_line_output=3
|
4 |
+
length_sort=true
|
5 |
+
known_standard_library=numpy,setuptools
|
6 |
+
known_myself=torchreid
|
7 |
+
known_third_party=matplotlib,cv2,torch,torchvision,PIL,yacs
|
8 |
+
no_lines_before=STDLIB,THIRDPARTY
|
9 |
+
sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER
|
10 |
+
default_section=FIRSTPARTY
|
strong_sort/deep/reid/.style.yapf
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[style]
|
2 |
+
BASED_ON_STYLE = pep8
|
3 |
+
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
|
4 |
+
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
|
5 |
+
DEDENT_CLOSING_BRACKETS = true
|
6 |
+
SPACES_BEFORE_COMMENT = 1
|
7 |
+
ARITHMETIC_PRECEDENCE_INDICATION = true
|
strong_sort/deep/reid/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2018 Kaiyang Zhou
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
strong_sort/deep/reid/README.rst
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Torchreid
|
2 |
+
===========
|
3 |
+
Torchreid is a library for deep-learning person re-identification, written in `PyTorch <https://pytorch.org/>`_ and developed for our ICCV'19 project, `Omni-Scale Feature Learning for Person Re-Identification <https://arxiv.org/abs/1905.00953>`_.
|
4 |
+
|
5 |
+
It features:
|
6 |
+
|
7 |
+
- multi-GPU training
|
8 |
+
- support both image- and video-reid
|
9 |
+
- end-to-end training and evaluation
|
10 |
+
- incredibly easy preparation of reid datasets
|
11 |
+
- multi-dataset training
|
12 |
+
- cross-dataset evaluation
|
13 |
+
- standard protocol used by most research papers
|
14 |
+
- highly extensible (easy to add models, datasets, training methods, etc.)
|
15 |
+
- implementations of state-of-the-art deep reid models
|
16 |
+
- access to pretrained reid models
|
17 |
+
- advanced training techniques
|
18 |
+
- visualization tools (tensorboard, ranks, etc.)
|
19 |
+
|
20 |
+
|
21 |
+
Code: https://github.com/KaiyangZhou/deep-person-reid.
|
22 |
+
|
23 |
+
Documentation: https://kaiyangzhou.github.io/deep-person-reid/.
|
24 |
+
|
25 |
+
How-to instructions: https://kaiyangzhou.github.io/deep-person-reid/user_guide.
|
26 |
+
|
27 |
+
Model zoo: https://kaiyangzhou.github.io/deep-person-reid/MODEL_ZOO.
|
28 |
+
|
29 |
+
Tech report: https://arxiv.org/abs/1910.10093.
|
30 |
+
|
31 |
+
You can find some research projects that are built on top of Torchreid `here <https://github.com/KaiyangZhou/deep-person-reid/tree/master/projects>`_.
|
32 |
+
|
33 |
+
|
34 |
+
What's new
|
35 |
+
------------
|
36 |
+
- [Aug 2021] We have released the ImageNet-pretrained models of ``osnet_ain_x0_75``, ``osnet_ain_x0_5`` and ``osnet_ain_x0_25``. The pretraining setup follows `pycls <https://github.com/facebookresearch/pycls/blob/master/configs/archive/imagenet/resnet/R-50-1x64d_step_8gpu.yaml>`_.
|
37 |
+
- [Apr 2021] We have updated the appendix in the `TPAMI version of OSNet <https://arxiv.org/abs/1910.06827v5>`_ to include results in the multi-source domain generalization setting. The trained models can be found in the `Model Zoo <https://kaiyangzhou.github.io/deep-person-reid/MODEL_ZOO.html>`_.
|
38 |
+
- [Apr 2021] We have added a script to automate the process of calculating average results over multiple splits. For more details please see ``tools/parse_test_res.py``.
|
39 |
+
- [Apr 2021] ``v1.4.0``: We added the person search dataset, `CUHK-SYSU <http://www.ee.cuhk.edu.hk/~xgwang/PS/dataset.html>`_. Please see the `documentation <https://kaiyangzhou.github.io/deep-person-reid/>`_ regarding how to download the dataset (it contains cropped person images).
|
40 |
+
- [Apr 2021] All models in the model zoo have been moved to google drive. Please raise an issue if any model's performance is inconsistent with the numbers shown in the model zoo page (could be caused by wrong links).
|
41 |
+
- [Mar 2021] `OSNet <https://arxiv.org/abs/1910.06827>`_ will appear in the TPAMI journal! Compared with the conference version, which focuses on discriminative feature learning using the omni-scale building block, this journal extension further considers generalizable feature learning by integrating `instance normalization layers <https://arxiv.org/abs/1607.08022>`_ with the OSNet architecture. We hope this journal paper can motivate more future work to taclke the generalization issue in cross-dataset re-ID.
|
42 |
+
- [Mar 2021] Generalization across domains (datasets) in person re-ID is crucial in real-world applications, which is closely related to the topic of *domain generalization*. Interested in learning how the field of domain generalization has developed over the last decade? Check our recent survey in this topic at https://arxiv.org/abs/2103.02503, with coverage on the history, datasets, related problems, methodologies, potential directions, and so on (*methods designed for generalizable re-ID are also covered*!).
|
43 |
+
- [Feb 2021] ``v1.3.6`` Added `University-1652 <https://dl.acm.org/doi/abs/10.1145/3394171.3413896>`_, a new dataset for multi-view multi-source geo-localization (credit to `Zhedong Zheng <https://github.com/layumi>`_).
|
44 |
+
- [Feb 2021] ``v1.3.5``: Now the `cython code <https://github.com/KaiyangZhou/deep-person-reid/pull/412>`_ works on Windows (credit to `lablabla <https://github.com/lablabla>`_).
|
45 |
+
- [Jan 2021] Our recent work, `MixStyle <https://openreview.net/forum?id=6xHJ37MVxxp>`_ (mixing instance-level feature statistics of samples of different domains for improving domain generalization), has been accepted to ICLR'21. The code has been released at https://github.com/KaiyangZhou/mixstyle-release where the person re-ID part is based on Torchreid.
|
46 |
+
- [Jan 2021] A new evaluation metric called `mean Inverse Negative Penalty (mINP)` for person re-ID has been introduced in `Deep Learning for Person Re-identification: A Survey and Outlook (TPAMI 2021) <https://arxiv.org/abs/2001.04193>`_. Their code can be accessed at `<https://github.com/mangye16/ReID-Survey>`_.
|
47 |
+
- [Aug 2020] ``v1.3.3``: Fixed bug in ``visrank`` (caused by not unpacking ``dsetid``).
|
48 |
+
- [Aug 2020] ``v1.3.2``: Added ``_junk_pids`` to ``grid`` and ``prid``. This avoids using mislabeled gallery images for training when setting ``combineall=True``.
|
49 |
+
- [Aug 2020] ``v1.3.0``: (1) Added ``dsetid`` to the existing 3-tuple data source, resulting in ``(impath, pid, camid, dsetid)``. This variable denotes the dataset ID and is useful when combining multiple datasets for training (as a dataset indicator). E.g., when combining ``market1501`` and ``cuhk03``, the former will be assigned ``dsetid=0`` while the latter will be assigned ``dsetid=1``. (2) Added ``RandomDatasetSampler``. Analogous to ``RandomDomainSampler``, ``RandomDatasetSampler`` samples a certain number of images (``batch_size // num_datasets``) from each of specified datasets (the amount is determined by ``num_datasets``).
|
50 |
+
- [Aug 2020] ``v1.2.6``: Added ``RandomDomainSampler`` (it samples ``num_cams`` cameras each with ``batch_size // num_cams`` images to form a mini-batch).
|
51 |
+
- [Jun 2020] ``v1.2.5``: (1) Dataloader's output from ``__getitem__`` has been changed from ``list`` to ``dict``. Previously, an element, e.g. image tensor, was fetched with ``imgs=data[0]``. Now it should be obtained by ``imgs=data['img']``. See this `commit <https://github.com/KaiyangZhou/deep-person-reid/commit/aefe335d68f39a20160860e6d14c2d34f539b8a5>`_ for detailed changes. (2) Added ``k_tfm`` as an option to image data loader, which allows data augmentation to be applied ``k_tfm`` times *independently* to an image. If ``k_tfm > 1``, ``imgs=data['img']`` returns a list with ``k_tfm`` image tensors.
|
52 |
+
- [May 2020] Added the person attribute recognition code used in `Omni-Scale Feature Learning for Person Re-Identification (ICCV'19) <https://arxiv.org/abs/1905.00953>`_. See ``projects/attribute_recognition/``.
|
53 |
+
- [May 2020] ``v1.2.1``: Added a simple API for feature extraction (``torchreid/utils/feature_extractor.py``). See the `documentation <https://kaiyangzhou.github.io/deep-person-reid/user_guide.html>`_ for the instruction.
|
54 |
+
- [Apr 2020] Code for reproducing the experiments of `deep mutual learning <https://zpascal.net/cvpr2018/Zhang_Deep_Mutual_Learning_CVPR_2018_paper.pdf>`_ in the `OSNet paper <https://arxiv.org/pdf/1905.00953v6.pdf>`__ (Supp. B) has been released at ``projects/DML``.
|
55 |
+
- [Apr 2020] Upgraded to ``v1.2.0``. The engine class has been made more model-agnostic to improve extensibility. See `Engine <torchreid/engine/engine.py>`_ and `ImageSoftmaxEngine <torchreid/engine/image/softmax.py>`_ for more details. Credit to `Dassl.pytorch <https://github.com/KaiyangZhou/Dassl.pytorch>`_.
|
56 |
+
- [Dec 2019] Our `OSNet paper <https://arxiv.org/pdf/1905.00953v6.pdf>`_ has been updated, with additional experiments (in section B of the supplementary) showing some useful techniques for improving OSNet's performance in practice.
|
57 |
+
- [Nov 2019] ``ImageDataManager`` can load training data from target datasets by setting ``load_train_targets=True``, and the train-loader can be accessed with ``train_loader_t = datamanager.train_loader_t``. This feature is useful for domain adaptation research.
|
58 |
+
|
59 |
+
|
60 |
+
Installation
|
61 |
+
---------------
|
62 |
+
|
63 |
+
Make sure `conda <https://www.anaconda.com/distribution/>`_ is installed.
|
64 |
+
|
65 |
+
|
66 |
+
.. code-block:: bash
|
67 |
+
|
68 |
+
# cd to your preferred directory and clone this repo
|
69 |
+
git clone https://github.com/KaiyangZhou/deep-person-reid.git
|
70 |
+
|
71 |
+
# create environment
|
72 |
+
cd deep-person-reid/
|
73 |
+
conda create --name torchreid python=3.7
|
74 |
+
conda activate torchreid
|
75 |
+
|
76 |
+
# install dependencies
|
77 |
+
# make sure `which python` and `which pip` point to the correct path
|
78 |
+
pip install -r requirements.txt
|
79 |
+
|
80 |
+
# install torch and torchvision (select the proper cuda version to suit your machine)
|
81 |
+
conda install pytorch torchvision cudatoolkit=9.0 -c pytorch
|
82 |
+
|
83 |
+
# install torchreid (don't need to re-build it if you modify the source code)
|
84 |
+
python setup.py develop
|
85 |
+
|
86 |
+
|
87 |
+
Get started: 30 seconds to Torchreid
|
88 |
+
-------------------------------------
|
89 |
+
1. Import ``torchreid``
|
90 |
+
|
91 |
+
.. code-block:: python
|
92 |
+
|
93 |
+
import torchreid
|
94 |
+
|
95 |
+
2. Load data manager
|
96 |
+
|
97 |
+
.. code-block:: python
|
98 |
+
|
99 |
+
datamanager = torchreid.data.ImageDataManager(
|
100 |
+
root="reid-data",
|
101 |
+
sources="market1501",
|
102 |
+
targets="market1501",
|
103 |
+
height=256,
|
104 |
+
width=128,
|
105 |
+
batch_size_train=32,
|
106 |
+
batch_size_test=100,
|
107 |
+
transforms=["random_flip", "random_crop"]
|
108 |
+
)
|
109 |
+
|
110 |
+
3 Build model, optimizer and lr_scheduler
|
111 |
+
|
112 |
+
.. code-block:: python
|
113 |
+
|
114 |
+
model = torchreid.models.build_model(
|
115 |
+
name="resnet50",
|
116 |
+
num_classes=datamanager.num_train_pids,
|
117 |
+
loss="softmax",
|
118 |
+
pretrained=True
|
119 |
+
)
|
120 |
+
|
121 |
+
model = model.cuda()
|
122 |
+
|
123 |
+
optimizer = torchreid.optim.build_optimizer(
|
124 |
+
model,
|
125 |
+
optim="adam",
|
126 |
+
lr=0.0003
|
127 |
+
)
|
128 |
+
|
129 |
+
scheduler = torchreid.optim.build_lr_scheduler(
|
130 |
+
optimizer,
|
131 |
+
lr_scheduler="single_step",
|
132 |
+
stepsize=20
|
133 |
+
)
|
134 |
+
|
135 |
+
4. Build engine
|
136 |
+
|
137 |
+
.. code-block:: python
|
138 |
+
|
139 |
+
engine = torchreid.engine.ImageSoftmaxEngine(
|
140 |
+
datamanager,
|
141 |
+
model,
|
142 |
+
optimizer=optimizer,
|
143 |
+
scheduler=scheduler,
|
144 |
+
label_smooth=True
|
145 |
+
)
|
146 |
+
|
147 |
+
5. Run training and test
|
148 |
+
|
149 |
+
.. code-block:: python
|
150 |
+
|
151 |
+
engine.run(
|
152 |
+
save_dir="log/resnet50",
|
153 |
+
max_epoch=60,
|
154 |
+
eval_freq=10,
|
155 |
+
print_freq=10,
|
156 |
+
test_only=False
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
A unified interface
|
161 |
+
-----------------------
|
162 |
+
In "deep-person-reid/scripts/", we provide a unified interface to train and test a model. See "scripts/main.py" and "scripts/default_config.py" for more details. The folder "configs/" contains some predefined configs which you can use as a starting point.
|
163 |
+
|
164 |
+
Below we provide an example to train and test `OSNet (Zhou et al. ICCV'19) <https://arxiv.org/abs/1905.00953>`_. Assume :code:`PATH_TO_DATA` is the directory containing reid datasets. The environmental variable :code:`CUDA_VISIBLE_DEVICES` is omitted, which you need to specify if you have a pool of gpus and want to use a specific set of them.
|
165 |
+
|
166 |
+
Conventional setting
|
167 |
+
^^^^^^^^^^^^^^^^^^^^^
|
168 |
+
|
169 |
+
To train OSNet on Market1501, do
|
170 |
+
|
171 |
+
.. code-block:: bash
|
172 |
+
|
173 |
+
python scripts/main.py \
|
174 |
+
--config-file configs/im_osnet_x1_0_softmax_256x128_amsgrad_cosine.yaml \
|
175 |
+
--transforms random_flip random_erase \
|
176 |
+
--root $PATH_TO_DATA
|
177 |
+
|
178 |
+
|
179 |
+
The config file sets Market1501 as the default dataset. If you wanna use DukeMTMC-reID, do
|
180 |
+
|
181 |
+
.. code-block:: bash
|
182 |
+
|
183 |
+
python scripts/main.py \
|
184 |
+
--config-file configs/im_osnet_x1_0_softmax_256x128_amsgrad_cosine.yaml \
|
185 |
+
-s dukemtmcreid \
|
186 |
+
-t dukemtmcreid \
|
187 |
+
--transforms random_flip random_erase \
|
188 |
+
--root $PATH_TO_DATA \
|
189 |
+
data.save_dir log/osnet_x1_0_dukemtmcreid_softmax_cosinelr
|
190 |
+
|
191 |
+
The code will automatically (download and) load the ImageNet pretrained weights. After the training is done, the model will be saved as "log/osnet_x1_0_market1501_softmax_cosinelr/model.pth.tar-250". Under the same folder, you can find the `tensorboard <https://pytorch.org/docs/stable/tensorboard.html>`_ file. To visualize the learning curves using tensorboard, you can run :code:`tensorboard --logdir=log/osnet_x1_0_market1501_softmax_cosinelr` in the terminal and visit :code:`http://localhost:6006/` in your web browser.
|
192 |
+
|
193 |
+
Evaluation is automatically performed at the end of training. To run the test again using the trained model, do
|
194 |
+
|
195 |
+
.. code-block:: bash
|
196 |
+
|
197 |
+
python scripts/main.py \
|
198 |
+
--config-file configs/im_osnet_x1_0_softmax_256x128_amsgrad_cosine.yaml \
|
199 |
+
--root $PATH_TO_DATA \
|
200 |
+
model.load_weights log/osnet_x1_0_market1501_softmax_cosinelr/model.pth.tar-250 \
|
201 |
+
test.evaluate True
|
202 |
+
|
203 |
+
|
204 |
+
Cross-domain setting
|
205 |
+
^^^^^^^^^^^^^^^^^^^^^
|
206 |
+
|
207 |
+
Suppose you wanna train OSNet on DukeMTMC-reID and test its performance on Market1501, you can do
|
208 |
+
|
209 |
+
.. code-block:: bash
|
210 |
+
|
211 |
+
python scripts/main.py \
|
212 |
+
--config-file configs/im_osnet_x1_0_softmax_256x128_amsgrad.yaml \
|
213 |
+
-s dukemtmcreid \
|
214 |
+
-t market1501 \
|
215 |
+
--transforms random_flip color_jitter \
|
216 |
+
--root $PATH_TO_DATA
|
217 |
+
|
218 |
+
Here we only test the cross-domain performance. However, if you also want to test the performance on the source dataset, i.e. DukeMTMC-reID, you can set :code:`-t dukemtmcreid market1501`, which will evaluate the model on the two datasets separately.
|
219 |
+
|
220 |
+
Different from the same-domain setting, here we replace :code:`random_erase` with :code:`color_jitter`. This can improve the generalization performance on the unseen target dataset.
|
221 |
+
|
222 |
+
Pretrained models are available in the `Model Zoo <https://kaiyangzhou.github.io/deep-person-reid/MODEL_ZOO.html>`_.
|
223 |
+
|
224 |
+
|
225 |
+
Datasets
|
226 |
+
--------
|
227 |
+
|
228 |
+
Image-reid datasets
|
229 |
+
^^^^^^^^^^^^^^^^^^^^^
|
230 |
+
- `Market1501 <https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Zheng_Scalable_Person_Re-Identification_ICCV_2015_paper.pdf>`_
|
231 |
+
- `CUHK03 <https://www.cv-foundation.org/openaccess/content_cvpr_2014/papers/Li_DeepReID_Deep_Filter_2014_CVPR_paper.pdf>`_
|
232 |
+
- `DukeMTMC-reID <https://arxiv.org/abs/1701.07717>`_
|
233 |
+
- `MSMT17 <https://arxiv.org/abs/1711.08565>`_
|
234 |
+
- `VIPeR <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.331.7285&rep=rep1&type=pdf>`_
|
235 |
+
- `GRID <http://www.eecs.qmul.ac.uk/~txiang/publications/LoyXiangGong_cvpr_2009.pdf>`_
|
236 |
+
- `CUHK01 <http://www.ee.cuhk.edu.hk/~xgwang/papers/liZWaccv12.pdf>`_
|
237 |
+
- `SenseReID <http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhao_Spindle_Net_Person_CVPR_2017_paper.pdf>`_
|
238 |
+
- `QMUL-iLIDS <http://www.eecs.qmul.ac.uk/~sgg/papers/ZhengGongXiang_BMVC09.pdf>`_
|
239 |
+
- `PRID <https://pdfs.semanticscholar.org/4c1b/f0592be3e535faf256c95e27982db9b3d3d3.pdf>`_
|
240 |
+
|
241 |
+
Geo-localization datasets
|
242 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
243 |
+
- `University-1652 <https://dl.acm.org/doi/abs/10.1145/3394171.3413896>`_
|
244 |
+
|
245 |
+
Video-reid datasets
|
246 |
+
^^^^^^^^^^^^^^^^^^^^^^^
|
247 |
+
- `MARS <http://www.liangzheng.org/1320.pdf>`_
|
248 |
+
- `iLIDS-VID <https://www.eecs.qmul.ac.uk/~sgg/papers/WangEtAl_ECCV14.pdf>`_
|
249 |
+
- `PRID2011 <https://pdfs.semanticscholar.org/4c1b/f0592be3e535faf256c95e27982db9b3d3d3.pdf>`_
|
250 |
+
- `DukeMTMC-VideoReID <http://openaccess.thecvf.com/content_cvpr_2018/papers/Wu_Exploit_the_Unknown_CVPR_2018_paper.pdf>`_
|
251 |
+
|
252 |
+
|
253 |
+
Models
|
254 |
+
-------
|
255 |
+
|
256 |
+
ImageNet classification models
|
257 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
258 |
+
- `ResNet <https://arxiv.org/abs/1512.03385>`_
|
259 |
+
- `ResNeXt <https://arxiv.org/abs/1611.05431>`_
|
260 |
+
- `SENet <https://arxiv.org/abs/1709.01507>`_
|
261 |
+
- `DenseNet <https://arxiv.org/abs/1608.06993>`_
|
262 |
+
- `Inception-ResNet-V2 <https://arxiv.org/abs/1602.07261>`_
|
263 |
+
- `Inception-V4 <https://arxiv.org/abs/1602.07261>`_
|
264 |
+
- `Xception <https://arxiv.org/abs/1610.02357>`_
|
265 |
+
- `IBN-Net <https://arxiv.org/abs/1807.09441>`_
|
266 |
+
|
267 |
+
Lightweight models
|
268 |
+
^^^^^^^^^^^^^^^^^^^
|
269 |
+
- `NASNet <https://arxiv.org/abs/1707.07012>`_
|
270 |
+
- `MobileNetV2 <https://arxiv.org/abs/1801.04381>`_
|
271 |
+
- `ShuffleNet <https://arxiv.org/abs/1707.01083>`_
|
272 |
+
- `ShuffleNetV2 <https://arxiv.org/abs/1807.11164>`_
|
273 |
+
- `SqueezeNet <https://arxiv.org/abs/1602.07360>`_
|
274 |
+
|
275 |
+
ReID-specific models
|
276 |
+
^^^^^^^^^^^^^^^^^^^^^^
|
277 |
+
- `MuDeep <https://arxiv.org/abs/1709.05165>`_
|
278 |
+
- `ResNet-mid <https://arxiv.org/abs/1711.08106>`_
|
279 |
+
- `HACNN <https://arxiv.org/abs/1802.08122>`_
|
280 |
+
- `PCB <https://arxiv.org/abs/1711.09349>`_
|
281 |
+
- `MLFN <https://arxiv.org/abs/1803.09132>`_
|
282 |
+
- `OSNet <https://arxiv.org/abs/1905.00953>`_
|
283 |
+
- `OSNet-AIN <https://arxiv.org/abs/1910.06827>`_
|
284 |
+
|
285 |
+
|
286 |
+
Useful links
|
287 |
+
-------------
|
288 |
+
- `OSNet-IBN1-Lite (test-only code with lite docker container) <https://github.com/RodMech/OSNet-IBN1-Lite>`_
|
289 |
+
- `Deep Learning for Person Re-identification: A Survey and Outlook <https://github.com/mangye16/ReID-Survey>`_
|
290 |
+
|
291 |
+
|
292 |
+
Citation
|
293 |
+
---------
|
294 |
+
If you use this code or the models in your research, please give credit to the following papers:
|
295 |
+
|
296 |
+
.. code-block:: bash
|
297 |
+
|
298 |
+
@article{torchreid,
|
299 |
+
title={Torchreid: A Library for Deep Learning Person Re-Identification in Pytorch},
|
300 |
+
author={Zhou, Kaiyang and Xiang, Tao},
|
301 |
+
journal={arXiv preprint arXiv:1910.10093},
|
302 |
+
year={2019}
|
303 |
+
}
|
304 |
+
|
305 |
+
@inproceedings{zhou2019osnet,
|
306 |
+
title={Omni-Scale Feature Learning for Person Re-Identification},
|
307 |
+
author={Zhou, Kaiyang and Yang, Yongxin and Cavallaro, Andrea and Xiang, Tao},
|
308 |
+
booktitle={ICCV},
|
309 |
+
year={2019}
|
310 |
+
}
|
311 |
+
|
312 |
+
@article{zhou2021osnet,
|
313 |
+
title={Learning Generalisable Omni-Scale Representations for Person Re-Identification},
|
314 |
+
author={Zhou, Kaiyang and Yang, Yongxin and Cavallaro, Andrea and Xiang, Tao},
|
315 |
+
journal={TPAMI},
|
316 |
+
year={2021}
|
317 |
+
}
|
strong_sort/deep/reid/configs/im_osnet_ain_x1_0_softmax_256x128_amsgrad_cosine.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: 'osnet_ain_x1_0'
|
3 |
+
pretrained: True
|
4 |
+
|
5 |
+
data:
|
6 |
+
type: 'image'
|
7 |
+
sources: ['market1501']
|
8 |
+
targets: ['market1501', 'dukemtmcreid']
|
9 |
+
height: 256
|
10 |
+
width: 128
|
11 |
+
combineall: False
|
12 |
+
transforms: ['random_flip', 'color_jitter']
|
13 |
+
save_dir: 'log/osnet_ain_x1_0_market1501_softmax_cosinelr'
|
14 |
+
|
15 |
+
loss:
|
16 |
+
name: 'softmax'
|
17 |
+
softmax:
|
18 |
+
label_smooth: True
|
19 |
+
|
20 |
+
train:
|
21 |
+
optim: 'amsgrad'
|
22 |
+
lr: 0.0015
|
23 |
+
max_epoch: 100
|
24 |
+
batch_size: 64
|
25 |
+
fixbase_epoch: 10
|
26 |
+
open_layers: ['classifier']
|
27 |
+
lr_scheduler: 'cosine'
|
28 |
+
|
29 |
+
test:
|
30 |
+
batch_size: 300
|
31 |
+
dist_metric: 'cosine'
|
32 |
+
normalize_feature: False
|
33 |
+
evaluate: False
|
34 |
+
eval_freq: -1
|
35 |
+
rerank: False
|
strong_sort/deep/reid/configs/im_osnet_ibn_x1_0_softmax_256x128_amsgrad.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: 'osnet_ibn_x1_0'
|
3 |
+
pretrained: True
|
4 |
+
|
5 |
+
data:
|
6 |
+
type: 'image'
|
7 |
+
sources: ['market1501']
|
8 |
+
targets: ['dukemtmcreid']
|
9 |
+
height: 256
|
10 |
+
width: 128
|
11 |
+
combineall: False
|
12 |
+
transforms: ['random_flip', 'color_jitter']
|
13 |
+
save_dir: 'log/osnet_ibn_x1_0_market2duke_softmax'
|
14 |
+
|
15 |
+
loss:
|
16 |
+
name: 'softmax'
|
17 |
+
softmax:
|
18 |
+
label_smooth: True
|
19 |
+
|
20 |
+
train:
|
21 |
+
optim: 'amsgrad'
|
22 |
+
lr: 0.0015
|
23 |
+
max_epoch: 150
|
24 |
+
batch_size: 64
|
25 |
+
fixbase_epoch: 10
|
26 |
+
open_layers: ['classifier']
|
27 |
+
lr_scheduler: 'single_step'
|
28 |
+
stepsize: [60]
|
29 |
+
|
30 |
+
test:
|
31 |
+
batch_size: 300
|
32 |
+
dist_metric: 'euclidean'
|
33 |
+
normalize_feature: False
|
34 |
+
evaluate: False
|
35 |
+
eval_freq: -1
|
36 |
+
rerank: False
|
strong_sort/deep/reid/configs/im_osnet_x0_25_softmax_256x128_amsgrad.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: 'osnet_x0_25'
|
3 |
+
pretrained: True
|
4 |
+
|
5 |
+
data:
|
6 |
+
type: 'image'
|
7 |
+
sources: ['market1501']
|
8 |
+
targets: ['market1501']
|
9 |
+
height: 256
|
10 |
+
width: 128
|
11 |
+
combineall: False
|
12 |
+
transforms: ['random_flip']
|
13 |
+
save_dir: 'log/osnet_x0_25_market1501_softmax'
|
14 |
+
|
15 |
+
loss:
|
16 |
+
name: 'softmax'
|
17 |
+
softmax:
|
18 |
+
label_smooth: True
|
19 |
+
|
20 |
+
train:
|
21 |
+
optim: 'amsgrad'
|
22 |
+
lr: 0.003
|
23 |
+
max_epoch: 180
|
24 |
+
batch_size: 128
|
25 |
+
fixbase_epoch: 10
|
26 |
+
open_layers: ['classifier']
|
27 |
+
lr_scheduler: 'single_step'
|
28 |
+
stepsize: [80]
|
29 |
+
|
30 |
+
test:
|
31 |
+
batch_size: 300
|
32 |
+
dist_metric: 'euclidean'
|
33 |
+
normalize_feature: False
|
34 |
+
evaluate: False
|
35 |
+
eval_freq: -1
|
36 |
+
rerank: False
|
strong_sort/deep/reid/configs/im_osnet_x0_5_softmax_256x128_amsgrad.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: 'osnet_x0_5'
|
3 |
+
pretrained: True
|
4 |
+
|
5 |
+
data:
|
6 |
+
type: 'image'
|
7 |
+
sources: ['market1501']
|
8 |
+
targets: ['market1501']
|
9 |
+
height: 256
|
10 |
+
width: 128
|
11 |
+
combineall: False
|
12 |
+
transforms: ['random_flip']
|
13 |
+
save_dir: 'log/osnet_x0_5_market1501_softmax'
|
14 |
+
|
15 |
+
loss:
|
16 |
+
name: 'softmax'
|
17 |
+
softmax:
|
18 |
+
label_smooth: True
|
19 |
+
|
20 |
+
train:
|
21 |
+
optim: 'amsgrad'
|
22 |
+
lr: 0.003
|
23 |
+
max_epoch: 180
|
24 |
+
batch_size: 128
|
25 |
+
fixbase_epoch: 10
|
26 |
+
open_layers: ['classifier']
|
27 |
+
lr_scheduler: 'single_step'
|
28 |
+
stepsize: [80]
|
29 |
+
|
30 |
+
test:
|
31 |
+
batch_size: 300
|
32 |
+
dist_metric: 'euclidean'
|
33 |
+
normalize_feature: False
|
34 |
+
evaluate: False
|
35 |
+
eval_freq: -1
|
36 |
+
rerank: False
|
strong_sort/deep/reid/configs/im_osnet_x0_75_softmax_256x128_amsgrad.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: 'osnet_x0_75'
|
3 |
+
pretrained: True
|
4 |
+
|
5 |
+
data:
|
6 |
+
type: 'image'
|
7 |
+
sources: ['market1501']
|
8 |
+
targets: ['market1501']
|
9 |
+
height: 256
|
10 |
+
width: 128
|
11 |
+
combineall: False
|
12 |
+
transforms: ['random_flip']
|
13 |
+
save_dir: 'log/osnet_x0_75_market1501_softmax'
|
14 |
+
|
15 |
+
loss:
|
16 |
+
name: 'softmax'
|
17 |
+
softmax:
|
18 |
+
label_smooth: True
|
19 |
+
|
20 |
+
train:
|
21 |
+
optim: 'amsgrad'
|
22 |
+
lr: 0.0015
|
23 |
+
max_epoch: 150
|
24 |
+
batch_size: 64
|
25 |
+
fixbase_epoch: 10
|
26 |
+
open_layers: ['classifier']
|
27 |
+
lr_scheduler: 'single_step'
|
28 |
+
stepsize: [60]
|
29 |
+
|
30 |
+
test:
|
31 |
+
batch_size: 300
|
32 |
+
dist_metric: 'euclidean'
|
33 |
+
normalize_feature: False
|
34 |
+
evaluate: False
|
35 |
+
eval_freq: -1
|
36 |
+
rerank: False
|
strong_sort/deep/reid/configs/im_osnet_x1_0_softmax_256x128_amsgrad.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: 'osnet_x1_0'
|
3 |
+
pretrained: True
|
4 |
+
|
5 |
+
data:
|
6 |
+
type: 'image'
|
7 |
+
sources: ['market1501']
|
8 |
+
targets: ['market1501']
|
9 |
+
height: 256
|
10 |
+
width: 128
|
11 |
+
combineall: False
|
12 |
+
transforms: ['random_flip']
|
13 |
+
save_dir: 'log/osnet_x1_0_market1501_softmax'
|
14 |
+
|
15 |
+
loss:
|
16 |
+
name: 'softmax'
|
17 |
+
softmax:
|
18 |
+
label_smooth: True
|
19 |
+
|
20 |
+
train:
|
21 |
+
optim: 'amsgrad'
|
22 |
+
lr: 0.0015
|
23 |
+
max_epoch: 150
|
24 |
+
batch_size: 64
|
25 |
+
fixbase_epoch: 10
|
26 |
+
open_layers: ['classifier']
|
27 |
+
lr_scheduler: 'single_step'
|
28 |
+
stepsize: [60]
|
29 |
+
|
30 |
+
test:
|
31 |
+
batch_size: 300
|
32 |
+
dist_metric: 'euclidean'
|
33 |
+
normalize_feature: False
|
34 |
+
evaluate: False
|
35 |
+
eval_freq: -1
|
36 |
+
rerank: False
|
strong_sort/deep/reid/configs/im_osnet_x1_0_softmax_256x128_amsgrad_cosine.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: 'osnet_x1_0'
|
3 |
+
pretrained: True
|
4 |
+
|
5 |
+
data:
|
6 |
+
type: 'image'
|
7 |
+
sources: ['market1501']
|
8 |
+
targets: ['market1501']
|
9 |
+
height: 256
|
10 |
+
width: 128
|
11 |
+
combineall: False
|
12 |
+
transforms: ['random_flip']
|
13 |
+
save_dir: 'log/osnet_x1_0_market1501_softmax_cosinelr'
|
14 |
+
|
15 |
+
loss:
|
16 |
+
name: 'softmax'
|
17 |
+
softmax:
|
18 |
+
label_smooth: True
|
19 |
+
|
20 |
+
train:
|
21 |
+
optim: 'amsgrad'
|
22 |
+
lr: 0.0015
|
23 |
+
max_epoch: 250
|
24 |
+
batch_size: 64
|
25 |
+
fixbase_epoch: 10
|
26 |
+
open_layers: ['classifier']
|
27 |
+
lr_scheduler: 'cosine'
|
28 |
+
|
29 |
+
test:
|
30 |
+
batch_size: 300
|
31 |
+
dist_metric: 'euclidean'
|
32 |
+
normalize_feature: False
|
33 |
+
evaluate: False
|
34 |
+
eval_freq: -1
|
35 |
+
rerank: False
|
strong_sort/deep/reid/configs/im_r50_softmax_256x128_amsgrad.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: 'resnet50_fc512'
|
3 |
+
pretrained: True
|
4 |
+
|
5 |
+
data:
|
6 |
+
type: 'image'
|
7 |
+
sources: ['market1501']
|
8 |
+
targets: ['market1501']
|
9 |
+
height: 256
|
10 |
+
width: 128
|
11 |
+
combineall: False
|
12 |
+
transforms: ['random_flip']
|
13 |
+
save_dir: 'log/resnet50_market1501_softmax'
|
14 |
+
|
15 |
+
loss:
|
16 |
+
name: 'softmax'
|
17 |
+
softmax:
|
18 |
+
label_smooth: True
|
19 |
+
|
20 |
+
train:
|
21 |
+
optim: 'amsgrad'
|
22 |
+
lr: 0.0003
|
23 |
+
max_epoch: 60
|
24 |
+
batch_size: 32
|
25 |
+
fixbase_epoch: 5
|
26 |
+
open_layers: ['classifier']
|
27 |
+
lr_scheduler: 'single_step'
|
28 |
+
stepsize: [20]
|
29 |
+
|
30 |
+
test:
|
31 |
+
batch_size: 100
|
32 |
+
dist_metric: 'euclidean'
|
33 |
+
normalize_feature: False
|
34 |
+
evaluate: False
|
35 |
+
eval_freq: -1
|
36 |
+
rerank: False
|
strong_sort/deep/reid/configs/im_r50fc512_softmax_256x128_amsgrad.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: 'resnet50_fc512'
|
3 |
+
pretrained: True
|
4 |
+
|
5 |
+
data:
|
6 |
+
type: 'image'
|
7 |
+
sources: ['market1501']
|
8 |
+
targets: ['market1501']
|
9 |
+
height: 256
|
10 |
+
width: 128
|
11 |
+
combineall: False
|
12 |
+
transforms: ['random_flip']
|
13 |
+
save_dir: 'log/resnet50_fc512_market1501_softmax'
|
14 |
+
|
15 |
+
loss:
|
16 |
+
name: 'softmax'
|
17 |
+
softmax:
|
18 |
+
label_smooth: True
|
19 |
+
|
20 |
+
train:
|
21 |
+
optim: 'amsgrad'
|
22 |
+
lr: 0.0003
|
23 |
+
max_epoch: 60
|
24 |
+
batch_size: 32
|
25 |
+
fixbase_epoch: 5
|
26 |
+
open_layers: ['fc', 'classifier']
|
27 |
+
lr_scheduler: 'single_step'
|
28 |
+
stepsize: [20]
|
29 |
+
|
30 |
+
test:
|
31 |
+
batch_size: 100
|
32 |
+
dist_metric: 'euclidean'
|
33 |
+
normalize_feature: False
|
34 |
+
evaluate: False
|
35 |
+
eval_freq: -1
|
36 |
+
rerank: False
|
strong_sort/deep/reid/docs/AWESOME_REID.md
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Awesome-ReID
|
2 |
+
Here is a collection of ReID-related research with links to papers and code. You are welcome to submit [PR](https://help.github.com/articles/creating-a-pull-request/)s if you find something missing.
|
3 |
+
|
4 |
+
|
5 |
+
- [TPAMI21] Learning Generalisable Omni-Scale Representations for Person Re-Identification [[paper](https://arxiv.org/abs/1910.06827)][[code](https://github.com/KaiyangZhou/deep-person-reid)]
|
6 |
+
|
7 |
+
- [TPAMI21] Deep Learning for Person Re-identification: A Survey and Outlook [[paper](https://arxiv.org/abs/2001.04193)] [[code](https://github.com/mangye16/ReID-Survey)]
|
8 |
+
|
9 |
+
- [ICCV19] RGB-Infrared Cross-Modality Person Re-Identification via Joint Pixel and Feature Alignment. [[paper](http://openaccess.thecvf.com/content_ICCV_2019/papers/Wang_RGB-Infrared_Cross-Modality_Person_Re-Identification_via_Joint_Pixel_and_Feature_Alignment_ICCV_2019_paper.pdf)] [[code](https://github.com/wangguanan/AlignGAN)]
|
10 |
+
|
11 |
+
- [ICCV19] Unsupervised Graph Association for Person Re-identification. [[paper](https://github.com/yichuan9527/Unsupervised-Graph-Association-for-Person-Re-identification)] [[code](https://github.com/yichuan9527/Unsupervised-Graph-Association-for-Person-Re-identification)]
|
12 |
+
|
13 |
+
- [ICCV19] Self-similarity Grouping: A Simple Unsupervised Cross Domain Adaptation Approach for Person Re-identification. [[paper](http://openaccess.thecvf.com/content_ICCV_2019/papers/Fu_Self-Similarity_Grouping_A_Simple_Unsupervised_Cross_Domain_Adaptation_Approach_for_ICCV_2019_paper.pdf)] [[code](https://github.com/OasisYang/SSG)]
|
14 |
+
|
15 |
+
- [ICCV19] Spectral Feature Transformation for Person Re-Identification. [[paper](http://openaccess.thecvf.com/content_ICCV_2019/papers/Luo_Spectral_Feature_Transformation_for_Person_Re-Identification_ICCV_2019_paper.pdf)] [[code](https://github.com/LuckyDC/SFT_REID)]
|
16 |
+
|
17 |
+
- [ICCV19] Beyond Human Parts: Dual Part-Aligned Representations for Person Re-Identification. [[paper](http://openaccess.thecvf.com/content_ICCV_2019/papers/Guo_Beyond_Human_Parts_Dual_Part-Aligned_Representations_for_Person_Re-Identification_ICCV_2019_paper.pdf)] [[code](https://github.com/ggjy/P2Net.pytorch)]
|
18 |
+
|
19 |
+
- [ICCV19] Co-segmentation Inspired Attention Networks for Video-based Person Re-identification. [[paper](http://openaccess.thecvf.com/content_ICCV_2019/papers/Subramaniam_Co-Segmentation_Inspired_Attention_Networks_for_Video-Based_Person_Re-Identification_ICCV_2019_paper.pdf)][[code](https://github.com/InnovArul/vidreid_cosegmentation)]
|
20 |
+
|
21 |
+
- [ICCV19] Mixed High-Order Attention Network for Person Re-Identification. [[paper](https://arxiv.org/abs/1908.05819)][[code](https://github.com/chenbinghui1/MHN)]
|
22 |
+
|
23 |
+
- [ICCV19] ABD-Net: Attentive but Diverse Person Re-Identification. [[paper](https://arxiv.org/abs/1908.01114)] [[code](https://github.com/TAMU-VITA/ABD-Net)]
|
24 |
+
|
25 |
+
- [ICCV19] Omni-Scale Feature Learning for Person Re-Identification. [[paper](https://arxiv.org/abs/1905.00953)] [[code](https://github.com/KaiyangZhou/deep-person-reid)]
|
26 |
+
|
27 |
+
- [CVPR19] Joint Discriminative and Generative Learning for Person Re-identification. [[paper](https://arxiv.org/abs/1904.07223)][[code](https://github.com/NVlabs/DG-Net)]
|
28 |
+
- [CVPR19] Invariance Matters: Exemplar Memory for Domain Adaptive Person Re-identification. [[paper](https://arxiv.org/abs/1904.01990)][[code](https://github.com/zhunzhong07/ECN)]
|
29 |
+
- [CVPR19] Dissecting Person Re-identification from the Viewpoint of Viewpoint. [[paper](https://arxiv.org/abs/1812.02162)][[code](https://github.com/sxzrt/Dissecting-Person-Re-ID-from-the-Viewpoint-of-Viewpoint)]
|
30 |
+
- [CVPR19] Unsupervised Person Re-identification by Soft Multilabel Learning. [[paper](https://arxiv.org/abs/1903.06325)][[code](https://github.com/KovenYu/MAR)]
|
31 |
+
- [CVPR19] Patch-based Discriminative Feature Learning for Unsupervised Person Re-identification. [[paper](https://kovenyu.com/publication/2019-cvpr-pedal/)][[code](https://github.com/QizeYang/PAUL)]
|
32 |
+
|
33 |
+
- [AAAI19] Spatial and Temporal Mutual Promotion for Video-based Person Re-identification. [[paper](https://arxiv.org/abs/1812.10305)][[code](https://github.com/yolomax/person-reid-lib)]
|
34 |
+
|
35 |
+
- [AAAI19] Spatial-Temporal Person Re-identification. [[paper](https://arxiv.org/abs/1812.03282)][[code](https://github.com/Wanggcong/Spatial-Temporal-Re-identification)]
|
36 |
+
|
37 |
+
- [AAAI19] Horizontal Pyramid Matching for Person Re-identification. [[paper](https://arxiv.org/abs/1804.05275)][[code](https://github.com/OasisYang/HPM)]
|
38 |
+
|
39 |
+
- [AAAI19] Backbone Can Not be Trained at Once: Rolling Back to Pre-trained Network for Person Re-identification. [[paper](https://arxiv.org/abs/1901.06140)][[code](https://github.com/youngminPIL/rollback)]
|
40 |
+
|
41 |
+
- [AAAI19] A Bottom-Up Clustering Approach to Unsupervised Person Re-identification. [[paper](https://vana77.github.io/vana77.github.io/images/AAAI19.pdf)][[code](https://github.com/vana77/Bottom-up-Clustering-Person-Re-identification)]
|
42 |
+
|
43 |
+
- [NIPS18] FD-GAN: Pose-guided Feature Distilling GAN for Robust Person Re-identification. [[paper](https://arxiv.org/abs/1810.02936)][[code](https://github.com/yxgeee/FD-GAN)]
|
44 |
+
|
45 |
+
- [ECCV18] Generalizing A Person Retrieval Model Hetero- and Homogeneously. [[paper](http://openaccess.thecvf.com/content_ECCV_2018/papers/Zhun_Zhong_Generalizing_A_Person_ECCV_2018_paper.pdf)][[code](https://github.com/zhunzhong07/HHL)]
|
46 |
+
|
47 |
+
- [ECCV18] Pose-Normalized Image Generation for Person Re-identification. [[paper](https://arxiv.org/abs/1712.02225)][[code](https://github.com/naiq/PN_GAN)]
|
48 |
+
|
49 |
+
- [CVPR18] Camera Style Adaptation for Person Re-Identification. [[paper](https://arxiv.org/abs/1711.10295)][[code](https://github.com/zhunzhong07/CamStyle)]
|
50 |
+
|
51 |
+
- [CVPR18] Deep Group-Shuffling Random Walk for Person Re-Identification. [[paper](https://arxiv.org/abs/1807.11178)][[code](https://github.com/YantaoShen/kpm_rw_person_reid)]
|
52 |
+
|
53 |
+
- [CVPR18] End-to-End Deep Kronecker-Product Matching for Person Re-identification. [[paper](https://arxiv.org/abs/1807.11182)][[code](https://github.com/YantaoShen/kpm_rw_person_reid)]
|
54 |
+
|
55 |
+
- [CVPR18] Features for Multi-Target Multi-Camera Tracking and Re-Identification. [[paper](https://arxiv.org/abs/1803.10859)][[code](https://github.com/ergysr/DeepCC)]
|
56 |
+
|
57 |
+
- [CVPR18] Group Consistent Similarity Learning via Deep CRF for Person Re-Identification. [[paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Chen_Group_Consistent_Similarity_CVPR_2018_paper.pdf)][[code](https://github.com/dapengchen123/crf_affinity)]
|
58 |
+
|
59 |
+
- [CVPR18] Harmonious Attention Network for Person Re-Identification. [[paper](https://arxiv.org/abs/1802.08122)][[code](https://github.com/KaiyangZhou/deep-person-reid)]
|
60 |
+
|
61 |
+
- [CVPR18] Human Semantic Parsing for Person Re-Identification. [[paper](https://arxiv.org/abs/1804.00216)][[code](https://github.com/emrahbasaran/SPReID)]
|
62 |
+
|
63 |
+
- [CVPR18] Multi-Level Factorisation Net for Person Re-Identification. [[paper](https://arxiv.org/abs/1803.09132)][[code](https://github.com/KaiyangZhou/deep-person-reid)]
|
64 |
+
|
65 |
+
- [CVPR18] Resource Aware Person Re-identification across Multiple Resolutions. [[paper](https://arxiv.org/abs/1805.08805)][[code](https://github.com/mileyan/DARENet)]
|
66 |
+
|
67 |
+
- [CVPR18] Exploit the Unknown Gradually: One-Shot Video-Based Person Re-Identification by Stepwise Learning. [[paper](https://yu-wu.net/pdf/CVPR2018_Exploit-Unknown-Gradually.pdf)][[code](https://github.com/Yu-Wu/Exploit-Unknown-Gradually)]
|
68 |
+
|
69 |
+
- [ArXiv18] Revisiting Temporal Modeling for Video-based Person ReID. [[paper](https://arxiv.org/abs/1805.02104)][[code](https://github.com/jiyanggao/Video-Person-ReID)]
|
strong_sort/deep/reid/docs/MODEL_ZOO.md
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Zoo
|
2 |
+
|
3 |
+
- Results are presented in the format of *<Rank-1 (mAP)>*.
|
4 |
+
- When computing model size and FLOPs, only layers that are used at test time are considered (see `torchreid.utils.compute_model_complexity`).
|
5 |
+
- Asterisk (\*) means the model is trained from scratch.
|
6 |
+
- `combineall=True` means all images in the dataset are used for model training.
|
7 |
+
- Why not use heavy data augmentation like [random erasing](https://arxiv.org/abs/1708.04896) for model training? It's because heavy data augmentation might harm the cross-dataset generalization performance (see [this paper](https://arxiv.org/abs/1708.04896)).
|
8 |
+
|
9 |
+
|
10 |
+
## ImageNet pretrained models
|
11 |
+
|
12 |
+
|
13 |
+
| Model | Download |
|
14 |
+
| :--- | :---: |
|
15 |
+
| shufflenet | [model](https://drive.google.com/file/d/1RFnYcHK1TM-yt3yLsNecaKCoFO4Yb6a-/view?usp=sharing) |
|
16 |
+
| mobilenetv2_x1_0 | [model](https://drive.google.com/file/d/1K7_CZE_L_Tf-BRY6_vVm0G-0ZKjVWh3R/view?usp=sharing) |
|
17 |
+
| mobilenetv2_x1_4 | [model](https://drive.google.com/file/d/10c0ToIGIVI0QZTx284nJe8QfSJl5bIta/view?usp=sharing) |
|
18 |
+
| mlfn | [model](https://drive.google.com/file/d/1PP8Eygct5OF4YItYRfA3qypYY9xiqHuV/view?usp=sharing) |
|
19 |
+
| osnet_x1_0 | [model](https://drive.google.com/file/d/1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY/view?usp=sharing) |
|
20 |
+
| osnet_x0_75 | [model](https://drive.google.com/file/d/1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq/view?usp=sharing) |
|
21 |
+
| osnet_x0_5 | [model](https://drive.google.com/file/d/16DGLbZukvVYgINws8u8deSaOqjybZ83i/view?usp=sharing) |
|
22 |
+
| osnet_x0_25 | [model](https://drive.google.com/file/d/1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs/view?usp=sharing) |
|
23 |
+
| osnet_ibn_x1_0 | [model](https://drive.google.com/file/d/1sr90V6irlYYDd4_4ISU2iruoRG8J__6l/view?usp=sharing) |
|
24 |
+
| osnet_ain_x1_0 | [model](https://drive.google.com/file/d/1-CaioD9NaqbHK_kzSMW8VE4_3KcsRjEo/view?usp=sharing) |
|
25 |
+
| osnet_ain_x0_75 | [model](https://drive.google.com/file/d/1apy0hpsMypqstfencdH-jKIUEFOW4xoM/view?usp=sharing) |
|
26 |
+
| osnet_ain_x0_5 | [model](https://drive.google.com/file/d/1KusKvEYyKGDTUBVRxRiz55G31wkihB6l/view?usp=sharing) |
|
27 |
+
| osnet_ain_x0_25 | [model](https://drive.google.com/file/d/1SxQt2AvmEcgWNhaRb2xC4rP6ZwVDP0Wt/view?usp=sharing) |
|
28 |
+
|
29 |
+
|
30 |
+
## Same-domain ReID
|
31 |
+
|
32 |
+
|
33 |
+
| Model | # Param (10^6) | GFLOPs | Loss | Input | Transforms | Distance | market1501 | dukemtmcreid | msmt17 |
|
34 |
+
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
35 |
+
| resnet50 | 23.5 | 2.7 | softmax | (256, 128) | `random_flip`, `random_crop` | `euclidean` | [87.9 (70.4)](https://drive.google.com/file/d/1dUUZ4rHDWohmsQXCRe2C_HbYkzz94iBV/view?usp=sharing) | [78.3 (58.9)](https://drive.google.com/file/d/17ymnLglnc64NRvGOitY3BqMRS9UWd1wg/view?usp=sharing) | [63.2 (33.9)](https://drive.google.com/file/d/1ep7RypVDOthCRIAqDnn4_N-UhkkFHJsj/view?usp=sharing) |
|
36 |
+
| resnet50_fc512 | 24.6 | 4.1 | softmax | (256, 128) | `random_flip`, `random_crop` | `euclidean` | [90.8 (75.3)](https://drive.google.com/file/d/1kv8l5laX_YCdIGVCetjlNdzKIA3NvsSt/view?usp=sharing) | [81.0 (64.0)](https://drive.google.com/file/d/13QN8Mp3XH81GK4BPGXobKHKyTGH50Rtx/view?usp=sharing) | [69.6 (38.4)](https://drive.google.com/file/d/1fDJLcz4O5wxNSUvImIIjoaIF9u1Rwaud/view?usp=sharing) |
|
37 |
+
| mlfn | 32.5 | 2.8 | softmax | (256, 128) | `random_flip`, `random_crop` | `euclidean` | [90.1 (74.3)](https://drive.google.com/file/d/1wXcvhA_b1kpDfrt9s2Pma-MHxtj9pmvS/view?usp=sharing) | [81.1 (63.2)](https://drive.google.com/file/d/1rExgrTNb0VCIcOnXfMsbwSUW1h2L1Bum/view?usp=sharing) | [66.4 (37.2)](https://drive.google.com/file/d/18JzsZlJb3Wm7irCbZbZ07TN4IFKvR6p-/view?usp=sharing) |
|
38 |
+
| hacnn<sup>*</sup> | 4.5 | 0.5 | softmax | (160, 64) | `random_flip`, `random_crop` | `euclidean` | [90.9 (75.6)](https://drive.google.com/file/d/1LRKIQduThwGxMDQMiVkTScBwR7WidmYF/view?usp=sharing) | [80.1 (63.2)](https://drive.google.com/file/d/1zNm6tP4ozFUCUQ7Sv1Z98EAJWXJEhtYH/view?usp=sharing) | [64.7 (37.2)](https://drive.google.com/file/d/1MsKRtPM5WJ3_Tk2xC0aGOO7pM3VaFDNZ/view?usp=sharing) |
|
39 |
+
| mobilenetv2_x1_0 | 2.2 | 0.2 | softmax | (256, 128) | `random_flip`, `random_crop` | `euclidean` | [85.6 (67.3)](https://drive.google.com/file/d/18DgHC2ZJkjekVoqBWszD8_Xiikz-fewp/view?usp=sharing) | [74.2 (54.7)](https://drive.google.com/file/d/1q1WU2FETRJ3BXcpVtfJUuqq4z3psetds/view?usp=sharing) | [57.4 (29.3)](https://drive.google.com/file/d/1j50Hv14NOUAg7ZeB3frzfX-WYLi7SrhZ/view?usp=sharing) |
|
40 |
+
| mobilenetv2_x1_4 | 4.3 | 0.4 | softmax | (256, 128) | `random_flip`, `random_crop` | `euclidean` | [87.0 (68.5)](https://drive.google.com/file/d/1t6JCqphJG-fwwPVkRLmGGyEBhGOf2GO5/view?usp=sharing) | [76.2 (55.8)](https://drive.google.com/file/d/12uD5FeVqLg9-AFDju2L7SQxjmPb4zpBN/view?usp=sharing) | [60.1 (31.5)](https://drive.google.com/file/d/1ZY5P2Zgm-3RbDpbXM0kIBMPvspeNIbXz/view?usp=sharing) |
|
41 |
+
| osnet_x1_0 | 2.2 | 0.98 | softmax | (256, 128) | `random_flip` | `euclidean` | [94.2 (82.6)](https://drive.google.com/file/d/1vduhq5DpN2q1g4fYEZfPI17MJeh9qyrA/view?usp=sharing) | [87.0 (70.2)](https://drive.google.com/file/d/1QZO_4sNf4hdOKKKzKc-TZU9WW1v6zQbq/view?usp=sharing) | [74.9 (43.8)](https://drive.google.com/file/d/112EMUfBPYeYg70w-syK6V6Mx8-Qb9Q1M/view?usp=sharing) |
|
42 |
+
| osnet_x0_75 | 1.3 | 0.57 | softmax | (256, 128) | `random_flip` | `euclidean` | [93.7 (81.2)](https://drive.google.com/file/d/1ozRaDSQw_EQ8_93OUmjDbvLXw9TnfPer/view?usp=sharing) | [85.8 (69.8)](https://drive.google.com/file/d/1IE3KRaTPp4OUa6PGTFL_d5_KQSJbP0Or/view?usp=sharing) | [72.8 (41.4)](https://drive.google.com/file/d/1QEGO6WnJ-BmUzVPd3q9NoaO_GsPNlmWc/view?usp=sharing) |
|
43 |
+
| osnet_x0_5 | 0.6 | 0.27 | softmax | (256, 128) | `random_flip` | `euclidean` | [92.5 (79.8)](https://drive.google.com/file/d/1PLB9rgqrUM7blWrg4QlprCuPT7ILYGKT/view?usp=sharing) | [85.1 (67.4)](https://drive.google.com/file/d/1KoUVqmiST175hnkALg9XuTi1oYpqcyTu/view?usp=sharing) | [69.7 (37.5)](https://drive.google.com/file/d/1UT3AxIaDvS2PdxzZmbkLmjtiqq7AIKCv/view?usp=sharing) |
|
44 |
+
| osnet_x0_25 | 0.2 | 0.08 | softmax | (256, 128) | `random_flip` | `euclidean` | [91.2 (75.0)](https://drive.google.com/file/d/1z1UghYvOTtjx7kEoRfmqSMu-z62J6MAj/view?usp=sharing) | [82.0 (61.4)](https://drive.google.com/file/d/1eumrtiXT4NOspjyEV4j8cHmlOaaCGk5l/view?usp=sharing) | [61.4 (29.5)](https://drive.google.com/file/d/1sSwXSUlj4_tHZequ_iZ8w_Jh0VaRQMqF/view?usp=sharing) |
|
45 |
+
|
46 |
+
|
47 |
+
## Cross-domain ReID
|
48 |
+
|
49 |
+
#### Market1501 -> DukeMTMC-reID
|
50 |
+
|
51 |
+
|
52 |
+
| Model | # Param (10^6) | GFLOPs | Loss | Input | Transforms | Distance | Rank-1 | Rank-5 | Rank-10 | mAP | Download |
|
53 |
+
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
54 |
+
| osnet_ibn_x1_0 | 2.2 | 0.98 | softmax | (256, 128) | `random_flip`, `color_jitter` | `euclidean` | 48.5 | 62.3 | 67.4 | 26.7 | [model](https://drive.google.com/file/d/1uWW7_z_IcUmRNPqQOrEBdsvic94fWH37/view?usp=sharing) |
|
55 |
+
| osnet_ain_x1_0 | 2.2 | 0.98 | softmax | (256, 128) | `random_flip`, `color_jitter` | `cosine` | 52.4 | 66.1 | 71.2 | 30.5 | [model](https://drive.google.com/file/d/14bNFGm0FhwHEkEpYKqKiDWjLNhXywFAd/view?usp=sharing) |
|
56 |
+
|
57 |
+
|
58 |
+
#### DukeMTMC-reID -> Market1501
|
59 |
+
|
60 |
+
|
61 |
+
| Model | # Param (10^6) | GFLOPs | Loss | Input | Transforms | Distance | Rank-1 | Rank-5 | Rank-10 | mAP | Download |
|
62 |
+
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
63 |
+
| osnet_ibn_x1_0 | 2.2 | 0.98 | softmax | (256, 128) | `random_flip`, `color_jitter` | `euclidean` | 57.7 | 73.7 | 80.0 | 26.1 | [model](https://drive.google.com/file/d/1CNxL1IP0BjcE1TSttiVOID1VNipAjiF3/view?usp=sharing) |
|
64 |
+
| osnet_ain_x1_0 | 2.2 | 0.98 | softmax | (256, 128) | `random_flip`, `color_jitter` | `cosine` | 61.0 | 77.0 | 82.5 | 30.6 | [model](https://drive.google.com/file/d/1hypJvq8G04SOby6jvF337GEkg5K_bmCw/view?usp=sharing) |
|
65 |
+
|
66 |
+
|
67 |
+
#### MSMT17 (`combineall=True`) -> Market1501 & DukeMTMC-reID
|
68 |
+
|
69 |
+
|
70 |
+
| Model | # Param (10^6) | GFLOPs | Loss | Input | Transforms | Distance | msmt17 -> market1501 | msmt17 -> dukemtmcreid | Download |
|
71 |
+
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
72 |
+
| resnet50 | 23.5 | 2.7 | softmax | (256, 128) | `random_flip`, `color_jitter` | `euclidean` | 46.3 (22.8) | 52.3 (32.1) | [model](https://drive.google.com/file/d/1yiBteqgIZoOeywE8AhGmEQl7FTVwrQmf/view?usp=sharing) |
|
73 |
+
| osnet_x1_0 | 2.2 | 0.98 | softmax | (256, 128) | `random_flip`, `color_jitter` | `euclidean` | 66.6 (37.5) | 66.0 (45.3) | [model](https://drive.google.com/file/d/1IosIFlLiulGIjwW3H8uMRmx3MzPwf86x/view?usp=sharing) |
|
74 |
+
| osnet_x0_75 | 1.3 | 0.57 | softmax | (256, 128) | `random_flip`, `color_jitter` | `euclidean` | 63.6 (35.5) | 65.3 (44.5) | [model](https://drive.google.com/file/d/1fhjSS_7SUGCioIf2SWXaRGPqIY9j7-uw/view?usp=sharing) |
|
75 |
+
| osnet_x0_5 | 0.6 | 0.27 | softmax | (256, 128) | `random_flip`, `color_jitter` | `euclidean` | 64.3 (34.9) | 65.2 (43.3) | [model](https://drive.google.com/file/d/1DHgmb6XV4fwG3n-CnCM0zdL9nMsZ9_RF/view?usp=sharing) |
|
76 |
+
| osnet_x0_25 | 0.2 | 0.08 | softmax | (256, 128) | `random_flip`, `color_jitter` | `euclidean` | 59.9 (31.0) | 61.5 (39.6) | [model](https://drive.google.com/file/d/1Kkx2zW89jq_NETu4u42CFZTMVD5Hwm6e/view?usp=sharing) |
|
77 |
+
| osnet_ibn_x1_0 | 2.2 | 0.98 | softmax | (256, 128) | `random_flip`, `color_jitter` | `euclidean` | 66.5 (37.2) | 67.4 (45.6) | [model](https://drive.google.com/file/d/1q3Sj2ii34NlfxA4LvmHdWO_75NDRmECJ/view?usp=sharing) |
|
78 |
+
| osnet_ain_x1_0 | 2.2 | 0.98 | softmax | (256, 128) | `random_flip`, `color_jitter` | `cosine` | 70.1 (43.3) | 71.1 (52.7) | [model](https://drive.google.com/file/d/1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal/view?usp=sharing) |
|
79 |
+
|
80 |
+
|
81 |
+
#### Multi-source domain generalization
|
82 |
+
|
83 |
+
The models below are trained using multiple source datasets, as described in [Zhou et al. TPAMI'21](https://arxiv.org/abs/1910.06827).
|
84 |
+
|
85 |
+
Regarding the abbreviations, MS is MSMT17; M is Market1501; D is DukeMTMC-reID; and C is CUHK03.
|
86 |
+
|
87 |
+
All models were trained with [im_osnet_ain_x1_0_softmax_256x128_amsgrad_cosine.yaml](https://github.com/KaiyangZhou/deep-person-reid/blob/master/configs/im_osnet_ain_x1_0_softmax_256x128_amsgrad_cosine.yaml) and `max_epoch=50`.
|
88 |
+
|
89 |
+
| Model | # Param (10^6) | GFLOPs | Loss | Input | Transforms | Distance | MS+D+C->M | MS+M+C->D | MS+D+M->C |D+M+C->MS |
|
90 |
+
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
91 |
+
| osnet_x1_0 | 2.2 | 0.98 | softmax | (256, 128) | `random_flip`, `color_jitter` | `cosine` | [72.5 (44.2)](https://drive.google.com/file/d/1tuYY1vQXReEd8N8_npUkc7npPDDmjNCV/view?usp=sharing) | [65.2 (47.0)](https://drive.google.com/file/d/1UxUI4NsE108UCvcy3O1Ufe73nIVPKCiu/view?usp=sharing) | [23.9 (23.3)](https://drive.google.com/file/d/1kAA6qHJvbaJtyh1b39ZyEqWROwUgWIhl/view?usp=sharing) | [33.2 (12.6)](https://drive.google.com/file/d/1wAHuYVTzj8suOwqCNcEmu6YdbVnHDvA2/view?usp=sharing) |
|
92 |
+
| osnet_ibn_x1_0 | 2.2 | 0.98 | softmax | (256, 128) | `random_flip`, `color_jitter` | `cosine` | [73.0 (44.9)](https://drive.google.com/file/d/14sH6yZwuNHPTElVoEZ26zozOOZIej5Mf/view?usp=sharing) | [64.6 (45.7)](https://drive.google.com/file/d/1Sk-2SSwKAF8n1Z4p_Lm_pl0E6v2WlIBn/view?usp=sharing) | [25.7 (25.4)](https://drive.google.com/file/d/1actHP7byqWcK4eBE1ojnspSMdo7k2W4G/view?usp=sharing) | [39.8 (16.2)](https://drive.google.com/file/d/1BGOSdLdZgqHe2qFafatb-5sPY40JlYfp/view?usp=sharing) |
|
93 |
+
| osnet_ain_x1_0 | 2.2 | 0.98 | softmax | (256, 128) | `random_flip`, `color_jitter` | `cosine` | [73.3 (45.8)](https://drive.google.com/file/d/1nIrszJVYSHf3Ej8-j6DTFdWz8EnO42PB/view?usp=sharing) | [65.6 (47.2)](https://drive.google.com/file/d/1YjJ1ZprCmaKG6MH2P9nScB9FL_Utf9t1/view?usp=sharing) | [27.4 (27.1)](https://drive.google.com/file/d/1IxIg5P0cei3KPOJQ9ZRWDE_Mdrz01ha2/view?usp=sharing) | [40.2 (16.2)](https://drive.google.com/file/d/1KcoUKzLmsUoGHI7B6as_Z2fXL50gzexS/view?usp=sharing) |
|
strong_sort/deep/reid/docs/Makefile
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Minimal makefile for Sphinx documentation
|
2 |
+
#
|
3 |
+
|
4 |
+
# You can set these variables from the command line.
|
5 |
+
SPHINXOPTS =
|
6 |
+
SPHINXBUILD = sphinx-build
|
7 |
+
SOURCEDIR = .
|
8 |
+
BUILDDIR = _build
|
9 |
+
|
10 |
+
# Put it first so that "make" without argument is like "make help".
|
11 |
+
help:
|
12 |
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
13 |
+
|
14 |
+
.PHONY: help Makefile
|
15 |
+
|
16 |
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
17 |
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
18 |
+
%: Makefile
|
19 |
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
strong_sort/deep/reid/docs/conf.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Configuration file for the Sphinx documentation builder.
|
4 |
+
#
|
5 |
+
# This file does only contain a selection of the most common options. For a
|
6 |
+
# full list see the documentation:
|
7 |
+
# http://www.sphinx-doc.org/en/master/config
|
8 |
+
|
9 |
+
# -- Path setup --------------------------------------------------------------
|
10 |
+
|
11 |
+
# If extensions (or modules to document with autodoc) are in another directory,
|
12 |
+
# add these directories to sys.path here. If the directory is relative to the
|
13 |
+
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
14 |
+
#
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
|
18 |
+
sys.path.insert(0, os.path.abspath('..'))
|
19 |
+
|
20 |
+
# -- Project information -----------------------------------------------------
|
21 |
+
|
22 |
+
project = u'torchreid'
|
23 |
+
copyright = u'2019, Kaiyang Zhou'
|
24 |
+
author = u'Kaiyang Zhou'
|
25 |
+
|
26 |
+
version_file = '../torchreid/__init__.py'
|
27 |
+
with open(version_file, 'r') as f:
|
28 |
+
exec(compile(f.read(), version_file, 'exec'))
|
29 |
+
__version__ = locals()['__version__']
|
30 |
+
|
31 |
+
# The short X.Y version
|
32 |
+
version = __version__
|
33 |
+
# The full version, including alpha/beta/rc tags
|
34 |
+
release = __version__
|
35 |
+
|
36 |
+
# -- General configuration ---------------------------------------------------
|
37 |
+
|
38 |
+
# If your documentation needs a minimal Sphinx version, state it here.
|
39 |
+
#
|
40 |
+
# needs_sphinx = '1.0'
|
41 |
+
|
42 |
+
# Add any Sphinx extension module names here, as strings. They can be
|
43 |
+
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
44 |
+
# ones.
|
45 |
+
extensions = [
|
46 |
+
'sphinx.ext.autodoc',
|
47 |
+
'sphinxcontrib.napoleon',
|
48 |
+
'sphinx.ext.viewcode',
|
49 |
+
'sphinx.ext.githubpages',
|
50 |
+
'sphinx_markdown_tables',
|
51 |
+
]
|
52 |
+
|
53 |
+
# Add any paths that contain templates here, relative to this directory.
|
54 |
+
templates_path = ['_templates']
|
55 |
+
|
56 |
+
# The suffix(es) of source filenames.
|
57 |
+
# You can specify multiple suffix as a list of string:
|
58 |
+
#
|
59 |
+
source_suffix = ['.rst', '.md']
|
60 |
+
# source_suffix = '.rst'
|
61 |
+
source_parsers = {'.md': 'recommonmark.parser.CommonMarkParser'}
|
62 |
+
|
63 |
+
# The master toctree document.
|
64 |
+
master_doc = 'index'
|
65 |
+
|
66 |
+
# The language for content autogenerated by Sphinx. Refer to documentation
|
67 |
+
# for a list of supported languages.
|
68 |
+
#
|
69 |
+
# This is also used if you do content translation via gettext catalogs.
|
70 |
+
# Usually you set "language" from the command line for these cases.
|
71 |
+
language = None
|
72 |
+
|
73 |
+
# List of patterns, relative to source directory, that match files and
|
74 |
+
# directories to ignore when looking for source files.
|
75 |
+
# This pattern also affects html_static_path and html_extra_path.
|
76 |
+
exclude_patterns = [u'_build', 'Thumbs.db', '.DS_Store']
|
77 |
+
|
78 |
+
# The name of the Pygments (syntax highlighting) style to use.
|
79 |
+
pygments_style = None
|
80 |
+
|
81 |
+
# -- Options for HTML output -------------------------------------------------
|
82 |
+
|
83 |
+
# The theme to use for HTML and HTML Help pages. See the documentation for
|
84 |
+
# a list of builtin themes.
|
85 |
+
#
|
86 |
+
html_theme = 'sphinx_rtd_theme'
|
87 |
+
|
88 |
+
# Theme options are theme-specific and customize the look and feel of a theme
|
89 |
+
# further. For a list of options available for each theme, see the
|
90 |
+
# documentation.
|
91 |
+
#
|
92 |
+
# html_theme_options = {}
|
93 |
+
|
94 |
+
# Add any paths that contain custom static files (such as style sheets) here,
|
95 |
+
# relative to this directory. They are copied after the builtin static files,
|
96 |
+
# so a file named "default.css" will overwrite the builtin "default.css".
|
97 |
+
html_static_path = ['_static']
|
98 |
+
|
99 |
+
# Custom sidebar templates, must be a dictionary that maps document names
|
100 |
+
# to template names.
|
101 |
+
#
|
102 |
+
# The default sidebars (for documents that don't match any pattern) are
|
103 |
+
# defined by theme itself. Builtin themes are using these templates by
|
104 |
+
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
|
105 |
+
# 'searchbox.html']``.
|
106 |
+
#
|
107 |
+
# html_sidebars = {}
|
108 |
+
|
109 |
+
# -- Options for HTMLHelp output ---------------------------------------------
|
110 |
+
|
111 |
+
# Output file base name for HTML help builder.
|
112 |
+
htmlhelp_basename = 'torchreiddoc'
|
113 |
+
|
114 |
+
# -- Options for LaTeX output ------------------------------------------------
|
115 |
+
|
116 |
+
latex_elements = {
|
117 |
+
# The paper size ('letterpaper' or 'a4paper').
|
118 |
+
#
|
119 |
+
# 'papersize': 'letterpaper',
|
120 |
+
|
121 |
+
# The font size ('10pt', '11pt' or '12pt').
|
122 |
+
#
|
123 |
+
# 'pointsize': '10pt',
|
124 |
+
|
125 |
+
# Additional stuff for the LaTeX preamble.
|
126 |
+
#
|
127 |
+
# 'preamble': '',
|
128 |
+
|
129 |
+
# Latex figure (float) alignment
|
130 |
+
#
|
131 |
+
# 'figure_align': 'htbp',
|
132 |
+
}
|
133 |
+
|
134 |
+
# Grouping the document tree into LaTeX files. List of tuples
|
135 |
+
# (source start file, target name, title,
|
136 |
+
# author, documentclass [howto, manual, or own class]).
|
137 |
+
latex_documents = [
|
138 |
+
(
|
139 |
+
master_doc, 'torchreid.tex', u'torchreid Documentation',
|
140 |
+
u'Kaiyang Zhou', 'manual'
|
141 |
+
),
|
142 |
+
]
|
143 |
+
|
144 |
+
# -- Options for manual page output ------------------------------------------
|
145 |
+
|
146 |
+
# One entry per manual page. List of tuples
|
147 |
+
# (source start file, name, description, authors, manual section).
|
148 |
+
man_pages = [
|
149 |
+
(master_doc, 'torchreid', u'torchreid Documentation', [author], 1)
|
150 |
+
]
|
151 |
+
|
152 |
+
# -- Options for Texinfo output ----------------------------------------------
|
153 |
+
|
154 |
+
# Grouping the document tree into Texinfo files. List of tuples
|
155 |
+
# (source start file, target name, title, author,
|
156 |
+
# dir menu entry, description, category)
|
157 |
+
texinfo_documents = [
|
158 |
+
(
|
159 |
+
master_doc, 'torchreid', u'torchreid Documentation', author,
|
160 |
+
'torchreid', 'One line description of project.', 'Miscellaneous'
|
161 |
+
),
|
162 |
+
]
|
163 |
+
|
164 |
+
# -- Options for Epub output -------------------------------------------------
|
165 |
+
|
166 |
+
# Bibliographic Dublin Core info.
|
167 |
+
epub_title = project
|
168 |
+
|
169 |
+
# The unique identifier of the text. This can be a ISBN number
|
170 |
+
# or the project homepage.
|
171 |
+
#
|
172 |
+
# epub_identifier = ''
|
173 |
+
|
174 |
+
# A unique identification for the text.
|
175 |
+
#
|
176 |
+
# epub_uid = ''
|
177 |
+
|
178 |
+
# A list of files that should not be packed into the epub file.
|
179 |
+
epub_exclude_files = ['search.html']
|
180 |
+
|
181 |
+
# -- Extension configuration -------------------------------------------------
|
strong_sort/deep/reid/docs/datasets.rst
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. _datasets:
|
2 |
+
|
3 |
+
Datasets
|
4 |
+
=========
|
5 |
+
|
6 |
+
Here we provide a comprehensive guide on how to prepare the datasets.
|
7 |
+
|
8 |
+
Suppose you want to store the reid data in a directory called "path/to/reid-data/", you need to specify the ``root`` as *root='path/to/reid-data/'* when initializing ``DataManager``. Below we use ``$REID`` to denote "path/to/reid-data".
|
9 |
+
|
10 |
+
Please refer to :ref:`torchreid_data` for details regarding the arguments.
|
11 |
+
|
12 |
+
|
13 |
+
.. note::
|
14 |
+
Dataset with a :math:`\dagger` symbol means that the process is automated, so you can directly call the dataset in ``DataManager`` (which automatically downloads the dataset and organizes the data structure). However, we also provide a way below to help the manual setup in case the automation fails.
|
15 |
+
|
16 |
+
|
17 |
+
.. note::
|
18 |
+
The keys to use specific datasets are enclosed in the parantheses beside the datasets' names.
|
19 |
+
|
20 |
+
|
21 |
+
.. note::
|
22 |
+
You are suggested to use the provided names for dataset folders such as "market1501" for Market1501 and "dukemtmcreid" for DukeMTMC-reID when doing the manual setup, otherwise you need to modify the source code accordingly (i.e. the ``dataset_dir`` attribute).
|
23 |
+
|
24 |
+
.. note::
|
25 |
+
Some download links provided by the original authors might not work. You can email `Kaiyang Zhou <https://kaiyangzhou.github.io/>`_ to reqeust new links. Please do provide your full name, institution, and purpose of using the data in the email (best use your work email address).
|
26 |
+
|
27 |
+
.. contents::
|
28 |
+
:local:
|
29 |
+
|
30 |
+
|
31 |
+
Image Datasets
|
32 |
+
--------------
|
33 |
+
|
34 |
+
Market1501 :math:`^\dagger` (``market1501``)
|
35 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
36 |
+
- Create a directory named "market1501" under ``$REID``.
|
37 |
+
- Download the dataset to "market1501" from http://www.liangzheng.org/Project/project_reid.html and extract the files.
|
38 |
+
- The data structure should look like
|
39 |
+
|
40 |
+
.. code-block:: none
|
41 |
+
|
42 |
+
market1501/
|
43 |
+
Market-1501-v15.09.15/
|
44 |
+
query/
|
45 |
+
bounding_box_train/
|
46 |
+
bounding_box_test/
|
47 |
+
|
48 |
+
- To use the extra 500K distractors (i.e. Market1501 + 500K), go to the **Market-1501+500k Dataset** section at http://www.liangzheng.org/Project/project_reid.html, download the zip file "distractors_500k.zip" and extract it under "market1501/Market-1501-v15.09.15". The argument to use these 500K distrctors is ``market1501_500k`` in ``ImageDataManager``.
|
49 |
+
|
50 |
+
|
51 |
+
CUHK03 (``cuhk03``)
|
52 |
+
^^^^^^^^^^^^^^^^^^^^^
|
53 |
+
- Create a folder named "cuhk03" under ``$REID``.
|
54 |
+
- Download the dataset to "cuhk03/" from http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html and extract "cuhk03_release.zip", resulting in "cuhk03/cuhk03_release/".
|
55 |
+
- Download the new split (767/700) from `person-re-ranking <https://github.com/zhunzhong07/person-re-ranking/tree/master/evaluation/data/CUHK03>`_. What you need are "cuhk03_new_protocol_config_detected.mat" and "cuhk03_new_protocol_config_labeled.mat". Put these two mat files under "cuhk03/".
|
56 |
+
- The data structure should look like
|
57 |
+
|
58 |
+
.. code-block:: none
|
59 |
+
|
60 |
+
cuhk03/
|
61 |
+
cuhk03_release/
|
62 |
+
cuhk03_new_protocol_config_detected.mat
|
63 |
+
cuhk03_new_protocol_config_labeled.mat
|
64 |
+
|
65 |
+
|
66 |
+
- In the default mode, we load data using the new split (767/700). If you wanna use the original (20) splits (1367/100), please set ``cuhk03_classic_split`` to True in ``ImageDataManager``. As the CMC is computed differently from Market1501 for the 1367/100 split (see `here <http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html>`_), you need to enable ``use_metric_cuhk03`` in ``ImageDataManager`` to activate the *single-gallery-shot* metric for fair comparison with some methods that adopt the old splits (*do not need to report mAP*). In addition, we support both *labeled* and *detected* modes. The default mode loads *detected* images. Enable ``cuhk03_labeled`` in ``ImageDataManager`` if you wanna train and test on *labeled* images.
|
67 |
+
|
68 |
+
.. note::
|
69 |
+
The code will extract images in "cuhk-03.mat" and save them under "cuhk03/images_detected" and "cuhk03/images_labeled". Also, four json files will be automatically generated, i.e. "splits_classic_detected.json", "splits_classic_labeled.json", "splits_new_detected.json" and "splits_new_labeled.json". If the parent path of ``$REID`` is changed, these json files should be manually deleted. The code can automatically generate new json files to match the new path.
|
70 |
+
|
71 |
+
|
72 |
+
DukeMTMC-reID :math:`^\dagger` (``dukemtmcreid``)
|
73 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
74 |
+
- Create a directory called "dukemtmc-reid" under ``$REID``.
|
75 |
+
- Download "DukeMTMC-reID" from http://vision.cs.duke.edu/DukeMTMC/ and extract it under "dukemtmc-reid".
|
76 |
+
- The data structure should look like
|
77 |
+
|
78 |
+
.. code-block:: none
|
79 |
+
|
80 |
+
dukemtmc-reid/
|
81 |
+
DukeMTMC-reID/
|
82 |
+
query/
|
83 |
+
bounding_box_train/
|
84 |
+
bounding_box_test/
|
85 |
+
...
|
86 |
+
|
87 |
+
MSMT17 (``msmt17``)
|
88 |
+
^^^^^^^^^^^^^^^^^^^^^
|
89 |
+
- Create a directory called "msmt17" under ``$REID``.
|
90 |
+
- Download the dataset from http://www.pkuvmc.com/publications/msmt17.html to "msmt17" and extract the files.
|
91 |
+
- The data structure should look like
|
92 |
+
|
93 |
+
.. code-block:: none
|
94 |
+
|
95 |
+
msmt17/
|
96 |
+
MSMT17_V1/ # or MSMT17_V2
|
97 |
+
train/
|
98 |
+
test/
|
99 |
+
list_train.txt
|
100 |
+
list_query.txt
|
101 |
+
list_gallery.txt
|
102 |
+
list_val.txt
|
103 |
+
|
104 |
+
VIPeR :math:`^\dagger` (``viper``)
|
105 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
106 |
+
- The download link is http://users.soe.ucsc.edu/~manduchi/VIPeR.v1.0.zip.
|
107 |
+
- Organize the dataset in a folder named "viper" as follows
|
108 |
+
|
109 |
+
.. code-block:: none
|
110 |
+
|
111 |
+
viper/
|
112 |
+
VIPeR/
|
113 |
+
cam_a/
|
114 |
+
cam_b/
|
115 |
+
|
116 |
+
GRID :math:`^\dagger` (``grid``)
|
117 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
118 |
+
- The download link is http://personal.ie.cuhk.edu.hk/~ccloy/files/datasets/underground_reid.zip.
|
119 |
+
- Organize the dataset in a folder named "grid" as follows
|
120 |
+
|
121 |
+
.. code-block:: none
|
122 |
+
|
123 |
+
grid/
|
124 |
+
underground_reid/
|
125 |
+
probe/
|
126 |
+
gallery/
|
127 |
+
...
|
128 |
+
|
129 |
+
CUHK01 (``cuhk01``)
|
130 |
+
^^^^^^^^^^^^^^^^^^^^^^^^
|
131 |
+
- Create a folder named "cuhk01" under ``$REID``.
|
132 |
+
- Download "CUHK01.zip" from http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html and place it under "cuhk01/".
|
133 |
+
- The code can automatically extract the files, or you can do it yourself.
|
134 |
+
- The data structure should look like
|
135 |
+
|
136 |
+
.. code-block:: none
|
137 |
+
|
138 |
+
cuhk01/
|
139 |
+
campus/
|
140 |
+
|
141 |
+
SenseReID (``sensereid``)
|
142 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
143 |
+
- Create "sensereid" under ``$REID``.
|
144 |
+
- Download the dataset from this `link <https://drive.google.com/file/d/0B56OfSrVI8hubVJLTzkwV2VaOWM/view>`_ and extract it to "sensereid".
|
145 |
+
- Organize the data to be like
|
146 |
+
|
147 |
+
.. code-block:: none
|
148 |
+
|
149 |
+
sensereid/
|
150 |
+
SenseReID/
|
151 |
+
test_probe/
|
152 |
+
test_gallery/
|
153 |
+
|
154 |
+
QMUL-iLIDS :math:`^\dagger` (``ilids``)
|
155 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
156 |
+
- Create a folder named "ilids" under ``$REID``.
|
157 |
+
- Download the dataset from http://www.eecs.qmul.ac.uk/~jason/data/i-LIDS_Pedestrian.tgz and organize it to look like
|
158 |
+
|
159 |
+
.. code-block:: none
|
160 |
+
|
161 |
+
ilids/
|
162 |
+
i-LIDS_Pedestrian/
|
163 |
+
Persons/
|
164 |
+
|
165 |
+
PRID (``prid``)
|
166 |
+
^^^^^^^^^^^^^^^^^^^
|
167 |
+
- Create a directory named "prid2011" under ``$REID``.
|
168 |
+
- Download the dataset from https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/PRID11/ and extract it under "prid2011".
|
169 |
+
- The data structure should end up with
|
170 |
+
|
171 |
+
.. code-block:: none
|
172 |
+
|
173 |
+
prid2011/
|
174 |
+
prid_2011/
|
175 |
+
single_shot/
|
176 |
+
multi_shot/
|
177 |
+
|
178 |
+
CUHK02 (``cuhk02``)
|
179 |
+
^^^^^^^^^^^^^^^^^^^^^
|
180 |
+
- Create a folder named "cuhk02" under ``$REID``.
|
181 |
+
- Download the data from http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html and put it under "cuhk02/".
|
182 |
+
- Extract the file so the data structure looks like
|
183 |
+
|
184 |
+
.. code-block:: none
|
185 |
+
|
186 |
+
cuhk02/
|
187 |
+
Dataset/
|
188 |
+
P1/
|
189 |
+
P2/
|
190 |
+
P3/
|
191 |
+
P4/
|
192 |
+
P5/
|
193 |
+
|
194 |
+
CUHKSYSU (``cuhksysu``)
|
195 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
196 |
+
- Create a folder named "cuhksysu" under ``$REID``.
|
197 |
+
- Download the data to "cuhksysu/" from this `google drive link <https://drive.google.com/file/d/1XmiNVrfK2ZmI0ZZ2HHT80HHbDrnE4l3W/view?usp=sharing>`_.
|
198 |
+
- Extract the zip file under "cuhksysu/".
|
199 |
+
- The data structure should look like
|
200 |
+
|
201 |
+
.. code-block:: none
|
202 |
+
|
203 |
+
cuhksysu/
|
204 |
+
cropped_images
|
205 |
+
|
206 |
+
|
207 |
+
Video Datasets
|
208 |
+
--------------
|
209 |
+
|
210 |
+
MARS (``mars``)
|
211 |
+
^^^^^^^^^^^^^^^^^
|
212 |
+
- Create "mars/" under ``$REID``.
|
213 |
+
- Download the dataset from http://www.liangzheng.com.cn/Project/project_mars.html and place it in "mars/".
|
214 |
+
- Extract "bbox_train.zip" and "bbox_test.zip".
|
215 |
+
- Download the split metadata from https://github.com/liangzheng06/MARS-evaluation/tree/master/info and put "info/" in "mars/".
|
216 |
+
- The data structure should end up with
|
217 |
+
|
218 |
+
.. code-block:: none
|
219 |
+
|
220 |
+
mars/
|
221 |
+
bbox_test/
|
222 |
+
bbox_train/
|
223 |
+
info/
|
224 |
+
|
225 |
+
iLIDS-VID :math:`^\dagger` (``ilidsvid``)
|
226 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
227 |
+
- Create "ilids-vid" under ``$REID``.
|
228 |
+
- Download the dataset from https://xiatian-zhu.github.io/downloads_qmul_iLIDS-VID_ReID_dataset.html to "ilids-vid".
|
229 |
+
- Organize the data structure to match
|
230 |
+
|
231 |
+
.. code-block:: none
|
232 |
+
|
233 |
+
ilids-vid/
|
234 |
+
i-LIDS-VID/
|
235 |
+
train-test people splits/
|
236 |
+
|
237 |
+
PRID2011 (``prid2011``)
|
238 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^
|
239 |
+
- Create a directory named "prid2011" under ``$REID``.
|
240 |
+
- Download the dataset from https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/PRID11/ and extract it under "prid2011".
|
241 |
+
- Download the split created by *iLIDS-VID* from `this google drive <https://drive.google.com/open?id=1qw7SI7YdIgfHetIQO7LLW4SHpL_qkieT>`_ and put it under "prid2011/". Following the standard protocol, only 178 persons whose sequences are more than a threshold are used.
|
242 |
+
- The data structure should end up with
|
243 |
+
|
244 |
+
.. code-block:: none
|
245 |
+
|
246 |
+
prid2011/
|
247 |
+
splits_prid2011.json
|
248 |
+
prid_2011/
|
249 |
+
single_shot/
|
250 |
+
multi_shot/
|
251 |
+
|
252 |
+
DukeMTMC-VideoReID :math:`^\dagger` (``dukemtmcvidreid``)
|
253 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
254 |
+
- Create "dukemtmc-vidreid" under ``$REID``.
|
255 |
+
- Download "DukeMTMC-VideoReID" from http://vision.cs.duke.edu/DukeMTMC/ and unzip the file to "dukemtmc-vidreid/".
|
256 |
+
- The data structure should look like
|
257 |
+
|
258 |
+
.. code-block:: none
|
259 |
+
|
260 |
+
dukemtmc-vidreid/
|
261 |
+
DukeMTMC-VideoReID/
|
262 |
+
train/
|
263 |
+
query/
|
264 |
+
gallery/
|
strong_sort/deep/reid/docs/evaluation.rst
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Evaluation
|
2 |
+
==========
|
3 |
+
|
4 |
+
Image ReID
|
5 |
+
-----------
|
6 |
+
- **Market1501**, **DukeMTMC-reID**, **CUHK03 (767/700 split)** and **MSMT17** have fixed split so keeping ``split_id=0`` is fine.
|
7 |
+
- **CUHK03 (classic split)** has 20 fixed splits, so do ``split_id=0~19``.
|
8 |
+
- **VIPeR** contains 632 identities each with 2 images under two camera views. Evaluation should be done for 10 random splits. Each split randomly divides 632 identities to 316 train ids (632 images) and the other 316 test ids (632 images). Note that, in each random split, there are two sub-splits, one using camera-A as query and camera-B as gallery while the other one using camera-B as query and camera-A as gallery. Thus, there are totally 20 splits generated with ``split_id`` starting from 0 to 19. Models can be trained on ``split_id=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]`` (because ``split_id=0`` and ``split_id=1`` share the same train set, and so on and so forth.). At test time, models trained on ``split_id=0`` can be directly evaluated on ``split_id=1``, models trained on ``split_id=2`` can be directly evaluated on ``split_id=3``, and so on and so forth.
|
9 |
+
- **CUHK01** is similar to VIPeR in the split generation.
|
10 |
+
- **GRID** , **iLIDS** and **PRID** have 10 random splits, so evaluation should be done by varying ``split_id`` from 0 to 9.
|
11 |
+
- **SenseReID** has no training images and is used for evaluation only.
|
12 |
+
|
13 |
+
|
14 |
+
.. note::
|
15 |
+
The ``split_id`` argument is defined in ``ImageDataManager`` and ``VideoDataManager``. Please refer to :ref:`torchreid_data`.
|
16 |
+
|
17 |
+
|
18 |
+
Video ReID
|
19 |
+
-----------
|
20 |
+
- **MARS** and **DukeMTMC-VideoReID** have fixed single split so using ``split_id=0`` is ok.
|
21 |
+
- **iLIDS-VID** and **PRID2011** have 10 predefined splits so evaluation should be done by varying ``split_id`` from 0 to 9.
|
strong_sort/deep/reid/docs/figures/actmap.jpg
ADDED
strong_sort/deep/reid/docs/figures/ranking_results.jpg
ADDED
strong_sort/deep/reid/docs/index.rst
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. include:: ../README.rst
|
2 |
+
|
3 |
+
|
4 |
+
.. toctree::
|
5 |
+
:hidden:
|
6 |
+
|
7 |
+
user_guide
|
8 |
+
datasets
|
9 |
+
evaluation
|
10 |
+
|
11 |
+
.. toctree::
|
12 |
+
:caption: Package Reference
|
13 |
+
:hidden:
|
14 |
+
|
15 |
+
pkg/data
|
16 |
+
pkg/engine
|
17 |
+
pkg/losses
|
18 |
+
pkg/metrics
|
19 |
+
pkg/models
|
20 |
+
pkg/optim
|
21 |
+
pkg/utils
|
22 |
+
|
23 |
+
.. toctree::
|
24 |
+
:caption: Resources
|
25 |
+
:hidden:
|
26 |
+
|
27 |
+
AWESOME_REID.md
|
28 |
+
MODEL_ZOO.md
|
29 |
+
|
30 |
+
|
31 |
+
Indices and tables
|
32 |
+
==================
|
33 |
+
|
34 |
+
* :ref:`genindex`
|
35 |
+
* :ref:`modindex`
|
strong_sort/deep/reid/docs/pkg/data.rst
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. _torchreid_data:
|
2 |
+
|
3 |
+
torchreid.data
|
4 |
+
==============
|
5 |
+
|
6 |
+
|
7 |
+
Data Manager
|
8 |
+
---------------------------
|
9 |
+
|
10 |
+
.. automodule:: torchreid.data.datamanager
|
11 |
+
:members:
|
12 |
+
|
13 |
+
|
14 |
+
Sampler
|
15 |
+
-----------------------
|
16 |
+
|
17 |
+
.. automodule:: torchreid.data.sampler
|
18 |
+
:members:
|
19 |
+
|
20 |
+
|
21 |
+
Transforms
|
22 |
+
---------------------------
|
23 |
+
|
24 |
+
.. automodule:: torchreid.data.transforms
|
25 |
+
:members:
|
26 |
+
|
27 |
+
|
28 |
+
Dataset
|
29 |
+
---------------------------
|
30 |
+
|
31 |
+
.. automodule:: torchreid.data.datasets.dataset
|
32 |
+
:members:
|
33 |
+
|
34 |
+
|
35 |
+
.. automodule:: torchreid.data.datasets.__init__
|
36 |
+
:members:
|
37 |
+
|
38 |
+
|
39 |
+
Image Datasets
|
40 |
+
------------------------------
|
41 |
+
|
42 |
+
.. automodule:: torchreid.data.datasets.image.market1501
|
43 |
+
:members:
|
44 |
+
|
45 |
+
.. automodule:: torchreid.data.datasets.image.cuhk03
|
46 |
+
:members:
|
47 |
+
|
48 |
+
.. automodule:: torchreid.data.datasets.image.dukemtmcreid
|
49 |
+
:members:
|
50 |
+
|
51 |
+
.. automodule:: torchreid.data.datasets.image.msmt17
|
52 |
+
:members:
|
53 |
+
|
54 |
+
.. automodule:: torchreid.data.datasets.image.viper
|
55 |
+
:members:
|
56 |
+
|
57 |
+
.. automodule:: torchreid.data.datasets.image.grid
|
58 |
+
:members:
|
59 |
+
|
60 |
+
.. automodule:: torchreid.data.datasets.image.cuhk01
|
61 |
+
:members:
|
62 |
+
|
63 |
+
.. automodule:: torchreid.data.datasets.image.ilids
|
64 |
+
:members:
|
65 |
+
|
66 |
+
.. automodule:: torchreid.data.datasets.image.sensereid
|
67 |
+
:members:
|
68 |
+
|
69 |
+
.. automodule:: torchreid.data.datasets.image.prid
|
70 |
+
:members:
|
71 |
+
|
72 |
+
|
73 |
+
Video Datasets
|
74 |
+
------------------------------
|
75 |
+
|
76 |
+
.. automodule:: torchreid.data.datasets.video.mars
|
77 |
+
:members:
|
78 |
+
|
79 |
+
.. automodule:: torchreid.data.datasets.video.ilidsvid
|
80 |
+
:members:
|
81 |
+
|
82 |
+
.. automodule:: torchreid.data.datasets.video.prid2011
|
83 |
+
:members:
|
84 |
+
|
85 |
+
.. automodule:: torchreid.data.datasets.video.dukemtmcvidreid
|
86 |
+
:members:
|
strong_sort/deep/reid/docs/pkg/engine.rst
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. _torchreid_engine:
|
2 |
+
|
3 |
+
torchreid.engine
|
4 |
+
==================
|
5 |
+
|
6 |
+
|
7 |
+
Base Engine
|
8 |
+
------------
|
9 |
+
|
10 |
+
.. autoclass:: torchreid.engine.engine.Engine
|
11 |
+
:members:
|
12 |
+
|
13 |
+
|
14 |
+
Image Engines
|
15 |
+
-------------
|
16 |
+
|
17 |
+
.. autoclass:: torchreid.engine.image.softmax.ImageSoftmaxEngine
|
18 |
+
:members:
|
19 |
+
|
20 |
+
|
21 |
+
.. autoclass:: torchreid.engine.image.triplet.ImageTripletEngine
|
22 |
+
:members:
|
23 |
+
|
24 |
+
|
25 |
+
Video Engines
|
26 |
+
-------------
|
27 |
+
|
28 |
+
.. autoclass:: torchreid.engine.video.softmax.VideoSoftmaxEngine
|
29 |
+
|
30 |
+
|
31 |
+
.. autoclass:: torchreid.engine.video.triplet.VideoTripletEngine
|
strong_sort/deep/reid/docs/pkg/losses.rst
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. _torchreid_losses:
|
2 |
+
|
3 |
+
torchreid.losses
|
4 |
+
=================
|
5 |
+
|
6 |
+
|
7 |
+
Softmax
|
8 |
+
--------
|
9 |
+
|
10 |
+
.. automodule:: torchreid.losses.cross_entropy_loss
|
11 |
+
:members:
|
12 |
+
|
13 |
+
|
14 |
+
Triplet
|
15 |
+
-------
|
16 |
+
|
17 |
+
.. automodule:: torchreid.losses.hard_mine_triplet_loss
|
18 |
+
:members:
|
strong_sort/deep/reid/docs/pkg/metrics.rst
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. _torchreid_metrics:
|
2 |
+
|
3 |
+
torchreid.metrics
|
4 |
+
=================
|
5 |
+
|
6 |
+
|
7 |
+
Distance
|
8 |
+
---------
|
9 |
+
|
10 |
+
.. automodule:: torchreid.metrics.distance
|
11 |
+
:members:
|
12 |
+
|
13 |
+
|
14 |
+
Accuracy
|
15 |
+
--------
|
16 |
+
|
17 |
+
.. automodule:: torchreid.metrics.accuracy
|
18 |
+
:members:
|
19 |
+
|
20 |
+
|
21 |
+
Rank
|
22 |
+
-----
|
23 |
+
|
24 |
+
.. automodule:: torchreid.metrics.rank
|
25 |
+
:members: evaluate_rank
|
strong_sort/deep/reid/docs/pkg/models.rst
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. _torchreid_models:
|
2 |
+
|
3 |
+
torchreid.models
|
4 |
+
=================
|
5 |
+
|
6 |
+
Interface
|
7 |
+
---------
|
8 |
+
|
9 |
+
.. automodule:: torchreid.models.__init__
|
10 |
+
:members:
|
11 |
+
|
12 |
+
|
13 |
+
ImageNet Classification Models
|
14 |
+
-------------------------------
|
15 |
+
|
16 |
+
.. autoclass:: torchreid.models.resnet.ResNet
|
17 |
+
.. autoclass:: torchreid.models.senet.SENet
|
18 |
+
.. autoclass:: torchreid.models.densenet.DenseNet
|
19 |
+
.. autoclass:: torchreid.models.inceptionresnetv2.InceptionResNetV2
|
20 |
+
.. autoclass:: torchreid.models.inceptionv4.InceptionV4
|
21 |
+
.. autoclass:: torchreid.models.xception.Xception
|
22 |
+
|
23 |
+
|
24 |
+
Lightweight Models
|
25 |
+
------------------
|
26 |
+
|
27 |
+
.. autoclass:: torchreid.models.nasnet.NASNetAMobile
|
28 |
+
.. autoclass:: torchreid.models.mobilenetv2.MobileNetV2
|
29 |
+
.. autoclass:: torchreid.models.shufflenet.ShuffleNet
|
30 |
+
.. autoclass:: torchreid.models.squeezenet.SqueezeNet
|
31 |
+
.. autoclass:: torchreid.models.shufflenetv2.ShuffleNetV2
|
32 |
+
|
33 |
+
|
34 |
+
ReID-specific Models
|
35 |
+
--------------------
|
36 |
+
|
37 |
+
.. autoclass:: torchreid.models.mudeep.MuDeep
|
38 |
+
.. autoclass:: torchreid.models.resnetmid.ResNetMid
|
39 |
+
.. autoclass:: torchreid.models.hacnn.HACNN
|
40 |
+
.. autoclass:: torchreid.models.pcb.PCB
|
41 |
+
.. autoclass:: torchreid.models.mlfn.MLFN
|
42 |
+
.. autoclass:: torchreid.models.osnet.OSNet
|
43 |
+
.. autoclass:: torchreid.models.osnet_ain.OSNet
|
strong_sort/deep/reid/docs/pkg/optim.rst
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. _torchreid_optim:
|
2 |
+
|
3 |
+
torchreid.optim
|
4 |
+
=================
|
5 |
+
|
6 |
+
|
7 |
+
Optimizer
|
8 |
+
----------
|
9 |
+
|
10 |
+
.. automodule:: torchreid.optim.optimizer
|
11 |
+
:members: build_optimizer
|
12 |
+
|
13 |
+
|
14 |
+
LR Scheduler
|
15 |
+
-------------
|
16 |
+
|
17 |
+
.. automodule:: torchreid.optim.lr_scheduler
|
18 |
+
:members: build_lr_scheduler
|
strong_sort/deep/reid/docs/pkg/utils.rst
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. _torchreid_utils:
|
2 |
+
|
3 |
+
torchreid.utils
|
4 |
+
=================
|
5 |
+
|
6 |
+
Average Meter
|
7 |
+
--------------
|
8 |
+
|
9 |
+
.. automodule:: torchreid.utils.avgmeter
|
10 |
+
:members:
|
11 |
+
|
12 |
+
|
13 |
+
Loggers
|
14 |
+
-------
|
15 |
+
|
16 |
+
.. automodule:: torchreid.utils.loggers
|
17 |
+
:members:
|
18 |
+
|
19 |
+
|
20 |
+
Generic Tools
|
21 |
+
---------------
|
22 |
+
.. automodule:: torchreid.utils.tools
|
23 |
+
:members:
|
24 |
+
|
25 |
+
|
26 |
+
ReID Tools
|
27 |
+
----------
|
28 |
+
|
29 |
+
.. automodule:: torchreid.utils.reidtools
|
30 |
+
:members:
|
31 |
+
|
32 |
+
|
33 |
+
Torch Tools
|
34 |
+
------------
|
35 |
+
|
36 |
+
.. automodule:: torchreid.utils.torchtools
|
37 |
+
:members:
|
38 |
+
|
39 |
+
|
40 |
+
.. automodule:: torchreid.utils.model_complexity
|
41 |
+
:members:
|
strong_sort/deep/reid/docs/user_guide.rst
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
How-to
|
2 |
+
============
|
3 |
+
|
4 |
+
.. contents::
|
5 |
+
:local:
|
6 |
+
|
7 |
+
|
8 |
+
Prepare datasets
|
9 |
+
-----------------
|
10 |
+
See :ref:`datasets`.
|
11 |
+
|
12 |
+
|
13 |
+
Find model keys
|
14 |
+
-----------------
|
15 |
+
Keys are listed under the *Public keys* section within each model class in :ref:`torchreid_models`.
|
16 |
+
|
17 |
+
|
18 |
+
Show available models
|
19 |
+
----------------------
|
20 |
+
|
21 |
+
.. code-block:: python
|
22 |
+
|
23 |
+
import torchreid
|
24 |
+
torchreid.models.show_avai_models()
|
25 |
+
|
26 |
+
|
27 |
+
Change the training sampler
|
28 |
+
-----------------------------
|
29 |
+
The default ``train_sampler`` is "RandomSampler". You can give the specific sampler name as input to ``train_sampler``, e.g. ``train_sampler='RandomIdentitySampler'`` for triplet loss.
|
30 |
+
|
31 |
+
|
32 |
+
Choose an optimizer/lr_scheduler
|
33 |
+
----------------------------------
|
34 |
+
Please refer to the source code of ``build_optimizer``/``build_lr_scheduler`` in :ref:`torchreid_optim` for details.
|
35 |
+
|
36 |
+
|
37 |
+
Resume training
|
38 |
+
----------------
|
39 |
+
Suppose the checkpoint is saved in "log/resnet50/model.pth.tar-30", you can do
|
40 |
+
|
41 |
+
.. code-block:: python
|
42 |
+
|
43 |
+
start_epoch = torchreid.utils.resume_from_checkpoint(
|
44 |
+
'log/resnet50/model.pth.tar-30',
|
45 |
+
model,
|
46 |
+
optimizer
|
47 |
+
)
|
48 |
+
|
49 |
+
engine.run(
|
50 |
+
save_dir='log/resnet50',
|
51 |
+
max_epoch=60,
|
52 |
+
start_epoch=start_epoch
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
Compute model complexity
|
57 |
+
--------------------------
|
58 |
+
We provide a tool in ``torchreid.utils.model_complexity.py`` to automatically compute the model complexity, i.e. number of parameters and FLOPs.
|
59 |
+
|
60 |
+
.. code-block:: python
|
61 |
+
|
62 |
+
from torchreid import models, utils
|
63 |
+
|
64 |
+
model = models.build_model(name='resnet50', num_classes=1000)
|
65 |
+
num_params, flops = utils.compute_model_complexity(model, (1, 3, 256, 128))
|
66 |
+
|
67 |
+
# show detailed complexity for each module
|
68 |
+
utils.compute_model_complexity(model, (1, 3, 256, 128), verbose=True)
|
69 |
+
|
70 |
+
# count flops for all layers including ReLU and BatchNorm
|
71 |
+
utils.compute_model_complexity(model, (1, 3, 256, 128), verbose=True, only_conv_linear=False)
|
72 |
+
|
73 |
+
Note that (1) this function only provides an estimate of the theoretical time complexity rather than the actual running time which depends on implementations and hardware; (2) the FLOPs is only counted for layers that are used at test time. This means that redundant layers such as person ID classification layer will be ignored. The inference graph depends on how you define the computations in ``forward()``.
|
74 |
+
|
75 |
+
|
76 |
+
Combine multiple datasets
|
77 |
+
---------------------------
|
78 |
+
Easy. Just give whatever datasets (keys) you want to the ``sources`` argument when instantiating a data manager. For example,
|
79 |
+
|
80 |
+
.. code-block:: python
|
81 |
+
|
82 |
+
datamanager = torchreid.data.ImageDataManager(
|
83 |
+
root='reid-data',
|
84 |
+
sources=['market1501', 'dukemtmcreid', 'cuhk03', 'msmt17'],
|
85 |
+
height=256,
|
86 |
+
width=128,
|
87 |
+
batch_size=32
|
88 |
+
)
|
89 |
+
|
90 |
+
In this example, the target datasets are Market1501, DukeMTMC-reID, CUHK03 and MSMT17 as the ``targets`` argument is not specified. Please refer to ``Engine.test()`` in :ref:`torchreid_engine` for details regarding how evaluation is performed.
|
91 |
+
|
92 |
+
|
93 |
+
Do cross-dataset evaluation
|
94 |
+
-----------------------------
|
95 |
+
Easy. Just give whatever datasets (keys) you want to the argument ``targets``, like
|
96 |
+
|
97 |
+
.. code-block:: python
|
98 |
+
|
99 |
+
datamanager = torchreid.data.ImageDataManager(
|
100 |
+
root='reid-data',
|
101 |
+
sources='market1501',
|
102 |
+
targets='dukemtmcreid', # or targets='cuhk03' or targets=['dukemtmcreid', 'cuhk03']
|
103 |
+
height=256,
|
104 |
+
width=128,
|
105 |
+
batch_size=32
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
Combine train, query and gallery
|
110 |
+
---------------------------------
|
111 |
+
This can be easily done by setting ``combineall=True`` when instantiating a data manager. Below is an example of using Market1501,
|
112 |
+
|
113 |
+
.. code-block:: python
|
114 |
+
|
115 |
+
datamanager = torchreid.data.ImageDataManager(
|
116 |
+
root='reid-data',
|
117 |
+
sources='market1501',
|
118 |
+
height=256,
|
119 |
+
width=128,
|
120 |
+
batch_size=32,
|
121 |
+
market1501_500k=False,
|
122 |
+
combineall=True # it's me, here
|
123 |
+
)
|
124 |
+
|
125 |
+
More specifically, with ``combineall=False``, you will get
|
126 |
+
|
127 |
+
.. code-block:: none
|
128 |
+
|
129 |
+
=> Loaded Market1501
|
130 |
+
----------------------------------------
|
131 |
+
subset | # ids | # images | # cameras
|
132 |
+
----------------------------------------
|
133 |
+
train | 751 | 12936 | 6
|
134 |
+
query | 750 | 3368 | 6
|
135 |
+
gallery | 751 | 15913 | 6
|
136 |
+
---------------------------------------
|
137 |
+
|
138 |
+
with ``combineall=True``, you will get
|
139 |
+
|
140 |
+
.. code-block:: none
|
141 |
+
|
142 |
+
=> Loaded Market1501
|
143 |
+
----------------------------------------
|
144 |
+
subset | # ids | # images | # cameras
|
145 |
+
----------------------------------------
|
146 |
+
train | 1501 | 29419 | 6
|
147 |
+
query | 750 | 3368 | 6
|
148 |
+
gallery | 751 | 15913 | 6
|
149 |
+
---------------------------------------
|
150 |
+
|
151 |
+
|
152 |
+
Optimize layers with different learning rates
|
153 |
+
-----------------------------------------------
|
154 |
+
A common practice for fine-tuning pretrained models is to use a smaller learning rate for base layers and a large learning rate for randomly initialized layers (referred to as ``new_layers``). ``torchreid.optim.optimizer`` has implemented such feature. What you need to do is to set ``staged_lr=True`` and give the names of ``new_layers`` such as "classifier".
|
155 |
+
|
156 |
+
Below is an example of setting different learning rates for base layers and new layers in ResNet50,
|
157 |
+
|
158 |
+
.. code-block:: python
|
159 |
+
|
160 |
+
# New layer "classifier" has a learning rate of 0.01
|
161 |
+
# The base layers have a learning rate of 0.001
|
162 |
+
optimizer = torchreid.optim.build_optimizer(
|
163 |
+
model,
|
164 |
+
optim='sgd',
|
165 |
+
lr=0.01,
|
166 |
+
staged_lr=True,
|
167 |
+
new_layers='classifier',
|
168 |
+
base_lr_mult=0.1
|
169 |
+
)
|
170 |
+
|
171 |
+
Please refer to :ref:`torchreid_optim` for more details.
|
172 |
+
|
173 |
+
|
174 |
+
Do two-stepped transfer learning
|
175 |
+
-------------------------------------
|
176 |
+
To prevent the pretrained layers from being damaged by harmful gradients back-propagated from randomly initialized layers, one can adopt the *two-stepped transfer learning strategy* presented in `Deep Transfer Learning for Person Re-identification <https://arxiv.org/abs/1611.05244>`_. The basic idea is to pretrain the randomly initialized layers for few epochs while keeping the base layers frozen before training all layers end-to-end.
|
177 |
+
|
178 |
+
This has been implemented in ``Engine.train()`` (see :ref:`torchreid_engine`). The arguments related to this feature are ``fixbase_epoch`` and ``open_layers``. Intuitively, ``fixbase_epoch`` denotes the number of epochs to keep the base layers frozen; ``open_layers`` means which layer is open for training.
|
179 |
+
|
180 |
+
For example, say you want to pretrain the classification layer named "classifier" in ResNet50 for 5 epochs before training all layers, you can do
|
181 |
+
|
182 |
+
.. code-block:: python
|
183 |
+
|
184 |
+
engine.run(
|
185 |
+
save_dir='log/resnet50',
|
186 |
+
max_epoch=60,
|
187 |
+
eval_freq=10,
|
188 |
+
print_freq=10,
|
189 |
+
test_only=False,
|
190 |
+
fixbase_epoch=5,
|
191 |
+
open_layers='classifier'
|
192 |
+
)
|
193 |
+
# or open_layers=['fc', 'classifier'] if there is another fc layer that
|
194 |
+
# is randomly initialized, like resnet50_fc512
|
195 |
+
|
196 |
+
Note that ``fixbase_epoch`` is counted into ``max_epoch``. In the above example, the base network will be fixed for 5 epochs and then open for training for 55 epochs. Thus, if you want to freeze some layers throughout the training, what you can do is to set ``fixbase_epoch`` equal to ``max_epoch`` and put the layer names in ``open_layers`` which you want to train.
|
197 |
+
|
198 |
+
|
199 |
+
Test a trained model
|
200 |
+
----------------------
|
201 |
+
You can load a trained model using :code:`torchreid.utils.load_pretrained_weights(model, weight_path)` and set ``test_only=True`` in ``engine.run()``.
|
202 |
+
|
203 |
+
|
204 |
+
Fine-tune a model pre-trained on reid datasets
|
205 |
+
-----------------------------------------------
|
206 |
+
Use :code:`torchreid.utils.load_pretrained_weights(model, weight_path)` to load the pre-trained weights and then fine-tune on the dataset you want.
|
207 |
+
|
208 |
+
|
209 |
+
Visualize learning curves with tensorboard
|
210 |
+
--------------------------------------------
|
211 |
+
The ``SummaryWriter()`` for tensorboard will be automatically initialized in ``engine.run()`` when you are training your model. Therefore, you do not need to do extra jobs. After the training is done, the ``*tf.events*`` file will be saved in ``save_dir``. Then, you just call ``tensorboard --logdir=your_save_dir`` in your terminal and visit ``http://localhost:6006/`` in a web browser. See `pytorch tensorboard <https://pytorch.org/docs/stable/tensorboard.html>`_ for further information.
|
212 |
+
|
213 |
+
|
214 |
+
Visualize ranking results
|
215 |
+
---------------------------
|
216 |
+
This can be achieved by setting ``visrank`` to true in ``engine.run()``. ``visrank_topk`` determines the top-k images to be visualized (Default is ``visrank_topk=10``). Note that ``visrank`` can only be used in test mode, i.e. ``test_only=True`` in ``engine.run()``. The output will be saved under ``save_dir/visrank_DATASETNAME`` where each plot contains the top-k similar gallery images given a query. An example is shown below where red and green denote incorrect and correct matches respectively.
|
217 |
+
|
218 |
+
.. image:: figures/ranking_results.jpg
|
219 |
+
:width: 800px
|
220 |
+
:align: center
|
221 |
+
|
222 |
+
|
223 |
+
Visualize activation maps
|
224 |
+
--------------------------
|
225 |
+
To understand where the CNN focuses on to extract features for ReID, you can visualize the activation maps as in `OSNet <https://arxiv.org/abs/1905.00953>`_. This is implemented in ``tools/visualize_actmap.py`` (check the code for more details). An example running command is
|
226 |
+
|
227 |
+
.. code-block:: shell
|
228 |
+
|
229 |
+
python tools/visualize_actmap.py \
|
230 |
+
--root $DATA/reid \
|
231 |
+
-d market1501 \
|
232 |
+
-m osnet_x1_0 \
|
233 |
+
--weights PATH_TO_PRETRAINED_WEIGHTS \
|
234 |
+
--save-dir log/visactmap_osnet_x1_0_market1501
|
235 |
+
|
236 |
+
The output will look like (from left to right: image, activation map, overlapped image)
|
237 |
+
|
238 |
+
.. image:: figures/actmap.jpg
|
239 |
+
:width: 300px
|
240 |
+
:align: center
|
241 |
+
|
242 |
+
|
243 |
+
.. note::
|
244 |
+
In order to visualize activation maps, the CNN needs to output the last convolutional feature maps at eval mode. See ``torchreid/models/osnet.py`` for example.
|
245 |
+
|
246 |
+
|
247 |
+
Use your own dataset
|
248 |
+
----------------------
|
249 |
+
1. Write your own dataset class. Below is a template for image dataset. However, it can also be applied to a video dataset class, for which you simply change ``ImageDataset`` to ``VideoDataset``.
|
250 |
+
|
251 |
+
.. code-block:: python
|
252 |
+
|
253 |
+
from __future__ import absolute_import
|
254 |
+
from __future__ import print_function
|
255 |
+
from __future__ import division
|
256 |
+
|
257 |
+
import sys
|
258 |
+
import os
|
259 |
+
import os.path as osp
|
260 |
+
|
261 |
+
from torchreid.data import ImageDataset
|
262 |
+
|
263 |
+
|
264 |
+
class NewDataset(ImageDataset):
|
265 |
+
dataset_dir = 'new_dataset'
|
266 |
+
|
267 |
+
def __init__(self, root='', **kwargs):
|
268 |
+
self.root = osp.abspath(osp.expanduser(root))
|
269 |
+
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
270 |
+
|
271 |
+
# All you need to do here is to generate three lists,
|
272 |
+
# which are train, query and gallery.
|
273 |
+
# Each list contains tuples of (img_path, pid, camid),
|
274 |
+
# where
|
275 |
+
# - img_path (str): absolute path to an image.
|
276 |
+
# - pid (int): person ID, e.g. 0, 1.
|
277 |
+
# - camid (int): camera ID, e.g. 0, 1.
|
278 |
+
# Note that
|
279 |
+
# - pid and camid should be 0-based.
|
280 |
+
# - query and gallery should share the same pid scope (e.g.
|
281 |
+
# pid=0 in query refers to the same person as pid=0 in gallery).
|
282 |
+
# - train, query and gallery share the same camid scope (e.g.
|
283 |
+
# camid=0 in train refers to the same camera as camid=0
|
284 |
+
# in query/gallery).
|
285 |
+
train = ...
|
286 |
+
query = ...
|
287 |
+
gallery = ...
|
288 |
+
|
289 |
+
super(NewDataset, self).__init__(train, query, gallery, **kwargs)
|
290 |
+
|
291 |
+
|
292 |
+
2. Register your dataset.
|
293 |
+
|
294 |
+
.. code-block:: python
|
295 |
+
|
296 |
+
import torchreid
|
297 |
+
torchreid.data.register_image_dataset('new_dataset', NewDataset)
|
298 |
+
|
299 |
+
|
300 |
+
3. Initialize a data manager with your dataset.
|
301 |
+
|
302 |
+
.. code-block:: python
|
303 |
+
|
304 |
+
# use your own dataset only
|
305 |
+
datamanager = torchreid.data.ImageDataManager(
|
306 |
+
root='reid-data',
|
307 |
+
sources='new_dataset'
|
308 |
+
)
|
309 |
+
# combine with other datasets
|
310 |
+
datamanager = torchreid.data.ImageDataManager(
|
311 |
+
root='reid-data',
|
312 |
+
sources=['new_dataset', 'dukemtmcreid']
|
313 |
+
)
|
314 |
+
# cross-dataset evaluation
|
315 |
+
datamanager = torchreid.data.ImageDataManager(
|
316 |
+
root='reid-data',
|
317 |
+
sources=['new_dataset', 'dukemtmcreid'],
|
318 |
+
targets='market1501' # or targets=['market1501', 'cuhk03']
|
319 |
+
)
|
320 |
+
|
321 |
+
|
322 |
+
|
323 |
+
Design your own Engine
|
324 |
+
------------------------
|
325 |
+
A new Engine should be designed if you have your own loss function. The base Engine class ``torchreid.engine.Engine`` has implemented some generic methods which you can inherit to avoid re-writing. Please refer to the source code for more details. You are suggested to see how ``ImageSoftmaxEngine`` and ``ImageTripletEngine`` are constructed (also ``VideoSoftmaxEngine`` and ``VideoTripletEngine``). All you need to implement might be just a ``forward_backward()`` function.
|
326 |
+
|
327 |
+
|
328 |
+
Use Torchreid as a feature extractor in your projects
|
329 |
+
-------------------------------------------------------
|
330 |
+
We have provided a simple API for feature extraction, which accepts input of various types such as a list of image paths or numpy arrays. More details can be found in the code at ``torchreid/utils/feature_extractor.py``. Here we show a simple example of how to extract features given a list of image paths.
|
331 |
+
|
332 |
+
.. code-block:: python
|
333 |
+
|
334 |
+
from torchreid.utils import FeatureExtractor
|
335 |
+
|
336 |
+
extractor = FeatureExtractor(
|
337 |
+
model_name='osnet_x1_0',
|
338 |
+
model_path='a/b/c/model.pth.tar',
|
339 |
+
device='cuda'
|
340 |
+
)
|
341 |
+
|
342 |
+
image_list = [
|
343 |
+
'a/b/c/image001.jpg',
|
344 |
+
'a/b/c/image002.jpg',
|
345 |
+
'a/b/c/image003.jpg',
|
346 |
+
'a/b/c/image004.jpg',
|
347 |
+
'a/b/c/image005.jpg'
|
348 |
+
]
|
349 |
+
|
350 |
+
features = extractor(image_list)
|
351 |
+
print(features.shape) # output (5, 512)
|
strong_sort/deep/reid/linter.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
echo "Running isort"
|
2 |
+
isort -y -sp .
|
3 |
+
echo "Done"
|
4 |
+
|
5 |
+
echo "Running yapf"
|
6 |
+
yapf -i -r -vv -e build .
|
7 |
+
echo "Done"
|
8 |
+
|
9 |
+
echo "Running flake8"
|
10 |
+
flake8 .
|
11 |
+
echo "Done"
|
strong_sort/deep/reid/projects/DML/README.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Deep mutual learning
|
2 |
+
|
3 |
+
This repo implements [Deep Mutual Learning (CVPR'18)](https://zpascal.net/cvpr2018/Zhang_Deep_Mutual_Learning_CVPR_2018_paper.pdf) (DML) for person re-id.
|
4 |
+
|
5 |
+
We used this code in our [OSNet](https://arxiv.org/pdf/1905.00953.pdf) paper (see Supp. B). The training command to reproduce the result of "triplet + DML" (Table 12f in the paper) is
|
6 |
+
```bash
|
7 |
+
python main.py \
|
8 |
+
--config-file im_osnet_x1_0_dml_256x128_amsgrad_cosine.yaml \
|
9 |
+
--root $DATA
|
10 |
+
```
|
11 |
+
|
12 |
+
`$DATA` corresponds to the path to your dataset folder.
|
13 |
+
|
14 |
+
Change `model.deploy` to `both` if you wanna enable model ensembling.
|
15 |
+
|
16 |
+
If you have any questions, please raise an issue in the Issues area.
|
strong_sort/deep/reid/projects/DML/default_config.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yacs.config import CfgNode as CN
|
2 |
+
|
3 |
+
|
4 |
+
def get_default_config():
|
5 |
+
cfg = CN()
|
6 |
+
|
7 |
+
# model
|
8 |
+
cfg.model = CN()
|
9 |
+
cfg.model.name = 'resnet50'
|
10 |
+
cfg.model.pretrained = True # automatically load pretrained model weights if available
|
11 |
+
cfg.model.load_weights1 = '' # path to model-1 weights
|
12 |
+
cfg.model.load_weights2 = '' # path to model-2 weights
|
13 |
+
cfg.model.resume1 = '' # path to checkpoint for resume training
|
14 |
+
cfg.model.resume2 = '' # path to checkpoint for resume training
|
15 |
+
cfg.model.deploy = 'model1' # model1, model2 or both
|
16 |
+
|
17 |
+
# data
|
18 |
+
cfg.data = CN()
|
19 |
+
cfg.data.type = 'image'
|
20 |
+
cfg.data.root = 'reid-data'
|
21 |
+
cfg.data.sources = ['market1501']
|
22 |
+
cfg.data.targets = ['market1501']
|
23 |
+
cfg.data.workers = 4 # number of data loading workers
|
24 |
+
cfg.data.split_id = 0 # split index
|
25 |
+
cfg.data.height = 256 # image height
|
26 |
+
cfg.data.width = 128 # image width
|
27 |
+
cfg.data.combineall = False # combine train, query and gallery for training
|
28 |
+
cfg.data.transforms = ['random_flip'] # data augmentation
|
29 |
+
cfg.data.norm_mean = [0.485, 0.456, 0.406] # default is imagenet mean
|
30 |
+
cfg.data.norm_std = [0.229, 0.224, 0.225] # default is imagenet std
|
31 |
+
cfg.data.save_dir = 'log' # path to save log
|
32 |
+
cfg.data.load_train_targets = False
|
33 |
+
|
34 |
+
# specific datasets
|
35 |
+
cfg.market1501 = CN()
|
36 |
+
cfg.market1501.use_500k_distractors = False # add 500k distractors to the gallery set for market1501
|
37 |
+
cfg.cuhk03 = CN()
|
38 |
+
cfg.cuhk03.labeled_images = False # use labeled images, if False, use detected images
|
39 |
+
cfg.cuhk03.classic_split = False # use classic split by Li et al. CVPR14
|
40 |
+
cfg.cuhk03.use_metric_cuhk03 = False # use cuhk03's metric for evaluation
|
41 |
+
|
42 |
+
# sampler
|
43 |
+
cfg.sampler = CN()
|
44 |
+
cfg.sampler.train_sampler = 'RandomSampler'
|
45 |
+
cfg.sampler.num_instances = 4 # number of instances per identity for RandomIdentitySampler
|
46 |
+
|
47 |
+
# video reid setting
|
48 |
+
cfg.video = CN()
|
49 |
+
cfg.video.seq_len = 15 # number of images to sample in a tracklet
|
50 |
+
cfg.video.sample_method = 'evenly' # how to sample images from a tracklet
|
51 |
+
cfg.video.pooling_method = 'avg' # how to pool features over a tracklet
|
52 |
+
|
53 |
+
# train
|
54 |
+
cfg.train = CN()
|
55 |
+
cfg.train.optim = 'adam'
|
56 |
+
cfg.train.lr = 0.0003
|
57 |
+
cfg.train.weight_decay = 5e-4
|
58 |
+
cfg.train.max_epoch = 60
|
59 |
+
cfg.train.start_epoch = 0
|
60 |
+
cfg.train.batch_size = 32
|
61 |
+
cfg.train.fixbase_epoch = 0 # number of epochs to fix base layers
|
62 |
+
cfg.train.open_layers = [
|
63 |
+
'classifier'
|
64 |
+
] # layers for training while keeping others frozen
|
65 |
+
cfg.train.staged_lr = False # set different lr to different layers
|
66 |
+
cfg.train.new_layers = ['classifier'] # newly added layers with default lr
|
67 |
+
cfg.train.base_lr_mult = 0.1 # learning rate multiplier for base layers
|
68 |
+
cfg.train.lr_scheduler = 'single_step'
|
69 |
+
cfg.train.stepsize = [20] # stepsize to decay learning rate
|
70 |
+
cfg.train.gamma = 0.1 # learning rate decay multiplier
|
71 |
+
cfg.train.print_freq = 20 # print frequency
|
72 |
+
cfg.train.seed = 1 # random seed
|
73 |
+
|
74 |
+
# optimizer
|
75 |
+
cfg.sgd = CN()
|
76 |
+
cfg.sgd.momentum = 0.9 # momentum factor for sgd and rmsprop
|
77 |
+
cfg.sgd.dampening = 0. # dampening for momentum
|
78 |
+
cfg.sgd.nesterov = False # Nesterov momentum
|
79 |
+
cfg.rmsprop = CN()
|
80 |
+
cfg.rmsprop.alpha = 0.99 # smoothing constant
|
81 |
+
cfg.adam = CN()
|
82 |
+
cfg.adam.beta1 = 0.9 # exponential decay rate for first moment
|
83 |
+
cfg.adam.beta2 = 0.999 # exponential decay rate for second moment
|
84 |
+
|
85 |
+
# loss
|
86 |
+
cfg.loss = CN()
|
87 |
+
cfg.loss.name = 'triplet'
|
88 |
+
cfg.loss.softmax = CN()
|
89 |
+
cfg.loss.softmax.label_smooth = True # use label smoothing regularizer
|
90 |
+
cfg.loss.triplet = CN()
|
91 |
+
cfg.loss.triplet.margin = 0.3 # distance margin
|
92 |
+
cfg.loss.triplet.weight_t = 1. # weight to balance hard triplet loss
|
93 |
+
cfg.loss.triplet.weight_x = 0. # weight to balance cross entropy loss
|
94 |
+
cfg.loss.dml = CN()
|
95 |
+
cfg.loss.dml.weight_ml = 1. # weight for mutual learning loss
|
96 |
+
|
97 |
+
# test
|
98 |
+
cfg.test = CN()
|
99 |
+
cfg.test.batch_size = 100
|
100 |
+
cfg.test.dist_metric = 'euclidean' # distance metric, ['euclidean', 'cosine']
|
101 |
+
cfg.test.normalize_feature = False # normalize feature vectors before computing distance
|
102 |
+
cfg.test.ranks = [1, 5, 10, 20] # cmc ranks
|
103 |
+
cfg.test.evaluate = False # test only
|
104 |
+
cfg.test.eval_freq = -1 # evaluation frequency (-1 means to only test after training)
|
105 |
+
cfg.test.start_eval = 0 # start to evaluate after a specific epoch
|
106 |
+
cfg.test.rerank = False # use person re-ranking
|
107 |
+
cfg.test.visrank = False # visualize ranked results (only available when cfg.test.evaluate=True)
|
108 |
+
cfg.test.visrank_topk = 10 # top-k ranks to visualize
|
109 |
+
|
110 |
+
return cfg
|
111 |
+
|
112 |
+
|
113 |
+
def imagedata_kwargs(cfg):
|
114 |
+
return {
|
115 |
+
'root': cfg.data.root,
|
116 |
+
'sources': cfg.data.sources,
|
117 |
+
'targets': cfg.data.targets,
|
118 |
+
'height': cfg.data.height,
|
119 |
+
'width': cfg.data.width,
|
120 |
+
'transforms': cfg.data.transforms,
|
121 |
+
'norm_mean': cfg.data.norm_mean,
|
122 |
+
'norm_std': cfg.data.norm_std,
|
123 |
+
'use_gpu': cfg.use_gpu,
|
124 |
+
'split_id': cfg.data.split_id,
|
125 |
+
'combineall': cfg.data.combineall,
|
126 |
+
'load_train_targets': cfg.data.load_train_targets,
|
127 |
+
'batch_size_train': cfg.train.batch_size,
|
128 |
+
'batch_size_test': cfg.test.batch_size,
|
129 |
+
'workers': cfg.data.workers,
|
130 |
+
'num_instances': cfg.sampler.num_instances,
|
131 |
+
'train_sampler': cfg.sampler.train_sampler,
|
132 |
+
# image
|
133 |
+
'cuhk03_labeled': cfg.cuhk03.labeled_images,
|
134 |
+
'cuhk03_classic_split': cfg.cuhk03.classic_split,
|
135 |
+
'market1501_500k': cfg.market1501.use_500k_distractors,
|
136 |
+
}
|
137 |
+
|
138 |
+
|
139 |
+
def videodata_kwargs(cfg):
|
140 |
+
return {
|
141 |
+
'root': cfg.data.root,
|
142 |
+
'sources': cfg.data.sources,
|
143 |
+
'targets': cfg.data.targets,
|
144 |
+
'height': cfg.data.height,
|
145 |
+
'width': cfg.data.width,
|
146 |
+
'transforms': cfg.data.transforms,
|
147 |
+
'norm_mean': cfg.data.norm_mean,
|
148 |
+
'norm_std': cfg.data.norm_std,
|
149 |
+
'use_gpu': cfg.use_gpu,
|
150 |
+
'split_id': cfg.data.split_id,
|
151 |
+
'combineall': cfg.data.combineall,
|
152 |
+
'batch_size_train': cfg.train.batch_size,
|
153 |
+
'batch_size_test': cfg.test.batch_size,
|
154 |
+
'workers': cfg.data.workers,
|
155 |
+
'num_instances': cfg.sampler.num_instances,
|
156 |
+
'train_sampler': cfg.sampler.train_sampler,
|
157 |
+
# video
|
158 |
+
'seq_len': cfg.video.seq_len,
|
159 |
+
'sample_method': cfg.video.sample_method
|
160 |
+
}
|
161 |
+
|
162 |
+
|
163 |
+
def optimizer_kwargs(cfg):
|
164 |
+
return {
|
165 |
+
'optim': cfg.train.optim,
|
166 |
+
'lr': cfg.train.lr,
|
167 |
+
'weight_decay': cfg.train.weight_decay,
|
168 |
+
'momentum': cfg.sgd.momentum,
|
169 |
+
'sgd_dampening': cfg.sgd.dampening,
|
170 |
+
'sgd_nesterov': cfg.sgd.nesterov,
|
171 |
+
'rmsprop_alpha': cfg.rmsprop.alpha,
|
172 |
+
'adam_beta1': cfg.adam.beta1,
|
173 |
+
'adam_beta2': cfg.adam.beta2,
|
174 |
+
'staged_lr': cfg.train.staged_lr,
|
175 |
+
'new_layers': cfg.train.new_layers,
|
176 |
+
'base_lr_mult': cfg.train.base_lr_mult
|
177 |
+
}
|
178 |
+
|
179 |
+
|
180 |
+
def lr_scheduler_kwargs(cfg):
|
181 |
+
return {
|
182 |
+
'lr_scheduler': cfg.train.lr_scheduler,
|
183 |
+
'stepsize': cfg.train.stepsize,
|
184 |
+
'gamma': cfg.train.gamma,
|
185 |
+
'max_epoch': cfg.train.max_epoch
|
186 |
+
}
|
187 |
+
|
188 |
+
|
189 |
+
def engine_run_kwargs(cfg):
|
190 |
+
return {
|
191 |
+
'save_dir': cfg.data.save_dir,
|
192 |
+
'max_epoch': cfg.train.max_epoch,
|
193 |
+
'start_epoch': cfg.train.start_epoch,
|
194 |
+
'fixbase_epoch': cfg.train.fixbase_epoch,
|
195 |
+
'open_layers': cfg.train.open_layers,
|
196 |
+
'start_eval': cfg.test.start_eval,
|
197 |
+
'eval_freq': cfg.test.eval_freq,
|
198 |
+
'test_only': cfg.test.evaluate,
|
199 |
+
'print_freq': cfg.train.print_freq,
|
200 |
+
'dist_metric': cfg.test.dist_metric,
|
201 |
+
'normalize_feature': cfg.test.normalize_feature,
|
202 |
+
'visrank': cfg.test.visrank,
|
203 |
+
'visrank_topk': cfg.test.visrank_topk,
|
204 |
+
'use_metric_cuhk03': cfg.cuhk03.use_metric_cuhk03,
|
205 |
+
'ranks': cfg.test.ranks,
|
206 |
+
'rerank': cfg.test.rerank
|
207 |
+
}
|
strong_sort/deep/reid/projects/DML/dml.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, print_function, absolute_import
|
2 |
+
import torch
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from torchreid.utils import open_all_layers, open_specified_layers
|
6 |
+
from torchreid.engine import Engine
|
7 |
+
from torchreid.losses import TripletLoss, CrossEntropyLoss
|
8 |
+
|
9 |
+
|
10 |
+
class ImageDMLEngine(Engine):
|
11 |
+
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
datamanager,
|
15 |
+
model1,
|
16 |
+
optimizer1,
|
17 |
+
scheduler1,
|
18 |
+
model2,
|
19 |
+
optimizer2,
|
20 |
+
scheduler2,
|
21 |
+
margin=0.3,
|
22 |
+
weight_t=0.5,
|
23 |
+
weight_x=1.,
|
24 |
+
weight_ml=1.,
|
25 |
+
use_gpu=True,
|
26 |
+
label_smooth=True,
|
27 |
+
deploy='model1'
|
28 |
+
):
|
29 |
+
super(ImageDMLEngine, self).__init__(datamanager, use_gpu)
|
30 |
+
|
31 |
+
self.model1 = model1
|
32 |
+
self.optimizer1 = optimizer1
|
33 |
+
self.scheduler1 = scheduler1
|
34 |
+
self.register_model('model1', model1, optimizer1, scheduler1)
|
35 |
+
|
36 |
+
self.model2 = model2
|
37 |
+
self.optimizer2 = optimizer2
|
38 |
+
self.scheduler2 = scheduler2
|
39 |
+
self.register_model('model2', model2, optimizer2, scheduler2)
|
40 |
+
|
41 |
+
self.weight_t = weight_t
|
42 |
+
self.weight_x = weight_x
|
43 |
+
self.weight_ml = weight_ml
|
44 |
+
|
45 |
+
assert deploy in ['model1', 'model2', 'both']
|
46 |
+
self.deploy = deploy
|
47 |
+
|
48 |
+
self.criterion_t = TripletLoss(margin=margin)
|
49 |
+
self.criterion_x = CrossEntropyLoss(
|
50 |
+
num_classes=self.datamanager.num_train_pids,
|
51 |
+
use_gpu=self.use_gpu,
|
52 |
+
label_smooth=label_smooth
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward_backward(self, data):
|
56 |
+
imgs, pids = self.parse_data_for_train(data)
|
57 |
+
|
58 |
+
if self.use_gpu:
|
59 |
+
imgs = imgs.cuda()
|
60 |
+
pids = pids.cuda()
|
61 |
+
|
62 |
+
outputs1, features1 = self.model1(imgs)
|
63 |
+
loss1_x = self.compute_loss(self.criterion_x, outputs1, pids)
|
64 |
+
loss1_t = self.compute_loss(self.criterion_t, features1, pids)
|
65 |
+
|
66 |
+
outputs2, features2 = self.model2(imgs)
|
67 |
+
loss2_x = self.compute_loss(self.criterion_x, outputs2, pids)
|
68 |
+
loss2_t = self.compute_loss(self.criterion_t, features2, pids)
|
69 |
+
|
70 |
+
loss1_ml = self.compute_kl_div(
|
71 |
+
outputs2.detach(), outputs1, is_logit=True
|
72 |
+
)
|
73 |
+
loss2_ml = self.compute_kl_div(
|
74 |
+
outputs1.detach(), outputs2, is_logit=True
|
75 |
+
)
|
76 |
+
|
77 |
+
loss1 = 0
|
78 |
+
loss1 += loss1_x * self.weight_x
|
79 |
+
loss1 += loss1_t * self.weight_t
|
80 |
+
loss1 += loss1_ml * self.weight_ml
|
81 |
+
|
82 |
+
loss2 = 0
|
83 |
+
loss2 += loss2_x * self.weight_x
|
84 |
+
loss2 += loss2_t * self.weight_t
|
85 |
+
loss2 += loss2_ml * self.weight_ml
|
86 |
+
|
87 |
+
self.optimizer1.zero_grad()
|
88 |
+
loss1.backward()
|
89 |
+
self.optimizer1.step()
|
90 |
+
|
91 |
+
self.optimizer2.zero_grad()
|
92 |
+
loss2.backward()
|
93 |
+
self.optimizer2.step()
|
94 |
+
|
95 |
+
loss_dict = {
|
96 |
+
'loss1_x': loss1_x.item(),
|
97 |
+
'loss1_t': loss1_t.item(),
|
98 |
+
'loss1_ml': loss1_ml.item(),
|
99 |
+
'loss2_x': loss1_x.item(),
|
100 |
+
'loss2_t': loss1_t.item(),
|
101 |
+
'loss2_ml': loss1_ml.item()
|
102 |
+
}
|
103 |
+
|
104 |
+
return loss_dict
|
105 |
+
|
106 |
+
@staticmethod
|
107 |
+
def compute_kl_div(p, q, is_logit=True):
|
108 |
+
if is_logit:
|
109 |
+
p = F.softmax(p, dim=1)
|
110 |
+
q = F.softmax(q, dim=1)
|
111 |
+
return -(p * torch.log(q + 1e-8)).sum(1).mean()
|
112 |
+
|
113 |
+
def two_stepped_transfer_learning(
|
114 |
+
self, epoch, fixbase_epoch, open_layers, model=None
|
115 |
+
):
|
116 |
+
"""Two stepped transfer learning.
|
117 |
+
|
118 |
+
The idea is to freeze base layers for a certain number of epochs
|
119 |
+
and then open all layers for training.
|
120 |
+
|
121 |
+
Reference: https://arxiv.org/abs/1611.05244
|
122 |
+
"""
|
123 |
+
model1 = self.model1
|
124 |
+
model2 = self.model2
|
125 |
+
|
126 |
+
if (epoch + 1) <= fixbase_epoch and open_layers is not None:
|
127 |
+
print(
|
128 |
+
'* Only train {} (epoch: {}/{})'.format(
|
129 |
+
open_layers, epoch + 1, fixbase_epoch
|
130 |
+
)
|
131 |
+
)
|
132 |
+
open_specified_layers(model1, open_layers)
|
133 |
+
open_specified_layers(model2, open_layers)
|
134 |
+
else:
|
135 |
+
open_all_layers(model1)
|
136 |
+
open_all_layers(model2)
|
137 |
+
|
138 |
+
def extract_features(self, input):
|
139 |
+
if self.deploy == 'model1':
|
140 |
+
return self.model1(input)
|
141 |
+
|
142 |
+
elif self.deploy == 'model2':
|
143 |
+
return self.model2(input)
|
144 |
+
|
145 |
+
else:
|
146 |
+
features = []
|
147 |
+
features.append(self.model1(input))
|
148 |
+
features.append(self.model2(input))
|
149 |
+
return torch.cat(features, 1)
|
strong_sort/deep/reid/projects/DML/im_osnet_x1_0_dml_256x128_amsgrad_cosine.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: 'osnet_x1_0'
|
3 |
+
pretrained: True
|
4 |
+
deploy: 'model1'
|
5 |
+
|
6 |
+
data:
|
7 |
+
type: 'image'
|
8 |
+
sources: ['market1501']
|
9 |
+
targets: ['market1501']
|
10 |
+
height: 256
|
11 |
+
width: 128
|
12 |
+
combineall: False
|
13 |
+
transforms: ['random_flip', 'random_erase']
|
14 |
+
save_dir: 'log/osnet_x1_0_market1501_dml_cosinelr'
|
15 |
+
|
16 |
+
loss:
|
17 |
+
name: 'triplet'
|
18 |
+
softmax:
|
19 |
+
label_smooth: True
|
20 |
+
triplet:
|
21 |
+
margin: 0.3
|
22 |
+
weight_t: 0.5
|
23 |
+
weight_x: 1.
|
24 |
+
dml:
|
25 |
+
weight_ml: 1.
|
26 |
+
|
27 |
+
train:
|
28 |
+
optim: 'amsgrad'
|
29 |
+
lr: 0.0015
|
30 |
+
max_epoch: 250
|
31 |
+
batch_size: 64
|
32 |
+
fixbase_epoch: 10
|
33 |
+
open_layers: ['classifier']
|
34 |
+
lr_scheduler: 'cosine'
|
35 |
+
|
36 |
+
test:
|
37 |
+
batch_size: 300
|
38 |
+
dist_metric: 'cosine'
|
39 |
+
normalize_feature: False
|
40 |
+
evaluate: False
|
41 |
+
eval_freq: -1
|
42 |
+
rerank: False
|
strong_sort/deep/reid/projects/DML/main.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import copy
|
3 |
+
import time
|
4 |
+
import os.path as osp
|
5 |
+
import argparse
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
import torchreid
|
10 |
+
from torchreid.utils import (
|
11 |
+
Logger, check_isfile, set_random_seed, collect_env_info,
|
12 |
+
resume_from_checkpoint, load_pretrained_weights, compute_model_complexity
|
13 |
+
)
|
14 |
+
|
15 |
+
from dml import ImageDMLEngine
|
16 |
+
from default_config import (
|
17 |
+
imagedata_kwargs, optimizer_kwargs, engine_run_kwargs, get_default_config,
|
18 |
+
lr_scheduler_kwargs
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def reset_config(cfg, args):
|
23 |
+
if args.root:
|
24 |
+
cfg.data.root = args.root
|
25 |
+
if args.sources:
|
26 |
+
cfg.data.sources = args.sources
|
27 |
+
if args.targets:
|
28 |
+
cfg.data.targets = args.targets
|
29 |
+
if args.transforms:
|
30 |
+
cfg.data.transforms = args.transforms
|
31 |
+
|
32 |
+
|
33 |
+
def main():
|
34 |
+
parser = argparse.ArgumentParser(
|
35 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
'--config-file', type=str, default='', help='path to config file'
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
'-s',
|
42 |
+
'--sources',
|
43 |
+
type=str,
|
44 |
+
nargs='+',
|
45 |
+
help='source datasets (delimited by space)'
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
'-t',
|
49 |
+
'--targets',
|
50 |
+
type=str,
|
51 |
+
nargs='+',
|
52 |
+
help='target datasets (delimited by space)'
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
'--transforms', type=str, nargs='+', help='data augmentation'
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
'--root', type=str, default='', help='path to data root'
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
'opts',
|
62 |
+
default=None,
|
63 |
+
nargs=argparse.REMAINDER,
|
64 |
+
help='Modify config options using the command-line'
|
65 |
+
)
|
66 |
+
args = parser.parse_args()
|
67 |
+
|
68 |
+
cfg = get_default_config()
|
69 |
+
cfg.use_gpu = torch.cuda.is_available()
|
70 |
+
if args.config_file:
|
71 |
+
cfg.merge_from_file(args.config_file)
|
72 |
+
reset_config(cfg, args)
|
73 |
+
cfg.merge_from_list(args.opts)
|
74 |
+
set_random_seed(cfg.train.seed)
|
75 |
+
|
76 |
+
log_name = 'test.log' if cfg.test.evaluate else 'train.log'
|
77 |
+
log_name += time.strftime('-%Y-%m-%d-%H-%M-%S')
|
78 |
+
sys.stdout = Logger(osp.join(cfg.data.save_dir, log_name))
|
79 |
+
|
80 |
+
print('Show configuration\n{}\n'.format(cfg))
|
81 |
+
print('Collecting env info ...')
|
82 |
+
print('** System info **\n{}\n'.format(collect_env_info()))
|
83 |
+
|
84 |
+
if cfg.use_gpu:
|
85 |
+
torch.backends.cudnn.benchmark = True
|
86 |
+
|
87 |
+
datamanager = torchreid.data.ImageDataManager(**imagedata_kwargs(cfg))
|
88 |
+
|
89 |
+
print('Building model-1: {}'.format(cfg.model.name))
|
90 |
+
model1 = torchreid.models.build_model(
|
91 |
+
name=cfg.model.name,
|
92 |
+
num_classes=datamanager.num_train_pids,
|
93 |
+
loss=cfg.loss.name,
|
94 |
+
pretrained=cfg.model.pretrained,
|
95 |
+
use_gpu=cfg.use_gpu
|
96 |
+
)
|
97 |
+
num_params, flops = compute_model_complexity(
|
98 |
+
model1, (1, 3, cfg.data.height, cfg.data.width)
|
99 |
+
)
|
100 |
+
print('Model complexity: params={:,} flops={:,}'.format(num_params, flops))
|
101 |
+
|
102 |
+
print('Copying model-1 to model-2')
|
103 |
+
model2 = copy.deepcopy(model1)
|
104 |
+
|
105 |
+
if cfg.model.load_weights1 and check_isfile(cfg.model.load_weights1):
|
106 |
+
load_pretrained_weights(model1, cfg.model.load_weights1)
|
107 |
+
|
108 |
+
if cfg.model.load_weights2 and check_isfile(cfg.model.load_weights2):
|
109 |
+
load_pretrained_weights(model2, cfg.model.load_weights2)
|
110 |
+
|
111 |
+
if cfg.use_gpu:
|
112 |
+
model1 = nn.DataParallel(model1).cuda()
|
113 |
+
model2 = nn.DataParallel(model2).cuda()
|
114 |
+
|
115 |
+
optimizer1 = torchreid.optim.build_optimizer(
|
116 |
+
model1, **optimizer_kwargs(cfg)
|
117 |
+
)
|
118 |
+
scheduler1 = torchreid.optim.build_lr_scheduler(
|
119 |
+
optimizer1, **lr_scheduler_kwargs(cfg)
|
120 |
+
)
|
121 |
+
|
122 |
+
optimizer2 = torchreid.optim.build_optimizer(
|
123 |
+
model2, **optimizer_kwargs(cfg)
|
124 |
+
)
|
125 |
+
scheduler2 = torchreid.optim.build_lr_scheduler(
|
126 |
+
optimizer2, **lr_scheduler_kwargs(cfg)
|
127 |
+
)
|
128 |
+
|
129 |
+
if cfg.model.resume1 and check_isfile(cfg.model.resume1):
|
130 |
+
cfg.train.start_epoch = resume_from_checkpoint(
|
131 |
+
cfg.model.resume1,
|
132 |
+
model1,
|
133 |
+
optimizer=optimizer1,
|
134 |
+
scheduler=scheduler1
|
135 |
+
)
|
136 |
+
|
137 |
+
if cfg.model.resume2 and check_isfile(cfg.model.resume2):
|
138 |
+
resume_from_checkpoint(
|
139 |
+
cfg.model.resume2,
|
140 |
+
model2,
|
141 |
+
optimizer=optimizer2,
|
142 |
+
scheduler=scheduler2
|
143 |
+
)
|
144 |
+
|
145 |
+
print('Building DML-engine for image-reid')
|
146 |
+
engine = ImageDMLEngine(
|
147 |
+
datamanager,
|
148 |
+
model1,
|
149 |
+
optimizer1,
|
150 |
+
scheduler1,
|
151 |
+
model2,
|
152 |
+
optimizer2,
|
153 |
+
scheduler2,
|
154 |
+
margin=cfg.loss.triplet.margin,
|
155 |
+
weight_t=cfg.loss.triplet.weight_t,
|
156 |
+
weight_x=cfg.loss.triplet.weight_x,
|
157 |
+
weight_ml=cfg.loss.dml.weight_ml,
|
158 |
+
use_gpu=cfg.use_gpu,
|
159 |
+
label_smooth=cfg.loss.softmax.label_smooth,
|
160 |
+
deploy=cfg.model.deploy
|
161 |
+
)
|
162 |
+
engine.run(**engine_run_kwargs(cfg))
|
163 |
+
|
164 |
+
|
165 |
+
if __name__ == '__main__':
|
166 |
+
main()
|
strong_sort/deep/reid/projects/OSNet_AIN/README.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Differentiable NAS for OSNet-AIN
|
2 |
+
|
3 |
+
## Introduction
|
4 |
+
This repository contains the neural architecture search (NAS) code (based on [Torchreid](https://arxiv.org/abs/1910.10093)) for [OSNet-AIN](https://arxiv.org/abs/1910.06827), an extension of [OSNet](https://arxiv.org/abs/1905.00953) that achieves strong performance on cross-domain person re-identification (re-ID) benchmarks (*without using any target data*). OSNet-AIN builds on the idea of using [instance normalisation](https://arxiv.org/abs/1607.08022) (IN) layers to eliminate instance-specific contrast in images for domain-generalisable representation learning. This is inspired by the [neural style transfer](https://arxiv.org/abs/1703.06868) works that use IN to remove image styles. Though IN naturally suits the cross-domain person re-ID task, it still remains unclear that where to insert IN to a re-ID CNN can maximise the performance gain. To avoid exhaustively evaluating all possible designs, OSNet-AIN learns to search for the optimal OSNet+IN design from data using a differentiable NAS algorithm. For technical details, please refer to our paper at https://arxiv.org/abs/1910.06827.
|
5 |
+
|
6 |
+
<div align="center">
|
7 |
+
<img src="https://drive.google.com/uc?export=view&id=1yvVIi2Ml7WBe85Uhaa54qyG4g8z-MGEB" width="500px" />
|
8 |
+
</div>
|
9 |
+
|
10 |
+
## Training
|
11 |
+
Assume the reid data is stored at `$DATA`. Run
|
12 |
+
```
|
13 |
+
python main.py --config-file nas.yaml --root $DATA
|
14 |
+
```
|
15 |
+
|
16 |
+
The structure of the found architecture will be shown at the end of training.
|
17 |
+
|
18 |
+
The default config was designed for 8 Tesla V100 32GB GPUs. You can modify the batch size based on your device memory.
|
19 |
+
|
20 |
+
**Note** that the test result obtained at the end of architecture search is not meaningful (due to the stochastic sampling layers). Therefore, do not rely on the result to judge the model performance. Instead, you should construct the found architecture in `osnet_child.py` and re-train and evaluate the model on the reid datasets.
|
21 |
+
|
22 |
+
## Citation
|
23 |
+
If you find this code useful to your research, please consider citing the following papers.
|
24 |
+
```
|
25 |
+
@article{zhou2021osnet,
|
26 |
+
title={Learning Generalisable Omni-Scale Representations for Person Re-Identification},
|
27 |
+
author={Zhou, Kaiyang and Yang, Yongxin and Cavallaro, Andrea and Xiang, Tao},
|
28 |
+
journal={TPAMI},
|
29 |
+
year={2021}
|
30 |
+
}
|
31 |
+
|
32 |
+
@inproceedings{zhou2019osnet,
|
33 |
+
title={Omni-Scale Feature Learning for Person Re-Identification},
|
34 |
+
author={Zhou, Kaiyang and Yang, Yongxin and Cavallaro, Andrea and Xiang, Tao},
|
35 |
+
booktitle={ICCV},
|
36 |
+
year={2019}
|
37 |
+
}
|
38 |
+
```
|
strong_sort/deep/reid/projects/OSNet_AIN/default_config.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yacs.config import CfgNode as CN
|
2 |
+
|
3 |
+
|
4 |
+
def get_default_config():
|
5 |
+
cfg = CN()
|
6 |
+
|
7 |
+
# model
|
8 |
+
cfg.model = CN()
|
9 |
+
cfg.model.name = 'resnet50'
|
10 |
+
cfg.model.pretrained = True # automatically load pretrained model weights if available
|
11 |
+
cfg.model.load_weights = '' # path to model weights
|
12 |
+
cfg.model.resume = '' # path to checkpoint for resume training
|
13 |
+
|
14 |
+
# NAS
|
15 |
+
cfg.nas = CN()
|
16 |
+
cfg.nas.mc_iter = 1 # Monte Carlo sampling
|
17 |
+
cfg.nas.init_lmda = 10. # initial lambda value
|
18 |
+
cfg.nas.min_lmda = 1. # minimum lambda value
|
19 |
+
cfg.nas.lmda_decay_step = 20 # decay step for lambda
|
20 |
+
cfg.nas.lmda_decay_rate = 0.5 # decay rate for lambda
|
21 |
+
cfg.nas.fixed_lmda = False # keep lambda unchanged
|
22 |
+
|
23 |
+
# data
|
24 |
+
cfg.data = CN()
|
25 |
+
cfg.data.type = 'image'
|
26 |
+
cfg.data.root = 'reid-data'
|
27 |
+
cfg.data.sources = ['market1501']
|
28 |
+
cfg.data.targets = ['market1501']
|
29 |
+
cfg.data.workers = 4 # number of data loading workers
|
30 |
+
cfg.data.split_id = 0 # split index
|
31 |
+
cfg.data.height = 256 # image height
|
32 |
+
cfg.data.width = 128 # image width
|
33 |
+
cfg.data.combineall = False # combine train, query and gallery for training
|
34 |
+
cfg.data.transforms = ['random_flip'] # data augmentation
|
35 |
+
cfg.data.norm_mean = [0.485, 0.456, 0.406] # default is imagenet mean
|
36 |
+
cfg.data.norm_std = [0.229, 0.224, 0.225] # default is imagenet std
|
37 |
+
cfg.data.save_dir = 'log' # path to save log
|
38 |
+
|
39 |
+
# specific datasets
|
40 |
+
cfg.market1501 = CN()
|
41 |
+
cfg.market1501.use_500k_distractors = False # add 500k distractors to the gallery set for market1501
|
42 |
+
cfg.cuhk03 = CN()
|
43 |
+
cfg.cuhk03.labeled_images = False # use labeled images, if False, use detected images
|
44 |
+
cfg.cuhk03.classic_split = False # use classic split by Li et al. CVPR14
|
45 |
+
cfg.cuhk03.use_metric_cuhk03 = False # use cuhk03's metric for evaluation
|
46 |
+
|
47 |
+
# sampler
|
48 |
+
cfg.sampler = CN()
|
49 |
+
cfg.sampler.train_sampler = 'RandomSampler'
|
50 |
+
cfg.sampler.num_instances = 4 # number of instances per identity for RandomIdentitySampler
|
51 |
+
|
52 |
+
# video reid setting
|
53 |
+
cfg.video = CN()
|
54 |
+
cfg.video.seq_len = 15 # number of images to sample in a tracklet
|
55 |
+
cfg.video.sample_method = 'evenly' # how to sample images from a tracklet
|
56 |
+
cfg.video.pooling_method = 'avg' # how to pool features over a tracklet
|
57 |
+
|
58 |
+
# train
|
59 |
+
cfg.train = CN()
|
60 |
+
cfg.train.optim = 'adam'
|
61 |
+
cfg.train.lr = 0.0003
|
62 |
+
cfg.train.weight_decay = 5e-4
|
63 |
+
cfg.train.max_epoch = 60
|
64 |
+
cfg.train.start_epoch = 0
|
65 |
+
cfg.train.batch_size = 32
|
66 |
+
cfg.train.fixbase_epoch = 0 # number of epochs to fix base layers
|
67 |
+
cfg.train.open_layers = [
|
68 |
+
'classifier'
|
69 |
+
] # layers for training while keeping others frozen
|
70 |
+
cfg.train.staged_lr = False # set different lr to different layers
|
71 |
+
cfg.train.new_layers = ['classifier'] # newly added layers with default lr
|
72 |
+
cfg.train.base_lr_mult = 0.1 # learning rate multiplier for base layers
|
73 |
+
cfg.train.lr_scheduler = 'single_step'
|
74 |
+
cfg.train.stepsize = [20] # stepsize to decay learning rate
|
75 |
+
cfg.train.gamma = 0.1 # learning rate decay multiplier
|
76 |
+
cfg.train.print_freq = 20 # print frequency
|
77 |
+
cfg.train.seed = 1 # random seed
|
78 |
+
|
79 |
+
# optimizer
|
80 |
+
cfg.sgd = CN()
|
81 |
+
cfg.sgd.momentum = 0.9 # momentum factor for sgd and rmsprop
|
82 |
+
cfg.sgd.dampening = 0. # dampening for momentum
|
83 |
+
cfg.sgd.nesterov = False # Nesterov momentum
|
84 |
+
cfg.rmsprop = CN()
|
85 |
+
cfg.rmsprop.alpha = 0.99 # smoothing constant
|
86 |
+
cfg.adam = CN()
|
87 |
+
cfg.adam.beta1 = 0.9 # exponential decay rate for first moment
|
88 |
+
cfg.adam.beta2 = 0.999 # exponential decay rate for second moment
|
89 |
+
|
90 |
+
# loss
|
91 |
+
cfg.loss = CN()
|
92 |
+
cfg.loss.name = 'softmax'
|
93 |
+
cfg.loss.softmax = CN()
|
94 |
+
cfg.loss.softmax.label_smooth = True # use label smoothing regularizer
|
95 |
+
cfg.loss.triplet = CN()
|
96 |
+
cfg.loss.triplet.margin = 0.3 # distance margin
|
97 |
+
cfg.loss.triplet.weight_t = 1. # weight to balance hard triplet loss
|
98 |
+
cfg.loss.triplet.weight_x = 0. # weight to balance cross entropy loss
|
99 |
+
|
100 |
+
# test
|
101 |
+
cfg.test = CN()
|
102 |
+
cfg.test.batch_size = 100
|
103 |
+
cfg.test.dist_metric = 'euclidean' # distance metric, ['euclidean', 'cosine']
|
104 |
+
cfg.test.normalize_feature = False # normalize feature vectors before computing distance
|
105 |
+
cfg.test.ranks = [1, 5, 10, 20] # cmc ranks
|
106 |
+
cfg.test.evaluate = False # test only
|
107 |
+
cfg.test.eval_freq = -1 # evaluation frequency (-1 means to only test after training)
|
108 |
+
cfg.test.start_eval = 0 # start to evaluate after a specific epoch
|
109 |
+
cfg.test.rerank = False # use person re-ranking
|
110 |
+
cfg.test.visrank = False # visualize ranked results (only available when cfg.test.evaluate=True)
|
111 |
+
cfg.test.visrank_topk = 10 # top-k ranks to visualize
|
112 |
+
cfg.test.visactmap = False # visualize CNN activation maps
|
113 |
+
|
114 |
+
return cfg
|
115 |
+
|
116 |
+
|
117 |
+
def imagedata_kwargs(cfg):
|
118 |
+
return {
|
119 |
+
'root': cfg.data.root,
|
120 |
+
'sources': cfg.data.sources,
|
121 |
+
'targets': cfg.data.targets,
|
122 |
+
'height': cfg.data.height,
|
123 |
+
'width': cfg.data.width,
|
124 |
+
'transforms': cfg.data.transforms,
|
125 |
+
'norm_mean': cfg.data.norm_mean,
|
126 |
+
'norm_std': cfg.data.norm_std,
|
127 |
+
'use_gpu': cfg.use_gpu,
|
128 |
+
'split_id': cfg.data.split_id,
|
129 |
+
'combineall': cfg.data.combineall,
|
130 |
+
'batch_size_train': cfg.train.batch_size,
|
131 |
+
'batch_size_test': cfg.test.batch_size,
|
132 |
+
'workers': cfg.data.workers,
|
133 |
+
'num_instances': cfg.sampler.num_instances,
|
134 |
+
'train_sampler': cfg.sampler.train_sampler,
|
135 |
+
# image
|
136 |
+
'cuhk03_labeled': cfg.cuhk03.labeled_images,
|
137 |
+
'cuhk03_classic_split': cfg.cuhk03.classic_split,
|
138 |
+
'market1501_500k': cfg.market1501.use_500k_distractors,
|
139 |
+
}
|
140 |
+
|
141 |
+
|
142 |
+
def videodata_kwargs(cfg):
|
143 |
+
return {
|
144 |
+
'root': cfg.data.root,
|
145 |
+
'sources': cfg.data.sources,
|
146 |
+
'targets': cfg.data.targets,
|
147 |
+
'height': cfg.data.height,
|
148 |
+
'width': cfg.data.width,
|
149 |
+
'transforms': cfg.data.transforms,
|
150 |
+
'norm_mean': cfg.data.norm_mean,
|
151 |
+
'norm_std': cfg.data.norm_std,
|
152 |
+
'use_gpu': cfg.use_gpu,
|
153 |
+
'split_id': cfg.data.split_id,
|
154 |
+
'combineall': cfg.data.combineall,
|
155 |
+
'batch_size_train': cfg.train.batch_size,
|
156 |
+
'batch_size_test': cfg.test.batch_size,
|
157 |
+
'workers': cfg.data.workers,
|
158 |
+
'num_instances': cfg.sampler.num_instances,
|
159 |
+
'train_sampler': cfg.sampler.train_sampler,
|
160 |
+
# video
|
161 |
+
'seq_len': cfg.video.seq_len,
|
162 |
+
'sample_method': cfg.video.sample_method
|
163 |
+
}
|
164 |
+
|
165 |
+
|
166 |
+
def optimizer_kwargs(cfg):
|
167 |
+
return {
|
168 |
+
'optim': cfg.train.optim,
|
169 |
+
'lr': cfg.train.lr,
|
170 |
+
'weight_decay': cfg.train.weight_decay,
|
171 |
+
'momentum': cfg.sgd.momentum,
|
172 |
+
'sgd_dampening': cfg.sgd.dampening,
|
173 |
+
'sgd_nesterov': cfg.sgd.nesterov,
|
174 |
+
'rmsprop_alpha': cfg.rmsprop.alpha,
|
175 |
+
'adam_beta1': cfg.adam.beta1,
|
176 |
+
'adam_beta2': cfg.adam.beta2,
|
177 |
+
'staged_lr': cfg.train.staged_lr,
|
178 |
+
'new_layers': cfg.train.new_layers,
|
179 |
+
'base_lr_mult': cfg.train.base_lr_mult
|
180 |
+
}
|
181 |
+
|
182 |
+
|
183 |
+
def lr_scheduler_kwargs(cfg):
|
184 |
+
return {
|
185 |
+
'lr_scheduler': cfg.train.lr_scheduler,
|
186 |
+
'stepsize': cfg.train.stepsize,
|
187 |
+
'gamma': cfg.train.gamma,
|
188 |
+
'max_epoch': cfg.train.max_epoch
|
189 |
+
}
|
190 |
+
|
191 |
+
|
192 |
+
def engine_run_kwargs(cfg):
|
193 |
+
return {
|
194 |
+
'save_dir': cfg.data.save_dir,
|
195 |
+
'max_epoch': cfg.train.max_epoch,
|
196 |
+
'start_epoch': cfg.train.start_epoch,
|
197 |
+
'fixbase_epoch': cfg.train.fixbase_epoch,
|
198 |
+
'open_layers': cfg.train.open_layers,
|
199 |
+
'start_eval': cfg.test.start_eval,
|
200 |
+
'eval_freq': cfg.test.eval_freq,
|
201 |
+
'test_only': cfg.test.evaluate,
|
202 |
+
'print_freq': cfg.train.print_freq,
|
203 |
+
'dist_metric': cfg.test.dist_metric,
|
204 |
+
'normalize_feature': cfg.test.normalize_feature,
|
205 |
+
'visrank': cfg.test.visrank,
|
206 |
+
'visrank_topk': cfg.test.visrank_topk,
|
207 |
+
'use_metric_cuhk03': cfg.cuhk03.use_metric_cuhk03,
|
208 |
+
'ranks': cfg.test.ranks,
|
209 |
+
'rerank': cfg.test.rerank
|
210 |
+
}
|
strong_sort/deep/reid/projects/OSNet_AIN/main.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import os.path as osp
|
5 |
+
import argparse
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
import torchreid
|
10 |
+
from torchreid.utils import (
|
11 |
+
Logger, check_isfile, set_random_seed, collect_env_info,
|
12 |
+
resume_from_checkpoint, compute_model_complexity
|
13 |
+
)
|
14 |
+
|
15 |
+
import osnet_search as osnet_models
|
16 |
+
from softmax_nas import ImageSoftmaxNASEngine
|
17 |
+
from default_config import (
|
18 |
+
imagedata_kwargs, optimizer_kwargs, engine_run_kwargs, get_default_config,
|
19 |
+
lr_scheduler_kwargs
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def reset_config(cfg, args):
|
24 |
+
if args.root:
|
25 |
+
cfg.data.root = args.root
|
26 |
+
if args.sources:
|
27 |
+
cfg.data.sources = args.sources
|
28 |
+
if args.targets:
|
29 |
+
cfg.data.targets = args.targets
|
30 |
+
if args.transforms:
|
31 |
+
cfg.data.transforms = args.transforms
|
32 |
+
|
33 |
+
|
34 |
+
def main():
|
35 |
+
parser = argparse.ArgumentParser(
|
36 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
'--config-file', type=str, default='', help='path to config file'
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
'-s',
|
43 |
+
'--sources',
|
44 |
+
type=str,
|
45 |
+
nargs='+',
|
46 |
+
help='source datasets (delimited by space)'
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
'-t',
|
50 |
+
'--targets',
|
51 |
+
type=str,
|
52 |
+
nargs='+',
|
53 |
+
help='target datasets (delimited by space)'
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
'--transforms', type=str, nargs='+', help='data augmentation'
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
'--root', type=str, default='', help='path to data root'
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
'--gpu-devices',
|
63 |
+
type=str,
|
64 |
+
default='',
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
'opts',
|
68 |
+
default=None,
|
69 |
+
nargs=argparse.REMAINDER,
|
70 |
+
help='Modify config options using the command-line'
|
71 |
+
)
|
72 |
+
args = parser.parse_args()
|
73 |
+
|
74 |
+
cfg = get_default_config()
|
75 |
+
cfg.use_gpu = torch.cuda.is_available()
|
76 |
+
if args.config_file:
|
77 |
+
cfg.merge_from_file(args.config_file)
|
78 |
+
reset_config(cfg, args)
|
79 |
+
cfg.merge_from_list(args.opts)
|
80 |
+
set_random_seed(cfg.train.seed)
|
81 |
+
|
82 |
+
if cfg.use_gpu and args.gpu_devices:
|
83 |
+
# if gpu_devices is not specified, all available gpus will be used
|
84 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
|
85 |
+
log_name = 'test.log' if cfg.test.evaluate else 'train.log'
|
86 |
+
log_name += time.strftime('-%Y-%m-%d-%H-%M-%S')
|
87 |
+
sys.stdout = Logger(osp.join(cfg.data.save_dir, log_name))
|
88 |
+
|
89 |
+
print('Show configuration\n{}\n'.format(cfg))
|
90 |
+
print('Collecting env info ...')
|
91 |
+
print('** System info **\n{}\n'.format(collect_env_info()))
|
92 |
+
|
93 |
+
if cfg.use_gpu:
|
94 |
+
torch.backends.cudnn.benchmark = True
|
95 |
+
|
96 |
+
datamanager = torchreid.data.ImageDataManager(**imagedata_kwargs(cfg))
|
97 |
+
|
98 |
+
print('Building model: {}'.format(cfg.model.name))
|
99 |
+
model = osnet_models.build_model(
|
100 |
+
cfg.model.name, num_classes=datamanager.num_train_pids
|
101 |
+
)
|
102 |
+
num_params, flops = compute_model_complexity(
|
103 |
+
model, (1, 3, cfg.data.height, cfg.data.width)
|
104 |
+
)
|
105 |
+
print('Model complexity: params={:,} flops={:,}'.format(num_params, flops))
|
106 |
+
|
107 |
+
if cfg.use_gpu:
|
108 |
+
model = nn.DataParallel(model).cuda()
|
109 |
+
|
110 |
+
optimizer = torchreid.optim.build_optimizer(model, **optimizer_kwargs(cfg))
|
111 |
+
scheduler = torchreid.optim.build_lr_scheduler(
|
112 |
+
optimizer, **lr_scheduler_kwargs(cfg)
|
113 |
+
)
|
114 |
+
|
115 |
+
if cfg.model.resume and check_isfile(cfg.model.resume):
|
116 |
+
cfg.train.start_epoch = resume_from_checkpoint(
|
117 |
+
cfg.model.resume, model, optimizer=optimizer
|
118 |
+
)
|
119 |
+
|
120 |
+
print('Building NAS engine')
|
121 |
+
engine = ImageSoftmaxNASEngine(
|
122 |
+
datamanager,
|
123 |
+
model,
|
124 |
+
optimizer,
|
125 |
+
scheduler=scheduler,
|
126 |
+
use_gpu=cfg.use_gpu,
|
127 |
+
label_smooth=cfg.loss.softmax.label_smooth,
|
128 |
+
mc_iter=cfg.nas.mc_iter,
|
129 |
+
init_lmda=cfg.nas.init_lmda,
|
130 |
+
min_lmda=cfg.nas.min_lmda,
|
131 |
+
lmda_decay_step=cfg.nas.lmda_decay_step,
|
132 |
+
lmda_decay_rate=cfg.nas.lmda_decay_rate,
|
133 |
+
fixed_lmda=cfg.nas.fixed_lmda
|
134 |
+
)
|
135 |
+
engine.run(**engine_run_kwargs(cfg))
|
136 |
+
|
137 |
+
print('*** Display the found architecture ***')
|
138 |
+
if cfg.use_gpu:
|
139 |
+
model.module.build_child_graph()
|
140 |
+
else:
|
141 |
+
model.build_child_graph()
|
142 |
+
|
143 |
+
|
144 |
+
if __name__ == '__main__':
|
145 |
+
main()
|