George commited on
Commit
b5f33fd
1 Parent(s): 141cbe3

upl all codes

Browse files
.gitignore ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ */__pycache__/*
3
+ */*/__pycache__/*
4
+ */__init__.py
5
+ */*/__init__.py
6
+ images/*
7
+ *.py[cod]
8
+ *$py.class
9
+ input/*
10
+ output/*
11
+ images/*
12
+ .vscode/*
13
+ test.py
14
+ model/*
15
+
16
+ # C extensions
17
+ *.so
18
+
19
+ # Distribution / packaging
20
+ .Python
21
+ build/
22
+ develop-eggs/
23
+ dist/
24
+ downloads/
25
+ eggs/
26
+ .eggs/
27
+ lib/
28
+ lib64/
29
+ parts/
30
+ sdist/
31
+ var/
32
+ wheels/
33
+ share/python-wheels/
34
+ *.egg-info/
35
+ .installed.cfg
36
+ *.egg
37
+ MANIFEST
38
+
39
+ # PyInstaller
40
+ # Usually these files are written by a python script from a template
41
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
42
+ *.manifest
43
+ *.spec
44
+
45
+ # Installer logs
46
+ pip-log.txt
47
+ pip-delete-this-directory.txt
48
+
49
+ # Unit test / coverage reports
50
+ htmlcov/
51
+ .tox/
52
+ .nox/
53
+ .coverage
54
+ .coverage.*
55
+ .cache
56
+ nosetests.xml
57
+ coverage.xml
58
+ *.cover
59
+ *.py,cover
60
+ .hypothesis/
61
+ .pytest_cache/
62
+ cover/
63
+
64
+ # Translations
65
+ *.mo
66
+ *.pot
67
+
68
+ # Django stuff:
69
+ *.log
70
+ local_settings.py
71
+ db.sqlite3
72
+ db.sqlite3-journal
73
+
74
+ # Flask stuff:
75
+ instance/
76
+ .webassets-cache
77
+
78
+ # Scrapy stuff:
79
+ .scrapy
80
+
81
+ # Sphinx documentation
82
+ docs/_build/
83
+
84
+ # PyBuilder
85
+ .pybuilder/
86
+ target/
87
+
88
+ # Jupyter Notebook
89
+ .ipynb_checkpoints
90
+
91
+ # IPython
92
+ profile_default/
93
+ ipython_config.py
94
+
95
+ # pyenv
96
+ # For a library or package, you might want to ignore these files since the code is
97
+ # intended to run in multiple environments; otherwise, check them in:
98
+ # .python-version
99
+
100
+ # pipenv
101
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
103
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
104
+ # install all needed dependencies.
105
+ #Pipfile.lock
106
+
107
+ # poetry
108
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
109
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
110
+ # commonly ignored for libraries.
111
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
112
+ #poetry.lock
113
+
114
+ # pdm
115
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
116
+ #pdm.lock
117
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
118
+ # in version control.
119
+ # https://pdm.fming.dev/#use-with-ide
120
+ .pdm.toml
121
+
122
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
123
+ __pypackages__/
124
+
125
+ # Celery stuff
126
+ celerybeat-schedule
127
+ celerybeat.pid
128
+
129
+ # SageMath parsed files
130
+ *.sage.py
131
+
132
+ # Environments
133
+ .env
134
+ .venv
135
+ env/
136
+ venv/
137
+ ENV/
138
+ env.bak/
139
+ venv.bak/
140
+
141
+ # Spyder project settings
142
+ .spyderproject
143
+ .spyproject
144
+
145
+ # Rope project settings
146
+ .ropeproject
147
+
148
+ # mkdocs documentation
149
+ /site
150
+
151
+ # mypy
152
+ .mypy_cache/
153
+ .dmypy.json
154
+ dmypy.json
155
+
156
+ # Pyre type checker
157
+ .pyre/
158
+
159
+ # pytype static type analyzer
160
+ .pytype/
161
+
162
+ # Cython debug symbols
163
+ cython_debug/
164
+
165
+ # PyCharm
166
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
167
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
168
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
169
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
170
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,10 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ ## Environment
6
+ ```
7
+ conda create -n tb --yes --file conda.txt
8
+ conda activate tb
9
+ pip install -r requirements.txt
10
+ ```
conda.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ python=3.9
2
+ pytorch=1.12.1
3
+ torchvision=0.13.1
4
+ torchaudio=0.12.1
5
+ cudatoolkit=11.3.1
gender_age.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import shutil
3
+ import numpy as np
4
+ from dataclasses import dataclass
5
+ from tqdm import tqdm
6
+ from mivolo.predictor import Predictor
7
+ from utils import *
8
+
9
+ import warnings
10
+ warnings.filterwarnings("ignore")
11
+
12
+
13
+ @dataclass
14
+ class Cfg:
15
+ detector_weights: str
16
+ checkpoint: str
17
+ device: str = "cuda"
18
+ with_persons: bool = True
19
+ disable_faces: bool = False
20
+ draw: bool = True
21
+
22
+
23
+ class ValidImgDetector:
24
+
25
+ predictor = None
26
+
27
+ def __init__(self):
28
+ detector_path = "./model/yolov8x_person_face.pt"
29
+ age_gender_path = "./model/model_imdb_cross_person_4.22_99.46.pth.tar"
30
+ predictor_cfg = Cfg(detector_path, age_gender_path)
31
+ self.predictor = Predictor(predictor_cfg)
32
+
33
+ def _detect(
34
+ self,
35
+ image: np.ndarray,
36
+ score_threshold: float,
37
+ iou_threshold: float,
38
+ mode: str,
39
+ predictor: Predictor
40
+ ) -> np.ndarray:
41
+ # input is rgb image, output must be rgb too
42
+ predictor.detector.detector_kwargs['conf'] = score_threshold
43
+ predictor.detector.detector_kwargs['iou'] = iou_threshold
44
+
45
+ if mode == "Use persons and faces":
46
+ use_persons = True
47
+ disable_faces = False
48
+ elif mode == "Use persons only":
49
+ use_persons = True
50
+ disable_faces = True
51
+ elif mode == "Use faces only":
52
+ use_persons = False
53
+ disable_faces = False
54
+
55
+ predictor.age_gender_model.meta.use_persons = use_persons
56
+ predictor.age_gender_model.meta.disable_faces = disable_faces
57
+
58
+ image = image[:, :, ::-1] # RGB -> BGR
59
+ detected_objects, _ = predictor.recognize(image)
60
+
61
+ has_child, has_female, has_male = False, False, False
62
+ if len(detected_objects.ages) > 0:
63
+ has_child = min(detected_objects.ages) < 18
64
+ has_female = 'female' in detected_objects.genders
65
+ has_male = 'male' in detected_objects.genders
66
+
67
+ return has_child, has_female, has_male
68
+
69
+ def valid_img(self, img_path):
70
+ image = cv2.imread(img_path)
71
+ has_child, has_female, has_male = self._detect(
72
+ image, 0.4, 0.7, "Use persons and faces", self.predictor)
73
+ return (not has_child) and (has_female) and (not has_male)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ detector = ValidImgDetector()
78
+ create_dir('./output/valid')
79
+ create_dir('./output/invalid')
80
+
81
+ for root, _, files in os.walk('./images'):
82
+ for file in tqdm(files):
83
+ if file.endswith('.jpg'):
84
+ src_path = f"./images/{file}"
85
+ dst_path = "./output/invalid"
86
+ if detector.valid_img(src_path):
87
+ dst_path = "./output/valid"
88
+
89
+ shutil.move(src_path, dst_path)
human_detect.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ from torchvision.models.detection import fasterrcnn_resnet50_fpn
5
+
6
+
7
+ def has_person(image_path):
8
+ # 加载预训练的 Faster R-CNN 模型
9
+ model = fasterrcnn_resnet50_fpn(pretrained=True)
10
+ model.eval()
11
+
12
+ # 载入并预处理图片
13
+ img = Image.open(image_path)
14
+ transform = transforms.Compose([transforms.ToTensor()])
15
+ input_tensor = transform(img)
16
+ input_batch = input_tensor.unsqueeze(0)
17
+
18
+ # 模型推理
19
+ with torch.no_grad():
20
+ output = model(input_batch)
21
+
22
+ # 解析输出结果
23
+ labels = output[0]['labels'].numpy()
24
+ scores = output[0]['scores'].numpy()
25
+
26
+ # 判断是否检测到人体(label=1 表示人类类别)
27
+ person_detected = any(label == 1 and score >
28
+ 0.5 for label, score in zip(labels, scores))
29
+
30
+ return person_detected
31
+
32
+
33
+ if __name__ == "__main__":
34
+ image_path = './images/test.jpg'
35
+ if has_person(image_path):
36
+ print("图片中检测到人体。")
37
+ else:
38
+ print("图片中没有检测到人体。")
item2pic.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ from bs4 import BeautifulSoup
5
+ from selenium import webdriver
6
+ from utils import *
7
+
8
+
9
+ def download_image(url, save_path):
10
+ rand_sleep()
11
+ try:
12
+ # 发送GET请求下载图片
13
+ response = requests.get(url)
14
+ response.raise_for_status()
15
+
16
+ # 确定保存路径
17
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
18
+
19
+ # 保存图片
20
+ with open(save_path, 'wb') as file:
21
+ file.write(response.content)
22
+
23
+ print(f"Image downloaded and saved to {save_path}")
24
+ except requests.exceptions.HTTPError as errh:
25
+ print("Http Error:", errh)
26
+ except requests.exceptions.ConnectionError as errc:
27
+ print("Error Connecting:", errc)
28
+ except requests.exceptions.Timeout as errt:
29
+ print("Timeout Error:", errt)
30
+ except requests.exceptions.RequestException as err:
31
+ print("OOps: Something Else", err)
32
+
33
+
34
+ def get_pics(id):
35
+ rand_sleep()
36
+ # selenium
37
+ option = webdriver.ChromeOptions()
38
+ option.add_experimental_option('excludeSwitches', ['enable-automation'])
39
+ option.add_argument("--disable-blink-features=AutomationControlled")
40
+ # option.add_argument('--headless')
41
+ browser = webdriver.Chrome(options=option)
42
+ browser.get(f'https://www.taobao.com/list/item/{id}.htm')
43
+ # browser.minimize_window()
44
+ browser.maximize_window()
45
+
46
+ skip_captcha()
47
+
48
+ # bs4
49
+ soup = BeautifulSoup(browser.page_source, 'html.parser')
50
+ srcs = set()
51
+
52
+ try:
53
+ for link in soup.find_all('img', class_='item-thumbnail'):
54
+ srcs.add('https:' + link.get('src').split('.jpg')[0] + '.jpg')
55
+
56
+ for link in soup.find_all('img', class_='property-img'):
57
+ srcs.add('https:' + link.get('src').split('.jpg')[0] + '.jpg')
58
+
59
+ for link in soup.find('div', class_='detail-content').find_all('img'):
60
+ srcs.add('https:' + link.get('src').split('.jpg')[0] + '.jpg')
61
+
62
+ except Exception as err:
63
+ print("Error: ", err)
64
+
65
+ return srcs
66
+
67
+
68
+ if __name__ == "__main__":
69
+ create_dir('./images')
70
+
71
+ with open('./output/items.jsonl', 'r', encoding='utf-8') as jsonl_file:
72
+ for line in jsonl_file:
73
+ # 将JSON字符串转换为Python对象
74
+ data = json.loads(line)
75
+ # 获取字典中的'id'键值的值,并添加到列表中
76
+ id_value = data.get('id')
77
+ if id_value is not None:
78
+ pic_urls = get_pics(id_value)
79
+ for url in pic_urls:
80
+ download_image(url, f'./images/{os.path.basename(url)}')
mivolo/data/data_reader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass, field
4
+ from enum import Enum
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ import pandas as pd
8
+
9
+ IMAGES_EXT: Tuple = (".jpeg", ".jpg", ".png", ".webp", ".bmp", ".gif")
10
+ VIDEO_EXT: Tuple = (".mp4", ".avi", ".mov", ".mkv", ".webm")
11
+
12
+
13
+ @dataclass
14
+ class PictureInfo:
15
+ image_path: str
16
+ age: Optional[str] # age or age range(start;end format) or "-1"
17
+ gender: Optional[str] # "M" of "F" or "-1"
18
+ bbox: List[int] = field(default_factory=lambda: [-1, -1, -1, -1]) # face bbox: xyxy
19
+ person_bbox: List[int] = field(default_factory=lambda: [-1, -1, -1, -1]) # person bbox: xyxy
20
+
21
+ @property
22
+ def has_person_bbox(self) -> bool:
23
+ return any(coord != -1 for coord in self.person_bbox)
24
+
25
+ @property
26
+ def has_face_bbox(self) -> bool:
27
+ return any(coord != -1 for coord in self.bbox)
28
+
29
+ def has_gt(self, only_age: bool = False) -> bool:
30
+ if only_age:
31
+ return self.age != "-1"
32
+ else:
33
+ return not (self.age == "-1" and self.gender == "-1")
34
+
35
+ def clear_person_bbox(self):
36
+ self.person_bbox = [-1, -1, -1, -1]
37
+
38
+ def clear_face_bbox(self):
39
+ self.bbox = [-1, -1, -1, -1]
40
+
41
+
42
+ class AnnotType(Enum):
43
+ ORIGINAL = "original"
44
+ PERSONS = "persons"
45
+ NONE = "none"
46
+
47
+ @classmethod
48
+ def _missing_(cls, value):
49
+ print(f"WARN: Unknown annotation type {value}.")
50
+ return AnnotType.NONE
51
+
52
+
53
+ def get_all_files(path: str, extensions: Tuple = IMAGES_EXT):
54
+ files_all = []
55
+ for root, subFolders, files in os.walk(path):
56
+ for name in files:
57
+ # linux tricks with .directory that still is file
58
+ if "directory" not in name and sum([ext.lower() in name.lower() for ext in extensions]) > 0:
59
+ files_all.append(os.path.join(root, name))
60
+ return files_all
61
+
62
+
63
+ class InputType(Enum):
64
+ Image = 0
65
+ Video = 1
66
+ VideoStream = 2
67
+
68
+
69
+ def get_input_type(input_path: str) -> InputType:
70
+ if os.path.isdir(input_path):
71
+ print("Input is a folder, only images will be processed")
72
+ return InputType.Image
73
+ elif os.path.isfile(input_path):
74
+ if input_path.endswith(VIDEO_EXT):
75
+ return InputType.Video
76
+ if input_path.endswith(IMAGES_EXT):
77
+ return InputType.Image
78
+ else:
79
+ raise ValueError(
80
+ f"Unknown or unsupported input file format {input_path}, \
81
+ supported video formats: {VIDEO_EXT}, \
82
+ supported image formats: {IMAGES_EXT}"
83
+ )
84
+ elif input_path.startswith("http") and not input_path.endswith(IMAGES_EXT):
85
+ return InputType.VideoStream
86
+ else:
87
+ raise ValueError(f"Unknown input {input_path}")
88
+
89
+
90
+ def read_csv_annotation_file(annotation_file: str, images_dir: str, ignore_without_gt=False):
91
+ bboxes_per_image: Dict[str, List[PictureInfo]] = defaultdict(list)
92
+
93
+ df = pd.read_csv(annotation_file, sep=",")
94
+
95
+ annot_type = AnnotType("persons") if "person_x0" in df.columns else AnnotType("original")
96
+ print(f"Reading {annotation_file} (type: {annot_type})...")
97
+
98
+ missing_images = 0
99
+ for index, row in df.iterrows():
100
+ img_path = os.path.join(images_dir, row["img_name"])
101
+ if not os.path.exists(img_path):
102
+ missing_images += 1
103
+ continue
104
+
105
+ face_x1, face_y1, face_x2, face_y2 = row["face_x0"], row["face_y0"], row["face_x1"], row["face_y1"]
106
+ age, gender = str(row["age"]), str(row["gender"])
107
+
108
+ if ignore_without_gt and (age == "-1" or gender == "-1"):
109
+ continue
110
+
111
+ if annot_type == AnnotType.PERSONS:
112
+ p_x1, p_y1, p_x2, p_y2 = row["person_x0"], row["person_y0"], row["person_x1"], row["person_y1"]
113
+ person_bbox = list(map(int, [p_x1, p_y1, p_x2, p_y2]))
114
+ else:
115
+ person_bbox = [-1, -1, -1, -1]
116
+
117
+ bbox = list(map(int, [face_x1, face_y1, face_x2, face_y2]))
118
+ pic_info = PictureInfo(img_path, age, gender, bbox, person_bbox)
119
+ assert isinstance(pic_info.person_bbox, list)
120
+
121
+ bboxes_per_image[img_path].append(pic_info)
122
+
123
+ if missing_images > 0:
124
+ print(f"WARNING: Missing images: {missing_images}/{len(df)}")
125
+ return bboxes_per_image, annot_type
mivolo/data/dataset/__init__.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from mivolo.model.mi_volo import MiVOLO
5
+
6
+ from .age_gender_dataset import AgeGenderDataset
7
+ from .age_gender_loader import create_loader
8
+ from .classification_dataset import AdienceDataset, FairFaceDataset
9
+
10
+ DATASET_CLASS_MAP = {
11
+ "utk": AgeGenderDataset,
12
+ "lagenda": AgeGenderDataset,
13
+ "imdb": AgeGenderDataset,
14
+ "adience": AdienceDataset,
15
+ "fairface": FairFaceDataset,
16
+ }
17
+
18
+
19
+ def build(
20
+ name: str,
21
+ images_path: str,
22
+ annotations_path: str,
23
+ split: str,
24
+ mivolo_model: MiVOLO,
25
+ workers: int,
26
+ batch_size: int,
27
+ ) -> Tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader]:
28
+
29
+ dataset_class = DATASET_CLASS_MAP[name]
30
+
31
+ dataset: torch.utils.data.Dataset = dataset_class(
32
+ images_path=images_path,
33
+ annotations_path=annotations_path,
34
+ name=name,
35
+ split=split,
36
+ target_size=mivolo_model.input_size,
37
+ max_age=mivolo_model.meta.max_age,
38
+ min_age=mivolo_model.meta.min_age,
39
+ model_with_persons=mivolo_model.meta.with_persons_model,
40
+ use_persons=mivolo_model.meta.use_persons,
41
+ disable_faces=mivolo_model.meta.disable_faces,
42
+ only_age=mivolo_model.meta.only_age,
43
+ )
44
+
45
+ data_config = mivolo_model.data_config
46
+
47
+ in_chans = 3 if not mivolo_model.meta.with_persons_model else 6
48
+ input_size = (in_chans, mivolo_model.input_size, mivolo_model.input_size)
49
+
50
+ dataset_loader: torch.utils.data.DataLoader = create_loader(
51
+ dataset,
52
+ input_size=input_size,
53
+ batch_size=batch_size,
54
+ mean=data_config["mean"],
55
+ std=data_config["std"],
56
+ num_workers=workers,
57
+ crop_pct=data_config["crop_pct"],
58
+ crop_mode=data_config["crop_mode"],
59
+ pin_memory=False,
60
+ device=mivolo_model.device,
61
+ target_type=dataset.target_dtype,
62
+ )
63
+
64
+ return dataset, dataset_loader
mivolo/data/dataset/age_gender_dataset.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, List, Optional, Set
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from mivolo.data.dataset.reader_age_gender import ReaderAgeGender
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+
11
+ _logger = logging.getLogger("AgeGenderDataset")
12
+
13
+
14
+ class AgeGenderDataset(torch.utils.data.Dataset):
15
+ def __init__(
16
+ self,
17
+ images_path,
18
+ annotations_path,
19
+ name=None,
20
+ split="train",
21
+ load_bytes=False,
22
+ img_mode="RGB",
23
+ transform=None,
24
+ is_training=False,
25
+ seed=1234,
26
+ target_size=224,
27
+ min_age=None,
28
+ max_age=None,
29
+ model_with_persons=False,
30
+ use_persons=False,
31
+ disable_faces=False,
32
+ only_age=False,
33
+ ):
34
+ reader = ReaderAgeGender(
35
+ images_path,
36
+ annotations_path,
37
+ split=split,
38
+ seed=seed,
39
+ target_size=target_size,
40
+ with_persons=use_persons,
41
+ disable_faces=disable_faces,
42
+ only_age=only_age,
43
+ )
44
+
45
+ self.name = name
46
+ self.model_with_persons = model_with_persons
47
+ self.reader = reader
48
+ self.load_bytes = load_bytes
49
+ self.img_mode = img_mode
50
+ self.transform = transform
51
+ self._consecutive_errors = 0
52
+ self.is_training = is_training
53
+ self.random_flip = 0.0
54
+
55
+ # Setting up classes.
56
+ # If min and max classes are passed - use them to have the same preprocessing for validation
57
+ self.max_age: float = None
58
+ self.min_age: float = None
59
+ self.avg_age: float = None
60
+ self.set_ages_min_max(min_age, max_age)
61
+
62
+ self.genders = ["M", "F"]
63
+ self.num_classes_gender = len(self.genders)
64
+
65
+ self.age_classes: Optional[List[str]] = self.set_age_classes()
66
+
67
+ self.num_classes_age = 1 if self.age_classes is None else len(self.age_classes)
68
+ self.num_classes: int = self.num_classes_age + self.num_classes_gender
69
+ self.target_dtype = torch.float32
70
+
71
+ def set_age_classes(self) -> Optional[List[str]]:
72
+ return None # for regression dataset
73
+
74
+ def set_ages_min_max(self, min_age: Optional[float], max_age: Optional[float]):
75
+
76
+ assert all(age is None for age in [min_age, max_age]) or all(
77
+ age is not None for age in [min_age, max_age]
78
+ ), "Both min and max age must be passed or none of them"
79
+
80
+ if max_age is not None and min_age is not None:
81
+ _logger.info(f"Received predefined min_age {min_age} and max_age {max_age}")
82
+ self.max_age = max_age
83
+ self.min_age = min_age
84
+ else:
85
+ # collect statistics from loaded dataset
86
+ all_ages_set: Set[int] = set()
87
+ for img_path, image_samples in self.reader._ann.items():
88
+ for image_sample_info in image_samples:
89
+ if image_sample_info.age == "-1":
90
+ continue
91
+ age = round(float(image_sample_info.age))
92
+ all_ages_set.add(age)
93
+
94
+ self.max_age = max(all_ages_set)
95
+ self.min_age = min(all_ages_set)
96
+
97
+ self.avg_age = (self.max_age + self.min_age) / 2.0
98
+
99
+ def _norm_age(self, age):
100
+ return (age - self.avg_age) / (self.max_age - self.min_age)
101
+
102
+ def parse_gender(self, _gender: str) -> float:
103
+ if _gender != "-1":
104
+ gender = float(0 if _gender == "M" or _gender == "0" else 1)
105
+ else:
106
+ gender = -1
107
+ return gender
108
+
109
+ def parse_target(self, _age: str, gender: str) -> List[Any]:
110
+ if _age != "-1":
111
+ age = round(float(_age))
112
+ age = self._norm_age(float(age))
113
+ else:
114
+ age = -1
115
+
116
+ target: List[float] = [age, self.parse_gender(gender)]
117
+ return target
118
+
119
+ @property
120
+ def transform(self):
121
+ return self._transform
122
+
123
+ @transform.setter
124
+ def transform(self, transform):
125
+ # Disable pretrained monkey-patched transforms
126
+ if not transform:
127
+ return
128
+
129
+ _trans = []
130
+ for trans in transform.transforms:
131
+ if "Resize" in str(trans):
132
+ continue
133
+ if "Crop" in str(trans):
134
+ continue
135
+ _trans.append(trans)
136
+ self._transform = transforms.Compose(_trans)
137
+
138
+ def apply_tranforms(self, image: Optional[np.ndarray]) -> np.ndarray:
139
+ if image is None:
140
+ return None
141
+
142
+ if self.transform is None:
143
+ return image
144
+
145
+ image = convert_to_pil(image, self.img_mode)
146
+ for trans in self.transform.transforms:
147
+ image = trans(image)
148
+ return image
149
+
150
+ def __getitem__(self, index):
151
+ # get preprocessed face and person crops (np.ndarray)
152
+ # resize + pad, for person crops: cut off other bboxes
153
+ images, target = self.reader[index]
154
+
155
+ target = self.parse_target(*target)
156
+
157
+ if self.model_with_persons:
158
+ face_image, person_image = images
159
+ person_image: np.ndarray = self.apply_tranforms(person_image)
160
+ else:
161
+ face_image = images[0]
162
+ person_image = None
163
+
164
+ face_image: np.ndarray = self.apply_tranforms(face_image)
165
+
166
+ if person_image is not None:
167
+ img = np.concatenate([face_image, person_image], axis=0)
168
+ else:
169
+ img = face_image
170
+
171
+ return img, target
172
+
173
+ def __len__(self):
174
+ return len(self.reader)
175
+
176
+ def filename(self, index, basename=False, absolute=False):
177
+ return self.reader.filename(index, basename, absolute)
178
+
179
+ def filenames(self, basename=False, absolute=False):
180
+ return self.reader.filenames(basename, absolute)
181
+
182
+
183
+ def convert_to_pil(cv_im: Optional[np.ndarray], img_mode: str = "RGB") -> "Image":
184
+ if cv_im is None:
185
+ return None
186
+
187
+ if img_mode == "RGB":
188
+ cv_im = cv2.cvtColor(cv_im, cv2.COLOR_BGR2RGB)
189
+ else:
190
+ raise Exception("Incorrect image mode has been passed!")
191
+
192
+ cv_im = np.ascontiguousarray(cv_im)
193
+ pil_image = Image.fromarray(cv_im)
194
+ return pil_image
mivolo/data/dataset/age_gender_loader.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from timm https://github.com/huggingface/pytorch-image-models
3
+
4
+ Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
+ """
6
+
7
+ import logging
8
+ from contextlib import suppress
9
+ from functools import partial
10
+ from itertools import repeat
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.utils.data
15
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16
+ from timm.data.dataset import IterableImageDataset
17
+ from timm.data.loader import PrefetchLoader, _worker_init
18
+ from timm.data.transforms_factory import create_transform
19
+
20
+ _logger = logging.getLogger(__name__)
21
+
22
+
23
+ def fast_collate(batch, target_dtype=torch.uint8):
24
+ """A fast collation function optimized for uint8 images (np array or torch) and target_dtype targets (labels)"""
25
+ assert isinstance(batch[0], tuple)
26
+ batch_size = len(batch)
27
+ if isinstance(batch[0][0], np.ndarray):
28
+ targets = torch.tensor([b[1] for b in batch], dtype=target_dtype)
29
+ assert len(targets) == batch_size
30
+ tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
31
+ for i in range(batch_size):
32
+ tensor[i] += torch.from_numpy(batch[i][0])
33
+ return tensor, targets
34
+ else:
35
+ raise ValueError(f"Incorrect batch type: {type(batch[0][0])}")
36
+
37
+
38
+ def adapt_to_chs(x, n):
39
+ if not isinstance(x, (tuple, list)):
40
+ x = tuple(repeat(x, n))
41
+ elif len(x) != n:
42
+ # doubled channels
43
+ if len(x) * 2 == n:
44
+ x = np.concatenate((x, x))
45
+ _logger.warning(f"Pretrained mean/std different shape than model (doubled channes), using concat: {x}.")
46
+ else:
47
+ x_mean = np.mean(x).item()
48
+ x = (x_mean,) * n
49
+ _logger.warning(f"Pretrained mean/std different shape than model, using avg value {x}.")
50
+ else:
51
+ assert len(x) == n, "normalization stats must match image channels"
52
+ return x
53
+
54
+
55
+ class PrefetchLoaderForMultiInput(PrefetchLoader):
56
+ def __init__(
57
+ self,
58
+ loader,
59
+ mean=IMAGENET_DEFAULT_MEAN,
60
+ std=IMAGENET_DEFAULT_STD,
61
+ channels=3,
62
+ device=torch.device("cuda"),
63
+ img_dtype=torch.float32,
64
+ ):
65
+
66
+ mean = adapt_to_chs(mean, channels)
67
+ std = adapt_to_chs(std, channels)
68
+ normalization_shape = (1, channels, 1, 1)
69
+
70
+ self.loader = loader
71
+ self.device = device
72
+ self.img_dtype = img_dtype
73
+ self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
74
+ self.std = torch.tensor([x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
75
+
76
+ self.is_cuda = torch.cuda.is_available() and device.type == "cuda"
77
+
78
+ def __iter__(self):
79
+ first = True
80
+ if self.is_cuda:
81
+ stream = torch.cuda.Stream()
82
+ stream_context = partial(torch.cuda.stream, stream=stream)
83
+ else:
84
+ stream = None
85
+ stream_context = suppress
86
+
87
+ for next_input, next_target in self.loader:
88
+
89
+ with stream_context():
90
+ next_input = next_input.to(device=self.device, non_blocking=True)
91
+ next_target = next_target.to(device=self.device, non_blocking=True)
92
+ next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
93
+
94
+ if not first:
95
+ yield input, target # noqa: F823, F821
96
+ else:
97
+ first = False
98
+
99
+ if stream is not None:
100
+ torch.cuda.current_stream().wait_stream(stream)
101
+
102
+ input = next_input
103
+ target = next_target
104
+
105
+ yield input, target
106
+
107
+
108
+ def create_loader(
109
+ dataset,
110
+ input_size,
111
+ batch_size,
112
+ mean=IMAGENET_DEFAULT_MEAN,
113
+ std=IMAGENET_DEFAULT_STD,
114
+ num_workers=1,
115
+ crop_pct=None,
116
+ crop_mode=None,
117
+ pin_memory=False,
118
+ img_dtype=torch.float32,
119
+ device=torch.device("cuda"),
120
+ persistent_workers=True,
121
+ worker_seeding="all",
122
+ target_type=torch.int64,
123
+ ):
124
+
125
+ transform = create_transform(
126
+ input_size,
127
+ is_training=False,
128
+ use_prefetcher=True,
129
+ mean=mean,
130
+ std=std,
131
+ crop_pct=crop_pct,
132
+ crop_mode=crop_mode,
133
+ )
134
+ dataset.transform = transform
135
+
136
+ if isinstance(dataset, IterableImageDataset):
137
+ # give Iterable datasets early knowledge of num_workers so that sample estimates
138
+ # are correct before worker processes are launched
139
+ dataset.set_loader_cfg(num_workers=num_workers)
140
+ raise ValueError("Incorrect dataset type: IterableImageDataset")
141
+
142
+ loader_class = torch.utils.data.DataLoader
143
+ loader_args = dict(
144
+ batch_size=batch_size,
145
+ shuffle=False,
146
+ num_workers=num_workers,
147
+ sampler=None,
148
+ collate_fn=lambda batch: fast_collate(batch, target_dtype=target_type),
149
+ pin_memory=pin_memory,
150
+ drop_last=False,
151
+ worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
152
+ persistent_workers=persistent_workers,
153
+ )
154
+ try:
155
+ loader = loader_class(dataset, **loader_args)
156
+ except TypeError:
157
+ loader_args.pop("persistent_workers") # only in Pytorch 1.7+
158
+ loader = loader_class(dataset, **loader_args)
159
+
160
+ loader = PrefetchLoaderForMultiInput(
161
+ loader,
162
+ mean=mean,
163
+ std=std,
164
+ channels=input_size[0],
165
+ device=device,
166
+ img_dtype=img_dtype,
167
+ )
168
+
169
+ return loader
mivolo/data/dataset/classification_dataset.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Optional
2
+
3
+ import torch
4
+
5
+ from .age_gender_dataset import AgeGenderDataset
6
+
7
+
8
+ class ClassificationDataset(AgeGenderDataset):
9
+ def __init__(self, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+
12
+ self.target_dtype = torch.int32
13
+
14
+ def set_age_classes(self) -> Optional[List[str]]:
15
+ raise NotImplementedError
16
+
17
+ def parse_target(self, age: str, gender: str) -> List[Any]:
18
+ assert self.age_classes is not None
19
+ if age != "-1":
20
+ assert age in self.age_classes, f"Unknown category in {self.name} dataset: {age}"
21
+ age_ind = self.age_classes.index(age)
22
+ else:
23
+ age_ind = -1
24
+
25
+ target: List[int] = [age_ind, int(self.parse_gender(gender))]
26
+ return target
27
+
28
+
29
+ class FairFaceDataset(ClassificationDataset):
30
+ def set_age_classes(self) -> Optional[List[str]]:
31
+ age_classes = ["0;2", "3;9", "10;19", "20;29", "30;39", "40;49", "50;59", "60;69", "70;120"]
32
+ # a[i-1] <= v < a[i] => age_classes[i-1]
33
+ self._intervals = torch.tensor([0, 3, 10, 20, 30, 40, 50, 60, 70])
34
+
35
+ return age_classes
36
+
37
+
38
+ class AdienceDataset(ClassificationDataset):
39
+ def __init__(self, *args, **kwargs):
40
+ super().__init__(*args, **kwargs)
41
+
42
+ self.target_dtype = torch.int32
43
+
44
+ def set_age_classes(self) -> Optional[List[str]]:
45
+ age_classes = ["0;2", "4;6", "8;12", "15;20", "25;32", "38;43", "48;53", "60;100"]
46
+ # a[i-1] <= v < a[i] => age_classes[i-1]
47
+ self._intervals = torch.tensor([0, 4, 7, 14, 24, 36, 46, 57])
48
+ return age_classes
mivolo/data/dataset/reader_age_gender.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from functools import partial
4
+ from multiprocessing.pool import ThreadPool
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ import cv2
8
+ import numpy as np
9
+ from mivolo.data.data_reader import AnnotType, PictureInfo, get_all_files, read_csv_annotation_file
10
+ from mivolo.data.misc import IOU, class_letterbox, cropout_black_parts
11
+ from timm.data.readers.reader import Reader
12
+ from tqdm import tqdm
13
+
14
+ CROP_ROUND_TOL = 0.3
15
+ MIN_PERSON_SIZE = 100
16
+ MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4
17
+
18
+ _logger = logging.getLogger("ReaderAgeGender")
19
+
20
+
21
+ class ReaderAgeGender(Reader):
22
+ """
23
+ Reader for almost original imdb-wiki cleaned dataset.
24
+ Two changes:
25
+ 1. Your annotation must be in ./annotation subdir of dataset root
26
+ 2. Images must be in images subdir
27
+
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ images_path,
33
+ annotations_path,
34
+ split="validation",
35
+ target_size=224,
36
+ min_size=5,
37
+ seed=1234,
38
+ with_persons=False,
39
+ min_person_size=MIN_PERSON_SIZE,
40
+ disable_faces=False,
41
+ only_age=False,
42
+ min_person_aftercut_ratio=MIN_PERSON_CROP_AFTERCUT_RATIO,
43
+ crop_round_tol=CROP_ROUND_TOL,
44
+ ):
45
+ super().__init__()
46
+
47
+ self.with_persons = with_persons
48
+ self.disable_faces = disable_faces
49
+ self.only_age = only_age
50
+
51
+ # can be only black for now, even though it's not very good with further normalization
52
+ self.crop_out_color = (0, 0, 0)
53
+
54
+ self.empty_crop = np.ones((target_size, target_size, 3)) * self.crop_out_color
55
+ self.empty_crop = self.empty_crop.astype(np.uint8)
56
+
57
+ self.min_person_size = min_person_size
58
+ self.min_person_aftercut_ratio = min_person_aftercut_ratio
59
+ self.crop_round_tol = crop_round_tol
60
+
61
+ self.split = split
62
+ self.min_size = min_size
63
+ self.seed = seed
64
+ self.target_size = target_size
65
+
66
+ # Reading annotations. Can be multiple files if annotations_path dir
67
+ self._ann: Dict[str, List[PictureInfo]] = {} # list of samples for each image
68
+ self._associated_objects: Dict[str, Dict[int, List[List[int]]]] = {}
69
+ self._faces_list: List[Tuple[str, int]] = [] # samples from this list will be loaded in __getitem__
70
+
71
+ self._read_annotations(images_path, annotations_path)
72
+ _logger.info(f"Dataset length: {len(self._faces_list)} crops")
73
+
74
+ def __getitem__(self, index):
75
+ return self._read_img_and_label(index)
76
+
77
+ def __len__(self):
78
+ return len(self._faces_list)
79
+
80
+ def _filename(self, index, basename=False, absolute=False):
81
+ img_p = self._faces_list[index][0]
82
+ return os.path.basename(img_p) if basename else img_p
83
+
84
+ def _read_annotations(self, images_path, csvs_path):
85
+ self._ann = {}
86
+ self._faces_list = []
87
+ self._associated_objects = {}
88
+
89
+ csvs = get_all_files(csvs_path, [".csv"])
90
+ csvs = [c for c in csvs if self.split in os.path.basename(c)]
91
+
92
+ # load annotations per image
93
+ for csv in csvs:
94
+ db, ann_type = read_csv_annotation_file(csv, images_path)
95
+ if self.with_persons and ann_type != AnnotType.PERSONS:
96
+ raise ValueError(
97
+ f"Annotation type in file {csv} contains no persons, "
98
+ f"but annotations with persons are requested."
99
+ )
100
+ self._ann.update(db)
101
+
102
+ if len(self._ann) == 0:
103
+ raise ValueError("Annotations are empty!")
104
+
105
+ self._ann, self._associated_objects = self.prepare_annotations()
106
+ images_list = list(self._ann.keys())
107
+
108
+ for img_path in images_list:
109
+ for index, image_sample_info in enumerate(self._ann[img_path]):
110
+ assert image_sample_info.has_gt(
111
+ self.only_age
112
+ ), "Annotations must be checked with self.prepare_annotations() func"
113
+ self._faces_list.append((img_path, index))
114
+
115
+ def _read_img_and_label(self, index):
116
+ if not isinstance(index, int):
117
+ raise TypeError("ReaderAgeGender expected index to be integer")
118
+
119
+ img_p, face_index = self._faces_list[index]
120
+ ann: PictureInfo = self._ann[img_p][face_index]
121
+ img = cv2.imread(img_p)
122
+
123
+ face_empty = True
124
+ if ann.has_face_bbox and not (self.with_persons and self.disable_faces):
125
+ face_crop, face_empty = self._get_crop(ann.bbox, img)
126
+
127
+ if not self.with_persons and face_empty:
128
+ # model without persons
129
+ raise ValueError("Annotations must be checked with self.prepare_annotations() func")
130
+
131
+ if face_empty:
132
+ face_crop = self.empty_crop
133
+
134
+ person_empty = True
135
+ if self.with_persons or self.disable_faces:
136
+ if ann.has_person_bbox:
137
+ # cut off all associated objects from person crop
138
+ objects = self._associated_objects[img_p][face_index]
139
+ person_crop, person_empty = self._get_crop(
140
+ ann.person_bbox,
141
+ img,
142
+ crop_out_color=self.crop_out_color,
143
+ asced_objects=objects,
144
+ )
145
+
146
+ if face_empty and person_empty:
147
+ raise ValueError("Annotations must be checked with self.prepare_annotations() func")
148
+
149
+ if person_empty:
150
+ person_crop = self.empty_crop
151
+
152
+ return (face_crop, person_crop), [ann.age, ann.gender]
153
+
154
+ def _get_crop(
155
+ self,
156
+ bbox,
157
+ img,
158
+ asced_objects=None,
159
+ crop_out_color=(0, 0, 0),
160
+ ) -> Tuple[np.ndarray, bool]:
161
+
162
+ empty_bbox = False
163
+
164
+ xmin, ymin, xmax, ymax = bbox
165
+ assert not (
166
+ ymax - ymin < self.min_size or xmax - xmin < self.min_size
167
+ ), "Annotations must be checked with self.prepare_annotations() func"
168
+
169
+ crop = img[ymin:ymax, xmin:xmax]
170
+
171
+ if asced_objects:
172
+ # cut off other objects for person crop
173
+ crop, empty_bbox = _cropout_asced_objs(
174
+ asced_objects,
175
+ bbox,
176
+ crop.copy(),
177
+ crop_out_color=crop_out_color,
178
+ min_person_size=self.min_person_size,
179
+ crop_round_tol=self.crop_round_tol,
180
+ min_person_aftercut_ratio=self.min_person_aftercut_ratio,
181
+ )
182
+ if empty_bbox:
183
+ crop = self.empty_crop
184
+
185
+ crop = class_letterbox(crop, new_shape=(self.target_size, self.target_size), color=crop_out_color)
186
+ return crop, empty_bbox
187
+
188
+ def prepare_annotations(self):
189
+
190
+ good_anns: Dict[str, List[PictureInfo]] = {}
191
+ all_associated_objects: Dict[str, Dict[int, List[List[int]]]] = {}
192
+
193
+ if not self.with_persons:
194
+ # remove all persons
195
+ for img_path, bboxes in self._ann.items():
196
+ for sample in bboxes:
197
+ sample.clear_person_bbox()
198
+
199
+ # check dataset and collect associated_objects
200
+ verify_images_func = partial(
201
+ verify_images,
202
+ min_size=self.min_size,
203
+ min_person_size=self.min_person_size,
204
+ with_persons=self.with_persons,
205
+ disable_faces=self.disable_faces,
206
+ crop_round_tol=self.crop_round_tol,
207
+ min_person_aftercut_ratio=self.min_person_aftercut_ratio,
208
+ only_age=self.only_age,
209
+ )
210
+ num_threads = min(8, os.cpu_count())
211
+
212
+ all_msgs = []
213
+ broken = 0
214
+ skipped = 0
215
+ all_skipped_crops = 0
216
+ desc = "Check annotations..."
217
+ with ThreadPool(num_threads) as pool:
218
+ pbar = tqdm(
219
+ pool.imap_unordered(verify_images_func, list(self._ann.items())),
220
+ desc=desc,
221
+ total=len(self._ann),
222
+ )
223
+
224
+ for (img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops) in pbar:
225
+ broken += 1 if is_corrupted else 0
226
+ all_msgs.extend(msgs)
227
+ all_skipped_crops += skipped_crops
228
+ skipped += 1 if is_empty_annotations else 0
229
+ if img_info is not None:
230
+ img_path, img_samples = img_info
231
+ good_anns[img_path] = img_samples
232
+ all_associated_objects.update({img_path: associated_objects})
233
+
234
+ pbar.desc = (
235
+ f"{desc} {skipped} images skipped ({all_skipped_crops} crops are incorrect); "
236
+ f"{broken} images corrupted"
237
+ )
238
+
239
+ pbar.close()
240
+
241
+ for msg in all_msgs:
242
+ print(msg)
243
+ print(f"\nLeft images: {len(good_anns)}")
244
+
245
+ return good_anns, all_associated_objects
246
+
247
+
248
+ def verify_images(
249
+ img_info,
250
+ min_size: int,
251
+ min_person_size: int,
252
+ with_persons: bool,
253
+ disable_faces: bool,
254
+ crop_round_tol: float,
255
+ min_person_aftercut_ratio: float,
256
+ only_age: bool,
257
+ ):
258
+ # If crop is too small, if image can not be read or if image does not exist
259
+ # then filter out this sample
260
+
261
+ disable_faces = disable_faces and with_persons
262
+ kwargs = dict(
263
+ min_person_size=min_person_size,
264
+ disable_faces=disable_faces,
265
+ with_persons=with_persons,
266
+ crop_round_tol=crop_round_tol,
267
+ min_person_aftercut_ratio=min_person_aftercut_ratio,
268
+ only_age=only_age,
269
+ )
270
+
271
+ def bbox_correct(bbox, min_size, im_h, im_w) -> Tuple[bool, List[int]]:
272
+ ymin, ymax, xmin, xmax = _correct_bbox(bbox, im_h, im_w)
273
+ crop_h, crop_w = ymax - ymin, xmax - xmin
274
+ if crop_h < min_size or crop_w < min_size:
275
+ return False, [-1, -1, -1, -1]
276
+ bbox = [xmin, ymin, xmax, ymax]
277
+ return True, bbox
278
+
279
+ msgs = []
280
+ skipped_crops = 0
281
+ is_corrupted = False
282
+ is_empty_annotations = False
283
+
284
+ img_path: str = img_info[0]
285
+ img_samples: List[PictureInfo] = img_info[1]
286
+ try:
287
+ im_cv = cv2.imread(img_path)
288
+ im_h, im_w = im_cv.shape[:2]
289
+ except Exception:
290
+ msgs.append(f"Can not load image {img_path}")
291
+ is_corrupted = True
292
+ return None, {}, msgs, is_corrupted, is_empty_annotations, skipped_crops
293
+
294
+ out_samples: List[PictureInfo] = []
295
+ for sample in img_samples:
296
+ # correct face bbox
297
+ if sample.has_face_bbox:
298
+ is_correct, sample.bbox = bbox_correct(sample.bbox, min_size, im_h, im_w)
299
+ if not is_correct and sample.has_gt(only_age):
300
+ msgs.append("Small face. Passing..")
301
+ skipped_crops += 1
302
+
303
+ # correct person bbox
304
+ if sample.has_person_bbox:
305
+ is_correct, sample.person_bbox = bbox_correct(
306
+ sample.person_bbox, max(min_person_size, min_size), im_h, im_w
307
+ )
308
+ if not is_correct and sample.has_gt(only_age):
309
+ msgs.append(f"Small person {img_path}. Passing..")
310
+ skipped_crops += 1
311
+
312
+ if sample.has_face_bbox or sample.has_person_bbox:
313
+ out_samples.append(sample)
314
+ elif sample.has_gt(only_age):
315
+ msgs.append("Sample hs no face and no body. Passing..")
316
+ skipped_crops += 1
317
+
318
+ # sort that samples with undefined age and gender be the last
319
+ out_samples = sorted(out_samples, key=lambda sample: 1 if not sample.has_gt(only_age) else 0)
320
+
321
+ # for each person find other faces and persons bboxes, intersected with it
322
+ associated_objects: Dict[int, List[List[int]]] = find_associated_objects(out_samples, only_age=only_age)
323
+
324
+ out_samples, associated_objects, skipped_crops = filter_bad_samples(
325
+ out_samples, associated_objects, im_cv, msgs, skipped_crops, **kwargs
326
+ )
327
+
328
+ out_img_info: Optional[Tuple[str, List]] = (img_path, out_samples)
329
+ if len(out_samples) == 0:
330
+ out_img_info = None
331
+ is_empty_annotations = True
332
+
333
+ return out_img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops
334
+
335
+
336
+ def filter_bad_samples(
337
+ out_samples: List[PictureInfo],
338
+ associated_objects: dict,
339
+ im_cv: np.ndarray,
340
+ msgs: List[str],
341
+ skipped_crops: int,
342
+ **kwargs,
343
+ ):
344
+ with_persons, disable_faces, min_person_size, crop_round_tol, min_person_aftercut_ratio, only_age = (
345
+ kwargs["with_persons"],
346
+ kwargs["disable_faces"],
347
+ kwargs["min_person_size"],
348
+ kwargs["crop_round_tol"],
349
+ kwargs["min_person_aftercut_ratio"],
350
+ kwargs["only_age"],
351
+ )
352
+
353
+ # left only samples with annotations
354
+ inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_gt(only_age)]
355
+ out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
356
+
357
+ if kwargs["disable_faces"]:
358
+ # clear all faces
359
+ for ind, sample in enumerate(out_samples):
360
+ sample.clear_face_bbox()
361
+
362
+ # left only samples with person_bbox
363
+ inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_person_bbox]
364
+ out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
365
+
366
+ if with_persons or disable_faces:
367
+ # check that preprocessing func
368
+ # _cropout_asced_objs() return not empty person_image for each out sample
369
+
370
+ inds = []
371
+ for ind, sample in enumerate(out_samples):
372
+ person_empty = True
373
+ if sample.has_person_bbox:
374
+ xmin, ymin, xmax, ymax = sample.person_bbox
375
+ crop = im_cv[ymin:ymax, xmin:xmax]
376
+ # cut off all associated objects from person crop
377
+ _, person_empty = _cropout_asced_objs(
378
+ associated_objects[ind],
379
+ sample.person_bbox,
380
+ crop.copy(),
381
+ min_person_size=min_person_size,
382
+ crop_round_tol=crop_round_tol,
383
+ min_person_aftercut_ratio=min_person_aftercut_ratio,
384
+ )
385
+
386
+ if person_empty and not sample.has_face_bbox:
387
+ msgs.append("Small person after preprocessing. Passing..")
388
+ skipped_crops += 1
389
+ else:
390
+ inds.append(ind)
391
+ out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
392
+
393
+ assert len(associated_objects) == len(out_samples)
394
+ return out_samples, associated_objects, skipped_crops
395
+
396
+
397
+ def _filter_by_ind(out_samples, associated_objects, inds):
398
+ _associated_objects = {}
399
+ _out_samples = []
400
+ for ind, sample in enumerate(out_samples):
401
+ if ind in inds:
402
+ _associated_objects[len(_out_samples)] = associated_objects[ind]
403
+ _out_samples.append(sample)
404
+
405
+ return _out_samples, _associated_objects
406
+
407
+
408
+ def find_associated_objects(
409
+ image_samples: List[PictureInfo], iou_thresh=0.0001, only_age=False
410
+ ) -> Dict[int, List[List[int]]]:
411
+ """
412
+ For each person (which has gt age and gt gender) find other faces and persons bboxes, intersected with it
413
+ """
414
+ associated_objects: Dict[int, List[List[int]]] = {}
415
+
416
+ for iindex, image_sample_info in enumerate(image_samples):
417
+ # add own face
418
+ associated_objects[iindex] = [image_sample_info.bbox] if image_sample_info.has_face_bbox else []
419
+
420
+ if not image_sample_info.has_person_bbox or not image_sample_info.has_gt(only_age):
421
+ # if sample has not gt => not be used
422
+ continue
423
+
424
+ iperson_box = image_sample_info.person_bbox
425
+ for jindex, other_image_sample in enumerate(image_samples):
426
+ if iindex == jindex:
427
+ continue
428
+ if other_image_sample.has_face_bbox:
429
+ jface_bbox = other_image_sample.bbox
430
+ iou = _get_iou(jface_bbox, iperson_box)
431
+ if iou >= iou_thresh:
432
+ associated_objects[iindex].append(jface_bbox)
433
+ if other_image_sample.has_person_bbox:
434
+ jperson_bbox = other_image_sample.person_bbox
435
+ iou = _get_iou(jperson_bbox, iperson_box)
436
+ if iou >= iou_thresh:
437
+ associated_objects[iindex].append(jperson_bbox)
438
+
439
+ return associated_objects
440
+
441
+
442
+ def _cropout_asced_objs(
443
+ asced_objects,
444
+ person_bbox,
445
+ crop,
446
+ min_person_size,
447
+ crop_round_tol,
448
+ min_person_aftercut_ratio,
449
+ crop_out_color=(0, 0, 0),
450
+ ):
451
+ empty = False
452
+ xmin, ymin, xmax, ymax = person_bbox
453
+
454
+ for a_obj in asced_objects:
455
+ aobj_xmin, aobj_ymin, aobj_xmax, aobj_ymax = a_obj
456
+
457
+ aobj_ymin = int(max(aobj_ymin - ymin, 0))
458
+ aobj_xmin = int(max(aobj_xmin - xmin, 0))
459
+ aobj_ymax = int(min(aobj_ymax - ymin, ymax - ymin))
460
+ aobj_xmax = int(min(aobj_xmax - xmin, xmax - xmin))
461
+
462
+ crop[aobj_ymin:aobj_ymax, aobj_xmin:aobj_xmax] = crop_out_color
463
+
464
+ crop, cropped_ratio = cropout_black_parts(crop, crop_round_tol)
465
+ if (
466
+ crop.shape[0] < min_person_size or crop.shape[1] < min_person_size
467
+ ) or cropped_ratio < min_person_aftercut_ratio:
468
+ crop = None
469
+ empty = True
470
+
471
+ return crop, empty
472
+
473
+
474
+ def _correct_bbox(bbox, h, w):
475
+ xmin, ymin, xmax, ymax = bbox
476
+ ymin = min(max(ymin, 0), h)
477
+ ymax = min(max(ymax, 0), h)
478
+ xmin = min(max(xmin, 0), w)
479
+ xmax = min(max(xmax, 0), w)
480
+ return ymin, ymax, xmin, xmax
481
+
482
+
483
+ def _get_iou(bbox1, bbox2):
484
+ xmin1, ymin1, xmax1, ymax1 = bbox1
485
+ xmin2, ymin2, xmax2, ymax2 = bbox2
486
+ iou = IOU(
487
+ [ymin1, xmin1, ymax1, xmax1],
488
+ [ymin2, xmin2, ymax2, xmax2],
489
+ )
490
+ return iou
mivolo/data/misc.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import ast
3
+ import re
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torchvision.transforms.functional as F
10
+ from scipy.optimize import linear_sum_assignment
11
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
12
+
13
+ CROP_ROUND_RATE = 0.1
14
+ MIN_PERSON_CROP_NONZERO = 0.5
15
+
16
+
17
+ def aggregate_votes_winsorized(ages, max_age_dist=6):
18
+ # Replace any annotation that is more than a max_age_dist away from the median
19
+ # with the median + max_age_dist if higher or max_age_dist - max_age_dist if below
20
+ median = np.median(ages)
21
+ ages = np.clip(ages, median - max_age_dist, median + max_age_dist)
22
+ return np.mean(ages)
23
+
24
+
25
+ def cropout_black_parts(img, tol=0.3):
26
+ # Create a binary mask of zero pixels
27
+ zero_pixels_mask = np.all(img == 0, axis=2)
28
+ # Calculate the threshold for zero pixels in rows and columns
29
+ threshold = img.shape[0] - img.shape[0] * tol
30
+ # Calculate row sums and column sums of zero pixels mask
31
+ row_sums = np.sum(zero_pixels_mask, axis=1)
32
+ col_sums = np.sum(zero_pixels_mask, axis=0)
33
+ # Find the first and last rows with zero pixel sums above the threshold
34
+ start_row = np.argmin(row_sums > threshold)
35
+ end_row = img.shape[0] - np.argmin(row_sums[::-1] > threshold)
36
+ # Find the first and last columns with zero pixel sums above the threshold
37
+ start_col = np.argmin(col_sums > threshold)
38
+ end_col = img.shape[1] - np.argmin(col_sums[::-1] > threshold)
39
+ # Crop the image
40
+ cropped_img = img[start_row:end_row, start_col:end_col, :]
41
+ area = cropped_img.shape[0] * cropped_img.shape[1]
42
+ area_orig = img.shape[0] * img.shape[1]
43
+ return cropped_img, area / area_orig
44
+
45
+
46
+ def natural_key(string_):
47
+ """See http://www.codinghorror.com/blog/archives/001018.html"""
48
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
49
+
50
+
51
+ def add_bool_arg(parser, name, default=False, help=""):
52
+ dest_name = name.replace("-", "_")
53
+ group = parser.add_mutually_exclusive_group(required=False)
54
+ group.add_argument("--" + name, dest=dest_name, action="store_true", help=help)
55
+ group.add_argument("--no-" + name, dest=dest_name, action="store_false", help=help)
56
+ parser.set_defaults(**{dest_name: default})
57
+
58
+
59
+ def cumulative_score(pred_ages, gt_ages, L, tol=1e-6):
60
+ n = pred_ages.shape[0]
61
+ num_correct = torch.sum(torch.abs(pred_ages - gt_ages) <= L + tol)
62
+ cs_score = num_correct / n
63
+ return cs_score
64
+
65
+
66
+ def cumulative_error(pred_ages, gt_ages, L, tol=1e-6):
67
+ n = pred_ages.shape[0]
68
+ num_correct = torch.sum(torch.abs(pred_ages - gt_ages) >= L + tol)
69
+ cs_score = num_correct / n
70
+ return cs_score
71
+
72
+
73
+ class ParseKwargs(argparse.Action):
74
+ def __call__(self, parser, namespace, values, option_string=None):
75
+ kw = {}
76
+ for value in values:
77
+ key, value = value.split("=")
78
+ try:
79
+ kw[key] = ast.literal_eval(value)
80
+ except ValueError:
81
+ kw[key] = str(value) # fallback to string (avoid need to escape on command line)
82
+ setattr(namespace, self.dest, kw)
83
+
84
+
85
+ def box_iou(box1, box2, over_second=False):
86
+ """
87
+ Return intersection-over-union (Jaccard index) of boxes.
88
+ If over_second == True, return mean(intersection-over-union, (inter / area2))
89
+
90
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
91
+
92
+ Arguments:
93
+ box1 (Tensor[N, 4])
94
+ box2 (Tensor[M, 4])
95
+ Returns:
96
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
97
+ IoU values for every element in boxes1 and boxes2
98
+ """
99
+
100
+ def box_area(box):
101
+ # box = 4xn
102
+ return (box[2] - box[0]) * (box[3] - box[1])
103
+
104
+ area1 = box_area(box1.T)
105
+ area2 = box_area(box2.T)
106
+
107
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
108
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
109
+
110
+ iou = inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
111
+ if over_second:
112
+ return (inter / area2 + iou) / 2 # mean(inter / area2, iou)
113
+ else:
114
+ return iou
115
+
116
+
117
+ def split_batch(bs: int, dev: int) -> Tuple[int, int]:
118
+ full_bs = (bs // dev) * dev
119
+ part_bs = bs - full_bs
120
+ return full_bs, part_bs
121
+
122
+
123
+ def assign_faces(
124
+ persons_bboxes: List[torch.tensor], faces_bboxes: List[torch.tensor], iou_thresh: float = 0.0001
125
+ ) -> Tuple[List[Optional[int]], List[int]]:
126
+ """
127
+ Assign person to each face if it is possible.
128
+ Return:
129
+ - assigned_faces List[Optional[int]]: mapping of face_ind to person_ind
130
+ ( assigned_faces[face_ind] = person_ind ). person_ind can be None
131
+ - unassigned_persons_inds List[int]: persons indexes without any assigned face
132
+ """
133
+
134
+ assigned_faces: List[Optional[int]] = [None for _ in range(len(faces_bboxes))]
135
+ unassigned_persons_inds: List[int] = [p_ind for p_ind in range(len(persons_bboxes))]
136
+
137
+ if len(persons_bboxes) == 0 or len(faces_bboxes) == 0:
138
+ return assigned_faces, unassigned_persons_inds
139
+
140
+ cost_matrix = box_iou(torch.stack(persons_bboxes), torch.stack(faces_bboxes), over_second=True).cpu().numpy()
141
+ persons_indexes, face_indexes = [], []
142
+
143
+ if len(cost_matrix) > 0:
144
+ persons_indexes, face_indexes = linear_sum_assignment(cost_matrix, maximize=True)
145
+
146
+ matched_persons = set()
147
+ for person_idx, face_idx in zip(persons_indexes, face_indexes):
148
+ ciou = cost_matrix[person_idx][face_idx]
149
+ if ciou > iou_thresh:
150
+ if person_idx in matched_persons:
151
+ # Person can not be assigned twice, in reality this should not happen
152
+ continue
153
+ assigned_faces[face_idx] = person_idx
154
+ matched_persons.add(person_idx)
155
+
156
+ unassigned_persons_inds = [p_ind for p_ind in range(len(persons_bboxes)) if p_ind not in matched_persons]
157
+
158
+ return assigned_faces, unassigned_persons_inds
159
+
160
+
161
+ def class_letterbox(im, new_shape=(640, 640), color=(0, 0, 0), scaleup=True):
162
+ # Resize and pad image while meeting stride-multiple constraints
163
+ shape = im.shape[:2] # current shape [height, width]
164
+ if isinstance(new_shape, int):
165
+ new_shape = (new_shape, new_shape)
166
+
167
+ if im.shape[0] == new_shape[0] and im.shape[1] == new_shape[1]:
168
+ return im
169
+
170
+ # Scale ratio (new / old)
171
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
172
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
173
+ r = min(r, 1.0)
174
+
175
+ # Compute padding
176
+ # ratio = r, r # width, height ratios
177
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
178
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
179
+
180
+ dw /= 2 # divide padding into 2 sides
181
+ dh /= 2
182
+
183
+ if shape[::-1] != new_unpad: # resize
184
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
185
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
186
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
187
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
188
+ return im
189
+
190
+
191
+ def prepare_classification_images(
192
+ img_list: List[Optional[np.ndarray]],
193
+ target_size: int = 224,
194
+ mean=IMAGENET_DEFAULT_MEAN,
195
+ std=IMAGENET_DEFAULT_STD,
196
+ device=None,
197
+ ) -> torch.tensor:
198
+
199
+ prepared_images: List[torch.tensor] = []
200
+
201
+ for img in img_list:
202
+ if img is None:
203
+ img = torch.zeros((3, target_size, target_size), dtype=torch.float32)
204
+ img = F.normalize(img, mean=mean, std=std)
205
+ img = img.unsqueeze(0)
206
+ prepared_images.append(img)
207
+ continue
208
+ img = class_letterbox(img, new_shape=(target_size, target_size))
209
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
210
+
211
+ img = img / 255.0
212
+ img = (img - mean) / std
213
+ img = img.astype(dtype=np.float32)
214
+
215
+ img = img.transpose((2, 0, 1))
216
+ img = np.ascontiguousarray(img)
217
+ img = torch.from_numpy(img)
218
+ img = img.unsqueeze(0)
219
+
220
+ prepared_images.append(img)
221
+
222
+ prepared_input = torch.concat(prepared_images)
223
+
224
+ if device:
225
+ prepared_input = prepared_input.to(device)
226
+
227
+ return prepared_input
228
+
229
+
230
+ def IOU(bb1: Union[tuple, list], bb2: Union[tuple, list], norm_second_bbox: bool = False) -> float:
231
+ # expects [ymin, xmin, ymax, xmax], doesnt matter absolute or relative
232
+ assert bb1[1] < bb1[3]
233
+ assert bb1[0] < bb1[2]
234
+ assert bb2[1] < bb2[3]
235
+ assert bb2[0] < bb2[2]
236
+
237
+ # determine the coordinates of the intersection rectangle
238
+ x_left = max(bb1[1], bb2[1])
239
+ y_top = max(bb1[0], bb2[0])
240
+ x_right = min(bb1[3], bb2[3])
241
+ y_bottom = min(bb1[2], bb2[2])
242
+
243
+ if x_right < x_left or y_bottom < y_top:
244
+ return 0.0
245
+
246
+ # The intersection of two axis-aligned bounding boxes is always an
247
+ # axis-aligned bounding box
248
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
249
+ # compute the area of both AABBs
250
+ bb1_area = (bb1[3] - bb1[1]) * (bb1[2] - bb1[0])
251
+ bb2_area = (bb2[3] - bb2[1]) * (bb2[2] - bb2[0])
252
+ if not norm_second_bbox:
253
+ # compute the intersection over union by taking the intersection
254
+ # area and dividing it by the sum of prediction + ground-truth
255
+ # areas - the interesection area
256
+ iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
257
+ else:
258
+ # for cases when we search if second bbox is inside first one
259
+ iou = intersection_area / float(bb2_area)
260
+
261
+ assert iou >= 0.0
262
+ assert iou <= 1.01
263
+
264
+ return iou
mivolo/model/create_timm_model.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from timm https://github.com/huggingface/pytorch-image-models
3
+
4
+ Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
+ """
6
+
7
+ import os
8
+ from typing import Any, Dict, Optional, Union
9
+
10
+ import timm
11
+
12
+ # register new models
13
+ from mivolo.model.mivolo_model import * # noqa: F403, F401
14
+ from timm.layers import set_layer_config
15
+ from timm.models._factory import parse_model_name
16
+ from timm.models._helpers import load_state_dict, remap_checkpoint
17
+ from timm.models._hub import load_model_config_from_hf
18
+ from timm.models._pretrained import PretrainedCfg, split_model_name_tag
19
+ from timm.models._registry import is_model, model_entrypoint
20
+
21
+
22
+ def load_checkpoint(
23
+ model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None
24
+ ):
25
+ if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"):
26
+ # numpy checkpoint, try to load via model specific load_pretrained fn
27
+ if hasattr(model, "load_pretrained"):
28
+ timm.models._model_builder.load_pretrained(checkpoint_path)
29
+ else:
30
+ raise NotImplementedError("Model cannot load numpy checkpoint")
31
+ return
32
+ state_dict = load_state_dict(checkpoint_path, use_ema)
33
+ if remap:
34
+ state_dict = remap_checkpoint(model, state_dict)
35
+ if filter_keys:
36
+ for sd_key in list(state_dict.keys()):
37
+ for filter_key in filter_keys:
38
+ if filter_key in sd_key:
39
+ if sd_key in state_dict:
40
+ del state_dict[sd_key]
41
+
42
+ rep = []
43
+ if state_dict_map is not None:
44
+ # 'patch_embed.conv1.' : 'patch_embed.conv.'
45
+ for state_k in list(state_dict.keys()):
46
+ for target_k, target_v in state_dict_map.items():
47
+ if target_v in state_k:
48
+ target_name = state_k.replace(target_v, target_k)
49
+ state_dict[target_name] = state_dict[state_k]
50
+ rep.append(state_k)
51
+ for r in rep:
52
+ if r in state_dict:
53
+ del state_dict[r]
54
+
55
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False)
56
+ return incompatible_keys
57
+
58
+
59
+ def create_model(
60
+ model_name: str,
61
+ pretrained: bool = False,
62
+ pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
63
+ pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
64
+ checkpoint_path: str = "",
65
+ scriptable: Optional[bool] = None,
66
+ exportable: Optional[bool] = None,
67
+ no_jit: Optional[bool] = None,
68
+ filter_keys=None,
69
+ state_dict_map=None,
70
+ **kwargs,
71
+ ):
72
+ """Create a model
73
+ Lookup model's entrypoint function and pass relevant args to create a new model.
74
+ """
75
+ # Parameters that aren't supported by all models or are intended to only override model defaults if set
76
+ # should default to None in command line args/cfg. Remove them if they are present and not set so that
77
+ # non-supporting models don't break and default args remain in effect.
78
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
79
+
80
+ model_source, model_name = parse_model_name(model_name)
81
+ if model_source == "hf-hub":
82
+ assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub."
83
+ # For model names specified in the form `hf-hub:path/architecture_name@revision`,
84
+ # load model weights + pretrained_cfg from Hugging Face hub.
85
+ pretrained_cfg, model_name = load_model_config_from_hf(model_name)
86
+ else:
87
+ model_name, pretrained_tag = split_model_name_tag(model_name)
88
+ if not pretrained_cfg:
89
+ # a valid pretrained_cfg argument takes priority over tag in model name
90
+ pretrained_cfg = pretrained_tag
91
+
92
+ if not is_model(model_name):
93
+ raise RuntimeError("Unknown model (%s)" % model_name)
94
+
95
+ create_fn = model_entrypoint(model_name)
96
+ with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
97
+ model = create_fn(
98
+ pretrained=pretrained,
99
+ pretrained_cfg=pretrained_cfg,
100
+ pretrained_cfg_overlay=pretrained_cfg_overlay,
101
+ **kwargs,
102
+ )
103
+
104
+ if checkpoint_path:
105
+ load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map)
106
+
107
+ return model
mivolo/model/cross_bottleneck_attn.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code based on timm https://github.com/huggingface/pytorch-image-models
3
+
4
+ Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from timm.layers.bottleneck_attn import PosEmbedRel
10
+ from timm.layers.helpers import make_divisible
11
+ from timm.layers.mlp import Mlp
12
+ from timm.layers.trace_utils import _assert
13
+ from timm.layers.weight_init import trunc_normal_
14
+
15
+
16
+ class CrossBottleneckAttn(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim,
20
+ dim_out=None,
21
+ feat_size=None,
22
+ stride=1,
23
+ num_heads=4,
24
+ dim_head=None,
25
+ qk_ratio=1.0,
26
+ qkv_bias=False,
27
+ scale_pos_embed=False,
28
+ ):
29
+ super().__init__()
30
+ assert feat_size is not None, "A concrete feature size matching expected input (H, W) is required"
31
+ dim_out = dim_out or dim
32
+ assert dim_out % num_heads == 0
33
+
34
+ self.num_heads = num_heads
35
+ self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
36
+ self.dim_head_v = dim_out // self.num_heads
37
+ self.dim_out_qk = num_heads * self.dim_head_qk
38
+ self.dim_out_v = num_heads * self.dim_head_v
39
+ self.scale = self.dim_head_qk**-0.5
40
+ self.scale_pos_embed = scale_pos_embed
41
+
42
+ self.qkv_f = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
43
+ self.qkv_p = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
44
+
45
+ # NOTE I'm only supporting relative pos embedding for now
46
+ self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
47
+
48
+ self.norm = nn.LayerNorm([self.dim_out_v * 2, *feat_size])
49
+ mlp_ratio = 4
50
+ self.mlp = Mlp(
51
+ in_features=self.dim_out_v * 2,
52
+ hidden_features=int(dim * mlp_ratio),
53
+ act_layer=nn.GELU,
54
+ out_features=dim_out,
55
+ drop=0,
56
+ use_conv=True,
57
+ )
58
+
59
+ self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
60
+ self.reset_parameters()
61
+
62
+ def reset_parameters(self):
63
+ trunc_normal_(self.qkv_f.weight, std=self.qkv_f.weight.shape[1] ** -0.5) # fan-in
64
+ trunc_normal_(self.qkv_p.weight, std=self.qkv_p.weight.shape[1] ** -0.5) # fan-in
65
+ trunc_normal_(self.pos_embed.height_rel, std=self.scale)
66
+ trunc_normal_(self.pos_embed.width_rel, std=self.scale)
67
+
68
+ def get_qkv(self, x, qvk_conv):
69
+ B, C, H, W = x.shape
70
+
71
+ x = qvk_conv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
72
+
73
+ q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
74
+
75
+ q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
76
+ k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
77
+ v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
78
+
79
+ return q, k, v
80
+
81
+ def apply_attn(self, q, k, v, B, H, W, dropout=None):
82
+ if self.scale_pos_embed:
83
+ attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
84
+ else:
85
+ attn = (q @ k) * self.scale + self.pos_embed(q)
86
+ attn = attn.softmax(dim=-1)
87
+ if dropout:
88
+ attn = dropout(attn)
89
+
90
+ out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
91
+ return out
92
+
93
+ def forward(self, x):
94
+ B, C, H, W = x.shape
95
+
96
+ dim = int(C / 2)
97
+ x1 = x[:, :dim, :, :]
98
+ x2 = x[:, dim:, :, :]
99
+
100
+ _assert(H == self.pos_embed.height, "")
101
+ _assert(W == self.pos_embed.width, "")
102
+
103
+ q_f, k_f, v_f = self.get_qkv(x1, self.qkv_f)
104
+ q_p, k_p, v_p = self.get_qkv(x2, self.qkv_p)
105
+
106
+ # person to face
107
+ out_f = self.apply_attn(q_f, k_p, v_p, B, H, W)
108
+ # face to person
109
+ out_p = self.apply_attn(q_p, k_f, v_f, B, H, W)
110
+
111
+ x_pf = torch.cat((out_f, out_p), dim=1) # B, dim_out * 2, H, W
112
+ x_pf = self.norm(x_pf)
113
+ x_pf = self.mlp(x_pf) # B, dim_out, H, W
114
+
115
+ out = self.pool(x_pf)
116
+ return out
mivolo/model/mi_volo.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from mivolo.data.misc import prepare_classification_images
7
+ from mivolo.model.create_timm_model import create_model
8
+ from mivolo.structures import PersonAndFaceCrops, PersonAndFaceResult
9
+ from timm.data import resolve_data_config
10
+
11
+ _logger = logging.getLogger("MiVOLO")
12
+ has_compile = hasattr(torch, "compile")
13
+
14
+
15
+ class Meta:
16
+ def __init__(self):
17
+ self.min_age = None
18
+ self.max_age = None
19
+ self.avg_age = None
20
+ self.num_classes = None
21
+
22
+ self.in_chans = 3
23
+ self.with_persons_model = False
24
+ self.disable_faces = False
25
+ self.use_persons = True
26
+ self.only_age = False
27
+
28
+ self.num_classes_gender = 2
29
+
30
+ def load_from_ckpt(self, ckpt_path: str, disable_faces: bool = False, use_persons: bool = True) -> "Meta":
31
+
32
+ state = torch.load(ckpt_path, map_location="cpu")
33
+
34
+ self.min_age = state["min_age"]
35
+ self.max_age = state["max_age"]
36
+ self.avg_age = state["avg_age"]
37
+ self.only_age = state["no_gender"]
38
+
39
+ only_age = state["no_gender"]
40
+
41
+ self.disable_faces = disable_faces
42
+ if "with_persons_model" in state:
43
+ self.with_persons_model = state["with_persons_model"]
44
+ else:
45
+ self.with_persons_model = True if "patch_embed.conv1.0.weight" in state["state_dict"] else False
46
+
47
+ self.num_classes = 1 if only_age else 3
48
+ self.in_chans = 3 if not self.with_persons_model else 6
49
+ self.use_persons = use_persons and self.with_persons_model
50
+
51
+ if not self.with_persons_model and self.disable_faces:
52
+ raise ValueError("You can not use disable-faces for faces-only model")
53
+ if self.with_persons_model and self.disable_faces and not self.use_persons:
54
+ raise ValueError("You can not disable faces and persons together")
55
+
56
+ return self
57
+
58
+ def __str__(self):
59
+ attrs = vars(self)
60
+ attrs.update({"use_person_crops": self.use_person_crops, "use_face_crops": self.use_face_crops})
61
+ return ", ".join("%s: %s" % item for item in attrs.items())
62
+
63
+ @property
64
+ def use_person_crops(self) -> bool:
65
+ return self.with_persons_model and self.use_persons
66
+
67
+ @property
68
+ def use_face_crops(self) -> bool:
69
+ return not self.disable_faces or not self.with_persons_model
70
+
71
+
72
+ class MiVOLO:
73
+ def __init__(
74
+ self,
75
+ ckpt_path: str,
76
+ device: str = "cuda",
77
+ half: bool = True,
78
+ disable_faces: bool = False,
79
+ use_persons: bool = True,
80
+ verbose: bool = False,
81
+ torchcompile: Optional[str] = None,
82
+ ):
83
+ self.verbose = verbose
84
+ self.device = torch.device(device)
85
+ self.half = half and self.device.type != "cpu"
86
+
87
+ self.meta: Meta = Meta().load_from_ckpt(ckpt_path, disable_faces, use_persons)
88
+ if self.verbose:
89
+ _logger.info(f"Model meta:\n{str(self.meta)}")
90
+
91
+ model_name = "mivolo_d1_224"
92
+ self.model = create_model(
93
+ model_name=model_name,
94
+ num_classes=self.meta.num_classes,
95
+ in_chans=self.meta.in_chans,
96
+ pretrained=False,
97
+ checkpoint_path=ckpt_path,
98
+ filter_keys=["fds."],
99
+ )
100
+ self.param_count = sum([m.numel() for m in self.model.parameters()])
101
+ _logger.info(f"Model {model_name} created, param count: {self.param_count}")
102
+
103
+ self.data_config = resolve_data_config(
104
+ model=self.model,
105
+ verbose=verbose,
106
+ use_test_size=True,
107
+ )
108
+ self.data_config["crop_pct"] = 1.0
109
+ c, h, w = self.data_config["input_size"]
110
+ assert h == w, "Incorrect data_config"
111
+ self.input_size = w
112
+
113
+ self.model = self.model.to(self.device)
114
+
115
+ if torchcompile:
116
+ assert has_compile, "A version of torch w/ torch.compile() is required for --compile, possibly a nightly."
117
+ torch._dynamo.reset()
118
+ self.model = torch.compile(self.model, backend=torchcompile)
119
+
120
+ self.model.eval()
121
+ if self.half:
122
+ self.model = self.model.half()
123
+
124
+ def warmup(self, batch_size: int, steps=10):
125
+ if self.meta.with_persons_model:
126
+ input_size = (6, self.input_size, self.input_size)
127
+ else:
128
+ input_size = self.data_config["input_size"]
129
+
130
+ input = torch.randn((batch_size,) + tuple(input_size)).to(self.device)
131
+
132
+ for _ in range(steps):
133
+ out = self.inference(input) # noqa: F841
134
+
135
+ if torch.cuda.is_available():
136
+ torch.cuda.synchronize()
137
+
138
+ def inference(self, model_input: torch.tensor) -> torch.tensor:
139
+
140
+ with torch.no_grad():
141
+ if self.half:
142
+ model_input = model_input.half()
143
+ output = self.model(model_input)
144
+ return output
145
+
146
+ def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
147
+ if detected_bboxes.n_objects == 0:
148
+ return
149
+
150
+ faces_input, person_input, faces_inds, bodies_inds = self.prepare_crops(image, detected_bboxes)
151
+
152
+ if self.meta.with_persons_model:
153
+ model_input = torch.cat((faces_input, person_input), dim=1)
154
+ else:
155
+ model_input = faces_input
156
+ output = self.inference(model_input)
157
+
158
+ # write gender and age results into detected_bboxes
159
+ self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds)
160
+
161
+ def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds):
162
+ if self.meta.only_age:
163
+ age_output = output
164
+ gender_probs, gender_indx = None, None
165
+ else:
166
+ age_output = output[:, 2]
167
+ gender_output = output[:, :2].softmax(-1)
168
+ gender_probs, gender_indx = gender_output.topk(1)
169
+
170
+ assert output.shape[0] == len(faces_inds) == len(bodies_inds)
171
+
172
+ # per face
173
+ for index in range(output.shape[0]):
174
+ face_ind = faces_inds[index]
175
+ body_ind = bodies_inds[index]
176
+
177
+ # get_age
178
+ age = age_output[index].item()
179
+ age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age
180
+ age = round(age, 2)
181
+
182
+ detected_bboxes.set_age(face_ind, age)
183
+ detected_bboxes.set_age(body_ind, age)
184
+
185
+ _logger.info(f"\tage: {age}")
186
+
187
+ if gender_probs is not None:
188
+ gender = "male" if gender_indx[index].item() == 0 else "female"
189
+ gender_score = gender_probs[index].item()
190
+
191
+ _logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]")
192
+
193
+ detected_bboxes.set_gender(face_ind, gender, gender_score)
194
+ detected_bboxes.set_gender(body_ind, gender, gender_score)
195
+
196
+ def prepare_crops(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
197
+
198
+ if self.meta.use_person_crops and self.meta.use_face_crops:
199
+ detected_bboxes.associate_faces_with_persons()
200
+
201
+ crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image)
202
+ (bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies(
203
+ self.meta.use_person_crops, self.meta.use_face_crops
204
+ )
205
+
206
+ if not self.meta.use_face_crops:
207
+ assert all(f is None for f in faces_crops)
208
+
209
+ faces_input = prepare_classification_images(
210
+ faces_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
211
+ )
212
+
213
+ if not self.meta.use_person_crops:
214
+ assert all(p is None for p in bodies_crops)
215
+
216
+ person_input = prepare_classification_images(
217
+ bodies_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
218
+ )
219
+
220
+ _logger.info(
221
+ f"faces_input: {faces_input.shape if faces_input is not None else None}, "
222
+ f"person_input: {person_input.shape if person_input is not None else None}"
223
+ )
224
+
225
+ return faces_input, person_input, faces_inds, bodies_inds
226
+
227
+
228
+ if __name__ == "__main__":
229
+ model = MiVOLO("../pretrained/checkpoint-377.pth.tar", half=True, device="cuda:0")
mivolo/model/mivolo_model.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from timm https://github.com/huggingface/pytorch-image-models
3
+
4
+ Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from mivolo.model.cross_bottleneck_attn import CrossBottleneckAttn
10
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
+ from timm.layers import trunc_normal_
12
+ from timm.models._builder import build_model_with_cfg
13
+ from timm.models._registry import register_model
14
+ from timm.models.volo import VOLO
15
+
16
+ __all__ = ["MiVOLOModel"] # model_registry will add each entrypoint fn to this
17
+
18
+
19
+ def _cfg(url="", **kwargs):
20
+ return {
21
+ "url": url,
22
+ "num_classes": 1000,
23
+ "input_size": (3, 224, 224),
24
+ "pool_size": None,
25
+ "crop_pct": 0.96,
26
+ "interpolation": "bicubic",
27
+ "fixed_input_size": True,
28
+ "mean": IMAGENET_DEFAULT_MEAN,
29
+ "std": IMAGENET_DEFAULT_STD,
30
+ "first_conv": None,
31
+ "classifier": ("head", "aux_head"),
32
+ **kwargs,
33
+ }
34
+
35
+
36
+ default_cfgs = {
37
+ "mivolo_d1_224": _cfg(
38
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar", crop_pct=0.96
39
+ ),
40
+ "mivolo_d1_384": _cfg(
41
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar",
42
+ crop_pct=1.0,
43
+ input_size=(3, 384, 384),
44
+ ),
45
+ "mivolo_d2_224": _cfg(
46
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar", crop_pct=0.96
47
+ ),
48
+ "mivolo_d2_384": _cfg(
49
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar",
50
+ crop_pct=1.0,
51
+ input_size=(3, 384, 384),
52
+ ),
53
+ "mivolo_d3_224": _cfg(
54
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar", crop_pct=0.96
55
+ ),
56
+ "mivolo_d3_448": _cfg(
57
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar",
58
+ crop_pct=1.0,
59
+ input_size=(3, 448, 448),
60
+ ),
61
+ "mivolo_d4_224": _cfg(
62
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar", crop_pct=0.96
63
+ ),
64
+ "mivolo_d4_448": _cfg(
65
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar",
66
+ crop_pct=1.15,
67
+ input_size=(3, 448, 448),
68
+ ),
69
+ "mivolo_d5_224": _cfg(
70
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar", crop_pct=0.96
71
+ ),
72
+ "mivolo_d5_448": _cfg(
73
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar",
74
+ crop_pct=1.15,
75
+ input_size=(3, 448, 448),
76
+ ),
77
+ "mivolo_d5_512": _cfg(
78
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar",
79
+ crop_pct=1.15,
80
+ input_size=(3, 512, 512),
81
+ ),
82
+ }
83
+
84
+
85
+ def get_output_size(input_shape, conv_layer):
86
+ padding = conv_layer.padding
87
+ dilation = conv_layer.dilation
88
+ kernel_size = conv_layer.kernel_size
89
+ stride = conv_layer.stride
90
+
91
+ output_size = [
92
+ ((input_shape[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[i]) + 1 for i in range(2)
93
+ ]
94
+ return output_size
95
+
96
+
97
+ def get_output_size_module(input_size, stem):
98
+ output_size = input_size
99
+
100
+ for module in stem:
101
+ if isinstance(module, nn.Conv2d):
102
+ output_size = [
103
+ (
104
+ (output_size[i] + 2 * module.padding[i] - module.dilation[i] * (module.kernel_size[i] - 1) - 1)
105
+ // module.stride[i]
106
+ )
107
+ + 1
108
+ for i in range(2)
109
+ ]
110
+
111
+ return output_size
112
+
113
+
114
+ class PatchEmbed(nn.Module):
115
+ """Image to Patch Embedding."""
116
+
117
+ def __init__(
118
+ self, img_size=224, stem_conv=False, stem_stride=1, patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384
119
+ ):
120
+ super().__init__()
121
+ assert patch_size in [4, 8, 16]
122
+ assert in_chans in [3, 6]
123
+ self.with_persons_model = in_chans == 6
124
+ self.use_cross_attn = True
125
+
126
+ if stem_conv:
127
+ if not self.with_persons_model:
128
+ self.conv = self.create_stem(stem_stride, in_chans, hidden_dim)
129
+ else:
130
+ self.conv = True # just to match interface
131
+ # split
132
+ self.conv1 = self.create_stem(stem_stride, 3, hidden_dim)
133
+ self.conv2 = self.create_stem(stem_stride, 3, hidden_dim)
134
+ else:
135
+ self.conv = None
136
+
137
+ if self.with_persons_model:
138
+
139
+ self.proj1 = nn.Conv2d(
140
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
141
+ )
142
+ self.proj2 = nn.Conv2d(
143
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
144
+ )
145
+
146
+ stem_out_shape = get_output_size_module((img_size, img_size), self.conv1)
147
+ self.proj_output_size = get_output_size(stem_out_shape, self.proj1)
148
+
149
+ self.map = CrossBottleneckAttn(embed_dim, dim_out=embed_dim, num_heads=1, feat_size=self.proj_output_size)
150
+
151
+ else:
152
+ self.proj = nn.Conv2d(
153
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
154
+ )
155
+
156
+ self.patch_dim = img_size // patch_size
157
+ self.num_patches = self.patch_dim**2
158
+
159
+ def create_stem(self, stem_stride, in_chans, hidden_dim):
160
+ return nn.Sequential(
161
+ nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112
162
+ nn.BatchNorm2d(hidden_dim),
163
+ nn.ReLU(inplace=True),
164
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
165
+ nn.BatchNorm2d(hidden_dim),
166
+ nn.ReLU(inplace=True),
167
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
168
+ nn.BatchNorm2d(hidden_dim),
169
+ nn.ReLU(inplace=True),
170
+ )
171
+
172
+ def forward(self, x):
173
+ if self.conv is not None:
174
+ if self.with_persons_model:
175
+ x1 = x[:, :3]
176
+ x2 = x[:, 3:]
177
+
178
+ x1 = self.conv1(x1)
179
+ x1 = self.proj1(x1)
180
+
181
+ x2 = self.conv2(x2)
182
+ x2 = self.proj2(x2)
183
+
184
+ x = torch.cat([x1, x2], dim=1)
185
+ x = self.map(x)
186
+ else:
187
+ x = self.conv(x)
188
+ x = self.proj(x) # B, C, H, W
189
+
190
+ return x
191
+
192
+
193
+ class MiVOLOModel(VOLO):
194
+ """
195
+ Vision Outlooker, the main class of our model
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ layers,
201
+ img_size=224,
202
+ in_chans=3,
203
+ num_classes=1000,
204
+ global_pool="token",
205
+ patch_size=8,
206
+ stem_hidden_dim=64,
207
+ embed_dims=None,
208
+ num_heads=None,
209
+ downsamples=(True, False, False, False),
210
+ outlook_attention=(True, False, False, False),
211
+ mlp_ratio=3.0,
212
+ qkv_bias=False,
213
+ drop_rate=0.0,
214
+ attn_drop_rate=0.0,
215
+ drop_path_rate=0.0,
216
+ norm_layer=nn.LayerNorm,
217
+ post_layers=("ca", "ca"),
218
+ use_aux_head=True,
219
+ use_mix_token=False,
220
+ pooling_scale=2,
221
+ ):
222
+ super().__init__(
223
+ layers,
224
+ img_size,
225
+ in_chans,
226
+ num_classes,
227
+ global_pool,
228
+ patch_size,
229
+ stem_hidden_dim,
230
+ embed_dims,
231
+ num_heads,
232
+ downsamples,
233
+ outlook_attention,
234
+ mlp_ratio,
235
+ qkv_bias,
236
+ drop_rate,
237
+ attn_drop_rate,
238
+ drop_path_rate,
239
+ norm_layer,
240
+ post_layers,
241
+ use_aux_head,
242
+ use_mix_token,
243
+ pooling_scale,
244
+ )
245
+
246
+ self.patch_embed = PatchEmbed(
247
+ stem_conv=True,
248
+ stem_stride=2,
249
+ patch_size=patch_size,
250
+ in_chans=in_chans,
251
+ hidden_dim=stem_hidden_dim,
252
+ embed_dim=embed_dims[0],
253
+ )
254
+
255
+ trunc_normal_(self.pos_embed, std=0.02)
256
+ self.apply(self._init_weights)
257
+
258
+ def forward_features(self, x):
259
+ x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
260
+
261
+ # step2: tokens learning in the two stages
262
+ x = self.forward_tokens(x)
263
+
264
+ # step3: post network, apply class attention or not
265
+ if self.post_network is not None:
266
+ x = self.forward_cls(x)
267
+ x = self.norm(x)
268
+ return x
269
+
270
+ def forward_head(self, x, pre_logits: bool = False, targets=None, epoch=None):
271
+ if self.global_pool == "avg":
272
+ out = x.mean(dim=1)
273
+ elif self.global_pool == "token":
274
+ out = x[:, 0]
275
+ else:
276
+ out = x
277
+ if pre_logits:
278
+ return out
279
+
280
+ features = out
281
+ fds_enabled = hasattr(self, "_fds_forward")
282
+ if fds_enabled:
283
+ features = self._fds_forward(features, targets, epoch)
284
+
285
+ out = self.head(features)
286
+ if self.aux_head is not None:
287
+ # generate classes in all feature tokens, see token labeling
288
+ aux = self.aux_head(x[:, 1:])
289
+ out = out + 0.5 * aux.max(1)[0]
290
+
291
+ return (out, features) if (fds_enabled and self.training) else out
292
+
293
+ def forward(self, x, targets=None, epoch=None):
294
+ """simplified forward (without mix token training)"""
295
+ x = self.forward_features(x)
296
+ x = self.forward_head(x, targets=targets, epoch=epoch)
297
+ return x
298
+
299
+
300
+ def _create_mivolo(variant, pretrained=False, **kwargs):
301
+ if kwargs.get("features_only", None):
302
+ raise RuntimeError("features_only not implemented for Vision Transformer models.")
303
+ return build_model_with_cfg(MiVOLOModel, variant, pretrained, **kwargs)
304
+
305
+
306
+ @register_model
307
+ def mivolo_d1_224(pretrained=False, **kwargs):
308
+ model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
309
+ model = _create_mivolo("mivolo_d1_224", pretrained=pretrained, **model_args)
310
+ return model
311
+
312
+
313
+ @register_model
314
+ def mivolo_d1_384(pretrained=False, **kwargs):
315
+ model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
316
+ model = _create_mivolo("mivolo_d1_384", pretrained=pretrained, **model_args)
317
+ return model
318
+
319
+
320
+ @register_model
321
+ def mivolo_d2_224(pretrained=False, **kwargs):
322
+ model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
323
+ model = _create_mivolo("mivolo_d2_224", pretrained=pretrained, **model_args)
324
+ return model
325
+
326
+
327
+ @register_model
328
+ def mivolo_d2_384(pretrained=False, **kwargs):
329
+ model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
330
+ model = _create_mivolo("mivolo_d2_384", pretrained=pretrained, **model_args)
331
+ return model
332
+
333
+
334
+ @register_model
335
+ def mivolo_d3_224(pretrained=False, **kwargs):
336
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
337
+ model = _create_mivolo("mivolo_d3_224", pretrained=pretrained, **model_args)
338
+ return model
339
+
340
+
341
+ @register_model
342
+ def mivolo_d3_448(pretrained=False, **kwargs):
343
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
344
+ model = _create_mivolo("mivolo_d3_448", pretrained=pretrained, **model_args)
345
+ return model
346
+
347
+
348
+ @register_model
349
+ def mivolo_d4_224(pretrained=False, **kwargs):
350
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
351
+ model = _create_mivolo("mivolo_d4_224", pretrained=pretrained, **model_args)
352
+ return model
353
+
354
+
355
+ @register_model
356
+ def mivolo_d4_448(pretrained=False, **kwargs):
357
+ """VOLO-D4 model, Params: 193M"""
358
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
359
+ model = _create_mivolo("mivolo_d4_448", pretrained=pretrained, **model_args)
360
+ return model
361
+
362
+
363
+ @register_model
364
+ def mivolo_d5_224(pretrained=False, **kwargs):
365
+ model_args = dict(
366
+ layers=(12, 12, 20, 4),
367
+ embed_dims=(384, 768, 768, 768),
368
+ num_heads=(12, 16, 16, 16),
369
+ mlp_ratio=4,
370
+ stem_hidden_dim=128,
371
+ **kwargs
372
+ )
373
+ model = _create_mivolo("mivolo_d5_224", pretrained=pretrained, **model_args)
374
+ return model
375
+
376
+
377
+ @register_model
378
+ def mivolo_d5_448(pretrained=False, **kwargs):
379
+ model_args = dict(
380
+ layers=(12, 12, 20, 4),
381
+ embed_dims=(384, 768, 768, 768),
382
+ num_heads=(12, 16, 16, 16),
383
+ mlp_ratio=4,
384
+ stem_hidden_dim=128,
385
+ **kwargs
386
+ )
387
+ model = _create_mivolo("mivolo_d5_448", pretrained=pretrained, **model_args)
388
+ return model
389
+
390
+
391
+ @register_model
392
+ def mivolo_d5_512(pretrained=False, **kwargs):
393
+ model_args = dict(
394
+ layers=(12, 12, 20, 4),
395
+ embed_dims=(384, 768, 768, 768),
396
+ num_heads=(12, 16, 16, 16),
397
+ mlp_ratio=4,
398
+ stem_hidden_dim=128,
399
+ **kwargs
400
+ )
401
+ model = _create_mivolo("mivolo_d5_512", pretrained=pretrained, **model_args)
402
+ return model
mivolo/model/yolo_detector.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, Union
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+ from mivolo.structures import PersonAndFaceResult
8
+ from ultralytics import YOLO
9
+ # from ultralytics.yolo.engine.results import Results
10
+
11
+ # because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing
12
+ os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
13
+
14
+
15
+ class Detector:
16
+ def __init__(
17
+ self,
18
+ weights: str,
19
+ device: str = "cuda",
20
+ half: bool = True,
21
+ verbose: bool = False,
22
+ conf_thresh: float = 0.4,
23
+ iou_thresh: float = 0.7,
24
+ ):
25
+ self.yolo = YOLO(weights)
26
+ self.yolo.fuse()
27
+
28
+ self.device = torch.device(device)
29
+ self.half = half and self.device.type != "cpu"
30
+
31
+ if self.half:
32
+ self.yolo.model = self.yolo.model.half()
33
+
34
+ self.detector_names: Dict[int, str] = self.yolo.model.names
35
+
36
+ # init yolo.predictor
37
+ self.detector_kwargs = {
38
+ "conf": conf_thresh, "iou": iou_thresh, "half": self.half, "verbose": verbose}
39
+ # self.yolo.predict(**self.detector_kwargs)
40
+
41
+ def predict(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult:
42
+ results = self.yolo.predict(image, **self.detector_kwargs)[0]
43
+ return PersonAndFaceResult(results)
44
+
45
+ def track(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult:
46
+ results = self.yolo.track(
47
+ image, persist=True, **self.detector_kwargs)[0]
48
+ return PersonAndFaceResult(results)
mivolo/predictor.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Dict, Generator, List, Optional, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import tqdm
7
+ from mivolo.model.mi_volo import MiVOLO
8
+ from mivolo.model.yolo_detector import Detector
9
+ from mivolo.structures import AGE_GENDER_TYPE, PersonAndFaceResult
10
+
11
+
12
+ class Predictor:
13
+ def __init__(self, config, verbose: bool = False):
14
+ self.detector = Detector(config.detector_weights, config.device, verbose=verbose)
15
+ self.age_gender_model = MiVOLO(
16
+ config.checkpoint,
17
+ config.device,
18
+ half=True,
19
+ use_persons=config.with_persons,
20
+ disable_faces=config.disable_faces,
21
+ verbose=verbose,
22
+ )
23
+ self.draw = config.draw
24
+
25
+ def recognize(self, image: np.ndarray) -> Tuple[PersonAndFaceResult, Optional[np.ndarray]]:
26
+ detected_objects: PersonAndFaceResult = self.detector.predict(image)
27
+ self.age_gender_model.predict(image, detected_objects)
28
+
29
+ out_im = None
30
+ if self.draw:
31
+ # plot results on image
32
+ out_im = detected_objects.plot()
33
+
34
+ return detected_objects, out_im
35
+
36
+ def recognize_video(self, source: str) -> Generator:
37
+ video_capture = cv2.VideoCapture(source)
38
+ if not video_capture.isOpened():
39
+ raise ValueError(f"Failed to open video source {source}")
40
+
41
+ detected_objects_history: Dict[int, List[AGE_GENDER_TYPE]] = defaultdict(list)
42
+
43
+ total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
44
+ for _ in tqdm.tqdm(range(total_frames)):
45
+ ret, frame = video_capture.read()
46
+ if not ret:
47
+ break
48
+
49
+ detected_objects: PersonAndFaceResult = self.detector.track(frame)
50
+ self.age_gender_model.predict(frame, detected_objects)
51
+
52
+ current_frame_objs = detected_objects.get_results_for_tracking()
53
+ cur_persons: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[0]
54
+ cur_faces: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[1]
55
+
56
+ # add tr_persons and tr_faces to history
57
+ for guid, data in cur_persons.items():
58
+ # not useful for tracking :)
59
+ if None not in data:
60
+ detected_objects_history[guid].append(data)
61
+ for guid, data in cur_faces.items():
62
+ if None not in data:
63
+ detected_objects_history[guid].append(data)
64
+
65
+ detected_objects.set_tracked_age_gender(detected_objects_history)
66
+ if self.draw:
67
+ frame = detected_objects.plot()
68
+ yield detected_objects_history, frame
mivolo/structures.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from mivolo.data.misc import aggregate_votes_winsorized, assign_faces, box_iou, cropout_black_parts
10
+ from ultralytics.yolo.engine.results import Results
11
+ from ultralytics.yolo.utils.plotting import Annotator, colors
12
+
13
+ # because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing
14
+ os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
15
+
16
+ AGE_GENDER_TYPE = Tuple[float, str]
17
+
18
+
19
+ class PersonAndFaceCrops:
20
+ def __init__(self):
21
+ # int: index of person along results
22
+ self.crops_persons: Dict[int, np.ndarray] = {}
23
+
24
+ # int: index of face along results
25
+ self.crops_faces: Dict[int, np.ndarray] = {}
26
+
27
+ # int: index of face along results
28
+ self.crops_faces_wo_body: Dict[int, np.ndarray] = {}
29
+
30
+ # int: index of person along results
31
+ self.crops_persons_wo_face: Dict[int, np.ndarray] = {}
32
+
33
+ def _add_to_output(
34
+ self, crops: Dict[int, np.ndarray], out_crops: List[np.ndarray], out_crop_inds: List[Optional[int]]
35
+ ):
36
+ inds_to_add = list(crops.keys())
37
+ crops_to_add = list(crops.values())
38
+ out_crops.extend(crops_to_add)
39
+ out_crop_inds.extend(inds_to_add)
40
+
41
+ def _get_all_faces(
42
+ self, use_persons: bool, use_faces: bool
43
+ ) -> Tuple[List[Optional[int]], List[Optional[np.ndarray]]]:
44
+ """
45
+ Returns
46
+ if use_persons and use_faces
47
+ faces: faces_with_bodies + faces_without_bodies + [None] * len(crops_persons_wo_face)
48
+ if use_persons and not use_faces
49
+ faces: [None] * n_persons
50
+ if not use_persons and use_faces:
51
+ faces: faces_with_bodies + faces_without_bodies
52
+ """
53
+
54
+ def add_none_to_output(faces_inds, faces_crops, num):
55
+ faces_inds.extend([None for _ in range(num)])
56
+ faces_crops.extend([None for _ in range(num)])
57
+
58
+ faces_inds: List[Optional[int]] = []
59
+ faces_crops: List[Optional[np.ndarray]] = []
60
+
61
+ if not use_faces:
62
+ add_none_to_output(faces_inds, faces_crops, len(self.crops_persons) + len(self.crops_persons_wo_face))
63
+ return faces_inds, faces_crops
64
+
65
+ self._add_to_output(self.crops_faces, faces_crops, faces_inds)
66
+ self._add_to_output(self.crops_faces_wo_body, faces_crops, faces_inds)
67
+
68
+ if use_persons:
69
+ add_none_to_output(faces_inds, faces_crops, len(self.crops_persons_wo_face))
70
+
71
+ return faces_inds, faces_crops
72
+
73
+ def _get_all_bodies(
74
+ self, use_persons: bool, use_faces: bool
75
+ ) -> Tuple[List[Optional[int]], List[Optional[np.ndarray]]]:
76
+ """
77
+ Returns
78
+ if use_persons and use_faces
79
+ persons: bodies_with_faces + [None] * len(faces_without_bodies) + bodies_without_faces
80
+ if use_persons and not use_faces
81
+ persons: bodies_with_faces + bodies_without_faces
82
+ if not use_persons and use_faces
83
+ persons: [None] * n_faces
84
+ """
85
+
86
+ def add_none_to_output(bodies_inds, bodies_crops, num):
87
+ bodies_inds.extend([None for _ in range(num)])
88
+ bodies_crops.extend([None for _ in range(num)])
89
+
90
+ bodies_inds: List[Optional[int]] = []
91
+ bodies_crops: List[Optional[np.ndarray]] = []
92
+
93
+ if not use_persons:
94
+ add_none_to_output(bodies_inds, bodies_crops, len(self.crops_faces) + len(self.crops_faces_wo_body))
95
+ return bodies_inds, bodies_crops
96
+
97
+ self._add_to_output(self.crops_persons, bodies_crops, bodies_inds)
98
+ if use_faces:
99
+ add_none_to_output(bodies_inds, bodies_crops, len(self.crops_faces_wo_body))
100
+
101
+ self._add_to_output(self.crops_persons_wo_face, bodies_crops, bodies_inds)
102
+
103
+ return bodies_inds, bodies_crops
104
+
105
+ def get_faces_with_bodies(self, use_persons: bool, use_faces: bool):
106
+ """
107
+ Return
108
+ faces: faces_with_bodies, faces_without_bodies, [None] * len(crops_persons_wo_face)
109
+ persons: bodies_with_faces, [None] * len(faces_without_bodies), bodies_without_faces
110
+ """
111
+
112
+ bodies_inds, bodies_crops = self._get_all_bodies(use_persons, use_faces)
113
+ faces_inds, faces_crops = self._get_all_faces(use_persons, use_faces)
114
+
115
+ return (bodies_inds, bodies_crops), (faces_inds, faces_crops)
116
+
117
+ def save(self, out_dir="output"):
118
+ ind = 0
119
+ os.makedirs(out_dir, exist_ok=True)
120
+ for crops in [self.crops_persons, self.crops_faces, self.crops_faces_wo_body, self.crops_persons_wo_face]:
121
+ for crop in crops.values():
122
+ if crop is None:
123
+ continue
124
+ out_name = os.path.join(out_dir, f"{ind}_crop.jpg")
125
+ cv2.imwrite(out_name, crop)
126
+ ind += 1
127
+
128
+
129
+ class PersonAndFaceResult:
130
+ def __init__(self, results: Results):
131
+
132
+ self.yolo_results = results
133
+ names = set(results.names.values())
134
+ assert "person" in names and "face" in names
135
+
136
+ # initially no faces and persons are associated to each other
137
+ self.face_to_person_map: Dict[int, Optional[int]] = {ind: None for ind in self.get_bboxes_inds("face")}
138
+ self.unassigned_persons_inds: List[int] = self.get_bboxes_inds("person")
139
+ n_objects = len(self.yolo_results.boxes)
140
+ self.ages: List[Optional[float]] = [None for _ in range(n_objects)]
141
+ self.genders: List[Optional[str]] = [None for _ in range(n_objects)]
142
+ self.gender_scores: List[Optional[float]] = [None for _ in range(n_objects)]
143
+
144
+ @property
145
+ def n_objects(self) -> int:
146
+ return len(self.yolo_results.boxes)
147
+
148
+ def get_bboxes_inds(self, category: str) -> List[int]:
149
+ bboxes: List[int] = []
150
+ for ind, det in enumerate(self.yolo_results.boxes):
151
+ name = self.yolo_results.names[int(det.cls)]
152
+ if name == category:
153
+ bboxes.append(ind)
154
+
155
+ return bboxes
156
+
157
+ def get_distance_to_center(self, bbox_ind: int) -> float:
158
+ """
159
+ Calculate euclidian distance between bbox center and image center.
160
+ """
161
+ im_h, im_w = self.yolo_results[bbox_ind].orig_shape
162
+ x1, y1, x2, y2 = self.get_bbox_by_ind(bbox_ind).cpu().numpy()
163
+ center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
164
+ dist = math.dist([center_x, center_y], [im_w / 2, im_h / 2])
165
+ return dist
166
+
167
+ def plot(
168
+ self,
169
+ conf=False,
170
+ line_width=None,
171
+ font_size=None,
172
+ font="Arial.ttf",
173
+ pil=False,
174
+ img=None,
175
+ labels=True,
176
+ boxes=True,
177
+ probs=True,
178
+ ages=True,
179
+ genders=True,
180
+ gender_probs=False,
181
+ ):
182
+ """
183
+ Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
184
+ Args:
185
+ conf (bool): Whether to plot the detection confidence score.
186
+ line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
187
+ font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
188
+ font (str): The font to use for the text.
189
+ pil (bool): Whether to return the image as a PIL Image.
190
+ img (numpy.ndarray): Plot to another image. if not, plot to original image.
191
+ labels (bool): Whether to plot the label of bounding boxes.
192
+ boxes (bool): Whether to plot the bounding boxes.
193
+ probs (bool): Whether to plot classification probability
194
+ ages (bool): Whether to plot the age of bounding boxes.
195
+ genders (bool): Whether to plot the genders of bounding boxes.
196
+ gender_probs (bool): Whether to plot gender classification probability
197
+ Returns:
198
+ (numpy.ndarray): A numpy array of the annotated image.
199
+ """
200
+
201
+ # return self.yolo_results.plot()
202
+ colors_by_ind = {}
203
+ for face_ind, person_ind in self.face_to_person_map.items():
204
+ if person_ind is not None:
205
+ colors_by_ind[face_ind] = face_ind + 2
206
+ colors_by_ind[person_ind] = face_ind + 2
207
+ else:
208
+ colors_by_ind[face_ind] = 0
209
+ for person_ind in self.unassigned_persons_inds:
210
+ colors_by_ind[person_ind] = 1
211
+
212
+ names = self.yolo_results.names
213
+ annotator = Annotator(
214
+ deepcopy(self.yolo_results.orig_img if img is None else img),
215
+ line_width,
216
+ font_size,
217
+ font,
218
+ pil,
219
+ example=names,
220
+ )
221
+ pred_boxes, show_boxes = self.yolo_results.boxes, boxes
222
+ pred_probs, show_probs = self.yolo_results.probs, probs
223
+
224
+ if pred_boxes and show_boxes:
225
+ for bb_ind, (d, age, gender, gender_score) in enumerate(
226
+ zip(pred_boxes, self.ages, self.genders, self.gender_scores)
227
+ ):
228
+ c, conf, guid = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
229
+ name = ("" if guid is None else f"id:{guid} ") + names[c]
230
+ label = (f"{name} {conf:.2f}" if conf else name) if labels else None
231
+ if ages and age is not None:
232
+ label += f" {age:.1f}"
233
+ if genders and gender is not None:
234
+ label += f" {'F' if gender == 'female' else 'M'}"
235
+ if gender_probs and gender_score is not None:
236
+ label += f" ({gender_score:.1f})"
237
+ annotator.box_label(d.xyxy.squeeze(), label, color=colors(colors_by_ind[bb_ind], True))
238
+
239
+ if pred_probs is not None and show_probs:
240
+ text = f"{', '.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)}, "
241
+ annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
242
+
243
+ return annotator.result()
244
+
245
+ def set_tracked_age_gender(self, tracked_objects: Dict[int, List[AGE_GENDER_TYPE]]):
246
+ """
247
+ Update age and gender for objects based on history from tracked_objects.
248
+ Args:
249
+ tracked_objects (dict[int, list[AGE_GENDER_TYPE]]): info about tracked objects by guid
250
+ """
251
+
252
+ for face_ind, person_ind in self.face_to_person_map.items():
253
+ pguid = self._get_id_by_ind(person_ind)
254
+ fguid = self._get_id_by_ind(face_ind)
255
+
256
+ if fguid == -1 and pguid == -1:
257
+ # YOLO might not assign ids for some objects in some cases:
258
+ # https://github.com/ultralytics/ultralytics/issues/3830
259
+ continue
260
+ age, gender = self._gather_tracking_result(tracked_objects, fguid, pguid)
261
+ if age is None or gender is None:
262
+ continue
263
+ self.set_age(face_ind, age)
264
+ self.set_gender(face_ind, gender, 1.0)
265
+ if pguid != -1:
266
+ self.set_gender(person_ind, gender, 1.0)
267
+ self.set_age(person_ind, age)
268
+
269
+ for person_ind in self.unassigned_persons_inds:
270
+ pid = self._get_id_by_ind(person_ind)
271
+ if pid == -1:
272
+ continue
273
+ age, gender = self._gather_tracking_result(tracked_objects, -1, pid)
274
+ if age is None or gender is None:
275
+ continue
276
+ self.set_gender(person_ind, gender, 1.0)
277
+ self.set_age(person_ind, age)
278
+
279
+ def _get_id_by_ind(self, ind: Optional[int] = None) -> int:
280
+ if ind is None:
281
+ return -1
282
+ obj_id = self.yolo_results.boxes[ind].id
283
+ if obj_id is None:
284
+ return -1
285
+ return obj_id.item()
286
+
287
+ def get_bbox_by_ind(self, ind: int, im_h: int = None, im_w: int = None) -> torch.tensor:
288
+ bb = self.yolo_results.boxes[ind].xyxy.squeeze().type(torch.int32)
289
+ if im_h is not None and im_w is not None:
290
+ bb[0] = torch.clamp(bb[0], min=0, max=im_w - 1)
291
+ bb[1] = torch.clamp(bb[1], min=0, max=im_h - 1)
292
+ bb[2] = torch.clamp(bb[2], min=0, max=im_w - 1)
293
+ bb[3] = torch.clamp(bb[3], min=0, max=im_h - 1)
294
+ return bb
295
+
296
+ def set_age(self, ind: Optional[int], age: float):
297
+ if ind is not None:
298
+ self.ages[ind] = age
299
+
300
+ def set_gender(self, ind: Optional[int], gender: str, gender_score: float):
301
+ if ind is not None:
302
+ self.genders[ind] = gender
303
+ self.gender_scores[ind] = gender_score
304
+
305
+ @staticmethod
306
+ def _gather_tracking_result(
307
+ tracked_objects: Dict[int, List[AGE_GENDER_TYPE]],
308
+ fguid: int = -1,
309
+ pguid: int = -1,
310
+ minimum_sample_size: int = 10,
311
+ ) -> AGE_GENDER_TYPE:
312
+
313
+ assert fguid != -1 or pguid != -1, "Incorrect tracking behaviour"
314
+
315
+ face_ages = [r[0] for r in tracked_objects[fguid] if r[0] is not None] if fguid in tracked_objects else []
316
+ face_genders = [r[1] for r in tracked_objects[fguid] if r[1] is not None] if fguid in tracked_objects else []
317
+ person_ages = [r[0] for r in tracked_objects[pguid] if r[0] is not None] if pguid in tracked_objects else []
318
+ person_genders = [r[1] for r in tracked_objects[pguid] if r[1] is not None] if pguid in tracked_objects else []
319
+
320
+ if not face_ages and not person_ages: # both empty
321
+ return None, None
322
+
323
+ # You can play here with different aggregation strategies
324
+ # Face ages - predictions based on face or face + person, depends on history of object
325
+ # Person ages - predictions based on person or face + person, depends on history of object
326
+
327
+ if len(person_ages + face_ages) >= minimum_sample_size:
328
+ age = aggregate_votes_winsorized(person_ages + face_ages)
329
+ else:
330
+ face_age = np.mean(face_ages) if face_ages else None
331
+ person_age = np.mean(person_ages) if person_ages else None
332
+ if face_age is None:
333
+ face_age = person_age
334
+ if person_age is None:
335
+ person_age = face_age
336
+ age = (face_age + person_age) / 2.0
337
+
338
+ genders = face_genders + person_genders
339
+ assert len(genders) > 0
340
+ # take mode of genders
341
+ gender = max(set(genders), key=genders.count)
342
+
343
+ return age, gender
344
+
345
+ def get_results_for_tracking(self) -> Tuple[Dict[int, AGE_GENDER_TYPE], Dict[int, AGE_GENDER_TYPE]]:
346
+ """
347
+ Get objects from current frame
348
+ """
349
+ persons: Dict[int, AGE_GENDER_TYPE] = {}
350
+ faces: Dict[int, AGE_GENDER_TYPE] = {}
351
+
352
+ names = self.yolo_results.names
353
+ pred_boxes = self.yolo_results.boxes
354
+ for _, (det, age, gender, _) in enumerate(zip(pred_boxes, self.ages, self.genders, self.gender_scores)):
355
+ if det.id is None:
356
+ continue
357
+ cat_id, _, guid = int(det.cls), float(det.conf), int(det.id.item())
358
+ name = names[cat_id]
359
+ if name == "person":
360
+ persons[guid] = (age, gender)
361
+ elif name == "face":
362
+ faces[guid] = (age, gender)
363
+
364
+ return persons, faces
365
+
366
+ def associate_faces_with_persons(self):
367
+ face_bboxes_inds: List[int] = self.get_bboxes_inds("face")
368
+ person_bboxes_inds: List[int] = self.get_bboxes_inds("person")
369
+
370
+ face_bboxes: List[torch.tensor] = [self.get_bbox_by_ind(ind) for ind in face_bboxes_inds]
371
+ person_bboxes: List[torch.tensor] = [self.get_bbox_by_ind(ind) for ind in person_bboxes_inds]
372
+
373
+ self.face_to_person_map = {ind: None for ind in face_bboxes_inds}
374
+ assigned_faces, unassigned_persons_inds = assign_faces(person_bboxes, face_bboxes)
375
+
376
+ for face_ind, person_ind in enumerate(assigned_faces):
377
+ face_ind = face_bboxes_inds[face_ind]
378
+ person_ind = person_bboxes_inds[person_ind] if person_ind is not None else None
379
+ self.face_to_person_map[face_ind] = person_ind
380
+
381
+ self.unassigned_persons_inds = [person_bboxes_inds[person_ind] for person_ind in unassigned_persons_inds]
382
+
383
+ def crop_object(
384
+ self, full_image: np.ndarray, ind: int, cut_other_classes: Optional[List[str]] = None
385
+ ) -> Optional[np.ndarray]:
386
+
387
+ IOU_THRESH = 0.000001
388
+ MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4
389
+ CROP_ROUND_RATE = 0.3
390
+ MIN_PERSON_SIZE = 50
391
+
392
+ obj_bbox = self.get_bbox_by_ind(ind, *full_image.shape[:2])
393
+ x1, y1, x2, y2 = obj_bbox
394
+ cur_cat = self.yolo_results.names[int(self.yolo_results.boxes[ind].cls)]
395
+ # get crop of face or person
396
+ obj_image = full_image[y1:y2, x1:x2].copy()
397
+ crop_h, crop_w = obj_image.shape[:2]
398
+
399
+ if cur_cat == "person" and (crop_h < MIN_PERSON_SIZE or crop_w < MIN_PERSON_SIZE):
400
+ return None
401
+
402
+ if not cut_other_classes:
403
+ return obj_image
404
+
405
+ # calc iou between obj_bbox and other bboxes
406
+ other_bboxes: List[torch.tensor] = [
407
+ self.get_bbox_by_ind(other_ind, *full_image.shape[:2]) for other_ind in range(len(self.yolo_results.boxes))
408
+ ]
409
+
410
+ iou_matrix = box_iou(torch.stack([obj_bbox]), torch.stack(other_bboxes)).cpu().numpy()[0]
411
+
412
+ # cut out other objects in case of intersection
413
+ for other_ind, (det, iou) in enumerate(zip(self.yolo_results.boxes, iou_matrix)):
414
+ other_cat = self.yolo_results.names[int(det.cls)]
415
+ if ind == other_ind or iou < IOU_THRESH or other_cat not in cut_other_classes:
416
+ continue
417
+ o_x1, o_y1, o_x2, o_y2 = det.xyxy.squeeze().type(torch.int32)
418
+
419
+ # remap current_person_bbox to reference_person_bbox coordinates
420
+ o_x1 = max(o_x1 - x1, 0)
421
+ o_y1 = max(o_y1 - y1, 0)
422
+ o_x2 = min(o_x2 - x1, crop_w)
423
+ o_y2 = min(o_y2 - y1, crop_h)
424
+
425
+ if other_cat != "face":
426
+ if (o_y1 / crop_h) < CROP_ROUND_RATE:
427
+ o_y1 = 0
428
+ if ((crop_h - o_y2) / crop_h) < CROP_ROUND_RATE:
429
+ o_y2 = crop_h
430
+ if (o_x1 / crop_w) < CROP_ROUND_RATE:
431
+ o_x1 = 0
432
+ if ((crop_w - o_x2) / crop_w) < CROP_ROUND_RATE:
433
+ o_x2 = crop_w
434
+
435
+ obj_image[o_y1:o_y2, o_x1:o_x2] = 0
436
+
437
+ obj_image, remain_ratio = cropout_black_parts(obj_image, CROP_ROUND_RATE)
438
+ if remain_ratio < MIN_PERSON_CROP_AFTERCUT_RATIO:
439
+ return None
440
+
441
+ return obj_image
442
+
443
+ def collect_crops(self, image) -> PersonAndFaceCrops:
444
+
445
+ crops_data = PersonAndFaceCrops()
446
+ for face_ind, person_ind in self.face_to_person_map.items():
447
+ face_image = self.crop_object(image, face_ind, cut_other_classes=[])
448
+
449
+ if person_ind is None:
450
+ crops_data.crops_faces_wo_body[face_ind] = face_image
451
+ continue
452
+
453
+ person_image = self.crop_object(image, person_ind, cut_other_classes=["face", "person"])
454
+
455
+ crops_data.crops_faces[face_ind] = face_image
456
+ crops_data.crops_persons[person_ind] = person_image
457
+
458
+ for person_ind in self.unassigned_persons_inds:
459
+ person_image = self.crop_object(image, person_ind, cut_other_classes=["face", "person"])
460
+ crops_data.crops_persons_wo_face[person_ind] = person_image
461
+
462
+ # uncomment to save preprocessed crops
463
+ # crops_data.save()
464
+ return crops_data
mivolo/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.3.0dev"
product2item.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from urllib.parse import quote
3
+ from bs4 import BeautifulSoup
4
+ from selenium import webdriver
5
+ from utils import *
6
+
7
+
8
+ def append_dict_to_jsonl(dictionary, file_path='output/items.jsonl'):
9
+ with open(file_path, 'a', encoding='utf-8') as jsonl_file:
10
+ json.dump(dictionary, jsonl_file)
11
+ jsonl_file.write('\n')
12
+
13
+
14
+ def get_second_links(keyword):
15
+ # selenium
16
+ option = webdriver.ChromeOptions()
17
+ option.add_experimental_option('excludeSwitches', ['enable-automation'])
18
+ option.add_argument("--disable-blink-features=AutomationControlled")
19
+ # option.add_argument('--headless')
20
+ browser = webdriver.Chrome(options=option)
21
+ browser.get(f'https://www.taobao.com/list/product/{quote(keyword)}.htm')
22
+ # browser.minimize_window()
23
+ browser.maximize_window()
24
+
25
+ skip_captcha()
26
+
27
+ # 遍历product页面下的所有item,直至已加载全部商品
28
+ i = 1
29
+ while i > 0:
30
+ browser.execute_script(
31
+ f'window.scrollTo(0, {i * 500})')
32
+ i += 1
33
+ rand_sleep()
34
+ page_str = str(browser.page_source)
35
+ if "<title>taobao | 淘寶</title>" in page_str:
36
+ return []
37
+
38
+ if "已加载全部商品" in page_str:
39
+ break
40
+
41
+ if "加载错误,请重试" in page_str:
42
+ break
43
+
44
+ html_content = browser.page_source
45
+
46
+ # bs4
47
+ soup = BeautifulSoup(html_content, 'html.parser')
48
+ return [link.get('href') for link in soup.find_all('a', class_='item')]
49
+
50
+
51
+ def read_lines_to_array(file_path):
52
+ create_dir('./' + os.path.dirname(file_path))
53
+ lines_array = []
54
+ with open(file_path, 'r', encoding='utf-8') as file:
55
+ for line in file:
56
+ lines_array.append(line.strip())
57
+
58
+ return lines_array
59
+
60
+
61
+ def save_to_file(data_list, file_path='output/items.jsonl'):
62
+ with open(file_path, 'w', encoding='utf-8') as jsonl_file:
63
+ for data in data_list:
64
+ json.dump(data, jsonl_file, ensure_ascii=(
65
+ file_path != 'output/items.jsonl'))
66
+ jsonl_file.write('\n')
67
+
68
+
69
+ def rm_duplicates_by_key(file_path='output/items.jsonl', key_to_check='id'):
70
+ data_set = set()
71
+ unique_data = []
72
+ duplicates = set()
73
+
74
+ with open(file_path, 'r', encoding='utf-8') as jsonl_file:
75
+ for line in jsonl_file:
76
+ data = json.loads(line)
77
+
78
+ # 提取指定键值的值,并用作判断重复的标识
79
+ key_value = data.get(key_to_check)
80
+
81
+ # 如果标识值已存在,表示数据重复
82
+ if key_value in data_set:
83
+ duplicates.add(key_value)
84
+ continue
85
+ else:
86
+ data_set.add(key_value)
87
+ unique_data.append(data)
88
+
89
+ save_to_file(unique_data)
90
+ save_to_file(duplicates, file_path='output/duplicates.txt')
91
+
92
+
93
+ if __name__ == "__main__":
94
+ keywords = read_lines_to_array('input/keywords.txt')
95
+ create_dir('./output')
96
+
97
+ for key in keywords:
98
+ for url in get_second_links(key):
99
+ append_dict_to_jsonl({
100
+ 'keyword': key,
101
+ 'id': url.split('.htm?spm=')[0].split('//www.taobao.com/list/item/')[1]
102
+ })
103
+
104
+ rm_duplicates_by_key()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ requests
2
+ beautifulsoup4
3
+ selenium
4
+ Cython==0.29.28
5
+ ultralytics
6
+ timm==0.8.13.dev0
7
+ omegaconf
utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from time import sleep
4
+
5
+
6
+ def create_dir(dir_path):
7
+ if not os.path.exists(dir_path):
8
+ os.makedirs(dir_path)
9
+
10
+
11
+ def skip_captcha():
12
+ print('Skipping the captcha...')
13
+
14
+
15
+ def rand_sleep():
16
+ sleep(0.5 + random.random() * 0.5)