Spaces:
Sleeping
Sleeping
usiddiquee
commited on
Commit
·
e1832f4
1
Parent(s):
3b054ae
hi
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +176 -0
- boxmot/__init__.py +21 -0
- boxmot/appearance/__init__.py +0 -0
- boxmot/appearance/backbones/__init__.py +1 -0
- boxmot/appearance/backbones/clip/__init__.py +1 -0
- boxmot/appearance/backbones/clip/clip/__init__.py +1 -0
- boxmot/appearance/backbones/clip/clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- boxmot/appearance/backbones/clip/clip/clip.py +222 -0
- boxmot/appearance/backbones/clip/clip/model.py +504 -0
- boxmot/appearance/backbones/clip/clip/simple_tokenizer.py +136 -0
- boxmot/appearance/backbones/clip/config/__init__.py +1 -0
- boxmot/appearance/backbones/clip/config/defaults.py +239 -0
- boxmot/appearance/backbones/clip/config/defaults_base.py +190 -0
- boxmot/appearance/backbones/clip/make_model.py +161 -0
- boxmot/appearance/backbones/clip/make_model_clipreid.py +247 -0
- boxmot/appearance/backbones/hacnn.py +406 -0
- boxmot/appearance/backbones/lmbn/__init__.py +1 -0
- boxmot/appearance/backbones/lmbn/attention.py +281 -0
- boxmot/appearance/backbones/lmbn/bnneck.py +166 -0
- boxmot/appearance/backbones/lmbn/lmbn_n.py +185 -0
- boxmot/appearance/backbones/mlfn.py +240 -0
- boxmot/appearance/backbones/mobilenetv2.py +246 -0
- boxmot/appearance/backbones/osnet.py +560 -0
- boxmot/appearance/backbones/osnet_ain.py +582 -0
- boxmot/appearance/backbones/resnet.py +517 -0
- boxmot/appearance/backends/base_backend.py +135 -0
- boxmot/appearance/backends/onnx_backend.py +42 -0
- boxmot/appearance/backends/openvino_backend.py +44 -0
- boxmot/appearance/backends/pytorch_backend.py +24 -0
- boxmot/appearance/backends/tensorrt_backend.py +126 -0
- boxmot/appearance/backends/tflite_backend.py +86 -0
- boxmot/appearance/backends/torchscript_backend.py +24 -0
- boxmot/appearance/exporters/base_exporter.py +56 -0
- boxmot/appearance/exporters/onnx_exporter.py +56 -0
- boxmot/appearance/exporters/openvino_exporter.py +26 -0
- boxmot/appearance/exporters/tensorrt_exporter.py +80 -0
- boxmot/appearance/exporters/tflite_exporter.py +37 -0
- boxmot/appearance/exporters/torchscript_exporter.py +15 -0
- boxmot/appearance/reid/__init__.py +16 -0
- boxmot/appearance/reid/auto_backend.py +128 -0
- boxmot/appearance/reid/config.py +73 -0
- boxmot/appearance/reid/export.py +227 -0
- boxmot/appearance/reid/factory.py +40 -0
- boxmot/appearance/reid/registry.py +87 -0
- boxmot/configs/__init__.py +1 -0
- boxmot/configs/boosttrack.yaml +90 -0
- boxmot/configs/botsort.yaml +39 -0
- boxmot/configs/bytetrack.yaml +24 -0
- boxmot/configs/deepocsort.yaml +74 -0
- boxmot/configs/hybridsort.yaml +49 -0
app.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import subprocess
|
| 4 |
+
import tempfile
|
| 5 |
+
import shutil
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import sys
|
| 8 |
+
import importlib.util
|
| 9 |
+
|
| 10 |
+
# Ensure models directory exists
|
| 11 |
+
MODELS_DIR = Path("models")
|
| 12 |
+
os.makedirs(MODELS_DIR, exist_ok=True)
|
| 13 |
+
|
| 14 |
+
def ensure_dependencies():
|
| 15 |
+
"""Ensure all required dependencies are installed."""
|
| 16 |
+
required_packages = [
|
| 17 |
+
"ultralytics",
|
| 18 |
+
"boxmot",
|
| 19 |
+
"supervision"
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
for package in required_packages:
|
| 23 |
+
try:
|
| 24 |
+
importlib.import_module(package)
|
| 25 |
+
print(f"✅ {package} is installed")
|
| 26 |
+
except ImportError:
|
| 27 |
+
print(f"⚠️ {package} is not installed, attempting to install...")
|
| 28 |
+
subprocess.run([sys.executable, "-m", "pip", "install", package], check=True)
|
| 29 |
+
|
| 30 |
+
# Apply tracker patches if tracker_patch.py exists
|
| 31 |
+
def apply_patches():
|
| 32 |
+
patch_path = Path("tracker_patch.py")
|
| 33 |
+
if patch_path.exists():
|
| 34 |
+
spec = importlib.util.spec_from_file_location("tracker_patch", patch_path)
|
| 35 |
+
if spec:
|
| 36 |
+
module = importlib.util.module_from_spec(spec)
|
| 37 |
+
spec.loader.exec_module(module)
|
| 38 |
+
if hasattr(module, "patch_trackers"):
|
| 39 |
+
module.patch_trackers()
|
| 40 |
+
print("✅ Applied tracker patches")
|
| 41 |
+
else:
|
| 42 |
+
print("⚠️ tracker_patch.py exists but has no patch_trackers function")
|
| 43 |
+
else:
|
| 44 |
+
print("⚠️ tracker_patch.py not found, skipping patches")
|
| 45 |
+
|
| 46 |
+
def run_tracking(video_file, yolo_model, reid_model, tracking_method, conf_threshold):
|
| 47 |
+
"""Run object tracking on the uploaded video."""
|
| 48 |
+
try:
|
| 49 |
+
# Create temporary workspace
|
| 50 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 51 |
+
# Prepare input
|
| 52 |
+
input_path = os.path.join(temp_dir, "input_video.mp4")
|
| 53 |
+
shutil.copy(video_file, input_path)
|
| 54 |
+
|
| 55 |
+
# Prepare output directory
|
| 56 |
+
output_dir = os.path.join(temp_dir, "output")
|
| 57 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 58 |
+
|
| 59 |
+
# Build command
|
| 60 |
+
cmd = [
|
| 61 |
+
"python", "tracking/track.py",
|
| 62 |
+
"--yolo-model", str(MODELS_DIR / yolo_model),
|
| 63 |
+
"--reid-model", str(MODELS_DIR / reid_model),
|
| 64 |
+
"--tracking-method", tracking_method,
|
| 65 |
+
"--source", input_path,
|
| 66 |
+
"--conf", str(conf_threshold),
|
| 67 |
+
"--save",
|
| 68 |
+
"--project", output_dir,
|
| 69 |
+
"--name", "track",
|
| 70 |
+
"--exist-ok"
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
# Special handling for OcSort
|
| 74 |
+
if tracking_method == "ocsort":
|
| 75 |
+
cmd.append("--per-class")
|
| 76 |
+
|
| 77 |
+
# Execute tracking with error handling
|
| 78 |
+
process = subprocess.run(
|
| 79 |
+
cmd,
|
| 80 |
+
capture_output=True,
|
| 81 |
+
text=True
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Check for errors in output
|
| 85 |
+
if process.returncode != 0:
|
| 86 |
+
error_message = process.stderr or process.stdout
|
| 87 |
+
return None, f"Error in tracking process: {error_message}"
|
| 88 |
+
|
| 89 |
+
# Find output video
|
| 90 |
+
output_files = []
|
| 91 |
+
for root, _, files in os.walk(output_dir):
|
| 92 |
+
for file in files:
|
| 93 |
+
if file.lower().endswith((".mp4", ".avi", ".mov")):
|
| 94 |
+
output_files.append(os.path.join(root, file))
|
| 95 |
+
|
| 96 |
+
if not output_files:
|
| 97 |
+
return None, "No output video was generated. Check if tracking was successful."
|
| 98 |
+
|
| 99 |
+
return output_files[0], "Processing completed successfully!"
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
return None, f"Error: {str(e)}"
|
| 103 |
+
|
| 104 |
+
# Define the Gradio interface
|
| 105 |
+
def process_video(video_path, yolo_model, reid_model, tracking_method, conf_threshold):
|
| 106 |
+
# Validate inputs
|
| 107 |
+
if not video_path:
|
| 108 |
+
return None, "Please upload a video file"
|
| 109 |
+
|
| 110 |
+
output_path, status = run_tracking(
|
| 111 |
+
video_path,
|
| 112 |
+
yolo_model,
|
| 113 |
+
reid_model,
|
| 114 |
+
tracking_method,
|
| 115 |
+
conf_threshold
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return output_path, status
|
| 119 |
+
|
| 120 |
+
# Available models and tracking methods
|
| 121 |
+
yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
|
| 122 |
+
reid_models = ["osnet_x0_25_msmt17.pt"]
|
| 123 |
+
tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
|
| 124 |
+
|
| 125 |
+
# Ensure dependencies and apply patches at startup
|
| 126 |
+
ensure_dependencies()
|
| 127 |
+
apply_patches()
|
| 128 |
+
|
| 129 |
+
# Create the Gradio interface
|
| 130 |
+
with gr.Blocks(title="YOLO Object Tracking") as app:
|
| 131 |
+
gr.Markdown("# 🚀 YOLO Object Tracking")
|
| 132 |
+
gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.")
|
| 133 |
+
|
| 134 |
+
with gr.Row():
|
| 135 |
+
with gr.Column():
|
| 136 |
+
input_video = gr.Video(label="Input Video", sources=["upload"])
|
| 137 |
+
|
| 138 |
+
with gr.Group():
|
| 139 |
+
yolo_model = gr.Dropdown(
|
| 140 |
+
choices=yolo_models,
|
| 141 |
+
value="yolov8n.pt",
|
| 142 |
+
label="YOLO Model"
|
| 143 |
+
)
|
| 144 |
+
reid_model = gr.Dropdown(
|
| 145 |
+
choices=reid_models,
|
| 146 |
+
value="osnet_x0_25_msmt17.pt",
|
| 147 |
+
label="ReID Model"
|
| 148 |
+
)
|
| 149 |
+
tracking_method = gr.Dropdown(
|
| 150 |
+
choices=tracking_methods,
|
| 151 |
+
value="bytetrack",
|
| 152 |
+
label="Tracking Method"
|
| 153 |
+
)
|
| 154 |
+
conf_threshold = gr.Slider(
|
| 155 |
+
minimum=0.1,
|
| 156 |
+
maximum=0.9,
|
| 157 |
+
value=0.3,
|
| 158 |
+
step=0.05,
|
| 159 |
+
label="Confidence Threshold"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
process_btn = gr.Button("Process Video", variant="primary")
|
| 163 |
+
|
| 164 |
+
with gr.Column():
|
| 165 |
+
output_video = gr.Video(label="Output Video with Tracking", autoplay=True)
|
| 166 |
+
status_text = gr.Textbox(label="Status", value="Ready to process video")
|
| 167 |
+
|
| 168 |
+
process_btn.click(
|
| 169 |
+
fn=process_video,
|
| 170 |
+
inputs=[input_video, yolo_model, reid_model, tracking_method, conf_threshold],
|
| 171 |
+
outputs=[output_video, status_text]
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Launch the app
|
| 175 |
+
if __name__ == "__main__":
|
| 176 |
+
app.launch(debug=True, share=True)
|
boxmot/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
__version__ = '12.0.7'
|
| 4 |
+
|
| 5 |
+
from boxmot.postprocessing.gsi import gsi
|
| 6 |
+
from boxmot.tracker_zoo import create_tracker, get_tracker_config
|
| 7 |
+
from boxmot.trackers.botsort.botsort import BotSort
|
| 8 |
+
from boxmot.trackers.bytetrack.bytetrack import ByteTrack
|
| 9 |
+
from boxmot.trackers.deepocsort.deepocsort import DeepOcSort
|
| 10 |
+
from boxmot.trackers.hybridsort.hybridsort import HybridSort
|
| 11 |
+
from boxmot.trackers.ocsort.ocsort import OcSort
|
| 12 |
+
from boxmot.trackers.strongsort.strongsort import StrongSort
|
| 13 |
+
from boxmot.trackers.imprassoc.imprassoctrack import ImprAssocTrack
|
| 14 |
+
from boxmot.trackers.boosttrack.boosttrack import BoostTrack
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
TRACKERS = ['bytetrack', 'botsort', 'strongsort', 'ocsort', 'deepocsort', 'hybridsort', 'imprassoc', 'boosttrack']
|
| 18 |
+
|
| 19 |
+
__all__ = ("__version__",
|
| 20 |
+
"StrongSort", "OcSort", "ByteTrack", "BotSort", "DeepOcSort", "HybridSort", "ImprAssocTrack", "BoostTrack",
|
| 21 |
+
"create_tracker", "get_tracker_config", "gsi")
|
boxmot/appearance/__init__.py
ADDED
|
File without changes
|
boxmot/appearance/backbones/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
boxmot/appearance/backbones/clip/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
boxmot/appearance/backbones/clip/clip/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
boxmot/appearance/backbones/clip/clip/bpe_simple_vocab_16e6.txt.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
+
size 1356917
|
boxmot/appearance/backbones/clip/clip/clip.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import hashlib
|
| 4 |
+
import os
|
| 5 |
+
import urllib
|
| 6 |
+
import warnings
|
| 7 |
+
from typing import List, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
|
| 12 |
+
ToTensor)
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from .model import build_model
|
| 16 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from torchvision.transforms import InterpolationMode
|
| 20 |
+
BICUBIC = InterpolationMode.BICUBIC
|
| 21 |
+
except ImportError:
|
| 22 |
+
BICUBIC = Image.BICUBIC
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
__all__ = ["available_models", "load", "tokenize"]
|
| 26 |
+
_tokenizer = _Tokenizer()
|
| 27 |
+
|
| 28 |
+
_MODELS = {
|
| 29 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", # noqa: E501
|
| 30 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", # noqa: E501
|
| 31 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", # noqa: E501
|
| 32 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", # noqa: E501
|
| 33 |
+
"ViT-B-32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", # noqa: E501
|
| 34 |
+
"ViT-B-16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", # noqa: E501
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
| 39 |
+
os.makedirs(root, exist_ok=True)
|
| 40 |
+
filename = os.path.basename(url)
|
| 41 |
+
|
| 42 |
+
expected_sha256 = url.split("/")[-2]
|
| 43 |
+
download_target = os.path.join(root, filename)
|
| 44 |
+
|
| 45 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 46 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 47 |
+
|
| 48 |
+
if os.path.isfile(download_target):
|
| 49 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
| 50 |
+
return download_target
|
| 51 |
+
else:
|
| 52 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
| 53 |
+
|
| 54 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 55 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
| 56 |
+
while True:
|
| 57 |
+
buffer = source.read(8192)
|
| 58 |
+
if not buffer:
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
output.write(buffer)
|
| 62 |
+
loop.update(len(buffer))
|
| 63 |
+
|
| 64 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
| 65 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
| 66 |
+
|
| 67 |
+
return download_target
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _transform(n_px):
|
| 71 |
+
return Compose([
|
| 72 |
+
Resize(n_px, interpolation=BICUBIC),
|
| 73 |
+
CenterCrop(n_px),
|
| 74 |
+
lambda image: image.convert("RGB"),
|
| 75 |
+
ToTensor(),
|
| 76 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
| 77 |
+
])
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def available_models() -> List[str]:
|
| 81 |
+
"""Returns the names of available CLIP models"""
|
| 82 |
+
return list(_MODELS.keys())
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
|
| 86 |
+
"""Load a CLIP model
|
| 87 |
+
|
| 88 |
+
Parameters
|
| 89 |
+
----------
|
| 90 |
+
name : str
|
| 91 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
| 92 |
+
|
| 93 |
+
device : Union[str, torch.device]
|
| 94 |
+
The device to put the loaded model
|
| 95 |
+
|
| 96 |
+
jit : bool
|
| 97 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
| 98 |
+
|
| 99 |
+
Returns
|
| 100 |
+
-------
|
| 101 |
+
model : torch.nn.Module
|
| 102 |
+
The CLIP model
|
| 103 |
+
|
| 104 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
| 105 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
| 106 |
+
"""
|
| 107 |
+
if name in _MODELS:
|
| 108 |
+
model_path = _download(_MODELS[name])
|
| 109 |
+
elif os.path.isfile(name):
|
| 110 |
+
model_path = name
|
| 111 |
+
else:
|
| 112 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
# loading JIT archive
|
| 116 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
| 117 |
+
state_dict = None
|
| 118 |
+
except RuntimeError:
|
| 119 |
+
# loading saved state dict
|
| 120 |
+
if jit:
|
| 121 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
| 122 |
+
jit = False
|
| 123 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 124 |
+
|
| 125 |
+
if not jit:
|
| 126 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
| 127 |
+
if str(device) == "cpu":
|
| 128 |
+
model.float()
|
| 129 |
+
return model, _transform(model.visual.input_resolution)
|
| 130 |
+
|
| 131 |
+
# patch the device names
|
| 132 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
| 133 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
| 134 |
+
|
| 135 |
+
def patch_device(module):
|
| 136 |
+
try:
|
| 137 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 138 |
+
except RuntimeError:
|
| 139 |
+
graphs = []
|
| 140 |
+
|
| 141 |
+
if hasattr(module, "forward1"):
|
| 142 |
+
graphs.append(module.forward1.graph)
|
| 143 |
+
|
| 144 |
+
for graph in graphs:
|
| 145 |
+
for node in graph.findAllNodes("prim::Constant"):
|
| 146 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
| 147 |
+
node.copyAttributes(device_node)
|
| 148 |
+
|
| 149 |
+
model.apply(patch_device)
|
| 150 |
+
patch_device(model.encode_image)
|
| 151 |
+
patch_device(model.encode_text)
|
| 152 |
+
|
| 153 |
+
# patch dtype to float32 on CPU
|
| 154 |
+
if str(device) == "cpu":
|
| 155 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
| 156 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
| 157 |
+
float_node = float_input.node()
|
| 158 |
+
|
| 159 |
+
def patch_float(module):
|
| 160 |
+
try:
|
| 161 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 162 |
+
except RuntimeError:
|
| 163 |
+
graphs = []
|
| 164 |
+
|
| 165 |
+
if hasattr(module, "forward1"):
|
| 166 |
+
graphs.append(module.forward1.graph)
|
| 167 |
+
|
| 168 |
+
for graph in graphs:
|
| 169 |
+
for node in graph.findAllNodes("aten::to"):
|
| 170 |
+
inputs = list(node.inputs())
|
| 171 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
| 172 |
+
if inputs[i].node()["value"] == 5:
|
| 173 |
+
inputs[i].node().copyAttributes(float_node)
|
| 174 |
+
|
| 175 |
+
model.apply(patch_float)
|
| 176 |
+
patch_float(model.encode_image)
|
| 177 |
+
patch_float(model.encode_text)
|
| 178 |
+
|
| 179 |
+
model.float()
|
| 180 |
+
|
| 181 |
+
return model, _transform(model.input_resolution.item())
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
|
| 185 |
+
"""
|
| 186 |
+
Returns the tokenized representation of given input string(s)
|
| 187 |
+
|
| 188 |
+
Parameters
|
| 189 |
+
----------
|
| 190 |
+
texts : Union[str, List[str]]
|
| 191 |
+
An input string or a list of input strings to tokenize
|
| 192 |
+
|
| 193 |
+
context_length : int
|
| 194 |
+
The context length to use; all CLIP models use 77 as the context length
|
| 195 |
+
|
| 196 |
+
truncate: bool
|
| 197 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
| 198 |
+
|
| 199 |
+
Returns
|
| 200 |
+
-------
|
| 201 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
| 202 |
+
"""
|
| 203 |
+
# import pdb
|
| 204 |
+
# pdb.set_trace()
|
| 205 |
+
if isinstance(texts, str):
|
| 206 |
+
texts = [texts] # ['a photo of a face.']
|
| 207 |
+
|
| 208 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"] # 49406
|
| 209 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"] # 49407
|
| 210 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
| 211 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) # 1,77
|
| 212 |
+
|
| 213 |
+
for i, tokens in enumerate(all_tokens):
|
| 214 |
+
if len(tokens) > context_length: # context_length 77
|
| 215 |
+
if truncate:
|
| 216 |
+
tokens = tokens[:context_length]
|
| 217 |
+
tokens[-1] = eot_token
|
| 218 |
+
else:
|
| 219 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
| 220 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 221 |
+
|
| 222 |
+
return result
|
boxmot/appearance/backbones/clip/clip/model.py
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from typing import Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Bottleneck(nn.Module):
|
| 13 |
+
expansion = 4
|
| 14 |
+
|
| 15 |
+
def __init__(self, inplanes, planes, stride=1):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 19 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 20 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 21 |
+
|
| 22 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 23 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 24 |
+
|
| 25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 26 |
+
|
| 27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 29 |
+
|
| 30 |
+
self.relu = nn.ReLU(inplace=True)
|
| 31 |
+
self.downsample = None
|
| 32 |
+
self.stride = stride
|
| 33 |
+
|
| 34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 36 |
+
self.downsample = nn.Sequential(OrderedDict([
|
| 37 |
+
("-1", nn.AvgPool2d(stride)),
|
| 38 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 39 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 40 |
+
]))
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor):
|
| 43 |
+
identity = x
|
| 44 |
+
|
| 45 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
| 46 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
| 47 |
+
out = self.avgpool(out)
|
| 48 |
+
out = self.bn3(self.conv3(out))
|
| 49 |
+
|
| 50 |
+
if self.downsample is not None:
|
| 51 |
+
identity = self.downsample(x)
|
| 52 |
+
|
| 53 |
+
out += identity
|
| 54 |
+
out = self.relu(out)
|
| 55 |
+
return out
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class AttentionPool2d(nn.Module):
|
| 59 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
| 62 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 63 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 64 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 65 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 66 |
+
self.num_heads = num_heads
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
# NCHW -> (HW)NC #32,2048,7,7 ->49, 32, 2048
|
| 70 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)
|
| 71 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 50,32,2048
|
| 72 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 73 |
+
x, _ = F.multi_head_attention_forward(
|
| 74 |
+
query=x, key=x, value=x,
|
| 75 |
+
embed_dim_to_check=x.shape[-1],
|
| 76 |
+
num_heads=self.num_heads,
|
| 77 |
+
q_proj_weight=self.q_proj.weight,
|
| 78 |
+
k_proj_weight=self.k_proj.weight,
|
| 79 |
+
v_proj_weight=self.v_proj.weight,
|
| 80 |
+
in_proj_weight=None,
|
| 81 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 82 |
+
bias_k=None,
|
| 83 |
+
bias_v=None,
|
| 84 |
+
add_zero_attn=False,
|
| 85 |
+
dropout_p=0,
|
| 86 |
+
out_proj_weight=self.c_proj.weight,
|
| 87 |
+
out_proj_bias=self.c_proj.bias,
|
| 88 |
+
use_separate_proj_weight=True,
|
| 89 |
+
training=self.training,
|
| 90 |
+
need_weights=False
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class ModifiedResNet(nn.Module):
|
| 97 |
+
"""
|
| 98 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 99 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 100 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 101 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.output_dim = output_dim
|
| 107 |
+
self.input_resolution = input_resolution
|
| 108 |
+
|
| 109 |
+
# the 3-layer stem
|
| 110 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 111 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 112 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 113 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 114 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 115 |
+
self.bn3 = nn.BatchNorm2d(width)
|
| 116 |
+
self.avgpool = nn.AvgPool2d(2)
|
| 117 |
+
self.relu = nn.ReLU(inplace=True)
|
| 118 |
+
|
| 119 |
+
# residual layers
|
| 120 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
| 121 |
+
self.layer1 = self._make_layer(width, layers[0])
|
| 122 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 123 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 124 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=1)
|
| 125 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
| 126 |
+
self.attnpool = AttentionPool2d(input_resolution, embed_dim, heads, output_dim)
|
| 127 |
+
|
| 128 |
+
def _make_layer(self, planes, blocks, stride=1):
|
| 129 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 130 |
+
|
| 131 |
+
self._inplanes = planes * Bottleneck.expansion
|
| 132 |
+
for _ in range(1, blocks):
|
| 133 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
| 134 |
+
|
| 135 |
+
return nn.Sequential(*layers)
|
| 136 |
+
|
| 137 |
+
def forward(self, x):
|
| 138 |
+
def stem(x):
|
| 139 |
+
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
| 140 |
+
x = self.relu(bn(conv(x)))
|
| 141 |
+
x = self.avgpool(x)
|
| 142 |
+
return x
|
| 143 |
+
|
| 144 |
+
x = x.type(self.conv1.weight.dtype)
|
| 145 |
+
x = stem(x)
|
| 146 |
+
x = self.layer1(x)
|
| 147 |
+
x = self.layer2(x)
|
| 148 |
+
x3 = self.layer3(x)
|
| 149 |
+
x4 = self.layer4(x3)
|
| 150 |
+
xproj = self.attnpool(x4)
|
| 151 |
+
|
| 152 |
+
return x3, x4, xproj
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class LayerNorm(nn.LayerNorm):
|
| 156 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 157 |
+
|
| 158 |
+
def forward(self, x: torch.Tensor):
|
| 159 |
+
orig_type = x.dtype
|
| 160 |
+
for param in self.parameters():
|
| 161 |
+
if param.dtype == torch.float16:
|
| 162 |
+
param.data = param.data.to(torch.float32)
|
| 163 |
+
ret = super().forward(x.to(torch.float32))
|
| 164 |
+
return ret.to(orig_type)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class QuickGELU(nn.Module):
|
| 168 |
+
def forward(self, x: torch.Tensor):
|
| 169 |
+
return x * torch.sigmoid(1.702 * x)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class ResidualAttentionBlock(nn.Module):
|
| 173 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
| 174 |
+
super().__init__()
|
| 175 |
+
|
| 176 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 177 |
+
self.ln_1 = LayerNorm(d_model)
|
| 178 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 179 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 180 |
+
("gelu", QuickGELU()),
|
| 181 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 182 |
+
]))
|
| 183 |
+
self.ln_2 = LayerNorm(d_model)
|
| 184 |
+
self.attn_mask = attn_mask
|
| 185 |
+
|
| 186 |
+
def attention(self, x: torch.Tensor):
|
| 187 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 188 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 189 |
+
|
| 190 |
+
def forward(self, x: torch.Tensor):
|
| 191 |
+
x = x + self.attention(self.ln_1(x))
|
| 192 |
+
x = x + self.mlp(self.ln_2(x))
|
| 193 |
+
return x
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class Transformer(nn.Module):
|
| 197 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
| 198 |
+
super().__init__()
|
| 199 |
+
self.width = width
|
| 200 |
+
self.layers = layers
|
| 201 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
| 202 |
+
|
| 203 |
+
def forward(self, x: torch.Tensor):
|
| 204 |
+
return self.resblocks(x)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class VisionTransformer(nn.Module):
|
| 208 |
+
def __init__(
|
| 209 |
+
self,
|
| 210 |
+
h_resolution: int,
|
| 211 |
+
w_resolution: int,
|
| 212 |
+
patch_size: int,
|
| 213 |
+
stride_size: int,
|
| 214 |
+
width: int,
|
| 215 |
+
layers: int,
|
| 216 |
+
heads: int,
|
| 217 |
+
output_dim: int
|
| 218 |
+
):
|
| 219 |
+
super().__init__()
|
| 220 |
+
self.h_resolution = h_resolution
|
| 221 |
+
self.w_resolution = w_resolution
|
| 222 |
+
self.output_dim = output_dim
|
| 223 |
+
self.conv1 = nn.Conv2d(
|
| 224 |
+
in_channels=3,
|
| 225 |
+
out_channels=width,
|
| 226 |
+
kernel_size=patch_size,
|
| 227 |
+
stride=stride_size,
|
| 228 |
+
bias=False
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
scale = width ** -0.5
|
| 232 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 233 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn(h_resolution * w_resolution + 1, width))
|
| 234 |
+
self.ln_pre = LayerNorm(width)
|
| 235 |
+
|
| 236 |
+
self.transformer = Transformer(width, layers, heads)
|
| 237 |
+
|
| 238 |
+
self.ln_post = LayerNorm(width)
|
| 239 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
| 240 |
+
|
| 241 |
+
def forward(self, x: torch.Tensor, cv_emb=None):
|
| 242 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
| 243 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 244 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 245 |
+
x = torch.cat([self.class_embedding.to(x.dtype) +
|
| 246 |
+
# shape = [*, grid ** 2 + 1, width]
|
| 247 |
+
torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
|
| 248 |
+
if cv_emb is not None:
|
| 249 |
+
x[:, 0] = x[:, 0] + cv_emb
|
| 250 |
+
x = x + self.positional_embedding.to(x.dtype)
|
| 251 |
+
x = self.ln_pre(x)
|
| 252 |
+
|
| 253 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 254 |
+
|
| 255 |
+
x11 = self.transformer.resblocks[:11](x)
|
| 256 |
+
x12 = self.transformer.resblocks[11](x11)
|
| 257 |
+
x11 = x11.permute(1, 0, 2) # LND -> NLD
|
| 258 |
+
x12 = x12.permute(1, 0, 2) # LND -> NLD
|
| 259 |
+
|
| 260 |
+
x12 = self.ln_post(x12)
|
| 261 |
+
|
| 262 |
+
if self.proj is not None:
|
| 263 |
+
xproj = x12 @ self.proj
|
| 264 |
+
|
| 265 |
+
return x11, x12, xproj
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class CLIP(nn.Module):
|
| 269 |
+
def __init__(self,
|
| 270 |
+
embed_dim: int,
|
| 271 |
+
# vision
|
| 272 |
+
image_resolution: int,
|
| 273 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
| 274 |
+
vision_width: int,
|
| 275 |
+
vision_patch_size: int,
|
| 276 |
+
vision_stride_size: int,
|
| 277 |
+
# text
|
| 278 |
+
context_length: int,
|
| 279 |
+
vocab_size: int,
|
| 280 |
+
transformer_width: int,
|
| 281 |
+
transformer_heads: int,
|
| 282 |
+
transformer_layers: int,
|
| 283 |
+
h_resolution: int,
|
| 284 |
+
w_resolution: int
|
| 285 |
+
):
|
| 286 |
+
super().__init__()
|
| 287 |
+
|
| 288 |
+
self.context_length = context_length
|
| 289 |
+
|
| 290 |
+
if isinstance(vision_layers, (tuple, list)):
|
| 291 |
+
vision_heads = vision_width * 32 // 64
|
| 292 |
+
self.visual = ModifiedResNet(
|
| 293 |
+
layers=vision_layers,
|
| 294 |
+
output_dim=embed_dim,
|
| 295 |
+
heads=vision_heads,
|
| 296 |
+
input_resolution=h_resolution * w_resolution,
|
| 297 |
+
width=vision_width
|
| 298 |
+
)
|
| 299 |
+
else:
|
| 300 |
+
vision_heads = vision_width // 64
|
| 301 |
+
self.visual = VisionTransformer(
|
| 302 |
+
h_resolution=h_resolution,
|
| 303 |
+
w_resolution=w_resolution,
|
| 304 |
+
patch_size=vision_patch_size,
|
| 305 |
+
stride_size=vision_stride_size,
|
| 306 |
+
width=vision_width,
|
| 307 |
+
layers=vision_layers,
|
| 308 |
+
heads=vision_heads,
|
| 309 |
+
output_dim=embed_dim
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
self.transformer = Transformer(
|
| 313 |
+
width=transformer_width,
|
| 314 |
+
layers=transformer_layers,
|
| 315 |
+
heads=transformer_heads,
|
| 316 |
+
attn_mask=self.build_attention_mask()
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
self.vocab_size = vocab_size
|
| 320 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
| 321 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
| 322 |
+
self.ln_final = LayerNorm(transformer_width)
|
| 323 |
+
|
| 324 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
| 325 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 326 |
+
|
| 327 |
+
self.initialize_parameters()
|
| 328 |
+
|
| 329 |
+
def initialize_parameters(self):
|
| 330 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 331 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 332 |
+
|
| 333 |
+
if isinstance(self.visual, ModifiedResNet):
|
| 334 |
+
if self.visual.attnpool is not None:
|
| 335 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
| 336 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
| 337 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
| 338 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
| 339 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
| 340 |
+
|
| 341 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
| 342 |
+
for name, param in resnet_block.named_parameters():
|
| 343 |
+
if name.endswith("bn3.weight"):
|
| 344 |
+
nn.init.zeros_(param)
|
| 345 |
+
|
| 346 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 347 |
+
attn_std = self.transformer.width ** -0.5
|
| 348 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 349 |
+
for block in self.transformer.resblocks:
|
| 350 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 351 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 352 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 353 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 354 |
+
|
| 355 |
+
if self.text_projection is not None:
|
| 356 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 357 |
+
|
| 358 |
+
def build_attention_mask(self):
|
| 359 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
| 360 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 361 |
+
mask = torch.empty(self.context_length, self.context_length)
|
| 362 |
+
mask.fill_(float("-inf"))
|
| 363 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 364 |
+
return mask
|
| 365 |
+
|
| 366 |
+
@property
|
| 367 |
+
def dtype(self):
|
| 368 |
+
return self.visual.conv1.weight.dtype
|
| 369 |
+
|
| 370 |
+
def encode_image(self, image):
|
| 371 |
+
return self.visual(image.type(self.dtype))
|
| 372 |
+
|
| 373 |
+
def encode_text(self, text):
|
| 374 |
+
x = self.token_embedding(text).type(self.dtype)
|
| 375 |
+
|
| 376 |
+
x = x + self.positional_embedding.type(self.dtype)
|
| 377 |
+
x = x.permute(1, 0, 2)
|
| 378 |
+
x = self.transformer(x)
|
| 379 |
+
x = x.permute(1, 0, 2)
|
| 380 |
+
x = self.ln_final(x).type(self.dtype)
|
| 381 |
+
|
| 382 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 383 |
+
|
| 384 |
+
return x
|
| 385 |
+
|
| 386 |
+
def forward(self, image, text):
|
| 387 |
+
image_features = self.encode_image(image)
|
| 388 |
+
text_features = self.encode_text(text)
|
| 389 |
+
|
| 390 |
+
# normalized features
|
| 391 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 392 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 393 |
+
|
| 394 |
+
# cosine similarity as logits
|
| 395 |
+
logit_scale = self.logit_scale.exp()
|
| 396 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 397 |
+
logits_per_text = logit_scale * text_features @ image_features.t()
|
| 398 |
+
|
| 399 |
+
# shape = [global_batch_size, global_batch_size]
|
| 400 |
+
return logits_per_image, logits_per_text
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def convert_weights(model: nn.Module):
|
| 404 |
+
"""Convert applicable model parameters to fp16"""
|
| 405 |
+
|
| 406 |
+
def _convert_weights_to_fp16(l):
|
| 407 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 408 |
+
l.weight.data = l.weight.data.float()
|
| 409 |
+
if l.bias is not None:
|
| 410 |
+
l.bias.data = l.bias.data.float()
|
| 411 |
+
|
| 412 |
+
if isinstance(l, nn.MultiheadAttention):
|
| 413 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 414 |
+
tensor = getattr(l, attr)
|
| 415 |
+
if tensor is not None:
|
| 416 |
+
tensor.data = tensor.data.float()
|
| 417 |
+
|
| 418 |
+
for name in ["text_projection", "proj"]:
|
| 419 |
+
if hasattr(l, name):
|
| 420 |
+
attr = getattr(l, name)
|
| 421 |
+
if attr is not None:
|
| 422 |
+
attr.data = attr.data.float()
|
| 423 |
+
|
| 424 |
+
model.apply(_convert_weights_to_fp16)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def build_model(state_dict: dict, h_resolution: int, w_resolution: int, vision_stride_size: int):
|
| 428 |
+
vit = "visual.proj" in state_dict
|
| 429 |
+
|
| 430 |
+
if vit:
|
| 431 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 432 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and
|
| 433 |
+
k.endswith(".attn.in_proj_weight")])
|
| 434 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 435 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 436 |
+
image_resolution = vision_patch_size * grid_size
|
| 437 |
+
else: # RN50
|
| 438 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if
|
| 439 |
+
k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 440 |
+
vision_layers = tuple(counts)
|
| 441 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 442 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 443 |
+
vision_patch_size = None
|
| 444 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 445 |
+
image_resolution = output_width * 32
|
| 446 |
+
|
| 447 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
| 448 |
+
context_length = state_dict["positional_embedding"].shape[0] # 77 (77,512)
|
| 449 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 450 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 451 |
+
transformer_heads = transformer_width // 64
|
| 452 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
| 453 |
+
|
| 454 |
+
model = CLIP(
|
| 455 |
+
embed_dim,
|
| 456 |
+
image_resolution, vision_layers, vision_width, vision_patch_size, vision_stride_size,
|
| 457 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,
|
| 458 |
+
h_resolution, w_resolution
|
| 459 |
+
)
|
| 460 |
+
if vit:
|
| 461 |
+
state_dict["visual.positional_embedding"] = resize_pos_embed(
|
| 462 |
+
state_dict["visual.positional_embedding"],
|
| 463 |
+
model.visual.positional_embedding,
|
| 464 |
+
h_resolution,
|
| 465 |
+
w_resolution
|
| 466 |
+
)
|
| 467 |
+
else: # RN50
|
| 468 |
+
state_dict["visual.attnpool.positional_embedding"] = resize_pos_embed(
|
| 469 |
+
state_dict["visual.attnpool.positional_embedding"],
|
| 470 |
+
model.visual.attnpool.positional_embedding,
|
| 471 |
+
h_resolution,
|
| 472 |
+
w_resolution
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 476 |
+
if key in state_dict:
|
| 477 |
+
del state_dict[key]
|
| 478 |
+
|
| 479 |
+
convert_weights(model)
|
| 480 |
+
|
| 481 |
+
model.load_state_dict(state_dict)
|
| 482 |
+
return model.eval()
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
import math
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def resize_pos_embed(posemb, posemb_new, hight, width):
|
| 489 |
+
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
| 490 |
+
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
| 491 |
+
print('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
| 492 |
+
|
| 493 |
+
ntok_new = posemb_new.shape[0] # 129,2048
|
| 494 |
+
|
| 495 |
+
posemb_token, posemb_grid = posemb[:1], posemb[1:]
|
| 496 |
+
ntok_new -= 1
|
| 497 |
+
|
| 498 |
+
gs_old = int(math.sqrt(len(posemb_grid))) # 14
|
| 499 |
+
print('Position embedding resize to height:{} width: {}'.format(hight, width))
|
| 500 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
| 501 |
+
posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear')
|
| 502 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1)
|
| 503 |
+
posemb = torch.cat([posemb_token, posemb_grid.squeeze()], dim=0)
|
| 504 |
+
return posemb
|
boxmot/appearance/backbones/clip/clip/simple_tokenizer.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import gzip
|
| 4 |
+
import html
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
|
| 7 |
+
import ftfy
|
| 8 |
+
import regex as re
|
| 9 |
+
|
| 10 |
+
from boxmot.utils import BOXMOT
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@lru_cache()
|
| 14 |
+
def default_bpe():
|
| 15 |
+
return BOXMOT / "appearance/backbones/clip/clip/bpe_simple_vocab_16e6.txt.gz"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@lru_cache()
|
| 19 |
+
def bytes_to_unicode():
|
| 20 |
+
"""
|
| 21 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 22 |
+
The reversible bpe codes work on unicode strings.
|
| 23 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 24 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 25 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
| 26 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 27 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 28 |
+
"""
|
| 29 |
+
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
| 30 |
+
cs = bs[:]
|
| 31 |
+
n = 0
|
| 32 |
+
for b in range(2**8):
|
| 33 |
+
if b not in bs:
|
| 34 |
+
bs.append(b)
|
| 35 |
+
cs.append(2 ** 8 + n)
|
| 36 |
+
n += 1
|
| 37 |
+
cs = [chr(n) for n in cs]
|
| 38 |
+
return dict(zip(bs, cs))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_pairs(word):
|
| 42 |
+
"""Return set of symbol pairs in a word.
|
| 43 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 44 |
+
"""
|
| 45 |
+
pairs = set()
|
| 46 |
+
prev_char = word[0]
|
| 47 |
+
for char in word[1:]:
|
| 48 |
+
pairs.add((prev_char, char))
|
| 49 |
+
prev_char = char
|
| 50 |
+
return pairs
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def basic_clean(text):
|
| 54 |
+
text = ftfy.fix_text(text)
|
| 55 |
+
text = html.unescape(html.unescape(text))
|
| 56 |
+
return text.strip()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def whitespace_clean(text):
|
| 60 |
+
text = re.sub(r'\s+', ' ', text)
|
| 61 |
+
text = text.strip()
|
| 62 |
+
return text
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SimpleTokenizer(object):
|
| 66 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
| 67 |
+
self.byte_encoder = bytes_to_unicode()
|
| 68 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 69 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
| 70 |
+
merges = merges[1:49152 - 256 - 2 + 1]
|
| 71 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 72 |
+
vocab = list(bytes_to_unicode().values())
|
| 73 |
+
vocab = vocab + [v + '</w>' for v in vocab]
|
| 74 |
+
for merge in merges:
|
| 75 |
+
vocab.append(''.join(merge))
|
| 76 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
| 77 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 78 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 79 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 80 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
| 81 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) # noqa: E501
|
| 82 |
+
|
| 83 |
+
def bpe(self, token):
|
| 84 |
+
if token in self.cache:
|
| 85 |
+
return self.cache[token]
|
| 86 |
+
word = tuple(token[:-1]) + (token[-1] + '</w>',)
|
| 87 |
+
pairs = get_pairs(word)
|
| 88 |
+
|
| 89 |
+
if not pairs:
|
| 90 |
+
return token + '</w>'
|
| 91 |
+
|
| 92 |
+
while True:
|
| 93 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| 94 |
+
if bigram not in self.bpe_ranks:
|
| 95 |
+
break
|
| 96 |
+
first, second = bigram
|
| 97 |
+
new_word = []
|
| 98 |
+
i = 0
|
| 99 |
+
while i < len(word):
|
| 100 |
+
try:
|
| 101 |
+
j = word.index(first, i)
|
| 102 |
+
new_word.extend(word[i:j])
|
| 103 |
+
i = j
|
| 104 |
+
except Exception:
|
| 105 |
+
|
| 106 |
+
new_word.extend(word[i:])
|
| 107 |
+
break
|
| 108 |
+
|
| 109 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 110 |
+
new_word.append(first + second)
|
| 111 |
+
i += 2
|
| 112 |
+
else:
|
| 113 |
+
new_word.append(word[i])
|
| 114 |
+
i += 1
|
| 115 |
+
new_word = tuple(new_word)
|
| 116 |
+
word = new_word
|
| 117 |
+
if len(word) == 1:
|
| 118 |
+
break
|
| 119 |
+
else:
|
| 120 |
+
pairs = get_pairs(word)
|
| 121 |
+
word = ' '.join(word)
|
| 122 |
+
self.cache[token] = word
|
| 123 |
+
return word
|
| 124 |
+
|
| 125 |
+
def encode(self, text):
|
| 126 |
+
bpe_tokens = []
|
| 127 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 128 |
+
for token in re.findall(self.pat, text):
|
| 129 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
| 130 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
| 131 |
+
return bpe_tokens
|
| 132 |
+
|
| 133 |
+
def decode(self, tokens):
|
| 134 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
| 135 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
| 136 |
+
return text
|
boxmot/appearance/backbones/clip/config/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
boxmot/appearance/backbones/clip/config/defaults.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
from yacs.config import CfgNode as CN
|
| 4 |
+
|
| 5 |
+
# -----------------------------------------------------------------------------
|
| 6 |
+
# Convention about Training / Test specific parameters
|
| 7 |
+
# -----------------------------------------------------------------------------
|
| 8 |
+
# Whenever an argument can be either used for training or for testing, the
|
| 9 |
+
# corresponding name will be post-fixed by a _TRAIN for a training parameter,
|
| 10 |
+
|
| 11 |
+
# -----------------------------------------------------------------------------
|
| 12 |
+
# Config definition
|
| 13 |
+
# -----------------------------------------------------------------------------
|
| 14 |
+
|
| 15 |
+
_C = CN()
|
| 16 |
+
# -----------------------------------------------------------------------------
|
| 17 |
+
# MODEL
|
| 18 |
+
# -----------------------------------------------------------------------------
|
| 19 |
+
_C.MODEL = CN()
|
| 20 |
+
# Using cuda or cpu for training
|
| 21 |
+
_C.MODEL.DEVICE = "cuda"
|
| 22 |
+
# ID number of GPU
|
| 23 |
+
_C.MODEL.DEVICE_ID = '0'
|
| 24 |
+
# Name of backbone
|
| 25 |
+
_C.MODEL.NAME = 'ViT-B-16'
|
| 26 |
+
# Last stride of backbone
|
| 27 |
+
_C.MODEL.LAST_STRIDE = 1
|
| 28 |
+
# Path to pretrained model of backbone
|
| 29 |
+
_C.MODEL.PRETRAIN_PATH = '/home/mikel.brostrom/yolo_tracking/clip_market1501.pt'
|
| 30 |
+
|
| 31 |
+
# Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model
|
| 32 |
+
# Options: 'imagenet' , 'self' , 'finetune'
|
| 33 |
+
_C.MODEL.PRETRAIN_CHOICE = 'imagenet'
|
| 34 |
+
|
| 35 |
+
# If train with BNNeck, options: 'bnneck' or 'no'
|
| 36 |
+
_C.MODEL.NECK = 'bnneck'
|
| 37 |
+
# If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration
|
| 38 |
+
_C.MODEL.IF_WITH_CENTER = 'no'
|
| 39 |
+
|
| 40 |
+
_C.MODEL.ID_LOSS_TYPE = 'softmax'
|
| 41 |
+
_C.MODEL.ID_LOSS_WEIGHT = 1.0
|
| 42 |
+
_C.MODEL.TRIPLET_LOSS_WEIGHT = 1.0
|
| 43 |
+
_C.MODEL.I2T_LOSS_WEIGHT = 1.0
|
| 44 |
+
|
| 45 |
+
_C.MODEL.METRIC_LOSS_TYPE = 'triplet'
|
| 46 |
+
# If train with multi-gpu ddp mode, options: 'True', 'False'
|
| 47 |
+
_C.MODEL.DIST_TRAIN = False
|
| 48 |
+
# If train with soft triplet loss, options: 'True', 'False'
|
| 49 |
+
_C.MODEL.NO_MARGIN = False
|
| 50 |
+
# If train with label smooth, options: 'on', 'off'
|
| 51 |
+
_C.MODEL.IF_LABELSMOOTH = 'on'
|
| 52 |
+
# If train with arcface loss, options: 'True', 'False'
|
| 53 |
+
_C.MODEL.COS_LAYER = False
|
| 54 |
+
|
| 55 |
+
# Transformer setting
|
| 56 |
+
_C.MODEL.DROP_PATH = 0.1
|
| 57 |
+
_C.MODEL.DROP_OUT = 0.0
|
| 58 |
+
_C.MODEL.ATT_DROP_RATE = 0.0
|
| 59 |
+
_C.MODEL.TRANSFORMER_TYPE = 'None'
|
| 60 |
+
_C.MODEL.STRIDE_SIZE = [16, 16]
|
| 61 |
+
|
| 62 |
+
# SIE Parameter
|
| 63 |
+
_C.MODEL.SIE_COE = 3.0
|
| 64 |
+
_C.MODEL.SIE_CAMERA = False
|
| 65 |
+
_C.MODEL.SIE_VIEW = False
|
| 66 |
+
|
| 67 |
+
# -----------------------------------------------------------------------------
|
| 68 |
+
# INPUT
|
| 69 |
+
# -----------------------------------------------------------------------------
|
| 70 |
+
_C.INPUT = CN()
|
| 71 |
+
# Size of the image during training
|
| 72 |
+
_C.INPUT.SIZE_TRAIN = [256, 128]
|
| 73 |
+
# Size of the image during test
|
| 74 |
+
_C.INPUT.SIZE_TEST = [256, 128]
|
| 75 |
+
# Random probability for image horizontal flip
|
| 76 |
+
_C.INPUT.PROB = 0.5
|
| 77 |
+
# Random probability for random erasing
|
| 78 |
+
_C.INPUT.RE_PROB = 0.5
|
| 79 |
+
# Values to be used for image normalization
|
| 80 |
+
_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
|
| 81 |
+
# Values to be used for image normalization
|
| 82 |
+
_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
|
| 83 |
+
# Value of padding size
|
| 84 |
+
_C.INPUT.PADDING = 10
|
| 85 |
+
|
| 86 |
+
# -----------------------------------------------------------------------------
|
| 87 |
+
# Dataset
|
| 88 |
+
# -----------------------------------------------------------------------------
|
| 89 |
+
_C.DATASETS = CN()
|
| 90 |
+
# List of the dataset names for training, as present in paths_catalog.py
|
| 91 |
+
_C.DATASETS.NAMES = ('market1501')
|
| 92 |
+
# Root directory where datasets should be used (and downloaded if not found)
|
| 93 |
+
_C.DATASETS.ROOT_DIR = ('../data')
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# -----------------------------------------------------------------------------
|
| 97 |
+
# DataLoader
|
| 98 |
+
# -----------------------------------------------------------------------------
|
| 99 |
+
_C.DATALOADER = CN()
|
| 100 |
+
# Number of data loading threads
|
| 101 |
+
_C.DATALOADER.NUM_WORKERS = 8
|
| 102 |
+
# Sampler for data loading
|
| 103 |
+
_C.DATALOADER.SAMPLER = 'softmax'
|
| 104 |
+
# Number of instance for one batch
|
| 105 |
+
_C.DATALOADER.NUM_INSTANCE = 16
|
| 106 |
+
|
| 107 |
+
# ---------------------------------------------------------------------------- #
|
| 108 |
+
# Solver
|
| 109 |
+
_C.SOLVER = CN()
|
| 110 |
+
_C.SOLVER.SEED = 1234
|
| 111 |
+
_C.SOLVER.MARGIN = 0.3
|
| 112 |
+
|
| 113 |
+
# stage1
|
| 114 |
+
# ---------------------------------------------------------------------------- #
|
| 115 |
+
# Name of optimizer
|
| 116 |
+
_C.SOLVER.STAGE1 = CN()
|
| 117 |
+
|
| 118 |
+
_C.SOLVER.STAGE1.IMS_PER_BATCH = 64
|
| 119 |
+
|
| 120 |
+
_C.SOLVER.STAGE1.OPTIMIZER_NAME = "Adam"
|
| 121 |
+
# Number of max epoches
|
| 122 |
+
_C.SOLVER.STAGE1.MAX_EPOCHS = 100
|
| 123 |
+
# Base learning rate
|
| 124 |
+
_C.SOLVER.STAGE1.BASE_LR = 3e-4
|
| 125 |
+
# Momentum
|
| 126 |
+
_C.SOLVER.STAGE1.MOMENTUM = 0.9
|
| 127 |
+
|
| 128 |
+
# Settings of weight decay
|
| 129 |
+
_C.SOLVER.STAGE1.WEIGHT_DECAY = 0.0005
|
| 130 |
+
_C.SOLVER.STAGE1.WEIGHT_DECAY_BIAS = 0.0005
|
| 131 |
+
|
| 132 |
+
# warm up factor
|
| 133 |
+
_C.SOLVER.STAGE1.WARMUP_FACTOR = 0.01
|
| 134 |
+
# warm up epochs
|
| 135 |
+
_C.SOLVER.STAGE1.WARMUP_EPOCHS = 5
|
| 136 |
+
_C.SOLVER.STAGE1.WARMUP_LR_INIT = 0.01
|
| 137 |
+
_C.SOLVER.STAGE1.LR_MIN = 0.000016
|
| 138 |
+
|
| 139 |
+
_C.SOLVER.STAGE1.WARMUP_ITERS = 500
|
| 140 |
+
# method of warm up, option: 'constant','linear'
|
| 141 |
+
_C.SOLVER.STAGE1.WARMUP_METHOD = "linear"
|
| 142 |
+
|
| 143 |
+
_C.SOLVER.STAGE1.COSINE_MARGIN = 0.5
|
| 144 |
+
_C.SOLVER.STAGE1.COSINE_SCALE = 30
|
| 145 |
+
|
| 146 |
+
# epoch number of saving checkpoints
|
| 147 |
+
_C.SOLVER.STAGE1.CHECKPOINT_PERIOD = 10
|
| 148 |
+
# iteration of display training log
|
| 149 |
+
_C.SOLVER.STAGE1.LOG_PERIOD = 100
|
| 150 |
+
# epoch number of validation
|
| 151 |
+
# Number of images per batch
|
| 152 |
+
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU will
|
| 153 |
+
# contain 16 images per batch
|
| 154 |
+
# _C.SOLVER.STAGE1.IMS_PER_BATCH = 64
|
| 155 |
+
_C.SOLVER.STAGE1.EVAL_PERIOD = 10
|
| 156 |
+
|
| 157 |
+
# ---------------------------------------------------------------------------- #
|
| 158 |
+
# Solver
|
| 159 |
+
# stage1
|
| 160 |
+
# ---------------------------------------------------------------------------- #
|
| 161 |
+
_C.SOLVER.STAGE2 = CN()
|
| 162 |
+
|
| 163 |
+
_C.SOLVER.STAGE2.IMS_PER_BATCH = 64
|
| 164 |
+
# Name of optimizer
|
| 165 |
+
_C.SOLVER.STAGE2.OPTIMIZER_NAME = "Adam"
|
| 166 |
+
# Number of max epoches
|
| 167 |
+
_C.SOLVER.STAGE2.MAX_EPOCHS = 100
|
| 168 |
+
# Base learning rate
|
| 169 |
+
_C.SOLVER.STAGE2.BASE_LR = 3e-4
|
| 170 |
+
# Whether using larger learning rate for fc layer
|
| 171 |
+
_C.SOLVER.STAGE2.LARGE_FC_LR = False
|
| 172 |
+
# Factor of learning bias
|
| 173 |
+
_C.SOLVER.STAGE2.BIAS_LR_FACTOR = 1
|
| 174 |
+
# Momentum
|
| 175 |
+
_C.SOLVER.STAGE2.MOMENTUM = 0.9
|
| 176 |
+
# Margin of triplet loss
|
| 177 |
+
# Learning rate of SGD to learn the centers of center loss
|
| 178 |
+
_C.SOLVER.STAGE2.CENTER_LR = 0.5
|
| 179 |
+
# Balanced weight of center loss
|
| 180 |
+
_C.SOLVER.STAGE2.CENTER_LOSS_WEIGHT = 0.0005
|
| 181 |
+
|
| 182 |
+
# Settings of weight decay
|
| 183 |
+
_C.SOLVER.STAGE2.WEIGHT_DECAY = 0.0005
|
| 184 |
+
_C.SOLVER.STAGE2.WEIGHT_DECAY_BIAS = 0.0005
|
| 185 |
+
|
| 186 |
+
# decay rate of learning rate
|
| 187 |
+
_C.SOLVER.STAGE2.GAMMA = 0.1
|
| 188 |
+
# decay step of learning rate
|
| 189 |
+
_C.SOLVER.STAGE2.STEPS = (40, 70)
|
| 190 |
+
# warm up factor
|
| 191 |
+
_C.SOLVER.STAGE2.WARMUP_FACTOR = 0.01
|
| 192 |
+
# warm up epochs
|
| 193 |
+
_C.SOLVER.STAGE2.WARMUP_EPOCHS = 5
|
| 194 |
+
_C.SOLVER.STAGE2.WARMUP_LR_INIT = 0.01
|
| 195 |
+
_C.SOLVER.STAGE2.LR_MIN = 0.000016
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
_C.SOLVER.STAGE2.WARMUP_ITERS = 500
|
| 199 |
+
# method of warm up, option: 'constant','linear'
|
| 200 |
+
_C.SOLVER.STAGE2.WARMUP_METHOD = "linear"
|
| 201 |
+
|
| 202 |
+
_C.SOLVER.STAGE2.COSINE_MARGIN = 0.5
|
| 203 |
+
_C.SOLVER.STAGE2.COSINE_SCALE = 30
|
| 204 |
+
|
| 205 |
+
# epoch number of saving checkpoints
|
| 206 |
+
_C.SOLVER.STAGE2.CHECKPOINT_PERIOD = 10
|
| 207 |
+
# iteration of display training log
|
| 208 |
+
_C.SOLVER.STAGE2.LOG_PERIOD = 100
|
| 209 |
+
# epoch number of validation
|
| 210 |
+
_C.SOLVER.STAGE2.EVAL_PERIOD = 10
|
| 211 |
+
# Number of images per batch
|
| 212 |
+
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU will
|
| 213 |
+
# contain 16 images per batch
|
| 214 |
+
|
| 215 |
+
# ---------------------------------------------------------------------------- #
|
| 216 |
+
# TEST
|
| 217 |
+
# ---------------------------------------------------------------------------- #
|
| 218 |
+
|
| 219 |
+
_C.TEST = CN()
|
| 220 |
+
# Number of images per batch during test
|
| 221 |
+
_C.TEST.IMS_PER_BATCH = 128
|
| 222 |
+
# If test with re-ranking, options: 'True','False'
|
| 223 |
+
_C.TEST.RE_RANKING = False
|
| 224 |
+
# Path to trained model
|
| 225 |
+
_C.TEST.WEIGHT = ""
|
| 226 |
+
# Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after'
|
| 227 |
+
_C.TEST.NECK_FEAT = 'after'
|
| 228 |
+
# Whether feature is nomalized before test, if yes, it is equivalent to cosine distance
|
| 229 |
+
_C.TEST.FEAT_NORM = 'yes'
|
| 230 |
+
|
| 231 |
+
# Name for saving the distmat after testing.
|
| 232 |
+
_C.TEST.DIST_MAT = "dist_mat.npy"
|
| 233 |
+
# Whether calculate the eval score option: 'True', 'False'
|
| 234 |
+
_C.TEST.EVAL = False
|
| 235 |
+
# ---------------------------------------------------------------------------- #
|
| 236 |
+
# Misc options
|
| 237 |
+
# ---------------------------------------------------------------------------- #
|
| 238 |
+
# Path to checkpoint and saved log of trained model
|
| 239 |
+
_C.OUTPUT_DIR = ""
|
boxmot/appearance/backbones/clip/config/defaults_base.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
from yacs.config import CfgNode as CN
|
| 4 |
+
|
| 5 |
+
# -----------------------------------------------------------------------------
|
| 6 |
+
# Convention about Training / Test specific parameters
|
| 7 |
+
# -----------------------------------------------------------------------------
|
| 8 |
+
# Whenever an argument can be either used for training or for testing, the
|
| 9 |
+
# corresponding name will be post-fixed by a _TRAIN for a training parameter,
|
| 10 |
+
|
| 11 |
+
# -----------------------------------------------------------------------------
|
| 12 |
+
# Config definition
|
| 13 |
+
# -----------------------------------------------------------------------------
|
| 14 |
+
|
| 15 |
+
_C = CN()
|
| 16 |
+
# -----------------------------------------------------------------------------
|
| 17 |
+
# MODEL
|
| 18 |
+
# -----------------------------------------------------------------------------
|
| 19 |
+
_C.MODEL = CN()
|
| 20 |
+
# Using cuda or cpu for training
|
| 21 |
+
_C.MODEL.DEVICE = "cuda"
|
| 22 |
+
# ID number of GPU
|
| 23 |
+
_C.MODEL.DEVICE_ID = '0'
|
| 24 |
+
# Name of backbone
|
| 25 |
+
_C.MODEL.NAME = 'resnet50'
|
| 26 |
+
# Last stride of backbone
|
| 27 |
+
_C.MODEL.LAST_STRIDE = 1
|
| 28 |
+
# Path to pretrained model of backbone
|
| 29 |
+
_C.MODEL.PRETRAIN_PATH = ''
|
| 30 |
+
|
| 31 |
+
# Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model
|
| 32 |
+
# Options: 'imagenet' , 'self' , 'finetune'
|
| 33 |
+
_C.MODEL.PRETRAIN_CHOICE = 'imagenet'
|
| 34 |
+
|
| 35 |
+
# If train with BNNeck, options: 'bnneck' or 'no'
|
| 36 |
+
_C.MODEL.NECK = 'bnneck'
|
| 37 |
+
# If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration
|
| 38 |
+
_C.MODEL.IF_WITH_CENTER = 'no'
|
| 39 |
+
|
| 40 |
+
_C.MODEL.ID_LOSS_TYPE = 'softmax'
|
| 41 |
+
_C.MODEL.ID_LOSS_WEIGHT = 1.0
|
| 42 |
+
_C.MODEL.TRIPLET_LOSS_WEIGHT = 1.0
|
| 43 |
+
_C.MODEL.I2T_LOSS_WEIGHT = 1.0
|
| 44 |
+
|
| 45 |
+
_C.MODEL.METRIC_LOSS_TYPE = 'triplet'
|
| 46 |
+
# If train with multi-gpu ddp mode, options: 'True', 'False'
|
| 47 |
+
_C.MODEL.DIST_TRAIN = False
|
| 48 |
+
# If train with soft triplet loss, options: 'True', 'False'
|
| 49 |
+
_C.MODEL.NO_MARGIN = False
|
| 50 |
+
# If train with label smooth, options: 'on', 'off'
|
| 51 |
+
_C.MODEL.IF_LABELSMOOTH = 'on'
|
| 52 |
+
# If train with arcface loss, options: 'True', 'False'
|
| 53 |
+
_C.MODEL.COS_LAYER = False
|
| 54 |
+
|
| 55 |
+
# Transformer setting
|
| 56 |
+
_C.MODEL.DROP_PATH = 0.1
|
| 57 |
+
_C.MODEL.DROP_OUT = 0.0
|
| 58 |
+
_C.MODEL.ATT_DROP_RATE = 0.0
|
| 59 |
+
_C.MODEL.TRANSFORMER_TYPE = 'None'
|
| 60 |
+
_C.MODEL.STRIDE_SIZE = [16, 16]
|
| 61 |
+
|
| 62 |
+
# SIE Parameter
|
| 63 |
+
_C.MODEL.SIE_COE = 3.0
|
| 64 |
+
_C.MODEL.SIE_CAMERA = False
|
| 65 |
+
_C.MODEL.SIE_VIEW = False
|
| 66 |
+
|
| 67 |
+
# -----------------------------------------------------------------------------
|
| 68 |
+
# INPUT
|
| 69 |
+
# -----------------------------------------------------------------------------
|
| 70 |
+
_C.INPUT = CN()
|
| 71 |
+
# Size of the image during training
|
| 72 |
+
_C.INPUT.SIZE_TRAIN = [384, 128]
|
| 73 |
+
# Size of the image during test
|
| 74 |
+
_C.INPUT.SIZE_TEST = [384, 128]
|
| 75 |
+
# Random probability for image horizontal flip
|
| 76 |
+
_C.INPUT.PROB = 0.5
|
| 77 |
+
# Random probability for random erasing
|
| 78 |
+
_C.INPUT.RE_PROB = 0.5
|
| 79 |
+
# Values to be used for image normalization
|
| 80 |
+
_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
|
| 81 |
+
# Values to be used for image normalization
|
| 82 |
+
_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
|
| 83 |
+
# Value of padding size
|
| 84 |
+
_C.INPUT.PADDING = 10
|
| 85 |
+
|
| 86 |
+
# -----------------------------------------------------------------------------
|
| 87 |
+
# Dataset
|
| 88 |
+
# -----------------------------------------------------------------------------
|
| 89 |
+
_C.DATASETS = CN()
|
| 90 |
+
# List of the dataset names for training, as present in paths_catalog.py
|
| 91 |
+
_C.DATASETS.NAMES = ('market1501')
|
| 92 |
+
# Root directory where datasets should be used (and downloaded if not found)
|
| 93 |
+
_C.DATASETS.ROOT_DIR = ('../data')
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# -----------------------------------------------------------------------------
|
| 97 |
+
# DataLoader
|
| 98 |
+
# -----------------------------------------------------------------------------
|
| 99 |
+
_C.DATALOADER = CN()
|
| 100 |
+
# Number of data loading threads
|
| 101 |
+
_C.DATALOADER.NUM_WORKERS = 8
|
| 102 |
+
# Sampler for data loading
|
| 103 |
+
_C.DATALOADER.SAMPLER = 'softmax'
|
| 104 |
+
# Number of instance for one batch
|
| 105 |
+
_C.DATALOADER.NUM_INSTANCE = 16
|
| 106 |
+
|
| 107 |
+
# ---------------------------------------------------------------------------- #
|
| 108 |
+
# Solver
|
| 109 |
+
# ---------------------------------------------------------------------------- #
|
| 110 |
+
_C.SOLVER = CN()
|
| 111 |
+
# Name of optimizer
|
| 112 |
+
_C.SOLVER.OPTIMIZER_NAME = "Adam"
|
| 113 |
+
# Number of max epoches
|
| 114 |
+
_C.SOLVER.MAX_EPOCHS = 100
|
| 115 |
+
# Base learning rate
|
| 116 |
+
_C.SOLVER.BASE_LR = 3e-4
|
| 117 |
+
# Whether using larger learning rate for fc layer
|
| 118 |
+
_C.SOLVER.LARGE_FC_LR = False
|
| 119 |
+
# Factor of learning bias
|
| 120 |
+
_C.SOLVER.BIAS_LR_FACTOR = 1
|
| 121 |
+
# Factor of learning bias
|
| 122 |
+
_C.SOLVER.SEED = 1234
|
| 123 |
+
# Momentum
|
| 124 |
+
_C.SOLVER.MOMENTUM = 0.9
|
| 125 |
+
# Margin of triplet loss
|
| 126 |
+
_C.SOLVER.MARGIN = 0.3
|
| 127 |
+
# Learning rate of SGD to learn the centers of center loss
|
| 128 |
+
_C.SOLVER.CENTER_LR = 0.5
|
| 129 |
+
# Balanced weight of center loss
|
| 130 |
+
_C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005
|
| 131 |
+
|
| 132 |
+
# Settings of weight decay
|
| 133 |
+
_C.SOLVER.WEIGHT_DECAY = 0.0005
|
| 134 |
+
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.0005
|
| 135 |
+
|
| 136 |
+
# decay rate of learning rate
|
| 137 |
+
_C.SOLVER.GAMMA = 0.1
|
| 138 |
+
# decay step of learning rate
|
| 139 |
+
_C.SOLVER.STEPS = (40, 70)
|
| 140 |
+
# warm up factor
|
| 141 |
+
_C.SOLVER.WARMUP_FACTOR = 0.01
|
| 142 |
+
# warm up epochs
|
| 143 |
+
_C.SOLVER.WARMUP_EPOCHS = 5
|
| 144 |
+
_C.SOLVER.WARMUP_LR_INIT = 0.01
|
| 145 |
+
_C.SOLVER.LR_MIN = 0.000016
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
_C.SOLVER.WARMUP_ITERS = 500
|
| 149 |
+
# method of warm up, option: 'constant','linear'
|
| 150 |
+
_C.SOLVER.WARMUP_METHOD = "linear"
|
| 151 |
+
|
| 152 |
+
_C.SOLVER.COSINE_MARGIN = 0.5
|
| 153 |
+
_C.SOLVER.COSINE_SCALE = 30
|
| 154 |
+
|
| 155 |
+
# epoch number of saving checkpoints
|
| 156 |
+
_C.SOLVER.CHECKPOINT_PERIOD = 10
|
| 157 |
+
# iteration of display training log
|
| 158 |
+
_C.SOLVER.LOG_PERIOD = 100
|
| 159 |
+
# epoch number of validation
|
| 160 |
+
_C.SOLVER.EVAL_PERIOD = 10
|
| 161 |
+
# Number of images per batch
|
| 162 |
+
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU will
|
| 163 |
+
# contain 16 images per batch
|
| 164 |
+
_C.SOLVER.IMS_PER_BATCH = 64
|
| 165 |
+
|
| 166 |
+
# ---------------------------------------------------------------------------- #
|
| 167 |
+
# TEST
|
| 168 |
+
# ---------------------------------------------------------------------------- #
|
| 169 |
+
|
| 170 |
+
_C.TEST = CN()
|
| 171 |
+
# Number of images per batch during test
|
| 172 |
+
_C.TEST.IMS_PER_BATCH = 128
|
| 173 |
+
# If test with re-ranking, options: 'True','False'
|
| 174 |
+
_C.TEST.RE_RANKING = False
|
| 175 |
+
# Path to trained model
|
| 176 |
+
_C.TEST.WEIGHT = ""
|
| 177 |
+
# Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after'
|
| 178 |
+
_C.TEST.NECK_FEAT = 'after'
|
| 179 |
+
# Whether feature is nomalized before test, if yes, it is equivalent to cosine distance
|
| 180 |
+
_C.TEST.FEAT_NORM = 'yes'
|
| 181 |
+
|
| 182 |
+
# Name for saving the distmat after testing.
|
| 183 |
+
_C.TEST.DIST_MAT = "dist_mat.npy"
|
| 184 |
+
# Whether calculate the eval score option: 'True', 'False'
|
| 185 |
+
_C.TEST.EVAL = False
|
| 186 |
+
# ---------------------------------------------------------------------------- #
|
| 187 |
+
# Misc options
|
| 188 |
+
# ---------------------------------------------------------------------------- #
|
| 189 |
+
# Path to checkpoint and saved log of trained model
|
| 190 |
+
_C.OUTPUT_DIR = ""
|
boxmot/appearance/backbones/clip/make_model.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from .clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
|
| 7 |
+
|
| 8 |
+
_tokenizer = _Tokenizer()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def weights_init_kaiming(m):
|
| 12 |
+
classname = m.__class__.__name__
|
| 13 |
+
if classname.find('Linear') != -1:
|
| 14 |
+
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
|
| 15 |
+
nn.init.constant_(m.bias, 0.0)
|
| 16 |
+
|
| 17 |
+
elif classname.find('Conv') != -1:
|
| 18 |
+
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
| 19 |
+
if m.bias is not None:
|
| 20 |
+
nn.init.constant_(m.bias, 0.0)
|
| 21 |
+
elif classname.find('BatchNorm') != -1:
|
| 22 |
+
if m.affine:
|
| 23 |
+
nn.init.constant_(m.weight, 1.0)
|
| 24 |
+
nn.init.constant_(m.bias, 0.0)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def weights_init_classifier(m):
|
| 28 |
+
classname = m.__class__.__name__
|
| 29 |
+
if classname.find('Linear') != -1:
|
| 30 |
+
nn.init.normal_(m.weight, std=0.001)
|
| 31 |
+
if m.bias:
|
| 32 |
+
nn.init.constant_(m.bias, 0.0)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class build_transformer(nn.Module):
|
| 36 |
+
def __init__(self, num_classes, camera_num, view_num, cfg):
|
| 37 |
+
super(build_transformer, self).__init__()
|
| 38 |
+
self.model_name = cfg.MODEL.NAME
|
| 39 |
+
self.cos_layer = cfg.MODEL.COS_LAYER
|
| 40 |
+
self.neck = cfg.MODEL.NECK
|
| 41 |
+
self.neck_feat = cfg.TEST.NECK_FEAT
|
| 42 |
+
if self.model_name == 'ViT-B-16':
|
| 43 |
+
self.in_planes = 768
|
| 44 |
+
self.in_planes_proj = 512
|
| 45 |
+
elif self.model_name == 'RN50':
|
| 46 |
+
self.in_planes = 2048
|
| 47 |
+
self.in_planes_proj = 1024
|
| 48 |
+
self.num_classes = num_classes
|
| 49 |
+
self.camera_num = camera_num
|
| 50 |
+
self.view_num = view_num
|
| 51 |
+
self.sie_coe = cfg.MODEL.SIE_COE
|
| 52 |
+
|
| 53 |
+
self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
|
| 54 |
+
self.classifier.apply(weights_init_classifier)
|
| 55 |
+
self.classifier_proj = nn.Linear(self.in_planes_proj, self.num_classes, bias=False)
|
| 56 |
+
self.classifier_proj.apply(weights_init_classifier)
|
| 57 |
+
|
| 58 |
+
self.bottleneck = nn.BatchNorm1d(self.in_planes)
|
| 59 |
+
self.bottleneck.bias.requires_grad_(False)
|
| 60 |
+
self.bottleneck.apply(weights_init_kaiming)
|
| 61 |
+
self.bottleneck_proj = nn.BatchNorm1d(self.in_planes_proj)
|
| 62 |
+
self.bottleneck_proj.bias.requires_grad_(False)
|
| 63 |
+
self.bottleneck_proj.apply(weights_init_kaiming)
|
| 64 |
+
|
| 65 |
+
self.h_resolution = int((cfg.INPUT.SIZE_TRAIN[0] - 16) // cfg.MODEL.STRIDE_SIZE[0] + 1)
|
| 66 |
+
self.w_resolution = int((cfg.INPUT.SIZE_TRAIN[1] - 16) // cfg.MODEL.STRIDE_SIZE[1] + 1)
|
| 67 |
+
self.vision_stride_size = cfg.MODEL.STRIDE_SIZE[0]
|
| 68 |
+
clip_model = load_clip_to_cpu(self.model_name, self.h_resolution, self.w_resolution, self.vision_stride_size)
|
| 69 |
+
|
| 70 |
+
self.image_encoder = clip_model.visual
|
| 71 |
+
|
| 72 |
+
# if cfg.MODEL.SIE_CAMERA and cfg.MODEL.SIE_VIEW:
|
| 73 |
+
# self.cv_embed = nn.Parameter(torch.zeros(camera_num * view_num, self.in_planes))
|
| 74 |
+
# trunc_normal_(self.cv_embed, std=.02)
|
| 75 |
+
# print('camera number is : {}'.format(camera_num))
|
| 76 |
+
# elif cfg.MODEL.SIE_CAMERA:
|
| 77 |
+
# self.cv_embed = nn.Parameter(torch.zeros(camera_num, self.in_planes))
|
| 78 |
+
# trunc_normal_(self.cv_embed, std=.02)
|
| 79 |
+
# print('camera number is : {}'.format(camera_num))
|
| 80 |
+
# elif cfg.MODEL.SIE_VIEW:
|
| 81 |
+
# self.cv_embed = nn.Parameter(torch.zeros(view_num, self.in_planes))
|
| 82 |
+
# trunc_normal_(self.cv_embed, std=.02)
|
| 83 |
+
# print('camera number is : {}'.format(view_num))
|
| 84 |
+
|
| 85 |
+
def forward(self, x, label=None, cam_label=None, view_label=None):
|
| 86 |
+
if self.model_name == 'RN50':
|
| 87 |
+
image_features_last, image_features, image_features_proj = self.image_encoder(x) # B,512 B,128,512
|
| 88 |
+
img_feature_last = nn.functional.avg_pool2d(
|
| 89 |
+
image_features_last,
|
| 90 |
+
image_features_last.shape[2:4]).view(x.shape[0], -1)
|
| 91 |
+
img_feature = nn.functional.avg_pool2d(
|
| 92 |
+
image_features,
|
| 93 |
+
image_features.shape[2:4]).view(x.shape[0], -1)
|
| 94 |
+
img_feature_proj = image_features_proj[0]
|
| 95 |
+
|
| 96 |
+
elif self.model_name == 'ViT-B-16':
|
| 97 |
+
if cam_label is not None and view_label is not None:
|
| 98 |
+
cv_embed = self.sie_coe * self.cv_embed[cam_label * self.view_num + view_label]
|
| 99 |
+
elif cam_label is not None:
|
| 100 |
+
cv_embed = self.sie_coe * self.cv_embed[cam_label]
|
| 101 |
+
elif view_label is not None:
|
| 102 |
+
cv_embed = self.sie_coe * self.cv_embed[view_label]
|
| 103 |
+
else:
|
| 104 |
+
cv_embed = None
|
| 105 |
+
# B,512 B,128,512
|
| 106 |
+
image_features_last, image_features, image_features_proj = self.image_encoder(x, cv_embed)
|
| 107 |
+
img_feature_last = image_features_last[:, 0]
|
| 108 |
+
img_feature = image_features[:, 0]
|
| 109 |
+
img_feature_proj = image_features_proj[:, 0]
|
| 110 |
+
|
| 111 |
+
feat = self.bottleneck(img_feature)
|
| 112 |
+
feat_proj = self.bottleneck_proj(img_feature_proj)
|
| 113 |
+
|
| 114 |
+
if self.training:
|
| 115 |
+
cls_score = self.classifier(feat)
|
| 116 |
+
cls_score_proj = self.classifier_proj(feat_proj)
|
| 117 |
+
return [cls_score, cls_score_proj], [img_feature_last, img_feature, img_feature_proj]
|
| 118 |
+
|
| 119 |
+
else:
|
| 120 |
+
if self.neck_feat == 'after':
|
| 121 |
+
# print("Test with feature after BN")
|
| 122 |
+
return torch.cat([feat, feat_proj], dim=1)
|
| 123 |
+
else:
|
| 124 |
+
return torch.cat([img_feature, img_feature_proj], dim=1)
|
| 125 |
+
|
| 126 |
+
def load_param(self, trained_path):
|
| 127 |
+
param_dict = torch.load(trained_path, map_location=torch.device("cpu"))
|
| 128 |
+
for i in self.state_dict():
|
| 129 |
+
self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
|
| 130 |
+
# print('Loading pretrained model from {}'.format('/home/mikel.brostrom/yolo_tracking/clip_market1501.pt'))
|
| 131 |
+
|
| 132 |
+
def load_param_finetune(self, model_path):
|
| 133 |
+
param_dict = torch.load(model_path)
|
| 134 |
+
for i in param_dict:
|
| 135 |
+
self.state_dict()[i].copy_(param_dict[i])
|
| 136 |
+
# print('Loading pretrained model for finetuning from {}'.format(model_path))
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def make_model(cfg, num_class, camera_num, view_num):
|
| 140 |
+
model = build_transformer(num_class, camera_num, view_num, cfg)
|
| 141 |
+
return model
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
from .clip import clip
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def load_clip_to_cpu(backbone_name, h_resolution, w_resolution, vision_stride_size):
|
| 148 |
+
url = clip._MODELS[backbone_name]
|
| 149 |
+
model_path = clip._download(url)
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
# loading JIT archive
|
| 153 |
+
model = torch.jit.load(model_path, map_location="cpu").eval()
|
| 154 |
+
state_dict = None
|
| 155 |
+
|
| 156 |
+
except RuntimeError:
|
| 157 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 158 |
+
|
| 159 |
+
model = clip.build_model(state_dict or model.state_dict(), h_resolution, w_resolution, vision_stride_size)
|
| 160 |
+
|
| 161 |
+
return model
|
boxmot/appearance/backbones/clip/make_model_clipreid.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from .clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
|
| 7 |
+
|
| 8 |
+
_tokenizer = _Tokenizer()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def weights_init_kaiming(m):
|
| 12 |
+
classname = m.__class__.__name__
|
| 13 |
+
if classname.find('Linear') != -1:
|
| 14 |
+
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
|
| 15 |
+
nn.init.constant_(m.bias, 0.0)
|
| 16 |
+
|
| 17 |
+
elif classname.find('Conv') != -1:
|
| 18 |
+
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
| 19 |
+
if m.bias is not None:
|
| 20 |
+
nn.init.constant_(m.bias, 0.0)
|
| 21 |
+
elif classname.find('BatchNorm') != -1:
|
| 22 |
+
if m.affine:
|
| 23 |
+
nn.init.constant_(m.weight, 1.0)
|
| 24 |
+
nn.init.constant_(m.bias, 0.0)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def weights_init_classifier(m):
|
| 28 |
+
classname = m.__class__.__name__
|
| 29 |
+
if classname.find('Linear') != -1:
|
| 30 |
+
nn.init.normal_(m.weight, std=0.001)
|
| 31 |
+
if m.bias:
|
| 32 |
+
nn.init.constant_(m.bias, 0.0)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TextEncoder(nn.Module):
|
| 36 |
+
def __init__(self, clip_model):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.transformer = clip_model.transformer
|
| 39 |
+
self.positional_embedding = clip_model.positional_embedding
|
| 40 |
+
self.ln_final = clip_model.ln_final
|
| 41 |
+
self.text_projection = clip_model.text_projection
|
| 42 |
+
self.dtype = clip_model.dtype
|
| 43 |
+
|
| 44 |
+
def forward(self, prompts, tokenized_prompts):
|
| 45 |
+
x = prompts + self.positional_embedding.type(self.dtype)
|
| 46 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 47 |
+
x = self.transformer(x)
|
| 48 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 49 |
+
x = self.ln_final(x).type(self.dtype)
|
| 50 |
+
|
| 51 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
| 52 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 53 |
+
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class build_transformer(nn.Module):
|
| 58 |
+
def __init__(self, num_classes, camera_num, view_num, cfg):
|
| 59 |
+
super(build_transformer, self).__init__()
|
| 60 |
+
self.model_name = cfg.MODEL.NAME
|
| 61 |
+
self.cos_layer = cfg.MODEL.COS_LAYER
|
| 62 |
+
self.neck = cfg.MODEL.NECK
|
| 63 |
+
self.neck_feat = cfg.TEST.NECK_FEAT
|
| 64 |
+
if self.model_name == 'ViT-B-16':
|
| 65 |
+
self.in_planes = 768
|
| 66 |
+
self.in_planes_proj = 512
|
| 67 |
+
elif self.model_name == 'RN50':
|
| 68 |
+
self.in_planes = 2048
|
| 69 |
+
self.in_planes_proj = 1024
|
| 70 |
+
self.num_classes = num_classes
|
| 71 |
+
self.camera_num = camera_num
|
| 72 |
+
self.view_num = view_num
|
| 73 |
+
self.sie_coe = cfg.MODEL.SIE_COE
|
| 74 |
+
|
| 75 |
+
self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
|
| 76 |
+
self.classifier.apply(weights_init_classifier)
|
| 77 |
+
self.classifier_proj = nn.Linear(self.in_planes_proj, self.num_classes, bias=False)
|
| 78 |
+
self.classifier_proj.apply(weights_init_classifier)
|
| 79 |
+
|
| 80 |
+
self.bottleneck = nn.BatchNorm1d(self.in_planes)
|
| 81 |
+
self.bottleneck.bias.requires_grad_(False)
|
| 82 |
+
self.bottleneck.apply(weights_init_kaiming)
|
| 83 |
+
self.bottleneck_proj = nn.BatchNorm1d(self.in_planes_proj)
|
| 84 |
+
self.bottleneck_proj.bias.requires_grad_(False)
|
| 85 |
+
self.bottleneck_proj.apply(weights_init_kaiming)
|
| 86 |
+
|
| 87 |
+
self.h_resolution = int((cfg.INPUT.SIZE_TRAIN[0] - 16) // cfg.MODEL.STRIDE_SIZE[0] + 1)
|
| 88 |
+
self.w_resolution = int((cfg.INPUT.SIZE_TRAIN[1] - 16) // cfg.MODEL.STRIDE_SIZE[1] + 1)
|
| 89 |
+
self.vision_stride_size = cfg.MODEL.STRIDE_SIZE[0]
|
| 90 |
+
clip_model = load_clip_to_cpu(self.model_name, self.h_resolution, self.w_resolution, self.vision_stride_size)
|
| 91 |
+
|
| 92 |
+
self.image_encoder = clip_model.visual
|
| 93 |
+
|
| 94 |
+
# if cfg.MODEL.SIE_CAMERA and cfg.MODEL.SIE_VIEW:
|
| 95 |
+
# self.cv_embed = nn.Parameter(torch.zeros(camera_num * view_num, self.in_planes))
|
| 96 |
+
# trunc_normal_(self.cv_embed, std=.02)
|
| 97 |
+
# print('camera number is : {}'.format(camera_num))
|
| 98 |
+
# elif cfg.MODEL.SIE_CAMERA:
|
| 99 |
+
# self.cv_embed = nn.Parameter(torch.zeros(camera_num, self.in_planes))
|
| 100 |
+
# trunc_normal_(self.cv_embed, std=.02)
|
| 101 |
+
# print('camera number is : {}'.format(camera_num))
|
| 102 |
+
# elif cfg.MODEL.SIE_VIEW:
|
| 103 |
+
# self.cv_embed = nn.Parameter(torch.zeros(view_num, self.in_planes))
|
| 104 |
+
# trunc_normal_(self.cv_embed, std=.02)
|
| 105 |
+
# print('camera number is : {}'.format(view_num))
|
| 106 |
+
|
| 107 |
+
dataset_name = cfg.DATASETS.NAMES
|
| 108 |
+
self.prompt_learner = PromptLearner(num_classes, dataset_name, clip_model.dtype, clip_model.token_embedding)
|
| 109 |
+
self.text_encoder = TextEncoder(clip_model)
|
| 110 |
+
|
| 111 |
+
def forward(self, x=None, label=None, get_image=False, get_text=False, cam_label=None, view_label=None):
|
| 112 |
+
if get_text is True:
|
| 113 |
+
prompts = self.prompt_learner(label)
|
| 114 |
+
text_features = self.text_encoder(prompts, self.prompt_learner.tokenized_prompts)
|
| 115 |
+
return text_features
|
| 116 |
+
|
| 117 |
+
if get_image is True:
|
| 118 |
+
image_features_last, image_features, image_features_proj = self.image_encoder(x)
|
| 119 |
+
if self.model_name == 'RN50':
|
| 120 |
+
return image_features_proj[0]
|
| 121 |
+
elif self.model_name == 'ViT-B-16':
|
| 122 |
+
return image_features_proj[:, 0]
|
| 123 |
+
|
| 124 |
+
if self.model_name == 'RN50':
|
| 125 |
+
image_features_last, image_features, image_features_proj = self.image_encoder(x)
|
| 126 |
+
img_feature_last = nn.functional.avg_pool2d(
|
| 127 |
+
image_features_last,
|
| 128 |
+
image_features_last.shape[2:4]).view(x.shape[0], -1)
|
| 129 |
+
img_feature = nn.functional.avg_pool2d(
|
| 130 |
+
image_features,
|
| 131 |
+
image_features.shape[2:4]).view(x.shape[0], -1)
|
| 132 |
+
img_feature_proj = image_features_proj[0]
|
| 133 |
+
|
| 134 |
+
elif self.model_name == 'ViT-B-16':
|
| 135 |
+
if cam_label is not None and view_label is not None:
|
| 136 |
+
cv_embed = self.sie_coe * self.cv_embed[cam_label * self.view_num + view_label]
|
| 137 |
+
elif cam_label is not None:
|
| 138 |
+
cv_embed = self.sie_coe * self.cv_embed[cam_label]
|
| 139 |
+
elif view_label is not None:
|
| 140 |
+
cv_embed = self.sie_coe * self.cv_embed[view_label]
|
| 141 |
+
else:
|
| 142 |
+
cv_embed = None
|
| 143 |
+
image_features_last, image_features, image_features_proj = self.image_encoder(x, cv_embed)
|
| 144 |
+
img_feature_last = image_features_last[:, 0]
|
| 145 |
+
img_feature = image_features[:, 0]
|
| 146 |
+
img_feature_proj = image_features_proj[:, 0]
|
| 147 |
+
|
| 148 |
+
feat = self.bottleneck(img_feature)
|
| 149 |
+
feat_proj = self.bottleneck_proj(img_feature_proj)
|
| 150 |
+
|
| 151 |
+
if self.training:
|
| 152 |
+
cls_score = self.classifier(feat)
|
| 153 |
+
cls_score_proj = self.classifier_proj(feat_proj)
|
| 154 |
+
return [cls_score, cls_score_proj], [img_feature_last, img_feature, img_feature_proj], img_feature_proj
|
| 155 |
+
|
| 156 |
+
else:
|
| 157 |
+
if self.neck_feat == 'after':
|
| 158 |
+
# print("Test with feature after BN")
|
| 159 |
+
return torch.cat([feat, feat_proj], dim=1)
|
| 160 |
+
else:
|
| 161 |
+
return torch.cat([img_feature, img_feature_proj], dim=1)
|
| 162 |
+
|
| 163 |
+
def load_param(self, trained_path):
|
| 164 |
+
param_dict = torch.load(trained_path)
|
| 165 |
+
for i in param_dict:
|
| 166 |
+
self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
|
| 167 |
+
print('Loaded pretrained model from {}'.format(trained_path))
|
| 168 |
+
|
| 169 |
+
def load_param_finetune(self, model_path):
|
| 170 |
+
param_dict = torch.load(model_path)
|
| 171 |
+
for i in param_dict:
|
| 172 |
+
self.state_dict()[i].copy_(param_dict[i])
|
| 173 |
+
print('Loading pretrained model for finetuning from {}'.format(model_path))
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def make_model(cfg, num_class, camera_num, view_num):
|
| 177 |
+
model = build_transformer(num_class, camera_num, view_num, cfg)
|
| 178 |
+
return model
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
from .clip import clip
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def load_clip_to_cpu(backbone_name, h_resolution, w_resolution, vision_stride_size):
|
| 185 |
+
url = clip._MODELS[backbone_name]
|
| 186 |
+
model_path = clip._download(url)
|
| 187 |
+
|
| 188 |
+
try:
|
| 189 |
+
# loading JIT archive
|
| 190 |
+
model = torch.jit.load(model_path, map_location="cpu").eval()
|
| 191 |
+
state_dict = None
|
| 192 |
+
|
| 193 |
+
except RuntimeError:
|
| 194 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 195 |
+
|
| 196 |
+
model = clip.build_model(state_dict or model.state_dict(), h_resolution, w_resolution, vision_stride_size)
|
| 197 |
+
|
| 198 |
+
return model
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class PromptLearner(nn.Module):
|
| 202 |
+
def __init__(self, num_class, dataset_name, dtype, token_embedding):
|
| 203 |
+
super().__init__()
|
| 204 |
+
if dataset_name == "VehicleID" or dataset_name == "veri":
|
| 205 |
+
ctx_init = "A photo of a X X X X vehicle."
|
| 206 |
+
else:
|
| 207 |
+
ctx_init = "A photo of a X X X X person."
|
| 208 |
+
|
| 209 |
+
ctx_dim = 512
|
| 210 |
+
# use given words to initialize context vectors
|
| 211 |
+
ctx_init = ctx_init.replace("_", " ")
|
| 212 |
+
n_ctx = 4
|
| 213 |
+
|
| 214 |
+
tokenized_prompts = clip.tokenize(ctx_init).cuda()
|
| 215 |
+
with torch.no_grad():
|
| 216 |
+
embedding = token_embedding(tokenized_prompts).type(dtype)
|
| 217 |
+
self.tokenized_prompts = tokenized_prompts # torch.Tensor
|
| 218 |
+
|
| 219 |
+
n_cls_ctx = 4
|
| 220 |
+
cls_vectors = torch.empty(num_class, n_cls_ctx, ctx_dim, dtype=dtype)
|
| 221 |
+
nn.init.normal_(cls_vectors, std=0.02)
|
| 222 |
+
self.cls_ctx = nn.Parameter(cls_vectors)
|
| 223 |
+
|
| 224 |
+
# These token vectors will be saved when in save_model(),
|
| 225 |
+
# but they should be ignored in load_model() as we want to use
|
| 226 |
+
# those computed using the current class names
|
| 227 |
+
self.register_buffer("token_prefix", embedding[:, :n_ctx + 1, :])
|
| 228 |
+
self.register_buffer("token_suffix", embedding[:, n_ctx + 1 + n_cls_ctx:, :])
|
| 229 |
+
self.num_class = num_class
|
| 230 |
+
self.n_cls_ctx = n_cls_ctx
|
| 231 |
+
|
| 232 |
+
def forward(self, label):
|
| 233 |
+
cls_ctx = self.cls_ctx[label]
|
| 234 |
+
b = label.shape[0]
|
| 235 |
+
prefix = self.token_prefix.expand(b, -1, -1)
|
| 236 |
+
suffix = self.token_suffix.expand(b, -1, -1)
|
| 237 |
+
|
| 238 |
+
prompts = torch.cat(
|
| 239 |
+
[
|
| 240 |
+
prefix, # (n_cls, 1, dim)
|
| 241 |
+
cls_ctx, # (n_cls, n_ctx, dim)
|
| 242 |
+
suffix, # (n_cls, *, dim)
|
| 243 |
+
],
|
| 244 |
+
dim=1,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
return prompts
|
boxmot/appearance/backbones/hacnn.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
from __future__ import absolute_import, division
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
__all__ = ["HACNN"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ConvBlock(nn.Module):
|
| 13 |
+
"""Basic convolutional block.
|
| 14 |
+
|
| 15 |
+
convolution + batch normalization + relu.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
in_c (int): number of input channels.
|
| 19 |
+
out_c (int): number of output channels.
|
| 20 |
+
k (int or tuple): kernel size.
|
| 21 |
+
s (int or tuple): stride.
|
| 22 |
+
p (int or tuple): padding.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, in_c, out_c, k, s=1, p=0):
|
| 26 |
+
super(ConvBlock, self).__init__()
|
| 27 |
+
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
|
| 28 |
+
self.bn = nn.BatchNorm2d(out_c)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
return F.relu(self.bn(self.conv(x)))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class InceptionA(nn.Module):
|
| 35 |
+
def __init__(self, in_channels, out_channels):
|
| 36 |
+
super(InceptionA, self).__init__()
|
| 37 |
+
mid_channels = out_channels // 4
|
| 38 |
+
|
| 39 |
+
self.stream1 = nn.Sequential(
|
| 40 |
+
ConvBlock(in_channels, mid_channels, 1),
|
| 41 |
+
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
| 42 |
+
)
|
| 43 |
+
self.stream2 = nn.Sequential(
|
| 44 |
+
ConvBlock(in_channels, mid_channels, 1),
|
| 45 |
+
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
| 46 |
+
)
|
| 47 |
+
self.stream3 = nn.Sequential(
|
| 48 |
+
ConvBlock(in_channels, mid_channels, 1),
|
| 49 |
+
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
| 50 |
+
)
|
| 51 |
+
self.stream4 = nn.Sequential(
|
| 52 |
+
nn.AvgPool2d(3, stride=1, padding=1),
|
| 53 |
+
ConvBlock(in_channels, mid_channels, 1),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
s1 = self.stream1(x)
|
| 58 |
+
s2 = self.stream2(x)
|
| 59 |
+
s3 = self.stream3(x)
|
| 60 |
+
s4 = self.stream4(x)
|
| 61 |
+
y = torch.cat([s1, s2, s3, s4], dim=1)
|
| 62 |
+
return y
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class InceptionB(nn.Module):
|
| 66 |
+
def __init__(self, in_channels, out_channels):
|
| 67 |
+
super(InceptionB, self).__init__()
|
| 68 |
+
mid_channels = out_channels // 4
|
| 69 |
+
|
| 70 |
+
self.stream1 = nn.Sequential(
|
| 71 |
+
ConvBlock(in_channels, mid_channels, 1),
|
| 72 |
+
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
|
| 73 |
+
)
|
| 74 |
+
self.stream2 = nn.Sequential(
|
| 75 |
+
ConvBlock(in_channels, mid_channels, 1),
|
| 76 |
+
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
| 77 |
+
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
|
| 78 |
+
)
|
| 79 |
+
self.stream3 = nn.Sequential(
|
| 80 |
+
nn.MaxPool2d(3, stride=2, padding=1),
|
| 81 |
+
ConvBlock(in_channels, mid_channels * 2, 1),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
s1 = self.stream1(x)
|
| 86 |
+
s2 = self.stream2(x)
|
| 87 |
+
s3 = self.stream3(x)
|
| 88 |
+
y = torch.cat([s1, s2, s3], dim=1)
|
| 89 |
+
return y
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class SpatialAttn(nn.Module):
|
| 93 |
+
"""Spatial Attention (Sec. 3.1.I.1)"""
|
| 94 |
+
|
| 95 |
+
def __init__(self):
|
| 96 |
+
super(SpatialAttn, self).__init__()
|
| 97 |
+
self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
|
| 98 |
+
self.conv2 = ConvBlock(1, 1, 1)
|
| 99 |
+
|
| 100 |
+
def forward(self, x):
|
| 101 |
+
# global cross-channel averaging
|
| 102 |
+
x = x.mean(1, keepdim=True)
|
| 103 |
+
# 3-by-3 conv
|
| 104 |
+
x = self.conv1(x)
|
| 105 |
+
# bilinear resizing
|
| 106 |
+
x = F.upsample(
|
| 107 |
+
x, (x.size(2) * 2, x.size(3) * 2), mode="bilinear", align_corners=True
|
| 108 |
+
)
|
| 109 |
+
# scaling conv
|
| 110 |
+
x = self.conv2(x)
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class ChannelAttn(nn.Module):
|
| 115 |
+
"""Channel Attention (Sec. 3.1.I.2)"""
|
| 116 |
+
|
| 117 |
+
def __init__(self, in_channels, reduction_rate=16):
|
| 118 |
+
super(ChannelAttn, self).__init__()
|
| 119 |
+
assert in_channels % reduction_rate == 0
|
| 120 |
+
self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1)
|
| 121 |
+
self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1)
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
# squeeze operation (global average pooling)
|
| 125 |
+
x = F.avg_pool2d(x, x.size()[2:])
|
| 126 |
+
# excitation operation (2 conv layers)
|
| 127 |
+
x = self.conv1(x)
|
| 128 |
+
x = self.conv2(x)
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class SoftAttn(nn.Module):
|
| 133 |
+
"""Soft Attention (Sec. 3.1.I)
|
| 134 |
+
|
| 135 |
+
Aim: Spatial Attention + Channel Attention
|
| 136 |
+
|
| 137 |
+
Output: attention maps with shape identical to input.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __init__(self, in_channels):
|
| 141 |
+
super(SoftAttn, self).__init__()
|
| 142 |
+
self.spatial_attn = SpatialAttn()
|
| 143 |
+
self.channel_attn = ChannelAttn(in_channels)
|
| 144 |
+
self.conv = ConvBlock(in_channels, in_channels, 1)
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
y_spatial = self.spatial_attn(x)
|
| 148 |
+
y_channel = self.channel_attn(x)
|
| 149 |
+
y = y_spatial * y_channel
|
| 150 |
+
y = torch.sigmoid(self.conv(y))
|
| 151 |
+
return y
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class HardAttn(nn.Module):
|
| 155 |
+
"""Hard Attention (Sec. 3.1.II)"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, in_channels):
|
| 158 |
+
super(HardAttn, self).__init__()
|
| 159 |
+
self.fc = nn.Linear(in_channels, 4 * 2)
|
| 160 |
+
self.init_params()
|
| 161 |
+
|
| 162 |
+
def init_params(self):
|
| 163 |
+
self.fc.weight.data.zero_()
|
| 164 |
+
self.fc.bias.data.copy_(
|
| 165 |
+
torch.tensor([0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float)
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def forward(self, x):
|
| 169 |
+
# squeeze operation (global average pooling)
|
| 170 |
+
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
|
| 171 |
+
# predict transformation parameters
|
| 172 |
+
theta = torch.tanh(self.fc(x))
|
| 173 |
+
theta = theta.view(-1, 4, 2)
|
| 174 |
+
return theta
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class HarmAttn(nn.Module):
|
| 178 |
+
"""Harmonious Attention (Sec. 3.1)"""
|
| 179 |
+
|
| 180 |
+
def __init__(self, in_channels):
|
| 181 |
+
super(HarmAttn, self).__init__()
|
| 182 |
+
self.soft_attn = SoftAttn(in_channels)
|
| 183 |
+
self.hard_attn = HardAttn(in_channels)
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
y_soft_attn = self.soft_attn(x)
|
| 187 |
+
theta = self.hard_attn(x)
|
| 188 |
+
return y_soft_attn, theta
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class HACNN(nn.Module):
|
| 192 |
+
"""Harmonious Attention Convolutional Neural Network.
|
| 193 |
+
|
| 194 |
+
Reference:
|
| 195 |
+
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
|
| 196 |
+
|
| 197 |
+
Public keys:
|
| 198 |
+
- ``hacnn``: HACNN.
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
# Args:
|
| 202 |
+
# num_classes (int): number of classes to predict
|
| 203 |
+
# nchannels (list): number of channels AFTER concatenation
|
| 204 |
+
# feat_dim (int): feature dimension for a single stream
|
| 205 |
+
# learn_region (bool): whether to learn region features (i.e. local branch)
|
| 206 |
+
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
num_classes,
|
| 210 |
+
loss="softmax",
|
| 211 |
+
nchannels=[128, 256, 384],
|
| 212 |
+
feat_dim=512,
|
| 213 |
+
learn_region=True,
|
| 214 |
+
use_gpu=True,
|
| 215 |
+
**kwargs
|
| 216 |
+
):
|
| 217 |
+
super(HACNN, self).__init__()
|
| 218 |
+
self.loss = loss
|
| 219 |
+
self.learn_region = learn_region
|
| 220 |
+
self.use_gpu = use_gpu
|
| 221 |
+
|
| 222 |
+
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
|
| 223 |
+
|
| 224 |
+
# Construct Inception + HarmAttn blocks
|
| 225 |
+
# ============== Block 1 ==============
|
| 226 |
+
self.inception1 = nn.Sequential(
|
| 227 |
+
InceptionA(32, nchannels[0]),
|
| 228 |
+
InceptionB(nchannels[0], nchannels[0]),
|
| 229 |
+
)
|
| 230 |
+
self.ha1 = HarmAttn(nchannels[0])
|
| 231 |
+
|
| 232 |
+
# ============== Block 2 ==============
|
| 233 |
+
self.inception2 = nn.Sequential(
|
| 234 |
+
InceptionA(nchannels[0], nchannels[1]),
|
| 235 |
+
InceptionB(nchannels[1], nchannels[1]),
|
| 236 |
+
)
|
| 237 |
+
self.ha2 = HarmAttn(nchannels[1])
|
| 238 |
+
|
| 239 |
+
# ============== Block 3 ==============
|
| 240 |
+
self.inception3 = nn.Sequential(
|
| 241 |
+
InceptionA(nchannels[1], nchannels[2]),
|
| 242 |
+
InceptionB(nchannels[2], nchannels[2]),
|
| 243 |
+
)
|
| 244 |
+
self.ha3 = HarmAttn(nchannels[2])
|
| 245 |
+
|
| 246 |
+
self.fc_global = nn.Sequential(
|
| 247 |
+
nn.Linear(nchannels[2], feat_dim),
|
| 248 |
+
nn.BatchNorm1d(feat_dim),
|
| 249 |
+
nn.ReLU(),
|
| 250 |
+
)
|
| 251 |
+
self.classifier_global = nn.Linear(feat_dim, num_classes)
|
| 252 |
+
|
| 253 |
+
if self.learn_region:
|
| 254 |
+
self.init_scale_factors()
|
| 255 |
+
self.local_conv1 = InceptionB(32, nchannels[0])
|
| 256 |
+
self.local_conv2 = InceptionB(nchannels[0], nchannels[1])
|
| 257 |
+
self.local_conv3 = InceptionB(nchannels[1], nchannels[2])
|
| 258 |
+
self.fc_local = nn.Sequential(
|
| 259 |
+
nn.Linear(nchannels[2] * 4, feat_dim),
|
| 260 |
+
nn.BatchNorm1d(feat_dim),
|
| 261 |
+
nn.ReLU(),
|
| 262 |
+
)
|
| 263 |
+
self.classifier_local = nn.Linear(feat_dim, num_classes)
|
| 264 |
+
self.feat_dim = feat_dim * 2
|
| 265 |
+
else:
|
| 266 |
+
self.feat_dim = feat_dim
|
| 267 |
+
|
| 268 |
+
def init_scale_factors(self):
|
| 269 |
+
# initialize scale factors (s_w, s_h) for four regions
|
| 270 |
+
self.scale_factors = []
|
| 271 |
+
self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
|
| 272 |
+
self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
|
| 273 |
+
self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
|
| 274 |
+
self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
|
| 275 |
+
|
| 276 |
+
def stn(self, x, theta):
|
| 277 |
+
"""Performs spatial transform
|
| 278 |
+
|
| 279 |
+
x: (batch, channel, height, width)
|
| 280 |
+
theta: (batch, 2, 3)
|
| 281 |
+
"""
|
| 282 |
+
grid = F.affine_grid(theta, x.size())
|
| 283 |
+
x = F.grid_sample(x, grid)
|
| 284 |
+
return x
|
| 285 |
+
|
| 286 |
+
def transform_theta(self, theta_i, region_idx):
|
| 287 |
+
"""Transforms theta to include (s_w, s_h), resulting in (batch, 2, 3)"""
|
| 288 |
+
scale_factors = self.scale_factors[region_idx]
|
| 289 |
+
theta = torch.zeros(theta_i.size(0), 2, 3)
|
| 290 |
+
theta[:, :, :2] = scale_factors
|
| 291 |
+
theta[:, :, -1] = theta_i
|
| 292 |
+
if self.use_gpu:
|
| 293 |
+
theta = theta.cuda()
|
| 294 |
+
return theta
|
| 295 |
+
|
| 296 |
+
def forward(self, x):
|
| 297 |
+
assert (
|
| 298 |
+
x.size(2) == 160 and x.size(3) == 64
|
| 299 |
+
), "Input size does not match, expected (160, 64) but got ({}, {})".format(
|
| 300 |
+
x.size(2), x.size(3)
|
| 301 |
+
)
|
| 302 |
+
x = self.conv(x)
|
| 303 |
+
|
| 304 |
+
# ============== Block 1 ==============
|
| 305 |
+
# global branch
|
| 306 |
+
x1 = self.inception1(x)
|
| 307 |
+
x1_attn, x1_theta = self.ha1(x1)
|
| 308 |
+
x1_out = x1 * x1_attn
|
| 309 |
+
# local branch
|
| 310 |
+
if self.learn_region:
|
| 311 |
+
x1_local_list = []
|
| 312 |
+
for region_idx in range(4):
|
| 313 |
+
x1_theta_i = x1_theta[:, region_idx, :]
|
| 314 |
+
x1_theta_i = self.transform_theta(x1_theta_i, region_idx)
|
| 315 |
+
x1_trans_i = self.stn(x, x1_theta_i)
|
| 316 |
+
x1_trans_i = F.upsample(
|
| 317 |
+
x1_trans_i, (24, 28), mode="bilinear", align_corners=True
|
| 318 |
+
)
|
| 319 |
+
x1_local_i = self.local_conv1(x1_trans_i)
|
| 320 |
+
x1_local_list.append(x1_local_i)
|
| 321 |
+
|
| 322 |
+
# ============== Block 2 ==============
|
| 323 |
+
# Block 2
|
| 324 |
+
# global branch
|
| 325 |
+
x2 = self.inception2(x1_out)
|
| 326 |
+
x2_attn, x2_theta = self.ha2(x2)
|
| 327 |
+
x2_out = x2 * x2_attn
|
| 328 |
+
# local branch
|
| 329 |
+
if self.learn_region:
|
| 330 |
+
x2_local_list = []
|
| 331 |
+
for region_idx in range(4):
|
| 332 |
+
x2_theta_i = x2_theta[:, region_idx, :]
|
| 333 |
+
x2_theta_i = self.transform_theta(x2_theta_i, region_idx)
|
| 334 |
+
x2_trans_i = self.stn(x1_out, x2_theta_i)
|
| 335 |
+
x2_trans_i = F.upsample(
|
| 336 |
+
x2_trans_i, (12, 14), mode="bilinear", align_corners=True
|
| 337 |
+
)
|
| 338 |
+
x2_local_i = x2_trans_i + x1_local_list[region_idx]
|
| 339 |
+
x2_local_i = self.local_conv2(x2_local_i)
|
| 340 |
+
x2_local_list.append(x2_local_i)
|
| 341 |
+
|
| 342 |
+
# ============== Block 3 ==============
|
| 343 |
+
# Block 3
|
| 344 |
+
# global branch
|
| 345 |
+
x3 = self.inception3(x2_out)
|
| 346 |
+
x3_attn, x3_theta = self.ha3(x3)
|
| 347 |
+
x3_out = x3 * x3_attn
|
| 348 |
+
# local branch
|
| 349 |
+
if self.learn_region:
|
| 350 |
+
x3_local_list = []
|
| 351 |
+
for region_idx in range(4):
|
| 352 |
+
x3_theta_i = x3_theta[:, region_idx, :]
|
| 353 |
+
x3_theta_i = self.transform_theta(x3_theta_i, region_idx)
|
| 354 |
+
x3_trans_i = self.stn(x2_out, x3_theta_i)
|
| 355 |
+
x3_trans_i = F.upsample(
|
| 356 |
+
x3_trans_i, (6, 7), mode="bilinear", align_corners=True
|
| 357 |
+
)
|
| 358 |
+
x3_local_i = x3_trans_i + x2_local_list[region_idx]
|
| 359 |
+
x3_local_i = self.local_conv3(x3_local_i)
|
| 360 |
+
x3_local_list.append(x3_local_i)
|
| 361 |
+
|
| 362 |
+
# ============== Feature generation ==============
|
| 363 |
+
# global branch
|
| 364 |
+
x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view(
|
| 365 |
+
x3_out.size(0), x3_out.size(1)
|
| 366 |
+
)
|
| 367 |
+
x_global = self.fc_global(x_global)
|
| 368 |
+
# local branch
|
| 369 |
+
if self.learn_region:
|
| 370 |
+
x_local_list = []
|
| 371 |
+
for region_idx in range(4):
|
| 372 |
+
x_local_i = x3_local_list[region_idx]
|
| 373 |
+
x_local_i = F.avg_pool2d(x_local_i, x_local_i.size()[2:]).view(
|
| 374 |
+
x_local_i.size(0), -1
|
| 375 |
+
)
|
| 376 |
+
x_local_list.append(x_local_i)
|
| 377 |
+
x_local = torch.cat(x_local_list, 1)
|
| 378 |
+
x_local = self.fc_local(x_local)
|
| 379 |
+
|
| 380 |
+
if not self.training:
|
| 381 |
+
# l2 normalization before concatenation
|
| 382 |
+
if self.learn_region:
|
| 383 |
+
x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True)
|
| 384 |
+
x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True)
|
| 385 |
+
return torch.cat([x_global, x_local], 1)
|
| 386 |
+
else:
|
| 387 |
+
return x_global
|
| 388 |
+
|
| 389 |
+
prelogits_global = self.classifier_global(x_global)
|
| 390 |
+
if self.learn_region:
|
| 391 |
+
prelogits_local = self.classifier_local(x_local)
|
| 392 |
+
|
| 393 |
+
if self.loss == "softmax":
|
| 394 |
+
if self.learn_region:
|
| 395 |
+
return (prelogits_global, prelogits_local)
|
| 396 |
+
else:
|
| 397 |
+
return prelogits_global
|
| 398 |
+
|
| 399 |
+
elif self.loss == "triplet":
|
| 400 |
+
if self.learn_region:
|
| 401 |
+
return (prelogits_global, prelogits_local), (x_global, x_local)
|
| 402 |
+
else:
|
| 403 |
+
return prelogits_global, x_global
|
| 404 |
+
|
| 405 |
+
else:
|
| 406 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
boxmot/appearance/backbones/lmbn/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
boxmot/appearance/backbones/lmbn/attention.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn import Conv2d, Module, Parameter, ReLU, Sigmoid, Softmax
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
torch_ver = torch.__version__[:3]
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"BatchDrop",
|
| 15 |
+
"BatchFeatureErase_Top",
|
| 16 |
+
"BatchRandomErasing",
|
| 17 |
+
"PAM_Module",
|
| 18 |
+
"CAM_Module",
|
| 19 |
+
"Dual_Module",
|
| 20 |
+
"SE_Module",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class BatchRandomErasing(nn.Module):
|
| 25 |
+
def __init__(
|
| 26 |
+
self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]
|
| 27 |
+
):
|
| 28 |
+
super(BatchRandomErasing, self).__init__()
|
| 29 |
+
|
| 30 |
+
self.probability = probability
|
| 31 |
+
self.mean = mean
|
| 32 |
+
self.sl = sl
|
| 33 |
+
self.sh = sh
|
| 34 |
+
self.r1 = r1
|
| 35 |
+
|
| 36 |
+
def forward(self, img):
|
| 37 |
+
if self.training:
|
| 38 |
+
if random.uniform(0, 1) > self.probability:
|
| 39 |
+
return img
|
| 40 |
+
|
| 41 |
+
for attempt in range(100):
|
| 42 |
+
area = img.size()[2] * img.size()[3]
|
| 43 |
+
|
| 44 |
+
target_area = random.uniform(self.sl, self.sh) * area
|
| 45 |
+
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
|
| 46 |
+
|
| 47 |
+
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
| 48 |
+
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
| 49 |
+
|
| 50 |
+
if w < img.size()[3] and h < img.size()[2]:
|
| 51 |
+
x1 = random.randint(0, img.size()[2] - h)
|
| 52 |
+
y1 = random.randint(0, img.size()[3] - w)
|
| 53 |
+
if img.size()[1] == 3:
|
| 54 |
+
img[:, 0, x1: x1 + h, y1: y1 + w] = self.mean[0]
|
| 55 |
+
img[:, 1, x1: x1 + h, y1: y1 + w] = self.mean[1]
|
| 56 |
+
img[:, 2, x1: x1 + h, y1: y1 + w] = self.mean[2]
|
| 57 |
+
else:
|
| 58 |
+
img[:, 0, x1: x1 + h, y1: y1 + w] = self.mean[0]
|
| 59 |
+
return img
|
| 60 |
+
|
| 61 |
+
return img
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class BatchDrop(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
Ref: Batch DropBlock Network for Person Re-identification and Beyond
|
| 67 |
+
https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
|
| 68 |
+
Created by: daizuozhuo
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, h_ratio, w_ratio):
|
| 72 |
+
super(BatchDrop, self).__init__()
|
| 73 |
+
self.h_ratio = h_ratio
|
| 74 |
+
self.w_ratio = w_ratio
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
if self.training:
|
| 78 |
+
h, w = x.size()[-2:]
|
| 79 |
+
rh = round(self.h_ratio * h)
|
| 80 |
+
rw = round(self.w_ratio * w)
|
| 81 |
+
sx = random.randint(0, h - rh)
|
| 82 |
+
sy = random.randint(0, w - rw)
|
| 83 |
+
mask = x.new_ones(x.size())
|
| 84 |
+
mask[:, :, sx: sx + rh, sy: sy + rw] = 0
|
| 85 |
+
x = x * mask
|
| 86 |
+
return x
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class BatchDropTop(nn.Module):
|
| 90 |
+
"""
|
| 91 |
+
Ref: Top-DB-Net: Top DropBlock for Activation Enhancement in Person Re-Identification
|
| 92 |
+
https://github.com/RQuispeC/top-dropblock/blob/master/torchreid/models/bdnet.py
|
| 93 |
+
Created by: RQuispeC
|
| 94 |
+
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(self, h_ratio):
|
| 98 |
+
super(BatchDropTop, self).__init__()
|
| 99 |
+
self.h_ratio = h_ratio
|
| 100 |
+
|
| 101 |
+
def forward(self, x, visdrop=False):
|
| 102 |
+
if self.training or visdrop:
|
| 103 |
+
b, c, h, w = x.size()
|
| 104 |
+
rh = round(self.h_ratio * h)
|
| 105 |
+
act = (x**2).sum(1)
|
| 106 |
+
act = act.view(b, h * w)
|
| 107 |
+
act = F.normalize(act, p=2, dim=1)
|
| 108 |
+
act = act.view(b, h, w)
|
| 109 |
+
max_act, _ = act.max(2)
|
| 110 |
+
ind = torch.argsort(max_act, 1)
|
| 111 |
+
ind = ind[:, -rh:]
|
| 112 |
+
mask = []
|
| 113 |
+
for i in range(b):
|
| 114 |
+
rmask = torch.ones(h)
|
| 115 |
+
rmask[ind[i]] = 0
|
| 116 |
+
mask.append(rmask.unsqueeze(0))
|
| 117 |
+
mask = torch.cat(mask)
|
| 118 |
+
mask = torch.repeat_interleave(mask, w, 1).view(b, h, w)
|
| 119 |
+
mask = torch.repeat_interleave(mask, c, 0).view(b, c, h, w)
|
| 120 |
+
if x.is_cuda:
|
| 121 |
+
mask = mask.cuda()
|
| 122 |
+
if visdrop:
|
| 123 |
+
return mask
|
| 124 |
+
x = x * mask
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class BatchFeatureErase_Top(nn.Module):
|
| 129 |
+
"""
|
| 130 |
+
Ref: Top-DB-Net: Top DropBlock for Activation Enhancement in Person Re-Identification
|
| 131 |
+
https://github.com/RQuispeC/top-dropblock/blob/master/torchreid/models/bdnet.py
|
| 132 |
+
Created by: RQuispeC
|
| 133 |
+
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
channels,
|
| 139 |
+
bottleneck_type,
|
| 140 |
+
h_ratio=0.33,
|
| 141 |
+
w_ratio=1.0,
|
| 142 |
+
double_bottleneck=False,
|
| 143 |
+
):
|
| 144 |
+
super(BatchFeatureErase_Top, self).__init__()
|
| 145 |
+
|
| 146 |
+
self.drop_batch_bottleneck = bottleneck_type(channels, 512)
|
| 147 |
+
|
| 148 |
+
self.drop_batch_drop_basic = BatchDrop(h_ratio, w_ratio)
|
| 149 |
+
self.drop_batch_drop_top = BatchDropTop(h_ratio)
|
| 150 |
+
|
| 151 |
+
def forward(self, x, drop_top=True, bottleneck_features=True, visdrop=False):
|
| 152 |
+
features = self.drop_batch_bottleneck(x)
|
| 153 |
+
|
| 154 |
+
if drop_top:
|
| 155 |
+
x = self.drop_batch_drop_top(features, visdrop=visdrop)
|
| 156 |
+
else:
|
| 157 |
+
x = self.drop_batch_drop_basic(features, visdrop=visdrop)
|
| 158 |
+
if visdrop:
|
| 159 |
+
return x # x is dropmask
|
| 160 |
+
if bottleneck_features:
|
| 161 |
+
return x, features
|
| 162 |
+
else:
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class SE_Module(Module):
|
| 167 |
+
def __init__(self, channels, reduction=4):
|
| 168 |
+
super(SE_Module, self).__init__()
|
| 169 |
+
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0)
|
| 170 |
+
self.relu = ReLU(inplace=True)
|
| 171 |
+
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0)
|
| 172 |
+
self.sigmoid = Sigmoid()
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
module_input = x
|
| 176 |
+
x = self.fc1(x)
|
| 177 |
+
x = self.relu(x)
|
| 178 |
+
x = self.fc2(x)
|
| 179 |
+
x = self.sigmoid(x)
|
| 180 |
+
return module_input * x
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class PAM_Module(Module):
|
| 184 |
+
"""Position attention module"""
|
| 185 |
+
|
| 186 |
+
# Ref from SAGAN
|
| 187 |
+
|
| 188 |
+
def __init__(self, in_dim):
|
| 189 |
+
super(PAM_Module, self).__init__()
|
| 190 |
+
self.chanel_in = in_dim
|
| 191 |
+
|
| 192 |
+
self.query_conv = Conv2d(
|
| 193 |
+
in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1
|
| 194 |
+
)
|
| 195 |
+
self.key_conv = Conv2d(
|
| 196 |
+
in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1
|
| 197 |
+
)
|
| 198 |
+
self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
| 199 |
+
self.gamma = Parameter(torch.zeros(1))
|
| 200 |
+
|
| 201 |
+
self.softmax = Softmax(dim=-1)
|
| 202 |
+
|
| 203 |
+
def forward(self, x):
|
| 204 |
+
"""
|
| 205 |
+
inputs :
|
| 206 |
+
x : input feature maps( B X C X H X W)
|
| 207 |
+
returns :
|
| 208 |
+
out : attention value + input feature
|
| 209 |
+
attention: B X (HxW) X (HxW)
|
| 210 |
+
"""
|
| 211 |
+
m_batchsize, C, height, width = x.size()
|
| 212 |
+
proj_query = (
|
| 213 |
+
self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
|
| 214 |
+
)
|
| 215 |
+
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
|
| 216 |
+
energy = torch.bmm(proj_query, proj_key)
|
| 217 |
+
attention = self.softmax(energy)
|
| 218 |
+
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
|
| 219 |
+
|
| 220 |
+
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
| 221 |
+
out = out.view(m_batchsize, C, height, width)
|
| 222 |
+
|
| 223 |
+
out = self.gamma * out + x
|
| 224 |
+
return out
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class CAM_Module(Module):
|
| 228 |
+
"""Channel attention module"""
|
| 229 |
+
|
| 230 |
+
def __init__(self, in_dim):
|
| 231 |
+
super(CAM_Module, self).__init__()
|
| 232 |
+
self.chanel_in = in_dim
|
| 233 |
+
|
| 234 |
+
self.gamma = Parameter(torch.zeros(1))
|
| 235 |
+
self.softmax = Softmax(dim=-1)
|
| 236 |
+
|
| 237 |
+
def forward(self, x):
|
| 238 |
+
"""
|
| 239 |
+
inputs :
|
| 240 |
+
x : input feature maps( B X C X H X W)
|
| 241 |
+
returns :
|
| 242 |
+
out : attention value + input feature
|
| 243 |
+
attention: B X C X C
|
| 244 |
+
"""
|
| 245 |
+
m_batchsize, C, height, width = x.size()
|
| 246 |
+
proj_query = x.view(m_batchsize, C, -1)
|
| 247 |
+
proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
|
| 248 |
+
# proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1).contiguous()
|
| 249 |
+
energy = torch.bmm(proj_query, proj_key)
|
| 250 |
+
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy
|
| 251 |
+
attention = self.softmax(energy_new)
|
| 252 |
+
proj_value = x.view(m_batchsize, C, -1)
|
| 253 |
+
|
| 254 |
+
out = torch.bmm(attention, proj_value)
|
| 255 |
+
out = out.view(m_batchsize, C, height, width)
|
| 256 |
+
|
| 257 |
+
out = self.gamma * out + x
|
| 258 |
+
return out
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class Dual_Module(Module):
|
| 262 |
+
"""
|
| 263 |
+
# Created by: CASIA IVA
|
| 264 |
+
# Email: jliu@nlpr.ia.ac.cn
|
| 265 |
+
# Copyright (c) 2018
|
| 266 |
+
|
| 267 |
+
# Reference: Dual Attention Network for Scene Segmentation
|
| 268 |
+
# https://arxiv.org/pdf/1809.02983.pdf
|
| 269 |
+
# https://github.com/junfu1115/DANet/blob/master/encoding/nn/attention.py
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
def __init__(self, in_dim):
|
| 273 |
+
super(Dual_Module).__init__()
|
| 274 |
+
self.indim = in_dim
|
| 275 |
+
self.pam = PAM_Module(in_dim)
|
| 276 |
+
self.cam = CAM_Module(in_dim)
|
| 277 |
+
|
| 278 |
+
def forward(self, x):
|
| 279 |
+
out1 = self.pam(x)
|
| 280 |
+
out2 = self.cam(x)
|
| 281 |
+
return out1 + out2
|
boxmot/appearance/backbones/lmbn/bnneck.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BNNeck(nn.Module):
|
| 7 |
+
def __init__(self, input_dim, class_num, return_f=False):
|
| 8 |
+
super(BNNeck, self).__init__()
|
| 9 |
+
self.return_f = return_f
|
| 10 |
+
self.bn = nn.BatchNorm1d(input_dim)
|
| 11 |
+
self.bn.bias.requires_grad_(False)
|
| 12 |
+
self.classifier = nn.Linear(input_dim, class_num, bias=False)
|
| 13 |
+
self.bn.apply(self.weights_init_kaiming)
|
| 14 |
+
self.classifier.apply(self.weights_init_classifier)
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
before_neck = x.view(x.size(0), x.size(1))
|
| 18 |
+
after_neck = self.bn(before_neck)
|
| 19 |
+
|
| 20 |
+
if self.return_f:
|
| 21 |
+
score = self.classifier(after_neck)
|
| 22 |
+
return after_neck, score, before_neck
|
| 23 |
+
else:
|
| 24 |
+
x = self.classifier(x)
|
| 25 |
+
return x
|
| 26 |
+
|
| 27 |
+
def weights_init_kaiming(self, m):
|
| 28 |
+
classname = m.__class__.__name__
|
| 29 |
+
if classname.find("Linear") != -1:
|
| 30 |
+
nn.init.kaiming_normal_(m.weight, a=0, mode="fan_out")
|
| 31 |
+
nn.init.constant_(m.bias, 0.0)
|
| 32 |
+
elif classname.find("Conv") != -1:
|
| 33 |
+
nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in")
|
| 34 |
+
if m.bias is not None:
|
| 35 |
+
nn.init.constant_(m.bias, 0.0)
|
| 36 |
+
elif classname.find("BatchNorm") != -1:
|
| 37 |
+
if m.affine:
|
| 38 |
+
nn.init.constant_(m.weight, 1.0)
|
| 39 |
+
nn.init.constant_(m.bias, 0.0)
|
| 40 |
+
|
| 41 |
+
def weights_init_classifier(self, m):
|
| 42 |
+
classname = m.__class__.__name__
|
| 43 |
+
if classname.find("Linear") != -1:
|
| 44 |
+
nn.init.normal_(m.weight, std=0.001)
|
| 45 |
+
if m.bias:
|
| 46 |
+
nn.init.constant_(m.bias, 0.0)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class BNNeck3(nn.Module):
|
| 50 |
+
def __init__(self, input_dim, class_num, feat_dim, return_f=False):
|
| 51 |
+
super(BNNeck3, self).__init__()
|
| 52 |
+
self.return_f = return_f
|
| 53 |
+
# self.reduction = nn.Linear(input_dim, feat_dim)
|
| 54 |
+
# self.bn = nn.BatchNorm1d(feat_dim)
|
| 55 |
+
|
| 56 |
+
self.reduction = nn.Conv2d(input_dim, feat_dim, 1, bias=False)
|
| 57 |
+
self.bn = nn.BatchNorm1d(feat_dim)
|
| 58 |
+
|
| 59 |
+
self.bn.bias.requires_grad_(False)
|
| 60 |
+
self.classifier = nn.Linear(feat_dim, class_num, bias=False)
|
| 61 |
+
self.bn.apply(self.weights_init_kaiming)
|
| 62 |
+
self.classifier.apply(self.weights_init_classifier)
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
x = self.reduction(x)
|
| 66 |
+
# before_neck = x.squeeze(dim=3).squeeze(dim=2)
|
| 67 |
+
# after_neck = self.bn(x).squeeze(dim=3).squeeze(dim=2)
|
| 68 |
+
before_neck = x.view(x.size(0), x.size(1))
|
| 69 |
+
after_neck = self.bn(before_neck)
|
| 70 |
+
if self.return_f:
|
| 71 |
+
score = self.classifier(after_neck)
|
| 72 |
+
return after_neck, score, before_neck
|
| 73 |
+
else:
|
| 74 |
+
x = self.classifier(x)
|
| 75 |
+
return x
|
| 76 |
+
|
| 77 |
+
def weights_init_kaiming(self, m):
|
| 78 |
+
classname = m.__class__.__name__
|
| 79 |
+
if classname.find("Linear") != -1:
|
| 80 |
+
nn.init.kaiming_normal_(m.weight, a=0, mode="fan_out")
|
| 81 |
+
nn.init.constant_(m.bias, 0.0)
|
| 82 |
+
elif classname.find("Conv") != -1:
|
| 83 |
+
nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in")
|
| 84 |
+
if m.bias is not None:
|
| 85 |
+
nn.init.constant_(m.bias, 0.0)
|
| 86 |
+
elif classname.find("BatchNorm") != -1:
|
| 87 |
+
if m.affine:
|
| 88 |
+
nn.init.constant_(m.weight, 1.0)
|
| 89 |
+
nn.init.constant_(m.bias, 0.0)
|
| 90 |
+
|
| 91 |
+
def weights_init_classifier(self, m):
|
| 92 |
+
classname = m.__class__.__name__
|
| 93 |
+
if classname.find("Linear") != -1:
|
| 94 |
+
nn.init.normal_(m.weight, std=0.001)
|
| 95 |
+
if m.bias:
|
| 96 |
+
nn.init.constant_(m.bias, 0.0)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Defines the new fc layer and classification layer
|
| 100 |
+
# |--Linear--|--bn--|--relu--|--Linear--|
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ClassBlock(nn.Module):
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
input_dim,
|
| 107 |
+
class_num,
|
| 108 |
+
droprate=0,
|
| 109 |
+
relu=False,
|
| 110 |
+
bnorm=True,
|
| 111 |
+
num_bottleneck=512,
|
| 112 |
+
linear=True,
|
| 113 |
+
return_f=False,
|
| 114 |
+
):
|
| 115 |
+
super(ClassBlock, self).__init__()
|
| 116 |
+
self.return_f = return_f
|
| 117 |
+
add_block = []
|
| 118 |
+
if linear:
|
| 119 |
+
add_block += [nn.Linear(input_dim, num_bottleneck)]
|
| 120 |
+
else:
|
| 121 |
+
num_bottleneck = input_dim
|
| 122 |
+
if bnorm:
|
| 123 |
+
add_block += [nn.BatchNorm1d(num_bottleneck)]
|
| 124 |
+
if relu:
|
| 125 |
+
add_block += [nn.LeakyReLU(0.1)]
|
| 126 |
+
if droprate > 0:
|
| 127 |
+
add_block += [nn.Dropout(p=droprate)]
|
| 128 |
+
add_block = nn.Sequential(*add_block)
|
| 129 |
+
add_block.apply(self.weights_init_kaiming)
|
| 130 |
+
|
| 131 |
+
classifier = []
|
| 132 |
+
classifier += [nn.Linear(num_bottleneck, class_num)]
|
| 133 |
+
classifier = nn.Sequential(*classifier)
|
| 134 |
+
classifier.apply(self.weights_init_classifier)
|
| 135 |
+
|
| 136 |
+
self.add_block = add_block
|
| 137 |
+
self.classifier = classifier
|
| 138 |
+
|
| 139 |
+
def forward(self, x):
|
| 140 |
+
x = self.add_block(x.squeeze(3).squeeze(2))
|
| 141 |
+
if self.return_f:
|
| 142 |
+
f = x
|
| 143 |
+
x = self.classifier(x)
|
| 144 |
+
return f, x, f
|
| 145 |
+
else:
|
| 146 |
+
x = self.classifier(x)
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
def weights_init_kaiming(self, m):
|
| 150 |
+
classname = m.__class__.__name__
|
| 151 |
+
# print(classname)
|
| 152 |
+
if classname.find("Conv") != -1:
|
| 153 |
+
# For old pytorch, you may use kaiming_normal.
|
| 154 |
+
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
| 155 |
+
elif classname.find("Linear") != -1:
|
| 156 |
+
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out")
|
| 157 |
+
nn.init.constant_(m.bias.data, 0.0)
|
| 158 |
+
elif classname.find("BatchNorm1d") != -1:
|
| 159 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
| 160 |
+
nn.init.constant_(m.bias.data, 0.0)
|
| 161 |
+
|
| 162 |
+
def weights_init_classifier(self, m):
|
| 163 |
+
classname = m.__class__.__name__
|
| 164 |
+
if classname.find("Linear") != -1:
|
| 165 |
+
nn.init.normal_(m.weight.data, std=0.001)
|
| 166 |
+
nn.init.constant_(m.bias.data, 0.0)
|
boxmot/appearance/backbones/lmbn/lmbn_n.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from boxmot.appearance.backbones.lmbn.attention import BatchFeatureErase_Top
|
| 9 |
+
from boxmot.appearance.backbones.lmbn.bnneck import BNNeck, BNNeck3
|
| 10 |
+
from boxmot.appearance.backbones.osnet import OSBlock, osnet_x1_0
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LMBN_n(nn.Module):
|
| 14 |
+
def __init__(self, num_classes, loss, pretrained, use_gpu):
|
| 15 |
+
super(LMBN_n, self).__init__()
|
| 16 |
+
|
| 17 |
+
self.n_ch = 2
|
| 18 |
+
self.chs = 512 // self.n_ch
|
| 19 |
+
self.training = False
|
| 20 |
+
|
| 21 |
+
osnet = osnet_x1_0(pretrained=True)
|
| 22 |
+
|
| 23 |
+
self.backone = nn.Sequential(
|
| 24 |
+
osnet.conv1, osnet.maxpool, osnet.conv2, osnet.conv3[0]
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
conv3 = osnet.conv3[1:]
|
| 28 |
+
|
| 29 |
+
self.global_branch = nn.Sequential(
|
| 30 |
+
copy.deepcopy(conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5)
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
self.partial_branch = nn.Sequential(
|
| 34 |
+
copy.deepcopy(conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.channel_branch = nn.Sequential(
|
| 38 |
+
copy.deepcopy(conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.global_pooling = nn.AdaptiveMaxPool2d((1, 1))
|
| 42 |
+
self.partial_pooling = nn.AdaptiveAvgPool2d((2, 1))
|
| 43 |
+
self.channel_pooling = nn.AdaptiveAvgPool2d((1, 1))
|
| 44 |
+
|
| 45 |
+
reduction = BNNeck3(512, num_classes, 512, return_f=True)
|
| 46 |
+
|
| 47 |
+
self.reduction_0 = copy.deepcopy(reduction)
|
| 48 |
+
self.reduction_1 = copy.deepcopy(reduction)
|
| 49 |
+
self.reduction_2 = copy.deepcopy(reduction)
|
| 50 |
+
self.reduction_3 = copy.deepcopy(reduction)
|
| 51 |
+
self.reduction_4 = copy.deepcopy(reduction)
|
| 52 |
+
|
| 53 |
+
self.shared = nn.Sequential(
|
| 54 |
+
nn.Conv2d(self.chs, 512, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True)
|
| 55 |
+
)
|
| 56 |
+
self.weights_init_kaiming(self.shared)
|
| 57 |
+
|
| 58 |
+
self.reduction_ch_0 = BNNeck(512, num_classes, return_f=True)
|
| 59 |
+
self.reduction_ch_1 = BNNeck(512, num_classes, return_f=True)
|
| 60 |
+
|
| 61 |
+
# if args.drop_block:
|
| 62 |
+
# print('Using batch random erasing block.')
|
| 63 |
+
# self.batch_drop_block = BatchRandomErasing()
|
| 64 |
+
# print('Using batch drop block.')
|
| 65 |
+
# self.batch_drop_block = BatchDrop(
|
| 66 |
+
# h_ratio=args.h_ratio, w_ratio=args.w_ratio)
|
| 67 |
+
self.batch_drop_block = BatchFeatureErase_Top(512, OSBlock)
|
| 68 |
+
|
| 69 |
+
self.activation_map = False
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
# if self.batch_drop_block is not None:
|
| 73 |
+
# x = self.batch_drop_block(x)
|
| 74 |
+
|
| 75 |
+
x = self.backone(x)
|
| 76 |
+
|
| 77 |
+
glo = self.global_branch(x)
|
| 78 |
+
par = self.partial_branch(x)
|
| 79 |
+
cha = self.channel_branch(x)
|
| 80 |
+
|
| 81 |
+
if self.activation_map:
|
| 82 |
+
glo_ = glo
|
| 83 |
+
|
| 84 |
+
if self.batch_drop_block is not None:
|
| 85 |
+
glo_drop, glo = self.batch_drop_block(glo)
|
| 86 |
+
|
| 87 |
+
if self.activation_map:
|
| 88 |
+
_, _, h_par, _ = par.size()
|
| 89 |
+
|
| 90 |
+
fmap_p0 = par[:, :, :h_par // 2, :]
|
| 91 |
+
fmap_p1 = par[:, :, h_par // 2:, :]
|
| 92 |
+
fmap_c0 = cha[:, : self.chs, :, :]
|
| 93 |
+
fmap_c1 = cha[:, self.chs:, :, :]
|
| 94 |
+
print("Generating activation maps...")
|
| 95 |
+
|
| 96 |
+
return glo, glo_, fmap_c0, fmap_c1, fmap_p0, fmap_p1
|
| 97 |
+
|
| 98 |
+
glo_drop = self.global_pooling(glo_drop)
|
| 99 |
+
glo = self.channel_pooling(glo) # shape:(batchsize, 512,1,1)
|
| 100 |
+
g_par = self.global_pooling(par) # shape:(batchsize, 512,1,1)
|
| 101 |
+
p_par = self.partial_pooling(par) # shape:(batchsize, 512,2,1)
|
| 102 |
+
cha = self.channel_pooling(cha) # shape:(batchsize, 256,1,1)
|
| 103 |
+
|
| 104 |
+
p0 = p_par[:, :, 0:1, :]
|
| 105 |
+
p1 = p_par[:, :, 1:2, :]
|
| 106 |
+
|
| 107 |
+
f_glo = self.reduction_0(glo)
|
| 108 |
+
f_p0 = self.reduction_1(g_par)
|
| 109 |
+
f_p1 = self.reduction_2(p0)
|
| 110 |
+
f_p2 = self.reduction_3(p1)
|
| 111 |
+
f_glo_drop = self.reduction_4(glo_drop)
|
| 112 |
+
|
| 113 |
+
################
|
| 114 |
+
|
| 115 |
+
c0 = cha[:, : self.chs, :, :]
|
| 116 |
+
c1 = cha[:, self.chs:, :, :]
|
| 117 |
+
c0 = self.shared(c0)
|
| 118 |
+
c1 = self.shared(c1)
|
| 119 |
+
f_c0 = self.reduction_ch_0(c0)
|
| 120 |
+
f_c1 = self.reduction_ch_1(c1)
|
| 121 |
+
|
| 122 |
+
################
|
| 123 |
+
|
| 124 |
+
fea = [f_glo[-1], f_glo_drop[-1], f_p0[-1]]
|
| 125 |
+
|
| 126 |
+
if not self.training:
|
| 127 |
+
features = torch.stack(
|
| 128 |
+
[f_glo[0], f_glo_drop[0], f_p0[0], f_p1[0], f_p2[0], f_c0[0], f_c1[0]],
|
| 129 |
+
dim=2,
|
| 130 |
+
)
|
| 131 |
+
features = features.flatten(1, 2)
|
| 132 |
+
return features
|
| 133 |
+
|
| 134 |
+
return [
|
| 135 |
+
f_glo[1],
|
| 136 |
+
f_glo_drop[1],
|
| 137 |
+
f_p0[1],
|
| 138 |
+
f_p1[1],
|
| 139 |
+
f_p2[1],
|
| 140 |
+
f_c0[1],
|
| 141 |
+
f_c1[1],
|
| 142 |
+
], fea
|
| 143 |
+
|
| 144 |
+
def weights_init_kaiming(self, m):
|
| 145 |
+
classname = m.__class__.__name__
|
| 146 |
+
if classname.find("Linear") != -1:
|
| 147 |
+
nn.init.kaiming_normal_(m.weight, a=0, mode="fan_out")
|
| 148 |
+
nn.init.constant_(m.bias, 0.0)
|
| 149 |
+
elif classname.find("Conv") != -1:
|
| 150 |
+
nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in")
|
| 151 |
+
if m.bias is not None:
|
| 152 |
+
nn.init.constant_(m.bias, 0.0)
|
| 153 |
+
elif classname.find("BatchNorm") != -1:
|
| 154 |
+
if m.affine:
|
| 155 |
+
nn.init.constant_(m.weight, 1.0)
|
| 156 |
+
nn.init.constant_(m.bias, 0.0)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
# Here I left a simple forward function.
|
| 161 |
+
# Test the model, before you train it.
|
| 162 |
+
import argparse
|
| 163 |
+
|
| 164 |
+
parser = argparse.ArgumentParser(description="MGN")
|
| 165 |
+
parser.add_argument("--num_classes", type=int, default=751, help="")
|
| 166 |
+
parser.add_argument("--bnneck", type=bool, default=True)
|
| 167 |
+
parser.add_argument("--pool", type=str, default="max")
|
| 168 |
+
parser.add_argument("--feats", type=int, default=512)
|
| 169 |
+
parser.add_argument("--drop_block", type=bool, default=True)
|
| 170 |
+
parser.add_argument("--w_ratio", type=float, default=1.0, help="")
|
| 171 |
+
|
| 172 |
+
args = parser.parse_args()
|
| 173 |
+
# net = MCMP_n(args)
|
| 174 |
+
# net.classifier = nn.Sequential()
|
| 175 |
+
# print([p for p in net.parameters()])
|
| 176 |
+
# a=filter(lambda p: p.requires_grad, net.parameters())
|
| 177 |
+
# print(a)
|
| 178 |
+
|
| 179 |
+
# print(net)
|
| 180 |
+
# input = Variable(torch.FloatTensor(8, 3, 384, 128))
|
| 181 |
+
# net.eval()
|
| 182 |
+
# output = net(input)
|
| 183 |
+
# print(output.shape)
|
| 184 |
+
print("net output size:")
|
| 185 |
+
# print(len(output))
|
boxmot/appearance/backbones/mlfn.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
from __future__ import absolute_import, division
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils.model_zoo as model_zoo
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
|
| 10 |
+
__all__ = ["mlfn"]
|
| 11 |
+
|
| 12 |
+
model_urls = {
|
| 13 |
+
# training epoch = 5, top1 = 51.6
|
| 14 |
+
"imagenet": "https://mega.nz/#!YHxAhaxC!yu9E6zWl0x5zscSouTdbZu8gdFFytDdl-RAdD2DEfpk",
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MLFNBlock(nn.Module):
|
| 19 |
+
def __init__(self, in_channels, out_channels, stride, fsm_channels, groups=32):
|
| 20 |
+
super(MLFNBlock, self).__init__()
|
| 21 |
+
self.groups = groups
|
| 22 |
+
mid_channels = out_channels // 2
|
| 23 |
+
|
| 24 |
+
# Factor Modules
|
| 25 |
+
self.fm_conv1 = nn.Conv2d(in_channels, mid_channels, 1, bias=False)
|
| 26 |
+
self.fm_bn1 = nn.BatchNorm2d(mid_channels)
|
| 27 |
+
self.fm_conv2 = nn.Conv2d(
|
| 28 |
+
mid_channels,
|
| 29 |
+
mid_channels,
|
| 30 |
+
3,
|
| 31 |
+
stride=stride,
|
| 32 |
+
padding=1,
|
| 33 |
+
bias=False,
|
| 34 |
+
groups=self.groups,
|
| 35 |
+
)
|
| 36 |
+
self.fm_bn2 = nn.BatchNorm2d(mid_channels)
|
| 37 |
+
self.fm_conv3 = nn.Conv2d(mid_channels, out_channels, 1, bias=False)
|
| 38 |
+
self.fm_bn3 = nn.BatchNorm2d(out_channels)
|
| 39 |
+
|
| 40 |
+
# Factor Selection Module
|
| 41 |
+
self.fsm = nn.Sequential(
|
| 42 |
+
nn.AdaptiveAvgPool2d(1),
|
| 43 |
+
nn.Conv2d(in_channels, fsm_channels[0], 1),
|
| 44 |
+
nn.BatchNorm2d(fsm_channels[0]),
|
| 45 |
+
nn.ReLU(inplace=True),
|
| 46 |
+
nn.Conv2d(fsm_channels[0], fsm_channels[1], 1),
|
| 47 |
+
nn.BatchNorm2d(fsm_channels[1]),
|
| 48 |
+
nn.ReLU(inplace=True),
|
| 49 |
+
nn.Conv2d(fsm_channels[1], self.groups, 1),
|
| 50 |
+
nn.BatchNorm2d(self.groups),
|
| 51 |
+
nn.Sigmoid(),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.downsample = None
|
| 55 |
+
if in_channels != out_channels or stride > 1:
|
| 56 |
+
self.downsample = nn.Sequential(
|
| 57 |
+
nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
|
| 58 |
+
nn.BatchNorm2d(out_channels),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
residual = x
|
| 63 |
+
s = self.fsm(x)
|
| 64 |
+
|
| 65 |
+
# reduce dimension
|
| 66 |
+
x = self.fm_conv1(x)
|
| 67 |
+
x = self.fm_bn1(x)
|
| 68 |
+
x = F.relu(x, inplace=True)
|
| 69 |
+
|
| 70 |
+
# group convolution
|
| 71 |
+
x = self.fm_conv2(x)
|
| 72 |
+
x = self.fm_bn2(x)
|
| 73 |
+
x = F.relu(x, inplace=True)
|
| 74 |
+
|
| 75 |
+
# factor selection
|
| 76 |
+
b, c = x.size(0), x.size(1)
|
| 77 |
+
n = c // self.groups
|
| 78 |
+
ss = s.repeat(1, n, 1, 1) # from (b, g, 1, 1) to (b, g*n=c, 1, 1)
|
| 79 |
+
ss = ss.view(b, n, self.groups, 1, 1)
|
| 80 |
+
ss = ss.permute(0, 2, 1, 3, 4).contiguous()
|
| 81 |
+
ss = ss.view(b, c, 1, 1)
|
| 82 |
+
x = ss * x
|
| 83 |
+
|
| 84 |
+
# recover dimension
|
| 85 |
+
x = self.fm_conv3(x)
|
| 86 |
+
x = self.fm_bn3(x)
|
| 87 |
+
x = F.relu(x, inplace=True)
|
| 88 |
+
|
| 89 |
+
if self.downsample is not None:
|
| 90 |
+
residual = self.downsample(residual)
|
| 91 |
+
|
| 92 |
+
return F.relu(residual + x, inplace=True), s
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class MLFN(nn.Module):
|
| 96 |
+
"""Multi-Level Factorisation Net.
|
| 97 |
+
|
| 98 |
+
Reference:
|
| 99 |
+
Chang et al. Multi-Level Factorisation Net for
|
| 100 |
+
Person Re-Identification. CVPR 2018.
|
| 101 |
+
|
| 102 |
+
Public keys:
|
| 103 |
+
- ``mlfn``: MLFN (Multi-Level Factorisation Net).
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
num_classes,
|
| 109 |
+
loss="softmax",
|
| 110 |
+
groups=32,
|
| 111 |
+
channels=[64, 256, 512, 1024, 2048],
|
| 112 |
+
embed_dim=1024,
|
| 113 |
+
**kwargs
|
| 114 |
+
):
|
| 115 |
+
super(MLFN, self).__init__()
|
| 116 |
+
self.loss = loss
|
| 117 |
+
self.groups = groups
|
| 118 |
+
|
| 119 |
+
# first convolutional layer
|
| 120 |
+
self.conv1 = nn.Conv2d(3, channels[0], 7, stride=2, padding=3)
|
| 121 |
+
self.bn1 = nn.BatchNorm2d(channels[0])
|
| 122 |
+
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
| 123 |
+
|
| 124 |
+
# main body
|
| 125 |
+
self.feature = nn.ModuleList(
|
| 126 |
+
[
|
| 127 |
+
# layer 1-3
|
| 128 |
+
MLFNBlock(channels[0], channels[1], 1, [128, 64], self.groups),
|
| 129 |
+
MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups),
|
| 130 |
+
MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups),
|
| 131 |
+
# layer 4-7
|
| 132 |
+
MLFNBlock(channels[1], channels[2], 2, [256, 128], self.groups),
|
| 133 |
+
MLFNBlock(channels[2], channels[2], 1, [256, 128], self.groups),
|
| 134 |
+
MLFNBlock(channels[2], channels[2], 1, [256, 128], self.groups),
|
| 135 |
+
MLFNBlock(channels[2], channels[2], 1, [256, 128], self.groups),
|
| 136 |
+
# layer 8-13
|
| 137 |
+
MLFNBlock(channels[2], channels[3], 2, [512, 128], self.groups),
|
| 138 |
+
MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups),
|
| 139 |
+
MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups),
|
| 140 |
+
MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups),
|
| 141 |
+
MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups),
|
| 142 |
+
MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups),
|
| 143 |
+
# layer 14-16
|
| 144 |
+
MLFNBlock(channels[3], channels[4], 2, [512, 128], self.groups),
|
| 145 |
+
MLFNBlock(channels[4], channels[4], 1, [512, 128], self.groups),
|
| 146 |
+
MLFNBlock(channels[4], channels[4], 1, [512, 128], self.groups),
|
| 147 |
+
]
|
| 148 |
+
)
|
| 149 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 150 |
+
|
| 151 |
+
# projection functions
|
| 152 |
+
self.fc_x = nn.Sequential(
|
| 153 |
+
nn.Conv2d(channels[4], embed_dim, 1, bias=False),
|
| 154 |
+
nn.BatchNorm2d(embed_dim),
|
| 155 |
+
nn.ReLU(inplace=True),
|
| 156 |
+
)
|
| 157 |
+
self.fc_s = nn.Sequential(
|
| 158 |
+
nn.Conv2d(self.groups * 16, embed_dim, 1, bias=False),
|
| 159 |
+
nn.BatchNorm2d(embed_dim),
|
| 160 |
+
nn.ReLU(inplace=True),
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
self.classifier = nn.Linear(embed_dim, num_classes)
|
| 164 |
+
|
| 165 |
+
self.init_params()
|
| 166 |
+
|
| 167 |
+
def init_params(self):
|
| 168 |
+
for m in self.modules():
|
| 169 |
+
if isinstance(m, nn.Conv2d):
|
| 170 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 171 |
+
if m.bias is not None:
|
| 172 |
+
nn.init.constant_(m.bias, 0)
|
| 173 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 174 |
+
nn.init.constant_(m.weight, 1)
|
| 175 |
+
nn.init.constant_(m.bias, 0)
|
| 176 |
+
elif isinstance(m, nn.Linear):
|
| 177 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 178 |
+
if m.bias is not None:
|
| 179 |
+
nn.init.constant_(m.bias, 0)
|
| 180 |
+
|
| 181 |
+
def forward(self, x):
|
| 182 |
+
x = self.conv1(x)
|
| 183 |
+
x = self.bn1(x)
|
| 184 |
+
x = F.relu(x, inplace=True)
|
| 185 |
+
x = self.maxpool(x)
|
| 186 |
+
|
| 187 |
+
s_hat = []
|
| 188 |
+
for block in self.feature:
|
| 189 |
+
x, s = block(x)
|
| 190 |
+
s_hat.append(s)
|
| 191 |
+
s_hat = torch.cat(s_hat, 1)
|
| 192 |
+
|
| 193 |
+
x = self.global_avgpool(x)
|
| 194 |
+
x = self.fc_x(x)
|
| 195 |
+
s_hat = self.fc_s(s_hat)
|
| 196 |
+
|
| 197 |
+
v = (x + s_hat) * 0.5
|
| 198 |
+
v = v.view(v.size(0), -1)
|
| 199 |
+
|
| 200 |
+
if not self.training:
|
| 201 |
+
return v
|
| 202 |
+
|
| 203 |
+
y = self.classifier(v)
|
| 204 |
+
|
| 205 |
+
if self.loss == "softmax":
|
| 206 |
+
return y
|
| 207 |
+
elif self.loss == "triplet":
|
| 208 |
+
return y, v
|
| 209 |
+
else:
|
| 210 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def init_pretrained_weights(model, model_url):
|
| 214 |
+
"""Initializes model with pretrained weights.
|
| 215 |
+
|
| 216 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
| 217 |
+
"""
|
| 218 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
| 219 |
+
model_dict = model.state_dict()
|
| 220 |
+
pretrain_dict = {
|
| 221 |
+
k: v
|
| 222 |
+
for k, v in pretrain_dict.items()
|
| 223 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
| 224 |
+
}
|
| 225 |
+
model_dict.update(pretrain_dict)
|
| 226 |
+
model.load_state_dict(model_dict)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def mlfn(num_classes, loss="softmax", pretrained=True, **kwargs):
|
| 230 |
+
model = MLFN(num_classes, loss, **kwargs)
|
| 231 |
+
if pretrained:
|
| 232 |
+
# init_pretrained_weights(model, model_urls['imagenet'])
|
| 233 |
+
import warnings
|
| 234 |
+
|
| 235 |
+
warnings.warn(
|
| 236 |
+
"The imagenet pretrained weights need to be manually downloaded from {}".format(
|
| 237 |
+
model_urls["imagenet"]
|
| 238 |
+
)
|
| 239 |
+
)
|
| 240 |
+
return model
|
boxmot/appearance/backbones/mobilenetv2.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
from __future__ import absolute_import, division
|
| 4 |
+
|
| 5 |
+
import torch.utils.model_zoo as model_zoo
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
__all__ = ["mobilenetv2_x1_0", "mobilenetv2_x1_4"]
|
| 10 |
+
|
| 11 |
+
model_urls = {
|
| 12 |
+
# 1.0: top-1 71.3
|
| 13 |
+
"mobilenetv2_x1_0": "https://mega.nz/#!NKp2wAIA!1NH1pbNzY_M2hVk_hdsxNM1NUOWvvGPHhaNr-fASF6c",
|
| 14 |
+
# 1.4: top-1 73.9
|
| 15 |
+
"mobilenetv2_x1_4": "https://mega.nz/#!RGhgEIwS!xN2s2ZdyqI6vQ3EwgmRXLEW3khr9tpXg96G9SUJugGk",
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ConvBlock(nn.Module):
|
| 20 |
+
"""Basic convolutional block.
|
| 21 |
+
|
| 22 |
+
convolution (bias discarded) + batch normalization + relu6.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
in_c (int): number of input channels.
|
| 26 |
+
out_c (int): number of output channels.
|
| 27 |
+
k (int or tuple): kernel size.
|
| 28 |
+
s (int or tuple): stride.
|
| 29 |
+
p (int or tuple): padding.
|
| 30 |
+
g (int): number of blocked connections from input channels
|
| 31 |
+
to output channels (default: 1).
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, in_c, out_c, k, s=1, p=0, g=1):
|
| 35 |
+
super(ConvBlock, self).__init__()
|
| 36 |
+
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p, bias=False, groups=g)
|
| 37 |
+
self.bn = nn.BatchNorm2d(out_c)
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
return F.relu6(self.bn(self.conv(x)))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Bottleneck(nn.Module):
|
| 44 |
+
def __init__(self, in_channels, out_channels, expansion_factor, stride=1):
|
| 45 |
+
super(Bottleneck, self).__init__()
|
| 46 |
+
mid_channels = in_channels * expansion_factor
|
| 47 |
+
self.use_residual = stride == 1 and in_channels == out_channels
|
| 48 |
+
self.conv1 = ConvBlock(in_channels, mid_channels, 1)
|
| 49 |
+
self.dwconv2 = ConvBlock(
|
| 50 |
+
mid_channels, mid_channels, 3, stride, 1, g=mid_channels
|
| 51 |
+
)
|
| 52 |
+
self.conv3 = nn.Sequential(
|
| 53 |
+
nn.Conv2d(mid_channels, out_channels, 1, bias=False),
|
| 54 |
+
nn.BatchNorm2d(out_channels),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
m = self.conv1(x)
|
| 59 |
+
m = self.dwconv2(m)
|
| 60 |
+
m = self.conv3(m)
|
| 61 |
+
if self.use_residual:
|
| 62 |
+
return x + m
|
| 63 |
+
else:
|
| 64 |
+
return m
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MobileNetV2(nn.Module):
|
| 68 |
+
"""MobileNetV2.
|
| 69 |
+
|
| 70 |
+
Reference:
|
| 71 |
+
Sandler et al. MobileNetV2: Inverted Residuals and
|
| 72 |
+
Linear Bottlenecks. CVPR 2018.
|
| 73 |
+
|
| 74 |
+
Public keys:
|
| 75 |
+
- ``mobilenetv2_x1_0``: MobileNetV2 x1.0.
|
| 76 |
+
- ``mobilenetv2_x1_4``: MobileNetV2 x1.4.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
num_classes,
|
| 82 |
+
width_mult=1,
|
| 83 |
+
loss="softmax",
|
| 84 |
+
fc_dims=None,
|
| 85 |
+
dropout_p=None,
|
| 86 |
+
**kwargs
|
| 87 |
+
):
|
| 88 |
+
super(MobileNetV2, self).__init__()
|
| 89 |
+
self.loss = loss
|
| 90 |
+
self.in_channels = int(32 * width_mult)
|
| 91 |
+
self.feature_dim = int(1280 * width_mult) if width_mult > 1 else 1280
|
| 92 |
+
|
| 93 |
+
# construct layers
|
| 94 |
+
self.conv1 = ConvBlock(3, self.in_channels, 3, s=2, p=1)
|
| 95 |
+
self.conv2 = self._make_layer(Bottleneck, 1, int(16 * width_mult), 1, 1)
|
| 96 |
+
self.conv3 = self._make_layer(Bottleneck, 6, int(24 * width_mult), 2, 2)
|
| 97 |
+
self.conv4 = self._make_layer(Bottleneck, 6, int(32 * width_mult), 3, 2)
|
| 98 |
+
self.conv5 = self._make_layer(Bottleneck, 6, int(64 * width_mult), 4, 2)
|
| 99 |
+
self.conv6 = self._make_layer(Bottleneck, 6, int(96 * width_mult), 3, 1)
|
| 100 |
+
self.conv7 = self._make_layer(Bottleneck, 6, int(160 * width_mult), 3, 2)
|
| 101 |
+
self.conv8 = self._make_layer(Bottleneck, 6, int(320 * width_mult), 1, 1)
|
| 102 |
+
self.conv9 = ConvBlock(self.in_channels, self.feature_dim, 1)
|
| 103 |
+
|
| 104 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 105 |
+
self.fc = self._construct_fc_layer(fc_dims, self.feature_dim, dropout_p)
|
| 106 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
| 107 |
+
|
| 108 |
+
self._init_params()
|
| 109 |
+
|
| 110 |
+
def _make_layer(self, block, t, c, n, s):
|
| 111 |
+
# t: expansion factor
|
| 112 |
+
# c: output channels
|
| 113 |
+
# n: number of blocks
|
| 114 |
+
# s: stride for first layer
|
| 115 |
+
layers = []
|
| 116 |
+
layers.append(block(self.in_channels, c, t, s))
|
| 117 |
+
self.in_channels = c
|
| 118 |
+
for i in range(1, n):
|
| 119 |
+
layers.append(block(self.in_channels, c, t))
|
| 120 |
+
return nn.Sequential(*layers)
|
| 121 |
+
|
| 122 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
| 123 |
+
"""Constructs fully connected layer.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
| 127 |
+
input_dim (int): input dimension
|
| 128 |
+
dropout_p (float): dropout probability, if None, dropout is unused
|
| 129 |
+
"""
|
| 130 |
+
if fc_dims is None:
|
| 131 |
+
self.feature_dim = input_dim
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
assert isinstance(
|
| 135 |
+
fc_dims, (list, tuple)
|
| 136 |
+
), "fc_dims must be either list or tuple, but got {}".format(type(fc_dims))
|
| 137 |
+
|
| 138 |
+
layers = []
|
| 139 |
+
for dim in fc_dims:
|
| 140 |
+
layers.append(nn.Linear(input_dim, dim))
|
| 141 |
+
layers.append(nn.BatchNorm1d(dim))
|
| 142 |
+
layers.append(nn.ReLU(inplace=True))
|
| 143 |
+
if dropout_p is not None:
|
| 144 |
+
layers.append(nn.Dropout(p=dropout_p))
|
| 145 |
+
input_dim = dim
|
| 146 |
+
|
| 147 |
+
self.feature_dim = fc_dims[-1]
|
| 148 |
+
|
| 149 |
+
return nn.Sequential(*layers)
|
| 150 |
+
|
| 151 |
+
def _init_params(self):
|
| 152 |
+
for m in self.modules():
|
| 153 |
+
if isinstance(m, nn.Conv2d):
|
| 154 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 155 |
+
if m.bias is not None:
|
| 156 |
+
nn.init.constant_(m.bias, 0)
|
| 157 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 158 |
+
nn.init.constant_(m.weight, 1)
|
| 159 |
+
nn.init.constant_(m.bias, 0)
|
| 160 |
+
elif isinstance(m, nn.BatchNorm1d):
|
| 161 |
+
nn.init.constant_(m.weight, 1)
|
| 162 |
+
nn.init.constant_(m.bias, 0)
|
| 163 |
+
elif isinstance(m, nn.Linear):
|
| 164 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 165 |
+
if m.bias is not None:
|
| 166 |
+
nn.init.constant_(m.bias, 0)
|
| 167 |
+
|
| 168 |
+
def featuremaps(self, x):
|
| 169 |
+
x = self.conv1(x)
|
| 170 |
+
x = self.conv2(x)
|
| 171 |
+
x = self.conv3(x)
|
| 172 |
+
x = self.conv4(x)
|
| 173 |
+
x = self.conv5(x)
|
| 174 |
+
x = self.conv6(x)
|
| 175 |
+
x = self.conv7(x)
|
| 176 |
+
x = self.conv8(x)
|
| 177 |
+
x = self.conv9(x)
|
| 178 |
+
return x
|
| 179 |
+
|
| 180 |
+
def forward(self, x):
|
| 181 |
+
f = self.featuremaps(x)
|
| 182 |
+
v = self.global_avgpool(f)
|
| 183 |
+
v = v.view(v.size(0), -1)
|
| 184 |
+
|
| 185 |
+
if self.fc is not None:
|
| 186 |
+
v = self.fc(v)
|
| 187 |
+
|
| 188 |
+
if not self.training:
|
| 189 |
+
return v
|
| 190 |
+
|
| 191 |
+
y = self.classifier(v)
|
| 192 |
+
|
| 193 |
+
if self.loss == "softmax":
|
| 194 |
+
return y
|
| 195 |
+
elif self.loss == "triplet":
|
| 196 |
+
return y, v
|
| 197 |
+
else:
|
| 198 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def init_pretrained_weights(model, model_url):
|
| 202 |
+
"""Initializes model with pretrained weights.
|
| 203 |
+
|
| 204 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
| 205 |
+
"""
|
| 206 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
| 207 |
+
model_dict = model.state_dict()
|
| 208 |
+
pretrain_dict = {
|
| 209 |
+
k: v
|
| 210 |
+
for k, v in pretrain_dict.items()
|
| 211 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
| 212 |
+
}
|
| 213 |
+
model_dict.update(pretrain_dict)
|
| 214 |
+
model.load_state_dict(model_dict)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def mobilenetv2_x1_0(num_classes, loss, pretrained=True, **kwargs):
|
| 218 |
+
model = MobileNetV2(
|
| 219 |
+
num_classes, loss=loss, width_mult=1, fc_dims=None, dropout_p=None, **kwargs
|
| 220 |
+
)
|
| 221 |
+
if pretrained:
|
| 222 |
+
# init_pretrained_weights(model, model_urls['mobilenetv2_x1_0'])
|
| 223 |
+
import warnings
|
| 224 |
+
|
| 225 |
+
warnings.warn(
|
| 226 |
+
"The imagenet pretrained weights need to be manually downloaded from {}".format(
|
| 227 |
+
model_urls["mobilenetv2_x1_0"]
|
| 228 |
+
)
|
| 229 |
+
)
|
| 230 |
+
return model
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def mobilenetv2_x1_4(num_classes, loss, pretrained=True, **kwargs):
|
| 234 |
+
model = MobileNetV2(
|
| 235 |
+
num_classes, loss=loss, width_mult=1.4, fc_dims=None, dropout_p=None, **kwargs
|
| 236 |
+
)
|
| 237 |
+
if pretrained:
|
| 238 |
+
# init_pretrained_weights(model, model_urls['mobilenetv2_x1_4'])
|
| 239 |
+
import warnings
|
| 240 |
+
|
| 241 |
+
warnings.warn(
|
| 242 |
+
"The imagenet pretrained weights need to be manually downloaded from {}".format(
|
| 243 |
+
model_urls["mobilenetv2_x1_4"]
|
| 244 |
+
)
|
| 245 |
+
)
|
| 246 |
+
return model
|
boxmot/appearance/backbones/osnet.py
ADDED
|
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
from __future__ import absolute_import, division
|
| 4 |
+
|
| 5 |
+
import warnings
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
__all__ = ["osnet_x1_0", "osnet_x0_75", "osnet_x0_5", "osnet_x0_25", "osnet_ibn_x1_0"]
|
| 12 |
+
|
| 13 |
+
pretrained_urls = {
|
| 14 |
+
"osnet_x1_0": "https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY",
|
| 15 |
+
"osnet_x0_75": "https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq",
|
| 16 |
+
"osnet_x0_5": "https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i",
|
| 17 |
+
"osnet_x0_25": "https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs",
|
| 18 |
+
"osnet_ibn_x1_0": "https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l",
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
##########
|
| 23 |
+
# Basic layers
|
| 24 |
+
##########
|
| 25 |
+
class ConvLayer(nn.Module):
|
| 26 |
+
"""Convolution layer (conv + bn + relu)."""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
in_channels,
|
| 31 |
+
out_channels,
|
| 32 |
+
kernel_size,
|
| 33 |
+
stride=1,
|
| 34 |
+
padding=0,
|
| 35 |
+
groups=1,
|
| 36 |
+
IN=False,
|
| 37 |
+
):
|
| 38 |
+
super(ConvLayer, self).__init__()
|
| 39 |
+
self.conv = nn.Conv2d(
|
| 40 |
+
in_channels,
|
| 41 |
+
out_channels,
|
| 42 |
+
kernel_size,
|
| 43 |
+
stride=stride,
|
| 44 |
+
padding=padding,
|
| 45 |
+
bias=False,
|
| 46 |
+
groups=groups,
|
| 47 |
+
)
|
| 48 |
+
if IN:
|
| 49 |
+
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
|
| 50 |
+
else:
|
| 51 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 52 |
+
self.relu = nn.ReLU(inplace=True)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
x = self.conv(x)
|
| 56 |
+
x = self.bn(x)
|
| 57 |
+
x = self.relu(x)
|
| 58 |
+
return x
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Conv1x1(nn.Module):
|
| 62 |
+
"""1x1 convolution + bn + relu."""
|
| 63 |
+
|
| 64 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
| 65 |
+
super(Conv1x1, self).__init__()
|
| 66 |
+
self.conv = nn.Conv2d(
|
| 67 |
+
in_channels,
|
| 68 |
+
out_channels,
|
| 69 |
+
1,
|
| 70 |
+
stride=stride,
|
| 71 |
+
padding=0,
|
| 72 |
+
bias=False,
|
| 73 |
+
groups=groups,
|
| 74 |
+
)
|
| 75 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 76 |
+
self.relu = nn.ReLU(inplace=True)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
x = self.conv(x)
|
| 80 |
+
x = self.bn(x)
|
| 81 |
+
x = self.relu(x)
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class Conv1x1Linear(nn.Module):
|
| 86 |
+
"""1x1 convolution + bn (w/o non-linearity)."""
|
| 87 |
+
|
| 88 |
+
def __init__(self, in_channels, out_channels, stride=1):
|
| 89 |
+
super(Conv1x1Linear, self).__init__()
|
| 90 |
+
self.conv = nn.Conv2d(
|
| 91 |
+
in_channels, out_channels, 1, stride=stride, padding=0, bias=False
|
| 92 |
+
)
|
| 93 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
x = self.conv(x)
|
| 97 |
+
x = self.bn(x)
|
| 98 |
+
return x
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class Conv3x3(nn.Module):
|
| 102 |
+
"""3x3 convolution + bn + relu."""
|
| 103 |
+
|
| 104 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
| 105 |
+
super(Conv3x3, self).__init__()
|
| 106 |
+
self.conv = nn.Conv2d(
|
| 107 |
+
in_channels,
|
| 108 |
+
out_channels,
|
| 109 |
+
3,
|
| 110 |
+
stride=stride,
|
| 111 |
+
padding=1,
|
| 112 |
+
bias=False,
|
| 113 |
+
groups=groups,
|
| 114 |
+
)
|
| 115 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 116 |
+
self.relu = nn.ReLU(inplace=True)
|
| 117 |
+
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
x = self.conv(x)
|
| 120 |
+
x = self.bn(x)
|
| 121 |
+
x = self.relu(x)
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class LightConv3x3(nn.Module):
|
| 126 |
+
"""Lightweight 3x3 convolution.
|
| 127 |
+
|
| 128 |
+
1x1 (linear) + dw 3x3 (nonlinear).
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def __init__(self, in_channels, out_channels):
|
| 132 |
+
super(LightConv3x3, self).__init__()
|
| 133 |
+
self.conv1 = nn.Conv2d(
|
| 134 |
+
in_channels, out_channels, 1, stride=1, padding=0, bias=False
|
| 135 |
+
)
|
| 136 |
+
self.conv2 = nn.Conv2d(
|
| 137 |
+
out_channels,
|
| 138 |
+
out_channels,
|
| 139 |
+
3,
|
| 140 |
+
stride=1,
|
| 141 |
+
padding=1,
|
| 142 |
+
bias=False,
|
| 143 |
+
groups=out_channels,
|
| 144 |
+
)
|
| 145 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 146 |
+
self.relu = nn.ReLU(inplace=True)
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
x = self.conv1(x)
|
| 150 |
+
x = self.conv2(x)
|
| 151 |
+
x = self.bn(x)
|
| 152 |
+
x = self.relu(x)
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
##########
|
| 157 |
+
# Building blocks for omni-scale feature learning
|
| 158 |
+
##########
|
| 159 |
+
class ChannelGate(nn.Module):
|
| 160 |
+
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
|
| 161 |
+
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
in_channels,
|
| 165 |
+
num_gates=None,
|
| 166 |
+
return_gates=False,
|
| 167 |
+
gate_activation="sigmoid",
|
| 168 |
+
reduction=16,
|
| 169 |
+
layer_norm=False,
|
| 170 |
+
):
|
| 171 |
+
super(ChannelGate, self).__init__()
|
| 172 |
+
if num_gates is None:
|
| 173 |
+
num_gates = in_channels
|
| 174 |
+
self.return_gates = return_gates
|
| 175 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 176 |
+
self.fc1 = nn.Conv2d(
|
| 177 |
+
in_channels, in_channels // reduction, kernel_size=1, bias=True, padding=0
|
| 178 |
+
)
|
| 179 |
+
self.norm1 = None
|
| 180 |
+
if layer_norm:
|
| 181 |
+
self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
|
| 182 |
+
self.relu = nn.ReLU(inplace=True)
|
| 183 |
+
self.fc2 = nn.Conv2d(
|
| 184 |
+
in_channels // reduction, num_gates, kernel_size=1, bias=True, padding=0
|
| 185 |
+
)
|
| 186 |
+
if gate_activation == "sigmoid":
|
| 187 |
+
self.gate_activation = nn.Sigmoid()
|
| 188 |
+
elif gate_activation == "relu":
|
| 189 |
+
self.gate_activation = nn.ReLU(inplace=True)
|
| 190 |
+
elif gate_activation == "linear":
|
| 191 |
+
self.gate_activation = None
|
| 192 |
+
else:
|
| 193 |
+
raise RuntimeError("Unknown gate activation: {}".format(gate_activation))
|
| 194 |
+
|
| 195 |
+
def forward(self, x):
|
| 196 |
+
input = x
|
| 197 |
+
x = self.global_avgpool(x)
|
| 198 |
+
x = self.fc1(x)
|
| 199 |
+
if self.norm1 is not None:
|
| 200 |
+
x = self.norm1(x)
|
| 201 |
+
x = self.relu(x)
|
| 202 |
+
x = self.fc2(x)
|
| 203 |
+
if self.gate_activation is not None:
|
| 204 |
+
x = self.gate_activation(x)
|
| 205 |
+
if self.return_gates:
|
| 206 |
+
return x
|
| 207 |
+
return input * x
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class OSBlock(nn.Module):
|
| 211 |
+
"""Omni-scale feature learning block."""
|
| 212 |
+
|
| 213 |
+
def __init__(
|
| 214 |
+
self, in_channels, out_channels, IN=False, bottleneck_reduction=4, **kwargs
|
| 215 |
+
):
|
| 216 |
+
super(OSBlock, self).__init__()
|
| 217 |
+
mid_channels = out_channels // bottleneck_reduction
|
| 218 |
+
self.conv1 = Conv1x1(in_channels, mid_channels)
|
| 219 |
+
self.conv2a = LightConv3x3(mid_channels, mid_channels)
|
| 220 |
+
self.conv2b = nn.Sequential(
|
| 221 |
+
LightConv3x3(mid_channels, mid_channels),
|
| 222 |
+
LightConv3x3(mid_channels, mid_channels),
|
| 223 |
+
)
|
| 224 |
+
self.conv2c = nn.Sequential(
|
| 225 |
+
LightConv3x3(mid_channels, mid_channels),
|
| 226 |
+
LightConv3x3(mid_channels, mid_channels),
|
| 227 |
+
LightConv3x3(mid_channels, mid_channels),
|
| 228 |
+
)
|
| 229 |
+
self.conv2d = nn.Sequential(
|
| 230 |
+
LightConv3x3(mid_channels, mid_channels),
|
| 231 |
+
LightConv3x3(mid_channels, mid_channels),
|
| 232 |
+
LightConv3x3(mid_channels, mid_channels),
|
| 233 |
+
LightConv3x3(mid_channels, mid_channels),
|
| 234 |
+
)
|
| 235 |
+
self.gate = ChannelGate(mid_channels)
|
| 236 |
+
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
|
| 237 |
+
self.downsample = None
|
| 238 |
+
if in_channels != out_channels:
|
| 239 |
+
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
| 240 |
+
self.IN = None
|
| 241 |
+
if IN:
|
| 242 |
+
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
|
| 243 |
+
|
| 244 |
+
def forward(self, x):
|
| 245 |
+
identity = x
|
| 246 |
+
x1 = self.conv1(x)
|
| 247 |
+
x2a = self.conv2a(x1)
|
| 248 |
+
x2b = self.conv2b(x1)
|
| 249 |
+
x2c = self.conv2c(x1)
|
| 250 |
+
x2d = self.conv2d(x1)
|
| 251 |
+
x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
|
| 252 |
+
x3 = self.conv3(x2)
|
| 253 |
+
if self.downsample is not None:
|
| 254 |
+
identity = self.downsample(identity)
|
| 255 |
+
out = x3 + identity
|
| 256 |
+
if self.IN is not None:
|
| 257 |
+
out = self.IN(out)
|
| 258 |
+
return F.relu(out)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
##########
|
| 262 |
+
# Network architecture
|
| 263 |
+
##########
|
| 264 |
+
class OSNet(nn.Module):
|
| 265 |
+
"""Omni-Scale Network.
|
| 266 |
+
|
| 267 |
+
Reference:
|
| 268 |
+
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
|
| 269 |
+
- Zhou et al. Learning Generalisable Omni-Scale Representations
|
| 270 |
+
for Person Re-Identification. TPAMI, 2021.
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
def __init__(
|
| 274 |
+
self,
|
| 275 |
+
num_classes,
|
| 276 |
+
blocks,
|
| 277 |
+
layers,
|
| 278 |
+
channels,
|
| 279 |
+
feature_dim=512,
|
| 280 |
+
loss="softmax",
|
| 281 |
+
IN=False,
|
| 282 |
+
**kwargs
|
| 283 |
+
):
|
| 284 |
+
super(OSNet, self).__init__()
|
| 285 |
+
num_blocks = len(blocks)
|
| 286 |
+
assert num_blocks == len(layers)
|
| 287 |
+
assert num_blocks == len(channels) - 1
|
| 288 |
+
self.loss = loss
|
| 289 |
+
self.feature_dim = feature_dim
|
| 290 |
+
|
| 291 |
+
# convolutional backbone
|
| 292 |
+
self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
|
| 293 |
+
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
| 294 |
+
self.conv2 = self._make_layer(
|
| 295 |
+
blocks[0],
|
| 296 |
+
layers[0],
|
| 297 |
+
channels[0],
|
| 298 |
+
channels[1],
|
| 299 |
+
reduce_spatial_size=True,
|
| 300 |
+
IN=IN,
|
| 301 |
+
)
|
| 302 |
+
self.conv3 = self._make_layer(
|
| 303 |
+
blocks[1], layers[1], channels[1], channels[2], reduce_spatial_size=True
|
| 304 |
+
)
|
| 305 |
+
self.conv4 = self._make_layer(
|
| 306 |
+
blocks[2], layers[2], channels[2], channels[3], reduce_spatial_size=False
|
| 307 |
+
)
|
| 308 |
+
self.conv5 = Conv1x1(channels[3], channels[3])
|
| 309 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 310 |
+
# fully connected layer
|
| 311 |
+
self.fc = self._construct_fc_layer(
|
| 312 |
+
self.feature_dim, channels[3], dropout_p=None
|
| 313 |
+
)
|
| 314 |
+
# identity classification layer
|
| 315 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
| 316 |
+
|
| 317 |
+
self._init_params()
|
| 318 |
+
|
| 319 |
+
def _make_layer(
|
| 320 |
+
self, block, layer, in_channels, out_channels, reduce_spatial_size, IN=False
|
| 321 |
+
):
|
| 322 |
+
layers = []
|
| 323 |
+
|
| 324 |
+
layers.append(block(in_channels, out_channels, IN=IN))
|
| 325 |
+
for i in range(1, layer):
|
| 326 |
+
layers.append(block(out_channels, out_channels, IN=IN))
|
| 327 |
+
|
| 328 |
+
if reduce_spatial_size:
|
| 329 |
+
layers.append(
|
| 330 |
+
nn.Sequential(
|
| 331 |
+
Conv1x1(out_channels, out_channels), nn.AvgPool2d(2, stride=2)
|
| 332 |
+
)
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
return nn.Sequential(*layers)
|
| 336 |
+
|
| 337 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
| 338 |
+
if fc_dims is None or fc_dims < 0:
|
| 339 |
+
self.feature_dim = input_dim
|
| 340 |
+
return None
|
| 341 |
+
|
| 342 |
+
if isinstance(fc_dims, int):
|
| 343 |
+
fc_dims = [fc_dims]
|
| 344 |
+
|
| 345 |
+
layers = []
|
| 346 |
+
for dim in fc_dims:
|
| 347 |
+
layers.append(nn.Linear(input_dim, dim))
|
| 348 |
+
layers.append(nn.BatchNorm1d(dim))
|
| 349 |
+
layers.append(nn.ReLU(inplace=True))
|
| 350 |
+
if dropout_p is not None:
|
| 351 |
+
layers.append(nn.Dropout(p=dropout_p))
|
| 352 |
+
input_dim = dim
|
| 353 |
+
|
| 354 |
+
self.feature_dim = fc_dims[-1]
|
| 355 |
+
|
| 356 |
+
return nn.Sequential(*layers)
|
| 357 |
+
|
| 358 |
+
def _init_params(self):
|
| 359 |
+
for m in self.modules():
|
| 360 |
+
if isinstance(m, nn.Conv2d):
|
| 361 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 362 |
+
if m.bias is not None:
|
| 363 |
+
nn.init.constant_(m.bias, 0)
|
| 364 |
+
|
| 365 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 366 |
+
nn.init.constant_(m.weight, 1)
|
| 367 |
+
nn.init.constant_(m.bias, 0)
|
| 368 |
+
|
| 369 |
+
elif isinstance(m, nn.BatchNorm1d):
|
| 370 |
+
nn.init.constant_(m.weight, 1)
|
| 371 |
+
nn.init.constant_(m.bias, 0)
|
| 372 |
+
|
| 373 |
+
elif isinstance(m, nn.Linear):
|
| 374 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 375 |
+
if m.bias is not None:
|
| 376 |
+
nn.init.constant_(m.bias, 0)
|
| 377 |
+
|
| 378 |
+
def featuremaps(self, x):
|
| 379 |
+
x = self.conv1(x)
|
| 380 |
+
x = self.maxpool(x)
|
| 381 |
+
x = self.conv2(x)
|
| 382 |
+
x = self.conv3(x)
|
| 383 |
+
x = self.conv4(x)
|
| 384 |
+
x = self.conv5(x)
|
| 385 |
+
return x
|
| 386 |
+
|
| 387 |
+
def forward(self, x, return_featuremaps=False):
|
| 388 |
+
x = self.featuremaps(x)
|
| 389 |
+
if return_featuremaps:
|
| 390 |
+
return x
|
| 391 |
+
v = self.global_avgpool(x)
|
| 392 |
+
v = v.view(v.size(0), -1)
|
| 393 |
+
if self.fc is not None:
|
| 394 |
+
v = self.fc(v)
|
| 395 |
+
if not self.training:
|
| 396 |
+
return v
|
| 397 |
+
y = self.classifier(v)
|
| 398 |
+
if self.loss == "softmax":
|
| 399 |
+
return y
|
| 400 |
+
elif self.loss == "triplet":
|
| 401 |
+
return y, v
|
| 402 |
+
else:
|
| 403 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def init_pretrained_weights(model, key=""):
|
| 407 |
+
"""Initializes model with pretrained weights.
|
| 408 |
+
|
| 409 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
| 410 |
+
"""
|
| 411 |
+
import errno
|
| 412 |
+
import os
|
| 413 |
+
from collections import OrderedDict
|
| 414 |
+
|
| 415 |
+
import gdown
|
| 416 |
+
|
| 417 |
+
def _get_torch_home():
|
| 418 |
+
ENV_TORCH_HOME = "TORCH_HOME"
|
| 419 |
+
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
|
| 420 |
+
DEFAULT_CACHE_DIR = "~/.cache"
|
| 421 |
+
torch_home = os.path.expanduser(
|
| 422 |
+
os.getenv(
|
| 423 |
+
ENV_TORCH_HOME,
|
| 424 |
+
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
|
| 425 |
+
)
|
| 426 |
+
)
|
| 427 |
+
return torch_home
|
| 428 |
+
|
| 429 |
+
torch_home = _get_torch_home()
|
| 430 |
+
model_dir = os.path.join(torch_home, "checkpoints")
|
| 431 |
+
try:
|
| 432 |
+
os.makedirs(model_dir)
|
| 433 |
+
except OSError as e:
|
| 434 |
+
if e.errno == errno.EEXIST:
|
| 435 |
+
# Directory already exists, ignore.
|
| 436 |
+
pass
|
| 437 |
+
else:
|
| 438 |
+
# Unexpected OSError, re-raise.
|
| 439 |
+
raise
|
| 440 |
+
filename = key + "_imagenet.pth"
|
| 441 |
+
cached_file = os.path.join(model_dir, filename)
|
| 442 |
+
|
| 443 |
+
if not os.path.exists(cached_file):
|
| 444 |
+
gdown.download(pretrained_urls[key], cached_file, quiet=False)
|
| 445 |
+
|
| 446 |
+
state_dict = torch.load(cached_file)
|
| 447 |
+
model_dict = model.state_dict()
|
| 448 |
+
new_state_dict = OrderedDict()
|
| 449 |
+
matched_layers, discarded_layers = [], []
|
| 450 |
+
|
| 451 |
+
for k, v in state_dict.items():
|
| 452 |
+
if k.startswith("module."):
|
| 453 |
+
k = k[7:] # discard module.
|
| 454 |
+
|
| 455 |
+
if k in model_dict and model_dict[k].size() == v.size():
|
| 456 |
+
new_state_dict[k] = v
|
| 457 |
+
matched_layers.append(k)
|
| 458 |
+
else:
|
| 459 |
+
discarded_layers.append(k)
|
| 460 |
+
|
| 461 |
+
model_dict.update(new_state_dict)
|
| 462 |
+
model.load_state_dict(model_dict)
|
| 463 |
+
|
| 464 |
+
if len(matched_layers) == 0:
|
| 465 |
+
warnings.warn(
|
| 466 |
+
'The pretrained weights from "{}" cannot be loaded, '
|
| 467 |
+
"please check the key names manually "
|
| 468 |
+
"(** ignored and continue **)".format(cached_file)
|
| 469 |
+
)
|
| 470 |
+
else:
|
| 471 |
+
print(
|
| 472 |
+
'Successfully loaded imagenet pretrained weights from "{}"'.format(
|
| 473 |
+
cached_file
|
| 474 |
+
)
|
| 475 |
+
)
|
| 476 |
+
if len(discarded_layers) > 0:
|
| 477 |
+
print(
|
| 478 |
+
"** The following layers are discarded "
|
| 479 |
+
"due to unmatched keys or layer size: {}".format(discarded_layers)
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
##########
|
| 484 |
+
# Instantiation
|
| 485 |
+
##########
|
| 486 |
+
def osnet_x1_0(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
|
| 487 |
+
# standard size (width x1.0)
|
| 488 |
+
model = OSNet(
|
| 489 |
+
num_classes,
|
| 490 |
+
blocks=[OSBlock, OSBlock, OSBlock],
|
| 491 |
+
layers=[2, 2, 2],
|
| 492 |
+
channels=[64, 256, 384, 512],
|
| 493 |
+
loss=loss,
|
| 494 |
+
**kwargs
|
| 495 |
+
)
|
| 496 |
+
if pretrained:
|
| 497 |
+
init_pretrained_weights(model, key="osnet_x1_0")
|
| 498 |
+
return model
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def osnet_x0_75(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
|
| 502 |
+
# medium size (width x0.75)
|
| 503 |
+
model = OSNet(
|
| 504 |
+
num_classes,
|
| 505 |
+
blocks=[OSBlock, OSBlock, OSBlock],
|
| 506 |
+
layers=[2, 2, 2],
|
| 507 |
+
channels=[48, 192, 288, 384],
|
| 508 |
+
loss=loss,
|
| 509 |
+
**kwargs
|
| 510 |
+
)
|
| 511 |
+
if pretrained:
|
| 512 |
+
init_pretrained_weights(model, key="osnet_x0_75")
|
| 513 |
+
return model
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def osnet_x0_5(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
|
| 517 |
+
# tiny size (width x0.5)
|
| 518 |
+
model = OSNet(
|
| 519 |
+
num_classes,
|
| 520 |
+
blocks=[OSBlock, OSBlock, OSBlock],
|
| 521 |
+
layers=[2, 2, 2],
|
| 522 |
+
channels=[32, 128, 192, 256],
|
| 523 |
+
loss=loss,
|
| 524 |
+
**kwargs
|
| 525 |
+
)
|
| 526 |
+
if pretrained:
|
| 527 |
+
init_pretrained_weights(model, key="osnet_x0_5")
|
| 528 |
+
return model
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def osnet_x0_25(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
|
| 532 |
+
# very tiny size (width x0.25)
|
| 533 |
+
model = OSNet(
|
| 534 |
+
num_classes,
|
| 535 |
+
blocks=[OSBlock, OSBlock, OSBlock],
|
| 536 |
+
layers=[2, 2, 2],
|
| 537 |
+
channels=[16, 64, 96, 128],
|
| 538 |
+
loss=loss,
|
| 539 |
+
**kwargs
|
| 540 |
+
)
|
| 541 |
+
if pretrained:
|
| 542 |
+
init_pretrained_weights(model, key="osnet_x0_25")
|
| 543 |
+
return model
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def osnet_ibn_x1_0(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
|
| 547 |
+
# standard size (width x1.0) + IBN layer
|
| 548 |
+
# Ref: Pan et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net. ECCV, 2018.
|
| 549 |
+
model = OSNet(
|
| 550 |
+
num_classes,
|
| 551 |
+
blocks=[OSBlock, OSBlock, OSBlock],
|
| 552 |
+
layers=[2, 2, 2],
|
| 553 |
+
channels=[64, 256, 384, 512],
|
| 554 |
+
loss=loss,
|
| 555 |
+
IN=True,
|
| 556 |
+
**kwargs
|
| 557 |
+
)
|
| 558 |
+
if pretrained:
|
| 559 |
+
init_pretrained_weights(model, key="osnet_ibn_x1_0")
|
| 560 |
+
return model
|
boxmot/appearance/backbones/osnet_ain.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
from __future__ import absolute_import, division
|
| 4 |
+
|
| 5 |
+
import warnings
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
__all__ = ["osnet_ain_x1_0", "osnet_ain_x0_75", "osnet_ain_x0_5", "osnet_ain_x0_25"]
|
| 12 |
+
|
| 13 |
+
pretrained_urls = {
|
| 14 |
+
"osnet_ain_x1_0": "https://drive.google.com/uc?id=1-CaioD9NaqbHK_kzSMW8VE4_3KcsRjEo",
|
| 15 |
+
"osnet_ain_x0_75": "https://drive.google.com/uc?id=1apy0hpsMypqstfencdH-jKIUEFOW4xoM",
|
| 16 |
+
"osnet_ain_x0_5": "https://drive.google.com/uc?id=1KusKvEYyKGDTUBVRxRiz55G31wkihB6l",
|
| 17 |
+
"osnet_ain_x0_25": "https://drive.google.com/uc?id=1SxQt2AvmEcgWNhaRb2xC4rP6ZwVDP0Wt",
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
##########
|
| 22 |
+
# Basic layers
|
| 23 |
+
##########
|
| 24 |
+
class ConvLayer(nn.Module):
|
| 25 |
+
"""Convolution layer (conv + bn + relu)."""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
in_channels,
|
| 30 |
+
out_channels,
|
| 31 |
+
kernel_size,
|
| 32 |
+
stride=1,
|
| 33 |
+
padding=0,
|
| 34 |
+
groups=1,
|
| 35 |
+
IN=False,
|
| 36 |
+
):
|
| 37 |
+
super(ConvLayer, self).__init__()
|
| 38 |
+
self.conv = nn.Conv2d(
|
| 39 |
+
in_channels,
|
| 40 |
+
out_channels,
|
| 41 |
+
kernel_size,
|
| 42 |
+
stride=stride,
|
| 43 |
+
padding=padding,
|
| 44 |
+
bias=False,
|
| 45 |
+
groups=groups,
|
| 46 |
+
)
|
| 47 |
+
if IN:
|
| 48 |
+
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
|
| 49 |
+
else:
|
| 50 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 51 |
+
self.relu = nn.ReLU()
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
x = self.conv(x)
|
| 55 |
+
x = self.bn(x)
|
| 56 |
+
return self.relu(x)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Conv1x1(nn.Module):
|
| 60 |
+
"""1x1 convolution + bn + relu."""
|
| 61 |
+
|
| 62 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
| 63 |
+
super(Conv1x1, self).__init__()
|
| 64 |
+
self.conv = nn.Conv2d(
|
| 65 |
+
in_channels,
|
| 66 |
+
out_channels,
|
| 67 |
+
1,
|
| 68 |
+
stride=stride,
|
| 69 |
+
padding=0,
|
| 70 |
+
bias=False,
|
| 71 |
+
groups=groups,
|
| 72 |
+
)
|
| 73 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 74 |
+
self.relu = nn.ReLU()
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
x = self.conv(x)
|
| 78 |
+
x = self.bn(x)
|
| 79 |
+
return self.relu(x)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Conv1x1Linear(nn.Module):
|
| 83 |
+
"""1x1 convolution + bn (w/o non-linearity)."""
|
| 84 |
+
|
| 85 |
+
def __init__(self, in_channels, out_channels, stride=1, bn=True):
|
| 86 |
+
super(Conv1x1Linear, self).__init__()
|
| 87 |
+
self.conv = nn.Conv2d(
|
| 88 |
+
in_channels, out_channels, 1, stride=stride, padding=0, bias=False
|
| 89 |
+
)
|
| 90 |
+
self.bn = None
|
| 91 |
+
if bn:
|
| 92 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
x = self.conv(x)
|
| 96 |
+
if self.bn is not None:
|
| 97 |
+
x = self.bn(x)
|
| 98 |
+
return x
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class Conv3x3(nn.Module):
|
| 102 |
+
"""3x3 convolution + bn + relu."""
|
| 103 |
+
|
| 104 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
| 105 |
+
super(Conv3x3, self).__init__()
|
| 106 |
+
self.conv = nn.Conv2d(
|
| 107 |
+
in_channels,
|
| 108 |
+
out_channels,
|
| 109 |
+
3,
|
| 110 |
+
stride=stride,
|
| 111 |
+
padding=1,
|
| 112 |
+
bias=False,
|
| 113 |
+
groups=groups,
|
| 114 |
+
)
|
| 115 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 116 |
+
self.relu = nn.ReLU()
|
| 117 |
+
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
x = self.conv(x)
|
| 120 |
+
x = self.bn(x)
|
| 121 |
+
return self.relu(x)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class LightConv3x3(nn.Module):
|
| 125 |
+
"""Lightweight 3x3 convolution.
|
| 126 |
+
|
| 127 |
+
1x1 (linear) + dw 3x3 (nonlinear).
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def __init__(self, in_channels, out_channels):
|
| 131 |
+
super(LightConv3x3, self).__init__()
|
| 132 |
+
self.conv1 = nn.Conv2d(
|
| 133 |
+
in_channels, out_channels, 1, stride=1, padding=0, bias=False
|
| 134 |
+
)
|
| 135 |
+
self.conv2 = nn.Conv2d(
|
| 136 |
+
out_channels,
|
| 137 |
+
out_channels,
|
| 138 |
+
3,
|
| 139 |
+
stride=1,
|
| 140 |
+
padding=1,
|
| 141 |
+
bias=False,
|
| 142 |
+
groups=out_channels,
|
| 143 |
+
)
|
| 144 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 145 |
+
self.relu = nn.ReLU()
|
| 146 |
+
|
| 147 |
+
def forward(self, x):
|
| 148 |
+
x = self.conv1(x)
|
| 149 |
+
x = self.conv2(x)
|
| 150 |
+
x = self.bn(x)
|
| 151 |
+
return self.relu(x)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class LightConvStream(nn.Module):
|
| 155 |
+
"""Lightweight convolution stream."""
|
| 156 |
+
|
| 157 |
+
def __init__(self, in_channels, out_channels, depth):
|
| 158 |
+
super(LightConvStream, self).__init__()
|
| 159 |
+
assert depth >= 1, "depth must be equal to or larger than 1, but got {}".format(
|
| 160 |
+
depth
|
| 161 |
+
)
|
| 162 |
+
layers = []
|
| 163 |
+
layers += [LightConv3x3(in_channels, out_channels)]
|
| 164 |
+
for i in range(depth - 1):
|
| 165 |
+
layers += [LightConv3x3(out_channels, out_channels)]
|
| 166 |
+
self.layers = nn.Sequential(*layers)
|
| 167 |
+
|
| 168 |
+
def forward(self, x):
|
| 169 |
+
return self.layers(x)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
##########
|
| 173 |
+
# Building blocks for omni-scale feature learning
|
| 174 |
+
##########
|
| 175 |
+
class ChannelGate(nn.Module):
|
| 176 |
+
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
|
| 177 |
+
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
in_channels,
|
| 181 |
+
num_gates=None,
|
| 182 |
+
return_gates=False,
|
| 183 |
+
gate_activation="sigmoid",
|
| 184 |
+
reduction=16,
|
| 185 |
+
layer_norm=False,
|
| 186 |
+
):
|
| 187 |
+
super(ChannelGate, self).__init__()
|
| 188 |
+
if num_gates is None:
|
| 189 |
+
num_gates = in_channels
|
| 190 |
+
self.return_gates = return_gates
|
| 191 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 192 |
+
self.fc1 = nn.Conv2d(
|
| 193 |
+
in_channels, in_channels // reduction, kernel_size=1, bias=True, padding=0
|
| 194 |
+
)
|
| 195 |
+
self.norm1 = None
|
| 196 |
+
if layer_norm:
|
| 197 |
+
self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
|
| 198 |
+
self.relu = nn.ReLU()
|
| 199 |
+
self.fc2 = nn.Conv2d(
|
| 200 |
+
in_channels // reduction, num_gates, kernel_size=1, bias=True, padding=0
|
| 201 |
+
)
|
| 202 |
+
if gate_activation == "sigmoid":
|
| 203 |
+
self.gate_activation = nn.Sigmoid()
|
| 204 |
+
elif gate_activation == "relu":
|
| 205 |
+
self.gate_activation = nn.ReLU()
|
| 206 |
+
elif gate_activation == "linear":
|
| 207 |
+
self.gate_activation = None
|
| 208 |
+
else:
|
| 209 |
+
raise RuntimeError("Unknown gate activation: {}".format(gate_activation))
|
| 210 |
+
|
| 211 |
+
def forward(self, x):
|
| 212 |
+
input = x
|
| 213 |
+
x = self.global_avgpool(x)
|
| 214 |
+
x = self.fc1(x)
|
| 215 |
+
if self.norm1 is not None:
|
| 216 |
+
x = self.norm1(x)
|
| 217 |
+
x = self.relu(x)
|
| 218 |
+
x = self.fc2(x)
|
| 219 |
+
if self.gate_activation is not None:
|
| 220 |
+
x = self.gate_activation(x)
|
| 221 |
+
if self.return_gates:
|
| 222 |
+
return x
|
| 223 |
+
return input * x
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class OSBlock(nn.Module):
|
| 227 |
+
"""Omni-scale feature learning block."""
|
| 228 |
+
|
| 229 |
+
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
|
| 230 |
+
super(OSBlock, self).__init__()
|
| 231 |
+
assert T >= 1
|
| 232 |
+
assert out_channels >= reduction and out_channels % reduction == 0
|
| 233 |
+
mid_channels = out_channels // reduction
|
| 234 |
+
|
| 235 |
+
self.conv1 = Conv1x1(in_channels, mid_channels)
|
| 236 |
+
self.conv2 = nn.ModuleList()
|
| 237 |
+
for t in range(1, T + 1):
|
| 238 |
+
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
|
| 239 |
+
self.gate = ChannelGate(mid_channels)
|
| 240 |
+
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
|
| 241 |
+
self.downsample = None
|
| 242 |
+
if in_channels != out_channels:
|
| 243 |
+
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
| 244 |
+
|
| 245 |
+
def forward(self, x):
|
| 246 |
+
identity = x
|
| 247 |
+
x1 = self.conv1(x)
|
| 248 |
+
x2 = 0
|
| 249 |
+
for conv2_t in self.conv2:
|
| 250 |
+
x2_t = conv2_t(x1)
|
| 251 |
+
x2 = x2 + self.gate(x2_t)
|
| 252 |
+
x3 = self.conv3(x2)
|
| 253 |
+
if self.downsample is not None:
|
| 254 |
+
identity = self.downsample(identity)
|
| 255 |
+
out = x3 + identity
|
| 256 |
+
return F.relu(out)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class OSBlockINin(nn.Module):
|
| 260 |
+
"""Omni-scale feature learning block with instance normalization."""
|
| 261 |
+
|
| 262 |
+
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
|
| 263 |
+
super(OSBlockINin, self).__init__()
|
| 264 |
+
assert T >= 1
|
| 265 |
+
assert out_channels >= reduction and out_channels % reduction == 0
|
| 266 |
+
mid_channels = out_channels // reduction
|
| 267 |
+
|
| 268 |
+
self.conv1 = Conv1x1(in_channels, mid_channels)
|
| 269 |
+
self.conv2 = nn.ModuleList()
|
| 270 |
+
for t in range(1, T + 1):
|
| 271 |
+
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
|
| 272 |
+
self.gate = ChannelGate(mid_channels)
|
| 273 |
+
self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False)
|
| 274 |
+
self.downsample = None
|
| 275 |
+
if in_channels != out_channels:
|
| 276 |
+
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
| 277 |
+
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
|
| 278 |
+
|
| 279 |
+
def forward(self, x):
|
| 280 |
+
identity = x
|
| 281 |
+
x1 = self.conv1(x)
|
| 282 |
+
x2 = 0
|
| 283 |
+
for conv2_t in self.conv2:
|
| 284 |
+
x2_t = conv2_t(x1)
|
| 285 |
+
x2 = x2 + self.gate(x2_t)
|
| 286 |
+
x3 = self.conv3(x2)
|
| 287 |
+
x3 = self.IN(x3) # IN inside residual
|
| 288 |
+
if self.downsample is not None:
|
| 289 |
+
identity = self.downsample(identity)
|
| 290 |
+
out = x3 + identity
|
| 291 |
+
return F.relu(out)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
##########
|
| 295 |
+
# Network architecture
|
| 296 |
+
##########
|
| 297 |
+
class OSNet(nn.Module):
|
| 298 |
+
"""Omni-Scale Network.
|
| 299 |
+
|
| 300 |
+
Reference:
|
| 301 |
+
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
|
| 302 |
+
- Zhou et al. Learning Generalisable Omni-Scale Representations
|
| 303 |
+
for Person Re-Identification. TPAMI, 2021.
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
def __init__(
|
| 307 |
+
self,
|
| 308 |
+
num_classes,
|
| 309 |
+
blocks,
|
| 310 |
+
layers,
|
| 311 |
+
channels,
|
| 312 |
+
feature_dim=512,
|
| 313 |
+
loss="softmax",
|
| 314 |
+
conv1_IN=False,
|
| 315 |
+
**kwargs
|
| 316 |
+
):
|
| 317 |
+
super(OSNet, self).__init__()
|
| 318 |
+
num_blocks = len(blocks)
|
| 319 |
+
assert num_blocks == len(layers)
|
| 320 |
+
assert num_blocks == len(channels) - 1
|
| 321 |
+
self.loss = loss
|
| 322 |
+
self.feature_dim = feature_dim
|
| 323 |
+
|
| 324 |
+
# convolutional backbone
|
| 325 |
+
self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=conv1_IN)
|
| 326 |
+
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
| 327 |
+
self.conv2 = self._make_layer(blocks[0], layers[0], channels[0], channels[1])
|
| 328 |
+
self.pool2 = nn.Sequential(
|
| 329 |
+
Conv1x1(channels[1], channels[1]), nn.AvgPool2d(2, stride=2)
|
| 330 |
+
)
|
| 331 |
+
self.conv3 = self._make_layer(blocks[1], layers[1], channels[1], channels[2])
|
| 332 |
+
self.pool3 = nn.Sequential(
|
| 333 |
+
Conv1x1(channels[2], channels[2]), nn.AvgPool2d(2, stride=2)
|
| 334 |
+
)
|
| 335 |
+
self.conv4 = self._make_layer(blocks[2], layers[2], channels[2], channels[3])
|
| 336 |
+
self.conv5 = Conv1x1(channels[3], channels[3])
|
| 337 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 338 |
+
# fully connected layer
|
| 339 |
+
self.fc = self._construct_fc_layer(
|
| 340 |
+
self.feature_dim, channels[3], dropout_p=None
|
| 341 |
+
)
|
| 342 |
+
# identity classification layer
|
| 343 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
| 344 |
+
|
| 345 |
+
self._init_params()
|
| 346 |
+
|
| 347 |
+
def _make_layer(self, blocks, layer, in_channels, out_channels):
|
| 348 |
+
layers = []
|
| 349 |
+
layers += [blocks[0](in_channels, out_channels)]
|
| 350 |
+
for i in range(1, len(blocks)):
|
| 351 |
+
layers += [blocks[i](out_channels, out_channels)]
|
| 352 |
+
return nn.Sequential(*layers)
|
| 353 |
+
|
| 354 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
| 355 |
+
if fc_dims is None or fc_dims < 0:
|
| 356 |
+
self.feature_dim = input_dim
|
| 357 |
+
return None
|
| 358 |
+
|
| 359 |
+
if isinstance(fc_dims, int):
|
| 360 |
+
fc_dims = [fc_dims]
|
| 361 |
+
|
| 362 |
+
layers = []
|
| 363 |
+
for dim in fc_dims:
|
| 364 |
+
layers.append(nn.Linear(input_dim, dim))
|
| 365 |
+
layers.append(nn.BatchNorm1d(dim))
|
| 366 |
+
layers.append(nn.ReLU())
|
| 367 |
+
if dropout_p is not None:
|
| 368 |
+
layers.append(nn.Dropout(p=dropout_p))
|
| 369 |
+
input_dim = dim
|
| 370 |
+
|
| 371 |
+
self.feature_dim = fc_dims[-1]
|
| 372 |
+
|
| 373 |
+
return nn.Sequential(*layers)
|
| 374 |
+
|
| 375 |
+
def _init_params(self):
|
| 376 |
+
for m in self.modules():
|
| 377 |
+
if isinstance(m, nn.Conv2d):
|
| 378 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 379 |
+
if m.bias is not None:
|
| 380 |
+
nn.init.constant_(m.bias, 0)
|
| 381 |
+
|
| 382 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 383 |
+
nn.init.constant_(m.weight, 1)
|
| 384 |
+
nn.init.constant_(m.bias, 0)
|
| 385 |
+
|
| 386 |
+
elif isinstance(m, nn.BatchNorm1d):
|
| 387 |
+
nn.init.constant_(m.weight, 1)
|
| 388 |
+
nn.init.constant_(m.bias, 0)
|
| 389 |
+
|
| 390 |
+
elif isinstance(m, nn.InstanceNorm2d):
|
| 391 |
+
nn.init.constant_(m.weight, 1)
|
| 392 |
+
nn.init.constant_(m.bias, 0)
|
| 393 |
+
|
| 394 |
+
elif isinstance(m, nn.Linear):
|
| 395 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 396 |
+
if m.bias is not None:
|
| 397 |
+
nn.init.constant_(m.bias, 0)
|
| 398 |
+
|
| 399 |
+
def featuremaps(self, x):
|
| 400 |
+
x = self.conv1(x)
|
| 401 |
+
x = self.maxpool(x)
|
| 402 |
+
x = self.conv2(x)
|
| 403 |
+
x = self.pool2(x)
|
| 404 |
+
x = self.conv3(x)
|
| 405 |
+
x = self.pool3(x)
|
| 406 |
+
x = self.conv4(x)
|
| 407 |
+
x = self.conv5(x)
|
| 408 |
+
return x
|
| 409 |
+
|
| 410 |
+
def forward(self, x, return_featuremaps=False):
|
| 411 |
+
x = self.featuremaps(x)
|
| 412 |
+
if return_featuremaps:
|
| 413 |
+
return x
|
| 414 |
+
v = self.global_avgpool(x)
|
| 415 |
+
v = v.view(v.size(0), -1)
|
| 416 |
+
if self.fc is not None:
|
| 417 |
+
v = self.fc(v)
|
| 418 |
+
if not self.training:
|
| 419 |
+
return v
|
| 420 |
+
y = self.classifier(v)
|
| 421 |
+
if self.loss == "softmax":
|
| 422 |
+
return y
|
| 423 |
+
elif self.loss == "triplet":
|
| 424 |
+
return y, v
|
| 425 |
+
else:
|
| 426 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def init_pretrained_weights(model, key=""):
|
| 430 |
+
"""Initializes model with pretrained weights.
|
| 431 |
+
|
| 432 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
| 433 |
+
"""
|
| 434 |
+
import errno
|
| 435 |
+
import os
|
| 436 |
+
from collections import OrderedDict
|
| 437 |
+
|
| 438 |
+
import gdown
|
| 439 |
+
|
| 440 |
+
def _get_torch_home():
|
| 441 |
+
ENV_TORCH_HOME = "TORCH_HOME"
|
| 442 |
+
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
|
| 443 |
+
DEFAULT_CACHE_DIR = "~/.cache"
|
| 444 |
+
torch_home = os.path.expanduser(
|
| 445 |
+
os.getenv(
|
| 446 |
+
ENV_TORCH_HOME,
|
| 447 |
+
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
|
| 448 |
+
)
|
| 449 |
+
)
|
| 450 |
+
return torch_home
|
| 451 |
+
|
| 452 |
+
torch_home = _get_torch_home()
|
| 453 |
+
model_dir = os.path.join(torch_home, "checkpoints")
|
| 454 |
+
try:
|
| 455 |
+
os.makedirs(model_dir)
|
| 456 |
+
except OSError as e:
|
| 457 |
+
if e.errno == errno.EEXIST:
|
| 458 |
+
# Directory already exists, ignore.
|
| 459 |
+
pass
|
| 460 |
+
else:
|
| 461 |
+
# Unexpected OSError, re-raise.
|
| 462 |
+
raise
|
| 463 |
+
filename = key + "_imagenet.pth"
|
| 464 |
+
cached_file = os.path.join(model_dir, filename)
|
| 465 |
+
|
| 466 |
+
if not os.path.exists(cached_file):
|
| 467 |
+
gdown.download(pretrained_urls[key], cached_file, quiet=False)
|
| 468 |
+
|
| 469 |
+
state_dict = torch.load(cached_file)
|
| 470 |
+
model_dict = model.state_dict()
|
| 471 |
+
new_state_dict = OrderedDict()
|
| 472 |
+
matched_layers, discarded_layers = [], []
|
| 473 |
+
|
| 474 |
+
for k, v in state_dict.items():
|
| 475 |
+
if k.startswith("module."):
|
| 476 |
+
k = k[7:] # discard module.
|
| 477 |
+
|
| 478 |
+
if k in model_dict and model_dict[k].size() == v.size():
|
| 479 |
+
new_state_dict[k] = v
|
| 480 |
+
matched_layers.append(k)
|
| 481 |
+
else:
|
| 482 |
+
discarded_layers.append(k)
|
| 483 |
+
|
| 484 |
+
model_dict.update(new_state_dict)
|
| 485 |
+
model.load_state_dict(model_dict)
|
| 486 |
+
|
| 487 |
+
if len(matched_layers) == 0:
|
| 488 |
+
warnings.warn(
|
| 489 |
+
'The pretrained weights from "{}" cannot be loaded, '
|
| 490 |
+
"please check the key names manually "
|
| 491 |
+
"(** ignored and continue **)".format(cached_file)
|
| 492 |
+
)
|
| 493 |
+
else:
|
| 494 |
+
print(
|
| 495 |
+
'Successfully loaded imagenet pretrained weights from "{}"'.format(
|
| 496 |
+
cached_file
|
| 497 |
+
)
|
| 498 |
+
)
|
| 499 |
+
if len(discarded_layers) > 0:
|
| 500 |
+
print(
|
| 501 |
+
"** The following layers are discarded "
|
| 502 |
+
"due to unmatched keys or layer size: {}".format(discarded_layers)
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
##########
|
| 507 |
+
# Instantiation
|
| 508 |
+
##########
|
| 509 |
+
def osnet_ain_x1_0(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
|
| 510 |
+
model = OSNet(
|
| 511 |
+
num_classes,
|
| 512 |
+
blocks=[
|
| 513 |
+
[OSBlockINin, OSBlockINin],
|
| 514 |
+
[OSBlock, OSBlockINin],
|
| 515 |
+
[OSBlockINin, OSBlock],
|
| 516 |
+
],
|
| 517 |
+
layers=[2, 2, 2],
|
| 518 |
+
channels=[64, 256, 384, 512],
|
| 519 |
+
loss=loss,
|
| 520 |
+
conv1_IN=True,
|
| 521 |
+
**kwargs
|
| 522 |
+
)
|
| 523 |
+
if pretrained:
|
| 524 |
+
init_pretrained_weights(model, key="osnet_ain_x1_0")
|
| 525 |
+
return model
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def osnet_ain_x0_75(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
|
| 529 |
+
model = OSNet(
|
| 530 |
+
num_classes,
|
| 531 |
+
blocks=[
|
| 532 |
+
[OSBlockINin, OSBlockINin],
|
| 533 |
+
[OSBlock, OSBlockINin],
|
| 534 |
+
[OSBlockINin, OSBlock],
|
| 535 |
+
],
|
| 536 |
+
layers=[2, 2, 2],
|
| 537 |
+
channels=[48, 192, 288, 384],
|
| 538 |
+
loss=loss,
|
| 539 |
+
conv1_IN=True,
|
| 540 |
+
**kwargs
|
| 541 |
+
)
|
| 542 |
+
if pretrained:
|
| 543 |
+
init_pretrained_weights(model, key="osnet_ain_x0_75")
|
| 544 |
+
return model
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def osnet_ain_x0_5(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
|
| 548 |
+
model = OSNet(
|
| 549 |
+
num_classes,
|
| 550 |
+
blocks=[
|
| 551 |
+
[OSBlockINin, OSBlockINin],
|
| 552 |
+
[OSBlock, OSBlockINin],
|
| 553 |
+
[OSBlockINin, OSBlock],
|
| 554 |
+
],
|
| 555 |
+
layers=[2, 2, 2],
|
| 556 |
+
channels=[32, 128, 192, 256],
|
| 557 |
+
loss=loss,
|
| 558 |
+
conv1_IN=True,
|
| 559 |
+
**kwargs
|
| 560 |
+
)
|
| 561 |
+
if pretrained:
|
| 562 |
+
init_pretrained_weights(model, key="osnet_ain_x0_5")
|
| 563 |
+
return model
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def osnet_ain_x0_25(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
|
| 567 |
+
model = OSNet(
|
| 568 |
+
num_classes,
|
| 569 |
+
blocks=[
|
| 570 |
+
[OSBlockINin, OSBlockINin],
|
| 571 |
+
[OSBlock, OSBlockINin],
|
| 572 |
+
[OSBlockINin, OSBlock],
|
| 573 |
+
],
|
| 574 |
+
layers=[2, 2, 2],
|
| 575 |
+
channels=[16, 64, 96, 128],
|
| 576 |
+
loss=loss,
|
| 577 |
+
conv1_IN=True,
|
| 578 |
+
**kwargs
|
| 579 |
+
)
|
| 580 |
+
if pretrained:
|
| 581 |
+
init_pretrained_weights(model, key="osnet_ain_x0_25")
|
| 582 |
+
return model
|
boxmot/appearance/backbones/resnet.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Code source: https://github.com/pytorch/vision
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import absolute_import, division
|
| 7 |
+
|
| 8 |
+
import torch.utils.model_zoo as model_zoo
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"resnet18",
|
| 13 |
+
"resnet34",
|
| 14 |
+
"resnet50",
|
| 15 |
+
"resnet101",
|
| 16 |
+
"resnet152",
|
| 17 |
+
"resnext50_32x4d",
|
| 18 |
+
"resnext101_32x8d",
|
| 19 |
+
"resnet50_fc512",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
model_urls = {
|
| 23 |
+
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
|
| 24 |
+
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
|
| 25 |
+
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
|
| 26 |
+
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
|
| 27 |
+
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
|
| 28 |
+
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
| 29 |
+
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 34 |
+
"""3x3 convolution with padding"""
|
| 35 |
+
return nn.Conv2d(
|
| 36 |
+
in_planes,
|
| 37 |
+
out_planes,
|
| 38 |
+
kernel_size=3,
|
| 39 |
+
stride=stride,
|
| 40 |
+
padding=dilation,
|
| 41 |
+
groups=groups,
|
| 42 |
+
bias=False,
|
| 43 |
+
dilation=dilation,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 48 |
+
"""1x1 convolution"""
|
| 49 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class BasicBlock(nn.Module):
|
| 53 |
+
expansion = 1
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
inplanes,
|
| 58 |
+
planes,
|
| 59 |
+
stride=1,
|
| 60 |
+
downsample=None,
|
| 61 |
+
groups=1,
|
| 62 |
+
base_width=64,
|
| 63 |
+
dilation=1,
|
| 64 |
+
norm_layer=None,
|
| 65 |
+
):
|
| 66 |
+
super(BasicBlock, self).__init__()
|
| 67 |
+
if norm_layer is None:
|
| 68 |
+
norm_layer = nn.BatchNorm2d
|
| 69 |
+
if groups != 1 or base_width != 64:
|
| 70 |
+
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
| 71 |
+
if dilation > 1:
|
| 72 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 73 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 74 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 75 |
+
self.bn1 = norm_layer(planes)
|
| 76 |
+
self.relu = nn.ReLU(inplace=True)
|
| 77 |
+
self.conv2 = conv3x3(planes, planes)
|
| 78 |
+
self.bn2 = norm_layer(planes)
|
| 79 |
+
self.downsample = downsample
|
| 80 |
+
self.stride = stride
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
identity = x
|
| 84 |
+
|
| 85 |
+
out = self.conv1(x)
|
| 86 |
+
out = self.bn1(out)
|
| 87 |
+
out = self.relu(out)
|
| 88 |
+
|
| 89 |
+
out = self.conv2(out)
|
| 90 |
+
out = self.bn2(out)
|
| 91 |
+
|
| 92 |
+
if self.downsample is not None:
|
| 93 |
+
identity = self.downsample(x)
|
| 94 |
+
|
| 95 |
+
out += identity
|
| 96 |
+
out = self.relu(out)
|
| 97 |
+
|
| 98 |
+
return out
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class Bottleneck(nn.Module):
|
| 102 |
+
expansion = 4
|
| 103 |
+
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
inplanes,
|
| 107 |
+
planes,
|
| 108 |
+
stride=1,
|
| 109 |
+
downsample=None,
|
| 110 |
+
groups=1,
|
| 111 |
+
base_width=64,
|
| 112 |
+
dilation=1,
|
| 113 |
+
norm_layer=None,
|
| 114 |
+
):
|
| 115 |
+
super(Bottleneck, self).__init__()
|
| 116 |
+
if norm_layer is None:
|
| 117 |
+
norm_layer = nn.BatchNorm2d
|
| 118 |
+
width = int(planes * (base_width / 64.0)) * groups
|
| 119 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 120 |
+
self.conv1 = conv1x1(inplanes, width)
|
| 121 |
+
self.bn1 = norm_layer(width)
|
| 122 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
| 123 |
+
self.bn2 = norm_layer(width)
|
| 124 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
| 125 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 126 |
+
self.relu = nn.ReLU(inplace=True)
|
| 127 |
+
self.downsample = downsample
|
| 128 |
+
self.stride = stride
|
| 129 |
+
|
| 130 |
+
def forward(self, x):
|
| 131 |
+
identity = x
|
| 132 |
+
|
| 133 |
+
out = self.conv1(x)
|
| 134 |
+
out = self.bn1(out)
|
| 135 |
+
out = self.relu(out)
|
| 136 |
+
|
| 137 |
+
out = self.conv2(out)
|
| 138 |
+
out = self.bn2(out)
|
| 139 |
+
out = self.relu(out)
|
| 140 |
+
|
| 141 |
+
out = self.conv3(out)
|
| 142 |
+
out = self.bn3(out)
|
| 143 |
+
|
| 144 |
+
if self.downsample is not None:
|
| 145 |
+
identity = self.downsample(x)
|
| 146 |
+
|
| 147 |
+
out += identity
|
| 148 |
+
out = self.relu(out)
|
| 149 |
+
|
| 150 |
+
return out
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class ResNet(nn.Module):
|
| 154 |
+
"""Residual network.
|
| 155 |
+
|
| 156 |
+
Reference:
|
| 157 |
+
- He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
|
| 158 |
+
- Xie et al. Aggregated Residual Transformations for Deep Neural Networks. CVPR 2017.
|
| 159 |
+
|
| 160 |
+
Public keys:
|
| 161 |
+
- ``resnet18``: ResNet18.
|
| 162 |
+
- ``resnet34``: ResNet34.
|
| 163 |
+
- ``resnet50``: ResNet50.
|
| 164 |
+
- ``resnet101``: ResNet101.
|
| 165 |
+
- ``resnet152``: ResNet152.
|
| 166 |
+
- ``resnext50_32x4d``: ResNeXt50.
|
| 167 |
+
- ``resnext101_32x8d``: ResNeXt101.
|
| 168 |
+
- ``resnet50_fc512``: ResNet50 + FC.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
num_classes,
|
| 174 |
+
loss,
|
| 175 |
+
block,
|
| 176 |
+
layers,
|
| 177 |
+
zero_init_residual=False,
|
| 178 |
+
groups=1,
|
| 179 |
+
width_per_group=64,
|
| 180 |
+
replace_stride_with_dilation=None,
|
| 181 |
+
norm_layer=None,
|
| 182 |
+
last_stride=2,
|
| 183 |
+
fc_dims=None,
|
| 184 |
+
dropout_p=None,
|
| 185 |
+
**kwargs
|
| 186 |
+
):
|
| 187 |
+
super(ResNet, self).__init__()
|
| 188 |
+
if norm_layer is None:
|
| 189 |
+
norm_layer = nn.BatchNorm2d
|
| 190 |
+
self._norm_layer = norm_layer
|
| 191 |
+
self.loss = loss
|
| 192 |
+
self.feature_dim = 512 * block.expansion
|
| 193 |
+
self.inplanes = 64
|
| 194 |
+
self.dilation = 1
|
| 195 |
+
if replace_stride_with_dilation is None:
|
| 196 |
+
# each element in the tuple indicates if we should replace
|
| 197 |
+
# the 2x2 stride with a dilated convolution instead
|
| 198 |
+
replace_stride_with_dilation = [False, False, False]
|
| 199 |
+
if len(replace_stride_with_dilation) != 3:
|
| 200 |
+
raise ValueError(
|
| 201 |
+
"replace_stride_with_dilation should be None "
|
| 202 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
|
| 203 |
+
)
|
| 204 |
+
self.groups = groups
|
| 205 |
+
self.base_width = width_per_group
|
| 206 |
+
self.conv1 = nn.Conv2d(
|
| 207 |
+
3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
|
| 208 |
+
)
|
| 209 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 210 |
+
self.relu = nn.ReLU(inplace=True)
|
| 211 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 212 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 213 |
+
self.layer2 = self._make_layer(
|
| 214 |
+
block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
|
| 215 |
+
)
|
| 216 |
+
self.layer3 = self._make_layer(
|
| 217 |
+
block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
|
| 218 |
+
)
|
| 219 |
+
self.layer4 = self._make_layer(
|
| 220 |
+
block,
|
| 221 |
+
512,
|
| 222 |
+
layers[3],
|
| 223 |
+
stride=last_stride,
|
| 224 |
+
dilate=replace_stride_with_dilation[2],
|
| 225 |
+
)
|
| 226 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 227 |
+
self.fc = self._construct_fc_layer(fc_dims, 512 * block.expansion, dropout_p)
|
| 228 |
+
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
| 229 |
+
|
| 230 |
+
self._init_params()
|
| 231 |
+
|
| 232 |
+
# Zero-initialize the last BN in each residual branch,
|
| 233 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 234 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 235 |
+
if zero_init_residual:
|
| 236 |
+
for m in self.modules():
|
| 237 |
+
if isinstance(m, Bottleneck):
|
| 238 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 239 |
+
elif isinstance(m, BasicBlock):
|
| 240 |
+
nn.init.constant_(m.bn2.weight, 0)
|
| 241 |
+
|
| 242 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 243 |
+
norm_layer = self._norm_layer
|
| 244 |
+
downsample = None
|
| 245 |
+
previous_dilation = self.dilation
|
| 246 |
+
if dilate:
|
| 247 |
+
self.dilation *= stride
|
| 248 |
+
stride = 1
|
| 249 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 250 |
+
downsample = nn.Sequential(
|
| 251 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 252 |
+
norm_layer(planes * block.expansion),
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
layers = []
|
| 256 |
+
layers.append(
|
| 257 |
+
block(
|
| 258 |
+
self.inplanes,
|
| 259 |
+
planes,
|
| 260 |
+
stride,
|
| 261 |
+
downsample,
|
| 262 |
+
self.groups,
|
| 263 |
+
self.base_width,
|
| 264 |
+
previous_dilation,
|
| 265 |
+
norm_layer,
|
| 266 |
+
)
|
| 267 |
+
)
|
| 268 |
+
self.inplanes = planes * block.expansion
|
| 269 |
+
for _ in range(1, blocks):
|
| 270 |
+
layers.append(
|
| 271 |
+
block(
|
| 272 |
+
self.inplanes,
|
| 273 |
+
planes,
|
| 274 |
+
groups=self.groups,
|
| 275 |
+
base_width=self.base_width,
|
| 276 |
+
dilation=self.dilation,
|
| 277 |
+
norm_layer=norm_layer,
|
| 278 |
+
)
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
return nn.Sequential(*layers)
|
| 282 |
+
|
| 283 |
+
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
| 284 |
+
"""Constructs fully connected layer
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
| 288 |
+
input_dim (int): input dimension
|
| 289 |
+
dropout_p (float): dropout probability, if None, dropout is unused
|
| 290 |
+
"""
|
| 291 |
+
if fc_dims is None:
|
| 292 |
+
self.feature_dim = input_dim
|
| 293 |
+
return None
|
| 294 |
+
|
| 295 |
+
assert isinstance(
|
| 296 |
+
fc_dims, (list, tuple)
|
| 297 |
+
), "fc_dims must be either list or tuple, but got {}".format(type(fc_dims))
|
| 298 |
+
|
| 299 |
+
layers = []
|
| 300 |
+
for dim in fc_dims:
|
| 301 |
+
layers.append(nn.Linear(input_dim, dim))
|
| 302 |
+
layers.append(nn.BatchNorm1d(dim))
|
| 303 |
+
layers.append(nn.ReLU(inplace=True))
|
| 304 |
+
if dropout_p is not None:
|
| 305 |
+
layers.append(nn.Dropout(p=dropout_p))
|
| 306 |
+
input_dim = dim
|
| 307 |
+
|
| 308 |
+
self.feature_dim = fc_dims[-1]
|
| 309 |
+
|
| 310 |
+
return nn.Sequential(*layers)
|
| 311 |
+
|
| 312 |
+
def _init_params(self):
|
| 313 |
+
for m in self.modules():
|
| 314 |
+
if isinstance(m, nn.Conv2d):
|
| 315 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 316 |
+
if m.bias is not None:
|
| 317 |
+
nn.init.constant_(m.bias, 0)
|
| 318 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 319 |
+
nn.init.constant_(m.weight, 1)
|
| 320 |
+
nn.init.constant_(m.bias, 0)
|
| 321 |
+
elif isinstance(m, nn.BatchNorm1d):
|
| 322 |
+
nn.init.constant_(m.weight, 1)
|
| 323 |
+
nn.init.constant_(m.bias, 0)
|
| 324 |
+
elif isinstance(m, nn.Linear):
|
| 325 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 326 |
+
if m.bias is not None:
|
| 327 |
+
nn.init.constant_(m.bias, 0)
|
| 328 |
+
|
| 329 |
+
def featuremaps(self, x):
|
| 330 |
+
x = self.conv1(x)
|
| 331 |
+
x = self.bn1(x)
|
| 332 |
+
x = self.relu(x)
|
| 333 |
+
x = self.maxpool(x)
|
| 334 |
+
x = self.layer1(x)
|
| 335 |
+
x = self.layer2(x)
|
| 336 |
+
x = self.layer3(x)
|
| 337 |
+
x = self.layer4(x)
|
| 338 |
+
return x
|
| 339 |
+
|
| 340 |
+
def forward(self, x):
|
| 341 |
+
f = self.featuremaps(x)
|
| 342 |
+
v = self.global_avgpool(f)
|
| 343 |
+
v = v.view(v.size(0), -1)
|
| 344 |
+
|
| 345 |
+
if self.fc is not None:
|
| 346 |
+
v = self.fc(v)
|
| 347 |
+
|
| 348 |
+
if not self.training:
|
| 349 |
+
return v
|
| 350 |
+
|
| 351 |
+
y = self.classifier(v)
|
| 352 |
+
|
| 353 |
+
if self.loss == "softmax":
|
| 354 |
+
return y
|
| 355 |
+
elif self.loss == "triplet":
|
| 356 |
+
return y, v
|
| 357 |
+
else:
|
| 358 |
+
raise KeyError("Unsupported loss: {}".format(self.loss))
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def init_pretrained_weights(model, model_url):
|
| 362 |
+
"""Initializes model with pretrained weights.
|
| 363 |
+
|
| 364 |
+
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
| 365 |
+
"""
|
| 366 |
+
pretrain_dict = model_zoo.load_url(model_url)
|
| 367 |
+
model_dict = model.state_dict()
|
| 368 |
+
pretrain_dict = {
|
| 369 |
+
k: v
|
| 370 |
+
for k, v in pretrain_dict.items()
|
| 371 |
+
if k in model_dict and model_dict[k].size() == v.size()
|
| 372 |
+
}
|
| 373 |
+
model_dict.update(pretrain_dict)
|
| 374 |
+
model.load_state_dict(model_dict)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
"""ResNet"""
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def resnet18(num_classes, loss="softmax", pretrained=True, **kwargs):
|
| 381 |
+
model = ResNet(
|
| 382 |
+
num_classes=num_classes,
|
| 383 |
+
loss=loss,
|
| 384 |
+
block=BasicBlock,
|
| 385 |
+
layers=[2, 2, 2, 2],
|
| 386 |
+
last_stride=2,
|
| 387 |
+
fc_dims=None,
|
| 388 |
+
dropout_p=None,
|
| 389 |
+
**kwargs
|
| 390 |
+
)
|
| 391 |
+
if pretrained:
|
| 392 |
+
init_pretrained_weights(model, model_urls["resnet18"])
|
| 393 |
+
return model
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def resnet34(num_classes, loss="softmax", pretrained=True, **kwargs):
|
| 397 |
+
model = ResNet(
|
| 398 |
+
num_classes=num_classes,
|
| 399 |
+
loss=loss,
|
| 400 |
+
block=BasicBlock,
|
| 401 |
+
layers=[3, 4, 6, 3],
|
| 402 |
+
last_stride=2,
|
| 403 |
+
fc_dims=None,
|
| 404 |
+
dropout_p=None,
|
| 405 |
+
**kwargs
|
| 406 |
+
)
|
| 407 |
+
if pretrained:
|
| 408 |
+
init_pretrained_weights(model, model_urls["resnet34"])
|
| 409 |
+
return model
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def resnet50(num_classes, loss="softmax", pretrained=True, **kwargs):
|
| 413 |
+
model = ResNet(
|
| 414 |
+
num_classes=num_classes,
|
| 415 |
+
loss=loss,
|
| 416 |
+
block=Bottleneck,
|
| 417 |
+
layers=[3, 4, 6, 3],
|
| 418 |
+
last_stride=2,
|
| 419 |
+
fc_dims=None,
|
| 420 |
+
dropout_p=None,
|
| 421 |
+
**kwargs
|
| 422 |
+
)
|
| 423 |
+
if pretrained:
|
| 424 |
+
init_pretrained_weights(model, model_urls["resnet50"])
|
| 425 |
+
return model
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def resnet101(num_classes, loss="softmax", pretrained=True, **kwargs):
|
| 429 |
+
model = ResNet(
|
| 430 |
+
num_classes=num_classes,
|
| 431 |
+
loss=loss,
|
| 432 |
+
block=Bottleneck,
|
| 433 |
+
layers=[3, 4, 23, 3],
|
| 434 |
+
last_stride=2,
|
| 435 |
+
fc_dims=None,
|
| 436 |
+
dropout_p=None,
|
| 437 |
+
**kwargs
|
| 438 |
+
)
|
| 439 |
+
if pretrained:
|
| 440 |
+
init_pretrained_weights(model, model_urls["resnet101"])
|
| 441 |
+
return model
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def resnet152(num_classes, loss="softmax", pretrained=True, **kwargs):
|
| 445 |
+
model = ResNet(
|
| 446 |
+
num_classes=num_classes,
|
| 447 |
+
loss=loss,
|
| 448 |
+
block=Bottleneck,
|
| 449 |
+
layers=[3, 8, 36, 3],
|
| 450 |
+
last_stride=2,
|
| 451 |
+
fc_dims=None,
|
| 452 |
+
dropout_p=None,
|
| 453 |
+
**kwargs
|
| 454 |
+
)
|
| 455 |
+
if pretrained:
|
| 456 |
+
init_pretrained_weights(model, model_urls["resnet152"])
|
| 457 |
+
return model
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
"""ResNeXt"""
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def resnext50_32x4d(num_classes, loss="softmax", pretrained=True, **kwargs):
|
| 464 |
+
model = ResNet(
|
| 465 |
+
num_classes=num_classes,
|
| 466 |
+
loss=loss,
|
| 467 |
+
block=Bottleneck,
|
| 468 |
+
layers=[3, 4, 6, 3],
|
| 469 |
+
last_stride=2,
|
| 470 |
+
fc_dims=None,
|
| 471 |
+
dropout_p=None,
|
| 472 |
+
groups=32,
|
| 473 |
+
width_per_group=4,
|
| 474 |
+
**kwargs
|
| 475 |
+
)
|
| 476 |
+
if pretrained:
|
| 477 |
+
init_pretrained_weights(model, model_urls["resnext50_32x4d"])
|
| 478 |
+
return model
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def resnext101_32x8d(num_classes, loss="softmax", pretrained=True, **kwargs):
|
| 482 |
+
model = ResNet(
|
| 483 |
+
num_classes=num_classes,
|
| 484 |
+
loss=loss,
|
| 485 |
+
block=Bottleneck,
|
| 486 |
+
layers=[3, 4, 23, 3],
|
| 487 |
+
last_stride=2,
|
| 488 |
+
fc_dims=None,
|
| 489 |
+
dropout_p=None,
|
| 490 |
+
groups=32,
|
| 491 |
+
width_per_group=8,
|
| 492 |
+
**kwargs
|
| 493 |
+
)
|
| 494 |
+
if pretrained:
|
| 495 |
+
init_pretrained_weights(model, model_urls["resnext101_32x8d"])
|
| 496 |
+
return model
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
"""
|
| 500 |
+
ResNet + FC
|
| 501 |
+
"""
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def resnet50_fc512(num_classes, loss="softmax", pretrained=True, **kwargs):
|
| 505 |
+
model = ResNet(
|
| 506 |
+
num_classes=num_classes,
|
| 507 |
+
loss=loss,
|
| 508 |
+
block=Bottleneck,
|
| 509 |
+
layers=[3, 4, 6, 3],
|
| 510 |
+
last_stride=1,
|
| 511 |
+
fc_dims=[512],
|
| 512 |
+
dropout_p=None,
|
| 513 |
+
**kwargs
|
| 514 |
+
)
|
| 515 |
+
if pretrained:
|
| 516 |
+
init_pretrained_weights(model, model_urls["resnet50"])
|
| 517 |
+
return model
|
boxmot/appearance/backends/base_backend.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import torch
|
| 3 |
+
import gdown
|
| 4 |
+
import numpy as np
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from boxmot.utils import logger as LOGGER
|
| 7 |
+
from boxmot.appearance.reid.registry import ReIDModelRegistry
|
| 8 |
+
from boxmot.utils.checks import RequirementsChecker
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseModelBackend:
|
| 12 |
+
def __init__(self, weights, device, half):
|
| 13 |
+
self.weights = weights[0] if isinstance(weights, list) else weights
|
| 14 |
+
self.device = device
|
| 15 |
+
self.half = half
|
| 16 |
+
self.model = None
|
| 17 |
+
self.cuda = torch.cuda.is_available() and self.device.type != "cpu"
|
| 18 |
+
|
| 19 |
+
self.download_model(self.weights)
|
| 20 |
+
self.model_name = ReIDModelRegistry.get_model_name(self.weights)
|
| 21 |
+
|
| 22 |
+
self.model = ReIDModelRegistry.build_model(
|
| 23 |
+
self.model_name,
|
| 24 |
+
num_classes=ReIDModelRegistry.get_nr_classes(self.weights),
|
| 25 |
+
pretrained=not (self.weights and self.weights.is_file()),
|
| 26 |
+
use_gpu=device,
|
| 27 |
+
)
|
| 28 |
+
self.checker = RequirementsChecker()
|
| 29 |
+
self.load_model(self.weights)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_crops(self, xyxys, img):
|
| 33 |
+
h, w = img.shape[:2]
|
| 34 |
+
resize_dims = (128, 256)
|
| 35 |
+
interpolation_method = cv2.INTER_LINEAR
|
| 36 |
+
mean_array = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
|
| 37 |
+
std_array = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
|
| 38 |
+
|
| 39 |
+
# Preallocate tensor for crops
|
| 40 |
+
num_crops = len(xyxys)
|
| 41 |
+
crops = torch.empty((num_crops, 3, resize_dims[1], resize_dims[0]),
|
| 42 |
+
dtype=torch.half if self.half else torch.float, device=self.device)
|
| 43 |
+
|
| 44 |
+
for i, box in enumerate(xyxys):
|
| 45 |
+
x1, y1, x2, y2 = box.round().astype('int')
|
| 46 |
+
x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w, x2), min(h, y2)
|
| 47 |
+
crop = img[y1:y2, x1:x2]
|
| 48 |
+
|
| 49 |
+
# Resize and convert color in one step
|
| 50 |
+
crop = cv2.resize(crop, resize_dims, interpolation=interpolation_method)
|
| 51 |
+
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
| 52 |
+
|
| 53 |
+
# Convert to tensor and normalize (convert to [0, 1] by dividing by 255 in batch later)
|
| 54 |
+
crop = torch.from_numpy(crop).to(self.device, dtype=torch.half if self.half else torch.float)
|
| 55 |
+
crops[i] = torch.permute(crop, (2, 0, 1)) # Change to (C, H, W)
|
| 56 |
+
|
| 57 |
+
# Normalize the entire batch in one go
|
| 58 |
+
crops = crops / 255.0
|
| 59 |
+
|
| 60 |
+
# Standardize the batch
|
| 61 |
+
crops = (crops - mean_array) / std_array
|
| 62 |
+
|
| 63 |
+
return crops
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def get_features(self, xyxys, img):
|
| 68 |
+
if xyxys.size != 0:
|
| 69 |
+
crops = self.get_crops(xyxys, img)
|
| 70 |
+
crops = self.inference_preprocess(crops)
|
| 71 |
+
features = self.forward(crops)
|
| 72 |
+
features = self.inference_postprocess(features)
|
| 73 |
+
else:
|
| 74 |
+
features = np.array([])
|
| 75 |
+
features = features / np.linalg.norm(features, axis=-1, keepdims=True)
|
| 76 |
+
return features
|
| 77 |
+
|
| 78 |
+
def warmup(self, imgsz=[(256, 128, 3)]):
|
| 79 |
+
# warmup model by running inference once
|
| 80 |
+
if self.device.type != "cpu":
|
| 81 |
+
im = np.random.randint(0, 255, *imgsz, dtype=np.uint8)
|
| 82 |
+
crops = self.get_crops(xyxys=np.array(
|
| 83 |
+
[[0, 0, 64, 64], [0, 0, 128, 128]]),
|
| 84 |
+
img=im
|
| 85 |
+
)
|
| 86 |
+
crops = self.inference_preprocess(crops)
|
| 87 |
+
self.forward(crops) # warmup
|
| 88 |
+
|
| 89 |
+
def to_numpy(self, x):
|
| 90 |
+
return x.cpu().numpy() if isinstance(x, torch.Tensor) else x
|
| 91 |
+
|
| 92 |
+
def inference_preprocess(self, x):
|
| 93 |
+
if self.half:
|
| 94 |
+
if isinstance(x, torch.Tensor):
|
| 95 |
+
if x.dtype != torch.float16:
|
| 96 |
+
x = x.half()
|
| 97 |
+
elif isinstance(x, np.ndarray):
|
| 98 |
+
if x.dtype != np.float16:
|
| 99 |
+
x = x.astype(np.float16)
|
| 100 |
+
|
| 101 |
+
if self.nhwc:
|
| 102 |
+
if isinstance(x, torch.Tensor):
|
| 103 |
+
x = x.permute(0, 2, 3, 1) # Convert from NCHW to NHWC
|
| 104 |
+
elif isinstance(x, np.ndarray):
|
| 105 |
+
x = np.transpose(x, (0, 2, 3, 1)) # Convert from NCHW to NHWC
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
def inference_postprocess(self, features):
|
| 109 |
+
if isinstance(features, (list, tuple)):
|
| 110 |
+
return (
|
| 111 |
+
self.to_numpy(features[0]) if len(features) == 1 else [self.to_numpy(x) for x in features]
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
return self.to_numpy(features)
|
| 115 |
+
|
| 116 |
+
@abstractmethod
|
| 117 |
+
def forward(self, im_batch):
|
| 118 |
+
raise NotImplementedError("This method should be implemented by subclasses.")
|
| 119 |
+
|
| 120 |
+
@abstractmethod
|
| 121 |
+
def load_model(self, w):
|
| 122 |
+
raise NotImplementedError("This method should be implemented by subclasses.")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def download_model(self, w):
|
| 126 |
+
if w.suffix == ".pt":
|
| 127 |
+
model_url = ReIDModelRegistry.get_model_url(w)
|
| 128 |
+
if not w.exists() and model_url is not None:
|
| 129 |
+
gdown.download(model_url, str(w), quiet=False)
|
| 130 |
+
elif not w.exists():
|
| 131 |
+
LOGGER.error(
|
| 132 |
+
f"No URL associated with the chosen StrongSORT weights ({w}). Choose between:"
|
| 133 |
+
)
|
| 134 |
+
ReIDModelRegistry.show_downloadable_models()
|
| 135 |
+
exit()
|
boxmot/appearance/backends/onnx_backend.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from boxmot.appearance.backends.base_backend import BaseModelBackend
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ONNXBackend(BaseModelBackend):
|
| 8 |
+
|
| 9 |
+
def __init__(self, weights, device, half):
|
| 10 |
+
super().__init__(weights, device, half)
|
| 11 |
+
self.nhwc = False
|
| 12 |
+
self.half = half
|
| 13 |
+
|
| 14 |
+
def load_model(self, w):
|
| 15 |
+
|
| 16 |
+
# ONNXRuntime will attempt to use the first provider, and if it fails or is not
|
| 17 |
+
# available for some reason, it will fall back to the next provider in the list
|
| 18 |
+
if self.device == "mps":
|
| 19 |
+
self.checker.check_packages(("onnxruntime-silicon==1.17.0",))
|
| 20 |
+
providers = ["MPSExecutionProvider", "CPUExecutionProvider"]
|
| 21 |
+
elif self.device == "cuda":
|
| 22 |
+
self.checker.check_packages(("onnxruntime-gpu==1.17.0",))
|
| 23 |
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 24 |
+
else:
|
| 25 |
+
self.checker.check_packages(("onnxruntime==1.17.0",))
|
| 26 |
+
providers = ["CPUExecutionProvider"]
|
| 27 |
+
|
| 28 |
+
# Load the ONNX model using onnxruntime
|
| 29 |
+
import onnxruntime
|
| 30 |
+
self.session = onnxruntime.InferenceSession(str(w), providers=providers)
|
| 31 |
+
|
| 32 |
+
def forward(self, im_batch):
|
| 33 |
+
# Convert torch tensor to numpy (onnxruntime expects numpy arrays)
|
| 34 |
+
im_batch = im_batch.cpu().numpy()
|
| 35 |
+
|
| 36 |
+
# Run inference using ONNX session
|
| 37 |
+
features = self.session.run(
|
| 38 |
+
[self.session.get_outputs()[0].name],
|
| 39 |
+
{self.session.get_inputs()[0].name: im_batch},
|
| 40 |
+
)[0]
|
| 41 |
+
|
| 42 |
+
return features
|
boxmot/appearance/backends/openvino_backend.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from boxmot.utils import logger as LOGGER
|
| 4 |
+
|
| 5 |
+
from boxmot.appearance.backends.base_backend import BaseModelBackend
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class OpenVinoBackend(BaseModelBackend):
|
| 9 |
+
|
| 10 |
+
def __init__(self, weights, device, half):
|
| 11 |
+
super().__init__(weights, device, half)
|
| 12 |
+
self.nhwc = False
|
| 13 |
+
self.half = half
|
| 14 |
+
|
| 15 |
+
def load_model(self, w):
|
| 16 |
+
self.checker.check_packages(("openvino-dev>=2022.3",))
|
| 17 |
+
|
| 18 |
+
LOGGER.info(f"Loading {w} for OpenVINO inference...")
|
| 19 |
+
try:
|
| 20 |
+
# requires openvino-dev: https://pypi.org/project/openvino-dev/
|
| 21 |
+
from openvino.runtime import Core, Layout
|
| 22 |
+
except ImportError:
|
| 23 |
+
LOGGER.error(
|
| 24 |
+
f"Running {self.__class__} with the specified OpenVINO weights\n{w.name}\n"
|
| 25 |
+
"requires openvino pip package to be installed!\n"
|
| 26 |
+
"$ pip install openvino-dev>=2022.3\n"
|
| 27 |
+
)
|
| 28 |
+
ie = Core()
|
| 29 |
+
if not Path(w).is_file(): # if not *.xml
|
| 30 |
+
w = next(
|
| 31 |
+
Path(w).glob("*.xml")
|
| 32 |
+
) # get *.xml file from *_openvino_model dir
|
| 33 |
+
network = ie.read_model(model=w, weights=Path(w).with_suffix(".bin"))
|
| 34 |
+
if network.get_parameters()[0].get_layout().empty:
|
| 35 |
+
network.get_parameters()[0].set_layout(Layout("NCWH"))
|
| 36 |
+
self.executable_network = ie.compile_model(
|
| 37 |
+
network, device_name="CPU"
|
| 38 |
+
) # device_name="MYRIAD" for Intel NCS2
|
| 39 |
+
self.output_layer = next(iter(self.executable_network.outputs))
|
| 40 |
+
|
| 41 |
+
def forward(self, im_batch):
|
| 42 |
+
im_batch = im_batch.cpu().numpy() # FP32
|
| 43 |
+
features = self.executable_network([im_batch])[self.output_layer]
|
| 44 |
+
return features
|
boxmot/appearance/backends/pytorch_backend.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from boxmot.appearance.backends.base_backend import BaseModelBackend
|
| 5 |
+
from boxmot.appearance.reid.registry import ReIDModelRegistry
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PyTorchBackend(BaseModelBackend):
|
| 9 |
+
|
| 10 |
+
def __init__(self, weights, device, half):
|
| 11 |
+
super().__init__(weights, device, half)
|
| 12 |
+
self.nhwc = False
|
| 13 |
+
self.half = half
|
| 14 |
+
|
| 15 |
+
def load_model(self, w):
|
| 16 |
+
# Load a PyTorch model
|
| 17 |
+
if w and w.is_file():
|
| 18 |
+
ReIDModelRegistry.load_pretrained_weights(self.model, w)
|
| 19 |
+
self.model.to(self.device).eval()
|
| 20 |
+
self.model.half() if self.half else self.model.float()
|
| 21 |
+
|
| 22 |
+
def forward(self, im_batch):
|
| 23 |
+
features = self.model(im_batch)
|
| 24 |
+
return features
|
boxmot/appearance/backends/tensorrt_backend.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from collections import OrderedDict, namedtuple
|
| 5 |
+
from boxmot.utils import logger as LOGGER
|
| 6 |
+
from boxmot.appearance.backends.base_backend import BaseModelBackend
|
| 7 |
+
|
| 8 |
+
class TensorRTBackend(BaseModelBackend):
|
| 9 |
+
def __init__(self, weights, device, half):
|
| 10 |
+
self.is_trt10 = False
|
| 11 |
+
super().__init__(weights, device, half)
|
| 12 |
+
self.nhwc = False
|
| 13 |
+
self.half = half
|
| 14 |
+
self.device = device
|
| 15 |
+
self.weights = weights
|
| 16 |
+
self.fp16 = False # Will be updated in load_model
|
| 17 |
+
self.load_model(self.weights)
|
| 18 |
+
|
| 19 |
+
def load_model(self, w):
|
| 20 |
+
LOGGER.info(f"Loading {w} for TensorRT inference...")
|
| 21 |
+
self.checker.check_packages(("nvidia-tensorrt",))
|
| 22 |
+
try:
|
| 23 |
+
import tensorrt as trt # TensorRT library
|
| 24 |
+
except ImportError:
|
| 25 |
+
raise ImportError("Please install tensorrt to use this backend.")
|
| 26 |
+
|
| 27 |
+
if self.device.type == "cpu":
|
| 28 |
+
if torch.cuda.is_available():
|
| 29 |
+
self.device = torch.device("cuda:0")
|
| 30 |
+
else:
|
| 31 |
+
raise ValueError("CUDA device not available for TensorRT inference.")
|
| 32 |
+
|
| 33 |
+
Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr"))
|
| 34 |
+
logger = trt.Logger(trt.Logger.INFO)
|
| 35 |
+
|
| 36 |
+
# Deserialize the engine
|
| 37 |
+
with open(w, "rb") as f, trt.Runtime(logger) as runtime:
|
| 38 |
+
self.model_ = runtime.deserialize_cuda_engine(f.read())
|
| 39 |
+
|
| 40 |
+
# Execution context
|
| 41 |
+
self.context = self.model_.create_execution_context()
|
| 42 |
+
self.bindings = OrderedDict()
|
| 43 |
+
|
| 44 |
+
self.is_trt10 = not hasattr(self.model_, "num_bindings")
|
| 45 |
+
num = range(self.model_.num_io_tensors) if self.is_trt10 else range(self.model_.num_bindings)
|
| 46 |
+
|
| 47 |
+
# Parse bindings
|
| 48 |
+
for index in num:
|
| 49 |
+
if self.is_trt10:
|
| 50 |
+
name = self.model_.get_tensor_name(index)
|
| 51 |
+
dtype = trt.nptype(self.model_.get_tensor_dtype(name))
|
| 52 |
+
is_input = self.model_.get_tensor_mode(name) == trt.TensorIOMode.INPUT
|
| 53 |
+
if is_input and -1 in tuple(self.model_.get_tensor_shape(name)):
|
| 54 |
+
self.context.set_input_shape(name, tuple(self.model_.get_tensor_profile_shape(name, 0)[1]))
|
| 55 |
+
if is_input and dtype == np.float16:
|
| 56 |
+
self.fp16 = True
|
| 57 |
+
|
| 58 |
+
shape = tuple(self.context.get_tensor_shape(name))
|
| 59 |
+
|
| 60 |
+
else:
|
| 61 |
+
name = self.model_.get_binding_name(index)
|
| 62 |
+
dtype = trt.nptype(self.model_.get_binding_dtype(index))
|
| 63 |
+
is_input = self.model_.binding_is_input(index)
|
| 64 |
+
|
| 65 |
+
# Handle dynamic shapes
|
| 66 |
+
if is_input and -1 in self.model_.get_binding_shape(index):
|
| 67 |
+
profile_index = 0
|
| 68 |
+
min_shape, opt_shape, max_shape = self.model_.get_profile_shape(profile_index, index)
|
| 69 |
+
self.context.set_binding_shape(index, opt_shape)
|
| 70 |
+
|
| 71 |
+
if is_input and dtype == np.float16:
|
| 72 |
+
self.fp16 = True
|
| 73 |
+
|
| 74 |
+
shape = tuple(self.context.get_binding_shape(index))
|
| 75 |
+
data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(self.device)
|
| 76 |
+
self.bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
|
| 77 |
+
|
| 78 |
+
self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items())
|
| 79 |
+
|
| 80 |
+
def forward(self, im_batch):
|
| 81 |
+
temp_im_batch = im_batch.clone()
|
| 82 |
+
batch_array = []
|
| 83 |
+
inp_batch = im_batch.shape[0]
|
| 84 |
+
out_batch = self.bindings["output"].shape[0]
|
| 85 |
+
resultant_features = []
|
| 86 |
+
|
| 87 |
+
# Divide batch to sub batches
|
| 88 |
+
while inp_batch > out_batch:
|
| 89 |
+
batch_array.append(temp_im_batch[:out_batch])
|
| 90 |
+
temp_im_batch = temp_im_batch[out_batch:]
|
| 91 |
+
inp_batch = temp_im_batch.shape[0]
|
| 92 |
+
if temp_im_batch.shape[0] > 0:
|
| 93 |
+
batch_array.append(temp_im_batch)
|
| 94 |
+
|
| 95 |
+
for temp_batch in batch_array:
|
| 96 |
+
# Adjust for dynamic shapes
|
| 97 |
+
if temp_batch.shape != self.bindings["images"].shape:
|
| 98 |
+
if self.is_trt10:
|
| 99 |
+
|
| 100 |
+
self.context.set_input_shape("images", temp_batch.shape)
|
| 101 |
+
self.bindings["images"] = self.bindings["images"]._replace(shape=temp_batch.shape)
|
| 102 |
+
self.bindings["output"].data.resize_(tuple(self.context.get_tensor_shape("output")))
|
| 103 |
+
else:
|
| 104 |
+
i_in = self.model_.get_binding_index("images")
|
| 105 |
+
i_out = self.model_.get_binding_index("output")
|
| 106 |
+
self.context.set_binding_shape(i_in, temp_batch.shape)
|
| 107 |
+
self.bindings["images"] = self.bindings["images"]._replace(shape=temp_batch.shape)
|
| 108 |
+
output_shape = tuple(self.context.get_binding_shape(i_out))
|
| 109 |
+
self.bindings["output"].data.resize_(output_shape)
|
| 110 |
+
|
| 111 |
+
s = self.bindings["images"].shape
|
| 112 |
+
assert temp_batch.shape == s, f"Input size {temp_batch.shape} does not match model size {s}"
|
| 113 |
+
|
| 114 |
+
self.binding_addrs["images"] = int(temp_batch.data_ptr())
|
| 115 |
+
|
| 116 |
+
# Execute inference
|
| 117 |
+
self.context.execute_v2(list(self.binding_addrs.values()))
|
| 118 |
+
features = self.bindings["output"].data
|
| 119 |
+
resultant_features.append(features.clone())
|
| 120 |
+
|
| 121 |
+
if len(resultant_features)== 1:
|
| 122 |
+
return resultant_features[0]
|
| 123 |
+
else:
|
| 124 |
+
rslt_features = torch.cat(resultant_features,dim=0)
|
| 125 |
+
rslt_features= rslt_features[:im_batch.shape[0]]
|
| 126 |
+
return rslt_features
|
boxmot/appearance/backends/tflite_backend.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from boxmot.utils import logger as LOGGER
|
| 5 |
+
|
| 6 |
+
from boxmot.appearance.backends.base_backend import BaseModelBackend
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TFLiteBackend(BaseModelBackend):
|
| 10 |
+
"""
|
| 11 |
+
A class to handle TensorFlow Lite model inference with dynamic batch size support.
|
| 12 |
+
|
| 13 |
+
Attributes:
|
| 14 |
+
nhwc (bool): A flag indicating the order of dimensions.
|
| 15 |
+
half (bool): A flag to indicate if half precision is used.
|
| 16 |
+
interpreter (tf.lite.Interpreter): The TensorFlow Lite interpreter.
|
| 17 |
+
current_allocated_batch_size (int): The current batch size allocated in the interpreter.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, weights: Path, device: str, half: bool):
|
| 21 |
+
"""
|
| 22 |
+
Initializes the TFLiteBackend with given weights, device, and precision flag.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
weights (Path): Path to the TFLite model file.
|
| 26 |
+
device (str): Device type (e.g., 'cpu', 'gpu').
|
| 27 |
+
half (bool): Flag to indicate if half precision is used.
|
| 28 |
+
"""
|
| 29 |
+
super().__init__(weights, device, half)
|
| 30 |
+
self.nhwc = True
|
| 31 |
+
self.half = False
|
| 32 |
+
# self.interpreter: tf.lite.Interpreter = None
|
| 33 |
+
# self.current_allocated_batch_size: int = None
|
| 34 |
+
|
| 35 |
+
def load_model(self, w):
|
| 36 |
+
"""
|
| 37 |
+
Loads the TensorFlow Lite model and initializes the interpreter.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
w (str): Path to the TFLite model file.
|
| 41 |
+
"""
|
| 42 |
+
self.checker.check_packages(("tensorflow",))
|
| 43 |
+
|
| 44 |
+
LOGGER.info(f"Loading {str(w)} for TensorFlow Lite inference...")
|
| 45 |
+
|
| 46 |
+
import tensorflow as tf
|
| 47 |
+
self.interpreter = tf.lite.Interpreter(model_path=str(w))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
self.interpreter.allocate_tensors() # allocate
|
| 51 |
+
self.input_details = self.interpreter.get_input_details() # inputs
|
| 52 |
+
self.output_details = self.interpreter.get_output_details() # outputs
|
| 53 |
+
self.current_allocated_batch_size = self.input_details[0]['shape'][0]
|
| 54 |
+
|
| 55 |
+
def forward(self, im_batch: torch.Tensor) -> np.ndarray:
|
| 56 |
+
"""
|
| 57 |
+
Runs forward pass for the given image batch through the TFLite model.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
im_batch (torch.Tensor): Input image batch tensor.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
np.ndarray: Output features from the TFLite model.
|
| 64 |
+
"""
|
| 65 |
+
im_batch = im_batch.cpu().numpy()
|
| 66 |
+
|
| 67 |
+
# Extract batch size from im_batch
|
| 68 |
+
batch_size = im_batch.shape[0]
|
| 69 |
+
|
| 70 |
+
# Resize tensors if the new batch size is different from the current allocated batch size
|
| 71 |
+
if batch_size != self.current_allocated_batch_size:
|
| 72 |
+
# print(f"Resizing tensor input to batch size {batch_size}")
|
| 73 |
+
self.interpreter.resize_tensor_input(self.input_details[0]['index'], [batch_size, 256, 128, 3])
|
| 74 |
+
self.interpreter.allocate_tensors()
|
| 75 |
+
self.current_allocated_batch_size = batch_size
|
| 76 |
+
|
| 77 |
+
# Set the tensor to point to the input data
|
| 78 |
+
self.interpreter.set_tensor(self.input_details[0]['index'], im_batch)
|
| 79 |
+
|
| 80 |
+
# Run inference
|
| 81 |
+
self.interpreter.invoke()
|
| 82 |
+
|
| 83 |
+
# Get the output data
|
| 84 |
+
features = self.interpreter.get_tensor(self.output_details[0]['index'])
|
| 85 |
+
|
| 86 |
+
return features
|
boxmot/appearance/backends/torchscript_backend.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from boxmot.utils import logger as LOGGER
|
| 5 |
+
|
| 6 |
+
from boxmot.appearance.backends.base_backend import BaseModelBackend
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TorchscriptBackend(BaseModelBackend):
|
| 10 |
+
|
| 11 |
+
def __init__(self, weights, device, half):
|
| 12 |
+
super().__init__(weights, device, half)
|
| 13 |
+
self.nhwc = False
|
| 14 |
+
self.half = half
|
| 15 |
+
|
| 16 |
+
def load_model(self, w):
|
| 17 |
+
|
| 18 |
+
LOGGER.info(f"Loading {w} for TorchScript inference...")
|
| 19 |
+
self.model = torch.jit.load(w)
|
| 20 |
+
self.model.half() if self.half else self.model.float()
|
| 21 |
+
|
| 22 |
+
def forward(self, im_batch):
|
| 23 |
+
features = self.model(im_batch)
|
| 24 |
+
return features
|
boxmot/appearance/exporters/base_exporter.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from boxmot.utils.checks import RequirementsChecker
|
| 5 |
+
from boxmot.utils import logger as LOGGER
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def export_decorator(export_func):
|
| 9 |
+
def wrapper(self, *args, **kwargs):
|
| 10 |
+
try:
|
| 11 |
+
if hasattr(self, 'required_packages'):
|
| 12 |
+
if hasattr(self, 'cmd'):
|
| 13 |
+
self.checker.check_packages(self.required_packages, cmd=self.cmd)
|
| 14 |
+
else:
|
| 15 |
+
self.checker.check_packages(self.required_packages)
|
| 16 |
+
|
| 17 |
+
LOGGER.info(f"\nStarting {self.file} export with {self.__class__.__name__}...")
|
| 18 |
+
result = export_func(self, *args, **kwargs)
|
| 19 |
+
if result:
|
| 20 |
+
LOGGER.info(f"Export success, saved as {result} ({self.file_size(result):.1f} MB)")
|
| 21 |
+
return result
|
| 22 |
+
except Exception as e:
|
| 23 |
+
LOGGER.error(f"Export failure: {e}")
|
| 24 |
+
return None
|
| 25 |
+
return wrapper
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BaseExporter:
|
| 29 |
+
def __init__(self, model, im, file, optimize=False, dynamic=False, half=False, simplify=False):
|
| 30 |
+
self.model = model
|
| 31 |
+
self.im = im
|
| 32 |
+
self.file = Path(file)
|
| 33 |
+
self.optimize = optimize
|
| 34 |
+
self.dynamic = dynamic
|
| 35 |
+
self.half = half
|
| 36 |
+
self.simplify = simplify
|
| 37 |
+
self.checker = RequirementsChecker()
|
| 38 |
+
self.workspace = 4
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def file_size(path):
|
| 42 |
+
path = Path(path)
|
| 43 |
+
if path.is_file():
|
| 44 |
+
return path.stat().st_size / 1e6
|
| 45 |
+
elif path.is_dir():
|
| 46 |
+
return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / 1e6
|
| 47 |
+
else:
|
| 48 |
+
return 0.0
|
| 49 |
+
|
| 50 |
+
def export(self):
|
| 51 |
+
raise NotImplementedError("Export method must be implemented in subclasses.")
|
| 52 |
+
|
| 53 |
+
def __init_subclass__(cls, **kwargs):
|
| 54 |
+
super().__init_subclass__(**kwargs)
|
| 55 |
+
if 'export' in cls.__dict__:
|
| 56 |
+
cls.export = export_decorator(cls.export)
|
boxmot/appearance/exporters/onnx_exporter.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import onnx
|
| 3 |
+
from boxmot.appearance.exporters.base_exporter import BaseExporter
|
| 4 |
+
from boxmot.utils import logger as LOGGER
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ONNXExporter(BaseExporter):
|
| 8 |
+
required_packages = ("onnx>=1.16.1",)
|
| 9 |
+
|
| 10 |
+
def export(self):
|
| 11 |
+
|
| 12 |
+
f = self.file.with_suffix(".onnx")
|
| 13 |
+
|
| 14 |
+
dynamic = {"images": {0: "batch"}, "output": {0: "batch"}} if self.dynamic else None
|
| 15 |
+
|
| 16 |
+
torch.onnx.export(
|
| 17 |
+
self.model.cpu() if self.dynamic else self.model,
|
| 18 |
+
self.im.cpu() if self.dynamic else self.im,
|
| 19 |
+
f,
|
| 20 |
+
verbose=False,
|
| 21 |
+
opset_version=12,
|
| 22 |
+
do_constant_folding=True,
|
| 23 |
+
input_names=["images"],
|
| 24 |
+
output_names=["output"],
|
| 25 |
+
dynamic_axes=dynamic,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
model_onnx = onnx.load(f)
|
| 29 |
+
onnx.checker.check_model(model_onnx)
|
| 30 |
+
onnx.save(model_onnx, f)
|
| 31 |
+
|
| 32 |
+
if self.simplify:
|
| 33 |
+
self.simplify_model(model_onnx, f)
|
| 34 |
+
|
| 35 |
+
return f
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def simplify_model(self, model_onnx, f):
|
| 39 |
+
try:
|
| 40 |
+
cuda = torch.cuda.is_available()
|
| 41 |
+
self.checker.check_packages(
|
| 42 |
+
(
|
| 43 |
+
"onnxruntime-gpu" if cuda else "onnxruntime",
|
| 44 |
+
"onnx-simplifier>=0.4.1",
|
| 45 |
+
)
|
| 46 |
+
)
|
| 47 |
+
import onnxsim
|
| 48 |
+
|
| 49 |
+
LOGGER.info(
|
| 50 |
+
f"Simplifying with onnx-simplifier {onnxsim.__version__}..."
|
| 51 |
+
)
|
| 52 |
+
model_onnx, check = onnxsim.simplify(model_onnx)
|
| 53 |
+
assert check, "assert check failed"
|
| 54 |
+
onnx.save(model_onnx, f)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
LOGGER.error(f"Simplifier failure: {e}")
|
boxmot/appearance/exporters/openvino_exporter.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import openvino.runtime as ov
|
| 4 |
+
from openvino.tools import mo
|
| 5 |
+
from boxmot.appearance.exporters.base_exporter import BaseExporter
|
| 6 |
+
from boxmot.utils import logger as LOGGER
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class OpenVINOExporter(BaseExporter):
|
| 10 |
+
required_packages = ("openvino-dev>=2023.0",)
|
| 11 |
+
|
| 12 |
+
def export(self):
|
| 13 |
+
|
| 14 |
+
f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}")
|
| 15 |
+
f_onnx = self.file.with_suffix(".onnx")
|
| 16 |
+
f_ov = str(Path(f) / self.file.with_suffix(".xml").name)
|
| 17 |
+
|
| 18 |
+
ov_model = mo.convert_model(
|
| 19 |
+
f_onnx,
|
| 20 |
+
model_name=self.file.with_suffix(".xml"),
|
| 21 |
+
framework="onnx",
|
| 22 |
+
compress_to_fp16=self.half,
|
| 23 |
+
)
|
| 24 |
+
ov.serialize(ov_model, f_ov)
|
| 25 |
+
|
| 26 |
+
return f
|
boxmot/appearance/exporters/tensorrt_exporter.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import platform
|
| 2 |
+
import torch
|
| 3 |
+
from boxmot.appearance.exporters.base_exporter import BaseExporter
|
| 4 |
+
from boxmot.appearance.exporters.onnx_exporter import ONNXExporter
|
| 5 |
+
from boxmot.utils import logger as LOGGER
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class EngineExporter(BaseExporter):
|
| 9 |
+
required_packages = ("nvidia-tensorrt",)
|
| 10 |
+
cmds = '--extra-index-url https://pypi.ngc.nvidia.com'
|
| 11 |
+
|
| 12 |
+
def export(self):
|
| 13 |
+
|
| 14 |
+
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. `python export.py --device 0`"
|
| 15 |
+
try:
|
| 16 |
+
import tensorrt as trt
|
| 17 |
+
except ImportError:
|
| 18 |
+
import tensorrt as trt
|
| 19 |
+
|
| 20 |
+
onnx_file = self.export_onnx()
|
| 21 |
+
LOGGER.info(f"\nStarting export with TensorRT {trt.__version__}...")
|
| 22 |
+
is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10
|
| 23 |
+
assert onnx_file.exists(), f"Failed to export ONNX file: {onnx_file}"
|
| 24 |
+
f = self.file.with_suffix(".engine")
|
| 25 |
+
logger = trt.Logger(trt.Logger.INFO)
|
| 26 |
+
if True:
|
| 27 |
+
logger.min_severity = trt.Logger.Severity.VERBOSE
|
| 28 |
+
|
| 29 |
+
builder = trt.Builder(logger)
|
| 30 |
+
config = builder.create_builder_config()
|
| 31 |
+
workspace = int(self.workspace * (1 << 30))
|
| 32 |
+
if is_trt10:
|
| 33 |
+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)
|
| 34 |
+
else: # TensorRT versions 7, 8
|
| 35 |
+
config.max_workspace_size = workspace
|
| 36 |
+
|
| 37 |
+
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
| 38 |
+
network = builder.create_network(flag)
|
| 39 |
+
parser = trt.OnnxParser(network, logger)
|
| 40 |
+
if not parser.parse_from_file(str(onnx_file)):
|
| 41 |
+
raise RuntimeError(f"Failed to load ONNX file: {onnx_file}")
|
| 42 |
+
|
| 43 |
+
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
| 44 |
+
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
| 45 |
+
LOGGER.info("Network Description:")
|
| 46 |
+
for inp in inputs:
|
| 47 |
+
LOGGER.info(f'\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
|
| 48 |
+
for out in outputs:
|
| 49 |
+
LOGGER.info(f'\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
|
| 50 |
+
|
| 51 |
+
if self.dynamic:
|
| 52 |
+
if self.im.shape[0] <= 1:
|
| 53 |
+
LOGGER.warning("WARNING: --dynamic model requires maximum --batch-size argument")
|
| 54 |
+
profile = builder.create_optimization_profile()
|
| 55 |
+
for inp in inputs:
|
| 56 |
+
if self.half:
|
| 57 |
+
inp.dtype = trt.float16
|
| 58 |
+
profile.set_shape(
|
| 59 |
+
inp.name,
|
| 60 |
+
(1, *self.im.shape[1:]),
|
| 61 |
+
(max(1, self.im.shape[0] // 2), *self.im.shape[1:]),
|
| 62 |
+
self.im.shape,
|
| 63 |
+
)
|
| 64 |
+
config.add_optimization_profile(profile)
|
| 65 |
+
|
| 66 |
+
LOGGER.info(f"Building FP{16 if builder.platform_has_fast_fp16 and self.half else 32} engine in {f}")
|
| 67 |
+
if builder.platform_has_fast_fp16 and self.half:
|
| 68 |
+
config.set_flag(trt.BuilderFlag.FP16)
|
| 69 |
+
config.default_device_type = trt.DeviceType.GPU
|
| 70 |
+
|
| 71 |
+
build = builder.build_serialized_network if is_trt10 else builder.build_engine
|
| 72 |
+
with build(network, config) as engine, open(f, "wb") as t:
|
| 73 |
+
t.write(engine if is_trt10 else engine.serialize())
|
| 74 |
+
|
| 75 |
+
return f
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def export_onnx(self):
|
| 79 |
+
onnx_exporter = ONNXExporter(self.model, self.im, self.file, self.optimize, self.dynamic, self.half, self.simplify)
|
| 80 |
+
return onnx_exporter.export()
|
boxmot/appearance/exporters/tflite_exporter.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from boxmot.appearance.exporters.base_exporter import BaseExporter
|
| 3 |
+
from boxmot.utils import logger as LOGGER
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TFLiteExporter(BaseExporter):
|
| 7 |
+
required_packages = (
|
| 8 |
+
"onnx2tf>=1.18.0",
|
| 9 |
+
"onnx>=1.16.1",
|
| 10 |
+
"tensorflow==2.17.0",
|
| 11 |
+
"tf_keras", # required by 'onnx2tf' package
|
| 12 |
+
"sng4onnx>=1.0.1", # required by 'onnx2tf' package
|
| 13 |
+
"onnx_graphsurgeon>=0.3.26", # required by 'onnx2tf' package
|
| 14 |
+
"onnxslim>=0.1.31",
|
| 15 |
+
"onnxruntime",
|
| 16 |
+
"flatbuffers>=23.5.26",
|
| 17 |
+
"psutil==5.9.5",
|
| 18 |
+
"ml_dtypes==0.3.2",
|
| 19 |
+
"ai_edge_litert>=1.2.0"
|
| 20 |
+
)
|
| 21 |
+
cmds = '--extra-index-url https://pypi.ngc.nvidia.com'
|
| 22 |
+
|
| 23 |
+
def export(self):
|
| 24 |
+
|
| 25 |
+
import onnx2tf
|
| 26 |
+
input_onnx_file_path = str(self.file.with_suffix('.onnx'))
|
| 27 |
+
output_folder_path = input_onnx_file_path.replace(".onnx", f"_saved_model{os.sep}")
|
| 28 |
+
onnx2tf.convert(
|
| 29 |
+
input_onnx_file_path=input_onnx_file_path,
|
| 30 |
+
output_folder_path=output_folder_path,
|
| 31 |
+
not_use_onnxsim=True,
|
| 32 |
+
verbosity=True,
|
| 33 |
+
# output_integer_quantized_tflite=self.args.int8,
|
| 34 |
+
# quant_type="per-tensor", # "per-tensor" (faster) or "per-channel" (slower but more accurate)
|
| 35 |
+
# custom_input_op_name_np_data_path=np_data,
|
| 36 |
+
)
|
| 37 |
+
return output_folder_path
|
boxmot/appearance/exporters/torchscript_exporter.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from boxmot.appearance.exporters.base_exporter import BaseExporter
|
| 3 |
+
from boxmot.utils import logger as LOGGER
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TorchScriptExporter(BaseExporter):
|
| 7 |
+
def export(self):
|
| 8 |
+
f = self.file.with_suffix(".torchscript")
|
| 9 |
+
ts = torch.jit.trace(self.model, self.im, strict=False)
|
| 10 |
+
if self.optimize:
|
| 11 |
+
torch.utils.mobile_optimizer.optimize_for_mobile(ts)._save_for_lite_interpreter(str(f))
|
| 12 |
+
else:
|
| 13 |
+
ts.save(str(f))
|
| 14 |
+
|
| 15 |
+
return f
|
boxmot/appearance/reid/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def export_formats():
|
| 7 |
+
# yolo tracking export formats
|
| 8 |
+
x = [
|
| 9 |
+
["PyTorch", "-", ".pt", True, True],
|
| 10 |
+
["TorchScript", "torchscript", ".torchscript", True, True],
|
| 11 |
+
["ONNX", "onnx", ".onnx", True, True],
|
| 12 |
+
["OpenVINO", "openvino", "_openvino_model", True, False],
|
| 13 |
+
["TensorRT", "engine", ".engine", False, True],
|
| 14 |
+
["TensorFlow Lite", "tflite", ".tflite", True, False],
|
| 15 |
+
]
|
| 16 |
+
return pd.DataFrame(x, columns=["Format", "Argument", "Suffix", "CPU", "GPU"])
|
boxmot/appearance/reid/auto_backend.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Union, Tuple
|
| 4 |
+
|
| 5 |
+
from boxmot.utils import WEIGHTS
|
| 6 |
+
from boxmot.utils import logger as LOGGER
|
| 7 |
+
from boxmot.utils.torch_utils import select_device
|
| 8 |
+
from boxmot.appearance.reid import export_formats
|
| 9 |
+
from boxmot.appearance.backends.onnx_backend import ONNXBackend
|
| 10 |
+
from boxmot.appearance.backends.openvino_backend import OpenVinoBackend
|
| 11 |
+
from boxmot.appearance.backends.pytorch_backend import PyTorchBackend
|
| 12 |
+
from boxmot.appearance.backends.tensorrt_backend import TensorRTBackend
|
| 13 |
+
from boxmot.appearance.backends.tflite_backend import TFLiteBackend
|
| 14 |
+
from boxmot.appearance.backends.torchscript_backend import TorchscriptBackend
|
| 15 |
+
from boxmot.appearance.backends.base_backend import BaseModelBackend
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ReidAutoBackend():
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
weights: Path = WEIGHTS / "osnet_x0_25_msmt17.pt",
|
| 23 |
+
device: torch.device = torch.device("cpu"),
|
| 24 |
+
half: bool = False) -> None:
|
| 25 |
+
"""
|
| 26 |
+
Initializes the ReidAutoBackend instance with specified weights, device, and precision mode.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
weights (Union[str, List[str]]): Path to the model weights. Can be a string or a list of strings; if a list, the first element is used.
|
| 30 |
+
device (torch.device): The device to run the model on, e.g., CPU or GPU.
|
| 31 |
+
half (bool): Whether to use half precision for model inference.
|
| 32 |
+
"""
|
| 33 |
+
super().__init__()
|
| 34 |
+
w = weights[0] if isinstance(weights, list) else weights
|
| 35 |
+
(
|
| 36 |
+
self.pt,
|
| 37 |
+
self.jit,
|
| 38 |
+
self.onnx,
|
| 39 |
+
self.xml,
|
| 40 |
+
self.engine,
|
| 41 |
+
self.tflite,
|
| 42 |
+
) = self.model_type(w) # get backend
|
| 43 |
+
|
| 44 |
+
self.weights = weights
|
| 45 |
+
self.device = select_device(device)
|
| 46 |
+
self.half = half
|
| 47 |
+
self.model = self.get_backend()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_backend(self) -> Union['PyTorchBackend', 'TorchscriptBackend', 'ONNXBackend', 'TensorRTBackend', 'OpenVinoBackend', 'TFLiteBackend']:
|
| 51 |
+
"""
|
| 52 |
+
Returns an instance of the appropriate backend based on the model type.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
An instance of a backend class corresponding to the detected model type.
|
| 56 |
+
|
| 57 |
+
Raises:
|
| 58 |
+
SystemExit: If no supported model framework is detected.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
# Mapping of conditions to backend constructors
|
| 62 |
+
backend_map = {
|
| 63 |
+
self.pt: PyTorchBackend,
|
| 64 |
+
self.jit: TorchscriptBackend,
|
| 65 |
+
self.onnx: ONNXBackend,
|
| 66 |
+
self.engine: TensorRTBackend,
|
| 67 |
+
self.xml: OpenVinoBackend,
|
| 68 |
+
self.tflite: TFLiteBackend
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Iterate through the mapping and return the first matching backend
|
| 72 |
+
for condition, backend_class in backend_map.items():
|
| 73 |
+
if condition:
|
| 74 |
+
return backend_class(self.weights, self.device, self.half)
|
| 75 |
+
|
| 76 |
+
# If no condition is met, log an error and exit
|
| 77 |
+
LOGGER.error("This model framework is not supported yet!")
|
| 78 |
+
exit()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def forward(self, im_batch: torch.Tensor) -> torch.Tensor:
|
| 82 |
+
"""
|
| 83 |
+
Processes an image batch through the selected backend and returns the processed batch.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
im_batch (torch.Tensor): The batch of images to process.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
torch.Tensor: The processed image batch.
|
| 90 |
+
"""
|
| 91 |
+
im_batch = self.backend.preprocess_input(im_batch)
|
| 92 |
+
return self.backend.get_features(im_batch)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def check_suffix(self, file: Path = "osnet_x0_25_msmt17.pt", suffix: Union[str, Tuple[str, ...]] = (".pt",), msg: str = "") -> None:
|
| 96 |
+
"""
|
| 97 |
+
Validates that the file or files have an acceptable suffix.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
file (Union[str, List[str], Path]): The file or files to check.
|
| 101 |
+
suffix (Union[str, Tuple[str, ...]]): Acceptable suffix or suffixes.
|
| 102 |
+
msg (str): Additional message to log in case of an error.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
suffix = [suffix] if isinstance(suffix, str) else list(suffix)
|
| 106 |
+
files = [file] if isinstance(file, (str, Path)) else list(file)
|
| 107 |
+
|
| 108 |
+
for f in files:
|
| 109 |
+
file_suffix = Path(f).suffix.lower()
|
| 110 |
+
if file_suffix and file_suffix not in suffix:
|
| 111 |
+
LOGGER.error(f"File {f} does not have an acceptable suffix. Expected: {suffix}")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def model_type(self, p: Path) -> Tuple[bool, ...]:
|
| 115 |
+
"""
|
| 116 |
+
Determines the model type based on the file's suffix.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
path (str): The file path to the model.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Tuple[bool, ...]: A tuple of booleans indicating the model type, corresponding to pt, jit, onnx, xml, engine, and tflite.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
sf = list(export_formats().Suffix) # export suffixes
|
| 126 |
+
self.check_suffix(p, sf) # checks
|
| 127 |
+
types = [s in Path(p).name for s in sf]
|
| 128 |
+
return types
|
boxmot/appearance/reid/config.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_TYPES = [
|
| 2 |
+
"resnet50",
|
| 3 |
+
"resnet101",
|
| 4 |
+
"mlfn",
|
| 5 |
+
"hacnn",
|
| 6 |
+
"mobilenetv2_x1_0",
|
| 7 |
+
"mobilenetv2_x1_4",
|
| 8 |
+
"osnet_x1_0",
|
| 9 |
+
"osnet_x0_75",
|
| 10 |
+
"osnet_x0_5",
|
| 11 |
+
"osnet_x0_25",
|
| 12 |
+
"osnet_ibn_x1_0",
|
| 13 |
+
"osnet_ain_x1_0",
|
| 14 |
+
"lmbn_n",
|
| 15 |
+
"clip",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
TRAINED_URLS = {
|
| 19 |
+
# resnet50
|
| 20 |
+
"resnet50_market1501.pt": "https://drive.google.com/uc?id=1dUUZ4rHDWohmsQXCRe2C_HbYkzz94iBV",
|
| 21 |
+
"resnet50_dukemtmcreid.pt": "https://drive.google.com/uc?id=17ymnLglnc64NRvGOitY3BqMRS9UWd1wg",
|
| 22 |
+
"resnet50_msmt17.pt": "https://drive.google.com/uc?id=1ep7RypVDOthCRIAqDnn4_N-UhkkFHJsj",
|
| 23 |
+
"resnet50_fc512_market1501.pt": "https://drive.google.com/uc?id=1kv8l5laX_YCdIGVCetjlNdzKIA3NvsSt",
|
| 24 |
+
"resnet50_fc512_dukemtmcreid.pt": "https://drive.google.com/uc?id=13QN8Mp3XH81GK4BPGXobKHKyTGH50Rtx",
|
| 25 |
+
"resnet50_fc512_msmt17.pt": "https://drive.google.com/uc?id=1fDJLcz4O5wxNSUvImIIjoaIF9u1Rwaud",
|
| 26 |
+
# mlfn
|
| 27 |
+
"mlfn_market1501.pt": "https://drive.google.com/uc?id=1wXcvhA_b1kpDfrt9s2Pma-MHxtj9pmvS",
|
| 28 |
+
"mlfn_dukemtmcreid.pt": "https://drive.google.com/uc?id=1rExgrTNb0VCIcOnXfMsbwSUW1h2L1Bum",
|
| 29 |
+
"mlfn_msmt17.pt": "https://drive.google.com/uc?id=18JzsZlJb3Wm7irCbZbZ07TN4IFKvR6p-",
|
| 30 |
+
# hacnn
|
| 31 |
+
"hacnn_market1501.pt": "https://drive.google.com/uc?id=1LRKIQduThwGxMDQMiVkTScBwR7WidmYF",
|
| 32 |
+
"hacnn_dukemtmcreid.pt": "https://drive.google.com/uc?id=1zNm6tP4ozFUCUQ7Sv1Z98EAJWXJEhtYH",
|
| 33 |
+
"hacnn_msmt17.pt": "https://drive.google.com/uc?id=1MsKRtPM5WJ3_Tk2xC0aGOO7pM3VaFDNZ",
|
| 34 |
+
# mobilenetv2
|
| 35 |
+
"mobilenetv2_x1_0_market1501.pt": "https://drive.google.com/uc?id=18DgHC2ZJkjekVoqBWszD8_Xiikz-fewp",
|
| 36 |
+
"mobilenetv2_x1_0_dukemtmcreid.pt": "https://drive.google.com/uc?id=1q1WU2FETRJ3BXcpVtfJUuqq4z3psetds",
|
| 37 |
+
"mobilenetv2_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1j50Hv14NOUAg7ZeB3frzfX-WYLi7SrhZ",
|
| 38 |
+
"mobilenetv2_x1_4_market1501.pt": "https://drive.google.com/uc?id=1t6JCqphJG-fwwPVkRLmGGyEBhGOf2GO5",
|
| 39 |
+
"mobilenetv2_x1_4_dukemtmcreid.pt": "https://drive.google.com/uc?id=12uD5FeVqLg9-AFDju2L7SQxjmPb4zpBN",
|
| 40 |
+
"mobilenetv2_x1_4_msmt17.pt": "https://drive.google.com/uc?id=1ZY5P2Zgm-3RbDpbXM0kIBMPvspeNIbXz",
|
| 41 |
+
# osnet
|
| 42 |
+
"osnet_x1_0_market1501.pt": "https://drive.google.com/uc?id=1vduhq5DpN2q1g4fYEZfPI17MJeh9qyrA",
|
| 43 |
+
"osnet_x1_0_dukemtmcreid.pt": "https://drive.google.com/uc?id=1QZO_4sNf4hdOKKKzKc-TZU9WW1v6zQbq",
|
| 44 |
+
"osnet_x1_0_msmt17.pt": "https://drive.google.com/uc?id=112EMUfBPYeYg70w-syK6V6Mx8-Qb9Q1M",
|
| 45 |
+
"osnet_x0_75_market1501.pt": "https://drive.google.com/uc?id=1ozRaDSQw_EQ8_93OUmjDbvLXw9TnfPer",
|
| 46 |
+
"osnet_x0_75_dukemtmcreid.pt": "https://drive.google.com/uc?id=1IE3KRaTPp4OUa6PGTFL_d5_KQSJbP0Or",
|
| 47 |
+
"osnet_x0_75_msmt17.pt": "https://drive.google.com/uc?id=1QEGO6WnJ-BmUzVPd3q9NoaO_GsPNlmWc",
|
| 48 |
+
"osnet_x0_5_market1501.pt": "https://drive.google.com/uc?id=1PLB9rgqrUM7blWrg4QlprCuPT7ILYGKT",
|
| 49 |
+
"osnet_x0_5_dukemtmcreid.pt": "https://drive.google.com/uc?id=1KoUVqmiST175hnkALg9XuTi1oYpqcyTu",
|
| 50 |
+
"osnet_x0_5_msmt17.pt": "https://drive.google.com/uc?id=1UT3AxIaDvS2PdxzZmbkLmjtiqq7AIKCv",
|
| 51 |
+
"osnet_x0_25_market1501.pt": "https://drive.google.com/uc?id=1z1UghYvOTtjx7kEoRfmqSMu-z62J6MAj",
|
| 52 |
+
"osnet_x0_25_dukemtmcreid.pt": "https://drive.google.com/uc?id=1eumrtiXT4NOspjyEV4j8cHmlOaaCGk5l",
|
| 53 |
+
"osnet_x0_25_msmt17.pt": "https://drive.google.com/uc?id=1sSwXSUlj4_tHZequ_iZ8w_Jh0VaRQMqF",
|
| 54 |
+
# osnet_ain | osnet_ibn
|
| 55 |
+
"osnet_ibn_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1q3Sj2ii34NlfxA4LvmHdWO_75NDRmECJ",
|
| 56 |
+
"osnet_ain_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal",
|
| 57 |
+
# lmbn
|
| 58 |
+
"lmbn_n_duke.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_duke.pth",
|
| 59 |
+
"lmbn_n_market.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_market.pth",
|
| 60 |
+
"lmbn_n_cuhk03_d.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_cuhk03_d.pth",
|
| 61 |
+
# clip
|
| 62 |
+
"clip_market1501.pt": "https://drive.google.com/uc?id=1GnyAVeNOg3Yug1KBBWMKKbT2x43O5Ch7",
|
| 63 |
+
"clip_duke.pt": "https://drive.google.com/uc?id=1ldjSkj-7pXAWmx8on5x0EftlCaolU4dY",
|
| 64 |
+
"clip_veri.pt": "https://drive.google.com/uc?id=1RyfHdOBI2pan_wIGSim5-l6cM4S2WN8e",
|
| 65 |
+
"clip_vehicleid.pt": "https://drive.google.com/uc?id=168BLegHHxNqatW5wx1YyL2REaThWoof5"
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
NR_CLASSES_DICT = {
|
| 69 |
+
"market1501": 751,
|
| 70 |
+
"duke": 702,
|
| 71 |
+
"veri": 576,
|
| 72 |
+
"vehicleid": 576,
|
| 73 |
+
}
|
boxmot/appearance/reid/export.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import time
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from boxmot.appearance.exporters.base_exporter import BaseExporter
|
| 9 |
+
from boxmot.appearance.exporters.onnx_exporter import ONNXExporter
|
| 10 |
+
from boxmot.appearance.exporters.openvino_exporter import OpenVINOExporter
|
| 11 |
+
from boxmot.appearance.exporters.tflite_exporter import TFLiteExporter
|
| 12 |
+
from boxmot.appearance.exporters.torchscript_exporter import TorchScriptExporter
|
| 13 |
+
from boxmot.appearance.exporters.tensorrt_exporter import EngineExporter
|
| 14 |
+
from boxmot.appearance.reid import export_formats
|
| 15 |
+
from boxmot.appearance.reid.auto_backend import ReidAutoBackend
|
| 16 |
+
from boxmot.appearance.reid.registry import ReIDModelRegistry
|
| 17 |
+
from boxmot.utils import WEIGHTS, logger as LOGGER
|
| 18 |
+
from boxmot.utils.torch_utils import select_device
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parse_args():
|
| 22 |
+
"""
|
| 23 |
+
Parse command-line arguments for the ReID export script.
|
| 24 |
+
"""
|
| 25 |
+
parser = argparse.ArgumentParser(description="ReID Export Script")
|
| 26 |
+
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for export")
|
| 27 |
+
parser.add_argument("--imgsz", "--img", "--img-size",
|
| 28 |
+
nargs="+", type=int, default=[256, 128],
|
| 29 |
+
help="Image size in the format: height width")
|
| 30 |
+
parser.add_argument("--device", default="cpu",
|
| 31 |
+
help="CUDA device (e.g., '0', '0,1,2,3', or 'cpu')")
|
| 32 |
+
parser.add_argument("--optimize", action="store_true",
|
| 33 |
+
help="Optimize TorchScript for mobile (CPU export only)")
|
| 34 |
+
parser.add_argument("--dynamic", action="store_true",
|
| 35 |
+
help="Enable dynamic axes for ONNX/TF/TensorRT export")
|
| 36 |
+
parser.add_argument("--simplify", action="store_true",
|
| 37 |
+
help="Simplify ONNX model")
|
| 38 |
+
parser.add_argument("--opset", type=int, default=12,
|
| 39 |
+
help="ONNX opset version")
|
| 40 |
+
parser.add_argument("--workspace", type=int, default=4,
|
| 41 |
+
help="TensorRT workspace size (GB)")
|
| 42 |
+
parser.add_argument("--verbose", action="store_true",
|
| 43 |
+
help="Enable verbose logging for TensorRT")
|
| 44 |
+
parser.add_argument("--weights", type=Path,
|
| 45 |
+
default=WEIGHTS / "osnet_x0_25_msmt17.pt",
|
| 46 |
+
help="Path to the model weights (.pt file)")
|
| 47 |
+
parser.add_argument("--half", action="store_true",
|
| 48 |
+
help="Enable FP16 half-precision export (GPU only)")
|
| 49 |
+
parser.add_argument("--include", nargs="+",
|
| 50 |
+
default=["torchscript"],
|
| 51 |
+
help=("Export formats to include. Options: torchscript, onnx, "
|
| 52 |
+
"openvino, engine, tflite"))
|
| 53 |
+
return parser.parse_args()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def validate_export_formats(include):
|
| 57 |
+
"""
|
| 58 |
+
Validate the provided export formats and return corresponding flags.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
include (list): List of export formats provided via the command line.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
tuple: Boolean flags for each export format in the order:
|
| 65 |
+
(torchscript, onnx, openvino, engine, tflite)
|
| 66 |
+
"""
|
| 67 |
+
available_formats = tuple(export_formats()["Argument"][1:])
|
| 68 |
+
include_lower = [fmt.lower() for fmt in include]
|
| 69 |
+
flags = [fmt in include_lower for fmt in available_formats]
|
| 70 |
+
if sum(flags) != len(include_lower):
|
| 71 |
+
raise AssertionError(
|
| 72 |
+
f"ERROR: Invalid --include {include}, valid arguments are {available_formats}"
|
| 73 |
+
)
|
| 74 |
+
return tuple(flags)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def setup_model(args):
|
| 78 |
+
"""
|
| 79 |
+
Initialize and prepare the ReID model for export.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
args: Parsed command-line arguments.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
tuple: (model (torch.nn.Module), dummy_input (torch.Tensor))
|
| 86 |
+
"""
|
| 87 |
+
# Select the correct device
|
| 88 |
+
args.device = select_device(args.device)
|
| 89 |
+
if args.half and args.device.type == "cpu":
|
| 90 |
+
raise AssertionError("--half only compatible with GPU export, use --device 0 for GPU")
|
| 91 |
+
|
| 92 |
+
# Initialize backend model using the auto backend
|
| 93 |
+
auto_backend = ReidAutoBackend(weights=args.weights, device=args.device, half=args.half)
|
| 94 |
+
_ = auto_backend.get_backend() # Backend model is managed internally
|
| 95 |
+
|
| 96 |
+
# Build and load the ReID model from the registry
|
| 97 |
+
model_name = ReIDModelRegistry.get_model_name(args.weights)
|
| 98 |
+
nr_classes = ReIDModelRegistry.get_nr_classes(args.weights)
|
| 99 |
+
pretrained = not (args.weights and args.weights.is_file() and args.weights.suffix == ".pt")
|
| 100 |
+
model = ReIDModelRegistry.build_model(
|
| 101 |
+
model_name,
|
| 102 |
+
num_classes=nr_classes,
|
| 103 |
+
pretrained=pretrained,
|
| 104 |
+
use_gpu=args.device,
|
| 105 |
+
).to(args.device)
|
| 106 |
+
ReIDModelRegistry.load_pretrained_weights(model, args.weights)
|
| 107 |
+
model.eval()
|
| 108 |
+
|
| 109 |
+
# Ensure --optimize is only used with CPU exports
|
| 110 |
+
if args.optimize and args.device.type != "cpu":
|
| 111 |
+
raise AssertionError("--optimize not compatible with CUDA devices, use --device cpu")
|
| 112 |
+
|
| 113 |
+
# Adjust image size if a specific weight type is detected
|
| 114 |
+
if "lmbn" in str(args.weights):
|
| 115 |
+
args.imgsz = [384, 128]
|
| 116 |
+
|
| 117 |
+
# Create dummy input tensor for warming up the model
|
| 118 |
+
dummy_input = torch.empty(args.batch_size, 3, args.imgsz[0], args.imgsz[1]).to(args.device)
|
| 119 |
+
for _ in range(2):
|
| 120 |
+
_ = model(dummy_input)
|
| 121 |
+
|
| 122 |
+
# Convert to half precision if required
|
| 123 |
+
if args.half:
|
| 124 |
+
dummy_input = dummy_input.half()
|
| 125 |
+
model = model.half()
|
| 126 |
+
|
| 127 |
+
return model, dummy_input
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def create_export_tasks(args, model, dummy_input):
|
| 131 |
+
"""
|
| 132 |
+
Create a mapping of export tasks with associated flags, exporter classes, and parameters.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
args: Parsed command-line arguments.
|
| 136 |
+
model: Prepared ReID model.
|
| 137 |
+
dummy_input: Dummy input tensor.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
dict: Mapping of export format to a tuple (flag, exporter_class, export_args)
|
| 141 |
+
"""
|
| 142 |
+
torchscript_flag, onnx_flag, openvino_flag, engine_flag, tflite_flag = validate_export_formats(args.include)
|
| 143 |
+
return {
|
| 144 |
+
"torchscript": (
|
| 145 |
+
torchscript_flag,
|
| 146 |
+
TorchScriptExporter,
|
| 147 |
+
(model, dummy_input, args.weights, args.optimize)
|
| 148 |
+
),
|
| 149 |
+
"engine": (
|
| 150 |
+
engine_flag,
|
| 151 |
+
EngineExporter,
|
| 152 |
+
(model, dummy_input, args.weights, args.half, args.dynamic, args.simplify, args.verbose)
|
| 153 |
+
),
|
| 154 |
+
"onnx": (
|
| 155 |
+
onnx_flag,
|
| 156 |
+
ONNXExporter,
|
| 157 |
+
(model, dummy_input, args.weights, args.opset, args.dynamic, args.half, args.simplify)
|
| 158 |
+
),
|
| 159 |
+
"tflite": (
|
| 160 |
+
tflite_flag,
|
| 161 |
+
TFLiteExporter,
|
| 162 |
+
(model, dummy_input, args.weights)
|
| 163 |
+
),
|
| 164 |
+
"openvino": (
|
| 165 |
+
openvino_flag,
|
| 166 |
+
OpenVINOExporter,
|
| 167 |
+
(model, dummy_input, args.weights, args.half)
|
| 168 |
+
)
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def perform_exports(export_tasks):
|
| 173 |
+
"""
|
| 174 |
+
Iterate over export tasks and perform export for enabled formats.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
export_tasks (dict): Mapping of export tasks.
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
dict: Mapping of export format to export results.
|
| 181 |
+
"""
|
| 182 |
+
exported_files = {}
|
| 183 |
+
for fmt, (flag, exporter_class, exp_args) in export_tasks.items():
|
| 184 |
+
if flag:
|
| 185 |
+
exporter = exporter_class(*exp_args)
|
| 186 |
+
export_result = exporter.export()
|
| 187 |
+
exported_files[fmt] = export_result
|
| 188 |
+
return exported_files
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def main():
|
| 192 |
+
"""Main function to execute the ReID export process."""
|
| 193 |
+
args = parse_args()
|
| 194 |
+
start_time = time.time()
|
| 195 |
+
|
| 196 |
+
# Ensure the weights directory exists
|
| 197 |
+
WEIGHTS.mkdir(parents=False, exist_ok=True)
|
| 198 |
+
|
| 199 |
+
# Setup model and create a dummy input tensor
|
| 200 |
+
model, dummy_input = setup_model(args)
|
| 201 |
+
|
| 202 |
+
# Log model output shape and file size
|
| 203 |
+
output = model(dummy_input)
|
| 204 |
+
output_tensor = output[0] if isinstance(output, tuple) else output
|
| 205 |
+
output_shape = tuple(output_tensor.shape)
|
| 206 |
+
LOGGER.info(
|
| 207 |
+
f"\nStarting from {args.weights} with output shape {output_shape} "
|
| 208 |
+
f"({BaseExporter.file_size(args.weights):.1f} MB)"
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Create export tasks
|
| 212 |
+
export_tasks = create_export_tasks(args, model, dummy_input)
|
| 213 |
+
|
| 214 |
+
# Perform exports for enabled formats
|
| 215 |
+
exported_files = perform_exports(export_tasks)
|
| 216 |
+
|
| 217 |
+
if exported_files:
|
| 218 |
+
elapsed_time = time.time() - start_time
|
| 219 |
+
LOGGER.info(
|
| 220 |
+
f"\nExport complete ({elapsed_time:.1f}s)"
|
| 221 |
+
f"\nResults saved to {args.weights.parent.resolve()}"
|
| 222 |
+
f"\nVisualize: https://netron.app"
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if __name__ == "__main__":
|
| 227 |
+
main()
|
boxmot/appearance/reid/factory.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from boxmot.appearance.backbones.clip.make_model import make_model
|
| 2 |
+
from boxmot.appearance.backbones.hacnn import HACNN
|
| 3 |
+
from boxmot.appearance.backbones.lmbn.lmbn_n import LMBN_n
|
| 4 |
+
from boxmot.appearance.backbones.mlfn import mlfn
|
| 5 |
+
from boxmot.appearance.backbones.mobilenetv2 import mobilenetv2_x1_0, mobilenetv2_x1_4
|
| 6 |
+
from boxmot.appearance.backbones.osnet import (
|
| 7 |
+
osnet_ibn_x1_0,
|
| 8 |
+
osnet_x0_5,
|
| 9 |
+
osnet_x0_25,
|
| 10 |
+
osnet_x0_75,
|
| 11 |
+
osnet_x1_0,
|
| 12 |
+
)
|
| 13 |
+
from boxmot.appearance.backbones.osnet_ain import (
|
| 14 |
+
osnet_ain_x0_5,
|
| 15 |
+
osnet_ain_x0_25,
|
| 16 |
+
osnet_ain_x0_75,
|
| 17 |
+
osnet_ain_x1_0,
|
| 18 |
+
)
|
| 19 |
+
from boxmot.appearance.backbones.resnet import resnet50, resnet101
|
| 20 |
+
|
| 21 |
+
# Map model names to their respective constructors
|
| 22 |
+
MODEL_FACTORY = {
|
| 23 |
+
"resnet50": resnet50,
|
| 24 |
+
"resnet101": resnet101,
|
| 25 |
+
"mobilenetv2_x1_0": mobilenetv2_x1_0,
|
| 26 |
+
"mobilenetv2_x1_4": mobilenetv2_x1_4,
|
| 27 |
+
"hacnn": HACNN,
|
| 28 |
+
"mlfn": mlfn,
|
| 29 |
+
"osnet_x1_0": osnet_x1_0,
|
| 30 |
+
"osnet_x0_75": osnet_x0_75,
|
| 31 |
+
"osnet_x0_5": osnet_x0_5,
|
| 32 |
+
"osnet_x0_25": osnet_x0_25,
|
| 33 |
+
"osnet_ibn_x1_0": osnet_ibn_x1_0,
|
| 34 |
+
"osnet_ain_x1_0": osnet_ain_x1_0,
|
| 35 |
+
"osnet_ain_x0_75": osnet_ain_x0_75,
|
| 36 |
+
"osnet_ain_x0_5": osnet_ain_x0_5,
|
| 37 |
+
"osnet_ain_x0_25": osnet_ain_x0_25,
|
| 38 |
+
"lmbn_n": LMBN_n,
|
| 39 |
+
"clip": make_model,
|
| 40 |
+
}
|
boxmot/appearance/reid/registry.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model_registry.py
|
| 2 |
+
import torch
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from boxmot.utils import logger as LOGGER
|
| 5 |
+
|
| 6 |
+
from boxmot.appearance.reid.config import MODEL_TYPES, TRAINED_URLS, NR_CLASSES_DICT
|
| 7 |
+
from boxmot.appearance.reid.factory import MODEL_FACTORY
|
| 8 |
+
|
| 9 |
+
class ReIDModelRegistry:
|
| 10 |
+
"""Encapsulates model registration and related utilities."""
|
| 11 |
+
|
| 12 |
+
@staticmethod
|
| 13 |
+
def show_downloadable_models():
|
| 14 |
+
LOGGER.info("Available .pt ReID models for automatic download")
|
| 15 |
+
LOGGER.info(list(TRAINED_URLS.keys()))
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
def get_model_name(model):
|
| 19 |
+
for name in MODEL_TYPES:
|
| 20 |
+
if name in model.name:
|
| 21 |
+
return name
|
| 22 |
+
return None
|
| 23 |
+
|
| 24 |
+
@staticmethod
|
| 25 |
+
def get_model_url(model):
|
| 26 |
+
return TRAINED_URLS.get(model.name, None)
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
def load_pretrained_weights(model, weight_path):
|
| 30 |
+
"""
|
| 31 |
+
Loads pretrained weights into a model.
|
| 32 |
+
Chooses the proper map_location based on CUDA availability.
|
| 33 |
+
"""
|
| 34 |
+
device = "cpu" if not torch.cuda.is_available() else None
|
| 35 |
+
checkpoint = torch.load(weight_path, map_location=torch.device("cpu") if device == "cpu" else None)
|
| 36 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
| 37 |
+
model_dict = model.state_dict()
|
| 38 |
+
|
| 39 |
+
if "lmbn" in weight_path.parts:
|
| 40 |
+
model.load_state_dict(model_dict, strict=True)
|
| 41 |
+
else:
|
| 42 |
+
new_state_dict = OrderedDict()
|
| 43 |
+
matched_layers, discarded_layers = [], []
|
| 44 |
+
for k, v in state_dict.items():
|
| 45 |
+
# Remove 'module.' prefix if present
|
| 46 |
+
key = k[7:] if k.startswith("module.") else k
|
| 47 |
+
if key in model_dict and model_dict[key].size() == v.size():
|
| 48 |
+
new_state_dict[key] = v
|
| 49 |
+
matched_layers.append(key)
|
| 50 |
+
else:
|
| 51 |
+
discarded_layers.append(key)
|
| 52 |
+
model_dict.update(new_state_dict)
|
| 53 |
+
model.load_state_dict(model_dict)
|
| 54 |
+
|
| 55 |
+
if not matched_layers:
|
| 56 |
+
LOGGER.debug(f"Pretrained weights from {weight_path} cannot be loaded. Check key names manually.")
|
| 57 |
+
else:
|
| 58 |
+
LOGGER.success(f"Loaded pretrained weights from {weight_path}")
|
| 59 |
+
|
| 60 |
+
if discarded_layers:
|
| 61 |
+
LOGGER.debug(f"Discarded layers due to unmatched keys or size: {discarded_layers}")
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def show_available_models():
|
| 65 |
+
LOGGER.info("Available models:")
|
| 66 |
+
LOGGER.info(list(MODEL_FACTORY.keys()))
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def get_nr_classes(weights):
|
| 70 |
+
# Extract dataset name from weights name, then look up in the class dictionary
|
| 71 |
+
dataset_key = weights.name.split('_')[1]
|
| 72 |
+
return NR_CLASSES_DICT.get(dataset_key, 1)
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def build_model(name, num_classes, loss="softmax", pretrained=True, use_gpu=True):
|
| 76 |
+
if name not in MODEL_FACTORY:
|
| 77 |
+
available = list(MODEL_FACTORY.keys())
|
| 78 |
+
raise KeyError(f"Unknown model '{name}'. Must be one of {available}")
|
| 79 |
+
|
| 80 |
+
# Special case handling for clip model
|
| 81 |
+
if 'clip' in name:
|
| 82 |
+
from boxmot.appearance.backbones.clip.config.defaults import _C as cfg
|
| 83 |
+
return MODEL_FACTORY[name](cfg, num_class=num_classes, camera_num=2, view_num=1)
|
| 84 |
+
|
| 85 |
+
return MODEL_FACTORY[name](
|
| 86 |
+
num_classes=num_classes, loss=loss, pretrained=pretrained, use_gpu=use_gpu
|
| 87 |
+
)
|
boxmot/configs/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
|
boxmot/configs/boosttrack.yaml
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
max_age:
|
| 2 |
+
type: uniform
|
| 3 |
+
default: 60
|
| 4 |
+
range: [15, 90]
|
| 5 |
+
|
| 6 |
+
min_hits:
|
| 7 |
+
type: uniform
|
| 8 |
+
default: 3
|
| 9 |
+
range: [1, 5]
|
| 10 |
+
|
| 11 |
+
det_thresh:
|
| 12 |
+
type: uniform
|
| 13 |
+
default: 0.6
|
| 14 |
+
range: [0.1, 0.9]
|
| 15 |
+
|
| 16 |
+
iou_threshold:
|
| 17 |
+
type: uniform
|
| 18 |
+
default: 0.3
|
| 19 |
+
range: [0.1, 0.9]
|
| 20 |
+
|
| 21 |
+
use_ecc:
|
| 22 |
+
type: choice
|
| 23 |
+
default: True
|
| 24 |
+
options: [False, True]
|
| 25 |
+
|
| 26 |
+
min_box_area:
|
| 27 |
+
type: uniform
|
| 28 |
+
default: 10
|
| 29 |
+
range: [5, 100]
|
| 30 |
+
|
| 31 |
+
aspect_ratio_thresh:
|
| 32 |
+
type: uniform
|
| 33 |
+
default: 1.6
|
| 34 |
+
range: [0.1, 2.0]
|
| 35 |
+
|
| 36 |
+
lambda_iou:
|
| 37 |
+
type: uniform
|
| 38 |
+
default: 0.5
|
| 39 |
+
range: [0.3, 2.0]
|
| 40 |
+
|
| 41 |
+
lambda_mhd:
|
| 42 |
+
type: uniform
|
| 43 |
+
default: 0.25
|
| 44 |
+
range: [0.5, 2.0]
|
| 45 |
+
|
| 46 |
+
lambda_shape:
|
| 47 |
+
type: uniform
|
| 48 |
+
default: 0.25
|
| 49 |
+
range: [0.5, 2.0]
|
| 50 |
+
|
| 51 |
+
use_dlo_boost:
|
| 52 |
+
type: choice
|
| 53 |
+
default: True
|
| 54 |
+
options: [False, True]
|
| 55 |
+
|
| 56 |
+
use_duo_boost:
|
| 57 |
+
type: choice
|
| 58 |
+
default: True
|
| 59 |
+
options: [False, True]
|
| 60 |
+
|
| 61 |
+
dlo_boost_coef:
|
| 62 |
+
type: uniform
|
| 63 |
+
default: 0.65
|
| 64 |
+
range: [0.3, 2.0]
|
| 65 |
+
|
| 66 |
+
s_sim_corr:
|
| 67 |
+
type: choice
|
| 68 |
+
default: False
|
| 69 |
+
options: [False, True]
|
| 70 |
+
|
| 71 |
+
use_rich_s:
|
| 72 |
+
type: choice
|
| 73 |
+
default: True # True for BoostTrack++
|
| 74 |
+
options: [False, True]
|
| 75 |
+
|
| 76 |
+
use_sb:
|
| 77 |
+
type: choice
|
| 78 |
+
default: True # True for BoostTrack++
|
| 79 |
+
options: [False, True]
|
| 80 |
+
|
| 81 |
+
use_vt:
|
| 82 |
+
type: choice
|
| 83 |
+
default: True # True for BoostTrack++
|
| 84 |
+
options: [False, True]
|
| 85 |
+
|
| 86 |
+
with_reid:
|
| 87 |
+
type: choice
|
| 88 |
+
default: True # True for BoostTrack+ and BoostTrack++
|
| 89 |
+
options: [False, True]
|
| 90 |
+
|
boxmot/configs/botsort.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
track_high_thresh:
|
| 2 |
+
type: uniform
|
| 3 |
+
default: 0.6 # from the default parameters
|
| 4 |
+
range: [0.3, 0.7]
|
| 5 |
+
|
| 6 |
+
track_low_thresh:
|
| 7 |
+
type: uniform
|
| 8 |
+
default: 0.1 # from the default parameters
|
| 9 |
+
range: [0.1, 0.3]
|
| 10 |
+
|
| 11 |
+
new_track_thresh:
|
| 12 |
+
type: uniform
|
| 13 |
+
default: 0.7 # from the default parameters
|
| 14 |
+
range: [0.1, 0.8]
|
| 15 |
+
|
| 16 |
+
track_buffer:
|
| 17 |
+
type: randint
|
| 18 |
+
default: 30 # from the default parameters
|
| 19 |
+
range: [20, 81]
|
| 20 |
+
|
| 21 |
+
match_thresh:
|
| 22 |
+
type: uniform
|
| 23 |
+
default: 0.8 # from the default parameters
|
| 24 |
+
range: [0.1, 0.9]
|
| 25 |
+
|
| 26 |
+
proximity_thresh:
|
| 27 |
+
type: uniform
|
| 28 |
+
default: 0.5 # from the default parameters
|
| 29 |
+
range: [0.25, 0.75]
|
| 30 |
+
|
| 31 |
+
appearance_thresh:
|
| 32 |
+
type: uniform
|
| 33 |
+
default: 0.25 # from the default parameters
|
| 34 |
+
range: [0.1, 0.8]
|
| 35 |
+
|
| 36 |
+
cmc_method:
|
| 37 |
+
type: choice
|
| 38 |
+
default: ecc # from the default parameters
|
| 39 |
+
options: [sof, ecc]
|
boxmot/configs/bytetrack.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
min_conf:
|
| 2 |
+
type: uniform
|
| 3 |
+
default: 0.1 # from the default parameters
|
| 4 |
+
range: [0.1, 0.3]
|
| 5 |
+
|
| 6 |
+
track_thresh:
|
| 7 |
+
type: uniform
|
| 8 |
+
default: 0.6 # from the default parameters
|
| 9 |
+
range: [0.4, 0.6]
|
| 10 |
+
|
| 11 |
+
track_buffer:
|
| 12 |
+
type: randint
|
| 13 |
+
default: 30 # from the default parameters
|
| 14 |
+
range: [10, 61, 10] # step size of 10, upper bound exclusive
|
| 15 |
+
|
| 16 |
+
match_thresh:
|
| 17 |
+
type: uniform
|
| 18 |
+
default: 0.9 # from the default parameters
|
| 19 |
+
range: [0.7, 0.9]
|
| 20 |
+
|
| 21 |
+
frame_rate:
|
| 22 |
+
type: choice
|
| 23 |
+
default: 30 # from the default parameters
|
| 24 |
+
choices: [30] # static choice for Ray Search
|
boxmot/configs/deepocsort.yaml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
det_thresh:
|
| 2 |
+
type: uniform
|
| 3 |
+
default: 0.5 # from the default parameters
|
| 4 |
+
range: [0.3, 0.6]
|
| 5 |
+
|
| 6 |
+
max_age:
|
| 7 |
+
type: randint
|
| 8 |
+
default: 30 # from the default parameters
|
| 9 |
+
range: [10, 61, 10] # step size of 10, upper bound exclusive
|
| 10 |
+
|
| 11 |
+
min_hits:
|
| 12 |
+
type: randint
|
| 13 |
+
default: 3 # from the default parameters
|
| 14 |
+
range: [1, 6] # upper bound exclusive
|
| 15 |
+
|
| 16 |
+
iou_thresh:
|
| 17 |
+
type: uniform
|
| 18 |
+
default: 0.3 # from the default parameters
|
| 19 |
+
range: [0.1, 0.4]
|
| 20 |
+
|
| 21 |
+
delta_t:
|
| 22 |
+
type: randint
|
| 23 |
+
default: 3 # from the default parameters
|
| 24 |
+
range: [1, 6] # upper bound exclusive
|
| 25 |
+
|
| 26 |
+
asso_func:
|
| 27 |
+
type: choice
|
| 28 |
+
default: iou # from the default parameters
|
| 29 |
+
options: ['iou', 'giou', 'diou', 'ciou', 'hmiou']
|
| 30 |
+
|
| 31 |
+
inertia:
|
| 32 |
+
type: uniform
|
| 33 |
+
default: 0.2 # from the default parameters
|
| 34 |
+
range: [0.1, 0.4]
|
| 35 |
+
|
| 36 |
+
w_association_emb:
|
| 37 |
+
type: uniform
|
| 38 |
+
default: 0.75 # from the default parameters
|
| 39 |
+
range: [0.5, 0.9]
|
| 40 |
+
|
| 41 |
+
alpha_fixed_emb:
|
| 42 |
+
type: uniform
|
| 43 |
+
default: 0.95 # from the default parameters
|
| 44 |
+
range: [0.9, 0.999]
|
| 45 |
+
|
| 46 |
+
aw_param:
|
| 47 |
+
type: uniform
|
| 48 |
+
default: 0.5 # from the default parameters
|
| 49 |
+
range: [0.3, 0.7]
|
| 50 |
+
|
| 51 |
+
embedding_off:
|
| 52 |
+
type: choice
|
| 53 |
+
default: false # from the default parameters
|
| 54 |
+
options: [True, False]
|
| 55 |
+
|
| 56 |
+
cmc_off:
|
| 57 |
+
type: choice
|
| 58 |
+
default: false # from the default parameters
|
| 59 |
+
options: [True, False]
|
| 60 |
+
|
| 61 |
+
aw_off:
|
| 62 |
+
type: choice
|
| 63 |
+
default: false # from the default parameters
|
| 64 |
+
options: [True, False]
|
| 65 |
+
|
| 66 |
+
Q_xy_scaling:
|
| 67 |
+
type: uniform
|
| 68 |
+
default: 0.01 # from the default parameters
|
| 69 |
+
range: [0.01, 1]
|
| 70 |
+
|
| 71 |
+
Q_s_scaling:
|
| 72 |
+
type: uniform
|
| 73 |
+
default: 0.0001 # from the default parameters
|
| 74 |
+
range: [0.0001, 1]
|
boxmot/configs/hybridsort.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
det_thresh:
|
| 2 |
+
type: uniform
|
| 3 |
+
default: 0.12442660055370669 # from the default parameters
|
| 4 |
+
range: [0, 0.6]
|
| 5 |
+
|
| 6 |
+
max_age:
|
| 7 |
+
type: randint
|
| 8 |
+
default: 30 # from the default parameters
|
| 9 |
+
range: [10, 151, 10] # step size of 10, upper bound exclusive
|
| 10 |
+
|
| 11 |
+
min_hits:
|
| 12 |
+
type: randint
|
| 13 |
+
default: 1 # from the default parameters
|
| 14 |
+
range: [1, 6] # upper bound exclusive
|
| 15 |
+
|
| 16 |
+
delta_t:
|
| 17 |
+
type: randint
|
| 18 |
+
default: 5 # from the default parameters
|
| 19 |
+
range: [1, 6] # upper bound exclusive
|
| 20 |
+
|
| 21 |
+
asso_func:
|
| 22 |
+
type: choice
|
| 23 |
+
default: hmiou # from the default parameters
|
| 24 |
+
options: ['iou', 'giou', 'diou']
|
| 25 |
+
|
| 26 |
+
iou_threshold:
|
| 27 |
+
type: uniform
|
| 28 |
+
default: 0.3 # from the default parameters
|
| 29 |
+
range: [0.1, 0.4]
|
| 30 |
+
|
| 31 |
+
inertia:
|
| 32 |
+
type: uniform
|
| 33 |
+
default: 0.369525477649008 # from the default parameters
|
| 34 |
+
range: [0.1, 0.4]
|
| 35 |
+
|
| 36 |
+
TCM_first_step_weight:
|
| 37 |
+
type: uniform
|
| 38 |
+
default: 0.2866529225304586 # from the default parameters
|
| 39 |
+
range: [0, 0.5]
|
| 40 |
+
|
| 41 |
+
longterm_reid_weight:
|
| 42 |
+
type: uniform
|
| 43 |
+
default: 0.0509704360503877 # from the default parameters
|
| 44 |
+
range: [0, 0.5]
|
| 45 |
+
|
| 46 |
+
use_byte:
|
| 47 |
+
type: choice
|
| 48 |
+
default: False # from the default parameters
|
| 49 |
+
options: [True, False]
|