usiddiquee commited on
Commit
e1832f4
·
1 Parent(s): 3b054ae
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +176 -0
  2. boxmot/__init__.py +21 -0
  3. boxmot/appearance/__init__.py +0 -0
  4. boxmot/appearance/backbones/__init__.py +1 -0
  5. boxmot/appearance/backbones/clip/__init__.py +1 -0
  6. boxmot/appearance/backbones/clip/clip/__init__.py +1 -0
  7. boxmot/appearance/backbones/clip/clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  8. boxmot/appearance/backbones/clip/clip/clip.py +222 -0
  9. boxmot/appearance/backbones/clip/clip/model.py +504 -0
  10. boxmot/appearance/backbones/clip/clip/simple_tokenizer.py +136 -0
  11. boxmot/appearance/backbones/clip/config/__init__.py +1 -0
  12. boxmot/appearance/backbones/clip/config/defaults.py +239 -0
  13. boxmot/appearance/backbones/clip/config/defaults_base.py +190 -0
  14. boxmot/appearance/backbones/clip/make_model.py +161 -0
  15. boxmot/appearance/backbones/clip/make_model_clipreid.py +247 -0
  16. boxmot/appearance/backbones/hacnn.py +406 -0
  17. boxmot/appearance/backbones/lmbn/__init__.py +1 -0
  18. boxmot/appearance/backbones/lmbn/attention.py +281 -0
  19. boxmot/appearance/backbones/lmbn/bnneck.py +166 -0
  20. boxmot/appearance/backbones/lmbn/lmbn_n.py +185 -0
  21. boxmot/appearance/backbones/mlfn.py +240 -0
  22. boxmot/appearance/backbones/mobilenetv2.py +246 -0
  23. boxmot/appearance/backbones/osnet.py +560 -0
  24. boxmot/appearance/backbones/osnet_ain.py +582 -0
  25. boxmot/appearance/backbones/resnet.py +517 -0
  26. boxmot/appearance/backends/base_backend.py +135 -0
  27. boxmot/appearance/backends/onnx_backend.py +42 -0
  28. boxmot/appearance/backends/openvino_backend.py +44 -0
  29. boxmot/appearance/backends/pytorch_backend.py +24 -0
  30. boxmot/appearance/backends/tensorrt_backend.py +126 -0
  31. boxmot/appearance/backends/tflite_backend.py +86 -0
  32. boxmot/appearance/backends/torchscript_backend.py +24 -0
  33. boxmot/appearance/exporters/base_exporter.py +56 -0
  34. boxmot/appearance/exporters/onnx_exporter.py +56 -0
  35. boxmot/appearance/exporters/openvino_exporter.py +26 -0
  36. boxmot/appearance/exporters/tensorrt_exporter.py +80 -0
  37. boxmot/appearance/exporters/tflite_exporter.py +37 -0
  38. boxmot/appearance/exporters/torchscript_exporter.py +15 -0
  39. boxmot/appearance/reid/__init__.py +16 -0
  40. boxmot/appearance/reid/auto_backend.py +128 -0
  41. boxmot/appearance/reid/config.py +73 -0
  42. boxmot/appearance/reid/export.py +227 -0
  43. boxmot/appearance/reid/factory.py +40 -0
  44. boxmot/appearance/reid/registry.py +87 -0
  45. boxmot/configs/__init__.py +1 -0
  46. boxmot/configs/boosttrack.yaml +90 -0
  47. boxmot/configs/botsort.yaml +39 -0
  48. boxmot/configs/bytetrack.yaml +24 -0
  49. boxmot/configs/deepocsort.yaml +74 -0
  50. boxmot/configs/hybridsort.yaml +49 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import subprocess
4
+ import tempfile
5
+ import shutil
6
+ from pathlib import Path
7
+ import sys
8
+ import importlib.util
9
+
10
+ # Ensure models directory exists
11
+ MODELS_DIR = Path("models")
12
+ os.makedirs(MODELS_DIR, exist_ok=True)
13
+
14
+ def ensure_dependencies():
15
+ """Ensure all required dependencies are installed."""
16
+ required_packages = [
17
+ "ultralytics",
18
+ "boxmot",
19
+ "supervision"
20
+ ]
21
+
22
+ for package in required_packages:
23
+ try:
24
+ importlib.import_module(package)
25
+ print(f"✅ {package} is installed")
26
+ except ImportError:
27
+ print(f"⚠️ {package} is not installed, attempting to install...")
28
+ subprocess.run([sys.executable, "-m", "pip", "install", package], check=True)
29
+
30
+ # Apply tracker patches if tracker_patch.py exists
31
+ def apply_patches():
32
+ patch_path = Path("tracker_patch.py")
33
+ if patch_path.exists():
34
+ spec = importlib.util.spec_from_file_location("tracker_patch", patch_path)
35
+ if spec:
36
+ module = importlib.util.module_from_spec(spec)
37
+ spec.loader.exec_module(module)
38
+ if hasattr(module, "patch_trackers"):
39
+ module.patch_trackers()
40
+ print("✅ Applied tracker patches")
41
+ else:
42
+ print("⚠️ tracker_patch.py exists but has no patch_trackers function")
43
+ else:
44
+ print("⚠️ tracker_patch.py not found, skipping patches")
45
+
46
+ def run_tracking(video_file, yolo_model, reid_model, tracking_method, conf_threshold):
47
+ """Run object tracking on the uploaded video."""
48
+ try:
49
+ # Create temporary workspace
50
+ with tempfile.TemporaryDirectory() as temp_dir:
51
+ # Prepare input
52
+ input_path = os.path.join(temp_dir, "input_video.mp4")
53
+ shutil.copy(video_file, input_path)
54
+
55
+ # Prepare output directory
56
+ output_dir = os.path.join(temp_dir, "output")
57
+ os.makedirs(output_dir, exist_ok=True)
58
+
59
+ # Build command
60
+ cmd = [
61
+ "python", "tracking/track.py",
62
+ "--yolo-model", str(MODELS_DIR / yolo_model),
63
+ "--reid-model", str(MODELS_DIR / reid_model),
64
+ "--tracking-method", tracking_method,
65
+ "--source", input_path,
66
+ "--conf", str(conf_threshold),
67
+ "--save",
68
+ "--project", output_dir,
69
+ "--name", "track",
70
+ "--exist-ok"
71
+ ]
72
+
73
+ # Special handling for OcSort
74
+ if tracking_method == "ocsort":
75
+ cmd.append("--per-class")
76
+
77
+ # Execute tracking with error handling
78
+ process = subprocess.run(
79
+ cmd,
80
+ capture_output=True,
81
+ text=True
82
+ )
83
+
84
+ # Check for errors in output
85
+ if process.returncode != 0:
86
+ error_message = process.stderr or process.stdout
87
+ return None, f"Error in tracking process: {error_message}"
88
+
89
+ # Find output video
90
+ output_files = []
91
+ for root, _, files in os.walk(output_dir):
92
+ for file in files:
93
+ if file.lower().endswith((".mp4", ".avi", ".mov")):
94
+ output_files.append(os.path.join(root, file))
95
+
96
+ if not output_files:
97
+ return None, "No output video was generated. Check if tracking was successful."
98
+
99
+ return output_files[0], "Processing completed successfully!"
100
+
101
+ except Exception as e:
102
+ return None, f"Error: {str(e)}"
103
+
104
+ # Define the Gradio interface
105
+ def process_video(video_path, yolo_model, reid_model, tracking_method, conf_threshold):
106
+ # Validate inputs
107
+ if not video_path:
108
+ return None, "Please upload a video file"
109
+
110
+ output_path, status = run_tracking(
111
+ video_path,
112
+ yolo_model,
113
+ reid_model,
114
+ tracking_method,
115
+ conf_threshold
116
+ )
117
+
118
+ return output_path, status
119
+
120
+ # Available models and tracking methods
121
+ yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
122
+ reid_models = ["osnet_x0_25_msmt17.pt"]
123
+ tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
124
+
125
+ # Ensure dependencies and apply patches at startup
126
+ ensure_dependencies()
127
+ apply_patches()
128
+
129
+ # Create the Gradio interface
130
+ with gr.Blocks(title="YOLO Object Tracking") as app:
131
+ gr.Markdown("# 🚀 YOLO Object Tracking")
132
+ gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.")
133
+
134
+ with gr.Row():
135
+ with gr.Column():
136
+ input_video = gr.Video(label="Input Video", sources=["upload"])
137
+
138
+ with gr.Group():
139
+ yolo_model = gr.Dropdown(
140
+ choices=yolo_models,
141
+ value="yolov8n.pt",
142
+ label="YOLO Model"
143
+ )
144
+ reid_model = gr.Dropdown(
145
+ choices=reid_models,
146
+ value="osnet_x0_25_msmt17.pt",
147
+ label="ReID Model"
148
+ )
149
+ tracking_method = gr.Dropdown(
150
+ choices=tracking_methods,
151
+ value="bytetrack",
152
+ label="Tracking Method"
153
+ )
154
+ conf_threshold = gr.Slider(
155
+ minimum=0.1,
156
+ maximum=0.9,
157
+ value=0.3,
158
+ step=0.05,
159
+ label="Confidence Threshold"
160
+ )
161
+
162
+ process_btn = gr.Button("Process Video", variant="primary")
163
+
164
+ with gr.Column():
165
+ output_video = gr.Video(label="Output Video with Tracking", autoplay=True)
166
+ status_text = gr.Textbox(label="Status", value="Ready to process video")
167
+
168
+ process_btn.click(
169
+ fn=process_video,
170
+ inputs=[input_video, yolo_model, reid_model, tracking_method, conf_threshold],
171
+ outputs=[output_video, status_text]
172
+ )
173
+
174
+ # Launch the app
175
+ if __name__ == "__main__":
176
+ app.launch(debug=True, share=True)
boxmot/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ __version__ = '12.0.7'
4
+
5
+ from boxmot.postprocessing.gsi import gsi
6
+ from boxmot.tracker_zoo import create_tracker, get_tracker_config
7
+ from boxmot.trackers.botsort.botsort import BotSort
8
+ from boxmot.trackers.bytetrack.bytetrack import ByteTrack
9
+ from boxmot.trackers.deepocsort.deepocsort import DeepOcSort
10
+ from boxmot.trackers.hybridsort.hybridsort import HybridSort
11
+ from boxmot.trackers.ocsort.ocsort import OcSort
12
+ from boxmot.trackers.strongsort.strongsort import StrongSort
13
+ from boxmot.trackers.imprassoc.imprassoctrack import ImprAssocTrack
14
+ from boxmot.trackers.boosttrack.boosttrack import BoostTrack
15
+
16
+
17
+ TRACKERS = ['bytetrack', 'botsort', 'strongsort', 'ocsort', 'deepocsort', 'hybridsort', 'imprassoc', 'boosttrack']
18
+
19
+ __all__ = ("__version__",
20
+ "StrongSort", "OcSort", "ByteTrack", "BotSort", "DeepOcSort", "HybridSort", "ImprAssocTrack", "BoostTrack",
21
+ "create_tracker", "get_tracker_config", "gsi")
boxmot/appearance/__init__.py ADDED
File without changes
boxmot/appearance/backbones/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
boxmot/appearance/backbones/clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
boxmot/appearance/backbones/clip/clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
boxmot/appearance/backbones/clip/clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
boxmot/appearance/backbones/clip/clip/clip.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ import hashlib
4
+ import os
5
+ import urllib
6
+ import warnings
7
+ from typing import List, Union
8
+
9
+ import torch
10
+ from PIL import Image
11
+ from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
12
+ ToTensor)
13
+ from tqdm import tqdm
14
+
15
+ from .model import build_model
16
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
17
+
18
+ try:
19
+ from torchvision.transforms import InterpolationMode
20
+ BICUBIC = InterpolationMode.BICUBIC
21
+ except ImportError:
22
+ BICUBIC = Image.BICUBIC
23
+
24
+
25
+ __all__ = ["available_models", "load", "tokenize"]
26
+ _tokenizer = _Tokenizer()
27
+
28
+ _MODELS = {
29
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", # noqa: E501
30
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", # noqa: E501
31
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", # noqa: E501
32
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", # noqa: E501
33
+ "ViT-B-32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", # noqa: E501
34
+ "ViT-B-16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", # noqa: E501
35
+ }
36
+
37
+
38
+ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
39
+ os.makedirs(root, exist_ok=True)
40
+ filename = os.path.basename(url)
41
+
42
+ expected_sha256 = url.split("/")[-2]
43
+ download_target = os.path.join(root, filename)
44
+
45
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
46
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
47
+
48
+ if os.path.isfile(download_target):
49
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
50
+ return download_target
51
+ else:
52
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
53
+
54
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
55
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
56
+ while True:
57
+ buffer = source.read(8192)
58
+ if not buffer:
59
+ break
60
+
61
+ output.write(buffer)
62
+ loop.update(len(buffer))
63
+
64
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
65
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
66
+
67
+ return download_target
68
+
69
+
70
+ def _transform(n_px):
71
+ return Compose([
72
+ Resize(n_px, interpolation=BICUBIC),
73
+ CenterCrop(n_px),
74
+ lambda image: image.convert("RGB"),
75
+ ToTensor(),
76
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
77
+ ])
78
+
79
+
80
+ def available_models() -> List[str]:
81
+ """Returns the names of available CLIP models"""
82
+ return list(_MODELS.keys())
83
+
84
+
85
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
86
+ """Load a CLIP model
87
+
88
+ Parameters
89
+ ----------
90
+ name : str
91
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
92
+
93
+ device : Union[str, torch.device]
94
+ The device to put the loaded model
95
+
96
+ jit : bool
97
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
98
+
99
+ Returns
100
+ -------
101
+ model : torch.nn.Module
102
+ The CLIP model
103
+
104
+ preprocess : Callable[[PIL.Image], torch.Tensor]
105
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
106
+ """
107
+ if name in _MODELS:
108
+ model_path = _download(_MODELS[name])
109
+ elif os.path.isfile(name):
110
+ model_path = name
111
+ else:
112
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
113
+
114
+ try:
115
+ # loading JIT archive
116
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
117
+ state_dict = None
118
+ except RuntimeError:
119
+ # loading saved state dict
120
+ if jit:
121
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
122
+ jit = False
123
+ state_dict = torch.load(model_path, map_location="cpu")
124
+
125
+ if not jit:
126
+ model = build_model(state_dict or model.state_dict()).to(device)
127
+ if str(device) == "cpu":
128
+ model.float()
129
+ return model, _transform(model.visual.input_resolution)
130
+
131
+ # patch the device names
132
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
133
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
134
+
135
+ def patch_device(module):
136
+ try:
137
+ graphs = [module.graph] if hasattr(module, "graph") else []
138
+ except RuntimeError:
139
+ graphs = []
140
+
141
+ if hasattr(module, "forward1"):
142
+ graphs.append(module.forward1.graph)
143
+
144
+ for graph in graphs:
145
+ for node in graph.findAllNodes("prim::Constant"):
146
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
147
+ node.copyAttributes(device_node)
148
+
149
+ model.apply(patch_device)
150
+ patch_device(model.encode_image)
151
+ patch_device(model.encode_text)
152
+
153
+ # patch dtype to float32 on CPU
154
+ if str(device) == "cpu":
155
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
156
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
157
+ float_node = float_input.node()
158
+
159
+ def patch_float(module):
160
+ try:
161
+ graphs = [module.graph] if hasattr(module, "graph") else []
162
+ except RuntimeError:
163
+ graphs = []
164
+
165
+ if hasattr(module, "forward1"):
166
+ graphs.append(module.forward1.graph)
167
+
168
+ for graph in graphs:
169
+ for node in graph.findAllNodes("aten::to"):
170
+ inputs = list(node.inputs())
171
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
172
+ if inputs[i].node()["value"] == 5:
173
+ inputs[i].node().copyAttributes(float_node)
174
+
175
+ model.apply(patch_float)
176
+ patch_float(model.encode_image)
177
+ patch_float(model.encode_text)
178
+
179
+ model.float()
180
+
181
+ return model, _transform(model.input_resolution.item())
182
+
183
+
184
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
185
+ """
186
+ Returns the tokenized representation of given input string(s)
187
+
188
+ Parameters
189
+ ----------
190
+ texts : Union[str, List[str]]
191
+ An input string or a list of input strings to tokenize
192
+
193
+ context_length : int
194
+ The context length to use; all CLIP models use 77 as the context length
195
+
196
+ truncate: bool
197
+ Whether to truncate the text in case its encoding is longer than the context length
198
+
199
+ Returns
200
+ -------
201
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
202
+ """
203
+ # import pdb
204
+ # pdb.set_trace()
205
+ if isinstance(texts, str):
206
+ texts = [texts] # ['a photo of a face.']
207
+
208
+ sot_token = _tokenizer.encoder["<|startoftext|>"] # 49406
209
+ eot_token = _tokenizer.encoder["<|endoftext|>"] # 49407
210
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
211
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) # 1,77
212
+
213
+ for i, tokens in enumerate(all_tokens):
214
+ if len(tokens) > context_length: # context_length 77
215
+ if truncate:
216
+ tokens = tokens[:context_length]
217
+ tokens[-1] = eot_token
218
+ else:
219
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
220
+ result[i, :len(tokens)] = torch.tensor(tokens)
221
+
222
+ return result
boxmot/appearance/backbones/clip/clip/model.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ from collections import OrderedDict
4
+ from typing import Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+
12
+ class Bottleneck(nn.Module):
13
+ expansion = 4
14
+
15
+ def __init__(self, inplanes, planes, stride=1):
16
+ super().__init__()
17
+
18
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
19
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
20
+ self.bn1 = nn.BatchNorm2d(planes)
21
+
22
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
23
+ self.bn2 = nn.BatchNorm2d(planes)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+
30
+ self.relu = nn.ReLU(inplace=True)
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.relu(self.bn1(self.conv1(x)))
46
+ out = self.relu(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.relu(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ # NCHW -> (HW)NC #32,2048,7,7 ->49, 32, 2048
70
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)
71
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 50,32,2048
72
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
73
+ x, _ = F.multi_head_attention_forward(
74
+ query=x, key=x, value=x,
75
+ embed_dim_to_check=x.shape[-1],
76
+ num_heads=self.num_heads,
77
+ q_proj_weight=self.q_proj.weight,
78
+ k_proj_weight=self.k_proj.weight,
79
+ v_proj_weight=self.v_proj.weight,
80
+ in_proj_weight=None,
81
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
82
+ bias_k=None,
83
+ bias_v=None,
84
+ add_zero_attn=False,
85
+ dropout_p=0,
86
+ out_proj_weight=self.c_proj.weight,
87
+ out_proj_bias=self.c_proj.bias,
88
+ use_separate_proj_weight=True,
89
+ training=self.training,
90
+ need_weights=False
91
+ )
92
+
93
+ return x
94
+
95
+
96
+ class ModifiedResNet(nn.Module):
97
+ """
98
+ A ResNet class that is similar to torchvision's but contains the following changes:
99
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
100
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
101
+ - The final pooling layer is a QKV attention instead of an average pool
102
+ """
103
+
104
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
105
+ super().__init__()
106
+ self.output_dim = output_dim
107
+ self.input_resolution = input_resolution
108
+
109
+ # the 3-layer stem
110
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
111
+ self.bn1 = nn.BatchNorm2d(width // 2)
112
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113
+ self.bn2 = nn.BatchNorm2d(width // 2)
114
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115
+ self.bn3 = nn.BatchNorm2d(width)
116
+ self.avgpool = nn.AvgPool2d(2)
117
+ self.relu = nn.ReLU(inplace=True)
118
+
119
+ # residual layers
120
+ self._inplanes = width # this is a *mutable* variable used during construction
121
+ self.layer1 = self._make_layer(width, layers[0])
122
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=1)
125
+ embed_dim = width * 32 # the ResNet feature dimension
126
+ self.attnpool = AttentionPool2d(input_resolution, embed_dim, heads, output_dim)
127
+
128
+ def _make_layer(self, planes, blocks, stride=1):
129
+ layers = [Bottleneck(self._inplanes, planes, stride)]
130
+
131
+ self._inplanes = planes * Bottleneck.expansion
132
+ for _ in range(1, blocks):
133
+ layers.append(Bottleneck(self._inplanes, planes))
134
+
135
+ return nn.Sequential(*layers)
136
+
137
+ def forward(self, x):
138
+ def stem(x):
139
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
140
+ x = self.relu(bn(conv(x)))
141
+ x = self.avgpool(x)
142
+ return x
143
+
144
+ x = x.type(self.conv1.weight.dtype)
145
+ x = stem(x)
146
+ x = self.layer1(x)
147
+ x = self.layer2(x)
148
+ x3 = self.layer3(x)
149
+ x4 = self.layer4(x3)
150
+ xproj = self.attnpool(x4)
151
+
152
+ return x3, x4, xproj
153
+
154
+
155
+ class LayerNorm(nn.LayerNorm):
156
+ """Subclass torch's LayerNorm to handle fp16."""
157
+
158
+ def forward(self, x: torch.Tensor):
159
+ orig_type = x.dtype
160
+ for param in self.parameters():
161
+ if param.dtype == torch.float16:
162
+ param.data = param.data.to(torch.float32)
163
+ ret = super().forward(x.to(torch.float32))
164
+ return ret.to(orig_type)
165
+
166
+
167
+ class QuickGELU(nn.Module):
168
+ def forward(self, x: torch.Tensor):
169
+ return x * torch.sigmoid(1.702 * x)
170
+
171
+
172
+ class ResidualAttentionBlock(nn.Module):
173
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
174
+ super().__init__()
175
+
176
+ self.attn = nn.MultiheadAttention(d_model, n_head)
177
+ self.ln_1 = LayerNorm(d_model)
178
+ self.mlp = nn.Sequential(OrderedDict([
179
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
180
+ ("gelu", QuickGELU()),
181
+ ("c_proj", nn.Linear(d_model * 4, d_model))
182
+ ]))
183
+ self.ln_2 = LayerNorm(d_model)
184
+ self.attn_mask = attn_mask
185
+
186
+ def attention(self, x: torch.Tensor):
187
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
188
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
189
+
190
+ def forward(self, x: torch.Tensor):
191
+ x = x + self.attention(self.ln_1(x))
192
+ x = x + self.mlp(self.ln_2(x))
193
+ return x
194
+
195
+
196
+ class Transformer(nn.Module):
197
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
198
+ super().__init__()
199
+ self.width = width
200
+ self.layers = layers
201
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
202
+
203
+ def forward(self, x: torch.Tensor):
204
+ return self.resblocks(x)
205
+
206
+
207
+ class VisionTransformer(nn.Module):
208
+ def __init__(
209
+ self,
210
+ h_resolution: int,
211
+ w_resolution: int,
212
+ patch_size: int,
213
+ stride_size: int,
214
+ width: int,
215
+ layers: int,
216
+ heads: int,
217
+ output_dim: int
218
+ ):
219
+ super().__init__()
220
+ self.h_resolution = h_resolution
221
+ self.w_resolution = w_resolution
222
+ self.output_dim = output_dim
223
+ self.conv1 = nn.Conv2d(
224
+ in_channels=3,
225
+ out_channels=width,
226
+ kernel_size=patch_size,
227
+ stride=stride_size,
228
+ bias=False
229
+ )
230
+
231
+ scale = width ** -0.5
232
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
233
+ self.positional_embedding = nn.Parameter(scale * torch.randn(h_resolution * w_resolution + 1, width))
234
+ self.ln_pre = LayerNorm(width)
235
+
236
+ self.transformer = Transformer(width, layers, heads)
237
+
238
+ self.ln_post = LayerNorm(width)
239
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
240
+
241
+ def forward(self, x: torch.Tensor, cv_emb=None):
242
+ x = self.conv1(x) # shape = [*, width, grid, grid]
243
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
244
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
245
+ x = torch.cat([self.class_embedding.to(x.dtype) +
246
+ # shape = [*, grid ** 2 + 1, width]
247
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
248
+ if cv_emb is not None:
249
+ x[:, 0] = x[:, 0] + cv_emb
250
+ x = x + self.positional_embedding.to(x.dtype)
251
+ x = self.ln_pre(x)
252
+
253
+ x = x.permute(1, 0, 2) # NLD -> LND
254
+
255
+ x11 = self.transformer.resblocks[:11](x)
256
+ x12 = self.transformer.resblocks[11](x11)
257
+ x11 = x11.permute(1, 0, 2) # LND -> NLD
258
+ x12 = x12.permute(1, 0, 2) # LND -> NLD
259
+
260
+ x12 = self.ln_post(x12)
261
+
262
+ if self.proj is not None:
263
+ xproj = x12 @ self.proj
264
+
265
+ return x11, x12, xproj
266
+
267
+
268
+ class CLIP(nn.Module):
269
+ def __init__(self,
270
+ embed_dim: int,
271
+ # vision
272
+ image_resolution: int,
273
+ vision_layers: Union[Tuple[int, int, int, int], int],
274
+ vision_width: int,
275
+ vision_patch_size: int,
276
+ vision_stride_size: int,
277
+ # text
278
+ context_length: int,
279
+ vocab_size: int,
280
+ transformer_width: int,
281
+ transformer_heads: int,
282
+ transformer_layers: int,
283
+ h_resolution: int,
284
+ w_resolution: int
285
+ ):
286
+ super().__init__()
287
+
288
+ self.context_length = context_length
289
+
290
+ if isinstance(vision_layers, (tuple, list)):
291
+ vision_heads = vision_width * 32 // 64
292
+ self.visual = ModifiedResNet(
293
+ layers=vision_layers,
294
+ output_dim=embed_dim,
295
+ heads=vision_heads,
296
+ input_resolution=h_resolution * w_resolution,
297
+ width=vision_width
298
+ )
299
+ else:
300
+ vision_heads = vision_width // 64
301
+ self.visual = VisionTransformer(
302
+ h_resolution=h_resolution,
303
+ w_resolution=w_resolution,
304
+ patch_size=vision_patch_size,
305
+ stride_size=vision_stride_size,
306
+ width=vision_width,
307
+ layers=vision_layers,
308
+ heads=vision_heads,
309
+ output_dim=embed_dim
310
+ )
311
+
312
+ self.transformer = Transformer(
313
+ width=transformer_width,
314
+ layers=transformer_layers,
315
+ heads=transformer_heads,
316
+ attn_mask=self.build_attention_mask()
317
+ )
318
+
319
+ self.vocab_size = vocab_size
320
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
321
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
322
+ self.ln_final = LayerNorm(transformer_width)
323
+
324
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
325
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
326
+
327
+ self.initialize_parameters()
328
+
329
+ def initialize_parameters(self):
330
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
331
+ nn.init.normal_(self.positional_embedding, std=0.01)
332
+
333
+ if isinstance(self.visual, ModifiedResNet):
334
+ if self.visual.attnpool is not None:
335
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
336
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
337
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
338
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
339
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
340
+
341
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
342
+ for name, param in resnet_block.named_parameters():
343
+ if name.endswith("bn3.weight"):
344
+ nn.init.zeros_(param)
345
+
346
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
347
+ attn_std = self.transformer.width ** -0.5
348
+ fc_std = (2 * self.transformer.width) ** -0.5
349
+ for block in self.transformer.resblocks:
350
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
351
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
352
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
353
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
354
+
355
+ if self.text_projection is not None:
356
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
357
+
358
+ def build_attention_mask(self):
359
+ # lazily create causal attention mask, with full attention between the vision tokens
360
+ # pytorch uses additive attention mask; fill with -inf
361
+ mask = torch.empty(self.context_length, self.context_length)
362
+ mask.fill_(float("-inf"))
363
+ mask.triu_(1) # zero out the lower diagonal
364
+ return mask
365
+
366
+ @property
367
+ def dtype(self):
368
+ return self.visual.conv1.weight.dtype
369
+
370
+ def encode_image(self, image):
371
+ return self.visual(image.type(self.dtype))
372
+
373
+ def encode_text(self, text):
374
+ x = self.token_embedding(text).type(self.dtype)
375
+
376
+ x = x + self.positional_embedding.type(self.dtype)
377
+ x = x.permute(1, 0, 2)
378
+ x = self.transformer(x)
379
+ x = x.permute(1, 0, 2)
380
+ x = self.ln_final(x).type(self.dtype)
381
+
382
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
383
+
384
+ return x
385
+
386
+ def forward(self, image, text):
387
+ image_features = self.encode_image(image)
388
+ text_features = self.encode_text(text)
389
+
390
+ # normalized features
391
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
392
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
393
+
394
+ # cosine similarity as logits
395
+ logit_scale = self.logit_scale.exp()
396
+ logits_per_image = logit_scale * image_features @ text_features.t()
397
+ logits_per_text = logit_scale * text_features @ image_features.t()
398
+
399
+ # shape = [global_batch_size, global_batch_size]
400
+ return logits_per_image, logits_per_text
401
+
402
+
403
+ def convert_weights(model: nn.Module):
404
+ """Convert applicable model parameters to fp16"""
405
+
406
+ def _convert_weights_to_fp16(l):
407
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
408
+ l.weight.data = l.weight.data.float()
409
+ if l.bias is not None:
410
+ l.bias.data = l.bias.data.float()
411
+
412
+ if isinstance(l, nn.MultiheadAttention):
413
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
414
+ tensor = getattr(l, attr)
415
+ if tensor is not None:
416
+ tensor.data = tensor.data.float()
417
+
418
+ for name in ["text_projection", "proj"]:
419
+ if hasattr(l, name):
420
+ attr = getattr(l, name)
421
+ if attr is not None:
422
+ attr.data = attr.data.float()
423
+
424
+ model.apply(_convert_weights_to_fp16)
425
+
426
+
427
+ def build_model(state_dict: dict, h_resolution: int, w_resolution: int, vision_stride_size: int):
428
+ vit = "visual.proj" in state_dict
429
+
430
+ if vit:
431
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
432
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and
433
+ k.endswith(".attn.in_proj_weight")])
434
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
435
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
436
+ image_resolution = vision_patch_size * grid_size
437
+ else: # RN50
438
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if
439
+ k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
440
+ vision_layers = tuple(counts)
441
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
442
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
443
+ vision_patch_size = None
444
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
445
+ image_resolution = output_width * 32
446
+
447
+ embed_dim = state_dict["text_projection"].shape[1]
448
+ context_length = state_dict["positional_embedding"].shape[0] # 77 (77,512)
449
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
450
+ transformer_width = state_dict["ln_final.weight"].shape[0]
451
+ transformer_heads = transformer_width // 64
452
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
453
+
454
+ model = CLIP(
455
+ embed_dim,
456
+ image_resolution, vision_layers, vision_width, vision_patch_size, vision_stride_size,
457
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,
458
+ h_resolution, w_resolution
459
+ )
460
+ if vit:
461
+ state_dict["visual.positional_embedding"] = resize_pos_embed(
462
+ state_dict["visual.positional_embedding"],
463
+ model.visual.positional_embedding,
464
+ h_resolution,
465
+ w_resolution
466
+ )
467
+ else: # RN50
468
+ state_dict["visual.attnpool.positional_embedding"] = resize_pos_embed(
469
+ state_dict["visual.attnpool.positional_embedding"],
470
+ model.visual.attnpool.positional_embedding,
471
+ h_resolution,
472
+ w_resolution
473
+ )
474
+
475
+ for key in ["input_resolution", "context_length", "vocab_size"]:
476
+ if key in state_dict:
477
+ del state_dict[key]
478
+
479
+ convert_weights(model)
480
+
481
+ model.load_state_dict(state_dict)
482
+ return model.eval()
483
+
484
+
485
+ import math
486
+
487
+
488
+ def resize_pos_embed(posemb, posemb_new, hight, width):
489
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
490
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
491
+ print('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
492
+
493
+ ntok_new = posemb_new.shape[0] # 129,2048
494
+
495
+ posemb_token, posemb_grid = posemb[:1], posemb[1:]
496
+ ntok_new -= 1
497
+
498
+ gs_old = int(math.sqrt(len(posemb_grid))) # 14
499
+ print('Position embedding resize to height:{} width: {}'.format(hight, width))
500
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
501
+ posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear')
502
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1)
503
+ posemb = torch.cat([posemb_token, posemb_grid.squeeze()], dim=0)
504
+ return posemb
boxmot/appearance/backbones/clip/clip/simple_tokenizer.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ import gzip
4
+ import html
5
+ from functools import lru_cache
6
+
7
+ import ftfy
8
+ import regex as re
9
+
10
+ from boxmot.utils import BOXMOT
11
+
12
+
13
+ @lru_cache()
14
+ def default_bpe():
15
+ return BOXMOT / "appearance/backbones/clip/clip/bpe_simple_vocab_16e6.txt.gz"
16
+
17
+
18
+ @lru_cache()
19
+ def bytes_to_unicode():
20
+ """
21
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
22
+ The reversible bpe codes work on unicode strings.
23
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
24
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
25
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
26
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
27
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
28
+ """
29
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
30
+ cs = bs[:]
31
+ n = 0
32
+ for b in range(2**8):
33
+ if b not in bs:
34
+ bs.append(b)
35
+ cs.append(2 ** 8 + n)
36
+ n += 1
37
+ cs = [chr(n) for n in cs]
38
+ return dict(zip(bs, cs))
39
+
40
+
41
+ def get_pairs(word):
42
+ """Return set of symbol pairs in a word.
43
+ Word is represented as tuple of symbols (symbols being variable-length strings).
44
+ """
45
+ pairs = set()
46
+ prev_char = word[0]
47
+ for char in word[1:]:
48
+ pairs.add((prev_char, char))
49
+ prev_char = char
50
+ return pairs
51
+
52
+
53
+ def basic_clean(text):
54
+ text = ftfy.fix_text(text)
55
+ text = html.unescape(html.unescape(text))
56
+ return text.strip()
57
+
58
+
59
+ def whitespace_clean(text):
60
+ text = re.sub(r'\s+', ' ', text)
61
+ text = text.strip()
62
+ return text
63
+
64
+
65
+ class SimpleTokenizer(object):
66
+ def __init__(self, bpe_path: str = default_bpe()):
67
+ self.byte_encoder = bytes_to_unicode()
68
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
69
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
70
+ merges = merges[1:49152 - 256 - 2 + 1]
71
+ merges = [tuple(merge.split()) for merge in merges]
72
+ vocab = list(bytes_to_unicode().values())
73
+ vocab = vocab + [v + '</w>' for v in vocab]
74
+ for merge in merges:
75
+ vocab.append(''.join(merge))
76
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
77
+ self.encoder = dict(zip(vocab, range(len(vocab))))
78
+ self.decoder = {v: k for k, v in self.encoder.items()}
79
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
80
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
81
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) # noqa: E501
82
+
83
+ def bpe(self, token):
84
+ if token in self.cache:
85
+ return self.cache[token]
86
+ word = tuple(token[:-1]) + (token[-1] + '</w>',)
87
+ pairs = get_pairs(word)
88
+
89
+ if not pairs:
90
+ return token + '</w>'
91
+
92
+ while True:
93
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
94
+ if bigram not in self.bpe_ranks:
95
+ break
96
+ first, second = bigram
97
+ new_word = []
98
+ i = 0
99
+ while i < len(word):
100
+ try:
101
+ j = word.index(first, i)
102
+ new_word.extend(word[i:j])
103
+ i = j
104
+ except Exception:
105
+
106
+ new_word.extend(word[i:])
107
+ break
108
+
109
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
110
+ new_word.append(first + second)
111
+ i += 2
112
+ else:
113
+ new_word.append(word[i])
114
+ i += 1
115
+ new_word = tuple(new_word)
116
+ word = new_word
117
+ if len(word) == 1:
118
+ break
119
+ else:
120
+ pairs = get_pairs(word)
121
+ word = ' '.join(word)
122
+ self.cache[token] = word
123
+ return word
124
+
125
+ def encode(self, text):
126
+ bpe_tokens = []
127
+ text = whitespace_clean(basic_clean(text)).lower()
128
+ for token in re.findall(self.pat, text):
129
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
130
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
131
+ return bpe_tokens
132
+
133
+ def decode(self, tokens):
134
+ text = ''.join([self.decoder[token] for token in tokens])
135
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
136
+ return text
boxmot/appearance/backbones/clip/config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
boxmot/appearance/backbones/clip/config/defaults.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ from yacs.config import CfgNode as CN
4
+
5
+ # -----------------------------------------------------------------------------
6
+ # Convention about Training / Test specific parameters
7
+ # -----------------------------------------------------------------------------
8
+ # Whenever an argument can be either used for training or for testing, the
9
+ # corresponding name will be post-fixed by a _TRAIN for a training parameter,
10
+
11
+ # -----------------------------------------------------------------------------
12
+ # Config definition
13
+ # -----------------------------------------------------------------------------
14
+
15
+ _C = CN()
16
+ # -----------------------------------------------------------------------------
17
+ # MODEL
18
+ # -----------------------------------------------------------------------------
19
+ _C.MODEL = CN()
20
+ # Using cuda or cpu for training
21
+ _C.MODEL.DEVICE = "cuda"
22
+ # ID number of GPU
23
+ _C.MODEL.DEVICE_ID = '0'
24
+ # Name of backbone
25
+ _C.MODEL.NAME = 'ViT-B-16'
26
+ # Last stride of backbone
27
+ _C.MODEL.LAST_STRIDE = 1
28
+ # Path to pretrained model of backbone
29
+ _C.MODEL.PRETRAIN_PATH = '/home/mikel.brostrom/yolo_tracking/clip_market1501.pt'
30
+
31
+ # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model
32
+ # Options: 'imagenet' , 'self' , 'finetune'
33
+ _C.MODEL.PRETRAIN_CHOICE = 'imagenet'
34
+
35
+ # If train with BNNeck, options: 'bnneck' or 'no'
36
+ _C.MODEL.NECK = 'bnneck'
37
+ # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration
38
+ _C.MODEL.IF_WITH_CENTER = 'no'
39
+
40
+ _C.MODEL.ID_LOSS_TYPE = 'softmax'
41
+ _C.MODEL.ID_LOSS_WEIGHT = 1.0
42
+ _C.MODEL.TRIPLET_LOSS_WEIGHT = 1.0
43
+ _C.MODEL.I2T_LOSS_WEIGHT = 1.0
44
+
45
+ _C.MODEL.METRIC_LOSS_TYPE = 'triplet'
46
+ # If train with multi-gpu ddp mode, options: 'True', 'False'
47
+ _C.MODEL.DIST_TRAIN = False
48
+ # If train with soft triplet loss, options: 'True', 'False'
49
+ _C.MODEL.NO_MARGIN = False
50
+ # If train with label smooth, options: 'on', 'off'
51
+ _C.MODEL.IF_LABELSMOOTH = 'on'
52
+ # If train with arcface loss, options: 'True', 'False'
53
+ _C.MODEL.COS_LAYER = False
54
+
55
+ # Transformer setting
56
+ _C.MODEL.DROP_PATH = 0.1
57
+ _C.MODEL.DROP_OUT = 0.0
58
+ _C.MODEL.ATT_DROP_RATE = 0.0
59
+ _C.MODEL.TRANSFORMER_TYPE = 'None'
60
+ _C.MODEL.STRIDE_SIZE = [16, 16]
61
+
62
+ # SIE Parameter
63
+ _C.MODEL.SIE_COE = 3.0
64
+ _C.MODEL.SIE_CAMERA = False
65
+ _C.MODEL.SIE_VIEW = False
66
+
67
+ # -----------------------------------------------------------------------------
68
+ # INPUT
69
+ # -----------------------------------------------------------------------------
70
+ _C.INPUT = CN()
71
+ # Size of the image during training
72
+ _C.INPUT.SIZE_TRAIN = [256, 128]
73
+ # Size of the image during test
74
+ _C.INPUT.SIZE_TEST = [256, 128]
75
+ # Random probability for image horizontal flip
76
+ _C.INPUT.PROB = 0.5
77
+ # Random probability for random erasing
78
+ _C.INPUT.RE_PROB = 0.5
79
+ # Values to be used for image normalization
80
+ _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
81
+ # Values to be used for image normalization
82
+ _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
83
+ # Value of padding size
84
+ _C.INPUT.PADDING = 10
85
+
86
+ # -----------------------------------------------------------------------------
87
+ # Dataset
88
+ # -----------------------------------------------------------------------------
89
+ _C.DATASETS = CN()
90
+ # List of the dataset names for training, as present in paths_catalog.py
91
+ _C.DATASETS.NAMES = ('market1501')
92
+ # Root directory where datasets should be used (and downloaded if not found)
93
+ _C.DATASETS.ROOT_DIR = ('../data')
94
+
95
+
96
+ # -----------------------------------------------------------------------------
97
+ # DataLoader
98
+ # -----------------------------------------------------------------------------
99
+ _C.DATALOADER = CN()
100
+ # Number of data loading threads
101
+ _C.DATALOADER.NUM_WORKERS = 8
102
+ # Sampler for data loading
103
+ _C.DATALOADER.SAMPLER = 'softmax'
104
+ # Number of instance for one batch
105
+ _C.DATALOADER.NUM_INSTANCE = 16
106
+
107
+ # ---------------------------------------------------------------------------- #
108
+ # Solver
109
+ _C.SOLVER = CN()
110
+ _C.SOLVER.SEED = 1234
111
+ _C.SOLVER.MARGIN = 0.3
112
+
113
+ # stage1
114
+ # ---------------------------------------------------------------------------- #
115
+ # Name of optimizer
116
+ _C.SOLVER.STAGE1 = CN()
117
+
118
+ _C.SOLVER.STAGE1.IMS_PER_BATCH = 64
119
+
120
+ _C.SOLVER.STAGE1.OPTIMIZER_NAME = "Adam"
121
+ # Number of max epoches
122
+ _C.SOLVER.STAGE1.MAX_EPOCHS = 100
123
+ # Base learning rate
124
+ _C.SOLVER.STAGE1.BASE_LR = 3e-4
125
+ # Momentum
126
+ _C.SOLVER.STAGE1.MOMENTUM = 0.9
127
+
128
+ # Settings of weight decay
129
+ _C.SOLVER.STAGE1.WEIGHT_DECAY = 0.0005
130
+ _C.SOLVER.STAGE1.WEIGHT_DECAY_BIAS = 0.0005
131
+
132
+ # warm up factor
133
+ _C.SOLVER.STAGE1.WARMUP_FACTOR = 0.01
134
+ # warm up epochs
135
+ _C.SOLVER.STAGE1.WARMUP_EPOCHS = 5
136
+ _C.SOLVER.STAGE1.WARMUP_LR_INIT = 0.01
137
+ _C.SOLVER.STAGE1.LR_MIN = 0.000016
138
+
139
+ _C.SOLVER.STAGE1.WARMUP_ITERS = 500
140
+ # method of warm up, option: 'constant','linear'
141
+ _C.SOLVER.STAGE1.WARMUP_METHOD = "linear"
142
+
143
+ _C.SOLVER.STAGE1.COSINE_MARGIN = 0.5
144
+ _C.SOLVER.STAGE1.COSINE_SCALE = 30
145
+
146
+ # epoch number of saving checkpoints
147
+ _C.SOLVER.STAGE1.CHECKPOINT_PERIOD = 10
148
+ # iteration of display training log
149
+ _C.SOLVER.STAGE1.LOG_PERIOD = 100
150
+ # epoch number of validation
151
+ # Number of images per batch
152
+ # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU will
153
+ # contain 16 images per batch
154
+ # _C.SOLVER.STAGE1.IMS_PER_BATCH = 64
155
+ _C.SOLVER.STAGE1.EVAL_PERIOD = 10
156
+
157
+ # ---------------------------------------------------------------------------- #
158
+ # Solver
159
+ # stage1
160
+ # ---------------------------------------------------------------------------- #
161
+ _C.SOLVER.STAGE2 = CN()
162
+
163
+ _C.SOLVER.STAGE2.IMS_PER_BATCH = 64
164
+ # Name of optimizer
165
+ _C.SOLVER.STAGE2.OPTIMIZER_NAME = "Adam"
166
+ # Number of max epoches
167
+ _C.SOLVER.STAGE2.MAX_EPOCHS = 100
168
+ # Base learning rate
169
+ _C.SOLVER.STAGE2.BASE_LR = 3e-4
170
+ # Whether using larger learning rate for fc layer
171
+ _C.SOLVER.STAGE2.LARGE_FC_LR = False
172
+ # Factor of learning bias
173
+ _C.SOLVER.STAGE2.BIAS_LR_FACTOR = 1
174
+ # Momentum
175
+ _C.SOLVER.STAGE2.MOMENTUM = 0.9
176
+ # Margin of triplet loss
177
+ # Learning rate of SGD to learn the centers of center loss
178
+ _C.SOLVER.STAGE2.CENTER_LR = 0.5
179
+ # Balanced weight of center loss
180
+ _C.SOLVER.STAGE2.CENTER_LOSS_WEIGHT = 0.0005
181
+
182
+ # Settings of weight decay
183
+ _C.SOLVER.STAGE2.WEIGHT_DECAY = 0.0005
184
+ _C.SOLVER.STAGE2.WEIGHT_DECAY_BIAS = 0.0005
185
+
186
+ # decay rate of learning rate
187
+ _C.SOLVER.STAGE2.GAMMA = 0.1
188
+ # decay step of learning rate
189
+ _C.SOLVER.STAGE2.STEPS = (40, 70)
190
+ # warm up factor
191
+ _C.SOLVER.STAGE2.WARMUP_FACTOR = 0.01
192
+ # warm up epochs
193
+ _C.SOLVER.STAGE2.WARMUP_EPOCHS = 5
194
+ _C.SOLVER.STAGE2.WARMUP_LR_INIT = 0.01
195
+ _C.SOLVER.STAGE2.LR_MIN = 0.000016
196
+
197
+
198
+ _C.SOLVER.STAGE2.WARMUP_ITERS = 500
199
+ # method of warm up, option: 'constant','linear'
200
+ _C.SOLVER.STAGE2.WARMUP_METHOD = "linear"
201
+
202
+ _C.SOLVER.STAGE2.COSINE_MARGIN = 0.5
203
+ _C.SOLVER.STAGE2.COSINE_SCALE = 30
204
+
205
+ # epoch number of saving checkpoints
206
+ _C.SOLVER.STAGE2.CHECKPOINT_PERIOD = 10
207
+ # iteration of display training log
208
+ _C.SOLVER.STAGE2.LOG_PERIOD = 100
209
+ # epoch number of validation
210
+ _C.SOLVER.STAGE2.EVAL_PERIOD = 10
211
+ # Number of images per batch
212
+ # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU will
213
+ # contain 16 images per batch
214
+
215
+ # ---------------------------------------------------------------------------- #
216
+ # TEST
217
+ # ---------------------------------------------------------------------------- #
218
+
219
+ _C.TEST = CN()
220
+ # Number of images per batch during test
221
+ _C.TEST.IMS_PER_BATCH = 128
222
+ # If test with re-ranking, options: 'True','False'
223
+ _C.TEST.RE_RANKING = False
224
+ # Path to trained model
225
+ _C.TEST.WEIGHT = ""
226
+ # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after'
227
+ _C.TEST.NECK_FEAT = 'after'
228
+ # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance
229
+ _C.TEST.FEAT_NORM = 'yes'
230
+
231
+ # Name for saving the distmat after testing.
232
+ _C.TEST.DIST_MAT = "dist_mat.npy"
233
+ # Whether calculate the eval score option: 'True', 'False'
234
+ _C.TEST.EVAL = False
235
+ # ---------------------------------------------------------------------------- #
236
+ # Misc options
237
+ # ---------------------------------------------------------------------------- #
238
+ # Path to checkpoint and saved log of trained model
239
+ _C.OUTPUT_DIR = ""
boxmot/appearance/backbones/clip/config/defaults_base.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ from yacs.config import CfgNode as CN
4
+
5
+ # -----------------------------------------------------------------------------
6
+ # Convention about Training / Test specific parameters
7
+ # -----------------------------------------------------------------------------
8
+ # Whenever an argument can be either used for training or for testing, the
9
+ # corresponding name will be post-fixed by a _TRAIN for a training parameter,
10
+
11
+ # -----------------------------------------------------------------------------
12
+ # Config definition
13
+ # -----------------------------------------------------------------------------
14
+
15
+ _C = CN()
16
+ # -----------------------------------------------------------------------------
17
+ # MODEL
18
+ # -----------------------------------------------------------------------------
19
+ _C.MODEL = CN()
20
+ # Using cuda or cpu for training
21
+ _C.MODEL.DEVICE = "cuda"
22
+ # ID number of GPU
23
+ _C.MODEL.DEVICE_ID = '0'
24
+ # Name of backbone
25
+ _C.MODEL.NAME = 'resnet50'
26
+ # Last stride of backbone
27
+ _C.MODEL.LAST_STRIDE = 1
28
+ # Path to pretrained model of backbone
29
+ _C.MODEL.PRETRAIN_PATH = ''
30
+
31
+ # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model
32
+ # Options: 'imagenet' , 'self' , 'finetune'
33
+ _C.MODEL.PRETRAIN_CHOICE = 'imagenet'
34
+
35
+ # If train with BNNeck, options: 'bnneck' or 'no'
36
+ _C.MODEL.NECK = 'bnneck'
37
+ # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration
38
+ _C.MODEL.IF_WITH_CENTER = 'no'
39
+
40
+ _C.MODEL.ID_LOSS_TYPE = 'softmax'
41
+ _C.MODEL.ID_LOSS_WEIGHT = 1.0
42
+ _C.MODEL.TRIPLET_LOSS_WEIGHT = 1.0
43
+ _C.MODEL.I2T_LOSS_WEIGHT = 1.0
44
+
45
+ _C.MODEL.METRIC_LOSS_TYPE = 'triplet'
46
+ # If train with multi-gpu ddp mode, options: 'True', 'False'
47
+ _C.MODEL.DIST_TRAIN = False
48
+ # If train with soft triplet loss, options: 'True', 'False'
49
+ _C.MODEL.NO_MARGIN = False
50
+ # If train with label smooth, options: 'on', 'off'
51
+ _C.MODEL.IF_LABELSMOOTH = 'on'
52
+ # If train with arcface loss, options: 'True', 'False'
53
+ _C.MODEL.COS_LAYER = False
54
+
55
+ # Transformer setting
56
+ _C.MODEL.DROP_PATH = 0.1
57
+ _C.MODEL.DROP_OUT = 0.0
58
+ _C.MODEL.ATT_DROP_RATE = 0.0
59
+ _C.MODEL.TRANSFORMER_TYPE = 'None'
60
+ _C.MODEL.STRIDE_SIZE = [16, 16]
61
+
62
+ # SIE Parameter
63
+ _C.MODEL.SIE_COE = 3.0
64
+ _C.MODEL.SIE_CAMERA = False
65
+ _C.MODEL.SIE_VIEW = False
66
+
67
+ # -----------------------------------------------------------------------------
68
+ # INPUT
69
+ # -----------------------------------------------------------------------------
70
+ _C.INPUT = CN()
71
+ # Size of the image during training
72
+ _C.INPUT.SIZE_TRAIN = [384, 128]
73
+ # Size of the image during test
74
+ _C.INPUT.SIZE_TEST = [384, 128]
75
+ # Random probability for image horizontal flip
76
+ _C.INPUT.PROB = 0.5
77
+ # Random probability for random erasing
78
+ _C.INPUT.RE_PROB = 0.5
79
+ # Values to be used for image normalization
80
+ _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
81
+ # Values to be used for image normalization
82
+ _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
83
+ # Value of padding size
84
+ _C.INPUT.PADDING = 10
85
+
86
+ # -----------------------------------------------------------------------------
87
+ # Dataset
88
+ # -----------------------------------------------------------------------------
89
+ _C.DATASETS = CN()
90
+ # List of the dataset names for training, as present in paths_catalog.py
91
+ _C.DATASETS.NAMES = ('market1501')
92
+ # Root directory where datasets should be used (and downloaded if not found)
93
+ _C.DATASETS.ROOT_DIR = ('../data')
94
+
95
+
96
+ # -----------------------------------------------------------------------------
97
+ # DataLoader
98
+ # -----------------------------------------------------------------------------
99
+ _C.DATALOADER = CN()
100
+ # Number of data loading threads
101
+ _C.DATALOADER.NUM_WORKERS = 8
102
+ # Sampler for data loading
103
+ _C.DATALOADER.SAMPLER = 'softmax'
104
+ # Number of instance for one batch
105
+ _C.DATALOADER.NUM_INSTANCE = 16
106
+
107
+ # ---------------------------------------------------------------------------- #
108
+ # Solver
109
+ # ---------------------------------------------------------------------------- #
110
+ _C.SOLVER = CN()
111
+ # Name of optimizer
112
+ _C.SOLVER.OPTIMIZER_NAME = "Adam"
113
+ # Number of max epoches
114
+ _C.SOLVER.MAX_EPOCHS = 100
115
+ # Base learning rate
116
+ _C.SOLVER.BASE_LR = 3e-4
117
+ # Whether using larger learning rate for fc layer
118
+ _C.SOLVER.LARGE_FC_LR = False
119
+ # Factor of learning bias
120
+ _C.SOLVER.BIAS_LR_FACTOR = 1
121
+ # Factor of learning bias
122
+ _C.SOLVER.SEED = 1234
123
+ # Momentum
124
+ _C.SOLVER.MOMENTUM = 0.9
125
+ # Margin of triplet loss
126
+ _C.SOLVER.MARGIN = 0.3
127
+ # Learning rate of SGD to learn the centers of center loss
128
+ _C.SOLVER.CENTER_LR = 0.5
129
+ # Balanced weight of center loss
130
+ _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005
131
+
132
+ # Settings of weight decay
133
+ _C.SOLVER.WEIGHT_DECAY = 0.0005
134
+ _C.SOLVER.WEIGHT_DECAY_BIAS = 0.0005
135
+
136
+ # decay rate of learning rate
137
+ _C.SOLVER.GAMMA = 0.1
138
+ # decay step of learning rate
139
+ _C.SOLVER.STEPS = (40, 70)
140
+ # warm up factor
141
+ _C.SOLVER.WARMUP_FACTOR = 0.01
142
+ # warm up epochs
143
+ _C.SOLVER.WARMUP_EPOCHS = 5
144
+ _C.SOLVER.WARMUP_LR_INIT = 0.01
145
+ _C.SOLVER.LR_MIN = 0.000016
146
+
147
+
148
+ _C.SOLVER.WARMUP_ITERS = 500
149
+ # method of warm up, option: 'constant','linear'
150
+ _C.SOLVER.WARMUP_METHOD = "linear"
151
+
152
+ _C.SOLVER.COSINE_MARGIN = 0.5
153
+ _C.SOLVER.COSINE_SCALE = 30
154
+
155
+ # epoch number of saving checkpoints
156
+ _C.SOLVER.CHECKPOINT_PERIOD = 10
157
+ # iteration of display training log
158
+ _C.SOLVER.LOG_PERIOD = 100
159
+ # epoch number of validation
160
+ _C.SOLVER.EVAL_PERIOD = 10
161
+ # Number of images per batch
162
+ # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU will
163
+ # contain 16 images per batch
164
+ _C.SOLVER.IMS_PER_BATCH = 64
165
+
166
+ # ---------------------------------------------------------------------------- #
167
+ # TEST
168
+ # ---------------------------------------------------------------------------- #
169
+
170
+ _C.TEST = CN()
171
+ # Number of images per batch during test
172
+ _C.TEST.IMS_PER_BATCH = 128
173
+ # If test with re-ranking, options: 'True','False'
174
+ _C.TEST.RE_RANKING = False
175
+ # Path to trained model
176
+ _C.TEST.WEIGHT = ""
177
+ # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after'
178
+ _C.TEST.NECK_FEAT = 'after'
179
+ # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance
180
+ _C.TEST.FEAT_NORM = 'yes'
181
+
182
+ # Name for saving the distmat after testing.
183
+ _C.TEST.DIST_MAT = "dist_mat.npy"
184
+ # Whether calculate the eval score option: 'True', 'False'
185
+ _C.TEST.EVAL = False
186
+ # ---------------------------------------------------------------------------- #
187
+ # Misc options
188
+ # ---------------------------------------------------------------------------- #
189
+ # Path to checkpoint and saved log of trained model
190
+ _C.OUTPUT_DIR = ""
boxmot/appearance/backbones/clip/make_model.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
7
+
8
+ _tokenizer = _Tokenizer()
9
+
10
+
11
+ def weights_init_kaiming(m):
12
+ classname = m.__class__.__name__
13
+ if classname.find('Linear') != -1:
14
+ nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
15
+ nn.init.constant_(m.bias, 0.0)
16
+
17
+ elif classname.find('Conv') != -1:
18
+ nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19
+ if m.bias is not None:
20
+ nn.init.constant_(m.bias, 0.0)
21
+ elif classname.find('BatchNorm') != -1:
22
+ if m.affine:
23
+ nn.init.constant_(m.weight, 1.0)
24
+ nn.init.constant_(m.bias, 0.0)
25
+
26
+
27
+ def weights_init_classifier(m):
28
+ classname = m.__class__.__name__
29
+ if classname.find('Linear') != -1:
30
+ nn.init.normal_(m.weight, std=0.001)
31
+ if m.bias:
32
+ nn.init.constant_(m.bias, 0.0)
33
+
34
+
35
+ class build_transformer(nn.Module):
36
+ def __init__(self, num_classes, camera_num, view_num, cfg):
37
+ super(build_transformer, self).__init__()
38
+ self.model_name = cfg.MODEL.NAME
39
+ self.cos_layer = cfg.MODEL.COS_LAYER
40
+ self.neck = cfg.MODEL.NECK
41
+ self.neck_feat = cfg.TEST.NECK_FEAT
42
+ if self.model_name == 'ViT-B-16':
43
+ self.in_planes = 768
44
+ self.in_planes_proj = 512
45
+ elif self.model_name == 'RN50':
46
+ self.in_planes = 2048
47
+ self.in_planes_proj = 1024
48
+ self.num_classes = num_classes
49
+ self.camera_num = camera_num
50
+ self.view_num = view_num
51
+ self.sie_coe = cfg.MODEL.SIE_COE
52
+
53
+ self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
54
+ self.classifier.apply(weights_init_classifier)
55
+ self.classifier_proj = nn.Linear(self.in_planes_proj, self.num_classes, bias=False)
56
+ self.classifier_proj.apply(weights_init_classifier)
57
+
58
+ self.bottleneck = nn.BatchNorm1d(self.in_planes)
59
+ self.bottleneck.bias.requires_grad_(False)
60
+ self.bottleneck.apply(weights_init_kaiming)
61
+ self.bottleneck_proj = nn.BatchNorm1d(self.in_planes_proj)
62
+ self.bottleneck_proj.bias.requires_grad_(False)
63
+ self.bottleneck_proj.apply(weights_init_kaiming)
64
+
65
+ self.h_resolution = int((cfg.INPUT.SIZE_TRAIN[0] - 16) // cfg.MODEL.STRIDE_SIZE[0] + 1)
66
+ self.w_resolution = int((cfg.INPUT.SIZE_TRAIN[1] - 16) // cfg.MODEL.STRIDE_SIZE[1] + 1)
67
+ self.vision_stride_size = cfg.MODEL.STRIDE_SIZE[0]
68
+ clip_model = load_clip_to_cpu(self.model_name, self.h_resolution, self.w_resolution, self.vision_stride_size)
69
+
70
+ self.image_encoder = clip_model.visual
71
+
72
+ # if cfg.MODEL.SIE_CAMERA and cfg.MODEL.SIE_VIEW:
73
+ # self.cv_embed = nn.Parameter(torch.zeros(camera_num * view_num, self.in_planes))
74
+ # trunc_normal_(self.cv_embed, std=.02)
75
+ # print('camera number is : {}'.format(camera_num))
76
+ # elif cfg.MODEL.SIE_CAMERA:
77
+ # self.cv_embed = nn.Parameter(torch.zeros(camera_num, self.in_planes))
78
+ # trunc_normal_(self.cv_embed, std=.02)
79
+ # print('camera number is : {}'.format(camera_num))
80
+ # elif cfg.MODEL.SIE_VIEW:
81
+ # self.cv_embed = nn.Parameter(torch.zeros(view_num, self.in_planes))
82
+ # trunc_normal_(self.cv_embed, std=.02)
83
+ # print('camera number is : {}'.format(view_num))
84
+
85
+ def forward(self, x, label=None, cam_label=None, view_label=None):
86
+ if self.model_name == 'RN50':
87
+ image_features_last, image_features, image_features_proj = self.image_encoder(x) # B,512 B,128,512
88
+ img_feature_last = nn.functional.avg_pool2d(
89
+ image_features_last,
90
+ image_features_last.shape[2:4]).view(x.shape[0], -1)
91
+ img_feature = nn.functional.avg_pool2d(
92
+ image_features,
93
+ image_features.shape[2:4]).view(x.shape[0], -1)
94
+ img_feature_proj = image_features_proj[0]
95
+
96
+ elif self.model_name == 'ViT-B-16':
97
+ if cam_label is not None and view_label is not None:
98
+ cv_embed = self.sie_coe * self.cv_embed[cam_label * self.view_num + view_label]
99
+ elif cam_label is not None:
100
+ cv_embed = self.sie_coe * self.cv_embed[cam_label]
101
+ elif view_label is not None:
102
+ cv_embed = self.sie_coe * self.cv_embed[view_label]
103
+ else:
104
+ cv_embed = None
105
+ # B,512 B,128,512
106
+ image_features_last, image_features, image_features_proj = self.image_encoder(x, cv_embed)
107
+ img_feature_last = image_features_last[:, 0]
108
+ img_feature = image_features[:, 0]
109
+ img_feature_proj = image_features_proj[:, 0]
110
+
111
+ feat = self.bottleneck(img_feature)
112
+ feat_proj = self.bottleneck_proj(img_feature_proj)
113
+
114
+ if self.training:
115
+ cls_score = self.classifier(feat)
116
+ cls_score_proj = self.classifier_proj(feat_proj)
117
+ return [cls_score, cls_score_proj], [img_feature_last, img_feature, img_feature_proj]
118
+
119
+ else:
120
+ if self.neck_feat == 'after':
121
+ # print("Test with feature after BN")
122
+ return torch.cat([feat, feat_proj], dim=1)
123
+ else:
124
+ return torch.cat([img_feature, img_feature_proj], dim=1)
125
+
126
+ def load_param(self, trained_path):
127
+ param_dict = torch.load(trained_path, map_location=torch.device("cpu"))
128
+ for i in self.state_dict():
129
+ self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
130
+ # print('Loading pretrained model from {}'.format('/home/mikel.brostrom/yolo_tracking/clip_market1501.pt'))
131
+
132
+ def load_param_finetune(self, model_path):
133
+ param_dict = torch.load(model_path)
134
+ for i in param_dict:
135
+ self.state_dict()[i].copy_(param_dict[i])
136
+ # print('Loading pretrained model for finetuning from {}'.format(model_path))
137
+
138
+
139
+ def make_model(cfg, num_class, camera_num, view_num):
140
+ model = build_transformer(num_class, camera_num, view_num, cfg)
141
+ return model
142
+
143
+
144
+ from .clip import clip
145
+
146
+
147
+ def load_clip_to_cpu(backbone_name, h_resolution, w_resolution, vision_stride_size):
148
+ url = clip._MODELS[backbone_name]
149
+ model_path = clip._download(url)
150
+
151
+ try:
152
+ # loading JIT archive
153
+ model = torch.jit.load(model_path, map_location="cpu").eval()
154
+ state_dict = None
155
+
156
+ except RuntimeError:
157
+ state_dict = torch.load(model_path, map_location="cpu")
158
+
159
+ model = clip.build_model(state_dict or model.state_dict(), h_resolution, w_resolution, vision_stride_size)
160
+
161
+ return model
boxmot/appearance/backbones/clip/make_model_clipreid.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
7
+
8
+ _tokenizer = _Tokenizer()
9
+
10
+
11
+ def weights_init_kaiming(m):
12
+ classname = m.__class__.__name__
13
+ if classname.find('Linear') != -1:
14
+ nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
15
+ nn.init.constant_(m.bias, 0.0)
16
+
17
+ elif classname.find('Conv') != -1:
18
+ nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19
+ if m.bias is not None:
20
+ nn.init.constant_(m.bias, 0.0)
21
+ elif classname.find('BatchNorm') != -1:
22
+ if m.affine:
23
+ nn.init.constant_(m.weight, 1.0)
24
+ nn.init.constant_(m.bias, 0.0)
25
+
26
+
27
+ def weights_init_classifier(m):
28
+ classname = m.__class__.__name__
29
+ if classname.find('Linear') != -1:
30
+ nn.init.normal_(m.weight, std=0.001)
31
+ if m.bias:
32
+ nn.init.constant_(m.bias, 0.0)
33
+
34
+
35
+ class TextEncoder(nn.Module):
36
+ def __init__(self, clip_model):
37
+ super().__init__()
38
+ self.transformer = clip_model.transformer
39
+ self.positional_embedding = clip_model.positional_embedding
40
+ self.ln_final = clip_model.ln_final
41
+ self.text_projection = clip_model.text_projection
42
+ self.dtype = clip_model.dtype
43
+
44
+ def forward(self, prompts, tokenized_prompts):
45
+ x = prompts + self.positional_embedding.type(self.dtype)
46
+ x = x.permute(1, 0, 2) # NLD -> LND
47
+ x = self.transformer(x)
48
+ x = x.permute(1, 0, 2) # LND -> NLD
49
+ x = self.ln_final(x).type(self.dtype)
50
+
51
+ # x.shape = [batch_size, n_ctx, transformer.width]
52
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
53
+ x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
54
+ return x
55
+
56
+
57
+ class build_transformer(nn.Module):
58
+ def __init__(self, num_classes, camera_num, view_num, cfg):
59
+ super(build_transformer, self).__init__()
60
+ self.model_name = cfg.MODEL.NAME
61
+ self.cos_layer = cfg.MODEL.COS_LAYER
62
+ self.neck = cfg.MODEL.NECK
63
+ self.neck_feat = cfg.TEST.NECK_FEAT
64
+ if self.model_name == 'ViT-B-16':
65
+ self.in_planes = 768
66
+ self.in_planes_proj = 512
67
+ elif self.model_name == 'RN50':
68
+ self.in_planes = 2048
69
+ self.in_planes_proj = 1024
70
+ self.num_classes = num_classes
71
+ self.camera_num = camera_num
72
+ self.view_num = view_num
73
+ self.sie_coe = cfg.MODEL.SIE_COE
74
+
75
+ self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
76
+ self.classifier.apply(weights_init_classifier)
77
+ self.classifier_proj = nn.Linear(self.in_planes_proj, self.num_classes, bias=False)
78
+ self.classifier_proj.apply(weights_init_classifier)
79
+
80
+ self.bottleneck = nn.BatchNorm1d(self.in_planes)
81
+ self.bottleneck.bias.requires_grad_(False)
82
+ self.bottleneck.apply(weights_init_kaiming)
83
+ self.bottleneck_proj = nn.BatchNorm1d(self.in_planes_proj)
84
+ self.bottleneck_proj.bias.requires_grad_(False)
85
+ self.bottleneck_proj.apply(weights_init_kaiming)
86
+
87
+ self.h_resolution = int((cfg.INPUT.SIZE_TRAIN[0] - 16) // cfg.MODEL.STRIDE_SIZE[0] + 1)
88
+ self.w_resolution = int((cfg.INPUT.SIZE_TRAIN[1] - 16) // cfg.MODEL.STRIDE_SIZE[1] + 1)
89
+ self.vision_stride_size = cfg.MODEL.STRIDE_SIZE[0]
90
+ clip_model = load_clip_to_cpu(self.model_name, self.h_resolution, self.w_resolution, self.vision_stride_size)
91
+
92
+ self.image_encoder = clip_model.visual
93
+
94
+ # if cfg.MODEL.SIE_CAMERA and cfg.MODEL.SIE_VIEW:
95
+ # self.cv_embed = nn.Parameter(torch.zeros(camera_num * view_num, self.in_planes))
96
+ # trunc_normal_(self.cv_embed, std=.02)
97
+ # print('camera number is : {}'.format(camera_num))
98
+ # elif cfg.MODEL.SIE_CAMERA:
99
+ # self.cv_embed = nn.Parameter(torch.zeros(camera_num, self.in_planes))
100
+ # trunc_normal_(self.cv_embed, std=.02)
101
+ # print('camera number is : {}'.format(camera_num))
102
+ # elif cfg.MODEL.SIE_VIEW:
103
+ # self.cv_embed = nn.Parameter(torch.zeros(view_num, self.in_planes))
104
+ # trunc_normal_(self.cv_embed, std=.02)
105
+ # print('camera number is : {}'.format(view_num))
106
+
107
+ dataset_name = cfg.DATASETS.NAMES
108
+ self.prompt_learner = PromptLearner(num_classes, dataset_name, clip_model.dtype, clip_model.token_embedding)
109
+ self.text_encoder = TextEncoder(clip_model)
110
+
111
+ def forward(self, x=None, label=None, get_image=False, get_text=False, cam_label=None, view_label=None):
112
+ if get_text is True:
113
+ prompts = self.prompt_learner(label)
114
+ text_features = self.text_encoder(prompts, self.prompt_learner.tokenized_prompts)
115
+ return text_features
116
+
117
+ if get_image is True:
118
+ image_features_last, image_features, image_features_proj = self.image_encoder(x)
119
+ if self.model_name == 'RN50':
120
+ return image_features_proj[0]
121
+ elif self.model_name == 'ViT-B-16':
122
+ return image_features_proj[:, 0]
123
+
124
+ if self.model_name == 'RN50':
125
+ image_features_last, image_features, image_features_proj = self.image_encoder(x)
126
+ img_feature_last = nn.functional.avg_pool2d(
127
+ image_features_last,
128
+ image_features_last.shape[2:4]).view(x.shape[0], -1)
129
+ img_feature = nn.functional.avg_pool2d(
130
+ image_features,
131
+ image_features.shape[2:4]).view(x.shape[0], -1)
132
+ img_feature_proj = image_features_proj[0]
133
+
134
+ elif self.model_name == 'ViT-B-16':
135
+ if cam_label is not None and view_label is not None:
136
+ cv_embed = self.sie_coe * self.cv_embed[cam_label * self.view_num + view_label]
137
+ elif cam_label is not None:
138
+ cv_embed = self.sie_coe * self.cv_embed[cam_label]
139
+ elif view_label is not None:
140
+ cv_embed = self.sie_coe * self.cv_embed[view_label]
141
+ else:
142
+ cv_embed = None
143
+ image_features_last, image_features, image_features_proj = self.image_encoder(x, cv_embed)
144
+ img_feature_last = image_features_last[:, 0]
145
+ img_feature = image_features[:, 0]
146
+ img_feature_proj = image_features_proj[:, 0]
147
+
148
+ feat = self.bottleneck(img_feature)
149
+ feat_proj = self.bottleneck_proj(img_feature_proj)
150
+
151
+ if self.training:
152
+ cls_score = self.classifier(feat)
153
+ cls_score_proj = self.classifier_proj(feat_proj)
154
+ return [cls_score, cls_score_proj], [img_feature_last, img_feature, img_feature_proj], img_feature_proj
155
+
156
+ else:
157
+ if self.neck_feat == 'after':
158
+ # print("Test with feature after BN")
159
+ return torch.cat([feat, feat_proj], dim=1)
160
+ else:
161
+ return torch.cat([img_feature, img_feature_proj], dim=1)
162
+
163
+ def load_param(self, trained_path):
164
+ param_dict = torch.load(trained_path)
165
+ for i in param_dict:
166
+ self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
167
+ print('Loaded pretrained model from {}'.format(trained_path))
168
+
169
+ def load_param_finetune(self, model_path):
170
+ param_dict = torch.load(model_path)
171
+ for i in param_dict:
172
+ self.state_dict()[i].copy_(param_dict[i])
173
+ print('Loading pretrained model for finetuning from {}'.format(model_path))
174
+
175
+
176
+ def make_model(cfg, num_class, camera_num, view_num):
177
+ model = build_transformer(num_class, camera_num, view_num, cfg)
178
+ return model
179
+
180
+
181
+ from .clip import clip
182
+
183
+
184
+ def load_clip_to_cpu(backbone_name, h_resolution, w_resolution, vision_stride_size):
185
+ url = clip._MODELS[backbone_name]
186
+ model_path = clip._download(url)
187
+
188
+ try:
189
+ # loading JIT archive
190
+ model = torch.jit.load(model_path, map_location="cpu").eval()
191
+ state_dict = None
192
+
193
+ except RuntimeError:
194
+ state_dict = torch.load(model_path, map_location="cpu")
195
+
196
+ model = clip.build_model(state_dict or model.state_dict(), h_resolution, w_resolution, vision_stride_size)
197
+
198
+ return model
199
+
200
+
201
+ class PromptLearner(nn.Module):
202
+ def __init__(self, num_class, dataset_name, dtype, token_embedding):
203
+ super().__init__()
204
+ if dataset_name == "VehicleID" or dataset_name == "veri":
205
+ ctx_init = "A photo of a X X X X vehicle."
206
+ else:
207
+ ctx_init = "A photo of a X X X X person."
208
+
209
+ ctx_dim = 512
210
+ # use given words to initialize context vectors
211
+ ctx_init = ctx_init.replace("_", " ")
212
+ n_ctx = 4
213
+
214
+ tokenized_prompts = clip.tokenize(ctx_init).cuda()
215
+ with torch.no_grad():
216
+ embedding = token_embedding(tokenized_prompts).type(dtype)
217
+ self.tokenized_prompts = tokenized_prompts # torch.Tensor
218
+
219
+ n_cls_ctx = 4
220
+ cls_vectors = torch.empty(num_class, n_cls_ctx, ctx_dim, dtype=dtype)
221
+ nn.init.normal_(cls_vectors, std=0.02)
222
+ self.cls_ctx = nn.Parameter(cls_vectors)
223
+
224
+ # These token vectors will be saved when in save_model(),
225
+ # but they should be ignored in load_model() as we want to use
226
+ # those computed using the current class names
227
+ self.register_buffer("token_prefix", embedding[:, :n_ctx + 1, :])
228
+ self.register_buffer("token_suffix", embedding[:, n_ctx + 1 + n_cls_ctx:, :])
229
+ self.num_class = num_class
230
+ self.n_cls_ctx = n_cls_ctx
231
+
232
+ def forward(self, label):
233
+ cls_ctx = self.cls_ctx[label]
234
+ b = label.shape[0]
235
+ prefix = self.token_prefix.expand(b, -1, -1)
236
+ suffix = self.token_suffix.expand(b, -1, -1)
237
+
238
+ prompts = torch.cat(
239
+ [
240
+ prefix, # (n_cls, 1, dim)
241
+ cls_ctx, # (n_cls, n_ctx, dim)
242
+ suffix, # (n_cls, *, dim)
243
+ ],
244
+ dim=1,
245
+ )
246
+
247
+ return prompts
boxmot/appearance/backbones/hacnn.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ from __future__ import absolute_import, division
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ __all__ = ["HACNN"]
10
+
11
+
12
+ class ConvBlock(nn.Module):
13
+ """Basic convolutional block.
14
+
15
+ convolution + batch normalization + relu.
16
+
17
+ Args:
18
+ in_c (int): number of input channels.
19
+ out_c (int): number of output channels.
20
+ k (int or tuple): kernel size.
21
+ s (int or tuple): stride.
22
+ p (int or tuple): padding.
23
+ """
24
+
25
+ def __init__(self, in_c, out_c, k, s=1, p=0):
26
+ super(ConvBlock, self).__init__()
27
+ self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
28
+ self.bn = nn.BatchNorm2d(out_c)
29
+
30
+ def forward(self, x):
31
+ return F.relu(self.bn(self.conv(x)))
32
+
33
+
34
+ class InceptionA(nn.Module):
35
+ def __init__(self, in_channels, out_channels):
36
+ super(InceptionA, self).__init__()
37
+ mid_channels = out_channels // 4
38
+
39
+ self.stream1 = nn.Sequential(
40
+ ConvBlock(in_channels, mid_channels, 1),
41
+ ConvBlock(mid_channels, mid_channels, 3, p=1),
42
+ )
43
+ self.stream2 = nn.Sequential(
44
+ ConvBlock(in_channels, mid_channels, 1),
45
+ ConvBlock(mid_channels, mid_channels, 3, p=1),
46
+ )
47
+ self.stream3 = nn.Sequential(
48
+ ConvBlock(in_channels, mid_channels, 1),
49
+ ConvBlock(mid_channels, mid_channels, 3, p=1),
50
+ )
51
+ self.stream4 = nn.Sequential(
52
+ nn.AvgPool2d(3, stride=1, padding=1),
53
+ ConvBlock(in_channels, mid_channels, 1),
54
+ )
55
+
56
+ def forward(self, x):
57
+ s1 = self.stream1(x)
58
+ s2 = self.stream2(x)
59
+ s3 = self.stream3(x)
60
+ s4 = self.stream4(x)
61
+ y = torch.cat([s1, s2, s3, s4], dim=1)
62
+ return y
63
+
64
+
65
+ class InceptionB(nn.Module):
66
+ def __init__(self, in_channels, out_channels):
67
+ super(InceptionB, self).__init__()
68
+ mid_channels = out_channels // 4
69
+
70
+ self.stream1 = nn.Sequential(
71
+ ConvBlock(in_channels, mid_channels, 1),
72
+ ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
73
+ )
74
+ self.stream2 = nn.Sequential(
75
+ ConvBlock(in_channels, mid_channels, 1),
76
+ ConvBlock(mid_channels, mid_channels, 3, p=1),
77
+ ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
78
+ )
79
+ self.stream3 = nn.Sequential(
80
+ nn.MaxPool2d(3, stride=2, padding=1),
81
+ ConvBlock(in_channels, mid_channels * 2, 1),
82
+ )
83
+
84
+ def forward(self, x):
85
+ s1 = self.stream1(x)
86
+ s2 = self.stream2(x)
87
+ s3 = self.stream3(x)
88
+ y = torch.cat([s1, s2, s3], dim=1)
89
+ return y
90
+
91
+
92
+ class SpatialAttn(nn.Module):
93
+ """Spatial Attention (Sec. 3.1.I.1)"""
94
+
95
+ def __init__(self):
96
+ super(SpatialAttn, self).__init__()
97
+ self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
98
+ self.conv2 = ConvBlock(1, 1, 1)
99
+
100
+ def forward(self, x):
101
+ # global cross-channel averaging
102
+ x = x.mean(1, keepdim=True)
103
+ # 3-by-3 conv
104
+ x = self.conv1(x)
105
+ # bilinear resizing
106
+ x = F.upsample(
107
+ x, (x.size(2) * 2, x.size(3) * 2), mode="bilinear", align_corners=True
108
+ )
109
+ # scaling conv
110
+ x = self.conv2(x)
111
+ return x
112
+
113
+
114
+ class ChannelAttn(nn.Module):
115
+ """Channel Attention (Sec. 3.1.I.2)"""
116
+
117
+ def __init__(self, in_channels, reduction_rate=16):
118
+ super(ChannelAttn, self).__init__()
119
+ assert in_channels % reduction_rate == 0
120
+ self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1)
121
+ self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1)
122
+
123
+ def forward(self, x):
124
+ # squeeze operation (global average pooling)
125
+ x = F.avg_pool2d(x, x.size()[2:])
126
+ # excitation operation (2 conv layers)
127
+ x = self.conv1(x)
128
+ x = self.conv2(x)
129
+ return x
130
+
131
+
132
+ class SoftAttn(nn.Module):
133
+ """Soft Attention (Sec. 3.1.I)
134
+
135
+ Aim: Spatial Attention + Channel Attention
136
+
137
+ Output: attention maps with shape identical to input.
138
+ """
139
+
140
+ def __init__(self, in_channels):
141
+ super(SoftAttn, self).__init__()
142
+ self.spatial_attn = SpatialAttn()
143
+ self.channel_attn = ChannelAttn(in_channels)
144
+ self.conv = ConvBlock(in_channels, in_channels, 1)
145
+
146
+ def forward(self, x):
147
+ y_spatial = self.spatial_attn(x)
148
+ y_channel = self.channel_attn(x)
149
+ y = y_spatial * y_channel
150
+ y = torch.sigmoid(self.conv(y))
151
+ return y
152
+
153
+
154
+ class HardAttn(nn.Module):
155
+ """Hard Attention (Sec. 3.1.II)"""
156
+
157
+ def __init__(self, in_channels):
158
+ super(HardAttn, self).__init__()
159
+ self.fc = nn.Linear(in_channels, 4 * 2)
160
+ self.init_params()
161
+
162
+ def init_params(self):
163
+ self.fc.weight.data.zero_()
164
+ self.fc.bias.data.copy_(
165
+ torch.tensor([0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float)
166
+ )
167
+
168
+ def forward(self, x):
169
+ # squeeze operation (global average pooling)
170
+ x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
171
+ # predict transformation parameters
172
+ theta = torch.tanh(self.fc(x))
173
+ theta = theta.view(-1, 4, 2)
174
+ return theta
175
+
176
+
177
+ class HarmAttn(nn.Module):
178
+ """Harmonious Attention (Sec. 3.1)"""
179
+
180
+ def __init__(self, in_channels):
181
+ super(HarmAttn, self).__init__()
182
+ self.soft_attn = SoftAttn(in_channels)
183
+ self.hard_attn = HardAttn(in_channels)
184
+
185
+ def forward(self, x):
186
+ y_soft_attn = self.soft_attn(x)
187
+ theta = self.hard_attn(x)
188
+ return y_soft_attn, theta
189
+
190
+
191
+ class HACNN(nn.Module):
192
+ """Harmonious Attention Convolutional Neural Network.
193
+
194
+ Reference:
195
+ Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
196
+
197
+ Public keys:
198
+ - ``hacnn``: HACNN.
199
+ """
200
+
201
+ # Args:
202
+ # num_classes (int): number of classes to predict
203
+ # nchannels (list): number of channels AFTER concatenation
204
+ # feat_dim (int): feature dimension for a single stream
205
+ # learn_region (bool): whether to learn region features (i.e. local branch)
206
+
207
+ def __init__(
208
+ self,
209
+ num_classes,
210
+ loss="softmax",
211
+ nchannels=[128, 256, 384],
212
+ feat_dim=512,
213
+ learn_region=True,
214
+ use_gpu=True,
215
+ **kwargs
216
+ ):
217
+ super(HACNN, self).__init__()
218
+ self.loss = loss
219
+ self.learn_region = learn_region
220
+ self.use_gpu = use_gpu
221
+
222
+ self.conv = ConvBlock(3, 32, 3, s=2, p=1)
223
+
224
+ # Construct Inception + HarmAttn blocks
225
+ # ============== Block 1 ==============
226
+ self.inception1 = nn.Sequential(
227
+ InceptionA(32, nchannels[0]),
228
+ InceptionB(nchannels[0], nchannels[0]),
229
+ )
230
+ self.ha1 = HarmAttn(nchannels[0])
231
+
232
+ # ============== Block 2 ==============
233
+ self.inception2 = nn.Sequential(
234
+ InceptionA(nchannels[0], nchannels[1]),
235
+ InceptionB(nchannels[1], nchannels[1]),
236
+ )
237
+ self.ha2 = HarmAttn(nchannels[1])
238
+
239
+ # ============== Block 3 ==============
240
+ self.inception3 = nn.Sequential(
241
+ InceptionA(nchannels[1], nchannels[2]),
242
+ InceptionB(nchannels[2], nchannels[2]),
243
+ )
244
+ self.ha3 = HarmAttn(nchannels[2])
245
+
246
+ self.fc_global = nn.Sequential(
247
+ nn.Linear(nchannels[2], feat_dim),
248
+ nn.BatchNorm1d(feat_dim),
249
+ nn.ReLU(),
250
+ )
251
+ self.classifier_global = nn.Linear(feat_dim, num_classes)
252
+
253
+ if self.learn_region:
254
+ self.init_scale_factors()
255
+ self.local_conv1 = InceptionB(32, nchannels[0])
256
+ self.local_conv2 = InceptionB(nchannels[0], nchannels[1])
257
+ self.local_conv3 = InceptionB(nchannels[1], nchannels[2])
258
+ self.fc_local = nn.Sequential(
259
+ nn.Linear(nchannels[2] * 4, feat_dim),
260
+ nn.BatchNorm1d(feat_dim),
261
+ nn.ReLU(),
262
+ )
263
+ self.classifier_local = nn.Linear(feat_dim, num_classes)
264
+ self.feat_dim = feat_dim * 2
265
+ else:
266
+ self.feat_dim = feat_dim
267
+
268
+ def init_scale_factors(self):
269
+ # initialize scale factors (s_w, s_h) for four regions
270
+ self.scale_factors = []
271
+ self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
272
+ self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
273
+ self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
274
+ self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
275
+
276
+ def stn(self, x, theta):
277
+ """Performs spatial transform
278
+
279
+ x: (batch, channel, height, width)
280
+ theta: (batch, 2, 3)
281
+ """
282
+ grid = F.affine_grid(theta, x.size())
283
+ x = F.grid_sample(x, grid)
284
+ return x
285
+
286
+ def transform_theta(self, theta_i, region_idx):
287
+ """Transforms theta to include (s_w, s_h), resulting in (batch, 2, 3)"""
288
+ scale_factors = self.scale_factors[region_idx]
289
+ theta = torch.zeros(theta_i.size(0), 2, 3)
290
+ theta[:, :, :2] = scale_factors
291
+ theta[:, :, -1] = theta_i
292
+ if self.use_gpu:
293
+ theta = theta.cuda()
294
+ return theta
295
+
296
+ def forward(self, x):
297
+ assert (
298
+ x.size(2) == 160 and x.size(3) == 64
299
+ ), "Input size does not match, expected (160, 64) but got ({}, {})".format(
300
+ x.size(2), x.size(3)
301
+ )
302
+ x = self.conv(x)
303
+
304
+ # ============== Block 1 ==============
305
+ # global branch
306
+ x1 = self.inception1(x)
307
+ x1_attn, x1_theta = self.ha1(x1)
308
+ x1_out = x1 * x1_attn
309
+ # local branch
310
+ if self.learn_region:
311
+ x1_local_list = []
312
+ for region_idx in range(4):
313
+ x1_theta_i = x1_theta[:, region_idx, :]
314
+ x1_theta_i = self.transform_theta(x1_theta_i, region_idx)
315
+ x1_trans_i = self.stn(x, x1_theta_i)
316
+ x1_trans_i = F.upsample(
317
+ x1_trans_i, (24, 28), mode="bilinear", align_corners=True
318
+ )
319
+ x1_local_i = self.local_conv1(x1_trans_i)
320
+ x1_local_list.append(x1_local_i)
321
+
322
+ # ============== Block 2 ==============
323
+ # Block 2
324
+ # global branch
325
+ x2 = self.inception2(x1_out)
326
+ x2_attn, x2_theta = self.ha2(x2)
327
+ x2_out = x2 * x2_attn
328
+ # local branch
329
+ if self.learn_region:
330
+ x2_local_list = []
331
+ for region_idx in range(4):
332
+ x2_theta_i = x2_theta[:, region_idx, :]
333
+ x2_theta_i = self.transform_theta(x2_theta_i, region_idx)
334
+ x2_trans_i = self.stn(x1_out, x2_theta_i)
335
+ x2_trans_i = F.upsample(
336
+ x2_trans_i, (12, 14), mode="bilinear", align_corners=True
337
+ )
338
+ x2_local_i = x2_trans_i + x1_local_list[region_idx]
339
+ x2_local_i = self.local_conv2(x2_local_i)
340
+ x2_local_list.append(x2_local_i)
341
+
342
+ # ============== Block 3 ==============
343
+ # Block 3
344
+ # global branch
345
+ x3 = self.inception3(x2_out)
346
+ x3_attn, x3_theta = self.ha3(x3)
347
+ x3_out = x3 * x3_attn
348
+ # local branch
349
+ if self.learn_region:
350
+ x3_local_list = []
351
+ for region_idx in range(4):
352
+ x3_theta_i = x3_theta[:, region_idx, :]
353
+ x3_theta_i = self.transform_theta(x3_theta_i, region_idx)
354
+ x3_trans_i = self.stn(x2_out, x3_theta_i)
355
+ x3_trans_i = F.upsample(
356
+ x3_trans_i, (6, 7), mode="bilinear", align_corners=True
357
+ )
358
+ x3_local_i = x3_trans_i + x2_local_list[region_idx]
359
+ x3_local_i = self.local_conv3(x3_local_i)
360
+ x3_local_list.append(x3_local_i)
361
+
362
+ # ============== Feature generation ==============
363
+ # global branch
364
+ x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view(
365
+ x3_out.size(0), x3_out.size(1)
366
+ )
367
+ x_global = self.fc_global(x_global)
368
+ # local branch
369
+ if self.learn_region:
370
+ x_local_list = []
371
+ for region_idx in range(4):
372
+ x_local_i = x3_local_list[region_idx]
373
+ x_local_i = F.avg_pool2d(x_local_i, x_local_i.size()[2:]).view(
374
+ x_local_i.size(0), -1
375
+ )
376
+ x_local_list.append(x_local_i)
377
+ x_local = torch.cat(x_local_list, 1)
378
+ x_local = self.fc_local(x_local)
379
+
380
+ if not self.training:
381
+ # l2 normalization before concatenation
382
+ if self.learn_region:
383
+ x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True)
384
+ x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True)
385
+ return torch.cat([x_global, x_local], 1)
386
+ else:
387
+ return x_global
388
+
389
+ prelogits_global = self.classifier_global(x_global)
390
+ if self.learn_region:
391
+ prelogits_local = self.classifier_local(x_local)
392
+
393
+ if self.loss == "softmax":
394
+ if self.learn_region:
395
+ return (prelogits_global, prelogits_local)
396
+ else:
397
+ return prelogits_global
398
+
399
+ elif self.loss == "triplet":
400
+ if self.learn_region:
401
+ return (prelogits_global, prelogits_local), (x_global, x_local)
402
+ else:
403
+ return prelogits_global, x_global
404
+
405
+ else:
406
+ raise KeyError("Unsupported loss: {}".format(self.loss))
boxmot/appearance/backbones/lmbn/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
boxmot/appearance/backbones/lmbn/attention.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ import math
4
+ import random
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import Conv2d, Module, Parameter, ReLU, Sigmoid, Softmax
9
+ from torch.nn import functional as F
10
+
11
+ torch_ver = torch.__version__[:3]
12
+
13
+ __all__ = [
14
+ "BatchDrop",
15
+ "BatchFeatureErase_Top",
16
+ "BatchRandomErasing",
17
+ "PAM_Module",
18
+ "CAM_Module",
19
+ "Dual_Module",
20
+ "SE_Module",
21
+ ]
22
+
23
+
24
+ class BatchRandomErasing(nn.Module):
25
+ def __init__(
26
+ self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]
27
+ ):
28
+ super(BatchRandomErasing, self).__init__()
29
+
30
+ self.probability = probability
31
+ self.mean = mean
32
+ self.sl = sl
33
+ self.sh = sh
34
+ self.r1 = r1
35
+
36
+ def forward(self, img):
37
+ if self.training:
38
+ if random.uniform(0, 1) > self.probability:
39
+ return img
40
+
41
+ for attempt in range(100):
42
+ area = img.size()[2] * img.size()[3]
43
+
44
+ target_area = random.uniform(self.sl, self.sh) * area
45
+ aspect_ratio = random.uniform(self.r1, 1 / self.r1)
46
+
47
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
48
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
49
+
50
+ if w < img.size()[3] and h < img.size()[2]:
51
+ x1 = random.randint(0, img.size()[2] - h)
52
+ y1 = random.randint(0, img.size()[3] - w)
53
+ if img.size()[1] == 3:
54
+ img[:, 0, x1: x1 + h, y1: y1 + w] = self.mean[0]
55
+ img[:, 1, x1: x1 + h, y1: y1 + w] = self.mean[1]
56
+ img[:, 2, x1: x1 + h, y1: y1 + w] = self.mean[2]
57
+ else:
58
+ img[:, 0, x1: x1 + h, y1: y1 + w] = self.mean[0]
59
+ return img
60
+
61
+ return img
62
+
63
+
64
+ class BatchDrop(nn.Module):
65
+ """
66
+ Ref: Batch DropBlock Network for Person Re-identification and Beyond
67
+ https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
68
+ Created by: daizuozhuo
69
+ """
70
+
71
+ def __init__(self, h_ratio, w_ratio):
72
+ super(BatchDrop, self).__init__()
73
+ self.h_ratio = h_ratio
74
+ self.w_ratio = w_ratio
75
+
76
+ def forward(self, x):
77
+ if self.training:
78
+ h, w = x.size()[-2:]
79
+ rh = round(self.h_ratio * h)
80
+ rw = round(self.w_ratio * w)
81
+ sx = random.randint(0, h - rh)
82
+ sy = random.randint(0, w - rw)
83
+ mask = x.new_ones(x.size())
84
+ mask[:, :, sx: sx + rh, sy: sy + rw] = 0
85
+ x = x * mask
86
+ return x
87
+
88
+
89
+ class BatchDropTop(nn.Module):
90
+ """
91
+ Ref: Top-DB-Net: Top DropBlock for Activation Enhancement in Person Re-Identification
92
+ https://github.com/RQuispeC/top-dropblock/blob/master/torchreid/models/bdnet.py
93
+ Created by: RQuispeC
94
+
95
+ """
96
+
97
+ def __init__(self, h_ratio):
98
+ super(BatchDropTop, self).__init__()
99
+ self.h_ratio = h_ratio
100
+
101
+ def forward(self, x, visdrop=False):
102
+ if self.training or visdrop:
103
+ b, c, h, w = x.size()
104
+ rh = round(self.h_ratio * h)
105
+ act = (x**2).sum(1)
106
+ act = act.view(b, h * w)
107
+ act = F.normalize(act, p=2, dim=1)
108
+ act = act.view(b, h, w)
109
+ max_act, _ = act.max(2)
110
+ ind = torch.argsort(max_act, 1)
111
+ ind = ind[:, -rh:]
112
+ mask = []
113
+ for i in range(b):
114
+ rmask = torch.ones(h)
115
+ rmask[ind[i]] = 0
116
+ mask.append(rmask.unsqueeze(0))
117
+ mask = torch.cat(mask)
118
+ mask = torch.repeat_interleave(mask, w, 1).view(b, h, w)
119
+ mask = torch.repeat_interleave(mask, c, 0).view(b, c, h, w)
120
+ if x.is_cuda:
121
+ mask = mask.cuda()
122
+ if visdrop:
123
+ return mask
124
+ x = x * mask
125
+ return x
126
+
127
+
128
+ class BatchFeatureErase_Top(nn.Module):
129
+ """
130
+ Ref: Top-DB-Net: Top DropBlock for Activation Enhancement in Person Re-Identification
131
+ https://github.com/RQuispeC/top-dropblock/blob/master/torchreid/models/bdnet.py
132
+ Created by: RQuispeC
133
+
134
+ """
135
+
136
+ def __init__(
137
+ self,
138
+ channels,
139
+ bottleneck_type,
140
+ h_ratio=0.33,
141
+ w_ratio=1.0,
142
+ double_bottleneck=False,
143
+ ):
144
+ super(BatchFeatureErase_Top, self).__init__()
145
+
146
+ self.drop_batch_bottleneck = bottleneck_type(channels, 512)
147
+
148
+ self.drop_batch_drop_basic = BatchDrop(h_ratio, w_ratio)
149
+ self.drop_batch_drop_top = BatchDropTop(h_ratio)
150
+
151
+ def forward(self, x, drop_top=True, bottleneck_features=True, visdrop=False):
152
+ features = self.drop_batch_bottleneck(x)
153
+
154
+ if drop_top:
155
+ x = self.drop_batch_drop_top(features, visdrop=visdrop)
156
+ else:
157
+ x = self.drop_batch_drop_basic(features, visdrop=visdrop)
158
+ if visdrop:
159
+ return x # x is dropmask
160
+ if bottleneck_features:
161
+ return x, features
162
+ else:
163
+ return x
164
+
165
+
166
+ class SE_Module(Module):
167
+ def __init__(self, channels, reduction=4):
168
+ super(SE_Module, self).__init__()
169
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0)
170
+ self.relu = ReLU(inplace=True)
171
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0)
172
+ self.sigmoid = Sigmoid()
173
+
174
+ def forward(self, x):
175
+ module_input = x
176
+ x = self.fc1(x)
177
+ x = self.relu(x)
178
+ x = self.fc2(x)
179
+ x = self.sigmoid(x)
180
+ return module_input * x
181
+
182
+
183
+ class PAM_Module(Module):
184
+ """Position attention module"""
185
+
186
+ # Ref from SAGAN
187
+
188
+ def __init__(self, in_dim):
189
+ super(PAM_Module, self).__init__()
190
+ self.chanel_in = in_dim
191
+
192
+ self.query_conv = Conv2d(
193
+ in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1
194
+ )
195
+ self.key_conv = Conv2d(
196
+ in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1
197
+ )
198
+ self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
199
+ self.gamma = Parameter(torch.zeros(1))
200
+
201
+ self.softmax = Softmax(dim=-1)
202
+
203
+ def forward(self, x):
204
+ """
205
+ inputs :
206
+ x : input feature maps( B X C X H X W)
207
+ returns :
208
+ out : attention value + input feature
209
+ attention: B X (HxW) X (HxW)
210
+ """
211
+ m_batchsize, C, height, width = x.size()
212
+ proj_query = (
213
+ self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
214
+ )
215
+ proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
216
+ energy = torch.bmm(proj_query, proj_key)
217
+ attention = self.softmax(energy)
218
+ proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
219
+
220
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
221
+ out = out.view(m_batchsize, C, height, width)
222
+
223
+ out = self.gamma * out + x
224
+ return out
225
+
226
+
227
+ class CAM_Module(Module):
228
+ """Channel attention module"""
229
+
230
+ def __init__(self, in_dim):
231
+ super(CAM_Module, self).__init__()
232
+ self.chanel_in = in_dim
233
+
234
+ self.gamma = Parameter(torch.zeros(1))
235
+ self.softmax = Softmax(dim=-1)
236
+
237
+ def forward(self, x):
238
+ """
239
+ inputs :
240
+ x : input feature maps( B X C X H X W)
241
+ returns :
242
+ out : attention value + input feature
243
+ attention: B X C X C
244
+ """
245
+ m_batchsize, C, height, width = x.size()
246
+ proj_query = x.view(m_batchsize, C, -1)
247
+ proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
248
+ # proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1).contiguous()
249
+ energy = torch.bmm(proj_query, proj_key)
250
+ energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy
251
+ attention = self.softmax(energy_new)
252
+ proj_value = x.view(m_batchsize, C, -1)
253
+
254
+ out = torch.bmm(attention, proj_value)
255
+ out = out.view(m_batchsize, C, height, width)
256
+
257
+ out = self.gamma * out + x
258
+ return out
259
+
260
+
261
+ class Dual_Module(Module):
262
+ """
263
+ # Created by: CASIA IVA
264
+ # Email: jliu@nlpr.ia.ac.cn
265
+ # Copyright (c) 2018
266
+
267
+ # Reference: Dual Attention Network for Scene Segmentation
268
+ # https://arxiv.org/pdf/1809.02983.pdf
269
+ # https://github.com/junfu1115/DANet/blob/master/encoding/nn/attention.py
270
+ """
271
+
272
+ def __init__(self, in_dim):
273
+ super(Dual_Module).__init__()
274
+ self.indim = in_dim
275
+ self.pam = PAM_Module(in_dim)
276
+ self.cam = CAM_Module(in_dim)
277
+
278
+ def forward(self, x):
279
+ out1 = self.pam(x)
280
+ out2 = self.cam(x)
281
+ return out1 + out2
boxmot/appearance/backbones/lmbn/bnneck.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ from torch import nn
4
+
5
+
6
+ class BNNeck(nn.Module):
7
+ def __init__(self, input_dim, class_num, return_f=False):
8
+ super(BNNeck, self).__init__()
9
+ self.return_f = return_f
10
+ self.bn = nn.BatchNorm1d(input_dim)
11
+ self.bn.bias.requires_grad_(False)
12
+ self.classifier = nn.Linear(input_dim, class_num, bias=False)
13
+ self.bn.apply(self.weights_init_kaiming)
14
+ self.classifier.apply(self.weights_init_classifier)
15
+
16
+ def forward(self, x):
17
+ before_neck = x.view(x.size(0), x.size(1))
18
+ after_neck = self.bn(before_neck)
19
+
20
+ if self.return_f:
21
+ score = self.classifier(after_neck)
22
+ return after_neck, score, before_neck
23
+ else:
24
+ x = self.classifier(x)
25
+ return x
26
+
27
+ def weights_init_kaiming(self, m):
28
+ classname = m.__class__.__name__
29
+ if classname.find("Linear") != -1:
30
+ nn.init.kaiming_normal_(m.weight, a=0, mode="fan_out")
31
+ nn.init.constant_(m.bias, 0.0)
32
+ elif classname.find("Conv") != -1:
33
+ nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in")
34
+ if m.bias is not None:
35
+ nn.init.constant_(m.bias, 0.0)
36
+ elif classname.find("BatchNorm") != -1:
37
+ if m.affine:
38
+ nn.init.constant_(m.weight, 1.0)
39
+ nn.init.constant_(m.bias, 0.0)
40
+
41
+ def weights_init_classifier(self, m):
42
+ classname = m.__class__.__name__
43
+ if classname.find("Linear") != -1:
44
+ nn.init.normal_(m.weight, std=0.001)
45
+ if m.bias:
46
+ nn.init.constant_(m.bias, 0.0)
47
+
48
+
49
+ class BNNeck3(nn.Module):
50
+ def __init__(self, input_dim, class_num, feat_dim, return_f=False):
51
+ super(BNNeck3, self).__init__()
52
+ self.return_f = return_f
53
+ # self.reduction = nn.Linear(input_dim, feat_dim)
54
+ # self.bn = nn.BatchNorm1d(feat_dim)
55
+
56
+ self.reduction = nn.Conv2d(input_dim, feat_dim, 1, bias=False)
57
+ self.bn = nn.BatchNorm1d(feat_dim)
58
+
59
+ self.bn.bias.requires_grad_(False)
60
+ self.classifier = nn.Linear(feat_dim, class_num, bias=False)
61
+ self.bn.apply(self.weights_init_kaiming)
62
+ self.classifier.apply(self.weights_init_classifier)
63
+
64
+ def forward(self, x):
65
+ x = self.reduction(x)
66
+ # before_neck = x.squeeze(dim=3).squeeze(dim=2)
67
+ # after_neck = self.bn(x).squeeze(dim=3).squeeze(dim=2)
68
+ before_neck = x.view(x.size(0), x.size(1))
69
+ after_neck = self.bn(before_neck)
70
+ if self.return_f:
71
+ score = self.classifier(after_neck)
72
+ return after_neck, score, before_neck
73
+ else:
74
+ x = self.classifier(x)
75
+ return x
76
+
77
+ def weights_init_kaiming(self, m):
78
+ classname = m.__class__.__name__
79
+ if classname.find("Linear") != -1:
80
+ nn.init.kaiming_normal_(m.weight, a=0, mode="fan_out")
81
+ nn.init.constant_(m.bias, 0.0)
82
+ elif classname.find("Conv") != -1:
83
+ nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in")
84
+ if m.bias is not None:
85
+ nn.init.constant_(m.bias, 0.0)
86
+ elif classname.find("BatchNorm") != -1:
87
+ if m.affine:
88
+ nn.init.constant_(m.weight, 1.0)
89
+ nn.init.constant_(m.bias, 0.0)
90
+
91
+ def weights_init_classifier(self, m):
92
+ classname = m.__class__.__name__
93
+ if classname.find("Linear") != -1:
94
+ nn.init.normal_(m.weight, std=0.001)
95
+ if m.bias:
96
+ nn.init.constant_(m.bias, 0.0)
97
+
98
+
99
+ # Defines the new fc layer and classification layer
100
+ # |--Linear--|--bn--|--relu--|--Linear--|
101
+
102
+
103
+ class ClassBlock(nn.Module):
104
+ def __init__(
105
+ self,
106
+ input_dim,
107
+ class_num,
108
+ droprate=0,
109
+ relu=False,
110
+ bnorm=True,
111
+ num_bottleneck=512,
112
+ linear=True,
113
+ return_f=False,
114
+ ):
115
+ super(ClassBlock, self).__init__()
116
+ self.return_f = return_f
117
+ add_block = []
118
+ if linear:
119
+ add_block += [nn.Linear(input_dim, num_bottleneck)]
120
+ else:
121
+ num_bottleneck = input_dim
122
+ if bnorm:
123
+ add_block += [nn.BatchNorm1d(num_bottleneck)]
124
+ if relu:
125
+ add_block += [nn.LeakyReLU(0.1)]
126
+ if droprate > 0:
127
+ add_block += [nn.Dropout(p=droprate)]
128
+ add_block = nn.Sequential(*add_block)
129
+ add_block.apply(self.weights_init_kaiming)
130
+
131
+ classifier = []
132
+ classifier += [nn.Linear(num_bottleneck, class_num)]
133
+ classifier = nn.Sequential(*classifier)
134
+ classifier.apply(self.weights_init_classifier)
135
+
136
+ self.add_block = add_block
137
+ self.classifier = classifier
138
+
139
+ def forward(self, x):
140
+ x = self.add_block(x.squeeze(3).squeeze(2))
141
+ if self.return_f:
142
+ f = x
143
+ x = self.classifier(x)
144
+ return f, x, f
145
+ else:
146
+ x = self.classifier(x)
147
+ return x
148
+
149
+ def weights_init_kaiming(self, m):
150
+ classname = m.__class__.__name__
151
+ # print(classname)
152
+ if classname.find("Conv") != -1:
153
+ # For old pytorch, you may use kaiming_normal.
154
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
155
+ elif classname.find("Linear") != -1:
156
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out")
157
+ nn.init.constant_(m.bias.data, 0.0)
158
+ elif classname.find("BatchNorm1d") != -1:
159
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
160
+ nn.init.constant_(m.bias.data, 0.0)
161
+
162
+ def weights_init_classifier(self, m):
163
+ classname = m.__class__.__name__
164
+ if classname.find("Linear") != -1:
165
+ nn.init.normal_(m.weight.data, std=0.001)
166
+ nn.init.constant_(m.bias.data, 0.0)
boxmot/appearance/backbones/lmbn/lmbn_n.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ import copy
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from boxmot.appearance.backbones.lmbn.attention import BatchFeatureErase_Top
9
+ from boxmot.appearance.backbones.lmbn.bnneck import BNNeck, BNNeck3
10
+ from boxmot.appearance.backbones.osnet import OSBlock, osnet_x1_0
11
+
12
+
13
+ class LMBN_n(nn.Module):
14
+ def __init__(self, num_classes, loss, pretrained, use_gpu):
15
+ super(LMBN_n, self).__init__()
16
+
17
+ self.n_ch = 2
18
+ self.chs = 512 // self.n_ch
19
+ self.training = False
20
+
21
+ osnet = osnet_x1_0(pretrained=True)
22
+
23
+ self.backone = nn.Sequential(
24
+ osnet.conv1, osnet.maxpool, osnet.conv2, osnet.conv3[0]
25
+ )
26
+
27
+ conv3 = osnet.conv3[1:]
28
+
29
+ self.global_branch = nn.Sequential(
30
+ copy.deepcopy(conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5)
31
+ )
32
+
33
+ self.partial_branch = nn.Sequential(
34
+ copy.deepcopy(conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5)
35
+ )
36
+
37
+ self.channel_branch = nn.Sequential(
38
+ copy.deepcopy(conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5)
39
+ )
40
+
41
+ self.global_pooling = nn.AdaptiveMaxPool2d((1, 1))
42
+ self.partial_pooling = nn.AdaptiveAvgPool2d((2, 1))
43
+ self.channel_pooling = nn.AdaptiveAvgPool2d((1, 1))
44
+
45
+ reduction = BNNeck3(512, num_classes, 512, return_f=True)
46
+
47
+ self.reduction_0 = copy.deepcopy(reduction)
48
+ self.reduction_1 = copy.deepcopy(reduction)
49
+ self.reduction_2 = copy.deepcopy(reduction)
50
+ self.reduction_3 = copy.deepcopy(reduction)
51
+ self.reduction_4 = copy.deepcopy(reduction)
52
+
53
+ self.shared = nn.Sequential(
54
+ nn.Conv2d(self.chs, 512, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True)
55
+ )
56
+ self.weights_init_kaiming(self.shared)
57
+
58
+ self.reduction_ch_0 = BNNeck(512, num_classes, return_f=True)
59
+ self.reduction_ch_1 = BNNeck(512, num_classes, return_f=True)
60
+
61
+ # if args.drop_block:
62
+ # print('Using batch random erasing block.')
63
+ # self.batch_drop_block = BatchRandomErasing()
64
+ # print('Using batch drop block.')
65
+ # self.batch_drop_block = BatchDrop(
66
+ # h_ratio=args.h_ratio, w_ratio=args.w_ratio)
67
+ self.batch_drop_block = BatchFeatureErase_Top(512, OSBlock)
68
+
69
+ self.activation_map = False
70
+
71
+ def forward(self, x):
72
+ # if self.batch_drop_block is not None:
73
+ # x = self.batch_drop_block(x)
74
+
75
+ x = self.backone(x)
76
+
77
+ glo = self.global_branch(x)
78
+ par = self.partial_branch(x)
79
+ cha = self.channel_branch(x)
80
+
81
+ if self.activation_map:
82
+ glo_ = glo
83
+
84
+ if self.batch_drop_block is not None:
85
+ glo_drop, glo = self.batch_drop_block(glo)
86
+
87
+ if self.activation_map:
88
+ _, _, h_par, _ = par.size()
89
+
90
+ fmap_p0 = par[:, :, :h_par // 2, :]
91
+ fmap_p1 = par[:, :, h_par // 2:, :]
92
+ fmap_c0 = cha[:, : self.chs, :, :]
93
+ fmap_c1 = cha[:, self.chs:, :, :]
94
+ print("Generating activation maps...")
95
+
96
+ return glo, glo_, fmap_c0, fmap_c1, fmap_p0, fmap_p1
97
+
98
+ glo_drop = self.global_pooling(glo_drop)
99
+ glo = self.channel_pooling(glo) # shape:(batchsize, 512,1,1)
100
+ g_par = self.global_pooling(par) # shape:(batchsize, 512,1,1)
101
+ p_par = self.partial_pooling(par) # shape:(batchsize, 512,2,1)
102
+ cha = self.channel_pooling(cha) # shape:(batchsize, 256,1,1)
103
+
104
+ p0 = p_par[:, :, 0:1, :]
105
+ p1 = p_par[:, :, 1:2, :]
106
+
107
+ f_glo = self.reduction_0(glo)
108
+ f_p0 = self.reduction_1(g_par)
109
+ f_p1 = self.reduction_2(p0)
110
+ f_p2 = self.reduction_3(p1)
111
+ f_glo_drop = self.reduction_4(glo_drop)
112
+
113
+ ################
114
+
115
+ c0 = cha[:, : self.chs, :, :]
116
+ c1 = cha[:, self.chs:, :, :]
117
+ c0 = self.shared(c0)
118
+ c1 = self.shared(c1)
119
+ f_c0 = self.reduction_ch_0(c0)
120
+ f_c1 = self.reduction_ch_1(c1)
121
+
122
+ ################
123
+
124
+ fea = [f_glo[-1], f_glo_drop[-1], f_p0[-1]]
125
+
126
+ if not self.training:
127
+ features = torch.stack(
128
+ [f_glo[0], f_glo_drop[0], f_p0[0], f_p1[0], f_p2[0], f_c0[0], f_c1[0]],
129
+ dim=2,
130
+ )
131
+ features = features.flatten(1, 2)
132
+ return features
133
+
134
+ return [
135
+ f_glo[1],
136
+ f_glo_drop[1],
137
+ f_p0[1],
138
+ f_p1[1],
139
+ f_p2[1],
140
+ f_c0[1],
141
+ f_c1[1],
142
+ ], fea
143
+
144
+ def weights_init_kaiming(self, m):
145
+ classname = m.__class__.__name__
146
+ if classname.find("Linear") != -1:
147
+ nn.init.kaiming_normal_(m.weight, a=0, mode="fan_out")
148
+ nn.init.constant_(m.bias, 0.0)
149
+ elif classname.find("Conv") != -1:
150
+ nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in")
151
+ if m.bias is not None:
152
+ nn.init.constant_(m.bias, 0.0)
153
+ elif classname.find("BatchNorm") != -1:
154
+ if m.affine:
155
+ nn.init.constant_(m.weight, 1.0)
156
+ nn.init.constant_(m.bias, 0.0)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ # Here I left a simple forward function.
161
+ # Test the model, before you train it.
162
+ import argparse
163
+
164
+ parser = argparse.ArgumentParser(description="MGN")
165
+ parser.add_argument("--num_classes", type=int, default=751, help="")
166
+ parser.add_argument("--bnneck", type=bool, default=True)
167
+ parser.add_argument("--pool", type=str, default="max")
168
+ parser.add_argument("--feats", type=int, default=512)
169
+ parser.add_argument("--drop_block", type=bool, default=True)
170
+ parser.add_argument("--w_ratio", type=float, default=1.0, help="")
171
+
172
+ args = parser.parse_args()
173
+ # net = MCMP_n(args)
174
+ # net.classifier = nn.Sequential()
175
+ # print([p for p in net.parameters()])
176
+ # a=filter(lambda p: p.requires_grad, net.parameters())
177
+ # print(a)
178
+
179
+ # print(net)
180
+ # input = Variable(torch.FloatTensor(8, 3, 384, 128))
181
+ # net.eval()
182
+ # output = net(input)
183
+ # print(output.shape)
184
+ print("net output size:")
185
+ # print(len(output))
boxmot/appearance/backbones/mlfn.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ from __future__ import absolute_import, division
4
+
5
+ import torch
6
+ import torch.utils.model_zoo as model_zoo
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ __all__ = ["mlfn"]
11
+
12
+ model_urls = {
13
+ # training epoch = 5, top1 = 51.6
14
+ "imagenet": "https://mega.nz/#!YHxAhaxC!yu9E6zWl0x5zscSouTdbZu8gdFFytDdl-RAdD2DEfpk",
15
+ }
16
+
17
+
18
+ class MLFNBlock(nn.Module):
19
+ def __init__(self, in_channels, out_channels, stride, fsm_channels, groups=32):
20
+ super(MLFNBlock, self).__init__()
21
+ self.groups = groups
22
+ mid_channels = out_channels // 2
23
+
24
+ # Factor Modules
25
+ self.fm_conv1 = nn.Conv2d(in_channels, mid_channels, 1, bias=False)
26
+ self.fm_bn1 = nn.BatchNorm2d(mid_channels)
27
+ self.fm_conv2 = nn.Conv2d(
28
+ mid_channels,
29
+ mid_channels,
30
+ 3,
31
+ stride=stride,
32
+ padding=1,
33
+ bias=False,
34
+ groups=self.groups,
35
+ )
36
+ self.fm_bn2 = nn.BatchNorm2d(mid_channels)
37
+ self.fm_conv3 = nn.Conv2d(mid_channels, out_channels, 1, bias=False)
38
+ self.fm_bn3 = nn.BatchNorm2d(out_channels)
39
+
40
+ # Factor Selection Module
41
+ self.fsm = nn.Sequential(
42
+ nn.AdaptiveAvgPool2d(1),
43
+ nn.Conv2d(in_channels, fsm_channels[0], 1),
44
+ nn.BatchNorm2d(fsm_channels[0]),
45
+ nn.ReLU(inplace=True),
46
+ nn.Conv2d(fsm_channels[0], fsm_channels[1], 1),
47
+ nn.BatchNorm2d(fsm_channels[1]),
48
+ nn.ReLU(inplace=True),
49
+ nn.Conv2d(fsm_channels[1], self.groups, 1),
50
+ nn.BatchNorm2d(self.groups),
51
+ nn.Sigmoid(),
52
+ )
53
+
54
+ self.downsample = None
55
+ if in_channels != out_channels or stride > 1:
56
+ self.downsample = nn.Sequential(
57
+ nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
58
+ nn.BatchNorm2d(out_channels),
59
+ )
60
+
61
+ def forward(self, x):
62
+ residual = x
63
+ s = self.fsm(x)
64
+
65
+ # reduce dimension
66
+ x = self.fm_conv1(x)
67
+ x = self.fm_bn1(x)
68
+ x = F.relu(x, inplace=True)
69
+
70
+ # group convolution
71
+ x = self.fm_conv2(x)
72
+ x = self.fm_bn2(x)
73
+ x = F.relu(x, inplace=True)
74
+
75
+ # factor selection
76
+ b, c = x.size(0), x.size(1)
77
+ n = c // self.groups
78
+ ss = s.repeat(1, n, 1, 1) # from (b, g, 1, 1) to (b, g*n=c, 1, 1)
79
+ ss = ss.view(b, n, self.groups, 1, 1)
80
+ ss = ss.permute(0, 2, 1, 3, 4).contiguous()
81
+ ss = ss.view(b, c, 1, 1)
82
+ x = ss * x
83
+
84
+ # recover dimension
85
+ x = self.fm_conv3(x)
86
+ x = self.fm_bn3(x)
87
+ x = F.relu(x, inplace=True)
88
+
89
+ if self.downsample is not None:
90
+ residual = self.downsample(residual)
91
+
92
+ return F.relu(residual + x, inplace=True), s
93
+
94
+
95
+ class MLFN(nn.Module):
96
+ """Multi-Level Factorisation Net.
97
+
98
+ Reference:
99
+ Chang et al. Multi-Level Factorisation Net for
100
+ Person Re-Identification. CVPR 2018.
101
+
102
+ Public keys:
103
+ - ``mlfn``: MLFN (Multi-Level Factorisation Net).
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ num_classes,
109
+ loss="softmax",
110
+ groups=32,
111
+ channels=[64, 256, 512, 1024, 2048],
112
+ embed_dim=1024,
113
+ **kwargs
114
+ ):
115
+ super(MLFN, self).__init__()
116
+ self.loss = loss
117
+ self.groups = groups
118
+
119
+ # first convolutional layer
120
+ self.conv1 = nn.Conv2d(3, channels[0], 7, stride=2, padding=3)
121
+ self.bn1 = nn.BatchNorm2d(channels[0])
122
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
123
+
124
+ # main body
125
+ self.feature = nn.ModuleList(
126
+ [
127
+ # layer 1-3
128
+ MLFNBlock(channels[0], channels[1], 1, [128, 64], self.groups),
129
+ MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups),
130
+ MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups),
131
+ # layer 4-7
132
+ MLFNBlock(channels[1], channels[2], 2, [256, 128], self.groups),
133
+ MLFNBlock(channels[2], channels[2], 1, [256, 128], self.groups),
134
+ MLFNBlock(channels[2], channels[2], 1, [256, 128], self.groups),
135
+ MLFNBlock(channels[2], channels[2], 1, [256, 128], self.groups),
136
+ # layer 8-13
137
+ MLFNBlock(channels[2], channels[3], 2, [512, 128], self.groups),
138
+ MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups),
139
+ MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups),
140
+ MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups),
141
+ MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups),
142
+ MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups),
143
+ # layer 14-16
144
+ MLFNBlock(channels[3], channels[4], 2, [512, 128], self.groups),
145
+ MLFNBlock(channels[4], channels[4], 1, [512, 128], self.groups),
146
+ MLFNBlock(channels[4], channels[4], 1, [512, 128], self.groups),
147
+ ]
148
+ )
149
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
150
+
151
+ # projection functions
152
+ self.fc_x = nn.Sequential(
153
+ nn.Conv2d(channels[4], embed_dim, 1, bias=False),
154
+ nn.BatchNorm2d(embed_dim),
155
+ nn.ReLU(inplace=True),
156
+ )
157
+ self.fc_s = nn.Sequential(
158
+ nn.Conv2d(self.groups * 16, embed_dim, 1, bias=False),
159
+ nn.BatchNorm2d(embed_dim),
160
+ nn.ReLU(inplace=True),
161
+ )
162
+
163
+ self.classifier = nn.Linear(embed_dim, num_classes)
164
+
165
+ self.init_params()
166
+
167
+ def init_params(self):
168
+ for m in self.modules():
169
+ if isinstance(m, nn.Conv2d):
170
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
171
+ if m.bias is not None:
172
+ nn.init.constant_(m.bias, 0)
173
+ elif isinstance(m, nn.BatchNorm2d):
174
+ nn.init.constant_(m.weight, 1)
175
+ nn.init.constant_(m.bias, 0)
176
+ elif isinstance(m, nn.Linear):
177
+ nn.init.normal_(m.weight, 0, 0.01)
178
+ if m.bias is not None:
179
+ nn.init.constant_(m.bias, 0)
180
+
181
+ def forward(self, x):
182
+ x = self.conv1(x)
183
+ x = self.bn1(x)
184
+ x = F.relu(x, inplace=True)
185
+ x = self.maxpool(x)
186
+
187
+ s_hat = []
188
+ for block in self.feature:
189
+ x, s = block(x)
190
+ s_hat.append(s)
191
+ s_hat = torch.cat(s_hat, 1)
192
+
193
+ x = self.global_avgpool(x)
194
+ x = self.fc_x(x)
195
+ s_hat = self.fc_s(s_hat)
196
+
197
+ v = (x + s_hat) * 0.5
198
+ v = v.view(v.size(0), -1)
199
+
200
+ if not self.training:
201
+ return v
202
+
203
+ y = self.classifier(v)
204
+
205
+ if self.loss == "softmax":
206
+ return y
207
+ elif self.loss == "triplet":
208
+ return y, v
209
+ else:
210
+ raise KeyError("Unsupported loss: {}".format(self.loss))
211
+
212
+
213
+ def init_pretrained_weights(model, model_url):
214
+ """Initializes model with pretrained weights.
215
+
216
+ Layers that don't match with pretrained layers in name or size are kept unchanged.
217
+ """
218
+ pretrain_dict = model_zoo.load_url(model_url)
219
+ model_dict = model.state_dict()
220
+ pretrain_dict = {
221
+ k: v
222
+ for k, v in pretrain_dict.items()
223
+ if k in model_dict and model_dict[k].size() == v.size()
224
+ }
225
+ model_dict.update(pretrain_dict)
226
+ model.load_state_dict(model_dict)
227
+
228
+
229
+ def mlfn(num_classes, loss="softmax", pretrained=True, **kwargs):
230
+ model = MLFN(num_classes, loss, **kwargs)
231
+ if pretrained:
232
+ # init_pretrained_weights(model, model_urls['imagenet'])
233
+ import warnings
234
+
235
+ warnings.warn(
236
+ "The imagenet pretrained weights need to be manually downloaded from {}".format(
237
+ model_urls["imagenet"]
238
+ )
239
+ )
240
+ return model
boxmot/appearance/backbones/mobilenetv2.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ from __future__ import absolute_import, division
4
+
5
+ import torch.utils.model_zoo as model_zoo
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ __all__ = ["mobilenetv2_x1_0", "mobilenetv2_x1_4"]
10
+
11
+ model_urls = {
12
+ # 1.0: top-1 71.3
13
+ "mobilenetv2_x1_0": "https://mega.nz/#!NKp2wAIA!1NH1pbNzY_M2hVk_hdsxNM1NUOWvvGPHhaNr-fASF6c",
14
+ # 1.4: top-1 73.9
15
+ "mobilenetv2_x1_4": "https://mega.nz/#!RGhgEIwS!xN2s2ZdyqI6vQ3EwgmRXLEW3khr9tpXg96G9SUJugGk",
16
+ }
17
+
18
+
19
+ class ConvBlock(nn.Module):
20
+ """Basic convolutional block.
21
+
22
+ convolution (bias discarded) + batch normalization + relu6.
23
+
24
+ Args:
25
+ in_c (int): number of input channels.
26
+ out_c (int): number of output channels.
27
+ k (int or tuple): kernel size.
28
+ s (int or tuple): stride.
29
+ p (int or tuple): padding.
30
+ g (int): number of blocked connections from input channels
31
+ to output channels (default: 1).
32
+ """
33
+
34
+ def __init__(self, in_c, out_c, k, s=1, p=0, g=1):
35
+ super(ConvBlock, self).__init__()
36
+ self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p, bias=False, groups=g)
37
+ self.bn = nn.BatchNorm2d(out_c)
38
+
39
+ def forward(self, x):
40
+ return F.relu6(self.bn(self.conv(x)))
41
+
42
+
43
+ class Bottleneck(nn.Module):
44
+ def __init__(self, in_channels, out_channels, expansion_factor, stride=1):
45
+ super(Bottleneck, self).__init__()
46
+ mid_channels = in_channels * expansion_factor
47
+ self.use_residual = stride == 1 and in_channels == out_channels
48
+ self.conv1 = ConvBlock(in_channels, mid_channels, 1)
49
+ self.dwconv2 = ConvBlock(
50
+ mid_channels, mid_channels, 3, stride, 1, g=mid_channels
51
+ )
52
+ self.conv3 = nn.Sequential(
53
+ nn.Conv2d(mid_channels, out_channels, 1, bias=False),
54
+ nn.BatchNorm2d(out_channels),
55
+ )
56
+
57
+ def forward(self, x):
58
+ m = self.conv1(x)
59
+ m = self.dwconv2(m)
60
+ m = self.conv3(m)
61
+ if self.use_residual:
62
+ return x + m
63
+ else:
64
+ return m
65
+
66
+
67
+ class MobileNetV2(nn.Module):
68
+ """MobileNetV2.
69
+
70
+ Reference:
71
+ Sandler et al. MobileNetV2: Inverted Residuals and
72
+ Linear Bottlenecks. CVPR 2018.
73
+
74
+ Public keys:
75
+ - ``mobilenetv2_x1_0``: MobileNetV2 x1.0.
76
+ - ``mobilenetv2_x1_4``: MobileNetV2 x1.4.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ num_classes,
82
+ width_mult=1,
83
+ loss="softmax",
84
+ fc_dims=None,
85
+ dropout_p=None,
86
+ **kwargs
87
+ ):
88
+ super(MobileNetV2, self).__init__()
89
+ self.loss = loss
90
+ self.in_channels = int(32 * width_mult)
91
+ self.feature_dim = int(1280 * width_mult) if width_mult > 1 else 1280
92
+
93
+ # construct layers
94
+ self.conv1 = ConvBlock(3, self.in_channels, 3, s=2, p=1)
95
+ self.conv2 = self._make_layer(Bottleneck, 1, int(16 * width_mult), 1, 1)
96
+ self.conv3 = self._make_layer(Bottleneck, 6, int(24 * width_mult), 2, 2)
97
+ self.conv4 = self._make_layer(Bottleneck, 6, int(32 * width_mult), 3, 2)
98
+ self.conv5 = self._make_layer(Bottleneck, 6, int(64 * width_mult), 4, 2)
99
+ self.conv6 = self._make_layer(Bottleneck, 6, int(96 * width_mult), 3, 1)
100
+ self.conv7 = self._make_layer(Bottleneck, 6, int(160 * width_mult), 3, 2)
101
+ self.conv8 = self._make_layer(Bottleneck, 6, int(320 * width_mult), 1, 1)
102
+ self.conv9 = ConvBlock(self.in_channels, self.feature_dim, 1)
103
+
104
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
105
+ self.fc = self._construct_fc_layer(fc_dims, self.feature_dim, dropout_p)
106
+ self.classifier = nn.Linear(self.feature_dim, num_classes)
107
+
108
+ self._init_params()
109
+
110
+ def _make_layer(self, block, t, c, n, s):
111
+ # t: expansion factor
112
+ # c: output channels
113
+ # n: number of blocks
114
+ # s: stride for first layer
115
+ layers = []
116
+ layers.append(block(self.in_channels, c, t, s))
117
+ self.in_channels = c
118
+ for i in range(1, n):
119
+ layers.append(block(self.in_channels, c, t))
120
+ return nn.Sequential(*layers)
121
+
122
+ def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
123
+ """Constructs fully connected layer.
124
+
125
+ Args:
126
+ fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
127
+ input_dim (int): input dimension
128
+ dropout_p (float): dropout probability, if None, dropout is unused
129
+ """
130
+ if fc_dims is None:
131
+ self.feature_dim = input_dim
132
+ return None
133
+
134
+ assert isinstance(
135
+ fc_dims, (list, tuple)
136
+ ), "fc_dims must be either list or tuple, but got {}".format(type(fc_dims))
137
+
138
+ layers = []
139
+ for dim in fc_dims:
140
+ layers.append(nn.Linear(input_dim, dim))
141
+ layers.append(nn.BatchNorm1d(dim))
142
+ layers.append(nn.ReLU(inplace=True))
143
+ if dropout_p is not None:
144
+ layers.append(nn.Dropout(p=dropout_p))
145
+ input_dim = dim
146
+
147
+ self.feature_dim = fc_dims[-1]
148
+
149
+ return nn.Sequential(*layers)
150
+
151
+ def _init_params(self):
152
+ for m in self.modules():
153
+ if isinstance(m, nn.Conv2d):
154
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
155
+ if m.bias is not None:
156
+ nn.init.constant_(m.bias, 0)
157
+ elif isinstance(m, nn.BatchNorm2d):
158
+ nn.init.constant_(m.weight, 1)
159
+ nn.init.constant_(m.bias, 0)
160
+ elif isinstance(m, nn.BatchNorm1d):
161
+ nn.init.constant_(m.weight, 1)
162
+ nn.init.constant_(m.bias, 0)
163
+ elif isinstance(m, nn.Linear):
164
+ nn.init.normal_(m.weight, 0, 0.01)
165
+ if m.bias is not None:
166
+ nn.init.constant_(m.bias, 0)
167
+
168
+ def featuremaps(self, x):
169
+ x = self.conv1(x)
170
+ x = self.conv2(x)
171
+ x = self.conv3(x)
172
+ x = self.conv4(x)
173
+ x = self.conv5(x)
174
+ x = self.conv6(x)
175
+ x = self.conv7(x)
176
+ x = self.conv8(x)
177
+ x = self.conv9(x)
178
+ return x
179
+
180
+ def forward(self, x):
181
+ f = self.featuremaps(x)
182
+ v = self.global_avgpool(f)
183
+ v = v.view(v.size(0), -1)
184
+
185
+ if self.fc is not None:
186
+ v = self.fc(v)
187
+
188
+ if not self.training:
189
+ return v
190
+
191
+ y = self.classifier(v)
192
+
193
+ if self.loss == "softmax":
194
+ return y
195
+ elif self.loss == "triplet":
196
+ return y, v
197
+ else:
198
+ raise KeyError("Unsupported loss: {}".format(self.loss))
199
+
200
+
201
+ def init_pretrained_weights(model, model_url):
202
+ """Initializes model with pretrained weights.
203
+
204
+ Layers that don't match with pretrained layers in name or size are kept unchanged.
205
+ """
206
+ pretrain_dict = model_zoo.load_url(model_url)
207
+ model_dict = model.state_dict()
208
+ pretrain_dict = {
209
+ k: v
210
+ for k, v in pretrain_dict.items()
211
+ if k in model_dict and model_dict[k].size() == v.size()
212
+ }
213
+ model_dict.update(pretrain_dict)
214
+ model.load_state_dict(model_dict)
215
+
216
+
217
+ def mobilenetv2_x1_0(num_classes, loss, pretrained=True, **kwargs):
218
+ model = MobileNetV2(
219
+ num_classes, loss=loss, width_mult=1, fc_dims=None, dropout_p=None, **kwargs
220
+ )
221
+ if pretrained:
222
+ # init_pretrained_weights(model, model_urls['mobilenetv2_x1_0'])
223
+ import warnings
224
+
225
+ warnings.warn(
226
+ "The imagenet pretrained weights need to be manually downloaded from {}".format(
227
+ model_urls["mobilenetv2_x1_0"]
228
+ )
229
+ )
230
+ return model
231
+
232
+
233
+ def mobilenetv2_x1_4(num_classes, loss, pretrained=True, **kwargs):
234
+ model = MobileNetV2(
235
+ num_classes, loss=loss, width_mult=1.4, fc_dims=None, dropout_p=None, **kwargs
236
+ )
237
+ if pretrained:
238
+ # init_pretrained_weights(model, model_urls['mobilenetv2_x1_4'])
239
+ import warnings
240
+
241
+ warnings.warn(
242
+ "The imagenet pretrained weights need to be manually downloaded from {}".format(
243
+ model_urls["mobilenetv2_x1_4"]
244
+ )
245
+ )
246
+ return model
boxmot/appearance/backbones/osnet.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ from __future__ import absolute_import, division
4
+
5
+ import warnings
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ __all__ = ["osnet_x1_0", "osnet_x0_75", "osnet_x0_5", "osnet_x0_25", "osnet_ibn_x1_0"]
12
+
13
+ pretrained_urls = {
14
+ "osnet_x1_0": "https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY",
15
+ "osnet_x0_75": "https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq",
16
+ "osnet_x0_5": "https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i",
17
+ "osnet_x0_25": "https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs",
18
+ "osnet_ibn_x1_0": "https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l",
19
+ }
20
+
21
+
22
+ ##########
23
+ # Basic layers
24
+ ##########
25
+ class ConvLayer(nn.Module):
26
+ """Convolution layer (conv + bn + relu)."""
27
+
28
+ def __init__(
29
+ self,
30
+ in_channels,
31
+ out_channels,
32
+ kernel_size,
33
+ stride=1,
34
+ padding=0,
35
+ groups=1,
36
+ IN=False,
37
+ ):
38
+ super(ConvLayer, self).__init__()
39
+ self.conv = nn.Conv2d(
40
+ in_channels,
41
+ out_channels,
42
+ kernel_size,
43
+ stride=stride,
44
+ padding=padding,
45
+ bias=False,
46
+ groups=groups,
47
+ )
48
+ if IN:
49
+ self.bn = nn.InstanceNorm2d(out_channels, affine=True)
50
+ else:
51
+ self.bn = nn.BatchNorm2d(out_channels)
52
+ self.relu = nn.ReLU(inplace=True)
53
+
54
+ def forward(self, x):
55
+ x = self.conv(x)
56
+ x = self.bn(x)
57
+ x = self.relu(x)
58
+ return x
59
+
60
+
61
+ class Conv1x1(nn.Module):
62
+ """1x1 convolution + bn + relu."""
63
+
64
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
65
+ super(Conv1x1, self).__init__()
66
+ self.conv = nn.Conv2d(
67
+ in_channels,
68
+ out_channels,
69
+ 1,
70
+ stride=stride,
71
+ padding=0,
72
+ bias=False,
73
+ groups=groups,
74
+ )
75
+ self.bn = nn.BatchNorm2d(out_channels)
76
+ self.relu = nn.ReLU(inplace=True)
77
+
78
+ def forward(self, x):
79
+ x = self.conv(x)
80
+ x = self.bn(x)
81
+ x = self.relu(x)
82
+ return x
83
+
84
+
85
+ class Conv1x1Linear(nn.Module):
86
+ """1x1 convolution + bn (w/o non-linearity)."""
87
+
88
+ def __init__(self, in_channels, out_channels, stride=1):
89
+ super(Conv1x1Linear, self).__init__()
90
+ self.conv = nn.Conv2d(
91
+ in_channels, out_channels, 1, stride=stride, padding=0, bias=False
92
+ )
93
+ self.bn = nn.BatchNorm2d(out_channels)
94
+
95
+ def forward(self, x):
96
+ x = self.conv(x)
97
+ x = self.bn(x)
98
+ return x
99
+
100
+
101
+ class Conv3x3(nn.Module):
102
+ """3x3 convolution + bn + relu."""
103
+
104
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
105
+ super(Conv3x3, self).__init__()
106
+ self.conv = nn.Conv2d(
107
+ in_channels,
108
+ out_channels,
109
+ 3,
110
+ stride=stride,
111
+ padding=1,
112
+ bias=False,
113
+ groups=groups,
114
+ )
115
+ self.bn = nn.BatchNorm2d(out_channels)
116
+ self.relu = nn.ReLU(inplace=True)
117
+
118
+ def forward(self, x):
119
+ x = self.conv(x)
120
+ x = self.bn(x)
121
+ x = self.relu(x)
122
+ return x
123
+
124
+
125
+ class LightConv3x3(nn.Module):
126
+ """Lightweight 3x3 convolution.
127
+
128
+ 1x1 (linear) + dw 3x3 (nonlinear).
129
+ """
130
+
131
+ def __init__(self, in_channels, out_channels):
132
+ super(LightConv3x3, self).__init__()
133
+ self.conv1 = nn.Conv2d(
134
+ in_channels, out_channels, 1, stride=1, padding=0, bias=False
135
+ )
136
+ self.conv2 = nn.Conv2d(
137
+ out_channels,
138
+ out_channels,
139
+ 3,
140
+ stride=1,
141
+ padding=1,
142
+ bias=False,
143
+ groups=out_channels,
144
+ )
145
+ self.bn = nn.BatchNorm2d(out_channels)
146
+ self.relu = nn.ReLU(inplace=True)
147
+
148
+ def forward(self, x):
149
+ x = self.conv1(x)
150
+ x = self.conv2(x)
151
+ x = self.bn(x)
152
+ x = self.relu(x)
153
+ return x
154
+
155
+
156
+ ##########
157
+ # Building blocks for omni-scale feature learning
158
+ ##########
159
+ class ChannelGate(nn.Module):
160
+ """A mini-network that generates channel-wise gates conditioned on input tensor."""
161
+
162
+ def __init__(
163
+ self,
164
+ in_channels,
165
+ num_gates=None,
166
+ return_gates=False,
167
+ gate_activation="sigmoid",
168
+ reduction=16,
169
+ layer_norm=False,
170
+ ):
171
+ super(ChannelGate, self).__init__()
172
+ if num_gates is None:
173
+ num_gates = in_channels
174
+ self.return_gates = return_gates
175
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
176
+ self.fc1 = nn.Conv2d(
177
+ in_channels, in_channels // reduction, kernel_size=1, bias=True, padding=0
178
+ )
179
+ self.norm1 = None
180
+ if layer_norm:
181
+ self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
182
+ self.relu = nn.ReLU(inplace=True)
183
+ self.fc2 = nn.Conv2d(
184
+ in_channels // reduction, num_gates, kernel_size=1, bias=True, padding=0
185
+ )
186
+ if gate_activation == "sigmoid":
187
+ self.gate_activation = nn.Sigmoid()
188
+ elif gate_activation == "relu":
189
+ self.gate_activation = nn.ReLU(inplace=True)
190
+ elif gate_activation == "linear":
191
+ self.gate_activation = None
192
+ else:
193
+ raise RuntimeError("Unknown gate activation: {}".format(gate_activation))
194
+
195
+ def forward(self, x):
196
+ input = x
197
+ x = self.global_avgpool(x)
198
+ x = self.fc1(x)
199
+ if self.norm1 is not None:
200
+ x = self.norm1(x)
201
+ x = self.relu(x)
202
+ x = self.fc2(x)
203
+ if self.gate_activation is not None:
204
+ x = self.gate_activation(x)
205
+ if self.return_gates:
206
+ return x
207
+ return input * x
208
+
209
+
210
+ class OSBlock(nn.Module):
211
+ """Omni-scale feature learning block."""
212
+
213
+ def __init__(
214
+ self, in_channels, out_channels, IN=False, bottleneck_reduction=4, **kwargs
215
+ ):
216
+ super(OSBlock, self).__init__()
217
+ mid_channels = out_channels // bottleneck_reduction
218
+ self.conv1 = Conv1x1(in_channels, mid_channels)
219
+ self.conv2a = LightConv3x3(mid_channels, mid_channels)
220
+ self.conv2b = nn.Sequential(
221
+ LightConv3x3(mid_channels, mid_channels),
222
+ LightConv3x3(mid_channels, mid_channels),
223
+ )
224
+ self.conv2c = nn.Sequential(
225
+ LightConv3x3(mid_channels, mid_channels),
226
+ LightConv3x3(mid_channels, mid_channels),
227
+ LightConv3x3(mid_channels, mid_channels),
228
+ )
229
+ self.conv2d = nn.Sequential(
230
+ LightConv3x3(mid_channels, mid_channels),
231
+ LightConv3x3(mid_channels, mid_channels),
232
+ LightConv3x3(mid_channels, mid_channels),
233
+ LightConv3x3(mid_channels, mid_channels),
234
+ )
235
+ self.gate = ChannelGate(mid_channels)
236
+ self.conv3 = Conv1x1Linear(mid_channels, out_channels)
237
+ self.downsample = None
238
+ if in_channels != out_channels:
239
+ self.downsample = Conv1x1Linear(in_channels, out_channels)
240
+ self.IN = None
241
+ if IN:
242
+ self.IN = nn.InstanceNorm2d(out_channels, affine=True)
243
+
244
+ def forward(self, x):
245
+ identity = x
246
+ x1 = self.conv1(x)
247
+ x2a = self.conv2a(x1)
248
+ x2b = self.conv2b(x1)
249
+ x2c = self.conv2c(x1)
250
+ x2d = self.conv2d(x1)
251
+ x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
252
+ x3 = self.conv3(x2)
253
+ if self.downsample is not None:
254
+ identity = self.downsample(identity)
255
+ out = x3 + identity
256
+ if self.IN is not None:
257
+ out = self.IN(out)
258
+ return F.relu(out)
259
+
260
+
261
+ ##########
262
+ # Network architecture
263
+ ##########
264
+ class OSNet(nn.Module):
265
+ """Omni-Scale Network.
266
+
267
+ Reference:
268
+ - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
269
+ - Zhou et al. Learning Generalisable Omni-Scale Representations
270
+ for Person Re-Identification. TPAMI, 2021.
271
+ """
272
+
273
+ def __init__(
274
+ self,
275
+ num_classes,
276
+ blocks,
277
+ layers,
278
+ channels,
279
+ feature_dim=512,
280
+ loss="softmax",
281
+ IN=False,
282
+ **kwargs
283
+ ):
284
+ super(OSNet, self).__init__()
285
+ num_blocks = len(blocks)
286
+ assert num_blocks == len(layers)
287
+ assert num_blocks == len(channels) - 1
288
+ self.loss = loss
289
+ self.feature_dim = feature_dim
290
+
291
+ # convolutional backbone
292
+ self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
293
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
294
+ self.conv2 = self._make_layer(
295
+ blocks[0],
296
+ layers[0],
297
+ channels[0],
298
+ channels[1],
299
+ reduce_spatial_size=True,
300
+ IN=IN,
301
+ )
302
+ self.conv3 = self._make_layer(
303
+ blocks[1], layers[1], channels[1], channels[2], reduce_spatial_size=True
304
+ )
305
+ self.conv4 = self._make_layer(
306
+ blocks[2], layers[2], channels[2], channels[3], reduce_spatial_size=False
307
+ )
308
+ self.conv5 = Conv1x1(channels[3], channels[3])
309
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
310
+ # fully connected layer
311
+ self.fc = self._construct_fc_layer(
312
+ self.feature_dim, channels[3], dropout_p=None
313
+ )
314
+ # identity classification layer
315
+ self.classifier = nn.Linear(self.feature_dim, num_classes)
316
+
317
+ self._init_params()
318
+
319
+ def _make_layer(
320
+ self, block, layer, in_channels, out_channels, reduce_spatial_size, IN=False
321
+ ):
322
+ layers = []
323
+
324
+ layers.append(block(in_channels, out_channels, IN=IN))
325
+ for i in range(1, layer):
326
+ layers.append(block(out_channels, out_channels, IN=IN))
327
+
328
+ if reduce_spatial_size:
329
+ layers.append(
330
+ nn.Sequential(
331
+ Conv1x1(out_channels, out_channels), nn.AvgPool2d(2, stride=2)
332
+ )
333
+ )
334
+
335
+ return nn.Sequential(*layers)
336
+
337
+ def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
338
+ if fc_dims is None or fc_dims < 0:
339
+ self.feature_dim = input_dim
340
+ return None
341
+
342
+ if isinstance(fc_dims, int):
343
+ fc_dims = [fc_dims]
344
+
345
+ layers = []
346
+ for dim in fc_dims:
347
+ layers.append(nn.Linear(input_dim, dim))
348
+ layers.append(nn.BatchNorm1d(dim))
349
+ layers.append(nn.ReLU(inplace=True))
350
+ if dropout_p is not None:
351
+ layers.append(nn.Dropout(p=dropout_p))
352
+ input_dim = dim
353
+
354
+ self.feature_dim = fc_dims[-1]
355
+
356
+ return nn.Sequential(*layers)
357
+
358
+ def _init_params(self):
359
+ for m in self.modules():
360
+ if isinstance(m, nn.Conv2d):
361
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
362
+ if m.bias is not None:
363
+ nn.init.constant_(m.bias, 0)
364
+
365
+ elif isinstance(m, nn.BatchNorm2d):
366
+ nn.init.constant_(m.weight, 1)
367
+ nn.init.constant_(m.bias, 0)
368
+
369
+ elif isinstance(m, nn.BatchNorm1d):
370
+ nn.init.constant_(m.weight, 1)
371
+ nn.init.constant_(m.bias, 0)
372
+
373
+ elif isinstance(m, nn.Linear):
374
+ nn.init.normal_(m.weight, 0, 0.01)
375
+ if m.bias is not None:
376
+ nn.init.constant_(m.bias, 0)
377
+
378
+ def featuremaps(self, x):
379
+ x = self.conv1(x)
380
+ x = self.maxpool(x)
381
+ x = self.conv2(x)
382
+ x = self.conv3(x)
383
+ x = self.conv4(x)
384
+ x = self.conv5(x)
385
+ return x
386
+
387
+ def forward(self, x, return_featuremaps=False):
388
+ x = self.featuremaps(x)
389
+ if return_featuremaps:
390
+ return x
391
+ v = self.global_avgpool(x)
392
+ v = v.view(v.size(0), -1)
393
+ if self.fc is not None:
394
+ v = self.fc(v)
395
+ if not self.training:
396
+ return v
397
+ y = self.classifier(v)
398
+ if self.loss == "softmax":
399
+ return y
400
+ elif self.loss == "triplet":
401
+ return y, v
402
+ else:
403
+ raise KeyError("Unsupported loss: {}".format(self.loss))
404
+
405
+
406
+ def init_pretrained_weights(model, key=""):
407
+ """Initializes model with pretrained weights.
408
+
409
+ Layers that don't match with pretrained layers in name or size are kept unchanged.
410
+ """
411
+ import errno
412
+ import os
413
+ from collections import OrderedDict
414
+
415
+ import gdown
416
+
417
+ def _get_torch_home():
418
+ ENV_TORCH_HOME = "TORCH_HOME"
419
+ ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
420
+ DEFAULT_CACHE_DIR = "~/.cache"
421
+ torch_home = os.path.expanduser(
422
+ os.getenv(
423
+ ENV_TORCH_HOME,
424
+ os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
425
+ )
426
+ )
427
+ return torch_home
428
+
429
+ torch_home = _get_torch_home()
430
+ model_dir = os.path.join(torch_home, "checkpoints")
431
+ try:
432
+ os.makedirs(model_dir)
433
+ except OSError as e:
434
+ if e.errno == errno.EEXIST:
435
+ # Directory already exists, ignore.
436
+ pass
437
+ else:
438
+ # Unexpected OSError, re-raise.
439
+ raise
440
+ filename = key + "_imagenet.pth"
441
+ cached_file = os.path.join(model_dir, filename)
442
+
443
+ if not os.path.exists(cached_file):
444
+ gdown.download(pretrained_urls[key], cached_file, quiet=False)
445
+
446
+ state_dict = torch.load(cached_file)
447
+ model_dict = model.state_dict()
448
+ new_state_dict = OrderedDict()
449
+ matched_layers, discarded_layers = [], []
450
+
451
+ for k, v in state_dict.items():
452
+ if k.startswith("module."):
453
+ k = k[7:] # discard module.
454
+
455
+ if k in model_dict and model_dict[k].size() == v.size():
456
+ new_state_dict[k] = v
457
+ matched_layers.append(k)
458
+ else:
459
+ discarded_layers.append(k)
460
+
461
+ model_dict.update(new_state_dict)
462
+ model.load_state_dict(model_dict)
463
+
464
+ if len(matched_layers) == 0:
465
+ warnings.warn(
466
+ 'The pretrained weights from "{}" cannot be loaded, '
467
+ "please check the key names manually "
468
+ "(** ignored and continue **)".format(cached_file)
469
+ )
470
+ else:
471
+ print(
472
+ 'Successfully loaded imagenet pretrained weights from "{}"'.format(
473
+ cached_file
474
+ )
475
+ )
476
+ if len(discarded_layers) > 0:
477
+ print(
478
+ "** The following layers are discarded "
479
+ "due to unmatched keys or layer size: {}".format(discarded_layers)
480
+ )
481
+
482
+
483
+ ##########
484
+ # Instantiation
485
+ ##########
486
+ def osnet_x1_0(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
487
+ # standard size (width x1.0)
488
+ model = OSNet(
489
+ num_classes,
490
+ blocks=[OSBlock, OSBlock, OSBlock],
491
+ layers=[2, 2, 2],
492
+ channels=[64, 256, 384, 512],
493
+ loss=loss,
494
+ **kwargs
495
+ )
496
+ if pretrained:
497
+ init_pretrained_weights(model, key="osnet_x1_0")
498
+ return model
499
+
500
+
501
+ def osnet_x0_75(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
502
+ # medium size (width x0.75)
503
+ model = OSNet(
504
+ num_classes,
505
+ blocks=[OSBlock, OSBlock, OSBlock],
506
+ layers=[2, 2, 2],
507
+ channels=[48, 192, 288, 384],
508
+ loss=loss,
509
+ **kwargs
510
+ )
511
+ if pretrained:
512
+ init_pretrained_weights(model, key="osnet_x0_75")
513
+ return model
514
+
515
+
516
+ def osnet_x0_5(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
517
+ # tiny size (width x0.5)
518
+ model = OSNet(
519
+ num_classes,
520
+ blocks=[OSBlock, OSBlock, OSBlock],
521
+ layers=[2, 2, 2],
522
+ channels=[32, 128, 192, 256],
523
+ loss=loss,
524
+ **kwargs
525
+ )
526
+ if pretrained:
527
+ init_pretrained_weights(model, key="osnet_x0_5")
528
+ return model
529
+
530
+
531
+ def osnet_x0_25(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
532
+ # very tiny size (width x0.25)
533
+ model = OSNet(
534
+ num_classes,
535
+ blocks=[OSBlock, OSBlock, OSBlock],
536
+ layers=[2, 2, 2],
537
+ channels=[16, 64, 96, 128],
538
+ loss=loss,
539
+ **kwargs
540
+ )
541
+ if pretrained:
542
+ init_pretrained_weights(model, key="osnet_x0_25")
543
+ return model
544
+
545
+
546
+ def osnet_ibn_x1_0(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
547
+ # standard size (width x1.0) + IBN layer
548
+ # Ref: Pan et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net. ECCV, 2018.
549
+ model = OSNet(
550
+ num_classes,
551
+ blocks=[OSBlock, OSBlock, OSBlock],
552
+ layers=[2, 2, 2],
553
+ channels=[64, 256, 384, 512],
554
+ loss=loss,
555
+ IN=True,
556
+ **kwargs
557
+ )
558
+ if pretrained:
559
+ init_pretrained_weights(model, key="osnet_ibn_x1_0")
560
+ return model
boxmot/appearance/backbones/osnet_ain.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ from __future__ import absolute_import, division
4
+
5
+ import warnings
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ __all__ = ["osnet_ain_x1_0", "osnet_ain_x0_75", "osnet_ain_x0_5", "osnet_ain_x0_25"]
12
+
13
+ pretrained_urls = {
14
+ "osnet_ain_x1_0": "https://drive.google.com/uc?id=1-CaioD9NaqbHK_kzSMW8VE4_3KcsRjEo",
15
+ "osnet_ain_x0_75": "https://drive.google.com/uc?id=1apy0hpsMypqstfencdH-jKIUEFOW4xoM",
16
+ "osnet_ain_x0_5": "https://drive.google.com/uc?id=1KusKvEYyKGDTUBVRxRiz55G31wkihB6l",
17
+ "osnet_ain_x0_25": "https://drive.google.com/uc?id=1SxQt2AvmEcgWNhaRb2xC4rP6ZwVDP0Wt",
18
+ }
19
+
20
+
21
+ ##########
22
+ # Basic layers
23
+ ##########
24
+ class ConvLayer(nn.Module):
25
+ """Convolution layer (conv + bn + relu)."""
26
+
27
+ def __init__(
28
+ self,
29
+ in_channels,
30
+ out_channels,
31
+ kernel_size,
32
+ stride=1,
33
+ padding=0,
34
+ groups=1,
35
+ IN=False,
36
+ ):
37
+ super(ConvLayer, self).__init__()
38
+ self.conv = nn.Conv2d(
39
+ in_channels,
40
+ out_channels,
41
+ kernel_size,
42
+ stride=stride,
43
+ padding=padding,
44
+ bias=False,
45
+ groups=groups,
46
+ )
47
+ if IN:
48
+ self.bn = nn.InstanceNorm2d(out_channels, affine=True)
49
+ else:
50
+ self.bn = nn.BatchNorm2d(out_channels)
51
+ self.relu = nn.ReLU()
52
+
53
+ def forward(self, x):
54
+ x = self.conv(x)
55
+ x = self.bn(x)
56
+ return self.relu(x)
57
+
58
+
59
+ class Conv1x1(nn.Module):
60
+ """1x1 convolution + bn + relu."""
61
+
62
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
63
+ super(Conv1x1, self).__init__()
64
+ self.conv = nn.Conv2d(
65
+ in_channels,
66
+ out_channels,
67
+ 1,
68
+ stride=stride,
69
+ padding=0,
70
+ bias=False,
71
+ groups=groups,
72
+ )
73
+ self.bn = nn.BatchNorm2d(out_channels)
74
+ self.relu = nn.ReLU()
75
+
76
+ def forward(self, x):
77
+ x = self.conv(x)
78
+ x = self.bn(x)
79
+ return self.relu(x)
80
+
81
+
82
+ class Conv1x1Linear(nn.Module):
83
+ """1x1 convolution + bn (w/o non-linearity)."""
84
+
85
+ def __init__(self, in_channels, out_channels, stride=1, bn=True):
86
+ super(Conv1x1Linear, self).__init__()
87
+ self.conv = nn.Conv2d(
88
+ in_channels, out_channels, 1, stride=stride, padding=0, bias=False
89
+ )
90
+ self.bn = None
91
+ if bn:
92
+ self.bn = nn.BatchNorm2d(out_channels)
93
+
94
+ def forward(self, x):
95
+ x = self.conv(x)
96
+ if self.bn is not None:
97
+ x = self.bn(x)
98
+ return x
99
+
100
+
101
+ class Conv3x3(nn.Module):
102
+ """3x3 convolution + bn + relu."""
103
+
104
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
105
+ super(Conv3x3, self).__init__()
106
+ self.conv = nn.Conv2d(
107
+ in_channels,
108
+ out_channels,
109
+ 3,
110
+ stride=stride,
111
+ padding=1,
112
+ bias=False,
113
+ groups=groups,
114
+ )
115
+ self.bn = nn.BatchNorm2d(out_channels)
116
+ self.relu = nn.ReLU()
117
+
118
+ def forward(self, x):
119
+ x = self.conv(x)
120
+ x = self.bn(x)
121
+ return self.relu(x)
122
+
123
+
124
+ class LightConv3x3(nn.Module):
125
+ """Lightweight 3x3 convolution.
126
+
127
+ 1x1 (linear) + dw 3x3 (nonlinear).
128
+ """
129
+
130
+ def __init__(self, in_channels, out_channels):
131
+ super(LightConv3x3, self).__init__()
132
+ self.conv1 = nn.Conv2d(
133
+ in_channels, out_channels, 1, stride=1, padding=0, bias=False
134
+ )
135
+ self.conv2 = nn.Conv2d(
136
+ out_channels,
137
+ out_channels,
138
+ 3,
139
+ stride=1,
140
+ padding=1,
141
+ bias=False,
142
+ groups=out_channels,
143
+ )
144
+ self.bn = nn.BatchNorm2d(out_channels)
145
+ self.relu = nn.ReLU()
146
+
147
+ def forward(self, x):
148
+ x = self.conv1(x)
149
+ x = self.conv2(x)
150
+ x = self.bn(x)
151
+ return self.relu(x)
152
+
153
+
154
+ class LightConvStream(nn.Module):
155
+ """Lightweight convolution stream."""
156
+
157
+ def __init__(self, in_channels, out_channels, depth):
158
+ super(LightConvStream, self).__init__()
159
+ assert depth >= 1, "depth must be equal to or larger than 1, but got {}".format(
160
+ depth
161
+ )
162
+ layers = []
163
+ layers += [LightConv3x3(in_channels, out_channels)]
164
+ for i in range(depth - 1):
165
+ layers += [LightConv3x3(out_channels, out_channels)]
166
+ self.layers = nn.Sequential(*layers)
167
+
168
+ def forward(self, x):
169
+ return self.layers(x)
170
+
171
+
172
+ ##########
173
+ # Building blocks for omni-scale feature learning
174
+ ##########
175
+ class ChannelGate(nn.Module):
176
+ """A mini-network that generates channel-wise gates conditioned on input tensor."""
177
+
178
+ def __init__(
179
+ self,
180
+ in_channels,
181
+ num_gates=None,
182
+ return_gates=False,
183
+ gate_activation="sigmoid",
184
+ reduction=16,
185
+ layer_norm=False,
186
+ ):
187
+ super(ChannelGate, self).__init__()
188
+ if num_gates is None:
189
+ num_gates = in_channels
190
+ self.return_gates = return_gates
191
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
192
+ self.fc1 = nn.Conv2d(
193
+ in_channels, in_channels // reduction, kernel_size=1, bias=True, padding=0
194
+ )
195
+ self.norm1 = None
196
+ if layer_norm:
197
+ self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
198
+ self.relu = nn.ReLU()
199
+ self.fc2 = nn.Conv2d(
200
+ in_channels // reduction, num_gates, kernel_size=1, bias=True, padding=0
201
+ )
202
+ if gate_activation == "sigmoid":
203
+ self.gate_activation = nn.Sigmoid()
204
+ elif gate_activation == "relu":
205
+ self.gate_activation = nn.ReLU()
206
+ elif gate_activation == "linear":
207
+ self.gate_activation = None
208
+ else:
209
+ raise RuntimeError("Unknown gate activation: {}".format(gate_activation))
210
+
211
+ def forward(self, x):
212
+ input = x
213
+ x = self.global_avgpool(x)
214
+ x = self.fc1(x)
215
+ if self.norm1 is not None:
216
+ x = self.norm1(x)
217
+ x = self.relu(x)
218
+ x = self.fc2(x)
219
+ if self.gate_activation is not None:
220
+ x = self.gate_activation(x)
221
+ if self.return_gates:
222
+ return x
223
+ return input * x
224
+
225
+
226
+ class OSBlock(nn.Module):
227
+ """Omni-scale feature learning block."""
228
+
229
+ def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
230
+ super(OSBlock, self).__init__()
231
+ assert T >= 1
232
+ assert out_channels >= reduction and out_channels % reduction == 0
233
+ mid_channels = out_channels // reduction
234
+
235
+ self.conv1 = Conv1x1(in_channels, mid_channels)
236
+ self.conv2 = nn.ModuleList()
237
+ for t in range(1, T + 1):
238
+ self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
239
+ self.gate = ChannelGate(mid_channels)
240
+ self.conv3 = Conv1x1Linear(mid_channels, out_channels)
241
+ self.downsample = None
242
+ if in_channels != out_channels:
243
+ self.downsample = Conv1x1Linear(in_channels, out_channels)
244
+
245
+ def forward(self, x):
246
+ identity = x
247
+ x1 = self.conv1(x)
248
+ x2 = 0
249
+ for conv2_t in self.conv2:
250
+ x2_t = conv2_t(x1)
251
+ x2 = x2 + self.gate(x2_t)
252
+ x3 = self.conv3(x2)
253
+ if self.downsample is not None:
254
+ identity = self.downsample(identity)
255
+ out = x3 + identity
256
+ return F.relu(out)
257
+
258
+
259
+ class OSBlockINin(nn.Module):
260
+ """Omni-scale feature learning block with instance normalization."""
261
+
262
+ def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
263
+ super(OSBlockINin, self).__init__()
264
+ assert T >= 1
265
+ assert out_channels >= reduction and out_channels % reduction == 0
266
+ mid_channels = out_channels // reduction
267
+
268
+ self.conv1 = Conv1x1(in_channels, mid_channels)
269
+ self.conv2 = nn.ModuleList()
270
+ for t in range(1, T + 1):
271
+ self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
272
+ self.gate = ChannelGate(mid_channels)
273
+ self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False)
274
+ self.downsample = None
275
+ if in_channels != out_channels:
276
+ self.downsample = Conv1x1Linear(in_channels, out_channels)
277
+ self.IN = nn.InstanceNorm2d(out_channels, affine=True)
278
+
279
+ def forward(self, x):
280
+ identity = x
281
+ x1 = self.conv1(x)
282
+ x2 = 0
283
+ for conv2_t in self.conv2:
284
+ x2_t = conv2_t(x1)
285
+ x2 = x2 + self.gate(x2_t)
286
+ x3 = self.conv3(x2)
287
+ x3 = self.IN(x3) # IN inside residual
288
+ if self.downsample is not None:
289
+ identity = self.downsample(identity)
290
+ out = x3 + identity
291
+ return F.relu(out)
292
+
293
+
294
+ ##########
295
+ # Network architecture
296
+ ##########
297
+ class OSNet(nn.Module):
298
+ """Omni-Scale Network.
299
+
300
+ Reference:
301
+ - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
302
+ - Zhou et al. Learning Generalisable Omni-Scale Representations
303
+ for Person Re-Identification. TPAMI, 2021.
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ num_classes,
309
+ blocks,
310
+ layers,
311
+ channels,
312
+ feature_dim=512,
313
+ loss="softmax",
314
+ conv1_IN=False,
315
+ **kwargs
316
+ ):
317
+ super(OSNet, self).__init__()
318
+ num_blocks = len(blocks)
319
+ assert num_blocks == len(layers)
320
+ assert num_blocks == len(channels) - 1
321
+ self.loss = loss
322
+ self.feature_dim = feature_dim
323
+
324
+ # convolutional backbone
325
+ self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=conv1_IN)
326
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
327
+ self.conv2 = self._make_layer(blocks[0], layers[0], channels[0], channels[1])
328
+ self.pool2 = nn.Sequential(
329
+ Conv1x1(channels[1], channels[1]), nn.AvgPool2d(2, stride=2)
330
+ )
331
+ self.conv3 = self._make_layer(blocks[1], layers[1], channels[1], channels[2])
332
+ self.pool3 = nn.Sequential(
333
+ Conv1x1(channels[2], channels[2]), nn.AvgPool2d(2, stride=2)
334
+ )
335
+ self.conv4 = self._make_layer(blocks[2], layers[2], channels[2], channels[3])
336
+ self.conv5 = Conv1x1(channels[3], channels[3])
337
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
338
+ # fully connected layer
339
+ self.fc = self._construct_fc_layer(
340
+ self.feature_dim, channels[3], dropout_p=None
341
+ )
342
+ # identity classification layer
343
+ self.classifier = nn.Linear(self.feature_dim, num_classes)
344
+
345
+ self._init_params()
346
+
347
+ def _make_layer(self, blocks, layer, in_channels, out_channels):
348
+ layers = []
349
+ layers += [blocks[0](in_channels, out_channels)]
350
+ for i in range(1, len(blocks)):
351
+ layers += [blocks[i](out_channels, out_channels)]
352
+ return nn.Sequential(*layers)
353
+
354
+ def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
355
+ if fc_dims is None or fc_dims < 0:
356
+ self.feature_dim = input_dim
357
+ return None
358
+
359
+ if isinstance(fc_dims, int):
360
+ fc_dims = [fc_dims]
361
+
362
+ layers = []
363
+ for dim in fc_dims:
364
+ layers.append(nn.Linear(input_dim, dim))
365
+ layers.append(nn.BatchNorm1d(dim))
366
+ layers.append(nn.ReLU())
367
+ if dropout_p is not None:
368
+ layers.append(nn.Dropout(p=dropout_p))
369
+ input_dim = dim
370
+
371
+ self.feature_dim = fc_dims[-1]
372
+
373
+ return nn.Sequential(*layers)
374
+
375
+ def _init_params(self):
376
+ for m in self.modules():
377
+ if isinstance(m, nn.Conv2d):
378
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
379
+ if m.bias is not None:
380
+ nn.init.constant_(m.bias, 0)
381
+
382
+ elif isinstance(m, nn.BatchNorm2d):
383
+ nn.init.constant_(m.weight, 1)
384
+ nn.init.constant_(m.bias, 0)
385
+
386
+ elif isinstance(m, nn.BatchNorm1d):
387
+ nn.init.constant_(m.weight, 1)
388
+ nn.init.constant_(m.bias, 0)
389
+
390
+ elif isinstance(m, nn.InstanceNorm2d):
391
+ nn.init.constant_(m.weight, 1)
392
+ nn.init.constant_(m.bias, 0)
393
+
394
+ elif isinstance(m, nn.Linear):
395
+ nn.init.normal_(m.weight, 0, 0.01)
396
+ if m.bias is not None:
397
+ nn.init.constant_(m.bias, 0)
398
+
399
+ def featuremaps(self, x):
400
+ x = self.conv1(x)
401
+ x = self.maxpool(x)
402
+ x = self.conv2(x)
403
+ x = self.pool2(x)
404
+ x = self.conv3(x)
405
+ x = self.pool3(x)
406
+ x = self.conv4(x)
407
+ x = self.conv5(x)
408
+ return x
409
+
410
+ def forward(self, x, return_featuremaps=False):
411
+ x = self.featuremaps(x)
412
+ if return_featuremaps:
413
+ return x
414
+ v = self.global_avgpool(x)
415
+ v = v.view(v.size(0), -1)
416
+ if self.fc is not None:
417
+ v = self.fc(v)
418
+ if not self.training:
419
+ return v
420
+ y = self.classifier(v)
421
+ if self.loss == "softmax":
422
+ return y
423
+ elif self.loss == "triplet":
424
+ return y, v
425
+ else:
426
+ raise KeyError("Unsupported loss: {}".format(self.loss))
427
+
428
+
429
+ def init_pretrained_weights(model, key=""):
430
+ """Initializes model with pretrained weights.
431
+
432
+ Layers that don't match with pretrained layers in name or size are kept unchanged.
433
+ """
434
+ import errno
435
+ import os
436
+ from collections import OrderedDict
437
+
438
+ import gdown
439
+
440
+ def _get_torch_home():
441
+ ENV_TORCH_HOME = "TORCH_HOME"
442
+ ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
443
+ DEFAULT_CACHE_DIR = "~/.cache"
444
+ torch_home = os.path.expanduser(
445
+ os.getenv(
446
+ ENV_TORCH_HOME,
447
+ os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
448
+ )
449
+ )
450
+ return torch_home
451
+
452
+ torch_home = _get_torch_home()
453
+ model_dir = os.path.join(torch_home, "checkpoints")
454
+ try:
455
+ os.makedirs(model_dir)
456
+ except OSError as e:
457
+ if e.errno == errno.EEXIST:
458
+ # Directory already exists, ignore.
459
+ pass
460
+ else:
461
+ # Unexpected OSError, re-raise.
462
+ raise
463
+ filename = key + "_imagenet.pth"
464
+ cached_file = os.path.join(model_dir, filename)
465
+
466
+ if not os.path.exists(cached_file):
467
+ gdown.download(pretrained_urls[key], cached_file, quiet=False)
468
+
469
+ state_dict = torch.load(cached_file)
470
+ model_dict = model.state_dict()
471
+ new_state_dict = OrderedDict()
472
+ matched_layers, discarded_layers = [], []
473
+
474
+ for k, v in state_dict.items():
475
+ if k.startswith("module."):
476
+ k = k[7:] # discard module.
477
+
478
+ if k in model_dict and model_dict[k].size() == v.size():
479
+ new_state_dict[k] = v
480
+ matched_layers.append(k)
481
+ else:
482
+ discarded_layers.append(k)
483
+
484
+ model_dict.update(new_state_dict)
485
+ model.load_state_dict(model_dict)
486
+
487
+ if len(matched_layers) == 0:
488
+ warnings.warn(
489
+ 'The pretrained weights from "{}" cannot be loaded, '
490
+ "please check the key names manually "
491
+ "(** ignored and continue **)".format(cached_file)
492
+ )
493
+ else:
494
+ print(
495
+ 'Successfully loaded imagenet pretrained weights from "{}"'.format(
496
+ cached_file
497
+ )
498
+ )
499
+ if len(discarded_layers) > 0:
500
+ print(
501
+ "** The following layers are discarded "
502
+ "due to unmatched keys or layer size: {}".format(discarded_layers)
503
+ )
504
+
505
+
506
+ ##########
507
+ # Instantiation
508
+ ##########
509
+ def osnet_ain_x1_0(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
510
+ model = OSNet(
511
+ num_classes,
512
+ blocks=[
513
+ [OSBlockINin, OSBlockINin],
514
+ [OSBlock, OSBlockINin],
515
+ [OSBlockINin, OSBlock],
516
+ ],
517
+ layers=[2, 2, 2],
518
+ channels=[64, 256, 384, 512],
519
+ loss=loss,
520
+ conv1_IN=True,
521
+ **kwargs
522
+ )
523
+ if pretrained:
524
+ init_pretrained_weights(model, key="osnet_ain_x1_0")
525
+ return model
526
+
527
+
528
+ def osnet_ain_x0_75(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
529
+ model = OSNet(
530
+ num_classes,
531
+ blocks=[
532
+ [OSBlockINin, OSBlockINin],
533
+ [OSBlock, OSBlockINin],
534
+ [OSBlockINin, OSBlock],
535
+ ],
536
+ layers=[2, 2, 2],
537
+ channels=[48, 192, 288, 384],
538
+ loss=loss,
539
+ conv1_IN=True,
540
+ **kwargs
541
+ )
542
+ if pretrained:
543
+ init_pretrained_weights(model, key="osnet_ain_x0_75")
544
+ return model
545
+
546
+
547
+ def osnet_ain_x0_5(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
548
+ model = OSNet(
549
+ num_classes,
550
+ blocks=[
551
+ [OSBlockINin, OSBlockINin],
552
+ [OSBlock, OSBlockINin],
553
+ [OSBlockINin, OSBlock],
554
+ ],
555
+ layers=[2, 2, 2],
556
+ channels=[32, 128, 192, 256],
557
+ loss=loss,
558
+ conv1_IN=True,
559
+ **kwargs
560
+ )
561
+ if pretrained:
562
+ init_pretrained_weights(model, key="osnet_ain_x0_5")
563
+ return model
564
+
565
+
566
+ def osnet_ain_x0_25(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
567
+ model = OSNet(
568
+ num_classes,
569
+ blocks=[
570
+ [OSBlockINin, OSBlockINin],
571
+ [OSBlock, OSBlockINin],
572
+ [OSBlockINin, OSBlock],
573
+ ],
574
+ layers=[2, 2, 2],
575
+ channels=[16, 64, 96, 128],
576
+ loss=loss,
577
+ conv1_IN=True,
578
+ **kwargs
579
+ )
580
+ if pretrained:
581
+ init_pretrained_weights(model, key="osnet_ain_x0_25")
582
+ return model
boxmot/appearance/backbones/resnet.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ """
4
+ Code source: https://github.com/pytorch/vision
5
+ """
6
+ from __future__ import absolute_import, division
7
+
8
+ import torch.utils.model_zoo as model_zoo
9
+ from torch import nn
10
+
11
+ __all__ = [
12
+ "resnet18",
13
+ "resnet34",
14
+ "resnet50",
15
+ "resnet101",
16
+ "resnet152",
17
+ "resnext50_32x4d",
18
+ "resnext101_32x8d",
19
+ "resnet50_fc512",
20
+ ]
21
+
22
+ model_urls = {
23
+ "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
24
+ "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
25
+ "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
26
+ "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
27
+ "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
28
+ "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
29
+ "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
30
+ }
31
+
32
+
33
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
34
+ """3x3 convolution with padding"""
35
+ return nn.Conv2d(
36
+ in_planes,
37
+ out_planes,
38
+ kernel_size=3,
39
+ stride=stride,
40
+ padding=dilation,
41
+ groups=groups,
42
+ bias=False,
43
+ dilation=dilation,
44
+ )
45
+
46
+
47
+ def conv1x1(in_planes, out_planes, stride=1):
48
+ """1x1 convolution"""
49
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
50
+
51
+
52
+ class BasicBlock(nn.Module):
53
+ expansion = 1
54
+
55
+ def __init__(
56
+ self,
57
+ inplanes,
58
+ planes,
59
+ stride=1,
60
+ downsample=None,
61
+ groups=1,
62
+ base_width=64,
63
+ dilation=1,
64
+ norm_layer=None,
65
+ ):
66
+ super(BasicBlock, self).__init__()
67
+ if norm_layer is None:
68
+ norm_layer = nn.BatchNorm2d
69
+ if groups != 1 or base_width != 64:
70
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
71
+ if dilation > 1:
72
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
73
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
74
+ self.conv1 = conv3x3(inplanes, planes, stride)
75
+ self.bn1 = norm_layer(planes)
76
+ self.relu = nn.ReLU(inplace=True)
77
+ self.conv2 = conv3x3(planes, planes)
78
+ self.bn2 = norm_layer(planes)
79
+ self.downsample = downsample
80
+ self.stride = stride
81
+
82
+ def forward(self, x):
83
+ identity = x
84
+
85
+ out = self.conv1(x)
86
+ out = self.bn1(out)
87
+ out = self.relu(out)
88
+
89
+ out = self.conv2(out)
90
+ out = self.bn2(out)
91
+
92
+ if self.downsample is not None:
93
+ identity = self.downsample(x)
94
+
95
+ out += identity
96
+ out = self.relu(out)
97
+
98
+ return out
99
+
100
+
101
+ class Bottleneck(nn.Module):
102
+ expansion = 4
103
+
104
+ def __init__(
105
+ self,
106
+ inplanes,
107
+ planes,
108
+ stride=1,
109
+ downsample=None,
110
+ groups=1,
111
+ base_width=64,
112
+ dilation=1,
113
+ norm_layer=None,
114
+ ):
115
+ super(Bottleneck, self).__init__()
116
+ if norm_layer is None:
117
+ norm_layer = nn.BatchNorm2d
118
+ width = int(planes * (base_width / 64.0)) * groups
119
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
120
+ self.conv1 = conv1x1(inplanes, width)
121
+ self.bn1 = norm_layer(width)
122
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
123
+ self.bn2 = norm_layer(width)
124
+ self.conv3 = conv1x1(width, planes * self.expansion)
125
+ self.bn3 = norm_layer(planes * self.expansion)
126
+ self.relu = nn.ReLU(inplace=True)
127
+ self.downsample = downsample
128
+ self.stride = stride
129
+
130
+ def forward(self, x):
131
+ identity = x
132
+
133
+ out = self.conv1(x)
134
+ out = self.bn1(out)
135
+ out = self.relu(out)
136
+
137
+ out = self.conv2(out)
138
+ out = self.bn2(out)
139
+ out = self.relu(out)
140
+
141
+ out = self.conv3(out)
142
+ out = self.bn3(out)
143
+
144
+ if self.downsample is not None:
145
+ identity = self.downsample(x)
146
+
147
+ out += identity
148
+ out = self.relu(out)
149
+
150
+ return out
151
+
152
+
153
+ class ResNet(nn.Module):
154
+ """Residual network.
155
+
156
+ Reference:
157
+ - He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
158
+ - Xie et al. Aggregated Residual Transformations for Deep Neural Networks. CVPR 2017.
159
+
160
+ Public keys:
161
+ - ``resnet18``: ResNet18.
162
+ - ``resnet34``: ResNet34.
163
+ - ``resnet50``: ResNet50.
164
+ - ``resnet101``: ResNet101.
165
+ - ``resnet152``: ResNet152.
166
+ - ``resnext50_32x4d``: ResNeXt50.
167
+ - ``resnext101_32x8d``: ResNeXt101.
168
+ - ``resnet50_fc512``: ResNet50 + FC.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ num_classes,
174
+ loss,
175
+ block,
176
+ layers,
177
+ zero_init_residual=False,
178
+ groups=1,
179
+ width_per_group=64,
180
+ replace_stride_with_dilation=None,
181
+ norm_layer=None,
182
+ last_stride=2,
183
+ fc_dims=None,
184
+ dropout_p=None,
185
+ **kwargs
186
+ ):
187
+ super(ResNet, self).__init__()
188
+ if norm_layer is None:
189
+ norm_layer = nn.BatchNorm2d
190
+ self._norm_layer = norm_layer
191
+ self.loss = loss
192
+ self.feature_dim = 512 * block.expansion
193
+ self.inplanes = 64
194
+ self.dilation = 1
195
+ if replace_stride_with_dilation is None:
196
+ # each element in the tuple indicates if we should replace
197
+ # the 2x2 stride with a dilated convolution instead
198
+ replace_stride_with_dilation = [False, False, False]
199
+ if len(replace_stride_with_dilation) != 3:
200
+ raise ValueError(
201
+ "replace_stride_with_dilation should be None "
202
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
203
+ )
204
+ self.groups = groups
205
+ self.base_width = width_per_group
206
+ self.conv1 = nn.Conv2d(
207
+ 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
208
+ )
209
+ self.bn1 = norm_layer(self.inplanes)
210
+ self.relu = nn.ReLU(inplace=True)
211
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
212
+ self.layer1 = self._make_layer(block, 64, layers[0])
213
+ self.layer2 = self._make_layer(
214
+ block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
215
+ )
216
+ self.layer3 = self._make_layer(
217
+ block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
218
+ )
219
+ self.layer4 = self._make_layer(
220
+ block,
221
+ 512,
222
+ layers[3],
223
+ stride=last_stride,
224
+ dilate=replace_stride_with_dilation[2],
225
+ )
226
+ self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
227
+ self.fc = self._construct_fc_layer(fc_dims, 512 * block.expansion, dropout_p)
228
+ self.classifier = nn.Linear(self.feature_dim, num_classes)
229
+
230
+ self._init_params()
231
+
232
+ # Zero-initialize the last BN in each residual branch,
233
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
234
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
235
+ if zero_init_residual:
236
+ for m in self.modules():
237
+ if isinstance(m, Bottleneck):
238
+ nn.init.constant_(m.bn3.weight, 0)
239
+ elif isinstance(m, BasicBlock):
240
+ nn.init.constant_(m.bn2.weight, 0)
241
+
242
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
243
+ norm_layer = self._norm_layer
244
+ downsample = None
245
+ previous_dilation = self.dilation
246
+ if dilate:
247
+ self.dilation *= stride
248
+ stride = 1
249
+ if stride != 1 or self.inplanes != planes * block.expansion:
250
+ downsample = nn.Sequential(
251
+ conv1x1(self.inplanes, planes * block.expansion, stride),
252
+ norm_layer(planes * block.expansion),
253
+ )
254
+
255
+ layers = []
256
+ layers.append(
257
+ block(
258
+ self.inplanes,
259
+ planes,
260
+ stride,
261
+ downsample,
262
+ self.groups,
263
+ self.base_width,
264
+ previous_dilation,
265
+ norm_layer,
266
+ )
267
+ )
268
+ self.inplanes = planes * block.expansion
269
+ for _ in range(1, blocks):
270
+ layers.append(
271
+ block(
272
+ self.inplanes,
273
+ planes,
274
+ groups=self.groups,
275
+ base_width=self.base_width,
276
+ dilation=self.dilation,
277
+ norm_layer=norm_layer,
278
+ )
279
+ )
280
+
281
+ return nn.Sequential(*layers)
282
+
283
+ def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
284
+ """Constructs fully connected layer
285
+
286
+ Args:
287
+ fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
288
+ input_dim (int): input dimension
289
+ dropout_p (float): dropout probability, if None, dropout is unused
290
+ """
291
+ if fc_dims is None:
292
+ self.feature_dim = input_dim
293
+ return None
294
+
295
+ assert isinstance(
296
+ fc_dims, (list, tuple)
297
+ ), "fc_dims must be either list or tuple, but got {}".format(type(fc_dims))
298
+
299
+ layers = []
300
+ for dim in fc_dims:
301
+ layers.append(nn.Linear(input_dim, dim))
302
+ layers.append(nn.BatchNorm1d(dim))
303
+ layers.append(nn.ReLU(inplace=True))
304
+ if dropout_p is not None:
305
+ layers.append(nn.Dropout(p=dropout_p))
306
+ input_dim = dim
307
+
308
+ self.feature_dim = fc_dims[-1]
309
+
310
+ return nn.Sequential(*layers)
311
+
312
+ def _init_params(self):
313
+ for m in self.modules():
314
+ if isinstance(m, nn.Conv2d):
315
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
316
+ if m.bias is not None:
317
+ nn.init.constant_(m.bias, 0)
318
+ elif isinstance(m, nn.BatchNorm2d):
319
+ nn.init.constant_(m.weight, 1)
320
+ nn.init.constant_(m.bias, 0)
321
+ elif isinstance(m, nn.BatchNorm1d):
322
+ nn.init.constant_(m.weight, 1)
323
+ nn.init.constant_(m.bias, 0)
324
+ elif isinstance(m, nn.Linear):
325
+ nn.init.normal_(m.weight, 0, 0.01)
326
+ if m.bias is not None:
327
+ nn.init.constant_(m.bias, 0)
328
+
329
+ def featuremaps(self, x):
330
+ x = self.conv1(x)
331
+ x = self.bn1(x)
332
+ x = self.relu(x)
333
+ x = self.maxpool(x)
334
+ x = self.layer1(x)
335
+ x = self.layer2(x)
336
+ x = self.layer3(x)
337
+ x = self.layer4(x)
338
+ return x
339
+
340
+ def forward(self, x):
341
+ f = self.featuremaps(x)
342
+ v = self.global_avgpool(f)
343
+ v = v.view(v.size(0), -1)
344
+
345
+ if self.fc is not None:
346
+ v = self.fc(v)
347
+
348
+ if not self.training:
349
+ return v
350
+
351
+ y = self.classifier(v)
352
+
353
+ if self.loss == "softmax":
354
+ return y
355
+ elif self.loss == "triplet":
356
+ return y, v
357
+ else:
358
+ raise KeyError("Unsupported loss: {}".format(self.loss))
359
+
360
+
361
+ def init_pretrained_weights(model, model_url):
362
+ """Initializes model with pretrained weights.
363
+
364
+ Layers that don't match with pretrained layers in name or size are kept unchanged.
365
+ """
366
+ pretrain_dict = model_zoo.load_url(model_url)
367
+ model_dict = model.state_dict()
368
+ pretrain_dict = {
369
+ k: v
370
+ for k, v in pretrain_dict.items()
371
+ if k in model_dict and model_dict[k].size() == v.size()
372
+ }
373
+ model_dict.update(pretrain_dict)
374
+ model.load_state_dict(model_dict)
375
+
376
+
377
+ """ResNet"""
378
+
379
+
380
+ def resnet18(num_classes, loss="softmax", pretrained=True, **kwargs):
381
+ model = ResNet(
382
+ num_classes=num_classes,
383
+ loss=loss,
384
+ block=BasicBlock,
385
+ layers=[2, 2, 2, 2],
386
+ last_stride=2,
387
+ fc_dims=None,
388
+ dropout_p=None,
389
+ **kwargs
390
+ )
391
+ if pretrained:
392
+ init_pretrained_weights(model, model_urls["resnet18"])
393
+ return model
394
+
395
+
396
+ def resnet34(num_classes, loss="softmax", pretrained=True, **kwargs):
397
+ model = ResNet(
398
+ num_classes=num_classes,
399
+ loss=loss,
400
+ block=BasicBlock,
401
+ layers=[3, 4, 6, 3],
402
+ last_stride=2,
403
+ fc_dims=None,
404
+ dropout_p=None,
405
+ **kwargs
406
+ )
407
+ if pretrained:
408
+ init_pretrained_weights(model, model_urls["resnet34"])
409
+ return model
410
+
411
+
412
+ def resnet50(num_classes, loss="softmax", pretrained=True, **kwargs):
413
+ model = ResNet(
414
+ num_classes=num_classes,
415
+ loss=loss,
416
+ block=Bottleneck,
417
+ layers=[3, 4, 6, 3],
418
+ last_stride=2,
419
+ fc_dims=None,
420
+ dropout_p=None,
421
+ **kwargs
422
+ )
423
+ if pretrained:
424
+ init_pretrained_weights(model, model_urls["resnet50"])
425
+ return model
426
+
427
+
428
+ def resnet101(num_classes, loss="softmax", pretrained=True, **kwargs):
429
+ model = ResNet(
430
+ num_classes=num_classes,
431
+ loss=loss,
432
+ block=Bottleneck,
433
+ layers=[3, 4, 23, 3],
434
+ last_stride=2,
435
+ fc_dims=None,
436
+ dropout_p=None,
437
+ **kwargs
438
+ )
439
+ if pretrained:
440
+ init_pretrained_weights(model, model_urls["resnet101"])
441
+ return model
442
+
443
+
444
+ def resnet152(num_classes, loss="softmax", pretrained=True, **kwargs):
445
+ model = ResNet(
446
+ num_classes=num_classes,
447
+ loss=loss,
448
+ block=Bottleneck,
449
+ layers=[3, 8, 36, 3],
450
+ last_stride=2,
451
+ fc_dims=None,
452
+ dropout_p=None,
453
+ **kwargs
454
+ )
455
+ if pretrained:
456
+ init_pretrained_weights(model, model_urls["resnet152"])
457
+ return model
458
+
459
+
460
+ """ResNeXt"""
461
+
462
+
463
+ def resnext50_32x4d(num_classes, loss="softmax", pretrained=True, **kwargs):
464
+ model = ResNet(
465
+ num_classes=num_classes,
466
+ loss=loss,
467
+ block=Bottleneck,
468
+ layers=[3, 4, 6, 3],
469
+ last_stride=2,
470
+ fc_dims=None,
471
+ dropout_p=None,
472
+ groups=32,
473
+ width_per_group=4,
474
+ **kwargs
475
+ )
476
+ if pretrained:
477
+ init_pretrained_weights(model, model_urls["resnext50_32x4d"])
478
+ return model
479
+
480
+
481
+ def resnext101_32x8d(num_classes, loss="softmax", pretrained=True, **kwargs):
482
+ model = ResNet(
483
+ num_classes=num_classes,
484
+ loss=loss,
485
+ block=Bottleneck,
486
+ layers=[3, 4, 23, 3],
487
+ last_stride=2,
488
+ fc_dims=None,
489
+ dropout_p=None,
490
+ groups=32,
491
+ width_per_group=8,
492
+ **kwargs
493
+ )
494
+ if pretrained:
495
+ init_pretrained_weights(model, model_urls["resnext101_32x8d"])
496
+ return model
497
+
498
+
499
+ """
500
+ ResNet + FC
501
+ """
502
+
503
+
504
+ def resnet50_fc512(num_classes, loss="softmax", pretrained=True, **kwargs):
505
+ model = ResNet(
506
+ num_classes=num_classes,
507
+ loss=loss,
508
+ block=Bottleneck,
509
+ layers=[3, 4, 6, 3],
510
+ last_stride=1,
511
+ fc_dims=[512],
512
+ dropout_p=None,
513
+ **kwargs
514
+ )
515
+ if pretrained:
516
+ init_pretrained_weights(model, model_urls["resnet50"])
517
+ return model
boxmot/appearance/backends/base_backend.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import gdown
4
+ import numpy as np
5
+ from abc import ABC, abstractmethod
6
+ from boxmot.utils import logger as LOGGER
7
+ from boxmot.appearance.reid.registry import ReIDModelRegistry
8
+ from boxmot.utils.checks import RequirementsChecker
9
+
10
+
11
+ class BaseModelBackend:
12
+ def __init__(self, weights, device, half):
13
+ self.weights = weights[0] if isinstance(weights, list) else weights
14
+ self.device = device
15
+ self.half = half
16
+ self.model = None
17
+ self.cuda = torch.cuda.is_available() and self.device.type != "cpu"
18
+
19
+ self.download_model(self.weights)
20
+ self.model_name = ReIDModelRegistry.get_model_name(self.weights)
21
+
22
+ self.model = ReIDModelRegistry.build_model(
23
+ self.model_name,
24
+ num_classes=ReIDModelRegistry.get_nr_classes(self.weights),
25
+ pretrained=not (self.weights and self.weights.is_file()),
26
+ use_gpu=device,
27
+ )
28
+ self.checker = RequirementsChecker()
29
+ self.load_model(self.weights)
30
+
31
+
32
+ def get_crops(self, xyxys, img):
33
+ h, w = img.shape[:2]
34
+ resize_dims = (128, 256)
35
+ interpolation_method = cv2.INTER_LINEAR
36
+ mean_array = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
37
+ std_array = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
38
+
39
+ # Preallocate tensor for crops
40
+ num_crops = len(xyxys)
41
+ crops = torch.empty((num_crops, 3, resize_dims[1], resize_dims[0]),
42
+ dtype=torch.half if self.half else torch.float, device=self.device)
43
+
44
+ for i, box in enumerate(xyxys):
45
+ x1, y1, x2, y2 = box.round().astype('int')
46
+ x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w, x2), min(h, y2)
47
+ crop = img[y1:y2, x1:x2]
48
+
49
+ # Resize and convert color in one step
50
+ crop = cv2.resize(crop, resize_dims, interpolation=interpolation_method)
51
+ crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
52
+
53
+ # Convert to tensor and normalize (convert to [0, 1] by dividing by 255 in batch later)
54
+ crop = torch.from_numpy(crop).to(self.device, dtype=torch.half if self.half else torch.float)
55
+ crops[i] = torch.permute(crop, (2, 0, 1)) # Change to (C, H, W)
56
+
57
+ # Normalize the entire batch in one go
58
+ crops = crops / 255.0
59
+
60
+ # Standardize the batch
61
+ crops = (crops - mean_array) / std_array
62
+
63
+ return crops
64
+
65
+
66
+ @torch.no_grad()
67
+ def get_features(self, xyxys, img):
68
+ if xyxys.size != 0:
69
+ crops = self.get_crops(xyxys, img)
70
+ crops = self.inference_preprocess(crops)
71
+ features = self.forward(crops)
72
+ features = self.inference_postprocess(features)
73
+ else:
74
+ features = np.array([])
75
+ features = features / np.linalg.norm(features, axis=-1, keepdims=True)
76
+ return features
77
+
78
+ def warmup(self, imgsz=[(256, 128, 3)]):
79
+ # warmup model by running inference once
80
+ if self.device.type != "cpu":
81
+ im = np.random.randint(0, 255, *imgsz, dtype=np.uint8)
82
+ crops = self.get_crops(xyxys=np.array(
83
+ [[0, 0, 64, 64], [0, 0, 128, 128]]),
84
+ img=im
85
+ )
86
+ crops = self.inference_preprocess(crops)
87
+ self.forward(crops) # warmup
88
+
89
+ def to_numpy(self, x):
90
+ return x.cpu().numpy() if isinstance(x, torch.Tensor) else x
91
+
92
+ def inference_preprocess(self, x):
93
+ if self.half:
94
+ if isinstance(x, torch.Tensor):
95
+ if x.dtype != torch.float16:
96
+ x = x.half()
97
+ elif isinstance(x, np.ndarray):
98
+ if x.dtype != np.float16:
99
+ x = x.astype(np.float16)
100
+
101
+ if self.nhwc:
102
+ if isinstance(x, torch.Tensor):
103
+ x = x.permute(0, 2, 3, 1) # Convert from NCHW to NHWC
104
+ elif isinstance(x, np.ndarray):
105
+ x = np.transpose(x, (0, 2, 3, 1)) # Convert from NCHW to NHWC
106
+ return x
107
+
108
+ def inference_postprocess(self, features):
109
+ if isinstance(features, (list, tuple)):
110
+ return (
111
+ self.to_numpy(features[0]) if len(features) == 1 else [self.to_numpy(x) for x in features]
112
+ )
113
+ else:
114
+ return self.to_numpy(features)
115
+
116
+ @abstractmethod
117
+ def forward(self, im_batch):
118
+ raise NotImplementedError("This method should be implemented by subclasses.")
119
+
120
+ @abstractmethod
121
+ def load_model(self, w):
122
+ raise NotImplementedError("This method should be implemented by subclasses.")
123
+
124
+
125
+ def download_model(self, w):
126
+ if w.suffix == ".pt":
127
+ model_url = ReIDModelRegistry.get_model_url(w)
128
+ if not w.exists() and model_url is not None:
129
+ gdown.download(model_url, str(w), quiet=False)
130
+ elif not w.exists():
131
+ LOGGER.error(
132
+ f"No URL associated with the chosen StrongSORT weights ({w}). Choose between:"
133
+ )
134
+ ReIDModelRegistry.show_downloadable_models()
135
+ exit()
boxmot/appearance/backends/onnx_backend.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pathlib import Path
3
+
4
+ from boxmot.appearance.backends.base_backend import BaseModelBackend
5
+
6
+
7
+ class ONNXBackend(BaseModelBackend):
8
+
9
+ def __init__(self, weights, device, half):
10
+ super().__init__(weights, device, half)
11
+ self.nhwc = False
12
+ self.half = half
13
+
14
+ def load_model(self, w):
15
+
16
+ # ONNXRuntime will attempt to use the first provider, and if it fails or is not
17
+ # available for some reason, it will fall back to the next provider in the list
18
+ if self.device == "mps":
19
+ self.checker.check_packages(("onnxruntime-silicon==1.17.0",))
20
+ providers = ["MPSExecutionProvider", "CPUExecutionProvider"]
21
+ elif self.device == "cuda":
22
+ self.checker.check_packages(("onnxruntime-gpu==1.17.0",))
23
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
24
+ else:
25
+ self.checker.check_packages(("onnxruntime==1.17.0",))
26
+ providers = ["CPUExecutionProvider"]
27
+
28
+ # Load the ONNX model using onnxruntime
29
+ import onnxruntime
30
+ self.session = onnxruntime.InferenceSession(str(w), providers=providers)
31
+
32
+ def forward(self, im_batch):
33
+ # Convert torch tensor to numpy (onnxruntime expects numpy arrays)
34
+ im_batch = im_batch.cpu().numpy()
35
+
36
+ # Run inference using ONNX session
37
+ features = self.session.run(
38
+ [self.session.get_outputs()[0].name],
39
+ {self.session.get_inputs()[0].name: im_batch},
40
+ )[0]
41
+
42
+ return features
boxmot/appearance/backends/openvino_backend.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pathlib import Path
3
+ from boxmot.utils import logger as LOGGER
4
+
5
+ from boxmot.appearance.backends.base_backend import BaseModelBackend
6
+
7
+
8
+ class OpenVinoBackend(BaseModelBackend):
9
+
10
+ def __init__(self, weights, device, half):
11
+ super().__init__(weights, device, half)
12
+ self.nhwc = False
13
+ self.half = half
14
+
15
+ def load_model(self, w):
16
+ self.checker.check_packages(("openvino-dev>=2022.3",))
17
+
18
+ LOGGER.info(f"Loading {w} for OpenVINO inference...")
19
+ try:
20
+ # requires openvino-dev: https://pypi.org/project/openvino-dev/
21
+ from openvino.runtime import Core, Layout
22
+ except ImportError:
23
+ LOGGER.error(
24
+ f"Running {self.__class__} with the specified OpenVINO weights\n{w.name}\n"
25
+ "requires openvino pip package to be installed!\n"
26
+ "$ pip install openvino-dev>=2022.3\n"
27
+ )
28
+ ie = Core()
29
+ if not Path(w).is_file(): # if not *.xml
30
+ w = next(
31
+ Path(w).glob("*.xml")
32
+ ) # get *.xml file from *_openvino_model dir
33
+ network = ie.read_model(model=w, weights=Path(w).with_suffix(".bin"))
34
+ if network.get_parameters()[0].get_layout().empty:
35
+ network.get_parameters()[0].set_layout(Layout("NCWH"))
36
+ self.executable_network = ie.compile_model(
37
+ network, device_name="CPU"
38
+ ) # device_name="MYRIAD" for Intel NCS2
39
+ self.output_layer = next(iter(self.executable_network.outputs))
40
+
41
+ def forward(self, im_batch):
42
+ im_batch = im_batch.cpu().numpy() # FP32
43
+ features = self.executable_network([im_batch])[self.output_layer]
44
+ return features
boxmot/appearance/backends/pytorch_backend.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pathlib import Path
3
+
4
+ from boxmot.appearance.backends.base_backend import BaseModelBackend
5
+ from boxmot.appearance.reid.registry import ReIDModelRegistry
6
+
7
+
8
+ class PyTorchBackend(BaseModelBackend):
9
+
10
+ def __init__(self, weights, device, half):
11
+ super().__init__(weights, device, half)
12
+ self.nhwc = False
13
+ self.half = half
14
+
15
+ def load_model(self, w):
16
+ # Load a PyTorch model
17
+ if w and w.is_file():
18
+ ReIDModelRegistry.load_pretrained_weights(self.model, w)
19
+ self.model.to(self.device).eval()
20
+ self.model.half() if self.half else self.model.float()
21
+
22
+ def forward(self, im_batch):
23
+ features = self.model(im_batch)
24
+ return features
boxmot/appearance/backends/tensorrt_backend.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from pathlib import Path
4
+ from collections import OrderedDict, namedtuple
5
+ from boxmot.utils import logger as LOGGER
6
+ from boxmot.appearance.backends.base_backend import BaseModelBackend
7
+
8
+ class TensorRTBackend(BaseModelBackend):
9
+ def __init__(self, weights, device, half):
10
+ self.is_trt10 = False
11
+ super().__init__(weights, device, half)
12
+ self.nhwc = False
13
+ self.half = half
14
+ self.device = device
15
+ self.weights = weights
16
+ self.fp16 = False # Will be updated in load_model
17
+ self.load_model(self.weights)
18
+
19
+ def load_model(self, w):
20
+ LOGGER.info(f"Loading {w} for TensorRT inference...")
21
+ self.checker.check_packages(("nvidia-tensorrt",))
22
+ try:
23
+ import tensorrt as trt # TensorRT library
24
+ except ImportError:
25
+ raise ImportError("Please install tensorrt to use this backend.")
26
+
27
+ if self.device.type == "cpu":
28
+ if torch.cuda.is_available():
29
+ self.device = torch.device("cuda:0")
30
+ else:
31
+ raise ValueError("CUDA device not available for TensorRT inference.")
32
+
33
+ Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr"))
34
+ logger = trt.Logger(trt.Logger.INFO)
35
+
36
+ # Deserialize the engine
37
+ with open(w, "rb") as f, trt.Runtime(logger) as runtime:
38
+ self.model_ = runtime.deserialize_cuda_engine(f.read())
39
+
40
+ # Execution context
41
+ self.context = self.model_.create_execution_context()
42
+ self.bindings = OrderedDict()
43
+
44
+ self.is_trt10 = not hasattr(self.model_, "num_bindings")
45
+ num = range(self.model_.num_io_tensors) if self.is_trt10 else range(self.model_.num_bindings)
46
+
47
+ # Parse bindings
48
+ for index in num:
49
+ if self.is_trt10:
50
+ name = self.model_.get_tensor_name(index)
51
+ dtype = trt.nptype(self.model_.get_tensor_dtype(name))
52
+ is_input = self.model_.get_tensor_mode(name) == trt.TensorIOMode.INPUT
53
+ if is_input and -1 in tuple(self.model_.get_tensor_shape(name)):
54
+ self.context.set_input_shape(name, tuple(self.model_.get_tensor_profile_shape(name, 0)[1]))
55
+ if is_input and dtype == np.float16:
56
+ self.fp16 = True
57
+
58
+ shape = tuple(self.context.get_tensor_shape(name))
59
+
60
+ else:
61
+ name = self.model_.get_binding_name(index)
62
+ dtype = trt.nptype(self.model_.get_binding_dtype(index))
63
+ is_input = self.model_.binding_is_input(index)
64
+
65
+ # Handle dynamic shapes
66
+ if is_input and -1 in self.model_.get_binding_shape(index):
67
+ profile_index = 0
68
+ min_shape, opt_shape, max_shape = self.model_.get_profile_shape(profile_index, index)
69
+ self.context.set_binding_shape(index, opt_shape)
70
+
71
+ if is_input and dtype == np.float16:
72
+ self.fp16 = True
73
+
74
+ shape = tuple(self.context.get_binding_shape(index))
75
+ data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(self.device)
76
+ self.bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
77
+
78
+ self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items())
79
+
80
+ def forward(self, im_batch):
81
+ temp_im_batch = im_batch.clone()
82
+ batch_array = []
83
+ inp_batch = im_batch.shape[0]
84
+ out_batch = self.bindings["output"].shape[0]
85
+ resultant_features = []
86
+
87
+ # Divide batch to sub batches
88
+ while inp_batch > out_batch:
89
+ batch_array.append(temp_im_batch[:out_batch])
90
+ temp_im_batch = temp_im_batch[out_batch:]
91
+ inp_batch = temp_im_batch.shape[0]
92
+ if temp_im_batch.shape[0] > 0:
93
+ batch_array.append(temp_im_batch)
94
+
95
+ for temp_batch in batch_array:
96
+ # Adjust for dynamic shapes
97
+ if temp_batch.shape != self.bindings["images"].shape:
98
+ if self.is_trt10:
99
+
100
+ self.context.set_input_shape("images", temp_batch.shape)
101
+ self.bindings["images"] = self.bindings["images"]._replace(shape=temp_batch.shape)
102
+ self.bindings["output"].data.resize_(tuple(self.context.get_tensor_shape("output")))
103
+ else:
104
+ i_in = self.model_.get_binding_index("images")
105
+ i_out = self.model_.get_binding_index("output")
106
+ self.context.set_binding_shape(i_in, temp_batch.shape)
107
+ self.bindings["images"] = self.bindings["images"]._replace(shape=temp_batch.shape)
108
+ output_shape = tuple(self.context.get_binding_shape(i_out))
109
+ self.bindings["output"].data.resize_(output_shape)
110
+
111
+ s = self.bindings["images"].shape
112
+ assert temp_batch.shape == s, f"Input size {temp_batch.shape} does not match model size {s}"
113
+
114
+ self.binding_addrs["images"] = int(temp_batch.data_ptr())
115
+
116
+ # Execute inference
117
+ self.context.execute_v2(list(self.binding_addrs.values()))
118
+ features = self.bindings["output"].data
119
+ resultant_features.append(features.clone())
120
+
121
+ if len(resultant_features)== 1:
122
+ return resultant_features[0]
123
+ else:
124
+ rslt_features = torch.cat(resultant_features,dim=0)
125
+ rslt_features= rslt_features[:im_batch.shape[0]]
126
+ return rslt_features
boxmot/appearance/backends/tflite_backend.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from pathlib import Path
4
+ from boxmot.utils import logger as LOGGER
5
+
6
+ from boxmot.appearance.backends.base_backend import BaseModelBackend
7
+
8
+
9
+ class TFLiteBackend(BaseModelBackend):
10
+ """
11
+ A class to handle TensorFlow Lite model inference with dynamic batch size support.
12
+
13
+ Attributes:
14
+ nhwc (bool): A flag indicating the order of dimensions.
15
+ half (bool): A flag to indicate if half precision is used.
16
+ interpreter (tf.lite.Interpreter): The TensorFlow Lite interpreter.
17
+ current_allocated_batch_size (int): The current batch size allocated in the interpreter.
18
+ """
19
+
20
+ def __init__(self, weights: Path, device: str, half: bool):
21
+ """
22
+ Initializes the TFLiteBackend with given weights, device, and precision flag.
23
+
24
+ Args:
25
+ weights (Path): Path to the TFLite model file.
26
+ device (str): Device type (e.g., 'cpu', 'gpu').
27
+ half (bool): Flag to indicate if half precision is used.
28
+ """
29
+ super().__init__(weights, device, half)
30
+ self.nhwc = True
31
+ self.half = False
32
+ # self.interpreter: tf.lite.Interpreter = None
33
+ # self.current_allocated_batch_size: int = None
34
+
35
+ def load_model(self, w):
36
+ """
37
+ Loads the TensorFlow Lite model and initializes the interpreter.
38
+
39
+ Args:
40
+ w (str): Path to the TFLite model file.
41
+ """
42
+ self.checker.check_packages(("tensorflow",))
43
+
44
+ LOGGER.info(f"Loading {str(w)} for TensorFlow Lite inference...")
45
+
46
+ import tensorflow as tf
47
+ self.interpreter = tf.lite.Interpreter(model_path=str(w))
48
+
49
+
50
+ self.interpreter.allocate_tensors() # allocate
51
+ self.input_details = self.interpreter.get_input_details() # inputs
52
+ self.output_details = self.interpreter.get_output_details() # outputs
53
+ self.current_allocated_batch_size = self.input_details[0]['shape'][0]
54
+
55
+ def forward(self, im_batch: torch.Tensor) -> np.ndarray:
56
+ """
57
+ Runs forward pass for the given image batch through the TFLite model.
58
+
59
+ Args:
60
+ im_batch (torch.Tensor): Input image batch tensor.
61
+
62
+ Returns:
63
+ np.ndarray: Output features from the TFLite model.
64
+ """
65
+ im_batch = im_batch.cpu().numpy()
66
+
67
+ # Extract batch size from im_batch
68
+ batch_size = im_batch.shape[0]
69
+
70
+ # Resize tensors if the new batch size is different from the current allocated batch size
71
+ if batch_size != self.current_allocated_batch_size:
72
+ # print(f"Resizing tensor input to batch size {batch_size}")
73
+ self.interpreter.resize_tensor_input(self.input_details[0]['index'], [batch_size, 256, 128, 3])
74
+ self.interpreter.allocate_tensors()
75
+ self.current_allocated_batch_size = batch_size
76
+
77
+ # Set the tensor to point to the input data
78
+ self.interpreter.set_tensor(self.input_details[0]['index'], im_batch)
79
+
80
+ # Run inference
81
+ self.interpreter.invoke()
82
+
83
+ # Get the output data
84
+ features = self.interpreter.get_tensor(self.output_details[0]['index'])
85
+
86
+ return features
boxmot/appearance/backends/torchscript_backend.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from pathlib import Path
4
+ from boxmot.utils import logger as LOGGER
5
+
6
+ from boxmot.appearance.backends.base_backend import BaseModelBackend
7
+
8
+
9
+ class TorchscriptBackend(BaseModelBackend):
10
+
11
+ def __init__(self, weights, device, half):
12
+ super().__init__(weights, device, half)
13
+ self.nhwc = False
14
+ self.half = half
15
+
16
+ def load_model(self, w):
17
+
18
+ LOGGER.info(f"Loading {w} for TorchScript inference...")
19
+ self.model = torch.jit.load(w)
20
+ self.model.half() if self.half else self.model.float()
21
+
22
+ def forward(self, im_batch):
23
+ features = self.model(im_batch)
24
+ return features
boxmot/appearance/exporters/base_exporter.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ from pathlib import Path
4
+ from boxmot.utils.checks import RequirementsChecker
5
+ from boxmot.utils import logger as LOGGER
6
+
7
+
8
+ def export_decorator(export_func):
9
+ def wrapper(self, *args, **kwargs):
10
+ try:
11
+ if hasattr(self, 'required_packages'):
12
+ if hasattr(self, 'cmd'):
13
+ self.checker.check_packages(self.required_packages, cmd=self.cmd)
14
+ else:
15
+ self.checker.check_packages(self.required_packages)
16
+
17
+ LOGGER.info(f"\nStarting {self.file} export with {self.__class__.__name__}...")
18
+ result = export_func(self, *args, **kwargs)
19
+ if result:
20
+ LOGGER.info(f"Export success, saved as {result} ({self.file_size(result):.1f} MB)")
21
+ return result
22
+ except Exception as e:
23
+ LOGGER.error(f"Export failure: {e}")
24
+ return None
25
+ return wrapper
26
+
27
+
28
+ class BaseExporter:
29
+ def __init__(self, model, im, file, optimize=False, dynamic=False, half=False, simplify=False):
30
+ self.model = model
31
+ self.im = im
32
+ self.file = Path(file)
33
+ self.optimize = optimize
34
+ self.dynamic = dynamic
35
+ self.half = half
36
+ self.simplify = simplify
37
+ self.checker = RequirementsChecker()
38
+ self.workspace = 4
39
+
40
+ @staticmethod
41
+ def file_size(path):
42
+ path = Path(path)
43
+ if path.is_file():
44
+ return path.stat().st_size / 1e6
45
+ elif path.is_dir():
46
+ return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / 1e6
47
+ else:
48
+ return 0.0
49
+
50
+ def export(self):
51
+ raise NotImplementedError("Export method must be implemented in subclasses.")
52
+
53
+ def __init_subclass__(cls, **kwargs):
54
+ super().__init_subclass__(**kwargs)
55
+ if 'export' in cls.__dict__:
56
+ cls.export = export_decorator(cls.export)
boxmot/appearance/exporters/onnx_exporter.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import onnx
3
+ from boxmot.appearance.exporters.base_exporter import BaseExporter
4
+ from boxmot.utils import logger as LOGGER
5
+
6
+
7
+ class ONNXExporter(BaseExporter):
8
+ required_packages = ("onnx>=1.16.1",)
9
+
10
+ def export(self):
11
+
12
+ f = self.file.with_suffix(".onnx")
13
+
14
+ dynamic = {"images": {0: "batch"}, "output": {0: "batch"}} if self.dynamic else None
15
+
16
+ torch.onnx.export(
17
+ self.model.cpu() if self.dynamic else self.model,
18
+ self.im.cpu() if self.dynamic else self.im,
19
+ f,
20
+ verbose=False,
21
+ opset_version=12,
22
+ do_constant_folding=True,
23
+ input_names=["images"],
24
+ output_names=["output"],
25
+ dynamic_axes=dynamic,
26
+ )
27
+
28
+ model_onnx = onnx.load(f)
29
+ onnx.checker.check_model(model_onnx)
30
+ onnx.save(model_onnx, f)
31
+
32
+ if self.simplify:
33
+ self.simplify_model(model_onnx, f)
34
+
35
+ return f
36
+
37
+
38
+ def simplify_model(self, model_onnx, f):
39
+ try:
40
+ cuda = torch.cuda.is_available()
41
+ self.checker.check_packages(
42
+ (
43
+ "onnxruntime-gpu" if cuda else "onnxruntime",
44
+ "onnx-simplifier>=0.4.1",
45
+ )
46
+ )
47
+ import onnxsim
48
+
49
+ LOGGER.info(
50
+ f"Simplifying with onnx-simplifier {onnxsim.__version__}..."
51
+ )
52
+ model_onnx, check = onnxsim.simplify(model_onnx)
53
+ assert check, "assert check failed"
54
+ onnx.save(model_onnx, f)
55
+ except Exception as e:
56
+ LOGGER.error(f"Simplifier failure: {e}")
boxmot/appearance/exporters/openvino_exporter.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import openvino.runtime as ov
4
+ from openvino.tools import mo
5
+ from boxmot.appearance.exporters.base_exporter import BaseExporter
6
+ from boxmot.utils import logger as LOGGER
7
+
8
+
9
+ class OpenVINOExporter(BaseExporter):
10
+ required_packages = ("openvino-dev>=2023.0",)
11
+
12
+ def export(self):
13
+
14
+ f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}")
15
+ f_onnx = self.file.with_suffix(".onnx")
16
+ f_ov = str(Path(f) / self.file.with_suffix(".xml").name)
17
+
18
+ ov_model = mo.convert_model(
19
+ f_onnx,
20
+ model_name=self.file.with_suffix(".xml"),
21
+ framework="onnx",
22
+ compress_to_fp16=self.half,
23
+ )
24
+ ov.serialize(ov_model, f_ov)
25
+
26
+ return f
boxmot/appearance/exporters/tensorrt_exporter.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ import torch
3
+ from boxmot.appearance.exporters.base_exporter import BaseExporter
4
+ from boxmot.appearance.exporters.onnx_exporter import ONNXExporter
5
+ from boxmot.utils import logger as LOGGER
6
+
7
+
8
+ class EngineExporter(BaseExporter):
9
+ required_packages = ("nvidia-tensorrt",)
10
+ cmds = '--extra-index-url https://pypi.ngc.nvidia.com'
11
+
12
+ def export(self):
13
+
14
+ assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. `python export.py --device 0`"
15
+ try:
16
+ import tensorrt as trt
17
+ except ImportError:
18
+ import tensorrt as trt
19
+
20
+ onnx_file = self.export_onnx()
21
+ LOGGER.info(f"\nStarting export with TensorRT {trt.__version__}...")
22
+ is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10
23
+ assert onnx_file.exists(), f"Failed to export ONNX file: {onnx_file}"
24
+ f = self.file.with_suffix(".engine")
25
+ logger = trt.Logger(trt.Logger.INFO)
26
+ if True:
27
+ logger.min_severity = trt.Logger.Severity.VERBOSE
28
+
29
+ builder = trt.Builder(logger)
30
+ config = builder.create_builder_config()
31
+ workspace = int(self.workspace * (1 << 30))
32
+ if is_trt10:
33
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)
34
+ else: # TensorRT versions 7, 8
35
+ config.max_workspace_size = workspace
36
+
37
+ flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
38
+ network = builder.create_network(flag)
39
+ parser = trt.OnnxParser(network, logger)
40
+ if not parser.parse_from_file(str(onnx_file)):
41
+ raise RuntimeError(f"Failed to load ONNX file: {onnx_file}")
42
+
43
+ inputs = [network.get_input(i) for i in range(network.num_inputs)]
44
+ outputs = [network.get_output(i) for i in range(network.num_outputs)]
45
+ LOGGER.info("Network Description:")
46
+ for inp in inputs:
47
+ LOGGER.info(f'\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
48
+ for out in outputs:
49
+ LOGGER.info(f'\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
50
+
51
+ if self.dynamic:
52
+ if self.im.shape[0] <= 1:
53
+ LOGGER.warning("WARNING: --dynamic model requires maximum --batch-size argument")
54
+ profile = builder.create_optimization_profile()
55
+ for inp in inputs:
56
+ if self.half:
57
+ inp.dtype = trt.float16
58
+ profile.set_shape(
59
+ inp.name,
60
+ (1, *self.im.shape[1:]),
61
+ (max(1, self.im.shape[0] // 2), *self.im.shape[1:]),
62
+ self.im.shape,
63
+ )
64
+ config.add_optimization_profile(profile)
65
+
66
+ LOGGER.info(f"Building FP{16 if builder.platform_has_fast_fp16 and self.half else 32} engine in {f}")
67
+ if builder.platform_has_fast_fp16 and self.half:
68
+ config.set_flag(trt.BuilderFlag.FP16)
69
+ config.default_device_type = trt.DeviceType.GPU
70
+
71
+ build = builder.build_serialized_network if is_trt10 else builder.build_engine
72
+ with build(network, config) as engine, open(f, "wb") as t:
73
+ t.write(engine if is_trt10 else engine.serialize())
74
+
75
+ return f
76
+
77
+
78
+ def export_onnx(self):
79
+ onnx_exporter = ONNXExporter(self.model, self.im, self.file, self.optimize, self.dynamic, self.half, self.simplify)
80
+ return onnx_exporter.export()
boxmot/appearance/exporters/tflite_exporter.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from boxmot.appearance.exporters.base_exporter import BaseExporter
3
+ from boxmot.utils import logger as LOGGER
4
+
5
+
6
+ class TFLiteExporter(BaseExporter):
7
+ required_packages = (
8
+ "onnx2tf>=1.18.0",
9
+ "onnx>=1.16.1",
10
+ "tensorflow==2.17.0",
11
+ "tf_keras", # required by 'onnx2tf' package
12
+ "sng4onnx>=1.0.1", # required by 'onnx2tf' package
13
+ "onnx_graphsurgeon>=0.3.26", # required by 'onnx2tf' package
14
+ "onnxslim>=0.1.31",
15
+ "onnxruntime",
16
+ "flatbuffers>=23.5.26",
17
+ "psutil==5.9.5",
18
+ "ml_dtypes==0.3.2",
19
+ "ai_edge_litert>=1.2.0"
20
+ )
21
+ cmds = '--extra-index-url https://pypi.ngc.nvidia.com'
22
+
23
+ def export(self):
24
+
25
+ import onnx2tf
26
+ input_onnx_file_path = str(self.file.with_suffix('.onnx'))
27
+ output_folder_path = input_onnx_file_path.replace(".onnx", f"_saved_model{os.sep}")
28
+ onnx2tf.convert(
29
+ input_onnx_file_path=input_onnx_file_path,
30
+ output_folder_path=output_folder_path,
31
+ not_use_onnxsim=True,
32
+ verbosity=True,
33
+ # output_integer_quantized_tflite=self.args.int8,
34
+ # quant_type="per-tensor", # "per-tensor" (faster) or "per-channel" (slower but more accurate)
35
+ # custom_input_op_name_np_data_path=np_data,
36
+ )
37
+ return output_folder_path
boxmot/appearance/exporters/torchscript_exporter.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from boxmot.appearance.exporters.base_exporter import BaseExporter
3
+ from boxmot.utils import logger as LOGGER
4
+
5
+
6
+ class TorchScriptExporter(BaseExporter):
7
+ def export(self):
8
+ f = self.file.with_suffix(".torchscript")
9
+ ts = torch.jit.trace(self.model, self.im, strict=False)
10
+ if self.optimize:
11
+ torch.utils.mobile_optimizer.optimize_for_mobile(ts)._save_for_lite_interpreter(str(f))
12
+ else:
13
+ ts.save(str(f))
14
+
15
+ return f
boxmot/appearance/reid/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
2
+
3
+ import pandas as pd
4
+
5
+
6
+ def export_formats():
7
+ # yolo tracking export formats
8
+ x = [
9
+ ["PyTorch", "-", ".pt", True, True],
10
+ ["TorchScript", "torchscript", ".torchscript", True, True],
11
+ ["ONNX", "onnx", ".onnx", True, True],
12
+ ["OpenVINO", "openvino", "_openvino_model", True, False],
13
+ ["TensorRT", "engine", ".engine", False, True],
14
+ ["TensorFlow Lite", "tflite", ".tflite", True, False],
15
+ ]
16
+ return pd.DataFrame(x, columns=["Format", "Argument", "Suffix", "CPU", "GPU"])
boxmot/appearance/reid/auto_backend.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ from typing import Union, Tuple
4
+
5
+ from boxmot.utils import WEIGHTS
6
+ from boxmot.utils import logger as LOGGER
7
+ from boxmot.utils.torch_utils import select_device
8
+ from boxmot.appearance.reid import export_formats
9
+ from boxmot.appearance.backends.onnx_backend import ONNXBackend
10
+ from boxmot.appearance.backends.openvino_backend import OpenVinoBackend
11
+ from boxmot.appearance.backends.pytorch_backend import PyTorchBackend
12
+ from boxmot.appearance.backends.tensorrt_backend import TensorRTBackend
13
+ from boxmot.appearance.backends.tflite_backend import TFLiteBackend
14
+ from boxmot.appearance.backends.torchscript_backend import TorchscriptBackend
15
+ from boxmot.appearance.backends.base_backend import BaseModelBackend
16
+
17
+
18
+
19
+ class ReidAutoBackend():
20
+ def __init__(
21
+ self,
22
+ weights: Path = WEIGHTS / "osnet_x0_25_msmt17.pt",
23
+ device: torch.device = torch.device("cpu"),
24
+ half: bool = False) -> None:
25
+ """
26
+ Initializes the ReidAutoBackend instance with specified weights, device, and precision mode.
27
+
28
+ Args:
29
+ weights (Union[str, List[str]]): Path to the model weights. Can be a string or a list of strings; if a list, the first element is used.
30
+ device (torch.device): The device to run the model on, e.g., CPU or GPU.
31
+ half (bool): Whether to use half precision for model inference.
32
+ """
33
+ super().__init__()
34
+ w = weights[0] if isinstance(weights, list) else weights
35
+ (
36
+ self.pt,
37
+ self.jit,
38
+ self.onnx,
39
+ self.xml,
40
+ self.engine,
41
+ self.tflite,
42
+ ) = self.model_type(w) # get backend
43
+
44
+ self.weights = weights
45
+ self.device = select_device(device)
46
+ self.half = half
47
+ self.model = self.get_backend()
48
+
49
+
50
+ def get_backend(self) -> Union['PyTorchBackend', 'TorchscriptBackend', 'ONNXBackend', 'TensorRTBackend', 'OpenVinoBackend', 'TFLiteBackend']:
51
+ """
52
+ Returns an instance of the appropriate backend based on the model type.
53
+
54
+ Returns:
55
+ An instance of a backend class corresponding to the detected model type.
56
+
57
+ Raises:
58
+ SystemExit: If no supported model framework is detected.
59
+ """
60
+
61
+ # Mapping of conditions to backend constructors
62
+ backend_map = {
63
+ self.pt: PyTorchBackend,
64
+ self.jit: TorchscriptBackend,
65
+ self.onnx: ONNXBackend,
66
+ self.engine: TensorRTBackend,
67
+ self.xml: OpenVinoBackend,
68
+ self.tflite: TFLiteBackend
69
+ }
70
+
71
+ # Iterate through the mapping and return the first matching backend
72
+ for condition, backend_class in backend_map.items():
73
+ if condition:
74
+ return backend_class(self.weights, self.device, self.half)
75
+
76
+ # If no condition is met, log an error and exit
77
+ LOGGER.error("This model framework is not supported yet!")
78
+ exit()
79
+
80
+
81
+ def forward(self, im_batch: torch.Tensor) -> torch.Tensor:
82
+ """
83
+ Processes an image batch through the selected backend and returns the processed batch.
84
+
85
+ Args:
86
+ im_batch (torch.Tensor): The batch of images to process.
87
+
88
+ Returns:
89
+ torch.Tensor: The processed image batch.
90
+ """
91
+ im_batch = self.backend.preprocess_input(im_batch)
92
+ return self.backend.get_features(im_batch)
93
+
94
+
95
+ def check_suffix(self, file: Path = "osnet_x0_25_msmt17.pt", suffix: Union[str, Tuple[str, ...]] = (".pt",), msg: str = "") -> None:
96
+ """
97
+ Validates that the file or files have an acceptable suffix.
98
+
99
+ Args:
100
+ file (Union[str, List[str], Path]): The file or files to check.
101
+ suffix (Union[str, Tuple[str, ...]]): Acceptable suffix or suffixes.
102
+ msg (str): Additional message to log in case of an error.
103
+ """
104
+
105
+ suffix = [suffix] if isinstance(suffix, str) else list(suffix)
106
+ files = [file] if isinstance(file, (str, Path)) else list(file)
107
+
108
+ for f in files:
109
+ file_suffix = Path(f).suffix.lower()
110
+ if file_suffix and file_suffix not in suffix:
111
+ LOGGER.error(f"File {f} does not have an acceptable suffix. Expected: {suffix}")
112
+
113
+
114
+ def model_type(self, p: Path) -> Tuple[bool, ...]:
115
+ """
116
+ Determines the model type based on the file's suffix.
117
+
118
+ Args:
119
+ path (str): The file path to the model.
120
+
121
+ Returns:
122
+ Tuple[bool, ...]: A tuple of booleans indicating the model type, corresponding to pt, jit, onnx, xml, engine, and tflite.
123
+ """
124
+
125
+ sf = list(export_formats().Suffix) # export suffixes
126
+ self.check_suffix(p, sf) # checks
127
+ types = [s in Path(p).name for s in sf]
128
+ return types
boxmot/appearance/reid/config.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL_TYPES = [
2
+ "resnet50",
3
+ "resnet101",
4
+ "mlfn",
5
+ "hacnn",
6
+ "mobilenetv2_x1_0",
7
+ "mobilenetv2_x1_4",
8
+ "osnet_x1_0",
9
+ "osnet_x0_75",
10
+ "osnet_x0_5",
11
+ "osnet_x0_25",
12
+ "osnet_ibn_x1_0",
13
+ "osnet_ain_x1_0",
14
+ "lmbn_n",
15
+ "clip",
16
+ ]
17
+
18
+ TRAINED_URLS = {
19
+ # resnet50
20
+ "resnet50_market1501.pt": "https://drive.google.com/uc?id=1dUUZ4rHDWohmsQXCRe2C_HbYkzz94iBV",
21
+ "resnet50_dukemtmcreid.pt": "https://drive.google.com/uc?id=17ymnLglnc64NRvGOitY3BqMRS9UWd1wg",
22
+ "resnet50_msmt17.pt": "https://drive.google.com/uc?id=1ep7RypVDOthCRIAqDnn4_N-UhkkFHJsj",
23
+ "resnet50_fc512_market1501.pt": "https://drive.google.com/uc?id=1kv8l5laX_YCdIGVCetjlNdzKIA3NvsSt",
24
+ "resnet50_fc512_dukemtmcreid.pt": "https://drive.google.com/uc?id=13QN8Mp3XH81GK4BPGXobKHKyTGH50Rtx",
25
+ "resnet50_fc512_msmt17.pt": "https://drive.google.com/uc?id=1fDJLcz4O5wxNSUvImIIjoaIF9u1Rwaud",
26
+ # mlfn
27
+ "mlfn_market1501.pt": "https://drive.google.com/uc?id=1wXcvhA_b1kpDfrt9s2Pma-MHxtj9pmvS",
28
+ "mlfn_dukemtmcreid.pt": "https://drive.google.com/uc?id=1rExgrTNb0VCIcOnXfMsbwSUW1h2L1Bum",
29
+ "mlfn_msmt17.pt": "https://drive.google.com/uc?id=18JzsZlJb3Wm7irCbZbZ07TN4IFKvR6p-",
30
+ # hacnn
31
+ "hacnn_market1501.pt": "https://drive.google.com/uc?id=1LRKIQduThwGxMDQMiVkTScBwR7WidmYF",
32
+ "hacnn_dukemtmcreid.pt": "https://drive.google.com/uc?id=1zNm6tP4ozFUCUQ7Sv1Z98EAJWXJEhtYH",
33
+ "hacnn_msmt17.pt": "https://drive.google.com/uc?id=1MsKRtPM5WJ3_Tk2xC0aGOO7pM3VaFDNZ",
34
+ # mobilenetv2
35
+ "mobilenetv2_x1_0_market1501.pt": "https://drive.google.com/uc?id=18DgHC2ZJkjekVoqBWszD8_Xiikz-fewp",
36
+ "mobilenetv2_x1_0_dukemtmcreid.pt": "https://drive.google.com/uc?id=1q1WU2FETRJ3BXcpVtfJUuqq4z3psetds",
37
+ "mobilenetv2_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1j50Hv14NOUAg7ZeB3frzfX-WYLi7SrhZ",
38
+ "mobilenetv2_x1_4_market1501.pt": "https://drive.google.com/uc?id=1t6JCqphJG-fwwPVkRLmGGyEBhGOf2GO5",
39
+ "mobilenetv2_x1_4_dukemtmcreid.pt": "https://drive.google.com/uc?id=12uD5FeVqLg9-AFDju2L7SQxjmPb4zpBN",
40
+ "mobilenetv2_x1_4_msmt17.pt": "https://drive.google.com/uc?id=1ZY5P2Zgm-3RbDpbXM0kIBMPvspeNIbXz",
41
+ # osnet
42
+ "osnet_x1_0_market1501.pt": "https://drive.google.com/uc?id=1vduhq5DpN2q1g4fYEZfPI17MJeh9qyrA",
43
+ "osnet_x1_0_dukemtmcreid.pt": "https://drive.google.com/uc?id=1QZO_4sNf4hdOKKKzKc-TZU9WW1v6zQbq",
44
+ "osnet_x1_0_msmt17.pt": "https://drive.google.com/uc?id=112EMUfBPYeYg70w-syK6V6Mx8-Qb9Q1M",
45
+ "osnet_x0_75_market1501.pt": "https://drive.google.com/uc?id=1ozRaDSQw_EQ8_93OUmjDbvLXw9TnfPer",
46
+ "osnet_x0_75_dukemtmcreid.pt": "https://drive.google.com/uc?id=1IE3KRaTPp4OUa6PGTFL_d5_KQSJbP0Or",
47
+ "osnet_x0_75_msmt17.pt": "https://drive.google.com/uc?id=1QEGO6WnJ-BmUzVPd3q9NoaO_GsPNlmWc",
48
+ "osnet_x0_5_market1501.pt": "https://drive.google.com/uc?id=1PLB9rgqrUM7blWrg4QlprCuPT7ILYGKT",
49
+ "osnet_x0_5_dukemtmcreid.pt": "https://drive.google.com/uc?id=1KoUVqmiST175hnkALg9XuTi1oYpqcyTu",
50
+ "osnet_x0_5_msmt17.pt": "https://drive.google.com/uc?id=1UT3AxIaDvS2PdxzZmbkLmjtiqq7AIKCv",
51
+ "osnet_x0_25_market1501.pt": "https://drive.google.com/uc?id=1z1UghYvOTtjx7kEoRfmqSMu-z62J6MAj",
52
+ "osnet_x0_25_dukemtmcreid.pt": "https://drive.google.com/uc?id=1eumrtiXT4NOspjyEV4j8cHmlOaaCGk5l",
53
+ "osnet_x0_25_msmt17.pt": "https://drive.google.com/uc?id=1sSwXSUlj4_tHZequ_iZ8w_Jh0VaRQMqF",
54
+ # osnet_ain | osnet_ibn
55
+ "osnet_ibn_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1q3Sj2ii34NlfxA4LvmHdWO_75NDRmECJ",
56
+ "osnet_ain_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal",
57
+ # lmbn
58
+ "lmbn_n_duke.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_duke.pth",
59
+ "lmbn_n_market.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_market.pth",
60
+ "lmbn_n_cuhk03_d.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_cuhk03_d.pth",
61
+ # clip
62
+ "clip_market1501.pt": "https://drive.google.com/uc?id=1GnyAVeNOg3Yug1KBBWMKKbT2x43O5Ch7",
63
+ "clip_duke.pt": "https://drive.google.com/uc?id=1ldjSkj-7pXAWmx8on5x0EftlCaolU4dY",
64
+ "clip_veri.pt": "https://drive.google.com/uc?id=1RyfHdOBI2pan_wIGSim5-l6cM4S2WN8e",
65
+ "clip_vehicleid.pt": "https://drive.google.com/uc?id=168BLegHHxNqatW5wx1YyL2REaThWoof5"
66
+ }
67
+
68
+ NR_CLASSES_DICT = {
69
+ "market1501": 751,
70
+ "duke": 702,
71
+ "veri": 576,
72
+ "vehicleid": 576,
73
+ }
boxmot/appearance/reid/export.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import torch
7
+
8
+ from boxmot.appearance.exporters.base_exporter import BaseExporter
9
+ from boxmot.appearance.exporters.onnx_exporter import ONNXExporter
10
+ from boxmot.appearance.exporters.openvino_exporter import OpenVINOExporter
11
+ from boxmot.appearance.exporters.tflite_exporter import TFLiteExporter
12
+ from boxmot.appearance.exporters.torchscript_exporter import TorchScriptExporter
13
+ from boxmot.appearance.exporters.tensorrt_exporter import EngineExporter
14
+ from boxmot.appearance.reid import export_formats
15
+ from boxmot.appearance.reid.auto_backend import ReidAutoBackend
16
+ from boxmot.appearance.reid.registry import ReIDModelRegistry
17
+ from boxmot.utils import WEIGHTS, logger as LOGGER
18
+ from boxmot.utils.torch_utils import select_device
19
+
20
+
21
+ def parse_args():
22
+ """
23
+ Parse command-line arguments for the ReID export script.
24
+ """
25
+ parser = argparse.ArgumentParser(description="ReID Export Script")
26
+ parser.add_argument("--batch-size", type=int, default=1, help="Batch size for export")
27
+ parser.add_argument("--imgsz", "--img", "--img-size",
28
+ nargs="+", type=int, default=[256, 128],
29
+ help="Image size in the format: height width")
30
+ parser.add_argument("--device", default="cpu",
31
+ help="CUDA device (e.g., '0', '0,1,2,3', or 'cpu')")
32
+ parser.add_argument("--optimize", action="store_true",
33
+ help="Optimize TorchScript for mobile (CPU export only)")
34
+ parser.add_argument("--dynamic", action="store_true",
35
+ help="Enable dynamic axes for ONNX/TF/TensorRT export")
36
+ parser.add_argument("--simplify", action="store_true",
37
+ help="Simplify ONNX model")
38
+ parser.add_argument("--opset", type=int, default=12,
39
+ help="ONNX opset version")
40
+ parser.add_argument("--workspace", type=int, default=4,
41
+ help="TensorRT workspace size (GB)")
42
+ parser.add_argument("--verbose", action="store_true",
43
+ help="Enable verbose logging for TensorRT")
44
+ parser.add_argument("--weights", type=Path,
45
+ default=WEIGHTS / "osnet_x0_25_msmt17.pt",
46
+ help="Path to the model weights (.pt file)")
47
+ parser.add_argument("--half", action="store_true",
48
+ help="Enable FP16 half-precision export (GPU only)")
49
+ parser.add_argument("--include", nargs="+",
50
+ default=["torchscript"],
51
+ help=("Export formats to include. Options: torchscript, onnx, "
52
+ "openvino, engine, tflite"))
53
+ return parser.parse_args()
54
+
55
+
56
+ def validate_export_formats(include):
57
+ """
58
+ Validate the provided export formats and return corresponding flags.
59
+
60
+ Args:
61
+ include (list): List of export formats provided via the command line.
62
+
63
+ Returns:
64
+ tuple: Boolean flags for each export format in the order:
65
+ (torchscript, onnx, openvino, engine, tflite)
66
+ """
67
+ available_formats = tuple(export_formats()["Argument"][1:])
68
+ include_lower = [fmt.lower() for fmt in include]
69
+ flags = [fmt in include_lower for fmt in available_formats]
70
+ if sum(flags) != len(include_lower):
71
+ raise AssertionError(
72
+ f"ERROR: Invalid --include {include}, valid arguments are {available_formats}"
73
+ )
74
+ return tuple(flags)
75
+
76
+
77
+ def setup_model(args):
78
+ """
79
+ Initialize and prepare the ReID model for export.
80
+
81
+ Args:
82
+ args: Parsed command-line arguments.
83
+
84
+ Returns:
85
+ tuple: (model (torch.nn.Module), dummy_input (torch.Tensor))
86
+ """
87
+ # Select the correct device
88
+ args.device = select_device(args.device)
89
+ if args.half and args.device.type == "cpu":
90
+ raise AssertionError("--half only compatible with GPU export, use --device 0 for GPU")
91
+
92
+ # Initialize backend model using the auto backend
93
+ auto_backend = ReidAutoBackend(weights=args.weights, device=args.device, half=args.half)
94
+ _ = auto_backend.get_backend() # Backend model is managed internally
95
+
96
+ # Build and load the ReID model from the registry
97
+ model_name = ReIDModelRegistry.get_model_name(args.weights)
98
+ nr_classes = ReIDModelRegistry.get_nr_classes(args.weights)
99
+ pretrained = not (args.weights and args.weights.is_file() and args.weights.suffix == ".pt")
100
+ model = ReIDModelRegistry.build_model(
101
+ model_name,
102
+ num_classes=nr_classes,
103
+ pretrained=pretrained,
104
+ use_gpu=args.device,
105
+ ).to(args.device)
106
+ ReIDModelRegistry.load_pretrained_weights(model, args.weights)
107
+ model.eval()
108
+
109
+ # Ensure --optimize is only used with CPU exports
110
+ if args.optimize and args.device.type != "cpu":
111
+ raise AssertionError("--optimize not compatible with CUDA devices, use --device cpu")
112
+
113
+ # Adjust image size if a specific weight type is detected
114
+ if "lmbn" in str(args.weights):
115
+ args.imgsz = [384, 128]
116
+
117
+ # Create dummy input tensor for warming up the model
118
+ dummy_input = torch.empty(args.batch_size, 3, args.imgsz[0], args.imgsz[1]).to(args.device)
119
+ for _ in range(2):
120
+ _ = model(dummy_input)
121
+
122
+ # Convert to half precision if required
123
+ if args.half:
124
+ dummy_input = dummy_input.half()
125
+ model = model.half()
126
+
127
+ return model, dummy_input
128
+
129
+
130
+ def create_export_tasks(args, model, dummy_input):
131
+ """
132
+ Create a mapping of export tasks with associated flags, exporter classes, and parameters.
133
+
134
+ Args:
135
+ args: Parsed command-line arguments.
136
+ model: Prepared ReID model.
137
+ dummy_input: Dummy input tensor.
138
+
139
+ Returns:
140
+ dict: Mapping of export format to a tuple (flag, exporter_class, export_args)
141
+ """
142
+ torchscript_flag, onnx_flag, openvino_flag, engine_flag, tflite_flag = validate_export_formats(args.include)
143
+ return {
144
+ "torchscript": (
145
+ torchscript_flag,
146
+ TorchScriptExporter,
147
+ (model, dummy_input, args.weights, args.optimize)
148
+ ),
149
+ "engine": (
150
+ engine_flag,
151
+ EngineExporter,
152
+ (model, dummy_input, args.weights, args.half, args.dynamic, args.simplify, args.verbose)
153
+ ),
154
+ "onnx": (
155
+ onnx_flag,
156
+ ONNXExporter,
157
+ (model, dummy_input, args.weights, args.opset, args.dynamic, args.half, args.simplify)
158
+ ),
159
+ "tflite": (
160
+ tflite_flag,
161
+ TFLiteExporter,
162
+ (model, dummy_input, args.weights)
163
+ ),
164
+ "openvino": (
165
+ openvino_flag,
166
+ OpenVINOExporter,
167
+ (model, dummy_input, args.weights, args.half)
168
+ )
169
+ }
170
+
171
+
172
+ def perform_exports(export_tasks):
173
+ """
174
+ Iterate over export tasks and perform export for enabled formats.
175
+
176
+ Args:
177
+ export_tasks (dict): Mapping of export tasks.
178
+
179
+ Returns:
180
+ dict: Mapping of export format to export results.
181
+ """
182
+ exported_files = {}
183
+ for fmt, (flag, exporter_class, exp_args) in export_tasks.items():
184
+ if flag:
185
+ exporter = exporter_class(*exp_args)
186
+ export_result = exporter.export()
187
+ exported_files[fmt] = export_result
188
+ return exported_files
189
+
190
+
191
+ def main():
192
+ """Main function to execute the ReID export process."""
193
+ args = parse_args()
194
+ start_time = time.time()
195
+
196
+ # Ensure the weights directory exists
197
+ WEIGHTS.mkdir(parents=False, exist_ok=True)
198
+
199
+ # Setup model and create a dummy input tensor
200
+ model, dummy_input = setup_model(args)
201
+
202
+ # Log model output shape and file size
203
+ output = model(dummy_input)
204
+ output_tensor = output[0] if isinstance(output, tuple) else output
205
+ output_shape = tuple(output_tensor.shape)
206
+ LOGGER.info(
207
+ f"\nStarting from {args.weights} with output shape {output_shape} "
208
+ f"({BaseExporter.file_size(args.weights):.1f} MB)"
209
+ )
210
+
211
+ # Create export tasks
212
+ export_tasks = create_export_tasks(args, model, dummy_input)
213
+
214
+ # Perform exports for enabled formats
215
+ exported_files = perform_exports(export_tasks)
216
+
217
+ if exported_files:
218
+ elapsed_time = time.time() - start_time
219
+ LOGGER.info(
220
+ f"\nExport complete ({elapsed_time:.1f}s)"
221
+ f"\nResults saved to {args.weights.parent.resolve()}"
222
+ f"\nVisualize: https://netron.app"
223
+ )
224
+
225
+
226
+ if __name__ == "__main__":
227
+ main()
boxmot/appearance/reid/factory.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from boxmot.appearance.backbones.clip.make_model import make_model
2
+ from boxmot.appearance.backbones.hacnn import HACNN
3
+ from boxmot.appearance.backbones.lmbn.lmbn_n import LMBN_n
4
+ from boxmot.appearance.backbones.mlfn import mlfn
5
+ from boxmot.appearance.backbones.mobilenetv2 import mobilenetv2_x1_0, mobilenetv2_x1_4
6
+ from boxmot.appearance.backbones.osnet import (
7
+ osnet_ibn_x1_0,
8
+ osnet_x0_5,
9
+ osnet_x0_25,
10
+ osnet_x0_75,
11
+ osnet_x1_0,
12
+ )
13
+ from boxmot.appearance.backbones.osnet_ain import (
14
+ osnet_ain_x0_5,
15
+ osnet_ain_x0_25,
16
+ osnet_ain_x0_75,
17
+ osnet_ain_x1_0,
18
+ )
19
+ from boxmot.appearance.backbones.resnet import resnet50, resnet101
20
+
21
+ # Map model names to their respective constructors
22
+ MODEL_FACTORY = {
23
+ "resnet50": resnet50,
24
+ "resnet101": resnet101,
25
+ "mobilenetv2_x1_0": mobilenetv2_x1_0,
26
+ "mobilenetv2_x1_4": mobilenetv2_x1_4,
27
+ "hacnn": HACNN,
28
+ "mlfn": mlfn,
29
+ "osnet_x1_0": osnet_x1_0,
30
+ "osnet_x0_75": osnet_x0_75,
31
+ "osnet_x0_5": osnet_x0_5,
32
+ "osnet_x0_25": osnet_x0_25,
33
+ "osnet_ibn_x1_0": osnet_ibn_x1_0,
34
+ "osnet_ain_x1_0": osnet_ain_x1_0,
35
+ "osnet_ain_x0_75": osnet_ain_x0_75,
36
+ "osnet_ain_x0_5": osnet_ain_x0_5,
37
+ "osnet_ain_x0_25": osnet_ain_x0_25,
38
+ "lmbn_n": LMBN_n,
39
+ "clip": make_model,
40
+ }
boxmot/appearance/reid/registry.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_registry.py
2
+ import torch
3
+ from collections import OrderedDict
4
+ from boxmot.utils import logger as LOGGER
5
+
6
+ from boxmot.appearance.reid.config import MODEL_TYPES, TRAINED_URLS, NR_CLASSES_DICT
7
+ from boxmot.appearance.reid.factory import MODEL_FACTORY
8
+
9
+ class ReIDModelRegistry:
10
+ """Encapsulates model registration and related utilities."""
11
+
12
+ @staticmethod
13
+ def show_downloadable_models():
14
+ LOGGER.info("Available .pt ReID models for automatic download")
15
+ LOGGER.info(list(TRAINED_URLS.keys()))
16
+
17
+ @staticmethod
18
+ def get_model_name(model):
19
+ for name in MODEL_TYPES:
20
+ if name in model.name:
21
+ return name
22
+ return None
23
+
24
+ @staticmethod
25
+ def get_model_url(model):
26
+ return TRAINED_URLS.get(model.name, None)
27
+
28
+ @staticmethod
29
+ def load_pretrained_weights(model, weight_path):
30
+ """
31
+ Loads pretrained weights into a model.
32
+ Chooses the proper map_location based on CUDA availability.
33
+ """
34
+ device = "cpu" if not torch.cuda.is_available() else None
35
+ checkpoint = torch.load(weight_path, map_location=torch.device("cpu") if device == "cpu" else None)
36
+ state_dict = checkpoint.get("state_dict", checkpoint)
37
+ model_dict = model.state_dict()
38
+
39
+ if "lmbn" in weight_path.parts:
40
+ model.load_state_dict(model_dict, strict=True)
41
+ else:
42
+ new_state_dict = OrderedDict()
43
+ matched_layers, discarded_layers = [], []
44
+ for k, v in state_dict.items():
45
+ # Remove 'module.' prefix if present
46
+ key = k[7:] if k.startswith("module.") else k
47
+ if key in model_dict and model_dict[key].size() == v.size():
48
+ new_state_dict[key] = v
49
+ matched_layers.append(key)
50
+ else:
51
+ discarded_layers.append(key)
52
+ model_dict.update(new_state_dict)
53
+ model.load_state_dict(model_dict)
54
+
55
+ if not matched_layers:
56
+ LOGGER.debug(f"Pretrained weights from {weight_path} cannot be loaded. Check key names manually.")
57
+ else:
58
+ LOGGER.success(f"Loaded pretrained weights from {weight_path}")
59
+
60
+ if discarded_layers:
61
+ LOGGER.debug(f"Discarded layers due to unmatched keys or size: {discarded_layers}")
62
+
63
+ @staticmethod
64
+ def show_available_models():
65
+ LOGGER.info("Available models:")
66
+ LOGGER.info(list(MODEL_FACTORY.keys()))
67
+
68
+ @staticmethod
69
+ def get_nr_classes(weights):
70
+ # Extract dataset name from weights name, then look up in the class dictionary
71
+ dataset_key = weights.name.split('_')[1]
72
+ return NR_CLASSES_DICT.get(dataset_key, 1)
73
+
74
+ @staticmethod
75
+ def build_model(name, num_classes, loss="softmax", pretrained=True, use_gpu=True):
76
+ if name not in MODEL_FACTORY:
77
+ available = list(MODEL_FACTORY.keys())
78
+ raise KeyError(f"Unknown model '{name}'. Must be one of {available}")
79
+
80
+ # Special case handling for clip model
81
+ if 'clip' in name:
82
+ from boxmot.appearance.backbones.clip.config.defaults import _C as cfg
83
+ return MODEL_FACTORY[name](cfg, num_class=num_classes, camera_num=2, view_num=1)
84
+
85
+ return MODEL_FACTORY[name](
86
+ num_classes=num_classes, loss=loss, pretrained=pretrained, use_gpu=use_gpu
87
+ )
boxmot/configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
boxmot/configs/boosttrack.yaml ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ max_age:
2
+ type: uniform
3
+ default: 60
4
+ range: [15, 90]
5
+
6
+ min_hits:
7
+ type: uniform
8
+ default: 3
9
+ range: [1, 5]
10
+
11
+ det_thresh:
12
+ type: uniform
13
+ default: 0.6
14
+ range: [0.1, 0.9]
15
+
16
+ iou_threshold:
17
+ type: uniform
18
+ default: 0.3
19
+ range: [0.1, 0.9]
20
+
21
+ use_ecc:
22
+ type: choice
23
+ default: True
24
+ options: [False, True]
25
+
26
+ min_box_area:
27
+ type: uniform
28
+ default: 10
29
+ range: [5, 100]
30
+
31
+ aspect_ratio_thresh:
32
+ type: uniform
33
+ default: 1.6
34
+ range: [0.1, 2.0]
35
+
36
+ lambda_iou:
37
+ type: uniform
38
+ default: 0.5
39
+ range: [0.3, 2.0]
40
+
41
+ lambda_mhd:
42
+ type: uniform
43
+ default: 0.25
44
+ range: [0.5, 2.0]
45
+
46
+ lambda_shape:
47
+ type: uniform
48
+ default: 0.25
49
+ range: [0.5, 2.0]
50
+
51
+ use_dlo_boost:
52
+ type: choice
53
+ default: True
54
+ options: [False, True]
55
+
56
+ use_duo_boost:
57
+ type: choice
58
+ default: True
59
+ options: [False, True]
60
+
61
+ dlo_boost_coef:
62
+ type: uniform
63
+ default: 0.65
64
+ range: [0.3, 2.0]
65
+
66
+ s_sim_corr:
67
+ type: choice
68
+ default: False
69
+ options: [False, True]
70
+
71
+ use_rich_s:
72
+ type: choice
73
+ default: True # True for BoostTrack++
74
+ options: [False, True]
75
+
76
+ use_sb:
77
+ type: choice
78
+ default: True # True for BoostTrack++
79
+ options: [False, True]
80
+
81
+ use_vt:
82
+ type: choice
83
+ default: True # True for BoostTrack++
84
+ options: [False, True]
85
+
86
+ with_reid:
87
+ type: choice
88
+ default: True # True for BoostTrack+ and BoostTrack++
89
+ options: [False, True]
90
+
boxmot/configs/botsort.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ track_high_thresh:
2
+ type: uniform
3
+ default: 0.6 # from the default parameters
4
+ range: [0.3, 0.7]
5
+
6
+ track_low_thresh:
7
+ type: uniform
8
+ default: 0.1 # from the default parameters
9
+ range: [0.1, 0.3]
10
+
11
+ new_track_thresh:
12
+ type: uniform
13
+ default: 0.7 # from the default parameters
14
+ range: [0.1, 0.8]
15
+
16
+ track_buffer:
17
+ type: randint
18
+ default: 30 # from the default parameters
19
+ range: [20, 81]
20
+
21
+ match_thresh:
22
+ type: uniform
23
+ default: 0.8 # from the default parameters
24
+ range: [0.1, 0.9]
25
+
26
+ proximity_thresh:
27
+ type: uniform
28
+ default: 0.5 # from the default parameters
29
+ range: [0.25, 0.75]
30
+
31
+ appearance_thresh:
32
+ type: uniform
33
+ default: 0.25 # from the default parameters
34
+ range: [0.1, 0.8]
35
+
36
+ cmc_method:
37
+ type: choice
38
+ default: ecc # from the default parameters
39
+ options: [sof, ecc]
boxmot/configs/bytetrack.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ min_conf:
2
+ type: uniform
3
+ default: 0.1 # from the default parameters
4
+ range: [0.1, 0.3]
5
+
6
+ track_thresh:
7
+ type: uniform
8
+ default: 0.6 # from the default parameters
9
+ range: [0.4, 0.6]
10
+
11
+ track_buffer:
12
+ type: randint
13
+ default: 30 # from the default parameters
14
+ range: [10, 61, 10] # step size of 10, upper bound exclusive
15
+
16
+ match_thresh:
17
+ type: uniform
18
+ default: 0.9 # from the default parameters
19
+ range: [0.7, 0.9]
20
+
21
+ frame_rate:
22
+ type: choice
23
+ default: 30 # from the default parameters
24
+ choices: [30] # static choice for Ray Search
boxmot/configs/deepocsort.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ det_thresh:
2
+ type: uniform
3
+ default: 0.5 # from the default parameters
4
+ range: [0.3, 0.6]
5
+
6
+ max_age:
7
+ type: randint
8
+ default: 30 # from the default parameters
9
+ range: [10, 61, 10] # step size of 10, upper bound exclusive
10
+
11
+ min_hits:
12
+ type: randint
13
+ default: 3 # from the default parameters
14
+ range: [1, 6] # upper bound exclusive
15
+
16
+ iou_thresh:
17
+ type: uniform
18
+ default: 0.3 # from the default parameters
19
+ range: [0.1, 0.4]
20
+
21
+ delta_t:
22
+ type: randint
23
+ default: 3 # from the default parameters
24
+ range: [1, 6] # upper bound exclusive
25
+
26
+ asso_func:
27
+ type: choice
28
+ default: iou # from the default parameters
29
+ options: ['iou', 'giou', 'diou', 'ciou', 'hmiou']
30
+
31
+ inertia:
32
+ type: uniform
33
+ default: 0.2 # from the default parameters
34
+ range: [0.1, 0.4]
35
+
36
+ w_association_emb:
37
+ type: uniform
38
+ default: 0.75 # from the default parameters
39
+ range: [0.5, 0.9]
40
+
41
+ alpha_fixed_emb:
42
+ type: uniform
43
+ default: 0.95 # from the default parameters
44
+ range: [0.9, 0.999]
45
+
46
+ aw_param:
47
+ type: uniform
48
+ default: 0.5 # from the default parameters
49
+ range: [0.3, 0.7]
50
+
51
+ embedding_off:
52
+ type: choice
53
+ default: false # from the default parameters
54
+ options: [True, False]
55
+
56
+ cmc_off:
57
+ type: choice
58
+ default: false # from the default parameters
59
+ options: [True, False]
60
+
61
+ aw_off:
62
+ type: choice
63
+ default: false # from the default parameters
64
+ options: [True, False]
65
+
66
+ Q_xy_scaling:
67
+ type: uniform
68
+ default: 0.01 # from the default parameters
69
+ range: [0.01, 1]
70
+
71
+ Q_s_scaling:
72
+ type: uniform
73
+ default: 0.0001 # from the default parameters
74
+ range: [0.0001, 1]
boxmot/configs/hybridsort.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ det_thresh:
2
+ type: uniform
3
+ default: 0.12442660055370669 # from the default parameters
4
+ range: [0, 0.6]
5
+
6
+ max_age:
7
+ type: randint
8
+ default: 30 # from the default parameters
9
+ range: [10, 151, 10] # step size of 10, upper bound exclusive
10
+
11
+ min_hits:
12
+ type: randint
13
+ default: 1 # from the default parameters
14
+ range: [1, 6] # upper bound exclusive
15
+
16
+ delta_t:
17
+ type: randint
18
+ default: 5 # from the default parameters
19
+ range: [1, 6] # upper bound exclusive
20
+
21
+ asso_func:
22
+ type: choice
23
+ default: hmiou # from the default parameters
24
+ options: ['iou', 'giou', 'diou']
25
+
26
+ iou_threshold:
27
+ type: uniform
28
+ default: 0.3 # from the default parameters
29
+ range: [0.1, 0.4]
30
+
31
+ inertia:
32
+ type: uniform
33
+ default: 0.369525477649008 # from the default parameters
34
+ range: [0.1, 0.4]
35
+
36
+ TCM_first_step_weight:
37
+ type: uniform
38
+ default: 0.2866529225304586 # from the default parameters
39
+ range: [0, 0.5]
40
+
41
+ longterm_reid_weight:
42
+ type: uniform
43
+ default: 0.0509704360503877 # from the default parameters
44
+ range: [0, 0.5]
45
+
46
+ use_byte:
47
+ type: choice
48
+ default: False # from the default parameters
49
+ options: [True, False]