Spaces:
Sleeping
Sleeping
github-actions[bot]
commited on
Commit
·
123489f
0
Parent(s):
Sync to HuggingFace Spaces
Browse files- .gitattributes +35 -0
- .github/workflows/main.yml +25 -0
- .gitignore +4 -0
- CHANGELOG.md +12 -0
- LICENSE +21 -0
- README.md +42 -0
- app.py +880 -0
- data/sample-1.mp4 +3 -0
- data/sample-2.mp4 +3 -0
- export_onnx_model.py +201 -0
- packages.txt +1 -0
- poetry.lock +0 -0
- pyproject.toml +31 -0
- requirements.txt +106 -0
- track_anything.py +88 -0
- tracker/base_tracker.py +142 -0
- tracker/config/config.yaml +15 -0
- tracker/inference/__init__.py +0 -0
- tracker/inference/inference_core.py +149 -0
- tracker/inference/kv_memory_store.py +234 -0
- tracker/inference/memory_manager.py +373 -0
- tracker/model/__init__.py +0 -0
- tracker/model/aggregate.py +16 -0
- tracker/model/cbam.py +119 -0
- tracker/model/group_modules.py +92 -0
- tracker/model/losses.py +76 -0
- tracker/model/memory_util.py +87 -0
- tracker/model/modules.py +261 -0
- tracker/model/network.py +241 -0
- tracker/model/resnet.py +191 -0
- tracker/model/trainer.py +302 -0
- tracker/util/__init__.py +0 -0
- tracker/util/mask_mapper.py +87 -0
- tracker/util/range_transform.py +12 -0
- tracker/util/tensor_util.py +50 -0
- utils/base_segmenter.py +149 -0
- utils/blur.py +81 -0
- utils/interact_tools.py +109 -0
- utils/painter.py +360 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/main.yml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
on:
|
2 |
+
push:
|
3 |
+
branches:
|
4 |
+
- main
|
5 |
+
jobs:
|
6 |
+
huggingface-sync:
|
7 |
+
runs-on: ubuntu-latest
|
8 |
+
steps:
|
9 |
+
- name: Checkout Repository
|
10 |
+
uses: actions/checkout@v3
|
11 |
+
|
12 |
+
- name: Hugging Face Sync
|
13 |
+
uses: JacobLinCool/huggingface-sync@v1
|
14 |
+
with:
|
15 |
+
user: Y-T-G
|
16 |
+
space: Blur-Anything
|
17 |
+
emoji: 💻
|
18 |
+
token: ${{ secrets.HF_TOKEN }}
|
19 |
+
github: ${{ secrets.GITHUB_TOKEN }}
|
20 |
+
colorFrom: yellow
|
21 |
+
colorTo: pino
|
22 |
+
sdk: gradio
|
23 |
+
app_file: app.py
|
24 |
+
pinned: false
|
25 |
+
license: mit
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoints/*
|
2 |
+
output/*
|
3 |
+
notebook.ipynb
|
4 |
+
*.pyc
|
CHANGELOG.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Changelog
|
2 |
+
|
3 |
+
## v0.2.0 - 2023-08-11
|
4 |
+
|
5 |
+
### MobileSAM
|
6 |
+
- Added quantized ONNX MobileSAM model. Pass `--sam_model_type vit_t` to use it.
|
7 |
+
|
8 |
+
## v0.1.0 - 2023-05-06
|
9 |
+
|
10 |
+
### Blur-Anything Initial Release
|
11 |
+
- Added blur implementation
|
12 |
+
- Using pims instead of storing frames in memory for better memory usage
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Mohammed Yasin
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Blur Anything
|
3 |
+
emoji: 💻
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: pino
|
6 |
+
sdk: gradio
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
9 |
+
---
|
10 |
+
|
11 |
+
# Blur Anything For Videos
|
12 |
+
|
13 |
+
Blur Anything is an adaptation of the excellent [Track Anything](https://github.com/gaomingqi/Track-Anything) project which is in turn based on Meta's Segment Anything and XMem. It allows you to blur anything in a video, including faces, license plates, etc.
|
14 |
+
|
15 |
+
<div>
|
16 |
+
<a src="https://img.shields.io/badge/%F0%9F%A4%97-Open_in_Spaces-informational.svg?style=flat-square" href="https://huggingface.co/spaces/Y-T-G/Blur-Anything">
|
17 |
+
<img src="https://img.shields.io/badge/%F0%9F%A4%97-Open_in_Spaces-informational.svg?style=flat-square">
|
18 |
+
</a>
|
19 |
+
</div>
|
20 |
+
|
21 |
+
## Get Started
|
22 |
+
```shell
|
23 |
+
# Clone the repository:
|
24 |
+
git clone https://github.com/Y-T-G/Blur-Anything.git
|
25 |
+
cd Blur-Anything
|
26 |
+
|
27 |
+
# Install dependencies:
|
28 |
+
pip install -r requirements.txt
|
29 |
+
|
30 |
+
# Run the Blur-Anything gradio demo.
|
31 |
+
python app.py --device cuda:0
|
32 |
+
# python app.py --device cuda:0 --sam_model_type vit_b # for lower memory usage
|
33 |
+
```
|
34 |
+
|
35 |
+
## To Do
|
36 |
+
- [x] Add a gradio demo
|
37 |
+
- [ ] Add support to use YouTube video URL
|
38 |
+
- [ ] Add option to completely black out the object
|
39 |
+
|
40 |
+
## Acknowledgements
|
41 |
+
|
42 |
+
The project is an adaptation of [Track Anything](https://github.com/gaomingqi/Track-Anything) which is based on [Segment Anything](https://github.com/facebookresearch/segment-anything) and [XMem](https://github.com/hkchengrex/XMem).
|
app.py
ADDED
@@ -0,0 +1,880 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import requests
|
4 |
+
import sys
|
5 |
+
import json
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torchvision
|
11 |
+
import pims
|
12 |
+
|
13 |
+
from export_onnx_model import run_export
|
14 |
+
from onnxruntime.quantization import QuantType
|
15 |
+
from onnxruntime.quantization.quantize import quantize_dynamic
|
16 |
+
|
17 |
+
sys.path.append(sys.path[0] + "/tracker")
|
18 |
+
sys.path.append(sys.path[0] + "/tracker/model")
|
19 |
+
|
20 |
+
from track_anything import TrackingAnything
|
21 |
+
from track_anything import parse_augment
|
22 |
+
|
23 |
+
from utils.painter import mask_painter
|
24 |
+
from utils.blur import blur_frames_and_write
|
25 |
+
|
26 |
+
|
27 |
+
# download checkpoints
|
28 |
+
def download_checkpoint(url, folder, filename):
|
29 |
+
os.makedirs(folder, exist_ok=True)
|
30 |
+
filepath = os.path.join(folder, filename)
|
31 |
+
|
32 |
+
if not os.path.exists(filepath):
|
33 |
+
print("Downloading checkpoints...")
|
34 |
+
response = requests.get(url, stream=True)
|
35 |
+
with open(filepath, "wb") as f:
|
36 |
+
for chunk in response.iter_content(chunk_size=8192):
|
37 |
+
if chunk:
|
38 |
+
f.write(chunk)
|
39 |
+
|
40 |
+
print("Download successful.")
|
41 |
+
|
42 |
+
return filepath
|
43 |
+
|
44 |
+
|
45 |
+
# convert points input to prompt state
|
46 |
+
def get_prompt(click_state, click_input):
|
47 |
+
inputs = json.loads(click_input)
|
48 |
+
points = click_state[0]
|
49 |
+
labels = click_state[1]
|
50 |
+
for input in inputs:
|
51 |
+
points.append(input[:2])
|
52 |
+
labels.append(input[2])
|
53 |
+
click_state[0] = points
|
54 |
+
click_state[1] = labels
|
55 |
+
prompt = {
|
56 |
+
"prompt_type": ["click"],
|
57 |
+
"input_point": click_state[0],
|
58 |
+
"input_label": click_state[1],
|
59 |
+
"multimask_output": "False",
|
60 |
+
}
|
61 |
+
return prompt
|
62 |
+
|
63 |
+
|
64 |
+
# extract frames from upload video
|
65 |
+
def get_frames_from_video(video_input, video_state):
|
66 |
+
"""
|
67 |
+
Args:
|
68 |
+
video_path:str
|
69 |
+
timestamp:float64
|
70 |
+
Return
|
71 |
+
[[0:nearest_frame], [nearest_frame:], nearest_frame]
|
72 |
+
"""
|
73 |
+
video_path = video_input
|
74 |
+
frames = []
|
75 |
+
user_name = time.time()
|
76 |
+
operation_log = [
|
77 |
+
("", ""),
|
78 |
+
(
|
79 |
+
"Video uploaded. Click the image for adding targets to track and blur.",
|
80 |
+
"Normal",
|
81 |
+
),
|
82 |
+
]
|
83 |
+
try:
|
84 |
+
frames = pims.Video(video_path)
|
85 |
+
fps = frames.frame_rate
|
86 |
+
image_size = (frames.shape[1], frames.shape[2])
|
87 |
+
|
88 |
+
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
|
89 |
+
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
|
90 |
+
|
91 |
+
# initialize video_state
|
92 |
+
video_state = {
|
93 |
+
"user_name": user_name,
|
94 |
+
"video_name": os.path.split(video_path)[-1],
|
95 |
+
"origin_images": frames,
|
96 |
+
"painted_images": [0] * len(frames),
|
97 |
+
"masks": [0] * len(frames),
|
98 |
+
"logits": [None] * len(frames),
|
99 |
+
"select_frame_number": 0,
|
100 |
+
"fps": fps,
|
101 |
+
}
|
102 |
+
video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(
|
103 |
+
video_state["video_name"], video_state["fps"], len(frames), image_size
|
104 |
+
)
|
105 |
+
model.samcontroler.sam_controler.reset_image()
|
106 |
+
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
|
107 |
+
return (
|
108 |
+
video_state,
|
109 |
+
video_info,
|
110 |
+
video_state["origin_images"][0],
|
111 |
+
gr.update(visible=True, maximum=len(frames), value=1),
|
112 |
+
gr.update(visible=True, maximum=len(frames), value=len(frames)),
|
113 |
+
gr.update(visible=True),
|
114 |
+
gr.update(visible=True),
|
115 |
+
gr.update(visible=True),
|
116 |
+
gr.update(visible=True),
|
117 |
+
gr.update(visible=True),
|
118 |
+
gr.update(visible=True),
|
119 |
+
gr.update(visible=True),
|
120 |
+
gr.update(visible=True),
|
121 |
+
gr.update(visible=True),
|
122 |
+
gr.update(visible=True),
|
123 |
+
gr.update(visible=True, value=operation_log),
|
124 |
+
)
|
125 |
+
|
126 |
+
|
127 |
+
def run_example(example):
|
128 |
+
return video_input
|
129 |
+
|
130 |
+
|
131 |
+
# get the select frame from gradio slider
|
132 |
+
def select_template(image_selection_slider, video_state, interactive_state):
|
133 |
+
# images = video_state[1]
|
134 |
+
image_selection_slider -= 1
|
135 |
+
video_state["select_frame_number"] = image_selection_slider
|
136 |
+
|
137 |
+
# once select a new template frame, set the image in sam
|
138 |
+
|
139 |
+
model.samcontroler.sam_controler.reset_image()
|
140 |
+
model.samcontroler.sam_controler.set_image(
|
141 |
+
video_state["origin_images"][image_selection_slider]
|
142 |
+
)
|
143 |
+
|
144 |
+
# update the masks when select a new template frame
|
145 |
+
operation_log = [
|
146 |
+
("", ""),
|
147 |
+
(
|
148 |
+
"Select frame {}. Try click image and add mask for tracking.".format(
|
149 |
+
image_selection_slider
|
150 |
+
),
|
151 |
+
"Normal",
|
152 |
+
),
|
153 |
+
]
|
154 |
+
|
155 |
+
return (
|
156 |
+
video_state["painted_images"][image_selection_slider],
|
157 |
+
video_state,
|
158 |
+
interactive_state,
|
159 |
+
operation_log,
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
# set the tracking end frame
|
164 |
+
def set_end_number(track_pause_number_slider, video_state, interactive_state):
|
165 |
+
interactive_state["track_end_number"] = track_pause_number_slider
|
166 |
+
operation_log = [
|
167 |
+
("", ""),
|
168 |
+
(
|
169 |
+
"Set the tracking finish at frame {}".format(track_pause_number_slider),
|
170 |
+
"Normal",
|
171 |
+
),
|
172 |
+
]
|
173 |
+
|
174 |
+
return (
|
175 |
+
interactive_state,
|
176 |
+
operation_log,
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
+
def get_resize_ratio(resize_ratio_slider, interactive_state):
|
181 |
+
interactive_state["resize_ratio"] = resize_ratio_slider
|
182 |
+
|
183 |
+
return interactive_state
|
184 |
+
|
185 |
+
|
186 |
+
def get_blur_strength(blur_strength_slider, interactive_state):
|
187 |
+
interactive_state["blur_strength"] = blur_strength_slider
|
188 |
+
|
189 |
+
return interactive_state
|
190 |
+
|
191 |
+
|
192 |
+
# use sam to get the mask
|
193 |
+
def sam_refine(
|
194 |
+
video_state, point_prompt, click_state, interactive_state, evt: gr.SelectData
|
195 |
+
):
|
196 |
+
"""
|
197 |
+
Args:
|
198 |
+
template_frame: PIL.Image
|
199 |
+
point_prompt: flag for positive or negative button click
|
200 |
+
click_state: [[points], [labels]]
|
201 |
+
"""
|
202 |
+
if point_prompt == "Positive":
|
203 |
+
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
|
204 |
+
interactive_state["positive_click_times"] += 1
|
205 |
+
else:
|
206 |
+
coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
|
207 |
+
interactive_state["negative_click_times"] += 1
|
208 |
+
|
209 |
+
# prompt for sam model
|
210 |
+
model.samcontroler.sam_controler.reset_image()
|
211 |
+
model.samcontroler.sam_controler.set_image(
|
212 |
+
video_state["origin_images"][video_state["select_frame_number"]]
|
213 |
+
)
|
214 |
+
prompt = get_prompt(click_state=click_state, click_input=coordinate)
|
215 |
+
|
216 |
+
mask, logit, painted_image = model.first_frame_click(
|
217 |
+
image=video_state["origin_images"][video_state["select_frame_number"]],
|
218 |
+
points=np.array(prompt["input_point"]),
|
219 |
+
labels=np.array(prompt["input_label"]),
|
220 |
+
multimask=prompt["multimask_output"],
|
221 |
+
)
|
222 |
+
|
223 |
+
video_state["masks"][video_state["select_frame_number"]] = mask
|
224 |
+
video_state["logits"][video_state["select_frame_number"]] = logit
|
225 |
+
video_state["painted_images"][video_state["select_frame_number"]] = painted_image
|
226 |
+
|
227 |
+
operation_log = [
|
228 |
+
("", ""),
|
229 |
+
(
|
230 |
+
"Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment",
|
231 |
+
"Normal",
|
232 |
+
),
|
233 |
+
]
|
234 |
+
return painted_image, video_state, interactive_state, operation_log
|
235 |
+
|
236 |
+
|
237 |
+
def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
238 |
+
try:
|
239 |
+
mask = video_state["masks"][video_state["select_frame_number"]]
|
240 |
+
interactive_state["multi_mask"]["masks"].append(mask)
|
241 |
+
interactive_state["multi_mask"]["mask_names"].append(
|
242 |
+
"mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))
|
243 |
+
)
|
244 |
+
mask_dropdown.append(
|
245 |
+
"mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))
|
246 |
+
)
|
247 |
+
select_frame, run_status = show_mask(
|
248 |
+
video_state, interactive_state, mask_dropdown
|
249 |
+
)
|
250 |
+
|
251 |
+
operation_log = [
|
252 |
+
("", ""),
|
253 |
+
(
|
254 |
+
"Added a mask, use the mask select for target tracking or blurring.",
|
255 |
+
"Normal",
|
256 |
+
),
|
257 |
+
]
|
258 |
+
except Exception:
|
259 |
+
operation_log = [
|
260 |
+
("Please click the left image to generate mask.", "Error"),
|
261 |
+
("", ""),
|
262 |
+
]
|
263 |
+
return (
|
264 |
+
interactive_state,
|
265 |
+
gr.update(
|
266 |
+
choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown
|
267 |
+
),
|
268 |
+
select_frame,
|
269 |
+
[[], []],
|
270 |
+
operation_log,
|
271 |
+
)
|
272 |
+
|
273 |
+
|
274 |
+
def clear_click(video_state, click_state):
|
275 |
+
click_state = [[], []]
|
276 |
+
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
277 |
+
operation_log = [
|
278 |
+
("", ""),
|
279 |
+
("Clear points history and refresh the image.", "Normal"),
|
280 |
+
]
|
281 |
+
return template_frame, click_state, operation_log
|
282 |
+
|
283 |
+
|
284 |
+
def remove_multi_mask(interactive_state, mask_dropdown):
|
285 |
+
interactive_state["multi_mask"]["mask_names"] = []
|
286 |
+
interactive_state["multi_mask"]["masks"] = []
|
287 |
+
|
288 |
+
operation_log = [("", ""), ("Remove all mask, please add new masks", "Normal")]
|
289 |
+
return interactive_state, gr.update(choices=[], value=[]), operation_log
|
290 |
+
|
291 |
+
|
292 |
+
def show_mask(video_state, interactive_state, mask_dropdown):
|
293 |
+
mask_dropdown.sort()
|
294 |
+
select_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
295 |
+
|
296 |
+
for i in range(len(mask_dropdown)):
|
297 |
+
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
298 |
+
mask = interactive_state["multi_mask"]["masks"][mask_number]
|
299 |
+
select_frame = mask_painter(
|
300 |
+
select_frame, mask.astype("uint8"), mask_color=mask_number + 2
|
301 |
+
)
|
302 |
+
|
303 |
+
operation_log = [
|
304 |
+
("", ""),
|
305 |
+
("Select {} for tracking or blurring".format(mask_dropdown), "Normal"),
|
306 |
+
]
|
307 |
+
return select_frame, operation_log
|
308 |
+
|
309 |
+
|
310 |
+
# tracking vos
|
311 |
+
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
312 |
+
operation_log = [
|
313 |
+
("", ""),
|
314 |
+
(
|
315 |
+
"Track the selected masks, and then you can select the masks for blurring.",
|
316 |
+
"Normal",
|
317 |
+
),
|
318 |
+
]
|
319 |
+
model.xmem.clear_memory()
|
320 |
+
if interactive_state["track_end_number"]:
|
321 |
+
following_frames = video_state["origin_images"][
|
322 |
+
video_state["select_frame_number"]: interactive_state["track_end_number"]
|
323 |
+
]
|
324 |
+
else:
|
325 |
+
following_frames = video_state["origin_images"][
|
326 |
+
video_state["select_frame_number"]:
|
327 |
+
]
|
328 |
+
|
329 |
+
if interactive_state["multi_mask"]["masks"]:
|
330 |
+
if len(mask_dropdown) == 0:
|
331 |
+
mask_dropdown = ["mask_001"]
|
332 |
+
mask_dropdown.sort()
|
333 |
+
template_mask = interactive_state["multi_mask"]["masks"][
|
334 |
+
int(mask_dropdown[0].split("_")[1]) - 1
|
335 |
+
] * (int(mask_dropdown[0].split("_")[1]))
|
336 |
+
for i in range(1, len(mask_dropdown)):
|
337 |
+
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
338 |
+
template_mask = np.clip(
|
339 |
+
template_mask
|
340 |
+
+ interactive_state["multi_mask"]["masks"][mask_number]
|
341 |
+
* (mask_number + 1),
|
342 |
+
0,
|
343 |
+
mask_number + 1,
|
344 |
+
)
|
345 |
+
video_state["masks"][video_state["select_frame_number"]] = template_mask
|
346 |
+
else:
|
347 |
+
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
348 |
+
|
349 |
+
# operation error
|
350 |
+
if len(np.unique(template_mask)) == 1:
|
351 |
+
template_mask[0][0] = 1
|
352 |
+
operation_log = [
|
353 |
+
(
|
354 |
+
"Error! Please add at least one mask to track by clicking the left image.",
|
355 |
+
"Error",
|
356 |
+
),
|
357 |
+
("", ""),
|
358 |
+
]
|
359 |
+
# return video_output, video_state, interactive_state, operation_error
|
360 |
+
output_path = "./output/track/{}".format(video_state["video_name"])
|
361 |
+
fps = video_state["fps"]
|
362 |
+
masks, logits, painted_images = model.generator(
|
363 |
+
images=following_frames, template_mask=template_mask, write=True, fps=fps, output_path=output_path
|
364 |
+
)
|
365 |
+
# clear GPU memory
|
366 |
+
model.xmem.clear_memory()
|
367 |
+
|
368 |
+
if interactive_state["track_end_number"]:
|
369 |
+
video_state["masks"][
|
370 |
+
video_state["select_frame_number"]: interactive_state["track_end_number"]
|
371 |
+
] = masks
|
372 |
+
video_state["logits"][
|
373 |
+
video_state["select_frame_number"]: interactive_state["track_end_number"]
|
374 |
+
] = logits
|
375 |
+
video_state["painted_images"][
|
376 |
+
video_state["select_frame_number"]: interactive_state["track_end_number"]
|
377 |
+
] = painted_images
|
378 |
+
else:
|
379 |
+
video_state["masks"][video_state["select_frame_number"]:] = masks
|
380 |
+
video_state["logits"][video_state["select_frame_number"]:] = logits
|
381 |
+
video_state["painted_images"][
|
382 |
+
video_state["select_frame_number"]:
|
383 |
+
] = painted_images
|
384 |
+
|
385 |
+
interactive_state["inference_times"] += 1
|
386 |
+
|
387 |
+
print(
|
388 |
+
"For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(
|
389 |
+
interactive_state["inference_times"],
|
390 |
+
interactive_state["positive_click_times"]
|
391 |
+
+ interactive_state["negative_click_times"],
|
392 |
+
interactive_state["positive_click_times"],
|
393 |
+
interactive_state["negative_click_times"],
|
394 |
+
)
|
395 |
+
)
|
396 |
+
|
397 |
+
return output_path, video_state, interactive_state, operation_log
|
398 |
+
|
399 |
+
|
400 |
+
def blur_video(video_state, interactive_state, mask_dropdown):
|
401 |
+
operation_log = [("", ""), ("Removed the selected masks.", "Normal")]
|
402 |
+
|
403 |
+
frames = np.asarray(video_state["origin_images"])[
|
404 |
+
video_state["select_frame_number"]:interactive_state["track_end_number"]
|
405 |
+
]
|
406 |
+
fps = video_state["fps"]
|
407 |
+
output_path = "./output/blur/{}".format(video_state["video_name"])
|
408 |
+
blur_masks = np.asarray(video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]])
|
409 |
+
if len(mask_dropdown) == 0:
|
410 |
+
mask_dropdown = ["mask_001"]
|
411 |
+
mask_dropdown.sort()
|
412 |
+
# convert mask_dropdown to mask numbers
|
413 |
+
blur_mask_numbers = [
|
414 |
+
int(mask_dropdown[i].split("_")[1]) for i in range(len(mask_dropdown))
|
415 |
+
]
|
416 |
+
# interate through all masks and remove the masks that are not in mask_dropdown
|
417 |
+
unique_masks = np.unique(blur_masks)
|
418 |
+
num_masks = len(unique_masks) - 1
|
419 |
+
for i in range(1, num_masks + 1):
|
420 |
+
if i in blur_mask_numbers:
|
421 |
+
continue
|
422 |
+
blur_masks[blur_masks == i] = 0
|
423 |
+
|
424 |
+
# blur video
|
425 |
+
try:
|
426 |
+
blur_frames_and_write(
|
427 |
+
frames,
|
428 |
+
blur_masks,
|
429 |
+
ratio=interactive_state["resize_ratio"],
|
430 |
+
strength=interactive_state["blur_strength"],
|
431 |
+
fps=fps,
|
432 |
+
output_path=output_path
|
433 |
+
)
|
434 |
+
except Exception as e:
|
435 |
+
print("Exception ", e)
|
436 |
+
operation_log = [
|
437 |
+
(
|
438 |
+
"Error! You are trying to blur without masks input. Please track the selected mask first, and then press blur. To speed up, please use the resize ratio to scale down the image size.",
|
439 |
+
"Error",
|
440 |
+
),
|
441 |
+
("", ""),
|
442 |
+
]
|
443 |
+
|
444 |
+
return output_path, video_state, interactive_state, operation_log
|
445 |
+
|
446 |
+
|
447 |
+
# generate video after vos inference
|
448 |
+
def generate_video_from_frames(frames, output_path, fps=30):
|
449 |
+
"""
|
450 |
+
Generates a video from a list of frames.
|
451 |
+
|
452 |
+
Args:
|
453 |
+
frames (list of numpy arrays): The frames to include in the video.
|
454 |
+
output_path (str): The path to save the generated video.
|
455 |
+
fps (int, optional): The frame rate of the output video. Defaults to 30.
|
456 |
+
"""
|
457 |
+
|
458 |
+
frames = torch.from_numpy(np.asarray(frames))
|
459 |
+
if not os.path.exists(os.path.dirname(output_path)):
|
460 |
+
os.makedirs(os.path.dirname(output_path))
|
461 |
+
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
462 |
+
return output_path
|
463 |
+
|
464 |
+
|
465 |
+
# convert to onnx quantized model
|
466 |
+
def convert_to_onnx(args, checkpoint, quantized=True):
|
467 |
+
"""
|
468 |
+
Convert the model to onnx format.
|
469 |
+
|
470 |
+
Args:
|
471 |
+
model (nn.Module): The model to convert.
|
472 |
+
output_path (str): The path to save the onnx model.
|
473 |
+
input_shape (tuple): The input shape of the model.
|
474 |
+
quantized (bool, optional): Whether to quantize the model. Defaults to True.
|
475 |
+
"""
|
476 |
+
onnx_output_path = f"{checkpoint.split('.')[-2]}.onnx"
|
477 |
+
quant_output_path = f"{checkpoint.split('.')[-2]}_quant.onnx"
|
478 |
+
|
479 |
+
print("Converting to ONNX quantized model...")
|
480 |
+
|
481 |
+
if not (os.path.exists(onnx_output_path)):
|
482 |
+
run_export(
|
483 |
+
model_type=args.sam_model_type,
|
484 |
+
checkpoint=checkpoint,
|
485 |
+
opset=16,
|
486 |
+
output=onnx_output_path,
|
487 |
+
return_single_mask=True
|
488 |
+
)
|
489 |
+
|
490 |
+
if quantized and not (os.path.exists(quant_output_path)):
|
491 |
+
quantize_dynamic(
|
492 |
+
model_input=onnx_output_path,
|
493 |
+
model_output=quant_output_path,
|
494 |
+
optimize_model=True,
|
495 |
+
per_channel=False,
|
496 |
+
reduce_range=False,
|
497 |
+
weight_type=QuantType.QUInt8,
|
498 |
+
)
|
499 |
+
|
500 |
+
return quant_output_path if quantized else onnx_output_path
|
501 |
+
|
502 |
+
|
503 |
+
# args, defined in track_anything.py
|
504 |
+
args = parse_augment()
|
505 |
+
|
506 |
+
# check and download checkpoints if needed
|
507 |
+
SAM_checkpoint_dict = {
|
508 |
+
"vit_h": "sam_vit_h_4b8939.pth",
|
509 |
+
"vit_l": "sam_vit_l_0b3195.pth",
|
510 |
+
"vit_b": "sam_vit_b_01ec64.pth",
|
511 |
+
"vit_t": "mobile_sam.pt",
|
512 |
+
}
|
513 |
+
SAM_checkpoint_url_dict = {
|
514 |
+
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
515 |
+
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
516 |
+
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
|
517 |
+
"vit_t": "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt",
|
518 |
+
}
|
519 |
+
sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
|
520 |
+
sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
|
521 |
+
xmem_checkpoint = "XMem-s012.pth"
|
522 |
+
xmem_checkpoint_url = (
|
523 |
+
"https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
524 |
+
)
|
525 |
+
|
526 |
+
# initialize SAM, XMem
|
527 |
+
folder = "checkpoints"
|
528 |
+
sam_pt_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
529 |
+
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
530 |
+
|
531 |
+
if args.sam_model_type == "vit_t":
|
532 |
+
sam_onnx_checkpoint = convert_to_onnx(args, sam_pt_checkpoint, quantized=True)
|
533 |
+
else:
|
534 |
+
sam_onnx_checkpoint = ""
|
535 |
+
|
536 |
+
model = TrackingAnything(sam_pt_checkpoint, sam_onnx_checkpoint, xmem_checkpoint, args)
|
537 |
+
|
538 |
+
title = """<p><h1 align="center">Blur-Anything</h1></p>
|
539 |
+
"""
|
540 |
+
description = """<p>Gradio demo for Blur Anything, a flexible and interactive
|
541 |
+
tool for video object tracking, segmentation, and blurring. To
|
542 |
+
use it, simply upload your video, or click one of the examples to
|
543 |
+
load them. Code: <a
|
544 |
+
href="https://github.com/Y-T-G/Blur-Anything">https://github.com/Y-T-G/Blur-Anything</a>
|
545 |
+
<a
|
546 |
+
href="https://huggingface.co/spaces/Y-T-G/Blur-Anything?duplicate=true"><img
|
547 |
+
style="display: inline; margin-top: 0em; margin-bottom: 0em"
|
548 |
+
src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
|
549 |
+
|
550 |
+
|
551 |
+
with gr.Blocks() as iface:
|
552 |
+
"""
|
553 |
+
state for
|
554 |
+
"""
|
555 |
+
click_state = gr.State([[], []])
|
556 |
+
interactive_state = gr.State(
|
557 |
+
{
|
558 |
+
"inference_times": 0,
|
559 |
+
"negative_click_times": 0,
|
560 |
+
"positive_click_times": 0,
|
561 |
+
"mask_save": args.mask_save,
|
562 |
+
"multi_mask": {"mask_names": [], "masks": []},
|
563 |
+
"track_end_number": None,
|
564 |
+
"resize_ratio": 1,
|
565 |
+
"blur_strength": 3,
|
566 |
+
}
|
567 |
+
)
|
568 |
+
|
569 |
+
video_state = gr.State(
|
570 |
+
{
|
571 |
+
"user_name": "",
|
572 |
+
"video_name": "",
|
573 |
+
"origin_images": None,
|
574 |
+
"painted_images": None,
|
575 |
+
"masks": None,
|
576 |
+
"blur_masks": None,
|
577 |
+
"logits": None,
|
578 |
+
"select_frame_number": 0,
|
579 |
+
"fps": 30,
|
580 |
+
}
|
581 |
+
)
|
582 |
+
gr.Markdown(title)
|
583 |
+
gr.Markdown(description)
|
584 |
+
with gr.Row():
|
585 |
+
# for user video input
|
586 |
+
with gr.Column():
|
587 |
+
with gr.Row():
|
588 |
+
video_input = gr.Video()
|
589 |
+
with gr.Column():
|
590 |
+
video_info = gr.Textbox(label="Video Info")
|
591 |
+
resize_info = gr.Textbox(
|
592 |
+
value="You can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.",
|
593 |
+
label="Tips for running this demo.",
|
594 |
+
)
|
595 |
+
resize_ratio_slider = gr.Slider(
|
596 |
+
minimum=0.02,
|
597 |
+
maximum=1,
|
598 |
+
step=0.02,
|
599 |
+
value=1,
|
600 |
+
label="Resize ratio",
|
601 |
+
visible=True,
|
602 |
+
)
|
603 |
+
|
604 |
+
with gr.Row():
|
605 |
+
# put the template frame under the radio button
|
606 |
+
with gr.Column():
|
607 |
+
# extract frames
|
608 |
+
with gr.Column():
|
609 |
+
extract_frames_button = gr.Button(
|
610 |
+
value="Get video info", interactive=True, variant="primary"
|
611 |
+
)
|
612 |
+
|
613 |
+
# click points settins, negative or positive, mode continuous or single
|
614 |
+
with gr.Row():
|
615 |
+
with gr.Row():
|
616 |
+
point_prompt = gr.Radio(
|
617 |
+
choices=["Positive", "Negative"],
|
618 |
+
value="Positive",
|
619 |
+
label="Point Prompt",
|
620 |
+
interactive=True,
|
621 |
+
visible=False,
|
622 |
+
)
|
623 |
+
remove_mask_button = gr.Button(
|
624 |
+
value="Remove mask", interactive=True, visible=False
|
625 |
+
)
|
626 |
+
clear_button_click = gr.Button(
|
627 |
+
value="Clear Clicks", interactive=True, visible=False
|
628 |
+
)
|
629 |
+
Add_mask_button = gr.Button(
|
630 |
+
value="Add mask", interactive=True, visible=False
|
631 |
+
)
|
632 |
+
template_frame = gr.Image(
|
633 |
+
type="pil",
|
634 |
+
interactive=True,
|
635 |
+
elem_id="template_frame",
|
636 |
+
visible=False,
|
637 |
+
)
|
638 |
+
image_selection_slider = gr.Slider(
|
639 |
+
minimum=1,
|
640 |
+
maximum=100,
|
641 |
+
step=1,
|
642 |
+
value=1,
|
643 |
+
label="Image Selection",
|
644 |
+
visible=False,
|
645 |
+
)
|
646 |
+
track_pause_number_slider = gr.Slider(
|
647 |
+
minimum=1,
|
648 |
+
maximum=100,
|
649 |
+
step=1,
|
650 |
+
value=1,
|
651 |
+
label="Track end frames",
|
652 |
+
visible=False,
|
653 |
+
)
|
654 |
+
|
655 |
+
with gr.Column():
|
656 |
+
run_status = gr.HighlightedText(
|
657 |
+
value=[
|
658 |
+
("Text", "Error"),
|
659 |
+
("to be", "Label 2"),
|
660 |
+
("highlighted", "Label 3"),
|
661 |
+
],
|
662 |
+
visible=False,
|
663 |
+
)
|
664 |
+
mask_dropdown = gr.Dropdown(
|
665 |
+
multiselect=True,
|
666 |
+
value=[],
|
667 |
+
label="Mask selection",
|
668 |
+
info=".",
|
669 |
+
visible=False,
|
670 |
+
)
|
671 |
+
video_output = gr.Video(visible=False)
|
672 |
+
with gr.Row():
|
673 |
+
tracking_video_predict_button = gr.Button(
|
674 |
+
value="Tracking", visible=False
|
675 |
+
)
|
676 |
+
blur_video_predict_button = gr.Button(
|
677 |
+
value="Blur", visible=False
|
678 |
+
)
|
679 |
+
with gr.Row():
|
680 |
+
blur_strength_slider = gr.Slider(
|
681 |
+
minimum=3,
|
682 |
+
maximum=15,
|
683 |
+
step=2,
|
684 |
+
value=3,
|
685 |
+
label="Blur Strength",
|
686 |
+
visible=False,
|
687 |
+
)
|
688 |
+
|
689 |
+
# first step: get the video information
|
690 |
+
extract_frames_button.click(
|
691 |
+
fn=get_frames_from_video,
|
692 |
+
inputs=[video_input, video_state],
|
693 |
+
outputs=[
|
694 |
+
video_state,
|
695 |
+
video_info,
|
696 |
+
template_frame,
|
697 |
+
image_selection_slider,
|
698 |
+
track_pause_number_slider,
|
699 |
+
point_prompt,
|
700 |
+
clear_button_click,
|
701 |
+
Add_mask_button,
|
702 |
+
template_frame,
|
703 |
+
tracking_video_predict_button,
|
704 |
+
video_output,
|
705 |
+
mask_dropdown,
|
706 |
+
remove_mask_button,
|
707 |
+
blur_video_predict_button,
|
708 |
+
blur_strength_slider,
|
709 |
+
run_status,
|
710 |
+
],
|
711 |
+
)
|
712 |
+
|
713 |
+
# second step: select images from slider
|
714 |
+
image_selection_slider.release(
|
715 |
+
fn=select_template,
|
716 |
+
inputs=[image_selection_slider, video_state, interactive_state],
|
717 |
+
outputs=[template_frame, video_state, interactive_state, run_status],
|
718 |
+
api_name="select_image",
|
719 |
+
)
|
720 |
+
track_pause_number_slider.release(
|
721 |
+
fn=set_end_number,
|
722 |
+
inputs=[track_pause_number_slider, video_state, interactive_state],
|
723 |
+
outputs=[interactive_state, run_status],
|
724 |
+
api_name="end_image",
|
725 |
+
)
|
726 |
+
resize_ratio_slider.release(
|
727 |
+
fn=get_resize_ratio,
|
728 |
+
inputs=[resize_ratio_slider, interactive_state],
|
729 |
+
outputs=[interactive_state],
|
730 |
+
api_name="resize_ratio",
|
731 |
+
)
|
732 |
+
|
733 |
+
blur_strength_slider.release(
|
734 |
+
fn=get_blur_strength,
|
735 |
+
inputs=[blur_strength_slider, interactive_state],
|
736 |
+
outputs=[interactive_state],
|
737 |
+
api_name="blur_strength",
|
738 |
+
)
|
739 |
+
|
740 |
+
# click select image to get mask using sam
|
741 |
+
template_frame.select(
|
742 |
+
fn=sam_refine,
|
743 |
+
inputs=[video_state, point_prompt, click_state, interactive_state],
|
744 |
+
outputs=[template_frame, video_state, interactive_state, run_status],
|
745 |
+
)
|
746 |
+
|
747 |
+
# add different mask
|
748 |
+
Add_mask_button.click(
|
749 |
+
fn=add_multi_mask,
|
750 |
+
inputs=[video_state, interactive_state, mask_dropdown],
|
751 |
+
outputs=[
|
752 |
+
interactive_state,
|
753 |
+
mask_dropdown,
|
754 |
+
template_frame,
|
755 |
+
click_state,
|
756 |
+
run_status,
|
757 |
+
],
|
758 |
+
)
|
759 |
+
|
760 |
+
remove_mask_button.click(
|
761 |
+
fn=remove_multi_mask,
|
762 |
+
inputs=[interactive_state, mask_dropdown],
|
763 |
+
outputs=[interactive_state, mask_dropdown, run_status],
|
764 |
+
)
|
765 |
+
|
766 |
+
# tracking video from select image and mask
|
767 |
+
tracking_video_predict_button.click(
|
768 |
+
fn=vos_tracking_video,
|
769 |
+
inputs=[video_state, interactive_state, mask_dropdown],
|
770 |
+
outputs=[video_output, video_state, interactive_state, run_status],
|
771 |
+
)
|
772 |
+
|
773 |
+
# tracking video from select image and mask
|
774 |
+
blur_video_predict_button.click(
|
775 |
+
fn=blur_video,
|
776 |
+
inputs=[video_state, interactive_state, mask_dropdown],
|
777 |
+
outputs=[video_output, video_state, interactive_state, run_status],
|
778 |
+
)
|
779 |
+
|
780 |
+
# click to get mask
|
781 |
+
mask_dropdown.change(
|
782 |
+
fn=show_mask,
|
783 |
+
inputs=[video_state, interactive_state, mask_dropdown],
|
784 |
+
outputs=[template_frame, run_status],
|
785 |
+
)
|
786 |
+
|
787 |
+
# clear input
|
788 |
+
video_input.clear(
|
789 |
+
lambda: (
|
790 |
+
{
|
791 |
+
"user_name": "",
|
792 |
+
"video_name": "",
|
793 |
+
"origin_images": None,
|
794 |
+
"painted_images": None,
|
795 |
+
"masks": None,
|
796 |
+
"blur_masks": None,
|
797 |
+
"logits": None,
|
798 |
+
"select_frame_number": 0,
|
799 |
+
"fps": 30,
|
800 |
+
},
|
801 |
+
{
|
802 |
+
"inference_times": 0,
|
803 |
+
"negative_click_times": 0,
|
804 |
+
"positive_click_times": 0,
|
805 |
+
"mask_save": args.mask_save,
|
806 |
+
"multi_mask": {"mask_names": [], "masks": []},
|
807 |
+
"track_end_number": 0,
|
808 |
+
"resize_ratio": 1,
|
809 |
+
"blur_strength": 3,
|
810 |
+
},
|
811 |
+
[[], []],
|
812 |
+
None,
|
813 |
+
None,
|
814 |
+
gr.update(visible=False),
|
815 |
+
gr.update(visible=False),
|
816 |
+
gr.update(visible=False),
|
817 |
+
gr.update(visible=False),
|
818 |
+
gr.update(visible=False),
|
819 |
+
gr.update(visible=False),
|
820 |
+
gr.update(visible=False),
|
821 |
+
gr.update(visible=False),
|
822 |
+
gr.update(visible=False),
|
823 |
+
gr.update(visible=False, value=[]),
|
824 |
+
gr.update(visible=False),
|
825 |
+
gr.update(visible=False),
|
826 |
+
gr.update(visible=False),
|
827 |
+
),
|
828 |
+
[],
|
829 |
+
[
|
830 |
+
video_state,
|
831 |
+
interactive_state,
|
832 |
+
click_state,
|
833 |
+
video_output,
|
834 |
+
template_frame,
|
835 |
+
tracking_video_predict_button,
|
836 |
+
image_selection_slider,
|
837 |
+
track_pause_number_slider,
|
838 |
+
point_prompt,
|
839 |
+
clear_button_click,
|
840 |
+
Add_mask_button,
|
841 |
+
template_frame,
|
842 |
+
tracking_video_predict_button,
|
843 |
+
video_output,
|
844 |
+
mask_dropdown,
|
845 |
+
remove_mask_button,
|
846 |
+
blur_video_predict_button,
|
847 |
+
blur_strength_slider,
|
848 |
+
run_status,
|
849 |
+
],
|
850 |
+
queue=False,
|
851 |
+
show_progress=False,
|
852 |
+
)
|
853 |
+
|
854 |
+
# points clear
|
855 |
+
clear_button_click.click(
|
856 |
+
fn=clear_click,
|
857 |
+
inputs=[
|
858 |
+
video_state,
|
859 |
+
click_state,
|
860 |
+
],
|
861 |
+
outputs=[template_frame, click_state, run_status],
|
862 |
+
)
|
863 |
+
# set example
|
864 |
+
gr.Markdown("## Examples")
|
865 |
+
gr.Examples(
|
866 |
+
examples=[
|
867 |
+
os.path.join(os.path.dirname(__file__), "./data/", test_sample)
|
868 |
+
for test_sample in [
|
869 |
+
"sample-1.mp4",
|
870 |
+
"sample-2.mp4",
|
871 |
+
]
|
872 |
+
],
|
873 |
+
fn=run_example,
|
874 |
+
inputs=[video_input],
|
875 |
+
outputs=[video_input],
|
876 |
+
)
|
877 |
+
iface.queue(concurrency_count=1)
|
878 |
+
iface.launch(
|
879 |
+
debug=True, enable_queue=True
|
880 |
+
)
|
data/sample-1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dc49f2d9f5f00775248b8a66228f3e42304bbc391013d23ac66d21ba1f0e5fd2
|
3 |
+
size 664422
|
data/sample-2.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:45ba5eb410e9d25744946afe61abff9e2ab0916d2f206637636ae30d0decd5e9
|
3 |
+
size 1369798
|
export_onnx_model.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from mobile_sam import sam_model_registry
|
10 |
+
from mobile_sam.utils.onnx import SamOnnxModel
|
11 |
+
|
12 |
+
import argparse
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
try:
|
16 |
+
import onnxruntime # type: ignore
|
17 |
+
|
18 |
+
onnxruntime_exists = True
|
19 |
+
except ImportError:
|
20 |
+
onnxruntime_exists = False
|
21 |
+
|
22 |
+
parser = argparse.ArgumentParser(
|
23 |
+
description="Export the SAM prompt encoder and mask decoder to an ONNX model."
|
24 |
+
)
|
25 |
+
|
26 |
+
parser.add_argument(
|
27 |
+
"--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
|
28 |
+
)
|
29 |
+
|
30 |
+
parser.add_argument(
|
31 |
+
"--output", type=str, required=True, help="The filename to save the ONNX model to."
|
32 |
+
)
|
33 |
+
|
34 |
+
parser.add_argument(
|
35 |
+
"--model-type",
|
36 |
+
type=str,
|
37 |
+
required=True,
|
38 |
+
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
|
39 |
+
)
|
40 |
+
|
41 |
+
parser.add_argument(
|
42 |
+
"--return-single-mask",
|
43 |
+
action="store_true",
|
44 |
+
help=(
|
45 |
+
"If true, the exported ONNX model will only return the best mask, "
|
46 |
+
"instead of returning multiple masks. For high resolution images "
|
47 |
+
"this can improve runtime when upscaling masks is expensive."
|
48 |
+
),
|
49 |
+
)
|
50 |
+
|
51 |
+
parser.add_argument(
|
52 |
+
"--opset",
|
53 |
+
type=int,
|
54 |
+
default=16,
|
55 |
+
help="The ONNX opset version to use. Must be >=11",
|
56 |
+
)
|
57 |
+
|
58 |
+
parser.add_argument(
|
59 |
+
"--quantize-out",
|
60 |
+
type=str,
|
61 |
+
default=None,
|
62 |
+
help=(
|
63 |
+
"If set, will quantize the model and save it with this name. "
|
64 |
+
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
|
65 |
+
),
|
66 |
+
)
|
67 |
+
|
68 |
+
parser.add_argument(
|
69 |
+
"--gelu-approximate",
|
70 |
+
action="store_true",
|
71 |
+
help=(
|
72 |
+
"Replace GELU operations with approximations using tanh. Useful "
|
73 |
+
"for some runtimes that have slow or unimplemented erf ops, used in GELU."
|
74 |
+
),
|
75 |
+
)
|
76 |
+
|
77 |
+
parser.add_argument(
|
78 |
+
"--use-stability-score",
|
79 |
+
action="store_true",
|
80 |
+
help=(
|
81 |
+
"Replaces the model's predicted mask quality score with the stability "
|
82 |
+
"score calculated on the low resolution masks using an offset of 1.0. "
|
83 |
+
),
|
84 |
+
)
|
85 |
+
|
86 |
+
parser.add_argument(
|
87 |
+
"--return-extra-metrics",
|
88 |
+
action="store_true",
|
89 |
+
help=(
|
90 |
+
"The model will return five results: (masks, scores, stability_scores, "
|
91 |
+
"areas, low_res_logits) instead of the usual three. This can be "
|
92 |
+
"significantly slower for high resolution outputs."
|
93 |
+
),
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
def run_export(
|
98 |
+
model_type: str,
|
99 |
+
checkpoint: str,
|
100 |
+
output: str,
|
101 |
+
opset: int,
|
102 |
+
return_single_mask: bool,
|
103 |
+
gelu_approximate: bool = False,
|
104 |
+
use_stability_score: bool = False,
|
105 |
+
return_extra_metrics=False,
|
106 |
+
):
|
107 |
+
print("Loading model...")
|
108 |
+
sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
109 |
+
|
110 |
+
onnx_model = SamOnnxModel(
|
111 |
+
model=sam,
|
112 |
+
return_single_mask=return_single_mask,
|
113 |
+
use_stability_score=use_stability_score,
|
114 |
+
return_extra_metrics=return_extra_metrics,
|
115 |
+
)
|
116 |
+
|
117 |
+
if gelu_approximate:
|
118 |
+
for n, m in onnx_model.named_modules():
|
119 |
+
if isinstance(m, torch.nn.GELU):
|
120 |
+
m.approximate = "tanh"
|
121 |
+
|
122 |
+
dynamic_axes = {
|
123 |
+
"point_coords": {1: "num_points"},
|
124 |
+
"point_labels": {1: "num_points"},
|
125 |
+
}
|
126 |
+
|
127 |
+
embed_dim = sam.prompt_encoder.embed_dim
|
128 |
+
embed_size = sam.prompt_encoder.image_embedding_size
|
129 |
+
mask_input_size = [4 * x for x in embed_size]
|
130 |
+
dummy_inputs = {
|
131 |
+
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
|
132 |
+
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
|
133 |
+
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
|
134 |
+
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
|
135 |
+
"has_mask_input": torch.tensor([1], dtype=torch.float),
|
136 |
+
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
|
137 |
+
}
|
138 |
+
|
139 |
+
_ = onnx_model(**dummy_inputs)
|
140 |
+
|
141 |
+
output_names = ["masks", "iou_predictions", "low_res_masks"]
|
142 |
+
|
143 |
+
with warnings.catch_warnings():
|
144 |
+
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
145 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
146 |
+
with open(output, "wb") as f:
|
147 |
+
print(f"Exporting onnx model to {output}...")
|
148 |
+
torch.onnx.export(
|
149 |
+
onnx_model,
|
150 |
+
tuple(dummy_inputs.values()),
|
151 |
+
f,
|
152 |
+
export_params=True,
|
153 |
+
verbose=False,
|
154 |
+
opset_version=opset,
|
155 |
+
do_constant_folding=True,
|
156 |
+
input_names=list(dummy_inputs.keys()),
|
157 |
+
output_names=output_names,
|
158 |
+
dynamic_axes=dynamic_axes,
|
159 |
+
)
|
160 |
+
|
161 |
+
if onnxruntime_exists:
|
162 |
+
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
|
163 |
+
# set cpu provider default
|
164 |
+
providers = ["CPUExecutionProvider"]
|
165 |
+
ort_session = onnxruntime.InferenceSession(output, providers=providers)
|
166 |
+
_ = ort_session.run(None, ort_inputs)
|
167 |
+
print("Model has successfully been run with ONNXRuntime.")
|
168 |
+
|
169 |
+
|
170 |
+
def to_numpy(tensor):
|
171 |
+
return tensor.cpu().numpy()
|
172 |
+
|
173 |
+
|
174 |
+
if __name__ == "__main__":
|
175 |
+
args = parser.parse_args()
|
176 |
+
run_export(
|
177 |
+
model_type=args.model_type,
|
178 |
+
checkpoint=args.checkpoint,
|
179 |
+
output=args.output,
|
180 |
+
opset=args.opset,
|
181 |
+
return_single_mask=args.return_single_mask,
|
182 |
+
gelu_approximate=args.gelu_approximate,
|
183 |
+
use_stability_score=args.use_stability_score,
|
184 |
+
return_extra_metrics=args.return_extra_metrics,
|
185 |
+
)
|
186 |
+
|
187 |
+
if args.quantize_out is not None:
|
188 |
+
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
|
189 |
+
from onnxruntime.quantization import QuantType # type: ignore
|
190 |
+
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
|
191 |
+
|
192 |
+
print(f"Quantizing model and writing to {args.quantize_out}...")
|
193 |
+
quantize_dynamic(
|
194 |
+
model_input=args.output,
|
195 |
+
model_output=args.quantize_out,
|
196 |
+
optimize_model=True,
|
197 |
+
per_channel=False,
|
198 |
+
reduce_range=False,
|
199 |
+
weight_type=QuantType.QUInt8,
|
200 |
+
)
|
201 |
+
print("Done!")
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python3-opencv
|
poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "Blur-Anything"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Track and blur any object or person in a video."
|
5 |
+
authors = ["Y-T-G <yaseensinbox@gmail.com>"]
|
6 |
+
license = "MIT"
|
7 |
+
readme = "README.md"
|
8 |
+
packages = [{include = "blur_anything"}]
|
9 |
+
|
10 |
+
[tool.poetry.dependencies]
|
11 |
+
python = "^3.9"
|
12 |
+
gradio = "^3.28.1"
|
13 |
+
numpy = "^1.24.3"
|
14 |
+
av = "^10.0.0"
|
15 |
+
torch = "^2.0.0"
|
16 |
+
opencv-python = "^4.7.0.72"
|
17 |
+
psutil = "^5.9.5"
|
18 |
+
tqdm = "^4.65.0"
|
19 |
+
matplotlib = "^3.7.1"
|
20 |
+
segment-anything = {git = "https://github.com/facebookresearch/segment-anything.git"}
|
21 |
+
torchvision = "^0.15.1"
|
22 |
+
pims = "^0.6.1"
|
23 |
+
mobile-sam = {git = "https://github.com/ChaoningZhang/MobileSAM.git"}
|
24 |
+
onnxruntime = "^1.15.1"
|
25 |
+
timm = "^0.9.5"
|
26 |
+
onnx = "^1.14.0"
|
27 |
+
|
28 |
+
|
29 |
+
[build-system]
|
30 |
+
requires = ["poetry-core"]
|
31 |
+
build-backend = "poetry.core.masonry.api"
|
requirements.txt
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.1.0
|
2 |
+
aiohttp==3.8.4
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==4.2.2
|
5 |
+
anyio==3.6.2
|
6 |
+
async-timeout==4.0.2
|
7 |
+
attrs==23.1.0
|
8 |
+
av==10.0.0
|
9 |
+
certifi==2022.12.7
|
10 |
+
charset-normalizer==3.1.0
|
11 |
+
click==8.1.3
|
12 |
+
cmake==3.26.3
|
13 |
+
colorama==0.4.6
|
14 |
+
coloredlogs==15.0.1
|
15 |
+
contourpy==1.0.7
|
16 |
+
cycler==0.11.0
|
17 |
+
entrypoints==0.4
|
18 |
+
fastapi==0.95.1
|
19 |
+
ffmpy==0.3.0
|
20 |
+
filelock==3.12.0
|
21 |
+
flatbuffers==23.5.26
|
22 |
+
fonttools==4.39.3
|
23 |
+
frozenlist==1.3.3
|
24 |
+
fsspec==2023.4.0
|
25 |
+
gradio-client==0.1.4
|
26 |
+
gradio==3.28.3
|
27 |
+
h11==0.14.0
|
28 |
+
httpcore==0.17.0
|
29 |
+
httpx==0.24.0
|
30 |
+
huggingface-hub==0.14.1
|
31 |
+
humanfriendly==10.0
|
32 |
+
idna==3.4
|
33 |
+
imageio==2.28.1
|
34 |
+
importlib-resources==5.12.0
|
35 |
+
jinja2==3.1.2
|
36 |
+
jsonschema==4.17.3
|
37 |
+
kiwisolver==1.4.4
|
38 |
+
linkify-it-py==2.0.2
|
39 |
+
lit==16.0.2
|
40 |
+
markdown-it-py==2.2.0
|
41 |
+
markdown-it-py[linkify]==2.2.0
|
42 |
+
markupsafe==2.1.2
|
43 |
+
matplotlib==3.7.1
|
44 |
+
mdit-py-plugins==0.3.3
|
45 |
+
mdurl==0.1.2
|
46 |
+
mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git
|
47 |
+
mpmath==1.3.0
|
48 |
+
multidict==6.0.4
|
49 |
+
networkx==3.1
|
50 |
+
numpy==1.24.3
|
51 |
+
nvidia-cublas-cu11==11.10.3.66
|
52 |
+
nvidia-cuda-cupti-cu11==11.7.101
|
53 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
54 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
55 |
+
nvidia-cudnn-cu11==8.5.0.96
|
56 |
+
nvidia-cufft-cu11==10.9.0.58
|
57 |
+
nvidia-curand-cu11==10.2.10.91
|
58 |
+
nvidia-cusolver-cu11==11.4.0.1
|
59 |
+
nvidia-cusparse-cu11==11.7.4.91
|
60 |
+
nvidia-nccl-cu11==2.14.3
|
61 |
+
nvidia-nvtx-cu11==11.7.91
|
62 |
+
onnx==1.14.0
|
63 |
+
onnxruntime==1.15.1
|
64 |
+
opencv-python==4.7.0.72
|
65 |
+
orjson==3.8.11
|
66 |
+
packaging==23.1
|
67 |
+
pandas==2.0.1
|
68 |
+
pillow==9.5.0
|
69 |
+
pims==0.6.1
|
70 |
+
protobuf==4.24.0
|
71 |
+
psutil==5.9.5
|
72 |
+
pydantic==1.10.7
|
73 |
+
pydub==0.25.1
|
74 |
+
pygments==2.15.1
|
75 |
+
pyparsing==3.0.9
|
76 |
+
pyreadline3==3.4.1
|
77 |
+
pyrsistent==0.19.3
|
78 |
+
python-dateutil==2.8.2
|
79 |
+
python-multipart==0.0.6
|
80 |
+
pytz==2023.3
|
81 |
+
pyyaml==6.0
|
82 |
+
requests==2.30.0
|
83 |
+
safetensors==0.3.2
|
84 |
+
segment-anything @ git+https://github.com/facebookresearch/segment-anything.git
|
85 |
+
semantic-version==2.10.0
|
86 |
+
setuptools==67.7.2
|
87 |
+
six==1.16.0
|
88 |
+
slicerator==1.1.0
|
89 |
+
sniffio==1.3.0
|
90 |
+
starlette==0.26.1
|
91 |
+
sympy==1.11.1
|
92 |
+
timm==0.9.5
|
93 |
+
toolz==0.12.0
|
94 |
+
torch==2.0.0
|
95 |
+
torchvision==0.15.1
|
96 |
+
tqdm==4.65.0
|
97 |
+
triton==2.0.0
|
98 |
+
typing-extensions==4.5.0
|
99 |
+
tzdata==2023.3
|
100 |
+
uc-micro-py==1.0.2
|
101 |
+
urllib3==2.0.2
|
102 |
+
uvicorn==0.22.0
|
103 |
+
websockets==11.0.2
|
104 |
+
wheel==0.40.0
|
105 |
+
yarl==1.9.2
|
106 |
+
zipp==3.15.0
|
track_anything.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
from utils.interact_tools import SamControler
|
5 |
+
from tracker.base_tracker import BaseTracker
|
6 |
+
import numpy as np
|
7 |
+
import argparse
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
|
13 |
+
class TrackingAnything:
|
14 |
+
def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, xmem_checkpoint, args):
|
15 |
+
self.args = args
|
16 |
+
self.sam_pt_checkpoint = sam_pt_checkpoint
|
17 |
+
self.sam_onnx_checkpoint = sam_onnx_checkpoint
|
18 |
+
self.xmem_checkpoint = xmem_checkpoint
|
19 |
+
self.samcontroler = SamControler(
|
20 |
+
self.sam_pt_checkpoint, self.sam_onnx_checkpoint, args.sam_model_type, args.device
|
21 |
+
)
|
22 |
+
self.xmem = BaseTracker(self.xmem_checkpoint, device=args.device)
|
23 |
+
|
24 |
+
def first_frame_click(
|
25 |
+
self, image: np.ndarray, points: np.ndarray, labels: np.ndarray, multimask=True
|
26 |
+
):
|
27 |
+
mask, logit, painted_image = self.samcontroler.first_frame_click(
|
28 |
+
image, points, labels, multimask
|
29 |
+
)
|
30 |
+
return mask, logit, painted_image
|
31 |
+
|
32 |
+
def generator(
|
33 |
+
self,
|
34 |
+
images: list,
|
35 |
+
template_mask: np.ndarray,
|
36 |
+
write: Optional[bool] = False,
|
37 |
+
fps: Optional[int] = "30",
|
38 |
+
output_path: Optional[str] = "tracking.mp4",
|
39 |
+
):
|
40 |
+
masks = []
|
41 |
+
logits = []
|
42 |
+
painted_images = []
|
43 |
+
|
44 |
+
if write:
|
45 |
+
size = images[0].shape[:2][::-1]
|
46 |
+
if not os.path.exists(os.path.dirname(output_path)):
|
47 |
+
os.makedirs(os.path.dirname(output_path))
|
48 |
+
writer = cv2.VideoWriter(
|
49 |
+
output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, size
|
50 |
+
)
|
51 |
+
|
52 |
+
for i in tqdm(range(len(images)), desc="Tracking image"):
|
53 |
+
if i == 0:
|
54 |
+
mask, logit, painted_image = self.xmem.track(images[i], template_mask)
|
55 |
+
else:
|
56 |
+
mask, logit, painted_image = self.xmem.track(images[i])
|
57 |
+
|
58 |
+
masks.append(mask)
|
59 |
+
logits.append(logit)
|
60 |
+
|
61 |
+
if write:
|
62 |
+
writer.write(painted_image[:,:,::-1])
|
63 |
+
else:
|
64 |
+
painted_images.append(painted_image)
|
65 |
+
|
66 |
+
if write:
|
67 |
+
writer.release()
|
68 |
+
|
69 |
+
return masks, logits, painted_images
|
70 |
+
|
71 |
+
|
72 |
+
def parse_augment():
|
73 |
+
parser = argparse.ArgumentParser()
|
74 |
+
parser.add_argument("--device", type=str, default="cpu")
|
75 |
+
parser.add_argument("--sam_model_type", type=str, default="vit_t")
|
76 |
+
parser.add_argument(
|
77 |
+
"--port",
|
78 |
+
type=int,
|
79 |
+
default=6080,
|
80 |
+
help="only useful when running gradio applications",
|
81 |
+
)
|
82 |
+
parser.add_argument("--debug", action="store_true")
|
83 |
+
parser.add_argument("--mask_save", default=False)
|
84 |
+
args = parser.parse_args()
|
85 |
+
|
86 |
+
if args.debug:
|
87 |
+
print(args)
|
88 |
+
return args
|
tracker/base_tracker.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import for debugging
|
2 |
+
import os
|
3 |
+
import glob
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
# import for base_tracker
|
8 |
+
import torch
|
9 |
+
import yaml
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from tracker.model.network import XMem
|
12 |
+
from inference.inference_core import InferenceCore
|
13 |
+
from tracker.util.mask_mapper import MaskMapper
|
14 |
+
from torchvision import transforms
|
15 |
+
from tracker.util.range_transform import im_normalization
|
16 |
+
|
17 |
+
from utils.painter import mask_painter
|
18 |
+
|
19 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
20 |
+
|
21 |
+
|
22 |
+
class BaseTracker:
|
23 |
+
def __init__(
|
24 |
+
self, xmem_checkpoint, device, sam_model=None, model_type=None
|
25 |
+
) -> None:
|
26 |
+
"""
|
27 |
+
device: model device
|
28 |
+
xmem_checkpoint: checkpoint of XMem model
|
29 |
+
"""
|
30 |
+
# load configurations
|
31 |
+
with open(f"{dir_path}/config/config.yaml", "r") as stream:
|
32 |
+
config = yaml.safe_load(stream)
|
33 |
+
# initialise XMem
|
34 |
+
network = XMem(config, xmem_checkpoint, map_location=device).eval()
|
35 |
+
# initialise IncerenceCore
|
36 |
+
self.tracker = InferenceCore(network, config)
|
37 |
+
# data transformation
|
38 |
+
self.im_transform = transforms.Compose(
|
39 |
+
[
|
40 |
+
transforms.ToTensor(),
|
41 |
+
im_normalization,
|
42 |
+
]
|
43 |
+
)
|
44 |
+
self.device = device
|
45 |
+
|
46 |
+
# changable properties
|
47 |
+
self.mapper = MaskMapper()
|
48 |
+
self.initialised = False
|
49 |
+
|
50 |
+
# # SAM-based refinement
|
51 |
+
# self.sam_model = sam_model
|
52 |
+
# self.resizer = Resize([256, 256])
|
53 |
+
|
54 |
+
@torch.no_grad()
|
55 |
+
def resize_mask(self, mask):
|
56 |
+
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
|
57 |
+
h, w = mask.shape[-2:]
|
58 |
+
min_hw = min(h, w)
|
59 |
+
return F.interpolate(
|
60 |
+
mask,
|
61 |
+
(int(h / min_hw * self.size), int(w / min_hw * self.size)),
|
62 |
+
mode="nearest",
|
63 |
+
)
|
64 |
+
|
65 |
+
@torch.no_grad()
|
66 |
+
def track(self, frame, first_frame_annotation=None):
|
67 |
+
"""
|
68 |
+
Input:
|
69 |
+
frames: numpy arrays (H, W, 3)
|
70 |
+
logit: numpy array (H, W), logit
|
71 |
+
|
72 |
+
Output:
|
73 |
+
mask: numpy arrays (H, W)
|
74 |
+
logit: numpy arrays, probability map (H, W)
|
75 |
+
painted_image: numpy array (H, W, 3)
|
76 |
+
"""
|
77 |
+
|
78 |
+
if first_frame_annotation is not None: # first frame mask
|
79 |
+
# initialisation
|
80 |
+
mask, labels = self.mapper.convert_mask(first_frame_annotation)
|
81 |
+
mask = torch.Tensor(mask).to(self.device)
|
82 |
+
self.tracker.set_all_labels(list(self.mapper.remappings.values()))
|
83 |
+
else:
|
84 |
+
mask = None
|
85 |
+
labels = None
|
86 |
+
# prepare inputs
|
87 |
+
frame_tensor = self.im_transform(frame).to(self.device)
|
88 |
+
# track one frame
|
89 |
+
probs, _ = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
|
90 |
+
# # refine
|
91 |
+
# if first_frame_annotation is None:
|
92 |
+
# out_mask = self.sam_refinement(frame, logits[1], ti)
|
93 |
+
|
94 |
+
# convert to mask
|
95 |
+
out_mask = torch.argmax(probs, dim=0)
|
96 |
+
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
97 |
+
|
98 |
+
final_mask = np.zeros_like(out_mask)
|
99 |
+
|
100 |
+
# map back
|
101 |
+
for k, v in self.mapper.remappings.items():
|
102 |
+
final_mask[out_mask == v] = k
|
103 |
+
|
104 |
+
num_objs = final_mask.max()
|
105 |
+
painted_image = frame
|
106 |
+
for obj in range(1, num_objs + 1):
|
107 |
+
if np.max(final_mask == obj) == 0:
|
108 |
+
continue
|
109 |
+
painted_image = mask_painter(
|
110 |
+
painted_image, (final_mask == obj).astype("uint8"), mask_color=obj + 1
|
111 |
+
)
|
112 |
+
|
113 |
+
# print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')
|
114 |
+
|
115 |
+
return final_mask, final_mask, painted_image
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def sam_refinement(self, frame, logits, ti):
|
119 |
+
"""
|
120 |
+
refine segmentation results with mask prompt
|
121 |
+
"""
|
122 |
+
# convert to 1, 256, 256
|
123 |
+
self.sam_model.set_image(frame)
|
124 |
+
mode = "mask"
|
125 |
+
logits = logits.unsqueeze(0)
|
126 |
+
logits = self.resizer(logits).cpu().numpy()
|
127 |
+
prompts = {"mask_input": logits} # 1 256 256
|
128 |
+
masks, scores, logits = self.sam_model.predict(
|
129 |
+
prompts, mode, multimask=True
|
130 |
+
) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
131 |
+
painted_image = mask_painter(
|
132 |
+
frame, masks[np.argmax(scores)].astype("uint8"), mask_alpha=0.8
|
133 |
+
)
|
134 |
+
painted_image = Image.fromarray(painted_image)
|
135 |
+
painted_image.save(f"/ssd1/gaomingqi/refine/{ti:05d}.png")
|
136 |
+
self.sam_model.reset_image()
|
137 |
+
|
138 |
+
@torch.no_grad()
|
139 |
+
def clear_memory(self):
|
140 |
+
self.tracker.clear_memory()
|
141 |
+
self.mapper.clear_labels()
|
142 |
+
torch.cuda.empty_cache()
|
tracker/config/config.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# config info for XMem
|
2 |
+
benchmark: False
|
3 |
+
disable_long_term: False
|
4 |
+
max_mid_term_frames: 10
|
5 |
+
min_mid_term_frames: 5
|
6 |
+
max_long_term_elements: 1000
|
7 |
+
num_prototypes: 128
|
8 |
+
top_k: 30
|
9 |
+
mem_every: 5
|
10 |
+
deep_update_every: -1
|
11 |
+
save_scores: False
|
12 |
+
flip: False
|
13 |
+
size: 480
|
14 |
+
enable_long_term: True
|
15 |
+
enable_long_term_count_usage: True
|
tracker/inference/__init__.py
ADDED
File without changes
|
tracker/inference/inference_core.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inference.memory_manager import MemoryManager
|
2 |
+
from model.network import XMem
|
3 |
+
from model.aggregate import aggregate
|
4 |
+
|
5 |
+
from tracker.util.tensor_util import pad_divide_by, unpad
|
6 |
+
|
7 |
+
|
8 |
+
class InferenceCore:
|
9 |
+
def __init__(self, network: XMem, config):
|
10 |
+
self.config = config
|
11 |
+
self.network = network
|
12 |
+
self.mem_every = config["mem_every"]
|
13 |
+
self.deep_update_every = config["deep_update_every"]
|
14 |
+
self.enable_long_term = config["enable_long_term"]
|
15 |
+
|
16 |
+
# if deep_update_every < 0, synchronize deep update with memory frame
|
17 |
+
self.deep_update_sync = self.deep_update_every < 0
|
18 |
+
|
19 |
+
self.clear_memory()
|
20 |
+
self.all_labels = None
|
21 |
+
|
22 |
+
def clear_memory(self):
|
23 |
+
self.curr_ti = -1
|
24 |
+
self.last_mem_ti = 0
|
25 |
+
if not self.deep_update_sync:
|
26 |
+
self.last_deep_update_ti = -self.deep_update_every
|
27 |
+
self.memory = MemoryManager(config=self.config)
|
28 |
+
|
29 |
+
def update_config(self, config):
|
30 |
+
self.mem_every = config["mem_every"]
|
31 |
+
self.deep_update_every = config["deep_update_every"]
|
32 |
+
self.enable_long_term = config["enable_long_term"]
|
33 |
+
|
34 |
+
# if deep_update_every < 0, synchronize deep update with memory frame
|
35 |
+
self.deep_update_sync = self.deep_update_every < 0
|
36 |
+
self.memory.update_config(config)
|
37 |
+
|
38 |
+
def set_all_labels(self, all_labels):
|
39 |
+
# self.all_labels = [l.item() for l in all_labels]
|
40 |
+
self.all_labels = all_labels
|
41 |
+
|
42 |
+
def step(self, image, mask=None, valid_labels=None, end=False):
|
43 |
+
# image: 3*H*W
|
44 |
+
# mask: num_objects*H*W or None
|
45 |
+
self.curr_ti += 1
|
46 |
+
image, self.pad = pad_divide_by(image, 16)
|
47 |
+
image = image.unsqueeze(0) # add the batch dimension
|
48 |
+
|
49 |
+
is_mem_frame = (
|
50 |
+
(self.curr_ti - self.last_mem_ti >= self.mem_every) or (mask is not None)
|
51 |
+
) and (not end)
|
52 |
+
need_segment = (self.curr_ti > 0) and (
|
53 |
+
(valid_labels is None) or (len(self.all_labels) != len(valid_labels))
|
54 |
+
)
|
55 |
+
is_deep_update = (
|
56 |
+
(self.deep_update_sync and is_mem_frame)
|
57 |
+
or ( # synchronized
|
58 |
+
not self.deep_update_sync
|
59 |
+
and self.curr_ti - self.last_deep_update_ti >= self.deep_update_every
|
60 |
+
) # no-sync
|
61 |
+
) and (not end)
|
62 |
+
is_normal_update = (not self.deep_update_sync or not is_deep_update) and (
|
63 |
+
not end
|
64 |
+
)
|
65 |
+
|
66 |
+
key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(
|
67 |
+
image, need_ek=(self.enable_long_term or need_segment), need_sk=is_mem_frame
|
68 |
+
)
|
69 |
+
multi_scale_features = (f16, f8, f4)
|
70 |
+
|
71 |
+
# segment the current frame is needed
|
72 |
+
if need_segment:
|
73 |
+
memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
|
74 |
+
|
75 |
+
hidden, pred_logits_with_bg, pred_prob_with_bg = self.network.segment(
|
76 |
+
multi_scale_features,
|
77 |
+
memory_readout,
|
78 |
+
self.memory.get_hidden(),
|
79 |
+
h_out=is_normal_update,
|
80 |
+
strip_bg=False,
|
81 |
+
)
|
82 |
+
# remove batch dim
|
83 |
+
pred_prob_with_bg = pred_prob_with_bg[0]
|
84 |
+
pred_prob_no_bg = pred_prob_with_bg[1:]
|
85 |
+
|
86 |
+
pred_logits_with_bg = pred_logits_with_bg[0]
|
87 |
+
pred_logits_no_bg = pred_logits_with_bg[1:]
|
88 |
+
|
89 |
+
if is_normal_update:
|
90 |
+
self.memory.set_hidden(hidden)
|
91 |
+
else:
|
92 |
+
pred_prob_no_bg = (
|
93 |
+
pred_prob_with_bg
|
94 |
+
) = pred_logits_with_bg = pred_logits_no_bg = None
|
95 |
+
|
96 |
+
# use the input mask if any
|
97 |
+
if mask is not None:
|
98 |
+
mask, _ = pad_divide_by(mask, 16)
|
99 |
+
|
100 |
+
if pred_prob_no_bg is not None:
|
101 |
+
# if we have a predicted mask, we work on it
|
102 |
+
# make pred_prob_no_bg consistent with the input mask
|
103 |
+
mask_regions = mask.sum(0) > 0.5
|
104 |
+
pred_prob_no_bg[:, mask_regions] = 0
|
105 |
+
# shift by 1 because mask/pred_prob_no_bg do not contain background
|
106 |
+
mask = mask.type_as(pred_prob_no_bg)
|
107 |
+
if valid_labels is not None:
|
108 |
+
shift_by_one_non_labels = [
|
109 |
+
i
|
110 |
+
for i in range(pred_prob_no_bg.shape[0])
|
111 |
+
if (i + 1) not in valid_labels
|
112 |
+
]
|
113 |
+
# non-labelled objects are copied from the predicted mask
|
114 |
+
mask[shift_by_one_non_labels] = pred_prob_no_bg[
|
115 |
+
shift_by_one_non_labels
|
116 |
+
]
|
117 |
+
pred_prob_with_bg = aggregate(mask, dim=0)
|
118 |
+
|
119 |
+
# also create new hidden states
|
120 |
+
self.memory.create_hidden_state(len(self.all_labels), key)
|
121 |
+
|
122 |
+
# save as memory if needed
|
123 |
+
if is_mem_frame:
|
124 |
+
value, hidden = self.network.encode_value(
|
125 |
+
image,
|
126 |
+
f16,
|
127 |
+
self.memory.get_hidden(),
|
128 |
+
pred_prob_with_bg[1:].unsqueeze(0),
|
129 |
+
is_deep_update=is_deep_update,
|
130 |
+
)
|
131 |
+
self.memory.add_memory(
|
132 |
+
key,
|
133 |
+
shrinkage,
|
134 |
+
value,
|
135 |
+
self.all_labels,
|
136 |
+
selection=selection if self.enable_long_term else None,
|
137 |
+
)
|
138 |
+
self.last_mem_ti = self.curr_ti
|
139 |
+
|
140 |
+
if is_deep_update:
|
141 |
+
self.memory.set_hidden(hidden)
|
142 |
+
self.last_deep_update_ti = self.curr_ti
|
143 |
+
|
144 |
+
if pred_logits_with_bg is None:
|
145 |
+
return unpad(pred_prob_with_bg, self.pad), None
|
146 |
+
else:
|
147 |
+
return unpad(pred_prob_with_bg, self.pad), unpad(
|
148 |
+
pred_logits_with_bg, self.pad
|
149 |
+
)
|
tracker/inference/kv_memory_store.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
class KeyValueMemoryStore:
|
6 |
+
"""
|
7 |
+
Works for key/value pairs type storage
|
8 |
+
e.g., working and long-term memory
|
9 |
+
"""
|
10 |
+
|
11 |
+
"""
|
12 |
+
An object group is created when new objects enter the video
|
13 |
+
Objects in the same group share the same temporal extent
|
14 |
+
i.e., objects initialized in the same frame are in the same group
|
15 |
+
For DAVIS/interactive, there is only one object group
|
16 |
+
For YouTubeVOS, there can be multiple object groups
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, count_usage: bool):
|
20 |
+
self.count_usage = count_usage
|
21 |
+
|
22 |
+
# keys are stored in a single tensor and are shared between groups/objects
|
23 |
+
# values are stored as a list indexed by object groups
|
24 |
+
self.k = None
|
25 |
+
self.v = []
|
26 |
+
self.obj_groups = []
|
27 |
+
# for debugging only
|
28 |
+
self.all_objects = []
|
29 |
+
|
30 |
+
# shrinkage and selection are also single tensors
|
31 |
+
self.s = self.e = None
|
32 |
+
|
33 |
+
# usage
|
34 |
+
if self.count_usage:
|
35 |
+
self.use_count = self.life_count = None
|
36 |
+
|
37 |
+
def add(self, key, value, shrinkage, selection, objects: List[int]):
|
38 |
+
new_count = torch.zeros(
|
39 |
+
(key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32
|
40 |
+
)
|
41 |
+
new_life = (
|
42 |
+
torch.zeros(
|
43 |
+
(key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32
|
44 |
+
)
|
45 |
+
+ 1e-7
|
46 |
+
)
|
47 |
+
|
48 |
+
# add the key
|
49 |
+
if self.k is None:
|
50 |
+
self.k = key
|
51 |
+
self.s = shrinkage
|
52 |
+
self.e = selection
|
53 |
+
if self.count_usage:
|
54 |
+
self.use_count = new_count
|
55 |
+
self.life_count = new_life
|
56 |
+
else:
|
57 |
+
self.k = torch.cat([self.k, key], -1)
|
58 |
+
if shrinkage is not None:
|
59 |
+
self.s = torch.cat([self.s, shrinkage], -1)
|
60 |
+
if selection is not None:
|
61 |
+
self.e = torch.cat([self.e, selection], -1)
|
62 |
+
if self.count_usage:
|
63 |
+
self.use_count = torch.cat([self.use_count, new_count], -1)
|
64 |
+
self.life_count = torch.cat([self.life_count, new_life], -1)
|
65 |
+
|
66 |
+
# add the value
|
67 |
+
if objects is not None:
|
68 |
+
# When objects is given, v is a tensor; used in working memory
|
69 |
+
assert isinstance(value, torch.Tensor)
|
70 |
+
# First consume objects that are already in the memory bank
|
71 |
+
# cannot use set here because we need to preserve order
|
72 |
+
# shift by one as background is not part of value
|
73 |
+
remaining_objects = [obj - 1 for obj in objects]
|
74 |
+
for gi, group in enumerate(self.obj_groups):
|
75 |
+
for obj in group:
|
76 |
+
# should properly raise an error if there are overlaps in obj_groups
|
77 |
+
remaining_objects.remove(obj)
|
78 |
+
self.v[gi] = torch.cat([self.v[gi], value[group]], -1)
|
79 |
+
|
80 |
+
# If there are remaining objects, add them as a new group
|
81 |
+
if len(remaining_objects) > 0:
|
82 |
+
new_group = list(remaining_objects)
|
83 |
+
self.v.append(value[new_group])
|
84 |
+
self.obj_groups.append(new_group)
|
85 |
+
self.all_objects.extend(new_group)
|
86 |
+
|
87 |
+
assert (
|
88 |
+
sorted(self.all_objects) == self.all_objects
|
89 |
+
), "Objects MUST be inserted in sorted order "
|
90 |
+
else:
|
91 |
+
# When objects is not given, v is a list that already has the object groups sorted
|
92 |
+
# used in long-term memory
|
93 |
+
assert isinstance(value, list)
|
94 |
+
for gi, gv in enumerate(value):
|
95 |
+
if gv is None:
|
96 |
+
continue
|
97 |
+
if gi < self.num_groups:
|
98 |
+
self.v[gi] = torch.cat([self.v[gi], gv], -1)
|
99 |
+
else:
|
100 |
+
self.v.append(gv)
|
101 |
+
|
102 |
+
def update_usage(self, usage):
|
103 |
+
# increase all life count by 1
|
104 |
+
# increase use of indexed elements
|
105 |
+
if not self.count_usage:
|
106 |
+
return
|
107 |
+
|
108 |
+
self.use_count += usage.view_as(self.use_count)
|
109 |
+
self.life_count += 1
|
110 |
+
|
111 |
+
def sieve_by_range(self, start: int, end: int, min_size: int):
|
112 |
+
# keep only the elements *outside* of this range (with some boundary conditions)
|
113 |
+
# i.e., concat (a[:start], a[end:])
|
114 |
+
# min_size is only used for values, we do not sieve values under this size
|
115 |
+
# (because they are not consolidated)
|
116 |
+
|
117 |
+
if end == 0:
|
118 |
+
# negative 0 would not work as the end index!
|
119 |
+
self.k = self.k[:, :, :start]
|
120 |
+
if self.count_usage:
|
121 |
+
self.use_count = self.use_count[:, :, :start]
|
122 |
+
self.life_count = self.life_count[:, :, :start]
|
123 |
+
if self.s is not None:
|
124 |
+
self.s = self.s[:, :, :start]
|
125 |
+
if self.e is not None:
|
126 |
+
self.e = self.e[:, :, :start]
|
127 |
+
|
128 |
+
for gi in range(self.num_groups):
|
129 |
+
if self.v[gi].shape[-1] >= min_size:
|
130 |
+
self.v[gi] = self.v[gi][:, :, :start]
|
131 |
+
else:
|
132 |
+
self.k = torch.cat([self.k[:, :, :start], self.k[:, :, end:]], -1)
|
133 |
+
if self.count_usage:
|
134 |
+
self.use_count = torch.cat(
|
135 |
+
[self.use_count[:, :, :start], self.use_count[:, :, end:]], -1
|
136 |
+
)
|
137 |
+
self.life_count = torch.cat(
|
138 |
+
[self.life_count[:, :, :start], self.life_count[:, :, end:]], -1
|
139 |
+
)
|
140 |
+
if self.s is not None:
|
141 |
+
self.s = torch.cat([self.s[:, :, :start], self.s[:, :, end:]], -1)
|
142 |
+
if self.e is not None:
|
143 |
+
self.e = torch.cat([self.e[:, :, :start], self.e[:, :, end:]], -1)
|
144 |
+
|
145 |
+
for gi in range(self.num_groups):
|
146 |
+
if self.v[gi].shape[-1] >= min_size:
|
147 |
+
self.v[gi] = torch.cat(
|
148 |
+
[self.v[gi][:, :, :start], self.v[gi][:, :, end:]], -1
|
149 |
+
)
|
150 |
+
|
151 |
+
def remove_obsolete_features(self, max_size: int):
|
152 |
+
# normalize with life duration
|
153 |
+
usage = self.get_usage().flatten()
|
154 |
+
|
155 |
+
values, _ = torch.topk(
|
156 |
+
usage, k=(self.size - max_size), largest=False, sorted=True
|
157 |
+
)
|
158 |
+
survived = usage > values[-1]
|
159 |
+
|
160 |
+
self.k = self.k[:, :, survived]
|
161 |
+
self.s = self.s[:, :, survived] if self.s is not None else None
|
162 |
+
# Long-term memory does not store ek so this should not be needed
|
163 |
+
self.e = self.e[:, :, survived] if self.e is not None else None
|
164 |
+
if self.num_groups > 1:
|
165 |
+
raise NotImplementedError(
|
166 |
+
"""The current data structure does not support feature removal with
|
167 |
+
multiple object groups (e.g., some objects start to appear later in the video)
|
168 |
+
The indices for "survived" is based on keys but not all values are present for every key
|
169 |
+
Basically we need to remap the indices for keys to values
|
170 |
+
"""
|
171 |
+
)
|
172 |
+
for gi in range(self.num_groups):
|
173 |
+
self.v[gi] = self.v[gi][:, :, survived]
|
174 |
+
|
175 |
+
self.use_count = self.use_count[:, :, survived]
|
176 |
+
self.life_count = self.life_count[:, :, survived]
|
177 |
+
|
178 |
+
def get_usage(self):
|
179 |
+
# return normalized usage
|
180 |
+
if not self.count_usage:
|
181 |
+
raise RuntimeError("I did not count usage!")
|
182 |
+
else:
|
183 |
+
usage = self.use_count / self.life_count
|
184 |
+
return usage
|
185 |
+
|
186 |
+
def get_all_sliced(self, start: int, end: int):
|
187 |
+
# return k, sk, ek, usage in order, sliced by start and end
|
188 |
+
|
189 |
+
if end == 0:
|
190 |
+
# negative 0 would not work as the end index!
|
191 |
+
k = self.k[:, :, start:]
|
192 |
+
sk = self.s[:, :, start:] if self.s is not None else None
|
193 |
+
ek = self.e[:, :, start:] if self.e is not None else None
|
194 |
+
usage = self.get_usage()[:, :, start:]
|
195 |
+
else:
|
196 |
+
k = self.k[:, :, start:end]
|
197 |
+
sk = self.s[:, :, start:end] if self.s is not None else None
|
198 |
+
ek = self.e[:, :, start:end] if self.e is not None else None
|
199 |
+
usage = self.get_usage()[:, :, start:end]
|
200 |
+
|
201 |
+
return k, sk, ek, usage
|
202 |
+
|
203 |
+
def get_v_size(self, ni: int):
|
204 |
+
return self.v[ni].shape[2]
|
205 |
+
|
206 |
+
def engaged(self):
|
207 |
+
return self.k is not None
|
208 |
+
|
209 |
+
@property
|
210 |
+
def size(self):
|
211 |
+
if self.k is None:
|
212 |
+
return 0
|
213 |
+
else:
|
214 |
+
return self.k.shape[-1]
|
215 |
+
|
216 |
+
@property
|
217 |
+
def num_groups(self):
|
218 |
+
return len(self.v)
|
219 |
+
|
220 |
+
@property
|
221 |
+
def key(self):
|
222 |
+
return self.k
|
223 |
+
|
224 |
+
@property
|
225 |
+
def value(self):
|
226 |
+
return self.v
|
227 |
+
|
228 |
+
@property
|
229 |
+
def shrinkage(self):
|
230 |
+
return self.s
|
231 |
+
|
232 |
+
@property
|
233 |
+
def selection(self):
|
234 |
+
return self.e
|
tracker/inference/memory_manager.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
from inference.kv_memory_store import KeyValueMemoryStore
|
5 |
+
from model.memory_util import *
|
6 |
+
|
7 |
+
|
8 |
+
class MemoryManager:
|
9 |
+
"""
|
10 |
+
Manages all three memory stores and the transition between working/long-term memory
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, config):
|
14 |
+
self.hidden_dim = config["hidden_dim"]
|
15 |
+
self.top_k = config["top_k"]
|
16 |
+
|
17 |
+
self.enable_long_term = config["enable_long_term"]
|
18 |
+
self.enable_long_term_usage = config["enable_long_term_count_usage"]
|
19 |
+
if self.enable_long_term:
|
20 |
+
self.max_mt_frames = config["max_mid_term_frames"]
|
21 |
+
self.min_mt_frames = config["min_mid_term_frames"]
|
22 |
+
self.num_prototypes = config["num_prototypes"]
|
23 |
+
self.max_long_elements = config["max_long_term_elements"]
|
24 |
+
|
25 |
+
# dimensions will be inferred from input later
|
26 |
+
self.CK = self.CV = None
|
27 |
+
self.H = self.W = None
|
28 |
+
|
29 |
+
# The hidden state will be stored in a single tensor for all objects
|
30 |
+
# B x num_objects x CH x H x W
|
31 |
+
self.hidden = None
|
32 |
+
|
33 |
+
self.work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term)
|
34 |
+
if self.enable_long_term:
|
35 |
+
self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage)
|
36 |
+
|
37 |
+
self.reset_config = True
|
38 |
+
|
39 |
+
def update_config(self, config):
|
40 |
+
self.reset_config = True
|
41 |
+
self.hidden_dim = config["hidden_dim"]
|
42 |
+
self.top_k = config["top_k"]
|
43 |
+
|
44 |
+
assert self.enable_long_term == config["enable_long_term"], "cannot update this"
|
45 |
+
assert (
|
46 |
+
self.enable_long_term_usage == config["enable_long_term_count_usage"]
|
47 |
+
), "cannot update this"
|
48 |
+
|
49 |
+
self.enable_long_term_usage = config["enable_long_term_count_usage"]
|
50 |
+
if self.enable_long_term:
|
51 |
+
self.max_mt_frames = config["max_mid_term_frames"]
|
52 |
+
self.min_mt_frames = config["min_mid_term_frames"]
|
53 |
+
self.num_prototypes = config["num_prototypes"]
|
54 |
+
self.max_long_elements = config["max_long_term_elements"]
|
55 |
+
|
56 |
+
def _readout(self, affinity, v):
|
57 |
+
# this function is for a single object group
|
58 |
+
return v @ affinity
|
59 |
+
|
60 |
+
def match_memory(self, query_key, selection):
|
61 |
+
# query_key: B x C^k x H x W
|
62 |
+
# selection: B x C^k x H x W
|
63 |
+
num_groups = self.work_mem.num_groups
|
64 |
+
h, w = query_key.shape[-2:]
|
65 |
+
|
66 |
+
query_key = query_key.flatten(start_dim=2)
|
67 |
+
selection = selection.flatten(start_dim=2) if selection is not None else None
|
68 |
+
|
69 |
+
"""
|
70 |
+
Memory readout using keys
|
71 |
+
"""
|
72 |
+
|
73 |
+
if self.enable_long_term and self.long_mem.engaged():
|
74 |
+
# Use long-term memory
|
75 |
+
long_mem_size = self.long_mem.size
|
76 |
+
memory_key = torch.cat([self.long_mem.key, self.work_mem.key], -1)
|
77 |
+
shrinkage = torch.cat(
|
78 |
+
[self.long_mem.shrinkage, self.work_mem.shrinkage], -1
|
79 |
+
)
|
80 |
+
|
81 |
+
similarity = get_similarity(memory_key, shrinkage, query_key, selection)
|
82 |
+
work_mem_similarity = similarity[:, long_mem_size:]
|
83 |
+
long_mem_similarity = similarity[:, :long_mem_size]
|
84 |
+
|
85 |
+
# get the usage with the first group
|
86 |
+
# the first group always have all the keys valid
|
87 |
+
affinity, usage = do_softmax(
|
88 |
+
torch.cat(
|
89 |
+
[
|
90 |
+
long_mem_similarity[:, -self.long_mem.get_v_size(0) :],
|
91 |
+
work_mem_similarity,
|
92 |
+
],
|
93 |
+
1,
|
94 |
+
),
|
95 |
+
top_k=self.top_k,
|
96 |
+
inplace=True,
|
97 |
+
return_usage=True,
|
98 |
+
)
|
99 |
+
affinity = [affinity]
|
100 |
+
|
101 |
+
# compute affinity group by group as later groups only have a subset of keys
|
102 |
+
for gi in range(1, num_groups):
|
103 |
+
if gi < self.long_mem.num_groups:
|
104 |
+
# merge working and lt similarities before softmax
|
105 |
+
affinity_one_group = do_softmax(
|
106 |
+
torch.cat(
|
107 |
+
[
|
108 |
+
long_mem_similarity[:, -self.long_mem.get_v_size(gi) :],
|
109 |
+
work_mem_similarity[:, -self.work_mem.get_v_size(gi) :],
|
110 |
+
],
|
111 |
+
1,
|
112 |
+
),
|
113 |
+
top_k=self.top_k,
|
114 |
+
inplace=True,
|
115 |
+
)
|
116 |
+
else:
|
117 |
+
# no long-term memory for this group
|
118 |
+
affinity_one_group = do_softmax(
|
119 |
+
work_mem_similarity[:, -self.work_mem.get_v_size(gi) :],
|
120 |
+
top_k=self.top_k,
|
121 |
+
inplace=(gi == num_groups - 1),
|
122 |
+
)
|
123 |
+
affinity.append(affinity_one_group)
|
124 |
+
|
125 |
+
all_memory_value = []
|
126 |
+
for gi, gv in enumerate(self.work_mem.value):
|
127 |
+
# merge the working and lt values before readout
|
128 |
+
if gi < self.long_mem.num_groups:
|
129 |
+
all_memory_value.append(
|
130 |
+
torch.cat(
|
131 |
+
[self.long_mem.value[gi], self.work_mem.value[gi]], -1
|
132 |
+
)
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
all_memory_value.append(gv)
|
136 |
+
|
137 |
+
"""
|
138 |
+
Record memory usage for working and long-term memory
|
139 |
+
"""
|
140 |
+
# ignore the index return for long-term memory
|
141 |
+
work_usage = usage[:, long_mem_size:]
|
142 |
+
self.work_mem.update_usage(work_usage.flatten())
|
143 |
+
|
144 |
+
if self.enable_long_term_usage:
|
145 |
+
# ignore the index return for working memory
|
146 |
+
long_usage = usage[:, :long_mem_size]
|
147 |
+
self.long_mem.update_usage(long_usage.flatten())
|
148 |
+
else:
|
149 |
+
# No long-term memory
|
150 |
+
similarity = get_similarity(
|
151 |
+
self.work_mem.key, self.work_mem.shrinkage, query_key, selection
|
152 |
+
)
|
153 |
+
|
154 |
+
if self.enable_long_term:
|
155 |
+
affinity, usage = do_softmax(
|
156 |
+
similarity,
|
157 |
+
inplace=(num_groups == 1),
|
158 |
+
top_k=self.top_k,
|
159 |
+
return_usage=True,
|
160 |
+
)
|
161 |
+
|
162 |
+
# Record memory usage for working memory
|
163 |
+
self.work_mem.update_usage(usage.flatten())
|
164 |
+
else:
|
165 |
+
affinity = do_softmax(
|
166 |
+
similarity,
|
167 |
+
inplace=(num_groups == 1),
|
168 |
+
top_k=self.top_k,
|
169 |
+
return_usage=False,
|
170 |
+
)
|
171 |
+
|
172 |
+
affinity = [affinity]
|
173 |
+
|
174 |
+
# compute affinity group by group as later groups only have a subset of keys
|
175 |
+
for gi in range(1, num_groups):
|
176 |
+
affinity_one_group = do_softmax(
|
177 |
+
similarity[:, -self.work_mem.get_v_size(gi) :],
|
178 |
+
top_k=self.top_k,
|
179 |
+
inplace=(gi == num_groups - 1),
|
180 |
+
)
|
181 |
+
affinity.append(affinity_one_group)
|
182 |
+
|
183 |
+
all_memory_value = self.work_mem.value
|
184 |
+
|
185 |
+
# Shared affinity within each group
|
186 |
+
all_readout_mem = torch.cat(
|
187 |
+
[self._readout(affinity[gi], gv) for gi, gv in enumerate(all_memory_value)],
|
188 |
+
0,
|
189 |
+
)
|
190 |
+
|
191 |
+
return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w)
|
192 |
+
|
193 |
+
def add_memory(self, key, shrinkage, value, objects, selection=None):
|
194 |
+
# key: 1*C*H*W
|
195 |
+
# value: 1*num_objects*C*H*W
|
196 |
+
# objects contain a list of object indices
|
197 |
+
if self.H is None or self.reset_config:
|
198 |
+
self.reset_config = False
|
199 |
+
self.H, self.W = key.shape[-2:]
|
200 |
+
self.HW = self.H * self.W
|
201 |
+
if self.enable_long_term:
|
202 |
+
# convert from num. frames to num. nodes
|
203 |
+
self.min_work_elements = self.min_mt_frames * self.HW
|
204 |
+
self.max_work_elements = self.max_mt_frames * self.HW
|
205 |
+
|
206 |
+
# key: 1*C*N
|
207 |
+
# value: num_objects*C*N
|
208 |
+
key = key.flatten(start_dim=2)
|
209 |
+
shrinkage = shrinkage.flatten(start_dim=2)
|
210 |
+
value = value[0].flatten(start_dim=2)
|
211 |
+
|
212 |
+
self.CK = key.shape[1]
|
213 |
+
self.CV = value.shape[1]
|
214 |
+
|
215 |
+
if selection is not None:
|
216 |
+
if not self.enable_long_term:
|
217 |
+
warnings.warn(
|
218 |
+
"the selection factor is only needed in long-term mode", UserWarning
|
219 |
+
)
|
220 |
+
selection = selection.flatten(start_dim=2)
|
221 |
+
|
222 |
+
self.work_mem.add(key, value, shrinkage, selection, objects)
|
223 |
+
|
224 |
+
# long-term memory cleanup
|
225 |
+
if self.enable_long_term:
|
226 |
+
# Do memory compressed if needed
|
227 |
+
if self.work_mem.size >= self.max_work_elements:
|
228 |
+
# print('remove memory')
|
229 |
+
# Remove obsolete features if needed
|
230 |
+
if self.long_mem.size >= (self.max_long_elements - self.num_prototypes):
|
231 |
+
self.long_mem.remove_obsolete_features(
|
232 |
+
self.max_long_elements - self.num_prototypes
|
233 |
+
)
|
234 |
+
|
235 |
+
self.compress_features()
|
236 |
+
|
237 |
+
def create_hidden_state(self, n, sample_key):
|
238 |
+
# n is the TOTAL number of objects
|
239 |
+
h, w = sample_key.shape[-2:]
|
240 |
+
if self.hidden is None:
|
241 |
+
self.hidden = torch.zeros(
|
242 |
+
(1, n, self.hidden_dim, h, w), device=sample_key.device
|
243 |
+
)
|
244 |
+
elif self.hidden.shape[1] != n:
|
245 |
+
self.hidden = torch.cat(
|
246 |
+
[
|
247 |
+
self.hidden,
|
248 |
+
torch.zeros(
|
249 |
+
(1, n - self.hidden.shape[1], self.hidden_dim, h, w),
|
250 |
+
device=sample_key.device,
|
251 |
+
),
|
252 |
+
],
|
253 |
+
1,
|
254 |
+
)
|
255 |
+
|
256 |
+
assert self.hidden.shape[1] == n
|
257 |
+
|
258 |
+
def set_hidden(self, hidden):
|
259 |
+
self.hidden = hidden
|
260 |
+
|
261 |
+
def get_hidden(self):
|
262 |
+
return self.hidden
|
263 |
+
|
264 |
+
def compress_features(self):
|
265 |
+
HW = self.HW
|
266 |
+
candidate_value = []
|
267 |
+
total_work_mem_size = self.work_mem.size
|
268 |
+
for gv in self.work_mem.value:
|
269 |
+
# Some object groups might be added later in the video
|
270 |
+
# So not all keys have values associated with all objects
|
271 |
+
# We need to keep track of the key->value validity
|
272 |
+
mem_size_in_this_group = gv.shape[-1]
|
273 |
+
if mem_size_in_this_group == total_work_mem_size:
|
274 |
+
# full LT
|
275 |
+
candidate_value.append(gv[:, :, HW : -self.min_work_elements + HW])
|
276 |
+
else:
|
277 |
+
# mem_size is smaller than total_work_mem_size, but at least HW
|
278 |
+
assert HW <= mem_size_in_this_group < total_work_mem_size
|
279 |
+
if mem_size_in_this_group > self.min_work_elements + HW:
|
280 |
+
# part of this object group still goes into LT
|
281 |
+
candidate_value.append(gv[:, :, HW : -self.min_work_elements + HW])
|
282 |
+
else:
|
283 |
+
# this object group cannot go to the LT at all
|
284 |
+
candidate_value.append(None)
|
285 |
+
|
286 |
+
# perform memory consolidation
|
287 |
+
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
|
288 |
+
*self.work_mem.get_all_sliced(HW, -self.min_work_elements + HW),
|
289 |
+
candidate_value
|
290 |
+
)
|
291 |
+
|
292 |
+
# remove consolidated working memory
|
293 |
+
self.work_mem.sieve_by_range(
|
294 |
+
HW, -self.min_work_elements + HW, min_size=self.min_work_elements + HW
|
295 |
+
)
|
296 |
+
|
297 |
+
# add to long-term memory
|
298 |
+
self.long_mem.add(
|
299 |
+
prototype_key,
|
300 |
+
prototype_value,
|
301 |
+
prototype_shrinkage,
|
302 |
+
selection=None,
|
303 |
+
objects=None,
|
304 |
+
)
|
305 |
+
# print(f'long memory size: {self.long_mem.size}')
|
306 |
+
# print(f'work memory size: {self.work_mem.size}')
|
307 |
+
|
308 |
+
def consolidation(
|
309 |
+
self,
|
310 |
+
candidate_key,
|
311 |
+
candidate_shrinkage,
|
312 |
+
candidate_selection,
|
313 |
+
usage,
|
314 |
+
candidate_value,
|
315 |
+
):
|
316 |
+
# keys: 1*C*N
|
317 |
+
# values: num_objects*C*N
|
318 |
+
N = candidate_key.shape[-1]
|
319 |
+
|
320 |
+
# find the indices with max usage
|
321 |
+
_, max_usage_indices = torch.topk(
|
322 |
+
usage, k=self.num_prototypes, dim=-1, sorted=True
|
323 |
+
)
|
324 |
+
prototype_indices = max_usage_indices.flatten()
|
325 |
+
|
326 |
+
# Prototypes are invalid for out-of-bound groups
|
327 |
+
validity = [
|
328 |
+
prototype_indices >= (N - gv.shape[2]) if gv is not None else None
|
329 |
+
for gv in candidate_value
|
330 |
+
]
|
331 |
+
|
332 |
+
prototype_key = candidate_key[:, :, prototype_indices]
|
333 |
+
prototype_selection = (
|
334 |
+
candidate_selection[:, :, prototype_indices]
|
335 |
+
if candidate_selection is not None
|
336 |
+
else None
|
337 |
+
)
|
338 |
+
|
339 |
+
"""
|
340 |
+
Potentiation step
|
341 |
+
"""
|
342 |
+
similarity = get_similarity(
|
343 |
+
candidate_key, candidate_shrinkage, prototype_key, prototype_selection
|
344 |
+
)
|
345 |
+
|
346 |
+
# convert similarity to affinity
|
347 |
+
# need to do it group by group since the softmax normalization would be different
|
348 |
+
affinity = [
|
349 |
+
do_softmax(similarity[:, -gv.shape[2] :, validity[gi]])
|
350 |
+
if gv is not None
|
351 |
+
else None
|
352 |
+
for gi, gv in enumerate(candidate_value)
|
353 |
+
]
|
354 |
+
|
355 |
+
# some values can be have all False validity. Weed them out.
|
356 |
+
affinity = [
|
357 |
+
aff if aff is None or aff.shape[-1] > 0 else None for aff in affinity
|
358 |
+
]
|
359 |
+
|
360 |
+
# readout the values
|
361 |
+
prototype_value = [
|
362 |
+
self._readout(affinity[gi], gv) if affinity[gi] is not None else None
|
363 |
+
for gi, gv in enumerate(candidate_value)
|
364 |
+
]
|
365 |
+
|
366 |
+
# readout the shrinkage term
|
367 |
+
prototype_shrinkage = (
|
368 |
+
self._readout(affinity[0], candidate_shrinkage)
|
369 |
+
if candidate_shrinkage is not None
|
370 |
+
else None
|
371 |
+
)
|
372 |
+
|
373 |
+
return prototype_key, prototype_value, prototype_shrinkage
|
tracker/model/__init__.py
ADDED
File without changes
|
tracker/model/aggregate.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
# Soft aggregation from STM
|
6 |
+
def aggregate(prob, dim, return_logits=False):
|
7 |
+
new_prob = torch.cat(
|
8 |
+
[torch.prod(1 - prob, dim=dim, keepdim=True), prob], dim
|
9 |
+
).clamp(1e-7, 1 - 1e-7)
|
10 |
+
logits = torch.log((new_prob / (1 - new_prob)))
|
11 |
+
prob = F.softmax(logits, dim=dim)
|
12 |
+
|
13 |
+
if return_logits:
|
14 |
+
return logits, prob
|
15 |
+
else:
|
16 |
+
return prob
|
tracker/model/cbam.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class BasicConv(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
in_planes,
|
12 |
+
out_planes,
|
13 |
+
kernel_size,
|
14 |
+
stride=1,
|
15 |
+
padding=0,
|
16 |
+
dilation=1,
|
17 |
+
groups=1,
|
18 |
+
bias=True,
|
19 |
+
):
|
20 |
+
super(BasicConv, self).__init__()
|
21 |
+
self.out_channels = out_planes
|
22 |
+
self.conv = nn.Conv2d(
|
23 |
+
in_planes,
|
24 |
+
out_planes,
|
25 |
+
kernel_size=kernel_size,
|
26 |
+
stride=stride,
|
27 |
+
padding=padding,
|
28 |
+
dilation=dilation,
|
29 |
+
groups=groups,
|
30 |
+
bias=bias,
|
31 |
+
)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
x = self.conv(x)
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
class Flatten(nn.Module):
|
39 |
+
def forward(self, x):
|
40 |
+
return x.view(x.size(0), -1)
|
41 |
+
|
42 |
+
|
43 |
+
class ChannelGate(nn.Module):
|
44 |
+
def __init__(self, gate_channels, reduction_ratio=16, pool_types=["avg", "max"]):
|
45 |
+
super(ChannelGate, self).__init__()
|
46 |
+
self.gate_channels = gate_channels
|
47 |
+
self.mlp = nn.Sequential(
|
48 |
+
Flatten(),
|
49 |
+
nn.Linear(gate_channels, gate_channels // reduction_ratio),
|
50 |
+
nn.ReLU(),
|
51 |
+
nn.Linear(gate_channels // reduction_ratio, gate_channels),
|
52 |
+
)
|
53 |
+
self.pool_types = pool_types
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
channel_att_sum = None
|
57 |
+
for pool_type in self.pool_types:
|
58 |
+
if pool_type == "avg":
|
59 |
+
avg_pool = F.avg_pool2d(
|
60 |
+
x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))
|
61 |
+
)
|
62 |
+
channel_att_raw = self.mlp(avg_pool)
|
63 |
+
elif pool_type == "max":
|
64 |
+
max_pool = F.max_pool2d(
|
65 |
+
x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))
|
66 |
+
)
|
67 |
+
channel_att_raw = self.mlp(max_pool)
|
68 |
+
|
69 |
+
if channel_att_sum is None:
|
70 |
+
channel_att_sum = channel_att_raw
|
71 |
+
else:
|
72 |
+
channel_att_sum = channel_att_sum + channel_att_raw
|
73 |
+
|
74 |
+
scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
|
75 |
+
return x * scale
|
76 |
+
|
77 |
+
|
78 |
+
class ChannelPool(nn.Module):
|
79 |
+
def forward(self, x):
|
80 |
+
return torch.cat(
|
81 |
+
(torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
class SpatialGate(nn.Module):
|
86 |
+
def __init__(self):
|
87 |
+
super(SpatialGate, self).__init__()
|
88 |
+
kernel_size = 7
|
89 |
+
self.compress = ChannelPool()
|
90 |
+
self.spatial = BasicConv(
|
91 |
+
2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2
|
92 |
+
)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
x_compress = self.compress(x)
|
96 |
+
x_out = self.spatial(x_compress)
|
97 |
+
scale = torch.sigmoid(x_out) # broadcasting
|
98 |
+
return x * scale
|
99 |
+
|
100 |
+
|
101 |
+
class CBAM(nn.Module):
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
gate_channels,
|
105 |
+
reduction_ratio=16,
|
106 |
+
pool_types=["avg", "max"],
|
107 |
+
no_spatial=False,
|
108 |
+
):
|
109 |
+
super(CBAM, self).__init__()
|
110 |
+
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
|
111 |
+
self.no_spatial = no_spatial
|
112 |
+
if not no_spatial:
|
113 |
+
self.SpatialGate = SpatialGate()
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
x_out = self.ChannelGate(x)
|
117 |
+
if not self.no_spatial:
|
118 |
+
x_out = self.SpatialGate(x_out)
|
119 |
+
return x_out
|
tracker/model/group_modules.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Group-specific modules
|
3 |
+
They handle features that also depends on the mask.
|
4 |
+
Features are typically of shape
|
5 |
+
batch_size * num_objects * num_channels * H * W
|
6 |
+
|
7 |
+
All of them are permutation equivariant w.r.t. to the num_objects dimension
|
8 |
+
"""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
|
15 |
+
def interpolate_groups(g, ratio, mode, align_corners):
|
16 |
+
batch_size, num_objects = g.shape[:2]
|
17 |
+
g = F.interpolate(
|
18 |
+
g.flatten(start_dim=0, end_dim=1),
|
19 |
+
scale_factor=ratio,
|
20 |
+
mode=mode,
|
21 |
+
align_corners=align_corners,
|
22 |
+
)
|
23 |
+
g = g.view(batch_size, num_objects, *g.shape[1:])
|
24 |
+
return g
|
25 |
+
|
26 |
+
|
27 |
+
def upsample_groups(g, ratio=2, mode="bilinear", align_corners=False):
|
28 |
+
return interpolate_groups(g, ratio, mode, align_corners)
|
29 |
+
|
30 |
+
|
31 |
+
def downsample_groups(g, ratio=1 / 2, mode="area", align_corners=None):
|
32 |
+
return interpolate_groups(g, ratio, mode, align_corners)
|
33 |
+
|
34 |
+
|
35 |
+
class GConv2D(nn.Conv2d):
|
36 |
+
def forward(self, g):
|
37 |
+
batch_size, num_objects = g.shape[:2]
|
38 |
+
g = super().forward(g.flatten(start_dim=0, end_dim=1))
|
39 |
+
return g.view(batch_size, num_objects, *g.shape[1:])
|
40 |
+
|
41 |
+
|
42 |
+
class GroupResBlock(nn.Module):
|
43 |
+
def __init__(self, in_dim, out_dim):
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
if in_dim == out_dim:
|
47 |
+
self.downsample = None
|
48 |
+
else:
|
49 |
+
self.downsample = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
|
50 |
+
|
51 |
+
self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
|
52 |
+
self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1)
|
53 |
+
|
54 |
+
def forward(self, g):
|
55 |
+
out_g = self.conv1(F.relu(g))
|
56 |
+
out_g = self.conv2(F.relu(out_g))
|
57 |
+
|
58 |
+
if self.downsample is not None:
|
59 |
+
g = self.downsample(g)
|
60 |
+
|
61 |
+
return out_g + g
|
62 |
+
|
63 |
+
|
64 |
+
class MainToGroupDistributor(nn.Module):
|
65 |
+
def __init__(self, x_transform=None, method="cat", reverse_order=False):
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
self.x_transform = x_transform
|
69 |
+
self.method = method
|
70 |
+
self.reverse_order = reverse_order
|
71 |
+
|
72 |
+
def forward(self, x, g):
|
73 |
+
num_objects = g.shape[1]
|
74 |
+
|
75 |
+
if self.x_transform is not None:
|
76 |
+
x = self.x_transform(x)
|
77 |
+
|
78 |
+
if self.method == "cat":
|
79 |
+
if self.reverse_order:
|
80 |
+
g = torch.cat(
|
81 |
+
[g, x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)], 2
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
g = torch.cat(
|
85 |
+
[x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1), g], 2
|
86 |
+
)
|
87 |
+
elif self.method == "add":
|
88 |
+
g = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) + g
|
89 |
+
else:
|
90 |
+
raise NotImplementedError
|
91 |
+
|
92 |
+
return g
|
tracker/model/losses.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
|
8 |
+
def dice_loss(input_mask, cls_gt):
|
9 |
+
num_objects = input_mask.shape[1]
|
10 |
+
losses = []
|
11 |
+
for i in range(num_objects):
|
12 |
+
mask = input_mask[:, i].flatten(start_dim=1)
|
13 |
+
# background not in mask, so we add one to cls_gt
|
14 |
+
gt = (cls_gt == (i + 1)).float().flatten(start_dim=1)
|
15 |
+
numerator = 2 * (mask * gt).sum(-1)
|
16 |
+
denominator = mask.sum(-1) + gt.sum(-1)
|
17 |
+
loss = 1 - (numerator + 1) / (denominator + 1)
|
18 |
+
losses.append(loss)
|
19 |
+
return torch.cat(losses).mean()
|
20 |
+
|
21 |
+
|
22 |
+
# https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch
|
23 |
+
class BootstrappedCE(nn.Module):
|
24 |
+
def __init__(self, start_warm, end_warm, top_p=0.15):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
self.start_warm = start_warm
|
28 |
+
self.end_warm = end_warm
|
29 |
+
self.top_p = top_p
|
30 |
+
|
31 |
+
def forward(self, input, target, it):
|
32 |
+
if it < self.start_warm:
|
33 |
+
return F.cross_entropy(input, target), 1.0
|
34 |
+
|
35 |
+
raw_loss = F.cross_entropy(input, target, reduction="none").view(-1)
|
36 |
+
num_pixels = raw_loss.numel()
|
37 |
+
|
38 |
+
if it > self.end_warm:
|
39 |
+
this_p = self.top_p
|
40 |
+
else:
|
41 |
+
this_p = self.top_p + (1 - self.top_p) * (
|
42 |
+
(self.end_warm - it) / (self.end_warm - self.start_warm)
|
43 |
+
)
|
44 |
+
loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
|
45 |
+
return loss.mean(), this_p
|
46 |
+
|
47 |
+
|
48 |
+
class LossComputer:
|
49 |
+
def __init__(self, config):
|
50 |
+
super().__init__()
|
51 |
+
self.config = config
|
52 |
+
self.bce = BootstrappedCE(config["start_warm"], config["end_warm"])
|
53 |
+
|
54 |
+
def compute(self, data, num_objects, it):
|
55 |
+
losses = defaultdict(int)
|
56 |
+
|
57 |
+
b, t = data["rgb"].shape[:2]
|
58 |
+
|
59 |
+
losses["total_loss"] = 0
|
60 |
+
for ti in range(1, t):
|
61 |
+
for bi in range(b):
|
62 |
+
loss, p = self.bce(
|
63 |
+
data[f"logits_{ti}"][bi : bi + 1, : num_objects[bi] + 1],
|
64 |
+
data["cls_gt"][bi : bi + 1, ti, 0],
|
65 |
+
it,
|
66 |
+
)
|
67 |
+
losses["p"] += p / b / (t - 1)
|
68 |
+
losses[f"ce_loss_{ti}"] += loss / b
|
69 |
+
|
70 |
+
losses["total_loss"] += losses["ce_loss_%d" % ti]
|
71 |
+
losses[f"dice_loss_{ti}"] = dice_loss(
|
72 |
+
data[f"masks_{ti}"], data["cls_gt"][:, ti, 0]
|
73 |
+
)
|
74 |
+
losses["total_loss"] += losses[f"dice_loss_{ti}"]
|
75 |
+
|
76 |
+
return losses
|
tracker/model/memory_util.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
|
7 |
+
def get_similarity(mk, ms, qk, qe):
|
8 |
+
# used for training/inference and memory reading/memory potentiation
|
9 |
+
# mk: B x CK x [N] - Memory keys
|
10 |
+
# ms: B x 1 x [N] - Memory shrinkage
|
11 |
+
# qk: B x CK x [HW/P] - Query keys
|
12 |
+
# qe: B x CK x [HW/P] - Query selection
|
13 |
+
# Dimensions in [] are flattened
|
14 |
+
CK = mk.shape[1]
|
15 |
+
mk = mk.flatten(start_dim=2)
|
16 |
+
ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
|
17 |
+
qk = qk.flatten(start_dim=2)
|
18 |
+
qe = qe.flatten(start_dim=2) if qe is not None else None
|
19 |
+
|
20 |
+
if qe is not None:
|
21 |
+
# See appendix for derivation
|
22 |
+
# or you can just trust me ヽ(ー_ー )ノ
|
23 |
+
mk = mk.transpose(1, 2)
|
24 |
+
a_sq = mk.pow(2) @ qe
|
25 |
+
two_ab = 2 * (mk @ (qk * qe))
|
26 |
+
b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
|
27 |
+
similarity = -a_sq + two_ab - b_sq
|
28 |
+
else:
|
29 |
+
# similar to STCN if we don't have the selection term
|
30 |
+
a_sq = mk.pow(2).sum(1).unsqueeze(2)
|
31 |
+
two_ab = 2 * (mk.transpose(1, 2) @ qk)
|
32 |
+
similarity = -a_sq + two_ab
|
33 |
+
|
34 |
+
if ms is not None:
|
35 |
+
similarity = similarity * ms / math.sqrt(CK) # B*N*HW
|
36 |
+
else:
|
37 |
+
similarity = similarity / math.sqrt(CK) # B*N*HW
|
38 |
+
|
39 |
+
return similarity
|
40 |
+
|
41 |
+
|
42 |
+
def do_softmax(
|
43 |
+
similarity, top_k: Optional[int] = None, inplace=False, return_usage=False
|
44 |
+
):
|
45 |
+
# normalize similarity with top-k softmax
|
46 |
+
# similarity: B x N x [HW/P]
|
47 |
+
# use inplace with care
|
48 |
+
if top_k is not None:
|
49 |
+
values, indices = torch.topk(similarity, k=top_k, dim=1)
|
50 |
+
|
51 |
+
x_exp = values.exp_()
|
52 |
+
x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
|
53 |
+
if inplace:
|
54 |
+
similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
|
55 |
+
affinity = similarity
|
56 |
+
else:
|
57 |
+
affinity = torch.zeros_like(similarity).scatter_(
|
58 |
+
1, indices, x_exp
|
59 |
+
) # B*N*HW
|
60 |
+
else:
|
61 |
+
maxes = torch.max(similarity, dim=1, keepdim=True)[0]
|
62 |
+
x_exp = torch.exp(similarity - maxes)
|
63 |
+
x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
|
64 |
+
affinity = x_exp / x_exp_sum
|
65 |
+
indices = None
|
66 |
+
|
67 |
+
if return_usage:
|
68 |
+
return affinity, affinity.sum(dim=2)
|
69 |
+
|
70 |
+
return affinity
|
71 |
+
|
72 |
+
|
73 |
+
def get_affinity(mk, ms, qk, qe):
|
74 |
+
# shorthand used in training with no top-k
|
75 |
+
similarity = get_similarity(mk, ms, qk, qe)
|
76 |
+
affinity = do_softmax(similarity)
|
77 |
+
return affinity
|
78 |
+
|
79 |
+
|
80 |
+
def readout(affinity, mv):
|
81 |
+
B, CV, T, H, W = mv.shape
|
82 |
+
|
83 |
+
mo = mv.view(B, CV, T * H * W)
|
84 |
+
mem = torch.bmm(mo, affinity)
|
85 |
+
mem = mem.view(B, CV, H, W)
|
86 |
+
|
87 |
+
return mem
|
tracker/model/modules.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
modules.py - This file stores the rather boring network blocks.
|
3 |
+
|
4 |
+
x - usually means features that only depends on the image
|
5 |
+
g - usually means features that also depends on the mask.
|
6 |
+
They might have an extra "group" or "num_objects" dimension, hence
|
7 |
+
batch_size * num_objects * num_channels * H * W
|
8 |
+
|
9 |
+
The trailing number of a variable usually denote the stride
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
from model.group_modules import *
|
18 |
+
from model import resnet
|
19 |
+
from model.cbam import CBAM
|
20 |
+
|
21 |
+
|
22 |
+
class FeatureFusionBlock(nn.Module):
|
23 |
+
def __init__(self, x_in_dim, g_in_dim, g_mid_dim, g_out_dim):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.distributor = MainToGroupDistributor()
|
27 |
+
self.block1 = GroupResBlock(x_in_dim + g_in_dim, g_mid_dim)
|
28 |
+
self.attention = CBAM(g_mid_dim)
|
29 |
+
self.block2 = GroupResBlock(g_mid_dim, g_out_dim)
|
30 |
+
|
31 |
+
def forward(self, x, g):
|
32 |
+
batch_size, num_objects = g.shape[:2]
|
33 |
+
|
34 |
+
g = self.distributor(x, g)
|
35 |
+
g = self.block1(g)
|
36 |
+
r = self.attention(g.flatten(start_dim=0, end_dim=1))
|
37 |
+
r = r.view(batch_size, num_objects, *r.shape[1:])
|
38 |
+
|
39 |
+
g = self.block2(g + r)
|
40 |
+
|
41 |
+
return g
|
42 |
+
|
43 |
+
|
44 |
+
class HiddenUpdater(nn.Module):
|
45 |
+
# Used in the decoder, multi-scale feature + GRU
|
46 |
+
def __init__(self, g_dims, mid_dim, hidden_dim):
|
47 |
+
super().__init__()
|
48 |
+
self.hidden_dim = hidden_dim
|
49 |
+
|
50 |
+
self.g16_conv = GConv2D(g_dims[0], mid_dim, kernel_size=1)
|
51 |
+
self.g8_conv = GConv2D(g_dims[1], mid_dim, kernel_size=1)
|
52 |
+
self.g4_conv = GConv2D(g_dims[2], mid_dim, kernel_size=1)
|
53 |
+
|
54 |
+
self.transform = GConv2D(
|
55 |
+
mid_dim + hidden_dim, hidden_dim * 3, kernel_size=3, padding=1
|
56 |
+
)
|
57 |
+
|
58 |
+
nn.init.xavier_normal_(self.transform.weight)
|
59 |
+
|
60 |
+
def forward(self, g, h):
|
61 |
+
g = (
|
62 |
+
self.g16_conv(g[0])
|
63 |
+
+ self.g8_conv(downsample_groups(g[1], ratio=1 / 2))
|
64 |
+
+ self.g4_conv(downsample_groups(g[2], ratio=1 / 4))
|
65 |
+
)
|
66 |
+
|
67 |
+
g = torch.cat([g, h], 2)
|
68 |
+
|
69 |
+
# defined slightly differently than standard GRU,
|
70 |
+
# namely the new value is generated before the forget gate.
|
71 |
+
# might provide better gradient but frankly it was initially just an
|
72 |
+
# implementation error that I never bothered fixing
|
73 |
+
values = self.transform(g)
|
74 |
+
forget_gate = torch.sigmoid(values[:, :, : self.hidden_dim])
|
75 |
+
update_gate = torch.sigmoid(values[:, :, self.hidden_dim : self.hidden_dim * 2])
|
76 |
+
new_value = torch.tanh(values[:, :, self.hidden_dim * 2 :])
|
77 |
+
new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
|
78 |
+
|
79 |
+
return new_h
|
80 |
+
|
81 |
+
|
82 |
+
class HiddenReinforcer(nn.Module):
|
83 |
+
# Used in the value encoder, a single GRU
|
84 |
+
def __init__(self, g_dim, hidden_dim):
|
85 |
+
super().__init__()
|
86 |
+
self.hidden_dim = hidden_dim
|
87 |
+
self.transform = GConv2D(
|
88 |
+
g_dim + hidden_dim, hidden_dim * 3, kernel_size=3, padding=1
|
89 |
+
)
|
90 |
+
|
91 |
+
nn.init.xavier_normal_(self.transform.weight)
|
92 |
+
|
93 |
+
def forward(self, g, h):
|
94 |
+
g = torch.cat([g, h], 2)
|
95 |
+
|
96 |
+
# defined slightly differently than standard GRU,
|
97 |
+
# namely the new value is generated before the forget gate.
|
98 |
+
# might provide better gradient but frankly it was initially just an
|
99 |
+
# implementation error that I never bothered fixing
|
100 |
+
values = self.transform(g)
|
101 |
+
forget_gate = torch.sigmoid(values[:, :, : self.hidden_dim])
|
102 |
+
update_gate = torch.sigmoid(values[:, :, self.hidden_dim : self.hidden_dim * 2])
|
103 |
+
new_value = torch.tanh(values[:, :, self.hidden_dim * 2 :])
|
104 |
+
new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
|
105 |
+
|
106 |
+
return new_h
|
107 |
+
|
108 |
+
|
109 |
+
class ValueEncoder(nn.Module):
|
110 |
+
def __init__(self, value_dim, hidden_dim, single_object=False):
|
111 |
+
super().__init__()
|
112 |
+
|
113 |
+
self.single_object = single_object
|
114 |
+
network = resnet.resnet18(pretrained=True, extra_dim=1 if single_object else 2)
|
115 |
+
self.conv1 = network.conv1
|
116 |
+
self.bn1 = network.bn1
|
117 |
+
self.relu = network.relu # 1/2, 64
|
118 |
+
self.maxpool = network.maxpool
|
119 |
+
|
120 |
+
self.layer1 = network.layer1 # 1/4, 64
|
121 |
+
self.layer2 = network.layer2 # 1/8, 128
|
122 |
+
self.layer3 = network.layer3 # 1/16, 256
|
123 |
+
|
124 |
+
self.distributor = MainToGroupDistributor()
|
125 |
+
self.fuser = FeatureFusionBlock(1024, 256, value_dim, value_dim)
|
126 |
+
if hidden_dim > 0:
|
127 |
+
self.hidden_reinforce = HiddenReinforcer(value_dim, hidden_dim)
|
128 |
+
else:
|
129 |
+
self.hidden_reinforce = None
|
130 |
+
|
131 |
+
def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True):
|
132 |
+
# image_feat_f16 is the feature from the key encoder
|
133 |
+
if not self.single_object:
|
134 |
+
g = torch.stack([masks, others], 2)
|
135 |
+
else:
|
136 |
+
g = masks.unsqueeze(2)
|
137 |
+
g = self.distributor(image, g)
|
138 |
+
|
139 |
+
batch_size, num_objects = g.shape[:2]
|
140 |
+
g = g.flatten(start_dim=0, end_dim=1)
|
141 |
+
|
142 |
+
g = self.conv1(g)
|
143 |
+
g = self.bn1(g) # 1/2, 64
|
144 |
+
g = self.maxpool(g) # 1/4, 64
|
145 |
+
g = self.relu(g)
|
146 |
+
|
147 |
+
g = self.layer1(g) # 1/4
|
148 |
+
g = self.layer2(g) # 1/8
|
149 |
+
g = self.layer3(g) # 1/16
|
150 |
+
|
151 |
+
g = g.view(batch_size, num_objects, *g.shape[1:])
|
152 |
+
g = self.fuser(image_feat_f16, g)
|
153 |
+
|
154 |
+
if is_deep_update and self.hidden_reinforce is not None:
|
155 |
+
h = self.hidden_reinforce(g, h)
|
156 |
+
|
157 |
+
return g, h
|
158 |
+
|
159 |
+
|
160 |
+
class KeyEncoder(nn.Module):
|
161 |
+
def __init__(self):
|
162 |
+
super().__init__()
|
163 |
+
network = resnet.resnet50(pretrained=True)
|
164 |
+
self.conv1 = network.conv1
|
165 |
+
self.bn1 = network.bn1
|
166 |
+
self.relu = network.relu # 1/2, 64
|
167 |
+
self.maxpool = network.maxpool
|
168 |
+
|
169 |
+
self.res2 = network.layer1 # 1/4, 256
|
170 |
+
self.layer2 = network.layer2 # 1/8, 512
|
171 |
+
self.layer3 = network.layer3 # 1/16, 1024
|
172 |
+
|
173 |
+
def forward(self, f):
|
174 |
+
x = self.conv1(f)
|
175 |
+
x = self.bn1(x)
|
176 |
+
x = self.relu(x) # 1/2, 64
|
177 |
+
x = self.maxpool(x) # 1/4, 64
|
178 |
+
f4 = self.res2(x) # 1/4, 256
|
179 |
+
f8 = self.layer2(f4) # 1/8, 512
|
180 |
+
f16 = self.layer3(f8) # 1/16, 1024
|
181 |
+
|
182 |
+
return f16, f8, f4
|
183 |
+
|
184 |
+
|
185 |
+
class UpsampleBlock(nn.Module):
|
186 |
+
def __init__(self, skip_dim, g_up_dim, g_out_dim, scale_factor=2):
|
187 |
+
super().__init__()
|
188 |
+
self.skip_conv = nn.Conv2d(skip_dim, g_up_dim, kernel_size=3, padding=1)
|
189 |
+
self.distributor = MainToGroupDistributor(method="add")
|
190 |
+
self.out_conv = GroupResBlock(g_up_dim, g_out_dim)
|
191 |
+
self.scale_factor = scale_factor
|
192 |
+
|
193 |
+
def forward(self, skip_f, up_g):
|
194 |
+
skip_f = self.skip_conv(skip_f)
|
195 |
+
g = upsample_groups(up_g, ratio=self.scale_factor)
|
196 |
+
g = self.distributor(skip_f, g)
|
197 |
+
g = self.out_conv(g)
|
198 |
+
return g
|
199 |
+
|
200 |
+
|
201 |
+
class KeyProjection(nn.Module):
|
202 |
+
def __init__(self, in_dim, keydim):
|
203 |
+
super().__init__()
|
204 |
+
|
205 |
+
self.key_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
|
206 |
+
# shrinkage
|
207 |
+
self.d_proj = nn.Conv2d(in_dim, 1, kernel_size=3, padding=1)
|
208 |
+
# selection
|
209 |
+
self.e_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
|
210 |
+
|
211 |
+
nn.init.orthogonal_(self.key_proj.weight.data)
|
212 |
+
nn.init.zeros_(self.key_proj.bias.data)
|
213 |
+
|
214 |
+
def forward(self, x, need_s, need_e):
|
215 |
+
shrinkage = self.d_proj(x) ** 2 + 1 if (need_s) else None
|
216 |
+
selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
|
217 |
+
|
218 |
+
return self.key_proj(x), shrinkage, selection
|
219 |
+
|
220 |
+
|
221 |
+
class Decoder(nn.Module):
|
222 |
+
def __init__(self, val_dim, hidden_dim):
|
223 |
+
super().__init__()
|
224 |
+
|
225 |
+
self.fuser = FeatureFusionBlock(1024, val_dim + hidden_dim, 512, 512)
|
226 |
+
if hidden_dim > 0:
|
227 |
+
self.hidden_update = HiddenUpdater([512, 256, 256 + 1], 256, hidden_dim)
|
228 |
+
else:
|
229 |
+
self.hidden_update = None
|
230 |
+
|
231 |
+
self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8
|
232 |
+
self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4
|
233 |
+
|
234 |
+
self.pred = nn.Conv2d(256, 1, kernel_size=3, padding=1, stride=1)
|
235 |
+
|
236 |
+
def forward(self, f16, f8, f4, hidden_state, memory_readout, h_out=True):
|
237 |
+
batch_size, num_objects = memory_readout.shape[:2]
|
238 |
+
|
239 |
+
if self.hidden_update is not None:
|
240 |
+
g16 = self.fuser(f16, torch.cat([memory_readout, hidden_state], 2))
|
241 |
+
else:
|
242 |
+
g16 = self.fuser(f16, memory_readout)
|
243 |
+
|
244 |
+
g8 = self.up_16_8(f8, g16)
|
245 |
+
g4 = self.up_8_4(f4, g8)
|
246 |
+
logits = self.pred(F.relu(g4.flatten(start_dim=0, end_dim=1)))
|
247 |
+
|
248 |
+
if h_out and self.hidden_update is not None:
|
249 |
+
g4 = torch.cat(
|
250 |
+
[g4, logits.view(batch_size, num_objects, 1, *logits.shape[-2:])], 2
|
251 |
+
)
|
252 |
+
hidden_state = self.hidden_update([g16, g8, g4], hidden_state)
|
253 |
+
else:
|
254 |
+
hidden_state = None
|
255 |
+
|
256 |
+
logits = F.interpolate(
|
257 |
+
logits, scale_factor=4, mode="bilinear", align_corners=False
|
258 |
+
)
|
259 |
+
logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
|
260 |
+
|
261 |
+
return hidden_state, logits
|
tracker/model/network.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file defines XMem, the highest level nn.Module interface
|
3 |
+
During training, it is used by trainer.py
|
4 |
+
During evaluation, it is used by inference_core.py
|
5 |
+
|
6 |
+
It further depends on modules.py which gives more detailed implementations of sub-modules
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from model.aggregate import aggregate
|
13 |
+
from model.modules import *
|
14 |
+
from model.memory_util import *
|
15 |
+
|
16 |
+
|
17 |
+
class XMem(nn.Module):
|
18 |
+
def __init__(self, config, model_path=None, map_location=None):
|
19 |
+
"""
|
20 |
+
model_path/map_location are used in evaluation only
|
21 |
+
map_location is for converting models saved in cuda to cpu
|
22 |
+
"""
|
23 |
+
super().__init__()
|
24 |
+
model_weights = self.init_hyperparameters(config, model_path, map_location)
|
25 |
+
|
26 |
+
self.single_object = config.get("single_object", False)
|
27 |
+
print(f"Single object mode: {self.single_object}")
|
28 |
+
|
29 |
+
self.key_encoder = KeyEncoder()
|
30 |
+
self.value_encoder = ValueEncoder(
|
31 |
+
self.value_dim, self.hidden_dim, self.single_object
|
32 |
+
)
|
33 |
+
|
34 |
+
# Projection from f16 feature space to key/value space
|
35 |
+
self.key_proj = KeyProjection(1024, self.key_dim)
|
36 |
+
|
37 |
+
self.decoder = Decoder(self.value_dim, self.hidden_dim)
|
38 |
+
|
39 |
+
if model_weights is not None:
|
40 |
+
self.load_weights(model_weights, init_as_zero_if_needed=True)
|
41 |
+
|
42 |
+
def encode_key(self, frame, need_sk=True, need_ek=True):
|
43 |
+
# Determine input shape
|
44 |
+
if len(frame.shape) == 5:
|
45 |
+
# shape is b*t*c*h*w
|
46 |
+
need_reshape = True
|
47 |
+
b, t = frame.shape[:2]
|
48 |
+
# flatten so that we can feed them into a 2D CNN
|
49 |
+
frame = frame.flatten(start_dim=0, end_dim=1)
|
50 |
+
elif len(frame.shape) == 4:
|
51 |
+
# shape is b*c*h*w
|
52 |
+
need_reshape = False
|
53 |
+
else:
|
54 |
+
raise NotImplementedError
|
55 |
+
|
56 |
+
f16, f8, f4 = self.key_encoder(frame)
|
57 |
+
key, shrinkage, selection = self.key_proj(f16, need_sk, need_ek)
|
58 |
+
|
59 |
+
if need_reshape:
|
60 |
+
# B*C*T*H*W
|
61 |
+
key = key.view(b, t, *key.shape[-3:]).transpose(1, 2).contiguous()
|
62 |
+
if shrinkage is not None:
|
63 |
+
shrinkage = (
|
64 |
+
shrinkage.view(b, t, *shrinkage.shape[-3:])
|
65 |
+
.transpose(1, 2)
|
66 |
+
.contiguous()
|
67 |
+
)
|
68 |
+
if selection is not None:
|
69 |
+
selection = (
|
70 |
+
selection.view(b, t, *selection.shape[-3:])
|
71 |
+
.transpose(1, 2)
|
72 |
+
.contiguous()
|
73 |
+
)
|
74 |
+
|
75 |
+
# B*T*C*H*W
|
76 |
+
f16 = f16.view(b, t, *f16.shape[-3:])
|
77 |
+
f8 = f8.view(b, t, *f8.shape[-3:])
|
78 |
+
f4 = f4.view(b, t, *f4.shape[-3:])
|
79 |
+
|
80 |
+
return key, shrinkage, selection, f16, f8, f4
|
81 |
+
|
82 |
+
def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_update=True):
|
83 |
+
num_objects = masks.shape[1]
|
84 |
+
if num_objects != 1:
|
85 |
+
others = torch.cat(
|
86 |
+
[
|
87 |
+
torch.sum(
|
88 |
+
masks[:, [j for j in range(num_objects) if i != j]],
|
89 |
+
dim=1,
|
90 |
+
keepdim=True,
|
91 |
+
)
|
92 |
+
for i in range(num_objects)
|
93 |
+
],
|
94 |
+
1,
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
others = torch.zeros_like(masks)
|
98 |
+
|
99 |
+
g16, h16 = self.value_encoder(
|
100 |
+
frame, image_feat_f16, h16, masks, others, is_deep_update
|
101 |
+
)
|
102 |
+
|
103 |
+
return g16, h16
|
104 |
+
|
105 |
+
# Used in training only.
|
106 |
+
# This step is replaced by MemoryManager in test time
|
107 |
+
def read_memory(
|
108 |
+
self, query_key, query_selection, memory_key, memory_shrinkage, memory_value
|
109 |
+
):
|
110 |
+
"""
|
111 |
+
query_key : B * CK * H * W
|
112 |
+
query_selection : B * CK * H * W
|
113 |
+
memory_key : B * CK * T * H * W
|
114 |
+
memory_shrinkage: B * 1 * T * H * W
|
115 |
+
memory_value : B * num_objects * CV * T * H * W
|
116 |
+
"""
|
117 |
+
batch_size, num_objects = memory_value.shape[:2]
|
118 |
+
memory_value = memory_value.flatten(start_dim=1, end_dim=2)
|
119 |
+
|
120 |
+
affinity = get_affinity(
|
121 |
+
memory_key, memory_shrinkage, query_key, query_selection
|
122 |
+
)
|
123 |
+
memory = readout(affinity, memory_value)
|
124 |
+
memory = memory.view(
|
125 |
+
batch_size, num_objects, self.value_dim, *memory.shape[-2:]
|
126 |
+
)
|
127 |
+
|
128 |
+
return memory
|
129 |
+
|
130 |
+
def segment(
|
131 |
+
self,
|
132 |
+
multi_scale_features,
|
133 |
+
memory_readout,
|
134 |
+
hidden_state,
|
135 |
+
selector=None,
|
136 |
+
h_out=True,
|
137 |
+
strip_bg=True,
|
138 |
+
):
|
139 |
+
|
140 |
+
hidden_state, logits = self.decoder(
|
141 |
+
*multi_scale_features, hidden_state, memory_readout, h_out=h_out
|
142 |
+
)
|
143 |
+
prob = torch.sigmoid(logits)
|
144 |
+
if selector is not None:
|
145 |
+
prob = prob * selector
|
146 |
+
|
147 |
+
logits, prob = aggregate(prob, dim=1, return_logits=True)
|
148 |
+
if strip_bg:
|
149 |
+
# Strip away the background
|
150 |
+
prob = prob[:, 1:]
|
151 |
+
|
152 |
+
return hidden_state, logits, prob
|
153 |
+
|
154 |
+
def forward(self, mode, *args, **kwargs):
|
155 |
+
if mode == "encode_key":
|
156 |
+
return self.encode_key(*args, **kwargs)
|
157 |
+
elif mode == "encode_value":
|
158 |
+
return self.encode_value(*args, **kwargs)
|
159 |
+
elif mode == "read_memory":
|
160 |
+
return self.read_memory(*args, **kwargs)
|
161 |
+
elif mode == "segment":
|
162 |
+
return self.segment(*args, **kwargs)
|
163 |
+
else:
|
164 |
+
raise NotImplementedError
|
165 |
+
|
166 |
+
def init_hyperparameters(self, config, model_path=None, map_location=None):
|
167 |
+
"""
|
168 |
+
Init three hyperparameters: key_dim, value_dim, and hidden_dim
|
169 |
+
If model_path is provided, we load these from the model weights
|
170 |
+
The actual parameters are then updated to the config in-place
|
171 |
+
|
172 |
+
Otherwise we load it either from the config or default
|
173 |
+
"""
|
174 |
+
if model_path is not None:
|
175 |
+
# load the model and key/value/hidden dimensions with some hacks
|
176 |
+
# config is updated with the loaded parameters
|
177 |
+
model_weights = torch.load(model_path, map_location=map_location)
|
178 |
+
self.key_dim = model_weights["key_proj.key_proj.weight"].shape[0]
|
179 |
+
self.value_dim = model_weights[
|
180 |
+
"value_encoder.fuser.block2.conv2.weight"
|
181 |
+
].shape[0]
|
182 |
+
self.disable_hidden = (
|
183 |
+
"decoder.hidden_update.transform.weight" not in model_weights
|
184 |
+
)
|
185 |
+
if self.disable_hidden:
|
186 |
+
self.hidden_dim = 0
|
187 |
+
else:
|
188 |
+
self.hidden_dim = (
|
189 |
+
model_weights["decoder.hidden_update.transform.weight"].shape[0]
|
190 |
+
// 3
|
191 |
+
)
|
192 |
+
print(
|
193 |
+
f"Hyperparameters read from the model weights: "
|
194 |
+
f"C^k={self.key_dim}, C^v={self.value_dim}, C^h={self.hidden_dim}"
|
195 |
+
)
|
196 |
+
else:
|
197 |
+
model_weights = None
|
198 |
+
# load dimensions from config or default
|
199 |
+
if "key_dim" not in config:
|
200 |
+
self.key_dim = 64
|
201 |
+
print(f"key_dim not found in config. Set to default {self.key_dim}")
|
202 |
+
else:
|
203 |
+
self.key_dim = config["key_dim"]
|
204 |
+
|
205 |
+
if "value_dim" not in config:
|
206 |
+
self.value_dim = 512
|
207 |
+
print(f"value_dim not found in config. Set to default {self.value_dim}")
|
208 |
+
else:
|
209 |
+
self.value_dim = config["value_dim"]
|
210 |
+
|
211 |
+
if "hidden_dim" not in config:
|
212 |
+
self.hidden_dim = 64
|
213 |
+
print(
|
214 |
+
f"hidden_dim not found in config. Set to default {self.hidden_dim}"
|
215 |
+
)
|
216 |
+
else:
|
217 |
+
self.hidden_dim = config["hidden_dim"]
|
218 |
+
|
219 |
+
self.disable_hidden = self.hidden_dim <= 0
|
220 |
+
|
221 |
+
config["key_dim"] = self.key_dim
|
222 |
+
config["value_dim"] = self.value_dim
|
223 |
+
config["hidden_dim"] = self.hidden_dim
|
224 |
+
|
225 |
+
return model_weights
|
226 |
+
|
227 |
+
def load_weights(self, src_dict, init_as_zero_if_needed=False):
|
228 |
+
# Maps SO weight (without other_mask) to MO weight (with other_mask)
|
229 |
+
for k in list(src_dict.keys()):
|
230 |
+
if k == "value_encoder.conv1.weight":
|
231 |
+
if src_dict[k].shape[1] == 4:
|
232 |
+
print("Converting weights from single object to multiple objects.")
|
233 |
+
pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device)
|
234 |
+
if not init_as_zero_if_needed:
|
235 |
+
print("Randomly initialized padding.")
|
236 |
+
nn.init.orthogonal_(pads)
|
237 |
+
else:
|
238 |
+
print("Zero-initialized padding.")
|
239 |
+
src_dict[k] = torch.cat([src_dict[k], pads], 1)
|
240 |
+
|
241 |
+
self.load_state_dict(src_dict)
|
tracker/model/resnet.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
resnet.py - A modified ResNet structure
|
3 |
+
We append extra channels to the first conv by some network surgery
|
4 |
+
"""
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.utils import model_zoo
|
12 |
+
|
13 |
+
|
14 |
+
def load_weights_add_extra_dim(target, source_state, extra_dim=1):
|
15 |
+
new_dict = OrderedDict()
|
16 |
+
|
17 |
+
for k1, v1 in target.state_dict().items():
|
18 |
+
if not "num_batches_tracked" in k1:
|
19 |
+
if k1 in source_state:
|
20 |
+
tar_v = source_state[k1]
|
21 |
+
|
22 |
+
if v1.shape != tar_v.shape:
|
23 |
+
# Init the new segmentation channel with zeros
|
24 |
+
# print(v1.shape, tar_v.shape)
|
25 |
+
c, _, w, h = v1.shape
|
26 |
+
pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device)
|
27 |
+
nn.init.orthogonal_(pads)
|
28 |
+
tar_v = torch.cat([tar_v, pads], 1)
|
29 |
+
|
30 |
+
new_dict[k1] = tar_v
|
31 |
+
|
32 |
+
target.load_state_dict(new_dict)
|
33 |
+
|
34 |
+
|
35 |
+
model_urls = {
|
36 |
+
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
|
37 |
+
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
|
42 |
+
return nn.Conv2d(
|
43 |
+
in_planes,
|
44 |
+
out_planes,
|
45 |
+
kernel_size=3,
|
46 |
+
stride=stride,
|
47 |
+
padding=dilation,
|
48 |
+
dilation=dilation,
|
49 |
+
bias=False,
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
class BasicBlock(nn.Module):
|
54 |
+
expansion = 1
|
55 |
+
|
56 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
|
57 |
+
super(BasicBlock, self).__init__()
|
58 |
+
self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
|
59 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
60 |
+
self.relu = nn.ReLU(inplace=True)
|
61 |
+
self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
|
62 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
63 |
+
self.downsample = downsample
|
64 |
+
self.stride = stride
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
residual = x
|
68 |
+
|
69 |
+
out = self.conv1(x)
|
70 |
+
out = self.bn1(out)
|
71 |
+
out = self.relu(out)
|
72 |
+
|
73 |
+
out = self.conv2(out)
|
74 |
+
out = self.bn2(out)
|
75 |
+
|
76 |
+
if self.downsample is not None:
|
77 |
+
residual = self.downsample(x)
|
78 |
+
|
79 |
+
out += residual
|
80 |
+
out = self.relu(out)
|
81 |
+
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
class Bottleneck(nn.Module):
|
86 |
+
expansion = 4
|
87 |
+
|
88 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
|
89 |
+
super(Bottleneck, self).__init__()
|
90 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
91 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
92 |
+
self.conv2 = nn.Conv2d(
|
93 |
+
planes,
|
94 |
+
planes,
|
95 |
+
kernel_size=3,
|
96 |
+
stride=stride,
|
97 |
+
dilation=dilation,
|
98 |
+
padding=dilation,
|
99 |
+
bias=False,
|
100 |
+
)
|
101 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
102 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
103 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
104 |
+
self.relu = nn.ReLU(inplace=True)
|
105 |
+
self.downsample = downsample
|
106 |
+
self.stride = stride
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
residual = x
|
110 |
+
|
111 |
+
out = self.conv1(x)
|
112 |
+
out = self.bn1(out)
|
113 |
+
out = self.relu(out)
|
114 |
+
|
115 |
+
out = self.conv2(out)
|
116 |
+
out = self.bn2(out)
|
117 |
+
out = self.relu(out)
|
118 |
+
|
119 |
+
out = self.conv3(out)
|
120 |
+
out = self.bn3(out)
|
121 |
+
|
122 |
+
if self.downsample is not None:
|
123 |
+
residual = self.downsample(x)
|
124 |
+
|
125 |
+
out += residual
|
126 |
+
out = self.relu(out)
|
127 |
+
|
128 |
+
return out
|
129 |
+
|
130 |
+
|
131 |
+
class ResNet(nn.Module):
|
132 |
+
def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
|
133 |
+
self.inplanes = 64
|
134 |
+
super(ResNet, self).__init__()
|
135 |
+
self.conv1 = nn.Conv2d(
|
136 |
+
3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False
|
137 |
+
)
|
138 |
+
self.bn1 = nn.BatchNorm2d(64)
|
139 |
+
self.relu = nn.ReLU(inplace=True)
|
140 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
141 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
142 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
143 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
144 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
145 |
+
|
146 |
+
for m in self.modules():
|
147 |
+
if isinstance(m, nn.Conv2d):
|
148 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
149 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / n))
|
150 |
+
elif isinstance(m, nn.BatchNorm2d):
|
151 |
+
m.weight.data.fill_(1)
|
152 |
+
m.bias.data.zero_()
|
153 |
+
|
154 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
|
155 |
+
downsample = None
|
156 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
157 |
+
downsample = nn.Sequential(
|
158 |
+
nn.Conv2d(
|
159 |
+
self.inplanes,
|
160 |
+
planes * block.expansion,
|
161 |
+
kernel_size=1,
|
162 |
+
stride=stride,
|
163 |
+
bias=False,
|
164 |
+
),
|
165 |
+
nn.BatchNorm2d(planes * block.expansion),
|
166 |
+
)
|
167 |
+
|
168 |
+
layers = [block(self.inplanes, planes, stride, downsample)]
|
169 |
+
self.inplanes = planes * block.expansion
|
170 |
+
for i in range(1, blocks):
|
171 |
+
layers.append(block(self.inplanes, planes, dilation=dilation))
|
172 |
+
|
173 |
+
return nn.Sequential(*layers)
|
174 |
+
|
175 |
+
|
176 |
+
def resnet18(pretrained=True, extra_dim=0):
|
177 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
|
178 |
+
if pretrained:
|
179 |
+
load_weights_add_extra_dim(
|
180 |
+
model, model_zoo.load_url(model_urls["resnet18"]), extra_dim
|
181 |
+
)
|
182 |
+
return model
|
183 |
+
|
184 |
+
|
185 |
+
def resnet50(pretrained=True, extra_dim=0):
|
186 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
|
187 |
+
if pretrained:
|
188 |
+
load_weights_add_extra_dim(
|
189 |
+
model, model_zoo.load_url(model_urls["resnet50"]), extra_dim
|
190 |
+
)
|
191 |
+
return model
|
tracker/model/trainer.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
trainer.py - warpper and utility functions for network training
|
3 |
+
Compute loss, back-prop, update parameters, logging, etc.
|
4 |
+
"""
|
5 |
+
import datetime
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.optim as optim
|
12 |
+
|
13 |
+
from model.network import XMem
|
14 |
+
from model.losses import LossComputer
|
15 |
+
from util.log_integrator import Integrator
|
16 |
+
from util.image_saver import pool_pairs
|
17 |
+
|
18 |
+
|
19 |
+
class XMemTrainer:
|
20 |
+
def __init__(self, config, logger=None, save_path=None, local_rank=0, world_size=1):
|
21 |
+
self.config = config
|
22 |
+
self.num_frames = config["num_frames"]
|
23 |
+
self.num_ref_frames = config["num_ref_frames"]
|
24 |
+
self.deep_update_prob = config["deep_update_prob"]
|
25 |
+
self.local_rank = local_rank
|
26 |
+
|
27 |
+
self.XMem = nn.parallel.DistributedDataParallel(
|
28 |
+
XMem(config).cuda(),
|
29 |
+
device_ids=[local_rank],
|
30 |
+
output_device=local_rank,
|
31 |
+
broadcast_buffers=False,
|
32 |
+
)
|
33 |
+
|
34 |
+
# Set up logger when local_rank=0
|
35 |
+
self.logger = logger
|
36 |
+
self.save_path = save_path
|
37 |
+
if logger is not None:
|
38 |
+
self.last_time = time.time()
|
39 |
+
self.logger.log_string(
|
40 |
+
"model_size",
|
41 |
+
str(sum([param.nelement() for param in self.XMem.parameters()])),
|
42 |
+
)
|
43 |
+
self.train_integrator = Integrator(
|
44 |
+
self.logger, distributed=True, local_rank=local_rank, world_size=world_size
|
45 |
+
)
|
46 |
+
self.loss_computer = LossComputer(config)
|
47 |
+
|
48 |
+
self.train()
|
49 |
+
self.optimizer = optim.AdamW(
|
50 |
+
filter(lambda p: p.requires_grad, self.XMem.parameters()),
|
51 |
+
lr=config["lr"],
|
52 |
+
weight_decay=config["weight_decay"],
|
53 |
+
)
|
54 |
+
self.scheduler = optim.lr_scheduler.MultiStepLR(
|
55 |
+
self.optimizer, config["steps"], config["gamma"]
|
56 |
+
)
|
57 |
+
if config["amp"]:
|
58 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
59 |
+
|
60 |
+
# Logging info
|
61 |
+
self.log_text_interval = config["log_text_interval"]
|
62 |
+
self.log_image_interval = config["log_image_interval"]
|
63 |
+
self.save_network_interval = config["save_network_interval"]
|
64 |
+
self.save_checkpoint_interval = config["save_checkpoint_interval"]
|
65 |
+
if config["debug"]:
|
66 |
+
self.log_text_interval = self.log_image_interval = 1
|
67 |
+
|
68 |
+
def do_pass(self, data, max_it, it=0):
|
69 |
+
# No need to store the gradient outside training
|
70 |
+
torch.set_grad_enabled(self._is_train)
|
71 |
+
|
72 |
+
for k, v in data.items():
|
73 |
+
if type(v) != list and type(v) != dict and type(v) != int:
|
74 |
+
data[k] = v.cuda(non_blocking=True)
|
75 |
+
|
76 |
+
out = {}
|
77 |
+
frames = data["rgb"]
|
78 |
+
first_frame_gt = data["first_frame_gt"].float()
|
79 |
+
b = frames.shape[0]
|
80 |
+
num_filled_objects = [o.item() for o in data["info"]["num_objects"]]
|
81 |
+
num_objects = first_frame_gt.shape[2]
|
82 |
+
selector = data["selector"].unsqueeze(2).unsqueeze(2)
|
83 |
+
|
84 |
+
global_avg = 0
|
85 |
+
|
86 |
+
with torch.cuda.amp.autocast(enabled=self.config["amp"]):
|
87 |
+
# image features never change, compute once
|
88 |
+
key, shrinkage, selection, f16, f8, f4 = self.XMem("encode_key", frames)
|
89 |
+
|
90 |
+
filler_one = torch.zeros(1, dtype=torch.int64)
|
91 |
+
hidden = torch.zeros(
|
92 |
+
(b, num_objects, self.config["hidden_dim"], *key.shape[-2:])
|
93 |
+
)
|
94 |
+
v16, hidden = self.XMem(
|
95 |
+
"encode_value", frames[:, 0], f16[:, 0], hidden, first_frame_gt[:, 0]
|
96 |
+
)
|
97 |
+
values = v16.unsqueeze(3) # add the time dimension
|
98 |
+
|
99 |
+
for ti in range(1, self.num_frames):
|
100 |
+
if ti <= self.num_ref_frames:
|
101 |
+
ref_values = values
|
102 |
+
ref_keys = key[:, :, :ti]
|
103 |
+
ref_shrinkage = (
|
104 |
+
shrinkage[:, :, :ti] if shrinkage is not None else None
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
# pick num_ref_frames random frames
|
108 |
+
# this is not very efficient but I think we would
|
109 |
+
# need broadcasting in gather which we don't have
|
110 |
+
indices = [
|
111 |
+
torch.cat(
|
112 |
+
[
|
113 |
+
filler_one,
|
114 |
+
torch.randperm(ti - 1)[: self.num_ref_frames - 1] + 1,
|
115 |
+
]
|
116 |
+
)
|
117 |
+
for _ in range(b)
|
118 |
+
]
|
119 |
+
ref_values = torch.stack(
|
120 |
+
[values[bi, :, :, indices[bi]] for bi in range(b)], 0
|
121 |
+
)
|
122 |
+
ref_keys = torch.stack(
|
123 |
+
[key[bi, :, indices[bi]] for bi in range(b)], 0
|
124 |
+
)
|
125 |
+
ref_shrinkage = (
|
126 |
+
torch.stack(
|
127 |
+
[shrinkage[bi, :, indices[bi]] for bi in range(b)], 0
|
128 |
+
)
|
129 |
+
if shrinkage is not None
|
130 |
+
else None
|
131 |
+
)
|
132 |
+
|
133 |
+
# Segment frame ti
|
134 |
+
memory_readout = self.XMem(
|
135 |
+
"read_memory",
|
136 |
+
key[:, :, ti],
|
137 |
+
selection[:, :, ti] if selection is not None else None,
|
138 |
+
ref_keys,
|
139 |
+
ref_shrinkage,
|
140 |
+
ref_values,
|
141 |
+
)
|
142 |
+
hidden, logits, masks = self.XMem(
|
143 |
+
"segment",
|
144 |
+
(f16[:, ti], f8[:, ti], f4[:, ti]),
|
145 |
+
memory_readout,
|
146 |
+
hidden,
|
147 |
+
selector,
|
148 |
+
h_out=(ti < (self.num_frames - 1)),
|
149 |
+
)
|
150 |
+
|
151 |
+
# No need to encode the last frame
|
152 |
+
if ti < (self.num_frames - 1):
|
153 |
+
is_deep_update = np.random.rand() < self.deep_update_prob
|
154 |
+
v16, hidden = self.XMem(
|
155 |
+
"encode_value",
|
156 |
+
frames[:, ti],
|
157 |
+
f16[:, ti],
|
158 |
+
hidden,
|
159 |
+
masks,
|
160 |
+
is_deep_update=is_deep_update,
|
161 |
+
)
|
162 |
+
values = torch.cat([values, v16.unsqueeze(3)], 3)
|
163 |
+
|
164 |
+
out[f"masks_{ti}"] = masks
|
165 |
+
out[f"logits_{ti}"] = logits
|
166 |
+
|
167 |
+
if self._do_log or self._is_train:
|
168 |
+
losses = self.loss_computer.compute(
|
169 |
+
{**data, **out}, num_filled_objects, it
|
170 |
+
)
|
171 |
+
|
172 |
+
# Logging
|
173 |
+
if self._do_log:
|
174 |
+
self.integrator.add_dict(losses)
|
175 |
+
if self._is_train:
|
176 |
+
if it % self.log_image_interval == 0 and it != 0:
|
177 |
+
if self.logger is not None:
|
178 |
+
images = {**data, **out}
|
179 |
+
size = (384, 384)
|
180 |
+
self.logger.log_cv2(
|
181 |
+
"train/pairs",
|
182 |
+
pool_pairs(images, size, num_filled_objects),
|
183 |
+
it,
|
184 |
+
)
|
185 |
+
|
186 |
+
if self._is_train:
|
187 |
+
|
188 |
+
if (it) % self.log_text_interval == 0 and it != 0:
|
189 |
+
time_spent = time.time() - self.last_time
|
190 |
+
|
191 |
+
if self.logger is not None:
|
192 |
+
self.logger.log_scalar(
|
193 |
+
"train/lr", self.scheduler.get_last_lr()[0], it
|
194 |
+
)
|
195 |
+
self.logger.log_metrics(
|
196 |
+
"train", "time", (time_spent) / self.log_text_interval, it
|
197 |
+
)
|
198 |
+
|
199 |
+
global_avg = 0.5 * (global_avg) + 0.5 * (time_spent)
|
200 |
+
eta_seconds = global_avg * (max_it - it) / 100
|
201 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
202 |
+
print(f"ETA: {eta_string}")
|
203 |
+
|
204 |
+
self.last_time = time.time()
|
205 |
+
self.train_integrator.finalize("train", it)
|
206 |
+
self.train_integrator.reset_except_hooks()
|
207 |
+
|
208 |
+
if it % self.save_network_interval == 0 and it != 0:
|
209 |
+
if self.logger is not None:
|
210 |
+
self.save_network(it)
|
211 |
+
|
212 |
+
if it % self.save_checkpoint_interval == 0 and it != 0:
|
213 |
+
if self.logger is not None:
|
214 |
+
self.save_checkpoint(it)
|
215 |
+
|
216 |
+
# Backward pass
|
217 |
+
self.optimizer.zero_grad(set_to_none=True)
|
218 |
+
if self.config["amp"]:
|
219 |
+
self.scaler.scale(losses["total_loss"]).backward()
|
220 |
+
self.scaler.step(self.optimizer)
|
221 |
+
self.scaler.update()
|
222 |
+
else:
|
223 |
+
losses["total_loss"].backward()
|
224 |
+
self.optimizer.step()
|
225 |
+
|
226 |
+
self.scheduler.step()
|
227 |
+
|
228 |
+
def save_network(self, it):
|
229 |
+
if self.save_path is None:
|
230 |
+
print("Saving has been disabled.")
|
231 |
+
return
|
232 |
+
|
233 |
+
os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
|
234 |
+
model_path = f"{self.save_path}_{it}.pth"
|
235 |
+
torch.save(self.XMem.module.state_dict(), model_path)
|
236 |
+
print(f"Network saved to {model_path}.")
|
237 |
+
|
238 |
+
def save_checkpoint(self, it):
|
239 |
+
if self.save_path is None:
|
240 |
+
print("Saving has been disabled.")
|
241 |
+
return
|
242 |
+
|
243 |
+
os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
|
244 |
+
checkpoint_path = f"{self.save_path}_checkpoint_{it}.pth"
|
245 |
+
checkpoint = {
|
246 |
+
"it": it,
|
247 |
+
"network": self.XMem.module.state_dict(),
|
248 |
+
"optimizer": self.optimizer.state_dict(),
|
249 |
+
"scheduler": self.scheduler.state_dict(),
|
250 |
+
}
|
251 |
+
torch.save(checkpoint, checkpoint_path)
|
252 |
+
print(f"Checkpoint saved to {checkpoint_path}.")
|
253 |
+
|
254 |
+
def load_checkpoint(self, path):
|
255 |
+
# This method loads everything and should be used to resume training
|
256 |
+
map_location = "cuda:%d" % self.local_rank
|
257 |
+
checkpoint = torch.load(path, map_location={"cuda:0": map_location})
|
258 |
+
|
259 |
+
it = checkpoint["it"]
|
260 |
+
network = checkpoint["network"]
|
261 |
+
optimizer = checkpoint["optimizer"]
|
262 |
+
scheduler = checkpoint["scheduler"]
|
263 |
+
|
264 |
+
map_location = "cuda:%d" % self.local_rank
|
265 |
+
self.XMem.module.load_state_dict(network)
|
266 |
+
self.optimizer.load_state_dict(optimizer)
|
267 |
+
self.scheduler.load_state_dict(scheduler)
|
268 |
+
|
269 |
+
print("Network weights, optimizer states, and scheduler states loaded.")
|
270 |
+
|
271 |
+
return it
|
272 |
+
|
273 |
+
def load_network_in_memory(self, src_dict):
|
274 |
+
self.XMem.module.load_weights(src_dict)
|
275 |
+
print("Network weight loaded from memory.")
|
276 |
+
|
277 |
+
def load_network(self, path):
|
278 |
+
# This method loads only the network weight and should be used to load a pretrained model
|
279 |
+
map_location = "cuda:%d" % self.local_rank
|
280 |
+
src_dict = torch.load(path, map_location={"cuda:0": map_location})
|
281 |
+
|
282 |
+
self.load_network_in_memory(src_dict)
|
283 |
+
print(f"Network weight loaded from {path}")
|
284 |
+
|
285 |
+
def train(self):
|
286 |
+
self._is_train = True
|
287 |
+
self._do_log = True
|
288 |
+
self.integrator = self.train_integrator
|
289 |
+
self.XMem.eval()
|
290 |
+
return self
|
291 |
+
|
292 |
+
def val(self):
|
293 |
+
self._is_train = False
|
294 |
+
self._do_log = True
|
295 |
+
self.XMem.eval()
|
296 |
+
return self
|
297 |
+
|
298 |
+
def test(self):
|
299 |
+
self._is_train = False
|
300 |
+
self._do_log = False
|
301 |
+
self.XMem.eval()
|
302 |
+
return self
|
tracker/util/__init__.py
ADDED
File without changes
|
tracker/util/mask_mapper.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def all_to_onehot(masks, labels):
|
6 |
+
if len(masks.shape) == 3:
|
7 |
+
Ms = np.zeros(
|
8 |
+
(len(labels), masks.shape[0], masks.shape[1], masks.shape[2]),
|
9 |
+
dtype=np.uint8,
|
10 |
+
)
|
11 |
+
else:
|
12 |
+
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
|
13 |
+
|
14 |
+
for ni, l in enumerate(labels):
|
15 |
+
Ms[ni] = (masks == l).astype(np.uint8)
|
16 |
+
|
17 |
+
return Ms
|
18 |
+
|
19 |
+
|
20 |
+
class MaskMapper:
|
21 |
+
"""
|
22 |
+
This class is used to convert a indexed-mask to a one-hot representation.
|
23 |
+
It also takes care of remapping non-continuous indices
|
24 |
+
It has two modes:
|
25 |
+
1. Default. Only masks with new indices are supposed to go into the remapper.
|
26 |
+
This is also the case for YouTubeVOS.
|
27 |
+
i.e., regions with index 0 are not "background", but "don't care".
|
28 |
+
|
29 |
+
2. Exhaustive. Regions with index 0 are considered "background".
|
30 |
+
Every single pixel is considered to be "labeled".
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self):
|
34 |
+
self.labels = []
|
35 |
+
self.remappings = {}
|
36 |
+
|
37 |
+
# if coherent, no mapping is required
|
38 |
+
self.coherent = True
|
39 |
+
|
40 |
+
def clear_labels(self):
|
41 |
+
self.labels = []
|
42 |
+
self.remappings = {}
|
43 |
+
# if coherent, no mapping is required
|
44 |
+
self.coherent = True
|
45 |
+
|
46 |
+
def convert_mask(self, mask, exhaustive=False):
|
47 |
+
# mask is in index representation, H*W numpy array
|
48 |
+
labels = np.unique(mask).astype(np.uint8)
|
49 |
+
labels = labels[labels != 0].tolist()
|
50 |
+
|
51 |
+
new_labels = list(set(labels) - set(self.labels))
|
52 |
+
if not exhaustive:
|
53 |
+
assert len(new_labels) == len(
|
54 |
+
labels
|
55 |
+
), "Old labels found in non-exhaustive mode"
|
56 |
+
|
57 |
+
# add new remappings
|
58 |
+
for i, l in enumerate(new_labels):
|
59 |
+
self.remappings[l] = i + len(self.labels) + 1
|
60 |
+
if self.coherent and i + len(self.labels) + 1 != l:
|
61 |
+
self.coherent = False
|
62 |
+
|
63 |
+
if exhaustive:
|
64 |
+
new_mapped_labels = range(1, len(self.labels) + len(new_labels) + 1)
|
65 |
+
else:
|
66 |
+
if self.coherent:
|
67 |
+
new_mapped_labels = new_labels
|
68 |
+
else:
|
69 |
+
new_mapped_labels = range(
|
70 |
+
len(self.labels) + 1, len(self.labels) + len(new_labels) + 1
|
71 |
+
)
|
72 |
+
|
73 |
+
self.labels.extend(new_labels)
|
74 |
+
mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float()
|
75 |
+
|
76 |
+
# mask num_objects*H*W
|
77 |
+
return mask, new_mapped_labels
|
78 |
+
|
79 |
+
def remap_index_mask(self, mask):
|
80 |
+
# mask is in index representation, H*W numpy array
|
81 |
+
if self.coherent:
|
82 |
+
return mask
|
83 |
+
|
84 |
+
new_mask = np.zeros_like(mask)
|
85 |
+
for l, i in self.remappings.items():
|
86 |
+
new_mask[mask == i] = l
|
87 |
+
return new_mask
|
tracker/util/range_transform.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision.transforms as transforms
|
2 |
+
|
3 |
+
im_mean = (124, 116, 104)
|
4 |
+
|
5 |
+
im_normalization = transforms.Normalize(
|
6 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
7 |
+
)
|
8 |
+
|
9 |
+
inv_im_trans = transforms.Normalize(
|
10 |
+
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
|
11 |
+
std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
|
12 |
+
)
|
tracker/util/tensor_util.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
|
3 |
+
|
4 |
+
def compute_tensor_iu(seg, gt):
|
5 |
+
intersection = (seg & gt).float().sum()
|
6 |
+
union = (seg | gt).float().sum()
|
7 |
+
|
8 |
+
return intersection, union
|
9 |
+
|
10 |
+
|
11 |
+
def compute_tensor_iou(seg, gt):
|
12 |
+
intersection, union = compute_tensor_iu(seg, gt)
|
13 |
+
iou = (intersection + 1e-6) / (union + 1e-6)
|
14 |
+
|
15 |
+
return iou
|
16 |
+
|
17 |
+
|
18 |
+
# STM
|
19 |
+
def pad_divide_by(in_img, d):
|
20 |
+
h, w = in_img.shape[-2:]
|
21 |
+
|
22 |
+
if h % d > 0:
|
23 |
+
new_h = h + d - h % d
|
24 |
+
else:
|
25 |
+
new_h = h
|
26 |
+
if w % d > 0:
|
27 |
+
new_w = w + d - w % d
|
28 |
+
else:
|
29 |
+
new_w = w
|
30 |
+
lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
|
31 |
+
lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
|
32 |
+
pad_array = (int(lw), int(uw), int(lh), int(uh))
|
33 |
+
out = F.pad(in_img, pad_array)
|
34 |
+
return out, pad_array
|
35 |
+
|
36 |
+
|
37 |
+
def unpad(img, pad):
|
38 |
+
if len(img.shape) == 4:
|
39 |
+
if pad[2] + pad[3] > 0:
|
40 |
+
img = img[:, :, pad[2] : -pad[3], :]
|
41 |
+
if pad[0] + pad[1] > 0:
|
42 |
+
img = img[:, :, :, pad[0] : -pad[1]]
|
43 |
+
elif len(img.shape) == 3:
|
44 |
+
if pad[2] + pad[3] > 0:
|
45 |
+
img = img[:, pad[2] : -pad[3], :]
|
46 |
+
if pad[0] + pad[1] > 0:
|
47 |
+
img = img[:, :, pad[0] : -pad[1]]
|
48 |
+
else:
|
49 |
+
raise NotImplementedError
|
50 |
+
return img
|
utils/base_segmenter.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
class BaseSegmenter:
|
6 |
+
def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device="cuda:0"):
|
7 |
+
"""
|
8 |
+
device: model device
|
9 |
+
SAM_checkpoint: path of SAM checkpoint
|
10 |
+
model_type: vit_b, vit_l, vit_h, vit_t
|
11 |
+
"""
|
12 |
+
print(f"Initializing BaseSegmenter to {device}")
|
13 |
+
assert model_type in [
|
14 |
+
"vit_b",
|
15 |
+
"vit_l",
|
16 |
+
"vit_h",
|
17 |
+
"vit_t",
|
18 |
+
], "model_type must be vit_b, vit_l, vit_h or vit_t"
|
19 |
+
|
20 |
+
self.device = device
|
21 |
+
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
22 |
+
|
23 |
+
if (model_type == "vit_t"):
|
24 |
+
from mobile_sam import sam_model_registry, SamPredictor
|
25 |
+
from onnxruntime import InferenceSession
|
26 |
+
self.ort_session = InferenceSession(sam_onnx_checkpoint)
|
27 |
+
self.predict = self.predict_onnx
|
28 |
+
else:
|
29 |
+
from segment_anything import sam_model_registry, SamPredictor
|
30 |
+
self.predict = self.predict_pt
|
31 |
+
|
32 |
+
self.model = sam_model_registry[model_type](checkpoint=sam_pt_checkpoint)
|
33 |
+
self.model.to(device=self.device)
|
34 |
+
self.predictor = SamPredictor(self.model)
|
35 |
+
self.embedded = False
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def set_image(self, image: np.ndarray):
|
39 |
+
# PIL.open(image_path) 3channel: RGB
|
40 |
+
# image embedding: avoid encode the same image multiple times
|
41 |
+
self.orignal_image = image
|
42 |
+
if self.embedded:
|
43 |
+
print("repeat embedding, please reset_image.")
|
44 |
+
return
|
45 |
+
self.predictor.set_image(image)
|
46 |
+
self.image_embedding = self.predictor.get_image_embedding().cpu().numpy()
|
47 |
+
self.embedded = True
|
48 |
+
return
|
49 |
+
|
50 |
+
@torch.no_grad()
|
51 |
+
def reset_image(self):
|
52 |
+
# reset image embeding
|
53 |
+
self.predictor.reset_image()
|
54 |
+
self.embedded = False
|
55 |
+
|
56 |
+
def predict_pt(self, prompts, mode, multimask=True):
|
57 |
+
"""
|
58 |
+
image: numpy array, h, w, 3
|
59 |
+
prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
|
60 |
+
prompts['point_coords']: numpy array [N,2]
|
61 |
+
prompts['point_labels']: numpy array [1,N]
|
62 |
+
prompts['mask_input']: numpy array [1,256,256]
|
63 |
+
mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
|
64 |
+
mask_outputs: True (return 3 masks), False (return 1 mask only)
|
65 |
+
whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
|
66 |
+
"""
|
67 |
+
assert (
|
68 |
+
self.embedded
|
69 |
+
), "prediction is called before set_image (feature embedding)."
|
70 |
+
assert mode in ["point", "mask", "both"], "mode must be point, mask, or both"
|
71 |
+
|
72 |
+
if mode == "point":
|
73 |
+
masks, scores, logits = self.predictor.predict(
|
74 |
+
point_coords=prompts["point_coords"],
|
75 |
+
point_labels=prompts["point_labels"],
|
76 |
+
multimask_output=multimask,
|
77 |
+
)
|
78 |
+
elif mode == "mask":
|
79 |
+
masks, scores, logits = self.predictor.predict(
|
80 |
+
mask_input=prompts["mask_input"], multimask_output=multimask
|
81 |
+
)
|
82 |
+
elif mode == "both": # both
|
83 |
+
masks, scores, logits = self.predictor.predict(
|
84 |
+
point_coords=prompts["point_coords"],
|
85 |
+
point_labels=prompts["point_labels"],
|
86 |
+
mask_input=prompts["mask_input"],
|
87 |
+
multimask_output=multimask,
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
raise ("Not implement now!")
|
91 |
+
# masks (n, h, w), scores (n,), logits (n, 256, 256)
|
92 |
+
return masks, scores, logits
|
93 |
+
|
94 |
+
def predict_onnx(self, prompts, mode, multimask=True):
|
95 |
+
"""
|
96 |
+
image: numpy array, h, w, 3
|
97 |
+
prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
|
98 |
+
prompts['point_coords']: numpy array [N,2]
|
99 |
+
prompts['point_labels']: numpy array [1,N]
|
100 |
+
prompts['mask_input']: numpy array [1,256,256]
|
101 |
+
mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
|
102 |
+
mask_outputs: True (return 3 masks), False (return 1 mask only)
|
103 |
+
whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
|
104 |
+
"""
|
105 |
+
assert (
|
106 |
+
self.embedded
|
107 |
+
), "prediction is called before set_image (feature embedding)."
|
108 |
+
assert mode in ["point", "mask", "both"], "mode must be point, mask, or both"
|
109 |
+
|
110 |
+
if mode == "point":
|
111 |
+
ort_inputs = {
|
112 |
+
"image_embeddings": self.image_embedding,
|
113 |
+
"point_coords": prompts["point_coords"],
|
114 |
+
"point_labels": prompts["point_labels"],
|
115 |
+
"mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
|
116 |
+
"has_mask_input": np.zeros(1, dtype=np.float32),
|
117 |
+
"orig_im_size": prompts["orig_im_size"],
|
118 |
+
}
|
119 |
+
masks, scores, logits = self.ort_session.run(None, ort_inputs)
|
120 |
+
masks = masks > self.predictor.model.mask_threshold
|
121 |
+
|
122 |
+
elif mode == "mask":
|
123 |
+
ort_inputs = {
|
124 |
+
"image_embeddings": self.image_embedding,
|
125 |
+
"point_coords": np.zeros((len(prompts["point_labels"]), 2), dtype=np.float32),
|
126 |
+
"point_labels": prompts["point_labels"],
|
127 |
+
"mask_input": prompts["mask_input"],
|
128 |
+
"has_mask_input": np.ones(1, dtype=np.float32),
|
129 |
+
"orig_im_size": prompts["orig_im_size"],
|
130 |
+
}
|
131 |
+
masks, scores, logits = self.ort_session.run(None, ort_inputs)
|
132 |
+
masks = masks > self.predictor.model.mask_threshold
|
133 |
+
|
134 |
+
elif mode == "both": # both
|
135 |
+
ort_inputs = {
|
136 |
+
"image_embeddings": self.image_embedding,
|
137 |
+
"point_coords": prompts["point_coords"],
|
138 |
+
"point_labels": prompts["point_labels"],
|
139 |
+
"mask_input": prompts["mask_input"],
|
140 |
+
"has_mask_input": np.ones(1, dtype=np.float32),
|
141 |
+
"orig_im_size": prompts["orig_im_size"],
|
142 |
+
}
|
143 |
+
masks, scores, logits = self.ort_session.run(None, ort_inputs)
|
144 |
+
masks = masks > self.predictor.model.mask_threshold
|
145 |
+
|
146 |
+
else:
|
147 |
+
raise ("Not implement now!")
|
148 |
+
# masks (n, h, w), scores (n,), logits (n, 256, 256)
|
149 |
+
return masks[0], scores[0], logits[0]
|
utils/blur.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
# resize frames
|
7 |
+
def resize_frames(frames, size=None):
|
8 |
+
"""
|
9 |
+
size: (w, h)
|
10 |
+
"""
|
11 |
+
if size is not None:
|
12 |
+
frames = [cv2.resize(f, size) for f in frames]
|
13 |
+
frames = np.stack(frames, 0)
|
14 |
+
|
15 |
+
return frames
|
16 |
+
|
17 |
+
|
18 |
+
# resize frames
|
19 |
+
def resize_masks(masks, size=None):
|
20 |
+
"""
|
21 |
+
size: (w, h)
|
22 |
+
"""
|
23 |
+
if size is not None:
|
24 |
+
masks = [np.expand_dims(cv2.resize(m, size), 2) for m in masks]
|
25 |
+
masks = np.stack(masks, 0)
|
26 |
+
|
27 |
+
return masks
|
28 |
+
|
29 |
+
|
30 |
+
# apply gaussian blur to mask with defined strength
|
31 |
+
def apply_blur(frame, strength):
|
32 |
+
blurred = cv2.GaussianBlur(frame, (strength, strength), 0)
|
33 |
+
return blurred
|
34 |
+
|
35 |
+
|
36 |
+
# blur frames
|
37 |
+
def blur_frames_and_write(
|
38 |
+
frames, masks, ratio, strength, dilate_radius=15, fps=30, output_path="blurred.mp4"
|
39 |
+
):
|
40 |
+
assert frames.shape[:3] == masks.shape, "different size between frames and masks"
|
41 |
+
assert ratio > 0 and ratio <= 1, "ratio must in (0, 1]"
|
42 |
+
|
43 |
+
# --------------------
|
44 |
+
# pre-processing
|
45 |
+
# --------------------
|
46 |
+
masks = masks.copy()
|
47 |
+
masks = np.clip(masks, 0, 1)
|
48 |
+
kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
|
49 |
+
masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
|
50 |
+
T, H, W = masks.shape
|
51 |
+
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
|
52 |
+
# size: (w, h)
|
53 |
+
if ratio == 1:
|
54 |
+
size = (W, H)
|
55 |
+
binary_masks = masks
|
56 |
+
else:
|
57 |
+
size = [int(W * ratio), int(H * ratio)]
|
58 |
+
size = [
|
59 |
+
si + 1 if si % 2 > 0 else si for si in size
|
60 |
+
] # only consider even values
|
61 |
+
# shortest side should be larger than 50
|
62 |
+
if min(size) < 50:
|
63 |
+
ratio = 50.0 / min(H, W)
|
64 |
+
size = [int(W * ratio), int(H * ratio)]
|
65 |
+
binary_masks = resize_masks(masks, tuple(size))
|
66 |
+
frames = resize_frames(frames, tuple(size)) # T, H, W, 3
|
67 |
+
|
68 |
+
if not os.path.exists(os.path.dirname(output_path)):
|
69 |
+
os.makedirs(os.path.dirname(output_path))
|
70 |
+
writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, size)
|
71 |
+
|
72 |
+
for frame, mask in zip(frames, binary_masks):
|
73 |
+
blurred_frame = apply_blur(frame, strength)
|
74 |
+
masked = cv2.bitwise_or(blurred_frame, blurred_frame, mask=mask)
|
75 |
+
processed = np.where(masked == (0, 0, 0), frame, masked)
|
76 |
+
|
77 |
+
writer.write(processed[:, :, ::-1])
|
78 |
+
|
79 |
+
writer.release()
|
80 |
+
|
81 |
+
return output_path
|
utils/interact_tools.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
from .base_segmenter import BaseSegmenter
|
4 |
+
from .painter import mask_painter, point_painter
|
5 |
+
|
6 |
+
|
7 |
+
mask_color = 3
|
8 |
+
mask_alpha = 0.7
|
9 |
+
contour_color = 1
|
10 |
+
contour_width = 5
|
11 |
+
point_color_ne = 8
|
12 |
+
point_color_ps = 50
|
13 |
+
point_alpha = 0.9
|
14 |
+
point_radius = 15
|
15 |
+
contour_color = 2
|
16 |
+
contour_width = 5
|
17 |
+
|
18 |
+
|
19 |
+
class SamControler:
|
20 |
+
def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device):
|
21 |
+
"""
|
22 |
+
initialize sam controler
|
23 |
+
"""
|
24 |
+
|
25 |
+
self.sam_controler = BaseSegmenter(sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device)
|
26 |
+
self.onnx = model_type == "vit_t"
|
27 |
+
|
28 |
+
def first_frame_click(
|
29 |
+
self,
|
30 |
+
image: np.ndarray,
|
31 |
+
points: np.ndarray,
|
32 |
+
labels: np.ndarray,
|
33 |
+
multimask=True,
|
34 |
+
mask_color=3,
|
35 |
+
):
|
36 |
+
"""
|
37 |
+
it is used in first frame in video
|
38 |
+
return: mask, logit, painted image(mask+point)
|
39 |
+
"""
|
40 |
+
# self.sam_controler.set_image(image)
|
41 |
+
neg_flag = labels[-1]
|
42 |
+
|
43 |
+
if self.onnx:
|
44 |
+
onnx_coord = np.concatenate([points, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
|
45 |
+
onnx_label = np.concatenate([labels, np.array([-1])], axis=0)[None, :].astype(np.float32)
|
46 |
+
onnx_coord = self.sam_controler.predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
|
47 |
+
prompts = {
|
48 |
+
"point_coords": onnx_coord,
|
49 |
+
"point_labels": onnx_label,
|
50 |
+
"orig_im_size": np.array(image.shape[:2], dtype=np.float32),
|
51 |
+
}
|
52 |
+
|
53 |
+
else:
|
54 |
+
prompts = {
|
55 |
+
"point_coords": points,
|
56 |
+
"point_labels": labels,
|
57 |
+
}
|
58 |
+
|
59 |
+
if neg_flag == 1:
|
60 |
+
# find positive
|
61 |
+
masks, scores, logits = self.sam_controler.predict(
|
62 |
+
prompts, "point", multimask
|
63 |
+
)
|
64 |
+
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
65 |
+
|
66 |
+
prompts["mask_input"] = np.expand_dims(logit[None, :, :], 0)
|
67 |
+
masks, scores, logits = self.sam_controler.predict(
|
68 |
+
prompts, "both", multimask
|
69 |
+
)
|
70 |
+
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
71 |
+
|
72 |
+
else:
|
73 |
+
# find neg
|
74 |
+
masks, scores, logits = self.sam_controler.predict(
|
75 |
+
prompts, "point", multimask
|
76 |
+
)
|
77 |
+
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
78 |
+
|
79 |
+
assert len(points) == len(labels)
|
80 |
+
|
81 |
+
painted_image = mask_painter(
|
82 |
+
image,
|
83 |
+
mask.astype("uint8"),
|
84 |
+
mask_color,
|
85 |
+
mask_alpha,
|
86 |
+
contour_color,
|
87 |
+
contour_width,
|
88 |
+
)
|
89 |
+
painted_image = point_painter(
|
90 |
+
painted_image,
|
91 |
+
np.squeeze(points[np.argwhere(labels > 0)], axis=1),
|
92 |
+
point_color_ne,
|
93 |
+
point_alpha,
|
94 |
+
point_radius,
|
95 |
+
contour_color,
|
96 |
+
contour_width,
|
97 |
+
)
|
98 |
+
painted_image = point_painter(
|
99 |
+
painted_image,
|
100 |
+
np.squeeze(points[np.argwhere(labels < 1)], axis=1),
|
101 |
+
point_color_ps,
|
102 |
+
point_alpha,
|
103 |
+
point_radius,
|
104 |
+
contour_color,
|
105 |
+
contour_width,
|
106 |
+
)
|
107 |
+
painted_image = Image.fromarray(painted_image)
|
108 |
+
|
109 |
+
return mask, logit, painted_image
|
utils/painter.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def colormap(rgb=True):
|
7 |
+
color_list = np.array(
|
8 |
+
[
|
9 |
+
0.000,
|
10 |
+
0.000,
|
11 |
+
0.000,
|
12 |
+
1.000,
|
13 |
+
1.000,
|
14 |
+
1.000,
|
15 |
+
1.000,
|
16 |
+
0.498,
|
17 |
+
0.313,
|
18 |
+
0.392,
|
19 |
+
0.581,
|
20 |
+
0.929,
|
21 |
+
0.000,
|
22 |
+
0.447,
|
23 |
+
0.741,
|
24 |
+
0.850,
|
25 |
+
0.325,
|
26 |
+
0.098,
|
27 |
+
0.929,
|
28 |
+
0.694,
|
29 |
+
0.125,
|
30 |
+
0.494,
|
31 |
+
0.184,
|
32 |
+
0.556,
|
33 |
+
0.466,
|
34 |
+
0.674,
|
35 |
+
0.188,
|
36 |
+
0.301,
|
37 |
+
0.745,
|
38 |
+
0.933,
|
39 |
+
0.635,
|
40 |
+
0.078,
|
41 |
+
0.184,
|
42 |
+
0.300,
|
43 |
+
0.300,
|
44 |
+
0.300,
|
45 |
+
0.600,
|
46 |
+
0.600,
|
47 |
+
0.600,
|
48 |
+
1.000,
|
49 |
+
0.000,
|
50 |
+
0.000,
|
51 |
+
1.000,
|
52 |
+
0.500,
|
53 |
+
0.000,
|
54 |
+
0.749,
|
55 |
+
0.749,
|
56 |
+
0.000,
|
57 |
+
0.000,
|
58 |
+
1.000,
|
59 |
+
0.000,
|
60 |
+
0.000,
|
61 |
+
0.000,
|
62 |
+
1.000,
|
63 |
+
0.667,
|
64 |
+
0.000,
|
65 |
+
1.000,
|
66 |
+
0.333,
|
67 |
+
0.333,
|
68 |
+
0.000,
|
69 |
+
0.333,
|
70 |
+
0.667,
|
71 |
+
0.000,
|
72 |
+
0.333,
|
73 |
+
1.000,
|
74 |
+
0.000,
|
75 |
+
0.667,
|
76 |
+
0.333,
|
77 |
+
0.000,
|
78 |
+
0.667,
|
79 |
+
0.667,
|
80 |
+
0.000,
|
81 |
+
0.667,
|
82 |
+
1.000,
|
83 |
+
0.000,
|
84 |
+
1.000,
|
85 |
+
0.333,
|
86 |
+
0.000,
|
87 |
+
1.000,
|
88 |
+
0.667,
|
89 |
+
0.000,
|
90 |
+
1.000,
|
91 |
+
1.000,
|
92 |
+
0.000,
|
93 |
+
0.000,
|
94 |
+
0.333,
|
95 |
+
0.500,
|
96 |
+
0.000,
|
97 |
+
0.667,
|
98 |
+
0.500,
|
99 |
+
0.000,
|
100 |
+
1.000,
|
101 |
+
0.500,
|
102 |
+
0.333,
|
103 |
+
0.000,
|
104 |
+
0.500,
|
105 |
+
0.333,
|
106 |
+
0.333,
|
107 |
+
0.500,
|
108 |
+
0.333,
|
109 |
+
0.667,
|
110 |
+
0.500,
|
111 |
+
0.333,
|
112 |
+
1.000,
|
113 |
+
0.500,
|
114 |
+
0.667,
|
115 |
+
0.000,
|
116 |
+
0.500,
|
117 |
+
0.667,
|
118 |
+
0.333,
|
119 |
+
0.500,
|
120 |
+
0.667,
|
121 |
+
0.667,
|
122 |
+
0.500,
|
123 |
+
0.667,
|
124 |
+
1.000,
|
125 |
+
0.500,
|
126 |
+
1.000,
|
127 |
+
0.000,
|
128 |
+
0.500,
|
129 |
+
1.000,
|
130 |
+
0.333,
|
131 |
+
0.500,
|
132 |
+
1.000,
|
133 |
+
0.667,
|
134 |
+
0.500,
|
135 |
+
1.000,
|
136 |
+
1.000,
|
137 |
+
0.500,
|
138 |
+
0.000,
|
139 |
+
0.333,
|
140 |
+
1.000,
|
141 |
+
0.000,
|
142 |
+
0.667,
|
143 |
+
1.000,
|
144 |
+
0.000,
|
145 |
+
1.000,
|
146 |
+
1.000,
|
147 |
+
0.333,
|
148 |
+
0.000,
|
149 |
+
1.000,
|
150 |
+
0.333,
|
151 |
+
0.333,
|
152 |
+
1.000,
|
153 |
+
0.333,
|
154 |
+
0.667,
|
155 |
+
1.000,
|
156 |
+
0.333,
|
157 |
+
1.000,
|
158 |
+
1.000,
|
159 |
+
0.667,
|
160 |
+
0.000,
|
161 |
+
1.000,
|
162 |
+
0.667,
|
163 |
+
0.333,
|
164 |
+
1.000,
|
165 |
+
0.667,
|
166 |
+
0.667,
|
167 |
+
1.000,
|
168 |
+
0.667,
|
169 |
+
1.000,
|
170 |
+
1.000,
|
171 |
+
1.000,
|
172 |
+
0.000,
|
173 |
+
1.000,
|
174 |
+
1.000,
|
175 |
+
0.333,
|
176 |
+
1.000,
|
177 |
+
1.000,
|
178 |
+
0.667,
|
179 |
+
1.000,
|
180 |
+
0.167,
|
181 |
+
0.000,
|
182 |
+
0.000,
|
183 |
+
0.333,
|
184 |
+
0.000,
|
185 |
+
0.000,
|
186 |
+
0.500,
|
187 |
+
0.000,
|
188 |
+
0.000,
|
189 |
+
0.667,
|
190 |
+
0.000,
|
191 |
+
0.000,
|
192 |
+
0.833,
|
193 |
+
0.000,
|
194 |
+
0.000,
|
195 |
+
1.000,
|
196 |
+
0.000,
|
197 |
+
0.000,
|
198 |
+
0.000,
|
199 |
+
0.167,
|
200 |
+
0.000,
|
201 |
+
0.000,
|
202 |
+
0.333,
|
203 |
+
0.000,
|
204 |
+
0.000,
|
205 |
+
0.500,
|
206 |
+
0.000,
|
207 |
+
0.000,
|
208 |
+
0.667,
|
209 |
+
0.000,
|
210 |
+
0.000,
|
211 |
+
0.833,
|
212 |
+
0.000,
|
213 |
+
0.000,
|
214 |
+
1.000,
|
215 |
+
0.000,
|
216 |
+
0.000,
|
217 |
+
0.000,
|
218 |
+
0.167,
|
219 |
+
0.000,
|
220 |
+
0.000,
|
221 |
+
0.333,
|
222 |
+
0.000,
|
223 |
+
0.000,
|
224 |
+
0.500,
|
225 |
+
0.000,
|
226 |
+
0.000,
|
227 |
+
0.667,
|
228 |
+
0.000,
|
229 |
+
0.000,
|
230 |
+
0.833,
|
231 |
+
0.000,
|
232 |
+
0.000,
|
233 |
+
1.000,
|
234 |
+
0.143,
|
235 |
+
0.143,
|
236 |
+
0.143,
|
237 |
+
0.286,
|
238 |
+
0.286,
|
239 |
+
0.286,
|
240 |
+
0.429,
|
241 |
+
0.429,
|
242 |
+
0.429,
|
243 |
+
0.571,
|
244 |
+
0.571,
|
245 |
+
0.571,
|
246 |
+
0.714,
|
247 |
+
0.714,
|
248 |
+
0.714,
|
249 |
+
0.857,
|
250 |
+
0.857,
|
251 |
+
0.857,
|
252 |
+
]
|
253 |
+
).astype(np.float32)
|
254 |
+
color_list = color_list.reshape((-1, 3)) * 255
|
255 |
+
if not rgb:
|
256 |
+
color_list = color_list[:, ::-1]
|
257 |
+
return color_list
|
258 |
+
|
259 |
+
|
260 |
+
color_list = colormap()
|
261 |
+
color_list = color_list.astype("uint8").tolist()
|
262 |
+
|
263 |
+
|
264 |
+
def vis_add_mask(image, mask, color, alpha):
|
265 |
+
color = np.array(color_list[color])
|
266 |
+
mask = mask > 0.5
|
267 |
+
image[mask] = image[mask] * (1 - alpha) + color * alpha
|
268 |
+
return image.astype("uint8")
|
269 |
+
|
270 |
+
|
271 |
+
def point_painter(
|
272 |
+
input_image,
|
273 |
+
input_points,
|
274 |
+
point_color=5,
|
275 |
+
point_alpha=0.9,
|
276 |
+
point_radius=15,
|
277 |
+
contour_color=2,
|
278 |
+
contour_width=5,
|
279 |
+
):
|
280 |
+
h, w = input_image.shape[:2]
|
281 |
+
point_mask = np.zeros((h, w)).astype("uint8")
|
282 |
+
for point in input_points:
|
283 |
+
point_mask[point[1], point[0]] = 1
|
284 |
+
|
285 |
+
kernel = cv2.getStructuringElement(2, (point_radius, point_radius))
|
286 |
+
point_mask = cv2.dilate(point_mask, kernel)
|
287 |
+
|
288 |
+
contour_radius = (contour_width - 1) // 2
|
289 |
+
dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3)
|
290 |
+
dist_transform_back = cv2.distanceTransform(1 - point_mask, cv2.DIST_L2, 3)
|
291 |
+
dist_map = dist_transform_fore - dist_transform_back
|
292 |
+
# ...:::!!!:::...
|
293 |
+
contour_radius += 2
|
294 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
295 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
296 |
+
contour_mask[contour_mask > 0.5] = 1.0
|
297 |
+
|
298 |
+
# paint mask
|
299 |
+
painted_image = vis_add_mask(
|
300 |
+
input_image.copy(), point_mask, point_color, point_alpha
|
301 |
+
)
|
302 |
+
# paint contour
|
303 |
+
painted_image = vis_add_mask(
|
304 |
+
painted_image.copy(), 1 - contour_mask, contour_color, 1
|
305 |
+
)
|
306 |
+
return painted_image
|
307 |
+
|
308 |
+
|
309 |
+
def mask_painter(
|
310 |
+
input_image,
|
311 |
+
input_mask,
|
312 |
+
mask_color=5,
|
313 |
+
mask_alpha=0.7,
|
314 |
+
contour_color=1,
|
315 |
+
contour_width=3,
|
316 |
+
):
|
317 |
+
assert (
|
318 |
+
input_image.shape[:2] == input_mask.shape
|
319 |
+
), "different shape between image and mask"
|
320 |
+
# 0: background, 1: foreground
|
321 |
+
mask = np.clip(input_mask, 0, 1)
|
322 |
+
contour_radius = (contour_width - 1) // 2
|
323 |
+
|
324 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
325 |
+
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
|
326 |
+
dist_map = dist_transform_fore - dist_transform_back
|
327 |
+
# ...:::!!!:::...
|
328 |
+
contour_radius += 2
|
329 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
330 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
331 |
+
contour_mask[contour_mask > 0.5] = 1.0
|
332 |
+
|
333 |
+
# paint mask
|
334 |
+
painted_image = vis_add_mask(
|
335 |
+
input_image.copy(), mask.copy(), mask_color, mask_alpha
|
336 |
+
)
|
337 |
+
# paint contour
|
338 |
+
painted_image = vis_add_mask(
|
339 |
+
painted_image.copy(), 1 - contour_mask, contour_color, 1
|
340 |
+
)
|
341 |
+
|
342 |
+
return painted_image
|
343 |
+
|
344 |
+
|
345 |
+
def background_remover(input_image, input_mask):
|
346 |
+
"""
|
347 |
+
input_image: H, W, 3, np.array
|
348 |
+
input_mask: H, W, np.array
|
349 |
+
|
350 |
+
image_wo_background: PIL.Image
|
351 |
+
"""
|
352 |
+
assert (
|
353 |
+
input_image.shape[:2] == input_mask.shape
|
354 |
+
), "different shape between image and mask"
|
355 |
+
# 0: background, 1: foreground
|
356 |
+
mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2) * 255
|
357 |
+
image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4
|
358 |
+
image_wo_background = Image.fromarray(image_wo_background).convert("RGBA")
|
359 |
+
|
360 |
+
return image_wo_background
|