George commited on
Commit
61214ab
1 Parent(s): 6c635c4

change to clean model

Browse files
conda.txt DELETED
@@ -1,5 +0,0 @@
1
- python=3.9
2
- pytorch=1.12.1
3
- torchvision=0.13.1
4
- torchaudio=0.12.1
5
- cudatoolkit=11.3.1
 
 
 
 
 
 
gender_age.py DELETED
@@ -1,93 +0,0 @@
1
- import cv2
2
- import shutil
3
- import numpy as np
4
- from dataclasses import dataclass
5
- from tqdm import tqdm
6
- from mivolo.predictor import Predictor
7
- from utils import *
8
-
9
- import warnings
10
- warnings.filterwarnings("ignore")
11
-
12
-
13
- @dataclass
14
- class Cfg:
15
- detector_weights: str
16
- checkpoint: str
17
- device: str = "cuda"
18
- with_persons: bool = True
19
- disable_faces: bool = False
20
- draw: bool = True
21
-
22
-
23
- class ValidImgDetector:
24
-
25
- predictor = None
26
-
27
- def __init__(self):
28
- detector_path = "./model/yolov8x_person_face.pt"
29
- age_gender_path = "./model/model_imdb_cross_person_4.22_99.46.pth.tar"
30
- predictor_cfg = Cfg(detector_path, age_gender_path)
31
- self.predictor = Predictor(predictor_cfg)
32
-
33
- def _detect(
34
- self,
35
- image: np.ndarray,
36
- score_threshold: float,
37
- iou_threshold: float,
38
- mode: str,
39
- predictor: Predictor
40
- ) -> np.ndarray:
41
- # input is rgb image, output must be rgb too
42
- predictor.detector.detector_kwargs['conf'] = score_threshold
43
- predictor.detector.detector_kwargs['iou'] = iou_threshold
44
-
45
- if mode == "Use persons and faces":
46
- use_persons = True
47
- disable_faces = False
48
- elif mode == "Use persons only":
49
- use_persons = True
50
- disable_faces = True
51
- elif mode == "Use faces only":
52
- use_persons = False
53
- disable_faces = False
54
-
55
- predictor.age_gender_model.meta.use_persons = use_persons
56
- predictor.age_gender_model.meta.disable_faces = disable_faces
57
-
58
- # image = image[:, :, ::-1] # RGB -> BGR
59
- detected_objects, _ = predictor.recognize(image)
60
-
61
- has_child, has_female, has_male = False, False, False
62
- if len(detected_objects.ages) > 0:
63
- has_child = min(detected_objects.ages) < 18
64
- has_female = 'female' in detected_objects.genders
65
- has_male = 'male' in detected_objects.genders
66
-
67
- return has_child, has_female, has_male
68
-
69
- def valid_img(self, img_path):
70
- image = cv2.imread(img_path)
71
- has_child, has_female, has_male = self._detect(
72
- image, 0.4, 0.7, "Use persons and faces", self.predictor)
73
- return (not has_child) and (has_female) and (not has_male)
74
-
75
-
76
- def filter_img():
77
- detector = ValidImgDetector()
78
- create_dir('./output/valid')
79
- create_dir('./output/invalid')
80
-
81
- for _, _, files in os.walk('./images'):
82
- for file in tqdm(files):
83
- if file.endswith('.jpg'):
84
- src_path = f"./images/{file}"
85
- dst_path = "./output/invalid"
86
- if detector.valid_img(src_path):
87
- dst_path = "./output/valid"
88
-
89
- shutil.move(src_path, dst_path)
90
-
91
-
92
- if __name__ == "__main__":
93
- filter_img()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
human_detect.py DELETED
@@ -1,38 +0,0 @@
1
- import torch
2
- import torchvision.transforms as transforms
3
- from PIL import Image
4
- from torchvision.models.detection import fasterrcnn_resnet50_fpn
5
-
6
-
7
- def has_person(image_path):
8
- # 加载预训练的 Faster R-CNN 模型
9
- model = fasterrcnn_resnet50_fpn(pretrained=True)
10
- model.eval()
11
-
12
- # 载入并预处理图片
13
- img = Image.open(image_path)
14
- transform = transforms.Compose([transforms.ToTensor()])
15
- input_tensor = transform(img)
16
- input_batch = input_tensor.unsqueeze(0)
17
-
18
- # 模型推理
19
- with torch.no_grad():
20
- output = model(input_batch)
21
-
22
- # 解析输出结果
23
- labels = output[0]['labels'].numpy()
24
- scores = output[0]['scores'].numpy()
25
-
26
- # 判断是否检测到人体(label=1 表示人类类别)
27
- person_detected = any(label == 1 and score >
28
- 0.5 for label, score in zip(labels, scores))
29
-
30
- return person_detected
31
-
32
-
33
- if __name__ == "__main__":
34
- image_path = './images/test.jpg'
35
- if has_person(image_path):
36
- print("图片中检测到人体。")
37
- else:
38
- print("图片中没有检测到人体。")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
item2pic.py DELETED
@@ -1,173 +0,0 @@
1
- import json
2
- import requests
3
- from bs4 import BeautifulSoup
4
- from selenium import webdriver
5
- from tqdm import tqdm
6
- from utils import *
7
-
8
- DEBUG_MODE = False
9
-
10
-
11
- def add_to_failist(urlog, file_path="./output/failist.txt"):
12
- with open(file_path, 'a', encoding='utf-8') as file:
13
- file.write(urlog + "\n")
14
-
15
- if DEBUG_MODE:
16
- print(urlog)
17
-
18
-
19
- def download_image(img_dir='./images'):
20
- create_dir(img_dir)
21
- image_urls = load_urls()
22
- print('下载图片中...')
23
- trytime = 0
24
- while len(image_urls) > 0:
25
- failist = []
26
- for img in tqdm(image_urls):
27
- sleeps(0.5 + 0.1 * trytime, 1.0 + 0.1 * trytime)
28
- response = requests.get(img['url'], stream=True)
29
-
30
- if response.status_code == 200:
31
- # 从URL中获取图像文件名
32
- image_filename = f'{img_dir}/{img["pid"]}_{img["url"].split("/")[-1]}'
33
-
34
- # 使用二进制写模式打开文件,准备写入图像数据
35
- with open(image_filename, 'wb') as file:
36
- for chunk in response.iter_content(chunk_size=8192):
37
- file.write(chunk)
38
-
39
- if DEBUG_MODE:
40
- print(f"{image_filename} 下载完成!")
41
-
42
- elif response.status_code == 420:
43
- failist.append(img)
44
-
45
- else:
46
- add_to_failist(
47
- f"下载 {img['url']} 失败: HTTP 错误码 {response.status_code}")
48
-
49
- trytime += 1
50
- print(
51
- f'[{len(failist)} / {len(image_urls)}] images failed to download in attempt [{trytime}].')
52
- image_urls = failist
53
-
54
- print('下载完成!')
55
-
56
-
57
- def fix_url(link):
58
- tmp_url = link.get('src')
59
-
60
- if tmp_url[:2] == '//':
61
- tmp_url = 'https:' + tmp_url
62
-
63
- if '.png_' in tmp_url:
64
- tmp_url = tmp_url.split('.png_')[0] + '.png'
65
-
66
- elif '.gif_' in tmp_url:
67
- tmp_url = tmp_url.split('.gif_')[0] + '.gif'
68
-
69
- else:
70
- tmp_url = tmp_url.split('.jpg_')[0] + '.jpg'
71
-
72
- return tmp_url
73
-
74
-
75
- def get_pics(id):
76
- sleeps(1.0, 1.5)
77
- # selenium
78
- option = webdriver.ChromeOptions()
79
- option.add_experimental_option('excludeSwitches', ['enable-automation'])
80
- option.add_argument("--disable-blink-features=AutomationControlled")
81
- # option.add_argument('--headless')
82
- browser = webdriver.Chrome(options=option)
83
- browser.get(f'https://www.taobao.com/list/item/{id}.htm')
84
- # browser.minimize_window()
85
- browser.maximize_window()
86
-
87
- skip_captcha()
88
-
89
- # bs4
90
- soup = BeautifulSoup(browser.page_source, 'html.parser')
91
- srcs = set()
92
-
93
- try:
94
- for link in soup.find_all('img', class_='item-thumbnail'):
95
- srcs.add(fix_url(link))
96
-
97
- for link in soup.find_all('img', class_='property-img'):
98
- srcs.add(fix_url(link))
99
-
100
- for link in soup.find('div', class_='detail-content').find('p').find_all('img'):
101
- srcs.add(fix_url(link))
102
-
103
- except Exception as err:
104
- print("Error: ", err)
105
-
106
- return srcs
107
-
108
-
109
- def load_items(items_jsonl_path='./output/items.jsonl'):
110
- ids = []
111
- with open(items_jsonl_path, 'r', encoding='utf-8') as items_jsonl:
112
- for line in items_jsonl:
113
- # 将JSON字符串转换为Python对象
114
- data = json.loads(line)
115
- # 获取字典中的'id'键值的值,并添加到列表中
116
- id_value = data.get('id')
117
- if id_value is not None:
118
- ids.append(id_value)
119
-
120
- return ids
121
-
122
-
123
- def get_img_urls(ids, images_jsonl_path="./output/images.jsonl"):
124
- for id in ids:
125
- urls = get_pics(id)
126
- with open(images_jsonl_path, 'a', encoding='utf-8') as images_jsonl:
127
- for url in urls:
128
- img = {
129
- 'url': url,
130
- 'pid': id
131
- }
132
- json.dump(img, images_jsonl)
133
- images_jsonl.write('\n')
134
-
135
-
136
- def load_urls(images_jsonl_path="./output/images.jsonl"):
137
- urls = []
138
- with open(images_jsonl_path, 'r', encoding='utf-8') as items_jsonl:
139
- for line in items_jsonl:
140
- # 将JSON字符串转换为Python对象
141
- data = json.loads(line)
142
- tmp_dict = {
143
- 'url': data.get('url'),
144
- 'pid': data.get('pid')
145
- }
146
- if tmp_dict is not None:
147
- urls.append(tmp_dict)
148
-
149
- return urls
150
-
151
-
152
- def item_to_pic():
153
- create_dir('./images')
154
- ids = load_items()
155
- get_img_urls(ids)
156
- rm_duplicates_by_key(
157
- jsonl_path='./output/images.jsonl',
158
- key_to_check='url',
159
- failist_path='./output/duplicate_img.txt'
160
- )
161
- download_image()
162
-
163
-
164
- if __name__ == "__main__":
165
- # create_dir('./images')
166
- # ids = load_items()
167
- # get_img_urls(ids)
168
- # rm_duplicates_by_key(
169
- # jsonl_path='./output/images.jsonl',
170
- # key_to_check='url',
171
- # failist_path='./output/duplicate_img.txt'
172
- # )
173
- download_image()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py DELETED
@@ -1,9 +0,0 @@
1
- from product2item import product_to_items
2
- from item2pic import item_to_pic
3
- from gender_age import filter_img
4
-
5
-
6
- if __name__ == "__main__":
7
- product_to_items()
8
- item_to_pic()
9
- filter_img()
 
 
 
 
 
 
 
 
 
 
mivolo/data/data_reader.py DELETED
@@ -1,125 +0,0 @@
1
- import os
2
- from collections import defaultdict
3
- from dataclasses import dataclass, field
4
- from enum import Enum
5
- from typing import Dict, List, Optional, Tuple
6
-
7
- import pandas as pd
8
-
9
- IMAGES_EXT: Tuple = (".jpeg", ".jpg", ".png", ".webp", ".bmp", ".gif")
10
- VIDEO_EXT: Tuple = (".mp4", ".avi", ".mov", ".mkv", ".webm")
11
-
12
-
13
- @dataclass
14
- class PictureInfo:
15
- image_path: str
16
- age: Optional[str] # age or age range(start;end format) or "-1"
17
- gender: Optional[str] # "M" of "F" or "-1"
18
- bbox: List[int] = field(default_factory=lambda: [-1, -1, -1, -1]) # face bbox: xyxy
19
- person_bbox: List[int] = field(default_factory=lambda: [-1, -1, -1, -1]) # person bbox: xyxy
20
-
21
- @property
22
- def has_person_bbox(self) -> bool:
23
- return any(coord != -1 for coord in self.person_bbox)
24
-
25
- @property
26
- def has_face_bbox(self) -> bool:
27
- return any(coord != -1 for coord in self.bbox)
28
-
29
- def has_gt(self, only_age: bool = False) -> bool:
30
- if only_age:
31
- return self.age != "-1"
32
- else:
33
- return not (self.age == "-1" and self.gender == "-1")
34
-
35
- def clear_person_bbox(self):
36
- self.person_bbox = [-1, -1, -1, -1]
37
-
38
- def clear_face_bbox(self):
39
- self.bbox = [-1, -1, -1, -1]
40
-
41
-
42
- class AnnotType(Enum):
43
- ORIGINAL = "original"
44
- PERSONS = "persons"
45
- NONE = "none"
46
-
47
- @classmethod
48
- def _missing_(cls, value):
49
- print(f"WARN: Unknown annotation type {value}.")
50
- return AnnotType.NONE
51
-
52
-
53
- def get_all_files(path: str, extensions: Tuple = IMAGES_EXT):
54
- files_all = []
55
- for root, subFolders, files in os.walk(path):
56
- for name in files:
57
- # linux tricks with .directory that still is file
58
- if "directory" not in name and sum([ext.lower() in name.lower() for ext in extensions]) > 0:
59
- files_all.append(os.path.join(root, name))
60
- return files_all
61
-
62
-
63
- class InputType(Enum):
64
- Image = 0
65
- Video = 1
66
- VideoStream = 2
67
-
68
-
69
- def get_input_type(input_path: str) -> InputType:
70
- if os.path.isdir(input_path):
71
- print("Input is a folder, only images will be processed")
72
- return InputType.Image
73
- elif os.path.isfile(input_path):
74
- if input_path.endswith(VIDEO_EXT):
75
- return InputType.Video
76
- if input_path.endswith(IMAGES_EXT):
77
- return InputType.Image
78
- else:
79
- raise ValueError(
80
- f"Unknown or unsupported input file format {input_path}, \
81
- supported video formats: {VIDEO_EXT}, \
82
- supported image formats: {IMAGES_EXT}"
83
- )
84
- elif input_path.startswith("http") and not input_path.endswith(IMAGES_EXT):
85
- return InputType.VideoStream
86
- else:
87
- raise ValueError(f"Unknown input {input_path}")
88
-
89
-
90
- def read_csv_annotation_file(annotation_file: str, images_dir: str, ignore_without_gt=False):
91
- bboxes_per_image: Dict[str, List[PictureInfo]] = defaultdict(list)
92
-
93
- df = pd.read_csv(annotation_file, sep=",")
94
-
95
- annot_type = AnnotType("persons") if "person_x0" in df.columns else AnnotType("original")
96
- print(f"Reading {annotation_file} (type: {annot_type})...")
97
-
98
- missing_images = 0
99
- for index, row in df.iterrows():
100
- img_path = os.path.join(images_dir, row["img_name"])
101
- if not os.path.exists(img_path):
102
- missing_images += 1
103
- continue
104
-
105
- face_x1, face_y1, face_x2, face_y2 = row["face_x0"], row["face_y0"], row["face_x1"], row["face_y1"]
106
- age, gender = str(row["age"]), str(row["gender"])
107
-
108
- if ignore_without_gt and (age == "-1" or gender == "-1"):
109
- continue
110
-
111
- if annot_type == AnnotType.PERSONS:
112
- p_x1, p_y1, p_x2, p_y2 = row["person_x0"], row["person_y0"], row["person_x1"], row["person_y1"]
113
- person_bbox = list(map(int, [p_x1, p_y1, p_x2, p_y2]))
114
- else:
115
- person_bbox = [-1, -1, -1, -1]
116
-
117
- bbox = list(map(int, [face_x1, face_y1, face_x2, face_y2]))
118
- pic_info = PictureInfo(img_path, age, gender, bbox, person_bbox)
119
- assert isinstance(pic_info.person_bbox, list)
120
-
121
- bboxes_per_image[img_path].append(pic_info)
122
-
123
- if missing_images > 0:
124
- print(f"WARNING: Missing images: {missing_images}/{len(df)}")
125
- return bboxes_per_image, annot_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mivolo/data/dataset/__init__.py DELETED
@@ -1,64 +0,0 @@
1
- from typing import Tuple
2
-
3
- import torch
4
- from mivolo.model.mi_volo import MiVOLO
5
-
6
- from .age_gender_dataset import AgeGenderDataset
7
- from .age_gender_loader import create_loader
8
- from .classification_dataset import AdienceDataset, FairFaceDataset
9
-
10
- DATASET_CLASS_MAP = {
11
- "utk": AgeGenderDataset,
12
- "lagenda": AgeGenderDataset,
13
- "imdb": AgeGenderDataset,
14
- "adience": AdienceDataset,
15
- "fairface": FairFaceDataset,
16
- }
17
-
18
-
19
- def build(
20
- name: str,
21
- images_path: str,
22
- annotations_path: str,
23
- split: str,
24
- mivolo_model: MiVOLO,
25
- workers: int,
26
- batch_size: int,
27
- ) -> Tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader]:
28
-
29
- dataset_class = DATASET_CLASS_MAP[name]
30
-
31
- dataset: torch.utils.data.Dataset = dataset_class(
32
- images_path=images_path,
33
- annotations_path=annotations_path,
34
- name=name,
35
- split=split,
36
- target_size=mivolo_model.input_size,
37
- max_age=mivolo_model.meta.max_age,
38
- min_age=mivolo_model.meta.min_age,
39
- model_with_persons=mivolo_model.meta.with_persons_model,
40
- use_persons=mivolo_model.meta.use_persons,
41
- disable_faces=mivolo_model.meta.disable_faces,
42
- only_age=mivolo_model.meta.only_age,
43
- )
44
-
45
- data_config = mivolo_model.data_config
46
-
47
- in_chans = 3 if not mivolo_model.meta.with_persons_model else 6
48
- input_size = (in_chans, mivolo_model.input_size, mivolo_model.input_size)
49
-
50
- dataset_loader: torch.utils.data.DataLoader = create_loader(
51
- dataset,
52
- input_size=input_size,
53
- batch_size=batch_size,
54
- mean=data_config["mean"],
55
- std=data_config["std"],
56
- num_workers=workers,
57
- crop_pct=data_config["crop_pct"],
58
- crop_mode=data_config["crop_mode"],
59
- pin_memory=False,
60
- device=mivolo_model.device,
61
- target_type=dataset.target_dtype,
62
- )
63
-
64
- return dataset, dataset_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mivolo/data/dataset/age_gender_dataset.py DELETED
@@ -1,194 +0,0 @@
1
- import logging
2
- from typing import Any, List, Optional, Set
3
-
4
- import cv2
5
- import numpy as np
6
- import torch
7
- from mivolo.data.dataset.reader_age_gender import ReaderAgeGender
8
- from PIL import Image
9
- from torchvision import transforms
10
-
11
- _logger = logging.getLogger("AgeGenderDataset")
12
-
13
-
14
- class AgeGenderDataset(torch.utils.data.Dataset):
15
- def __init__(
16
- self,
17
- images_path,
18
- annotations_path,
19
- name=None,
20
- split="train",
21
- load_bytes=False,
22
- img_mode="RGB",
23
- transform=None,
24
- is_training=False,
25
- seed=1234,
26
- target_size=224,
27
- min_age=None,
28
- max_age=None,
29
- model_with_persons=False,
30
- use_persons=False,
31
- disable_faces=False,
32
- only_age=False,
33
- ):
34
- reader = ReaderAgeGender(
35
- images_path,
36
- annotations_path,
37
- split=split,
38
- seed=seed,
39
- target_size=target_size,
40
- with_persons=use_persons,
41
- disable_faces=disable_faces,
42
- only_age=only_age,
43
- )
44
-
45
- self.name = name
46
- self.model_with_persons = model_with_persons
47
- self.reader = reader
48
- self.load_bytes = load_bytes
49
- self.img_mode = img_mode
50
- self.transform = transform
51
- self._consecutive_errors = 0
52
- self.is_training = is_training
53
- self.random_flip = 0.0
54
-
55
- # Setting up classes.
56
- # If min and max classes are passed - use them to have the same preprocessing for validation
57
- self.max_age: float = None
58
- self.min_age: float = None
59
- self.avg_age: float = None
60
- self.set_ages_min_max(min_age, max_age)
61
-
62
- self.genders = ["M", "F"]
63
- self.num_classes_gender = len(self.genders)
64
-
65
- self.age_classes: Optional[List[str]] = self.set_age_classes()
66
-
67
- self.num_classes_age = 1 if self.age_classes is None else len(self.age_classes)
68
- self.num_classes: int = self.num_classes_age + self.num_classes_gender
69
- self.target_dtype = torch.float32
70
-
71
- def set_age_classes(self) -> Optional[List[str]]:
72
- return None # for regression dataset
73
-
74
- def set_ages_min_max(self, min_age: Optional[float], max_age: Optional[float]):
75
-
76
- assert all(age is None for age in [min_age, max_age]) or all(
77
- age is not None for age in [min_age, max_age]
78
- ), "Both min and max age must be passed or none of them"
79
-
80
- if max_age is not None and min_age is not None:
81
- _logger.info(f"Received predefined min_age {min_age} and max_age {max_age}")
82
- self.max_age = max_age
83
- self.min_age = min_age
84
- else:
85
- # collect statistics from loaded dataset
86
- all_ages_set: Set[int] = set()
87
- for img_path, image_samples in self.reader._ann.items():
88
- for image_sample_info in image_samples:
89
- if image_sample_info.age == "-1":
90
- continue
91
- age = round(float(image_sample_info.age))
92
- all_ages_set.add(age)
93
-
94
- self.max_age = max(all_ages_set)
95
- self.min_age = min(all_ages_set)
96
-
97
- self.avg_age = (self.max_age + self.min_age) / 2.0
98
-
99
- def _norm_age(self, age):
100
- return (age - self.avg_age) / (self.max_age - self.min_age)
101
-
102
- def parse_gender(self, _gender: str) -> float:
103
- if _gender != "-1":
104
- gender = float(0 if _gender == "M" or _gender == "0" else 1)
105
- else:
106
- gender = -1
107
- return gender
108
-
109
- def parse_target(self, _age: str, gender: str) -> List[Any]:
110
- if _age != "-1":
111
- age = round(float(_age))
112
- age = self._norm_age(float(age))
113
- else:
114
- age = -1
115
-
116
- target: List[float] = [age, self.parse_gender(gender)]
117
- return target
118
-
119
- @property
120
- def transform(self):
121
- return self._transform
122
-
123
- @transform.setter
124
- def transform(self, transform):
125
- # Disable pretrained monkey-patched transforms
126
- if not transform:
127
- return
128
-
129
- _trans = []
130
- for trans in transform.transforms:
131
- if "Resize" in str(trans):
132
- continue
133
- if "Crop" in str(trans):
134
- continue
135
- _trans.append(trans)
136
- self._transform = transforms.Compose(_trans)
137
-
138
- def apply_tranforms(self, image: Optional[np.ndarray]) -> np.ndarray:
139
- if image is None:
140
- return None
141
-
142
- if self.transform is None:
143
- return image
144
-
145
- image = convert_to_pil(image, self.img_mode)
146
- for trans in self.transform.transforms:
147
- image = trans(image)
148
- return image
149
-
150
- def __getitem__(self, index):
151
- # get preprocessed face and person crops (np.ndarray)
152
- # resize + pad, for person crops: cut off other bboxes
153
- images, target = self.reader[index]
154
-
155
- target = self.parse_target(*target)
156
-
157
- if self.model_with_persons:
158
- face_image, person_image = images
159
- person_image: np.ndarray = self.apply_tranforms(person_image)
160
- else:
161
- face_image = images[0]
162
- person_image = None
163
-
164
- face_image: np.ndarray = self.apply_tranforms(face_image)
165
-
166
- if person_image is not None:
167
- img = np.concatenate([face_image, person_image], axis=0)
168
- else:
169
- img = face_image
170
-
171
- return img, target
172
-
173
- def __len__(self):
174
- return len(self.reader)
175
-
176
- def filename(self, index, basename=False, absolute=False):
177
- return self.reader.filename(index, basename, absolute)
178
-
179
- def filenames(self, basename=False, absolute=False):
180
- return self.reader.filenames(basename, absolute)
181
-
182
-
183
- def convert_to_pil(cv_im: Optional[np.ndarray], img_mode: str = "RGB") -> "Image":
184
- if cv_im is None:
185
- return None
186
-
187
- if img_mode == "RGB":
188
- cv_im = cv2.cvtColor(cv_im, cv2.COLOR_BGR2RGB)
189
- else:
190
- raise Exception("Incorrect image mode has been passed!")
191
-
192
- cv_im = np.ascontiguousarray(cv_im)
193
- pil_image = Image.fromarray(cv_im)
194
- return pil_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mivolo/data/dataset/age_gender_loader.py DELETED
@@ -1,169 +0,0 @@
1
- """
2
- Code adapted from timm https://github.com/huggingface/pytorch-image-models
3
-
4
- Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
- """
6
-
7
- import logging
8
- from contextlib import suppress
9
- from functools import partial
10
- from itertools import repeat
11
-
12
- import numpy as np
13
- import torch
14
- import torch.utils.data
15
- from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16
- from timm.data.dataset import IterableImageDataset
17
- from timm.data.loader import PrefetchLoader, _worker_init
18
- from timm.data.transforms_factory import create_transform
19
-
20
- _logger = logging.getLogger(__name__)
21
-
22
-
23
- def fast_collate(batch, target_dtype=torch.uint8):
24
- """A fast collation function optimized for uint8 images (np array or torch) and target_dtype targets (labels)"""
25
- assert isinstance(batch[0], tuple)
26
- batch_size = len(batch)
27
- if isinstance(batch[0][0], np.ndarray):
28
- targets = torch.tensor([b[1] for b in batch], dtype=target_dtype)
29
- assert len(targets) == batch_size
30
- tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
31
- for i in range(batch_size):
32
- tensor[i] += torch.from_numpy(batch[i][0])
33
- return tensor, targets
34
- else:
35
- raise ValueError(f"Incorrect batch type: {type(batch[0][0])}")
36
-
37
-
38
- def adapt_to_chs(x, n):
39
- if not isinstance(x, (tuple, list)):
40
- x = tuple(repeat(x, n))
41
- elif len(x) != n:
42
- # doubled channels
43
- if len(x) * 2 == n:
44
- x = np.concatenate((x, x))
45
- _logger.warning(f"Pretrained mean/std different shape than model (doubled channes), using concat: {x}.")
46
- else:
47
- x_mean = np.mean(x).item()
48
- x = (x_mean,) * n
49
- _logger.warning(f"Pretrained mean/std different shape than model, using avg value {x}.")
50
- else:
51
- assert len(x) == n, "normalization stats must match image channels"
52
- return x
53
-
54
-
55
- class PrefetchLoaderForMultiInput(PrefetchLoader):
56
- def __init__(
57
- self,
58
- loader,
59
- mean=IMAGENET_DEFAULT_MEAN,
60
- std=IMAGENET_DEFAULT_STD,
61
- channels=3,
62
- device=torch.device("cuda"),
63
- img_dtype=torch.float32,
64
- ):
65
-
66
- mean = adapt_to_chs(mean, channels)
67
- std = adapt_to_chs(std, channels)
68
- normalization_shape = (1, channels, 1, 1)
69
-
70
- self.loader = loader
71
- self.device = device
72
- self.img_dtype = img_dtype
73
- self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
74
- self.std = torch.tensor([x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
75
-
76
- self.is_cuda = torch.cuda.is_available() and device.type == "cuda"
77
-
78
- def __iter__(self):
79
- first = True
80
- if self.is_cuda:
81
- stream = torch.cuda.Stream()
82
- stream_context = partial(torch.cuda.stream, stream=stream)
83
- else:
84
- stream = None
85
- stream_context = suppress
86
-
87
- for next_input, next_target in self.loader:
88
-
89
- with stream_context():
90
- next_input = next_input.to(device=self.device, non_blocking=True)
91
- next_target = next_target.to(device=self.device, non_blocking=True)
92
- next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
93
-
94
- if not first:
95
- yield input, target # noqa: F823, F821
96
- else:
97
- first = False
98
-
99
- if stream is not None:
100
- torch.cuda.current_stream().wait_stream(stream)
101
-
102
- input = next_input
103
- target = next_target
104
-
105
- yield input, target
106
-
107
-
108
- def create_loader(
109
- dataset,
110
- input_size,
111
- batch_size,
112
- mean=IMAGENET_DEFAULT_MEAN,
113
- std=IMAGENET_DEFAULT_STD,
114
- num_workers=1,
115
- crop_pct=None,
116
- crop_mode=None,
117
- pin_memory=False,
118
- img_dtype=torch.float32,
119
- device=torch.device("cuda"),
120
- persistent_workers=True,
121
- worker_seeding="all",
122
- target_type=torch.int64,
123
- ):
124
-
125
- transform = create_transform(
126
- input_size,
127
- is_training=False,
128
- use_prefetcher=True,
129
- mean=mean,
130
- std=std,
131
- crop_pct=crop_pct,
132
- crop_mode=crop_mode,
133
- )
134
- dataset.transform = transform
135
-
136
- if isinstance(dataset, IterableImageDataset):
137
- # give Iterable datasets early knowledge of num_workers so that sample estimates
138
- # are correct before worker processes are launched
139
- dataset.set_loader_cfg(num_workers=num_workers)
140
- raise ValueError("Incorrect dataset type: IterableImageDataset")
141
-
142
- loader_class = torch.utils.data.DataLoader
143
- loader_args = dict(
144
- batch_size=batch_size,
145
- shuffle=False,
146
- num_workers=num_workers,
147
- sampler=None,
148
- collate_fn=lambda batch: fast_collate(batch, target_dtype=target_type),
149
- pin_memory=pin_memory,
150
- drop_last=False,
151
- worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
152
- persistent_workers=persistent_workers,
153
- )
154
- try:
155
- loader = loader_class(dataset, **loader_args)
156
- except TypeError:
157
- loader_args.pop("persistent_workers") # only in Pytorch 1.7+
158
- loader = loader_class(dataset, **loader_args)
159
-
160
- loader = PrefetchLoaderForMultiInput(
161
- loader,
162
- mean=mean,
163
- std=std,
164
- channels=input_size[0],
165
- device=device,
166
- img_dtype=img_dtype,
167
- )
168
-
169
- return loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mivolo/data/dataset/classification_dataset.py DELETED
@@ -1,48 +0,0 @@
1
- from typing import Any, List, Optional
2
-
3
- import torch
4
-
5
- from .age_gender_dataset import AgeGenderDataset
6
-
7
-
8
- class ClassificationDataset(AgeGenderDataset):
9
- def __init__(self, *args, **kwargs):
10
- super().__init__(*args, **kwargs)
11
-
12
- self.target_dtype = torch.int32
13
-
14
- def set_age_classes(self) -> Optional[List[str]]:
15
- raise NotImplementedError
16
-
17
- def parse_target(self, age: str, gender: str) -> List[Any]:
18
- assert self.age_classes is not None
19
- if age != "-1":
20
- assert age in self.age_classes, f"Unknown category in {self.name} dataset: {age}"
21
- age_ind = self.age_classes.index(age)
22
- else:
23
- age_ind = -1
24
-
25
- target: List[int] = [age_ind, int(self.parse_gender(gender))]
26
- return target
27
-
28
-
29
- class FairFaceDataset(ClassificationDataset):
30
- def set_age_classes(self) -> Optional[List[str]]:
31
- age_classes = ["0;2", "3;9", "10;19", "20;29", "30;39", "40;49", "50;59", "60;69", "70;120"]
32
- # a[i-1] <= v < a[i] => age_classes[i-1]
33
- self._intervals = torch.tensor([0, 3, 10, 20, 30, 40, 50, 60, 70])
34
-
35
- return age_classes
36
-
37
-
38
- class AdienceDataset(ClassificationDataset):
39
- def __init__(self, *args, **kwargs):
40
- super().__init__(*args, **kwargs)
41
-
42
- self.target_dtype = torch.int32
43
-
44
- def set_age_classes(self) -> Optional[List[str]]:
45
- age_classes = ["0;2", "4;6", "8;12", "15;20", "25;32", "38;43", "48;53", "60;100"]
46
- # a[i-1] <= v < a[i] => age_classes[i-1]
47
- self._intervals = torch.tensor([0, 4, 7, 14, 24, 36, 46, 57])
48
- return age_classes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mivolo/data/dataset/reader_age_gender.py DELETED
@@ -1,490 +0,0 @@
1
- import logging
2
- import os
3
- from functools import partial
4
- from multiprocessing.pool import ThreadPool
5
- from typing import Dict, List, Optional, Tuple
6
-
7
- import cv2
8
- import numpy as np
9
- from mivolo.data.data_reader import AnnotType, PictureInfo, get_all_files, read_csv_annotation_file
10
- from mivolo.data.misc import IOU, class_letterbox, cropout_black_parts
11
- from timm.data.readers.reader import Reader
12
- from tqdm import tqdm
13
-
14
- CROP_ROUND_TOL = 0.3
15
- MIN_PERSON_SIZE = 100
16
- MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4
17
-
18
- _logger = logging.getLogger("ReaderAgeGender")
19
-
20
-
21
- class ReaderAgeGender(Reader):
22
- """
23
- Reader for almost original imdb-wiki cleaned dataset.
24
- Two changes:
25
- 1. Your annotation must be in ./annotation subdir of dataset root
26
- 2. Images must be in images subdir
27
-
28
- """
29
-
30
- def __init__(
31
- self,
32
- images_path,
33
- annotations_path,
34
- split="validation",
35
- target_size=224,
36
- min_size=5,
37
- seed=1234,
38
- with_persons=False,
39
- min_person_size=MIN_PERSON_SIZE,
40
- disable_faces=False,
41
- only_age=False,
42
- min_person_aftercut_ratio=MIN_PERSON_CROP_AFTERCUT_RATIO,
43
- crop_round_tol=CROP_ROUND_TOL,
44
- ):
45
- super().__init__()
46
-
47
- self.with_persons = with_persons
48
- self.disable_faces = disable_faces
49
- self.only_age = only_age
50
-
51
- # can be only black for now, even though it's not very good with further normalization
52
- self.crop_out_color = (0, 0, 0)
53
-
54
- self.empty_crop = np.ones((target_size, target_size, 3)) * self.crop_out_color
55
- self.empty_crop = self.empty_crop.astype(np.uint8)
56
-
57
- self.min_person_size = min_person_size
58
- self.min_person_aftercut_ratio = min_person_aftercut_ratio
59
- self.crop_round_tol = crop_round_tol
60
-
61
- self.split = split
62
- self.min_size = min_size
63
- self.seed = seed
64
- self.target_size = target_size
65
-
66
- # Reading annotations. Can be multiple files if annotations_path dir
67
- self._ann: Dict[str, List[PictureInfo]] = {} # list of samples for each image
68
- self._associated_objects: Dict[str, Dict[int, List[List[int]]]] = {}
69
- self._faces_list: List[Tuple[str, int]] = [] # samples from this list will be loaded in __getitem__
70
-
71
- self._read_annotations(images_path, annotations_path)
72
- _logger.info(f"Dataset length: {len(self._faces_list)} crops")
73
-
74
- def __getitem__(self, index):
75
- return self._read_img_and_label(index)
76
-
77
- def __len__(self):
78
- return len(self._faces_list)
79
-
80
- def _filename(self, index, basename=False, absolute=False):
81
- img_p = self._faces_list[index][0]
82
- return os.path.basename(img_p) if basename else img_p
83
-
84
- def _read_annotations(self, images_path, csvs_path):
85
- self._ann = {}
86
- self._faces_list = []
87
- self._associated_objects = {}
88
-
89
- csvs = get_all_files(csvs_path, [".csv"])
90
- csvs = [c for c in csvs if self.split in os.path.basename(c)]
91
-
92
- # load annotations per image
93
- for csv in csvs:
94
- db, ann_type = read_csv_annotation_file(csv, images_path)
95
- if self.with_persons and ann_type != AnnotType.PERSONS:
96
- raise ValueError(
97
- f"Annotation type in file {csv} contains no persons, "
98
- f"but annotations with persons are requested."
99
- )
100
- self._ann.update(db)
101
-
102
- if len(self._ann) == 0:
103
- raise ValueError("Annotations are empty!")
104
-
105
- self._ann, self._associated_objects = self.prepare_annotations()
106
- images_list = list(self._ann.keys())
107
-
108
- for img_path in images_list:
109
- for index, image_sample_info in enumerate(self._ann[img_path]):
110
- assert image_sample_info.has_gt(
111
- self.only_age
112
- ), "Annotations must be checked with self.prepare_annotations() func"
113
- self._faces_list.append((img_path, index))
114
-
115
- def _read_img_and_label(self, index):
116
- if not isinstance(index, int):
117
- raise TypeError("ReaderAgeGender expected index to be integer")
118
-
119
- img_p, face_index = self._faces_list[index]
120
- ann: PictureInfo = self._ann[img_p][face_index]
121
- img = cv2.imread(img_p)
122
-
123
- face_empty = True
124
- if ann.has_face_bbox and not (self.with_persons and self.disable_faces):
125
- face_crop, face_empty = self._get_crop(ann.bbox, img)
126
-
127
- if not self.with_persons and face_empty:
128
- # model without persons
129
- raise ValueError("Annotations must be checked with self.prepare_annotations() func")
130
-
131
- if face_empty:
132
- face_crop = self.empty_crop
133
-
134
- person_empty = True
135
- if self.with_persons or self.disable_faces:
136
- if ann.has_person_bbox:
137
- # cut off all associated objects from person crop
138
- objects = self._associated_objects[img_p][face_index]
139
- person_crop, person_empty = self._get_crop(
140
- ann.person_bbox,
141
- img,
142
- crop_out_color=self.crop_out_color,
143
- asced_objects=objects,
144
- )
145
-
146
- if face_empty and person_empty:
147
- raise ValueError("Annotations must be checked with self.prepare_annotations() func")
148
-
149
- if person_empty:
150
- person_crop = self.empty_crop
151
-
152
- return (face_crop, person_crop), [ann.age, ann.gender]
153
-
154
- def _get_crop(
155
- self,
156
- bbox,
157
- img,
158
- asced_objects=None,
159
- crop_out_color=(0, 0, 0),
160
- ) -> Tuple[np.ndarray, bool]:
161
-
162
- empty_bbox = False
163
-
164
- xmin, ymin, xmax, ymax = bbox
165
- assert not (
166
- ymax - ymin < self.min_size or xmax - xmin < self.min_size
167
- ), "Annotations must be checked with self.prepare_annotations() func"
168
-
169
- crop = img[ymin:ymax, xmin:xmax]
170
-
171
- if asced_objects:
172
- # cut off other objects for person crop
173
- crop, empty_bbox = _cropout_asced_objs(
174
- asced_objects,
175
- bbox,
176
- crop.copy(),
177
- crop_out_color=crop_out_color,
178
- min_person_size=self.min_person_size,
179
- crop_round_tol=self.crop_round_tol,
180
- min_person_aftercut_ratio=self.min_person_aftercut_ratio,
181
- )
182
- if empty_bbox:
183
- crop = self.empty_crop
184
-
185
- crop = class_letterbox(crop, new_shape=(self.target_size, self.target_size), color=crop_out_color)
186
- return crop, empty_bbox
187
-
188
- def prepare_annotations(self):
189
-
190
- good_anns: Dict[str, List[PictureInfo]] = {}
191
- all_associated_objects: Dict[str, Dict[int, List[List[int]]]] = {}
192
-
193
- if not self.with_persons:
194
- # remove all persons
195
- for img_path, bboxes in self._ann.items():
196
- for sample in bboxes:
197
- sample.clear_person_bbox()
198
-
199
- # check dataset and collect associated_objects
200
- verify_images_func = partial(
201
- verify_images,
202
- min_size=self.min_size,
203
- min_person_size=self.min_person_size,
204
- with_persons=self.with_persons,
205
- disable_faces=self.disable_faces,
206
- crop_round_tol=self.crop_round_tol,
207
- min_person_aftercut_ratio=self.min_person_aftercut_ratio,
208
- only_age=self.only_age,
209
- )
210
- num_threads = min(8, os.cpu_count())
211
-
212
- all_msgs = []
213
- broken = 0
214
- skipped = 0
215
- all_skipped_crops = 0
216
- desc = "Check annotations..."
217
- with ThreadPool(num_threads) as pool:
218
- pbar = tqdm(
219
- pool.imap_unordered(verify_images_func, list(self._ann.items())),
220
- desc=desc,
221
- total=len(self._ann),
222
- )
223
-
224
- for (img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops) in pbar:
225
- broken += 1 if is_corrupted else 0
226
- all_msgs.extend(msgs)
227
- all_skipped_crops += skipped_crops
228
- skipped += 1 if is_empty_annotations else 0
229
- if img_info is not None:
230
- img_path, img_samples = img_info
231
- good_anns[img_path] = img_samples
232
- all_associated_objects.update({img_path: associated_objects})
233
-
234
- pbar.desc = (
235
- f"{desc} {skipped} images skipped ({all_skipped_crops} crops are incorrect); "
236
- f"{broken} images corrupted"
237
- )
238
-
239
- pbar.close()
240
-
241
- for msg in all_msgs:
242
- print(msg)
243
- print(f"\nLeft images: {len(good_anns)}")
244
-
245
- return good_anns, all_associated_objects
246
-
247
-
248
- def verify_images(
249
- img_info,
250
- min_size: int,
251
- min_person_size: int,
252
- with_persons: bool,
253
- disable_faces: bool,
254
- crop_round_tol: float,
255
- min_person_aftercut_ratio: float,
256
- only_age: bool,
257
- ):
258
- # If crop is too small, if image can not be read or if image does not exist
259
- # then filter out this sample
260
-
261
- disable_faces = disable_faces and with_persons
262
- kwargs = dict(
263
- min_person_size=min_person_size,
264
- disable_faces=disable_faces,
265
- with_persons=with_persons,
266
- crop_round_tol=crop_round_tol,
267
- min_person_aftercut_ratio=min_person_aftercut_ratio,
268
- only_age=only_age,
269
- )
270
-
271
- def bbox_correct(bbox, min_size, im_h, im_w) -> Tuple[bool, List[int]]:
272
- ymin, ymax, xmin, xmax = _correct_bbox(bbox, im_h, im_w)
273
- crop_h, crop_w = ymax - ymin, xmax - xmin
274
- if crop_h < min_size or crop_w < min_size:
275
- return False, [-1, -1, -1, -1]
276
- bbox = [xmin, ymin, xmax, ymax]
277
- return True, bbox
278
-
279
- msgs = []
280
- skipped_crops = 0
281
- is_corrupted = False
282
- is_empty_annotations = False
283
-
284
- img_path: str = img_info[0]
285
- img_samples: List[PictureInfo] = img_info[1]
286
- try:
287
- im_cv = cv2.imread(img_path)
288
- im_h, im_w = im_cv.shape[:2]
289
- except Exception:
290
- msgs.append(f"Can not load image {img_path}")
291
- is_corrupted = True
292
- return None, {}, msgs, is_corrupted, is_empty_annotations, skipped_crops
293
-
294
- out_samples: List[PictureInfo] = []
295
- for sample in img_samples:
296
- # correct face bbox
297
- if sample.has_face_bbox:
298
- is_correct, sample.bbox = bbox_correct(sample.bbox, min_size, im_h, im_w)
299
- if not is_correct and sample.has_gt(only_age):
300
- msgs.append("Small face. Passing..")
301
- skipped_crops += 1
302
-
303
- # correct person bbox
304
- if sample.has_person_bbox:
305
- is_correct, sample.person_bbox = bbox_correct(
306
- sample.person_bbox, max(min_person_size, min_size), im_h, im_w
307
- )
308
- if not is_correct and sample.has_gt(only_age):
309
- msgs.append(f"Small person {img_path}. Passing..")
310
- skipped_crops += 1
311
-
312
- if sample.has_face_bbox or sample.has_person_bbox:
313
- out_samples.append(sample)
314
- elif sample.has_gt(only_age):
315
- msgs.append("Sample hs no face and no body. Passing..")
316
- skipped_crops += 1
317
-
318
- # sort that samples with undefined age and gender be the last
319
- out_samples = sorted(out_samples, key=lambda sample: 1 if not sample.has_gt(only_age) else 0)
320
-
321
- # for each person find other faces and persons bboxes, intersected with it
322
- associated_objects: Dict[int, List[List[int]]] = find_associated_objects(out_samples, only_age=only_age)
323
-
324
- out_samples, associated_objects, skipped_crops = filter_bad_samples(
325
- out_samples, associated_objects, im_cv, msgs, skipped_crops, **kwargs
326
- )
327
-
328
- out_img_info: Optional[Tuple[str, List]] = (img_path, out_samples)
329
- if len(out_samples) == 0:
330
- out_img_info = None
331
- is_empty_annotations = True
332
-
333
- return out_img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops
334
-
335
-
336
- def filter_bad_samples(
337
- out_samples: List[PictureInfo],
338
- associated_objects: dict,
339
- im_cv: np.ndarray,
340
- msgs: List[str],
341
- skipped_crops: int,
342
- **kwargs,
343
- ):
344
- with_persons, disable_faces, min_person_size, crop_round_tol, min_person_aftercut_ratio, only_age = (
345
- kwargs["with_persons"],
346
- kwargs["disable_faces"],
347
- kwargs["min_person_size"],
348
- kwargs["crop_round_tol"],
349
- kwargs["min_person_aftercut_ratio"],
350
- kwargs["only_age"],
351
- )
352
-
353
- # left only samples with annotations
354
- inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_gt(only_age)]
355
- out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
356
-
357
- if kwargs["disable_faces"]:
358
- # clear all faces
359
- for ind, sample in enumerate(out_samples):
360
- sample.clear_face_bbox()
361
-
362
- # left only samples with person_bbox
363
- inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_person_bbox]
364
- out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
365
-
366
- if with_persons or disable_faces:
367
- # check that preprocessing func
368
- # _cropout_asced_objs() return not empty person_image for each out sample
369
-
370
- inds = []
371
- for ind, sample in enumerate(out_samples):
372
- person_empty = True
373
- if sample.has_person_bbox:
374
- xmin, ymin, xmax, ymax = sample.person_bbox
375
- crop = im_cv[ymin:ymax, xmin:xmax]
376
- # cut off all associated objects from person crop
377
- _, person_empty = _cropout_asced_objs(
378
- associated_objects[ind],
379
- sample.person_bbox,
380
- crop.copy(),
381
- min_person_size=min_person_size,
382
- crop_round_tol=crop_round_tol,
383
- min_person_aftercut_ratio=min_person_aftercut_ratio,
384
- )
385
-
386
- if person_empty and not sample.has_face_bbox:
387
- msgs.append("Small person after preprocessing. Passing..")
388
- skipped_crops += 1
389
- else:
390
- inds.append(ind)
391
- out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
392
-
393
- assert len(associated_objects) == len(out_samples)
394
- return out_samples, associated_objects, skipped_crops
395
-
396
-
397
- def _filter_by_ind(out_samples, associated_objects, inds):
398
- _associated_objects = {}
399
- _out_samples = []
400
- for ind, sample in enumerate(out_samples):
401
- if ind in inds:
402
- _associated_objects[len(_out_samples)] = associated_objects[ind]
403
- _out_samples.append(sample)
404
-
405
- return _out_samples, _associated_objects
406
-
407
-
408
- def find_associated_objects(
409
- image_samples: List[PictureInfo], iou_thresh=0.0001, only_age=False
410
- ) -> Dict[int, List[List[int]]]:
411
- """
412
- For each person (which has gt age and gt gender) find other faces and persons bboxes, intersected with it
413
- """
414
- associated_objects: Dict[int, List[List[int]]] = {}
415
-
416
- for iindex, image_sample_info in enumerate(image_samples):
417
- # add own face
418
- associated_objects[iindex] = [image_sample_info.bbox] if image_sample_info.has_face_bbox else []
419
-
420
- if not image_sample_info.has_person_bbox or not image_sample_info.has_gt(only_age):
421
- # if sample has not gt => not be used
422
- continue
423
-
424
- iperson_box = image_sample_info.person_bbox
425
- for jindex, other_image_sample in enumerate(image_samples):
426
- if iindex == jindex:
427
- continue
428
- if other_image_sample.has_face_bbox:
429
- jface_bbox = other_image_sample.bbox
430
- iou = _get_iou(jface_bbox, iperson_box)
431
- if iou >= iou_thresh:
432
- associated_objects[iindex].append(jface_bbox)
433
- if other_image_sample.has_person_bbox:
434
- jperson_bbox = other_image_sample.person_bbox
435
- iou = _get_iou(jperson_bbox, iperson_box)
436
- if iou >= iou_thresh:
437
- associated_objects[iindex].append(jperson_bbox)
438
-
439
- return associated_objects
440
-
441
-
442
- def _cropout_asced_objs(
443
- asced_objects,
444
- person_bbox,
445
- crop,
446
- min_person_size,
447
- crop_round_tol,
448
- min_person_aftercut_ratio,
449
- crop_out_color=(0, 0, 0),
450
- ):
451
- empty = False
452
- xmin, ymin, xmax, ymax = person_bbox
453
-
454
- for a_obj in asced_objects:
455
- aobj_xmin, aobj_ymin, aobj_xmax, aobj_ymax = a_obj
456
-
457
- aobj_ymin = int(max(aobj_ymin - ymin, 0))
458
- aobj_xmin = int(max(aobj_xmin - xmin, 0))
459
- aobj_ymax = int(min(aobj_ymax - ymin, ymax - ymin))
460
- aobj_xmax = int(min(aobj_xmax - xmin, xmax - xmin))
461
-
462
- crop[aobj_ymin:aobj_ymax, aobj_xmin:aobj_xmax] = crop_out_color
463
-
464
- crop, cropped_ratio = cropout_black_parts(crop, crop_round_tol)
465
- if (
466
- crop.shape[0] < min_person_size or crop.shape[1] < min_person_size
467
- ) or cropped_ratio < min_person_aftercut_ratio:
468
- crop = None
469
- empty = True
470
-
471
- return crop, empty
472
-
473
-
474
- def _correct_bbox(bbox, h, w):
475
- xmin, ymin, xmax, ymax = bbox
476
- ymin = min(max(ymin, 0), h)
477
- ymax = min(max(ymax, 0), h)
478
- xmin = min(max(xmin, 0), w)
479
- xmax = min(max(xmax, 0), w)
480
- return ymin, ymax, xmin, xmax
481
-
482
-
483
- def _get_iou(bbox1, bbox2):
484
- xmin1, ymin1, xmax1, ymax1 = bbox1
485
- xmin2, ymin2, xmax2, ymax2 = bbox2
486
- iou = IOU(
487
- [ymin1, xmin1, ymax1, xmax1],
488
- [ymin2, xmin2, ymax2, xmax2],
489
- )
490
- return iou
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mivolo/data/misc.py DELETED
@@ -1,264 +0,0 @@
1
- import argparse
2
- import ast
3
- import re
4
- from typing import List, Optional, Tuple, Union
5
-
6
- import cv2
7
- import numpy as np
8
- import torch
9
- import torchvision.transforms.functional as F
10
- from scipy.optimize import linear_sum_assignment
11
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
12
-
13
- CROP_ROUND_RATE = 0.1
14
- MIN_PERSON_CROP_NONZERO = 0.5
15
-
16
-
17
- def aggregate_votes_winsorized(ages, max_age_dist=6):
18
- # Replace any annotation that is more than a max_age_dist away from the median
19
- # with the median + max_age_dist if higher or max_age_dist - max_age_dist if below
20
- median = np.median(ages)
21
- ages = np.clip(ages, median - max_age_dist, median + max_age_dist)
22
- return np.mean(ages)
23
-
24
-
25
- def cropout_black_parts(img, tol=0.3):
26
- # Create a binary mask of zero pixels
27
- zero_pixels_mask = np.all(img == 0, axis=2)
28
- # Calculate the threshold for zero pixels in rows and columns
29
- threshold = img.shape[0] - img.shape[0] * tol
30
- # Calculate row sums and column sums of zero pixels mask
31
- row_sums = np.sum(zero_pixels_mask, axis=1)
32
- col_sums = np.sum(zero_pixels_mask, axis=0)
33
- # Find the first and last rows with zero pixel sums above the threshold
34
- start_row = np.argmin(row_sums > threshold)
35
- end_row = img.shape[0] - np.argmin(row_sums[::-1] > threshold)
36
- # Find the first and last columns with zero pixel sums above the threshold
37
- start_col = np.argmin(col_sums > threshold)
38
- end_col = img.shape[1] - np.argmin(col_sums[::-1] > threshold)
39
- # Crop the image
40
- cropped_img = img[start_row:end_row, start_col:end_col, :]
41
- area = cropped_img.shape[0] * cropped_img.shape[1]
42
- area_orig = img.shape[0] * img.shape[1]
43
- return cropped_img, area / area_orig
44
-
45
-
46
- def natural_key(string_):
47
- """See http://www.codinghorror.com/blog/archives/001018.html"""
48
- return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
49
-
50
-
51
- def add_bool_arg(parser, name, default=False, help=""):
52
- dest_name = name.replace("-", "_")
53
- group = parser.add_mutually_exclusive_group(required=False)
54
- group.add_argument("--" + name, dest=dest_name, action="store_true", help=help)
55
- group.add_argument("--no-" + name, dest=dest_name, action="store_false", help=help)
56
- parser.set_defaults(**{dest_name: default})
57
-
58
-
59
- def cumulative_score(pred_ages, gt_ages, L, tol=1e-6):
60
- n = pred_ages.shape[0]
61
- num_correct = torch.sum(torch.abs(pred_ages - gt_ages) <= L + tol)
62
- cs_score = num_correct / n
63
- return cs_score
64
-
65
-
66
- def cumulative_error(pred_ages, gt_ages, L, tol=1e-6):
67
- n = pred_ages.shape[0]
68
- num_correct = torch.sum(torch.abs(pred_ages - gt_ages) >= L + tol)
69
- cs_score = num_correct / n
70
- return cs_score
71
-
72
-
73
- class ParseKwargs(argparse.Action):
74
- def __call__(self, parser, namespace, values, option_string=None):
75
- kw = {}
76
- for value in values:
77
- key, value = value.split("=")
78
- try:
79
- kw[key] = ast.literal_eval(value)
80
- except ValueError:
81
- kw[key] = str(value) # fallback to string (avoid need to escape on command line)
82
- setattr(namespace, self.dest, kw)
83
-
84
-
85
- def box_iou(box1, box2, over_second=False):
86
- """
87
- Return intersection-over-union (Jaccard index) of boxes.
88
- If over_second == True, return mean(intersection-over-union, (inter / area2))
89
-
90
- Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
91
-
92
- Arguments:
93
- box1 (Tensor[N, 4])
94
- box2 (Tensor[M, 4])
95
- Returns:
96
- iou (Tensor[N, M]): the NxM matrix containing the pairwise
97
- IoU values for every element in boxes1 and boxes2
98
- """
99
-
100
- def box_area(box):
101
- # box = 4xn
102
- return (box[2] - box[0]) * (box[3] - box[1])
103
-
104
- area1 = box_area(box1.T)
105
- area2 = box_area(box2.T)
106
-
107
- # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
108
- inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
109
-
110
- iou = inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
111
- if over_second:
112
- return (inter / area2 + iou) / 2 # mean(inter / area2, iou)
113
- else:
114
- return iou
115
-
116
-
117
- def split_batch(bs: int, dev: int) -> Tuple[int, int]:
118
- full_bs = (bs // dev) * dev
119
- part_bs = bs - full_bs
120
- return full_bs, part_bs
121
-
122
-
123
- def assign_faces(
124
- persons_bboxes: List[torch.tensor], faces_bboxes: List[torch.tensor], iou_thresh: float = 0.0001
125
- ) -> Tuple[List[Optional[int]], List[int]]:
126
- """
127
- Assign person to each face if it is possible.
128
- Return:
129
- - assigned_faces List[Optional[int]]: mapping of face_ind to person_ind
130
- ( assigned_faces[face_ind] = person_ind ). person_ind can be None
131
- - unassigned_persons_inds List[int]: persons indexes without any assigned face
132
- """
133
-
134
- assigned_faces: List[Optional[int]] = [None for _ in range(len(faces_bboxes))]
135
- unassigned_persons_inds: List[int] = [p_ind for p_ind in range(len(persons_bboxes))]
136
-
137
- if len(persons_bboxes) == 0 or len(faces_bboxes) == 0:
138
- return assigned_faces, unassigned_persons_inds
139
-
140
- cost_matrix = box_iou(torch.stack(persons_bboxes), torch.stack(faces_bboxes), over_second=True).cpu().numpy()
141
- persons_indexes, face_indexes = [], []
142
-
143
- if len(cost_matrix) > 0:
144
- persons_indexes, face_indexes = linear_sum_assignment(cost_matrix, maximize=True)
145
-
146
- matched_persons = set()
147
- for person_idx, face_idx in zip(persons_indexes, face_indexes):
148
- ciou = cost_matrix[person_idx][face_idx]
149
- if ciou > iou_thresh:
150
- if person_idx in matched_persons:
151
- # Person can not be assigned twice, in reality this should not happen
152
- continue
153
- assigned_faces[face_idx] = person_idx
154
- matched_persons.add(person_idx)
155
-
156
- unassigned_persons_inds = [p_ind for p_ind in range(len(persons_bboxes)) if p_ind not in matched_persons]
157
-
158
- return assigned_faces, unassigned_persons_inds
159
-
160
-
161
- def class_letterbox(im, new_shape=(640, 640), color=(0, 0, 0), scaleup=True):
162
- # Resize and pad image while meeting stride-multiple constraints
163
- shape = im.shape[:2] # current shape [height, width]
164
- if isinstance(new_shape, int):
165
- new_shape = (new_shape, new_shape)
166
-
167
- if im.shape[0] == new_shape[0] and im.shape[1] == new_shape[1]:
168
- return im
169
-
170
- # Scale ratio (new / old)
171
- r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
172
- if not scaleup: # only scale down, do not scale up (for better val mAP)
173
- r = min(r, 1.0)
174
-
175
- # Compute padding
176
- # ratio = r, r # width, height ratios
177
- new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
178
- dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
179
-
180
- dw /= 2 # divide padding into 2 sides
181
- dh /= 2
182
-
183
- if shape[::-1] != new_unpad: # resize
184
- im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
185
- top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
186
- left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
187
- im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
188
- return im
189
-
190
-
191
- def prepare_classification_images(
192
- img_list: List[Optional[np.ndarray]],
193
- target_size: int = 224,
194
- mean=IMAGENET_DEFAULT_MEAN,
195
- std=IMAGENET_DEFAULT_STD,
196
- device=None,
197
- ) -> torch.tensor:
198
-
199
- prepared_images: List[torch.tensor] = []
200
-
201
- for img in img_list:
202
- if img is None:
203
- img = torch.zeros((3, target_size, target_size), dtype=torch.float32)
204
- img = F.normalize(img, mean=mean, std=std)
205
- img = img.unsqueeze(0)
206
- prepared_images.append(img)
207
- continue
208
- img = class_letterbox(img, new_shape=(target_size, target_size))
209
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
210
-
211
- img = img / 255.0
212
- img = (img - mean) / std
213
- img = img.astype(dtype=np.float32)
214
-
215
- img = img.transpose((2, 0, 1))
216
- img = np.ascontiguousarray(img)
217
- img = torch.from_numpy(img)
218
- img = img.unsqueeze(0)
219
-
220
- prepared_images.append(img)
221
-
222
- prepared_input = torch.concat(prepared_images)
223
-
224
- if device:
225
- prepared_input = prepared_input.to(device)
226
-
227
- return prepared_input
228
-
229
-
230
- def IOU(bb1: Union[tuple, list], bb2: Union[tuple, list], norm_second_bbox: bool = False) -> float:
231
- # expects [ymin, xmin, ymax, xmax], doesnt matter absolute or relative
232
- assert bb1[1] < bb1[3]
233
- assert bb1[0] < bb1[2]
234
- assert bb2[1] < bb2[3]
235
- assert bb2[0] < bb2[2]
236
-
237
- # determine the coordinates of the intersection rectangle
238
- x_left = max(bb1[1], bb2[1])
239
- y_top = max(bb1[0], bb2[0])
240
- x_right = min(bb1[3], bb2[3])
241
- y_bottom = min(bb1[2], bb2[2])
242
-
243
- if x_right < x_left or y_bottom < y_top:
244
- return 0.0
245
-
246
- # The intersection of two axis-aligned bounding boxes is always an
247
- # axis-aligned bounding box
248
- intersection_area = (x_right - x_left) * (y_bottom - y_top)
249
- # compute the area of both AABBs
250
- bb1_area = (bb1[3] - bb1[1]) * (bb1[2] - bb1[0])
251
- bb2_area = (bb2[3] - bb2[1]) * (bb2[2] - bb2[0])
252
- if not norm_second_bbox:
253
- # compute the intersection over union by taking the intersection
254
- # area and dividing it by the sum of prediction + ground-truth
255
- # areas - the interesection area
256
- iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
257
- else:
258
- # for cases when we search if second bbox is inside first one
259
- iou = intersection_area / float(bb2_area)
260
-
261
- assert iou >= 0.0
262
- assert iou <= 1.01
263
-
264
- return iou
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mivolo/model/create_timm_model.py DELETED
@@ -1,107 +0,0 @@
1
- """
2
- Code adapted from timm https://github.com/huggingface/pytorch-image-models
3
-
4
- Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
- """
6
-
7
- import os
8
- from typing import Any, Dict, Optional, Union
9
-
10
- import timm
11
-
12
- # register new models
13
- from mivolo.model.mivolo_model import * # noqa: F403, F401
14
- from timm.layers import set_layer_config
15
- from timm.models._factory import parse_model_name
16
- from timm.models._helpers import load_state_dict, remap_checkpoint
17
- from timm.models._hub import load_model_config_from_hf
18
- from timm.models._pretrained import PretrainedCfg, split_model_name_tag
19
- from timm.models._registry import is_model, model_entrypoint
20
-
21
-
22
- def load_checkpoint(
23
- model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None
24
- ):
25
- if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"):
26
- # numpy checkpoint, try to load via model specific load_pretrained fn
27
- if hasattr(model, "load_pretrained"):
28
- timm.models._model_builder.load_pretrained(checkpoint_path)
29
- else:
30
- raise NotImplementedError("Model cannot load numpy checkpoint")
31
- return
32
- state_dict = load_state_dict(checkpoint_path, use_ema)
33
- if remap:
34
- state_dict = remap_checkpoint(model, state_dict)
35
- if filter_keys:
36
- for sd_key in list(state_dict.keys()):
37
- for filter_key in filter_keys:
38
- if filter_key in sd_key:
39
- if sd_key in state_dict:
40
- del state_dict[sd_key]
41
-
42
- rep = []
43
- if state_dict_map is not None:
44
- # 'patch_embed.conv1.' : 'patch_embed.conv.'
45
- for state_k in list(state_dict.keys()):
46
- for target_k, target_v in state_dict_map.items():
47
- if target_v in state_k:
48
- target_name = state_k.replace(target_v, target_k)
49
- state_dict[target_name] = state_dict[state_k]
50
- rep.append(state_k)
51
- for r in rep:
52
- if r in state_dict:
53
- del state_dict[r]
54
-
55
- incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False)
56
- return incompatible_keys
57
-
58
-
59
- def create_model(
60
- model_name: str,
61
- pretrained: bool = False,
62
- pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
63
- pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
64
- checkpoint_path: str = "",
65
- scriptable: Optional[bool] = None,
66
- exportable: Optional[bool] = None,
67
- no_jit: Optional[bool] = None,
68
- filter_keys=None,
69
- state_dict_map=None,
70
- **kwargs,
71
- ):
72
- """Create a model
73
- Lookup model's entrypoint function and pass relevant args to create a new model.
74
- """
75
- # Parameters that aren't supported by all models or are intended to only override model defaults if set
76
- # should default to None in command line args/cfg. Remove them if they are present and not set so that
77
- # non-supporting models don't break and default args remain in effect.
78
- kwargs = {k: v for k, v in kwargs.items() if v is not None}
79
-
80
- model_source, model_name = parse_model_name(model_name)
81
- if model_source == "hf-hub":
82
- assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub."
83
- # For model names specified in the form `hf-hub:path/architecture_name@revision`,
84
- # load model weights + pretrained_cfg from Hugging Face hub.
85
- pretrained_cfg, model_name = load_model_config_from_hf(model_name)
86
- else:
87
- model_name, pretrained_tag = split_model_name_tag(model_name)
88
- if not pretrained_cfg:
89
- # a valid pretrained_cfg argument takes priority over tag in model name
90
- pretrained_cfg = pretrained_tag
91
-
92
- if not is_model(model_name):
93
- raise RuntimeError("Unknown model (%s)" % model_name)
94
-
95
- create_fn = model_entrypoint(model_name)
96
- with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
97
- model = create_fn(
98
- pretrained=pretrained,
99
- pretrained_cfg=pretrained_cfg,
100
- pretrained_cfg_overlay=pretrained_cfg_overlay,
101
- **kwargs,
102
- )
103
-
104
- if checkpoint_path:
105
- load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map)
106
-
107
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mivolo/model/cross_bottleneck_attn.py DELETED
@@ -1,116 +0,0 @@
1
- """
2
- Code based on timm https://github.com/huggingface/pytorch-image-models
3
-
4
- Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
- """
6
-
7
- import torch
8
- import torch.nn as nn
9
- from timm.layers.bottleneck_attn import PosEmbedRel
10
- from timm.layers.helpers import make_divisible
11
- from timm.layers.mlp import Mlp
12
- from timm.layers.trace_utils import _assert
13
- from timm.layers.weight_init import trunc_normal_
14
-
15
-
16
- class CrossBottleneckAttn(nn.Module):
17
- def __init__(
18
- self,
19
- dim,
20
- dim_out=None,
21
- feat_size=None,
22
- stride=1,
23
- num_heads=4,
24
- dim_head=None,
25
- qk_ratio=1.0,
26
- qkv_bias=False,
27
- scale_pos_embed=False,
28
- ):
29
- super().__init__()
30
- assert feat_size is not None, "A concrete feature size matching expected input (H, W) is required"
31
- dim_out = dim_out or dim
32
- assert dim_out % num_heads == 0
33
-
34
- self.num_heads = num_heads
35
- self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
36
- self.dim_head_v = dim_out // self.num_heads
37
- self.dim_out_qk = num_heads * self.dim_head_qk
38
- self.dim_out_v = num_heads * self.dim_head_v
39
- self.scale = self.dim_head_qk**-0.5
40
- self.scale_pos_embed = scale_pos_embed
41
-
42
- self.qkv_f = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
43
- self.qkv_p = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
44
-
45
- # NOTE I'm only supporting relative pos embedding for now
46
- self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
47
-
48
- self.norm = nn.LayerNorm([self.dim_out_v * 2, *feat_size])
49
- mlp_ratio = 4
50
- self.mlp = Mlp(
51
- in_features=self.dim_out_v * 2,
52
- hidden_features=int(dim * mlp_ratio),
53
- act_layer=nn.GELU,
54
- out_features=dim_out,
55
- drop=0,
56
- use_conv=True,
57
- )
58
-
59
- self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
60
- self.reset_parameters()
61
-
62
- def reset_parameters(self):
63
- trunc_normal_(self.qkv_f.weight, std=self.qkv_f.weight.shape[1] ** -0.5) # fan-in
64
- trunc_normal_(self.qkv_p.weight, std=self.qkv_p.weight.shape[1] ** -0.5) # fan-in
65
- trunc_normal_(self.pos_embed.height_rel, std=self.scale)
66
- trunc_normal_(self.pos_embed.width_rel, std=self.scale)
67
-
68
- def get_qkv(self, x, qvk_conv):
69
- B, C, H, W = x.shape
70
-
71
- x = qvk_conv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
72
-
73
- q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
74
-
75
- q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
76
- k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
77
- v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
78
-
79
- return q, k, v
80
-
81
- def apply_attn(self, q, k, v, B, H, W, dropout=None):
82
- if self.scale_pos_embed:
83
- attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
84
- else:
85
- attn = (q @ k) * self.scale + self.pos_embed(q)
86
- attn = attn.softmax(dim=-1)
87
- if dropout:
88
- attn = dropout(attn)
89
-
90
- out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
91
- return out
92
-
93
- def forward(self, x):
94
- B, C, H, W = x.shape
95
-
96
- dim = int(C / 2)
97
- x1 = x[:, :dim, :, :]
98
- x2 = x[:, dim:, :, :]
99
-
100
- _assert(H == self.pos_embed.height, "")
101
- _assert(W == self.pos_embed.width, "")
102
-
103
- q_f, k_f, v_f = self.get_qkv(x1, self.qkv_f)
104
- q_p, k_p, v_p = self.get_qkv(x2, self.qkv_p)
105
-
106
- # person to face
107
- out_f = self.apply_attn(q_f, k_p, v_p, B, H, W)
108
- # face to person
109
- out_p = self.apply_attn(q_p, k_f, v_f, B, H, W)
110
-
111
- x_pf = torch.cat((out_f, out_p), dim=1) # B, dim_out * 2, H, W
112
- x_pf = self.norm(x_pf)
113
- x_pf = self.mlp(x_pf) # B, dim_out, H, W
114
-
115
- out = self.pool(x_pf)
116
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mivolo/model/mi_volo.py DELETED
@@ -1,229 +0,0 @@
1
- import logging
2
- from typing import Optional
3
-
4
- import numpy as np
5
- import torch
6
- from mivolo.data.misc import prepare_classification_images
7
- from mivolo.model.create_timm_model import create_model
8
- from mivolo.structures import PersonAndFaceCrops, PersonAndFaceResult
9
- from timm.data import resolve_data_config
10
-
11
- _logger = logging.getLogger("MiVOLO")
12
- has_compile = hasattr(torch, "compile")
13
-
14
-
15
- class Meta:
16
- def __init__(self):
17
- self.min_age = None
18
- self.max_age = None
19
- self.avg_age = None
20
- self.num_classes = None
21
-
22
- self.in_chans = 3
23
- self.with_persons_model = False
24
- self.disable_faces = False
25
- self.use_persons = True
26
- self.only_age = False
27
-
28
- self.num_classes_gender = 2
29
-
30
- def load_from_ckpt(self, ckpt_path: str, disable_faces: bool = False, use_persons: bool = True) -> "Meta":
31
-
32
- state = torch.load(ckpt_path, map_location="cpu")
33
-
34
- self.min_age = state["min_age"]
35
- self.max_age = state["max_age"]
36
- self.avg_age = state["avg_age"]
37
- self.only_age = state["no_gender"]
38
-
39
- only_age = state["no_gender"]
40
-
41
- self.disable_faces = disable_faces
42
- if "with_persons_model" in state:
43
- self.with_persons_model = state["with_persons_model"]
44
- else:
45
- self.with_persons_model = True if "patch_embed.conv1.0.weight" in state["state_dict"] else False
46
-
47
- self.num_classes = 1 if only_age else 3
48
- self.in_chans = 3 if not self.with_persons_model else 6
49
- self.use_persons = use_persons and self.with_persons_model
50
-
51
- if not self.with_persons_model and self.disable_faces:
52
- raise ValueError("You can not use disable-faces for faces-only model")
53
- if self.with_persons_model and self.disable_faces and not self.use_persons:
54
- raise ValueError("You can not disable faces and persons together")
55
-
56
- return self
57
-
58
- def __str__(self):
59
- attrs = vars(self)
60
- attrs.update({"use_person_crops": self.use_person_crops, "use_face_crops": self.use_face_crops})
61
- return ", ".join("%s: %s" % item for item in attrs.items())
62
-
63
- @property
64
- def use_person_crops(self) -> bool:
65
- return self.with_persons_model and self.use_persons
66
-
67
- @property
68
- def use_face_crops(self) -> bool:
69
- return not self.disable_faces or not self.with_persons_model
70
-
71
-
72
- class MiVOLO:
73
- def __init__(
74
- self,
75
- ckpt_path: str,
76
- device: str = "cuda",
77
- half: bool = True,
78
- disable_faces: bool = False,
79
- use_persons: bool = True,
80
- verbose: bool = False,
81
- torchcompile: Optional[str] = None,
82
- ):
83
- self.verbose = verbose
84
- self.device = torch.device(device)
85
- self.half = half and self.device.type != "cpu"
86
-
87
- self.meta: Meta = Meta().load_from_ckpt(ckpt_path, disable_faces, use_persons)
88
- if self.verbose:
89
- _logger.info(f"Model meta:\n{str(self.meta)}")
90
-
91
- model_name = "mivolo_d1_224"
92
- self.model = create_model(
93
- model_name=model_name,
94
- num_classes=self.meta.num_classes,
95
- in_chans=self.meta.in_chans,
96
- pretrained=False,
97
- checkpoint_path=ckpt_path,
98
- filter_keys=["fds."],
99
- )
100
- self.param_count = sum([m.numel() for m in self.model.parameters()])
101
- _logger.info(f"Model {model_name} created, param count: {self.param_count}")
102
-
103
- self.data_config = resolve_data_config(
104
- model=self.model,
105
- verbose=verbose,
106
- use_test_size=True,
107
- )
108
- self.data_config["crop_pct"] = 1.0
109
- c, h, w = self.data_config["input_size"]
110
- assert h == w, "Incorrect data_config"
111
- self.input_size = w
112
-
113
- self.model = self.model.to(self.device)
114
-
115
- if torchcompile:
116
- assert has_compile, "A version of torch w/ torch.compile() is required for --compile, possibly a nightly."
117
- torch._dynamo.reset()
118
- self.model = torch.compile(self.model, backend=torchcompile)
119
-
120
- self.model.eval()
121
- if self.half:
122
- self.model = self.model.half()
123
-
124
- def warmup(self, batch_size: int, steps=10):
125
- if self.meta.with_persons_model:
126
- input_size = (6, self.input_size, self.input_size)
127
- else:
128
- input_size = self.data_config["input_size"]
129
-
130
- input = torch.randn((batch_size,) + tuple(input_size)).to(self.device)
131
-
132
- for _ in range(steps):
133
- out = self.inference(input) # noqa: F841
134
-
135
- if torch.cuda.is_available():
136
- torch.cuda.synchronize()
137
-
138
- def inference(self, model_input: torch.tensor) -> torch.tensor:
139
-
140
- with torch.no_grad():
141
- if self.half:
142
- model_input = model_input.half()
143
- output = self.model(model_input)
144
- return output
145
-
146
- def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
147
- if detected_bboxes.n_objects == 0:
148
- return
149
-
150
- faces_input, person_input, faces_inds, bodies_inds = self.prepare_crops(image, detected_bboxes)
151
-
152
- if self.meta.with_persons_model:
153
- model_input = torch.cat((faces_input, person_input), dim=1)
154
- else:
155
- model_input = faces_input
156
- output = self.inference(model_input)
157
-
158
- # write gender and age results into detected_bboxes
159
- self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds)
160
-
161
- def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds):
162
- if self.meta.only_age:
163
- age_output = output
164
- gender_probs, gender_indx = None, None
165
- else:
166
- age_output = output[:, 2]
167
- gender_output = output[:, :2].softmax(-1)
168
- gender_probs, gender_indx = gender_output.topk(1)
169
-
170
- assert output.shape[0] == len(faces_inds) == len(bodies_inds)
171
-
172
- # per face
173
- for index in range(output.shape[0]):
174
- face_ind = faces_inds[index]
175
- body_ind = bodies_inds[index]
176
-
177
- # get_age
178
- age = age_output[index].item()
179
- age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age
180
- age = round(age, 2)
181
-
182
- detected_bboxes.set_age(face_ind, age)
183
- detected_bboxes.set_age(body_ind, age)
184
-
185
- _logger.info(f"\tage: {age}")
186
-
187
- if gender_probs is not None:
188
- gender = "male" if gender_indx[index].item() == 0 else "female"
189
- gender_score = gender_probs[index].item()
190
-
191
- _logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]")
192
-
193
- detected_bboxes.set_gender(face_ind, gender, gender_score)
194
- detected_bboxes.set_gender(body_ind, gender, gender_score)
195
-
196
- def prepare_crops(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
197
-
198
- if self.meta.use_person_crops and self.meta.use_face_crops:
199
- detected_bboxes.associate_faces_with_persons()
200
-
201
- crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image)
202
- (bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies(
203
- self.meta.use_person_crops, self.meta.use_face_crops
204
- )
205
-
206
- if not self.meta.use_face_crops:
207
- assert all(f is None for f in faces_crops)
208
-
209
- faces_input = prepare_classification_images(
210
- faces_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
211
- )
212
-
213
- if not self.meta.use_person_crops:
214
- assert all(p is None for p in bodies_crops)
215
-
216
- person_input = prepare_classification_images(
217
- bodies_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
218
- )
219
-
220
- _logger.info(
221
- f"faces_input: {faces_input.shape if faces_input is not None else None}, "
222
- f"person_input: {person_input.shape if person_input is not None else None}"
223
- )
224
-
225
- return faces_input, person_input, faces_inds, bodies_inds
226
-
227
-
228
- if __name__ == "__main__":
229
- model = MiVOLO("../pretrained/checkpoint-377.pth.tar", half=True, device="cuda:0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mivolo/model/mivolo_model.py DELETED
@@ -1,402 +0,0 @@
1
- """
2
- Code adapted from timm https://github.com/huggingface/pytorch-image-models
3
-
4
- Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
- """
6
-
7
- import torch
8
- import torch.nn as nn
9
- from mivolo.model.cross_bottleneck_attn import CrossBottleneckAttn
10
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
- from timm.layers import trunc_normal_
12
- from timm.models._builder import build_model_with_cfg
13
- from timm.models._registry import register_model
14
- from timm.models.volo import VOLO
15
-
16
- __all__ = ["MiVOLOModel"] # model_registry will add each entrypoint fn to this
17
-
18
-
19
- def _cfg(url="", **kwargs):
20
- return {
21
- "url": url,
22
- "num_classes": 1000,
23
- "input_size": (3, 224, 224),
24
- "pool_size": None,
25
- "crop_pct": 0.96,
26
- "interpolation": "bicubic",
27
- "fixed_input_size": True,
28
- "mean": IMAGENET_DEFAULT_MEAN,
29
- "std": IMAGENET_DEFAULT_STD,
30
- "first_conv": None,
31
- "classifier": ("head", "aux_head"),
32
- **kwargs,
33
- }
34
-
35
-
36
- default_cfgs = {
37
- "mivolo_d1_224": _cfg(
38
- url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar", crop_pct=0.96
39
- ),
40
- "mivolo_d1_384": _cfg(
41
- url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar",
42
- crop_pct=1.0,
43
- input_size=(3, 384, 384),
44
- ),
45
- "mivolo_d2_224": _cfg(
46
- url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar", crop_pct=0.96
47
- ),
48
- "mivolo_d2_384": _cfg(
49
- url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar",
50
- crop_pct=1.0,
51
- input_size=(3, 384, 384),
52
- ),
53
- "mivolo_d3_224": _cfg(
54
- url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar", crop_pct=0.96
55
- ),
56
- "mivolo_d3_448": _cfg(
57
- url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar",
58
- crop_pct=1.0,
59
- input_size=(3, 448, 448),
60
- ),
61
- "mivolo_d4_224": _cfg(
62
- url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar", crop_pct=0.96
63
- ),
64
- "mivolo_d4_448": _cfg(
65
- url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar",
66
- crop_pct=1.15,
67
- input_size=(3, 448, 448),
68
- ),
69
- "mivolo_d5_224": _cfg(
70
- url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar", crop_pct=0.96
71
- ),
72
- "mivolo_d5_448": _cfg(
73
- url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar",
74
- crop_pct=1.15,
75
- input_size=(3, 448, 448),
76
- ),
77
- "mivolo_d5_512": _cfg(
78
- url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar",
79
- crop_pct=1.15,
80
- input_size=(3, 512, 512),
81
- ),
82
- }
83
-
84
-
85
- def get_output_size(input_shape, conv_layer):
86
- padding = conv_layer.padding
87
- dilation = conv_layer.dilation
88
- kernel_size = conv_layer.kernel_size
89
- stride = conv_layer.stride
90
-
91
- output_size = [
92
- ((input_shape[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[i]) + 1 for i in range(2)
93
- ]
94
- return output_size
95
-
96
-
97
- def get_output_size_module(input_size, stem):
98
- output_size = input_size
99
-
100
- for module in stem:
101
- if isinstance(module, nn.Conv2d):
102
- output_size = [
103
- (
104
- (output_size[i] + 2 * module.padding[i] - module.dilation[i] * (module.kernel_size[i] - 1) - 1)
105
- // module.stride[i]
106
- )
107
- + 1
108
- for i in range(2)
109
- ]
110
-
111
- return output_size
112
-
113
-
114
- class PatchEmbed(nn.Module):
115
- """Image to Patch Embedding."""
116
-
117
- def __init__(
118
- self, img_size=224, stem_conv=False, stem_stride=1, patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384
119
- ):
120
- super().__init__()
121
- assert patch_size in [4, 8, 16]
122
- assert in_chans in [3, 6]
123
- self.with_persons_model = in_chans == 6
124
- self.use_cross_attn = True
125
-
126
- if stem_conv:
127
- if not self.with_persons_model:
128
- self.conv = self.create_stem(stem_stride, in_chans, hidden_dim)
129
- else:
130
- self.conv = True # just to match interface
131
- # split
132
- self.conv1 = self.create_stem(stem_stride, 3, hidden_dim)
133
- self.conv2 = self.create_stem(stem_stride, 3, hidden_dim)
134
- else:
135
- self.conv = None
136
-
137
- if self.with_persons_model:
138
-
139
- self.proj1 = nn.Conv2d(
140
- hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
141
- )
142
- self.proj2 = nn.Conv2d(
143
- hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
144
- )
145
-
146
- stem_out_shape = get_output_size_module((img_size, img_size), self.conv1)
147
- self.proj_output_size = get_output_size(stem_out_shape, self.proj1)
148
-
149
- self.map = CrossBottleneckAttn(embed_dim, dim_out=embed_dim, num_heads=1, feat_size=self.proj_output_size)
150
-
151
- else:
152
- self.proj = nn.Conv2d(
153
- hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
154
- )
155
-
156
- self.patch_dim = img_size // patch_size
157
- self.num_patches = self.patch_dim**2
158
-
159
- def create_stem(self, stem_stride, in_chans, hidden_dim):
160
- return nn.Sequential(
161
- nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112
162
- nn.BatchNorm2d(hidden_dim),
163
- nn.ReLU(inplace=True),
164
- nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
165
- nn.BatchNorm2d(hidden_dim),
166
- nn.ReLU(inplace=True),
167
- nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
168
- nn.BatchNorm2d(hidden_dim),
169
- nn.ReLU(inplace=True),
170
- )
171
-
172
- def forward(self, x):
173
- if self.conv is not None:
174
- if self.with_persons_model:
175
- x1 = x[:, :3]
176
- x2 = x[:, 3:]
177
-
178
- x1 = self.conv1(x1)
179
- x1 = self.proj1(x1)
180
-
181
- x2 = self.conv2(x2)
182
- x2 = self.proj2(x2)
183
-
184
- x = torch.cat([x1, x2], dim=1)
185
- x = self.map(x)
186
- else:
187
- x = self.conv(x)
188
- x = self.proj(x) # B, C, H, W
189
-
190
- return x
191
-
192
-
193
- class MiVOLOModel(VOLO):
194
- """
195
- Vision Outlooker, the main class of our model
196
- """
197
-
198
- def __init__(
199
- self,
200
- layers,
201
- img_size=224,
202
- in_chans=3,
203
- num_classes=1000,
204
- global_pool="token",
205
- patch_size=8,
206
- stem_hidden_dim=64,
207
- embed_dims=None,
208
- num_heads=None,
209
- downsamples=(True, False, False, False),
210
- outlook_attention=(True, False, False, False),
211
- mlp_ratio=3.0,
212
- qkv_bias=False,
213
- drop_rate=0.0,
214
- attn_drop_rate=0.0,
215
- drop_path_rate=0.0,
216
- norm_layer=nn.LayerNorm,
217
- post_layers=("ca", "ca"),
218
- use_aux_head=True,
219
- use_mix_token=False,
220
- pooling_scale=2,
221
- ):
222
- super().__init__(
223
- layers,
224
- img_size,
225
- in_chans,
226
- num_classes,
227
- global_pool,
228
- patch_size,
229
- stem_hidden_dim,
230
- embed_dims,
231
- num_heads,
232
- downsamples,
233
- outlook_attention,
234
- mlp_ratio,
235
- qkv_bias,
236
- drop_rate,
237
- attn_drop_rate,
238
- drop_path_rate,
239
- norm_layer,
240
- post_layers,
241
- use_aux_head,
242
- use_mix_token,
243
- pooling_scale,
244
- )
245
-
246
- self.patch_embed = PatchEmbed(
247
- stem_conv=True,
248
- stem_stride=2,
249
- patch_size=patch_size,
250
- in_chans=in_chans,
251
- hidden_dim=stem_hidden_dim,
252
- embed_dim=embed_dims[0],
253
- )
254
-
255
- trunc_normal_(self.pos_embed, std=0.02)
256
- self.apply(self._init_weights)
257
-
258
- def forward_features(self, x):
259
- x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
260
-
261
- # step2: tokens learning in the two stages
262
- x = self.forward_tokens(x)
263
-
264
- # step3: post network, apply class attention or not
265
- if self.post_network is not None:
266
- x = self.forward_cls(x)
267
- x = self.norm(x)
268
- return x
269
-
270
- def forward_head(self, x, pre_logits: bool = False, targets=None, epoch=None):
271
- if self.global_pool == "avg":
272
- out = x.mean(dim=1)
273
- elif self.global_pool == "token":
274
- out = x[:, 0]
275
- else:
276
- out = x
277
- if pre_logits:
278
- return out
279
-
280
- features = out
281
- fds_enabled = hasattr(self, "_fds_forward")
282
- if fds_enabled:
283
- features = self._fds_forward(features, targets, epoch)
284
-
285
- out = self.head(features)
286
- if self.aux_head is not None:
287
- # generate classes in all feature tokens, see token labeling
288
- aux = self.aux_head(x[:, 1:])
289
- out = out + 0.5 * aux.max(1)[0]
290
-
291
- return (out, features) if (fds_enabled and self.training) else out
292
-
293
- def forward(self, x, targets=None, epoch=None):
294
- """simplified forward (without mix token training)"""
295
- x = self.forward_features(x)
296
- x = self.forward_head(x, targets=targets, epoch=epoch)
297
- return x
298
-
299
-
300
- def _create_mivolo(variant, pretrained=False, **kwargs):
301
- if kwargs.get("features_only", None):
302
- raise RuntimeError("features_only not implemented for Vision Transformer models.")
303
- return build_model_with_cfg(MiVOLOModel, variant, pretrained, **kwargs)
304
-
305
-
306
- @register_model
307
- def mivolo_d1_224(pretrained=False, **kwargs):
308
- model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
309
- model = _create_mivolo("mivolo_d1_224", pretrained=pretrained, **model_args)
310
- return model
311
-
312
-
313
- @register_model
314
- def mivolo_d1_384(pretrained=False, **kwargs):
315
- model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
316
- model = _create_mivolo("mivolo_d1_384", pretrained=pretrained, **model_args)
317
- return model
318
-
319
-
320
- @register_model
321
- def mivolo_d2_224(pretrained=False, **kwargs):
322
- model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
323
- model = _create_mivolo("mivolo_d2_224", pretrained=pretrained, **model_args)
324
- return model
325
-
326
-
327
- @register_model
328
- def mivolo_d2_384(pretrained=False, **kwargs):
329
- model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
330
- model = _create_mivolo("mivolo_d2_384", pretrained=pretrained, **model_args)
331
- return model
332
-
333
-
334
- @register_model
335
- def mivolo_d3_224(pretrained=False, **kwargs):
336
- model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
337
- model = _create_mivolo("mivolo_d3_224", pretrained=pretrained, **model_args)
338
- return model
339
-
340
-
341
- @register_model
342
- def mivolo_d3_448(pretrained=False, **kwargs):
343
- model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
344
- model = _create_mivolo("mivolo_d3_448", pretrained=pretrained, **model_args)
345
- return model
346
-
347
-
348
- @register_model
349
- def mivolo_d4_224(pretrained=False, **kwargs):
350
- model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
351
- model = _create_mivolo("mivolo_d4_224", pretrained=pretrained, **model_args)
352
- return model
353
-
354
-
355
- @register_model
356
- def mivolo_d4_448(pretrained=False, **kwargs):
357
- """VOLO-D4 model, Params: 193M"""
358
- model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
359
- model = _create_mivolo("mivolo_d4_448", pretrained=pretrained, **model_args)
360
- return model
361
-
362
-
363
- @register_model
364
- def mivolo_d5_224(pretrained=False, **kwargs):
365
- model_args = dict(
366
- layers=(12, 12, 20, 4),
367
- embed_dims=(384, 768, 768, 768),
368
- num_heads=(12, 16, 16, 16),
369
- mlp_ratio=4,
370
- stem_hidden_dim=128,
371
- **kwargs
372
- )
373
- model = _create_mivolo("mivolo_d5_224", pretrained=pretrained, **model_args)
374
- return model
375
-
376
-
377
- @register_model
378
- def mivolo_d5_448(pretrained=False, **kwargs):
379
- model_args = dict(
380
- layers=(12, 12, 20, 4),
381
- embed_dims=(384, 768, 768, 768),
382
- num_heads=(12, 16, 16, 16),
383
- mlp_ratio=4,
384
- stem_hidden_dim=128,
385
- **kwargs
386
- )
387
- model = _create_mivolo("mivolo_d5_448", pretrained=pretrained, **model_args)
388
- return model
389
-
390
-
391
- @register_model
392
- def mivolo_d5_512(pretrained=False, **kwargs):
393
- model_args = dict(
394
- layers=(12, 12, 20, 4),
395
- embed_dims=(384, 768, 768, 768),
396
- num_heads=(12, 16, 16, 16),
397
- mlp_ratio=4,
398
- stem_hidden_dim=128,
399
- **kwargs
400
- )
401
- model = _create_mivolo("mivolo_d5_512", pretrained=pretrained, **model_args)
402
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mivolo/model/yolo_detector.py DELETED
@@ -1,48 +0,0 @@
1
- import os
2
- from typing import Dict, Union
3
-
4
- import numpy as np
5
- import PIL
6
- import torch
7
- from mivolo.structures import PersonAndFaceResult
8
- from ultralytics import YOLO
9
- # from ultralytics.yolo.engine.results import Results
10
-
11
- # because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing
12
- os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
13
-
14
-
15
- class Detector:
16
- def __init__(
17
- self,
18
- weights: str,
19
- device: str = "cuda",
20
- half: bool = True,
21
- verbose: bool = False,
22
- conf_thresh: float = 0.4,
23
- iou_thresh: float = 0.7,
24
- ):
25
- self.yolo = YOLO(weights)
26
- self.yolo.fuse()
27
-
28
- self.device = torch.device(device)
29
- self.half = half and self.device.type != "cpu"
30
-
31
- if self.half:
32
- self.yolo.model = self.yolo.model.half()
33
-
34
- self.detector_names: Dict[int, str] = self.yolo.model.names
35
-
36
- # init yolo.predictor
37
- self.detector_kwargs = {
38
- "conf": conf_thresh, "iou": iou_thresh, "half": self.half, "verbose": verbose}
39
- # self.yolo.predict(**self.detector_kwargs)
40
-
41
- def predict(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult:
42
- results = self.yolo.predict(image, **self.detector_kwargs)[0]
43
- return PersonAndFaceResult(results)
44
-
45
- def track(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult:
46
- results = self.yolo.track(
47
- image, persist=True, **self.detector_kwargs)[0]
48
- return PersonAndFaceResult(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mivolo/predictor.py DELETED
@@ -1,68 +0,0 @@
1
- from collections import defaultdict
2
- from typing import Dict, Generator, List, Optional, Tuple
3
-
4
- import cv2
5
- import numpy as np
6
- import tqdm
7
- from mivolo.model.mi_volo import MiVOLO
8
- from mivolo.model.yolo_detector import Detector
9
- from mivolo.structures import AGE_GENDER_TYPE, PersonAndFaceResult
10
-
11
-
12
- class Predictor:
13
- def __init__(self, config, verbose: bool = False):
14
- self.detector = Detector(config.detector_weights, config.device, verbose=verbose)
15
- self.age_gender_model = MiVOLO(
16
- config.checkpoint,
17
- config.device,
18
- half=True,
19
- use_persons=config.with_persons,
20
- disable_faces=config.disable_faces,
21
- verbose=verbose,
22
- )
23
- self.draw = config.draw
24
-
25
- def recognize(self, image: np.ndarray) -> Tuple[PersonAndFaceResult, Optional[np.ndarray]]:
26
- detected_objects: PersonAndFaceResult = self.detector.predict(image)
27
- self.age_gender_model.predict(image, detected_objects)
28
-
29
- out_im = None
30
- if self.draw:
31
- # plot results on image
32
- out_im = detected_objects.plot()
33
-
34
- return detected_objects, out_im
35
-
36
- def recognize_video(self, source: str) -> Generator:
37
- video_capture = cv2.VideoCapture(source)
38
- if not video_capture.isOpened():
39
- raise ValueError(f"Failed to open video source {source}")
40
-
41
- detected_objects_history: Dict[int, List[AGE_GENDER_TYPE]] = defaultdict(list)
42
-
43
- total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
44
- for _ in tqdm.tqdm(range(total_frames)):
45
- ret, frame = video_capture.read()
46
- if not ret:
47
- break
48
-
49
- detected_objects: PersonAndFaceResult = self.detector.track(frame)
50
- self.age_gender_model.predict(frame, detected_objects)
51
-
52
- current_frame_objs = detected_objects.get_results_for_tracking()
53
- cur_persons: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[0]
54
- cur_faces: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[1]
55
-
56
- # add tr_persons and tr_faces to history
57
- for guid, data in cur_persons.items():
58
- # not useful for tracking :)
59
- if None not in data:
60
- detected_objects_history[guid].append(data)
61
- for guid, data in cur_faces.items():
62
- if None not in data:
63
- detected_objects_history[guid].append(data)
64
-
65
- detected_objects.set_tracked_age_gender(detected_objects_history)
66
- if self.draw:
67
- frame = detected_objects.plot()
68
- yield detected_objects_history, frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mivolo/structures.py DELETED
@@ -1,464 +0,0 @@
1
- import math
2
- import os
3
- from copy import deepcopy
4
- from typing import Dict, List, Optional, Tuple
5
-
6
- import cv2
7
- import numpy as np
8
- import torch
9
- from mivolo.data.misc import aggregate_votes_winsorized, assign_faces, box_iou, cropout_black_parts
10
- from ultralytics.yolo.engine.results import Results
11
- from ultralytics.yolo.utils.plotting import Annotator, colors
12
-
13
- # because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing
14
- os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
15
-
16
- AGE_GENDER_TYPE = Tuple[float, str]
17
-
18
-
19
- class PersonAndFaceCrops:
20
- def __init__(self):
21
- # int: index of person along results
22
- self.crops_persons: Dict[int, np.ndarray] = {}
23
-
24
- # int: index of face along results
25
- self.crops_faces: Dict[int, np.ndarray] = {}
26
-
27
- # int: index of face along results
28
- self.crops_faces_wo_body: Dict[int, np.ndarray] = {}
29
-
30
- # int: index of person along results
31
- self.crops_persons_wo_face: Dict[int, np.ndarray] = {}
32
-
33
- def _add_to_output(
34
- self, crops: Dict[int, np.ndarray], out_crops: List[np.ndarray], out_crop_inds: List[Optional[int]]
35
- ):
36
- inds_to_add = list(crops.keys())
37
- crops_to_add = list(crops.values())
38
- out_crops.extend(crops_to_add)
39
- out_crop_inds.extend(inds_to_add)
40
-
41
- def _get_all_faces(
42
- self, use_persons: bool, use_faces: bool
43
- ) -> Tuple[List[Optional[int]], List[Optional[np.ndarray]]]:
44
- """
45
- Returns
46
- if use_persons and use_faces
47
- faces: faces_with_bodies + faces_without_bodies + [None] * len(crops_persons_wo_face)
48
- if use_persons and not use_faces
49
- faces: [None] * n_persons
50
- if not use_persons and use_faces:
51
- faces: faces_with_bodies + faces_without_bodies
52
- """
53
-
54
- def add_none_to_output(faces_inds, faces_crops, num):
55
- faces_inds.extend([None for _ in range(num)])
56
- faces_crops.extend([None for _ in range(num)])
57
-
58
- faces_inds: List[Optional[int]] = []
59
- faces_crops: List[Optional[np.ndarray]] = []
60
-
61
- if not use_faces:
62
- add_none_to_output(faces_inds, faces_crops, len(self.crops_persons) + len(self.crops_persons_wo_face))
63
- return faces_inds, faces_crops
64
-
65
- self._add_to_output(self.crops_faces, faces_crops, faces_inds)
66
- self._add_to_output(self.crops_faces_wo_body, faces_crops, faces_inds)
67
-
68
- if use_persons:
69
- add_none_to_output(faces_inds, faces_crops, len(self.crops_persons_wo_face))
70
-
71
- return faces_inds, faces_crops
72
-
73
- def _get_all_bodies(
74
- self, use_persons: bool, use_faces: bool
75
- ) -> Tuple[List[Optional[int]], List[Optional[np.ndarray]]]:
76
- """
77
- Returns
78
- if use_persons and use_faces
79
- persons: bodies_with_faces + [None] * len(faces_without_bodies) + bodies_without_faces
80
- if use_persons and not use_faces
81
- persons: bodies_with_faces + bodies_without_faces
82
- if not use_persons and use_faces
83
- persons: [None] * n_faces
84
- """
85
-
86
- def add_none_to_output(bodies_inds, bodies_crops, num):
87
- bodies_inds.extend([None for _ in range(num)])
88
- bodies_crops.extend([None for _ in range(num)])
89
-
90
- bodies_inds: List[Optional[int]] = []
91
- bodies_crops: List[Optional[np.ndarray]] = []
92
-
93
- if not use_persons:
94
- add_none_to_output(bodies_inds, bodies_crops, len(self.crops_faces) + len(self.crops_faces_wo_body))
95
- return bodies_inds, bodies_crops
96
-
97
- self._add_to_output(self.crops_persons, bodies_crops, bodies_inds)
98
- if use_faces:
99
- add_none_to_output(bodies_inds, bodies_crops, len(self.crops_faces_wo_body))
100
-
101
- self._add_to_output(self.crops_persons_wo_face, bodies_crops, bodies_inds)
102
-
103
- return bodies_inds, bodies_crops
104
-
105
- def get_faces_with_bodies(self, use_persons: bool, use_faces: bool):
106
- """
107
- Return
108
- faces: faces_with_bodies, faces_without_bodies, [None] * len(crops_persons_wo_face)
109
- persons: bodies_with_faces, [None] * len(faces_without_bodies), bodies_without_faces
110
- """
111
-
112
- bodies_inds, bodies_crops = self._get_all_bodies(use_persons, use_faces)
113
- faces_inds, faces_crops = self._get_all_faces(use_persons, use_faces)
114
-
115
- return (bodies_inds, bodies_crops), (faces_inds, faces_crops)
116
-
117
- def save(self, out_dir="output"):
118
- ind = 0
119
- os.makedirs(out_dir, exist_ok=True)
120
- for crops in [self.crops_persons, self.crops_faces, self.crops_faces_wo_body, self.crops_persons_wo_face]:
121
- for crop in crops.values():
122
- if crop is None:
123
- continue
124
- out_name = os.path.join(out_dir, f"{ind}_crop.jpg")
125
- cv2.imwrite(out_name, crop)
126
- ind += 1
127
-
128
-
129
- class PersonAndFaceResult:
130
- def __init__(self, results: Results):
131
-
132
- self.yolo_results = results
133
- names = set(results.names.values())
134
- assert "person" in names and "face" in names
135
-
136
- # initially no faces and persons are associated to each other
137
- self.face_to_person_map: Dict[int, Optional[int]] = {ind: None for ind in self.get_bboxes_inds("face")}
138
- self.unassigned_persons_inds: List[int] = self.get_bboxes_inds("person")
139
- n_objects = len(self.yolo_results.boxes)
140
- self.ages: List[Optional[float]] = [None for _ in range(n_objects)]
141
- self.genders: List[Optional[str]] = [None for _ in range(n_objects)]
142
- self.gender_scores: List[Optional[float]] = [None for _ in range(n_objects)]
143
-
144
- @property
145
- def n_objects(self) -> int:
146
- return len(self.yolo_results.boxes)
147
-
148
- def get_bboxes_inds(self, category: str) -> List[int]:
149
- bboxes: List[int] = []
150
- for ind, det in enumerate(self.yolo_results.boxes):
151
- name = self.yolo_results.names[int(det.cls)]
152
- if name == category:
153
- bboxes.append(ind)
154
-
155
- return bboxes
156
-
157
- def get_distance_to_center(self, bbox_ind: int) -> float:
158
- """
159
- Calculate euclidian distance between bbox center and image center.
160
- """
161
- im_h, im_w = self.yolo_results[bbox_ind].orig_shape
162
- x1, y1, x2, y2 = self.get_bbox_by_ind(bbox_ind).cpu().numpy()
163
- center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
164
- dist = math.dist([center_x, center_y], [im_w / 2, im_h / 2])
165
- return dist
166
-
167
- def plot(
168
- self,
169
- conf=False,
170
- line_width=None,
171
- font_size=None,
172
- font="Arial.ttf",
173
- pil=False,
174
- img=None,
175
- labels=True,
176
- boxes=True,
177
- probs=True,
178
- ages=True,
179
- genders=True,
180
- gender_probs=False,
181
- ):
182
- """
183
- Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
184
- Args:
185
- conf (bool): Whether to plot the detection confidence score.
186
- line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
187
- font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
188
- font (str): The font to use for the text.
189
- pil (bool): Whether to return the image as a PIL Image.
190
- img (numpy.ndarray): Plot to another image. if not, plot to original image.
191
- labels (bool): Whether to plot the label of bounding boxes.
192
- boxes (bool): Whether to plot the bounding boxes.
193
- probs (bool): Whether to plot classification probability
194
- ages (bool): Whether to plot the age of bounding boxes.
195
- genders (bool): Whether to plot the genders of bounding boxes.
196
- gender_probs (bool): Whether to plot gender classification probability
197
- Returns:
198
- (numpy.ndarray): A numpy array of the annotated image.
199
- """
200
-
201
- # return self.yolo_results.plot()
202
- colors_by_ind = {}
203
- for face_ind, person_ind in self.face_to_person_map.items():
204
- if person_ind is not None:
205
- colors_by_ind[face_ind] = face_ind + 2
206
- colors_by_ind[person_ind] = face_ind + 2
207
- else:
208
- colors_by_ind[face_ind] = 0
209
- for person_ind in self.unassigned_persons_inds:
210
- colors_by_ind[person_ind] = 1
211
-
212
- names = self.yolo_results.names
213
- annotator = Annotator(
214
- deepcopy(self.yolo_results.orig_img if img is None else img),
215
- line_width,
216
- font_size,
217
- font,
218
- pil,
219
- example=names,
220
- )
221
- pred_boxes, show_boxes = self.yolo_results.boxes, boxes
222
- pred_probs, show_probs = self.yolo_results.probs, probs
223
-
224
- if pred_boxes and show_boxes:
225
- for bb_ind, (d, age, gender, gender_score) in enumerate(
226
- zip(pred_boxes, self.ages, self.genders, self.gender_scores)
227
- ):
228
- c, conf, guid = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
229
- name = ("" if guid is None else f"id:{guid} ") + names[c]
230
- label = (f"{name} {conf:.2f}" if conf else name) if labels else None
231
- if ages and age is not None:
232
- label += f" {age:.1f}"
233
- if genders and gender is not None:
234
- label += f" {'F' if gender == 'female' else 'M'}"
235
- if gender_probs and gender_score is not None:
236
- label += f" ({gender_score:.1f})"
237
- annotator.box_label(d.xyxy.squeeze(), label, color=colors(colors_by_ind[bb_ind], True))
238
-
239
- if pred_probs is not None and show_probs:
240
- text = f"{', '.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)}, "
241
- annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
242
-
243
- return annotator.result()
244
-
245
- def set_tracked_age_gender(self, tracked_objects: Dict[int, List[AGE_GENDER_TYPE]]):
246
- """
247
- Update age and gender for objects based on history from tracked_objects.
248
- Args:
249
- tracked_objects (dict[int, list[AGE_GENDER_TYPE]]): info about tracked objects by guid
250
- """
251
-
252
- for face_ind, person_ind in self.face_to_person_map.items():
253
- pguid = self._get_id_by_ind(person_ind)
254
- fguid = self._get_id_by_ind(face_ind)
255
-
256
- if fguid == -1 and pguid == -1:
257
- # YOLO might not assign ids for some objects in some cases:
258
- # https://github.com/ultralytics/ultralytics/issues/3830
259
- continue
260
- age, gender = self._gather_tracking_result(tracked_objects, fguid, pguid)
261
- if age is None or gender is None:
262
- continue
263
- self.set_age(face_ind, age)
264
- self.set_gender(face_ind, gender, 1.0)
265
- if pguid != -1:
266
- self.set_gender(person_ind, gender, 1.0)
267
- self.set_age(person_ind, age)
268
-
269
- for person_ind in self.unassigned_persons_inds:
270
- pid = self._get_id_by_ind(person_ind)
271
- if pid == -1:
272
- continue
273
- age, gender = self._gather_tracking_result(tracked_objects, -1, pid)
274
- if age is None or gender is None:
275
- continue
276
- self.set_gender(person_ind, gender, 1.0)
277
- self.set_age(person_ind, age)
278
-
279
- def _get_id_by_ind(self, ind: Optional[int] = None) -> int:
280
- if ind is None:
281
- return -1
282
- obj_id = self.yolo_results.boxes[ind].id
283
- if obj_id is None:
284
- return -1
285
- return obj_id.item()
286
-
287
- def get_bbox_by_ind(self, ind: int, im_h: int = None, im_w: int = None) -> torch.tensor:
288
- bb = self.yolo_results.boxes[ind].xyxy.squeeze().type(torch.int32)
289
- if im_h is not None and im_w is not None:
290
- bb[0] = torch.clamp(bb[0], min=0, max=im_w - 1)
291
- bb[1] = torch.clamp(bb[1], min=0, max=im_h - 1)
292
- bb[2] = torch.clamp(bb[2], min=0, max=im_w - 1)
293
- bb[3] = torch.clamp(bb[3], min=0, max=im_h - 1)
294
- return bb
295
-
296
- def set_age(self, ind: Optional[int], age: float):
297
- if ind is not None:
298
- self.ages[ind] = age
299
-
300
- def set_gender(self, ind: Optional[int], gender: str, gender_score: float):
301
- if ind is not None:
302
- self.genders[ind] = gender
303
- self.gender_scores[ind] = gender_score
304
-
305
- @staticmethod
306
- def _gather_tracking_result(
307
- tracked_objects: Dict[int, List[AGE_GENDER_TYPE]],
308
- fguid: int = -1,
309
- pguid: int = -1,
310
- minimum_sample_size: int = 10,
311
- ) -> AGE_GENDER_TYPE:
312
-
313
- assert fguid != -1 or pguid != -1, "Incorrect tracking behaviour"
314
-
315
- face_ages = [r[0] for r in tracked_objects[fguid] if r[0] is not None] if fguid in tracked_objects else []
316
- face_genders = [r[1] for r in tracked_objects[fguid] if r[1] is not None] if fguid in tracked_objects else []
317
- person_ages = [r[0] for r in tracked_objects[pguid] if r[0] is not None] if pguid in tracked_objects else []
318
- person_genders = [r[1] for r in tracked_objects[pguid] if r[1] is not None] if pguid in tracked_objects else []
319
-
320
- if not face_ages and not person_ages: # both empty
321
- return None, None
322
-
323
- # You can play here with different aggregation strategies
324
- # Face ages - predictions based on face or face + person, depends on history of object
325
- # Person ages - predictions based on person or face + person, depends on history of object
326
-
327
- if len(person_ages + face_ages) >= minimum_sample_size:
328
- age = aggregate_votes_winsorized(person_ages + face_ages)
329
- else:
330
- face_age = np.mean(face_ages) if face_ages else None
331
- person_age = np.mean(person_ages) if person_ages else None
332
- if face_age is None:
333
- face_age = person_age
334
- if person_age is None:
335
- person_age = face_age
336
- age = (face_age + person_age) / 2.0
337
-
338
- genders = face_genders + person_genders
339
- assert len(genders) > 0
340
- # take mode of genders
341
- gender = max(set(genders), key=genders.count)
342
-
343
- return age, gender
344
-
345
- def get_results_for_tracking(self) -> Tuple[Dict[int, AGE_GENDER_TYPE], Dict[int, AGE_GENDER_TYPE]]:
346
- """
347
- Get objects from current frame
348
- """
349
- persons: Dict[int, AGE_GENDER_TYPE] = {}
350
- faces: Dict[int, AGE_GENDER_TYPE] = {}
351
-
352
- names = self.yolo_results.names
353
- pred_boxes = self.yolo_results.boxes
354
- for _, (det, age, gender, _) in enumerate(zip(pred_boxes, self.ages, self.genders, self.gender_scores)):
355
- if det.id is None:
356
- continue
357
- cat_id, _, guid = int(det.cls), float(det.conf), int(det.id.item())
358
- name = names[cat_id]
359
- if name == "person":
360
- persons[guid] = (age, gender)
361
- elif name == "face":
362
- faces[guid] = (age, gender)
363
-
364
- return persons, faces
365
-
366
- def associate_faces_with_persons(self):
367
- face_bboxes_inds: List[int] = self.get_bboxes_inds("face")
368
- person_bboxes_inds: List[int] = self.get_bboxes_inds("person")
369
-
370
- face_bboxes: List[torch.tensor] = [self.get_bbox_by_ind(ind) for ind in face_bboxes_inds]
371
- person_bboxes: List[torch.tensor] = [self.get_bbox_by_ind(ind) for ind in person_bboxes_inds]
372
-
373
- self.face_to_person_map = {ind: None for ind in face_bboxes_inds}
374
- assigned_faces, unassigned_persons_inds = assign_faces(person_bboxes, face_bboxes)
375
-
376
- for face_ind, person_ind in enumerate(assigned_faces):
377
- face_ind = face_bboxes_inds[face_ind]
378
- person_ind = person_bboxes_inds[person_ind] if person_ind is not None else None
379
- self.face_to_person_map[face_ind] = person_ind
380
-
381
- self.unassigned_persons_inds = [person_bboxes_inds[person_ind] for person_ind in unassigned_persons_inds]
382
-
383
- def crop_object(
384
- self, full_image: np.ndarray, ind: int, cut_other_classes: Optional[List[str]] = None
385
- ) -> Optional[np.ndarray]:
386
-
387
- IOU_THRESH = 0.000001
388
- MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4
389
- CROP_ROUND_RATE = 0.3
390
- MIN_PERSON_SIZE = 50
391
-
392
- obj_bbox = self.get_bbox_by_ind(ind, *full_image.shape[:2])
393
- x1, y1, x2, y2 = obj_bbox
394
- cur_cat = self.yolo_results.names[int(self.yolo_results.boxes[ind].cls)]
395
- # get crop of face or person
396
- obj_image = full_image[y1:y2, x1:x2].copy()
397
- crop_h, crop_w = obj_image.shape[:2]
398
-
399
- if cur_cat == "person" and (crop_h < MIN_PERSON_SIZE or crop_w < MIN_PERSON_SIZE):
400
- return None
401
-
402
- if not cut_other_classes:
403
- return obj_image
404
-
405
- # calc iou between obj_bbox and other bboxes
406
- other_bboxes: List[torch.tensor] = [
407
- self.get_bbox_by_ind(other_ind, *full_image.shape[:2]) for other_ind in range(len(self.yolo_results.boxes))
408
- ]
409
-
410
- iou_matrix = box_iou(torch.stack([obj_bbox]), torch.stack(other_bboxes)).cpu().numpy()[0]
411
-
412
- # cut out other objects in case of intersection
413
- for other_ind, (det, iou) in enumerate(zip(self.yolo_results.boxes, iou_matrix)):
414
- other_cat = self.yolo_results.names[int(det.cls)]
415
- if ind == other_ind or iou < IOU_THRESH or other_cat not in cut_other_classes:
416
- continue
417
- o_x1, o_y1, o_x2, o_y2 = det.xyxy.squeeze().type(torch.int32)
418
-
419
- # remap current_person_bbox to reference_person_bbox coordinates
420
- o_x1 = max(o_x1 - x1, 0)
421
- o_y1 = max(o_y1 - y1, 0)
422
- o_x2 = min(o_x2 - x1, crop_w)
423
- o_y2 = min(o_y2 - y1, crop_h)
424
-
425
- if other_cat != "face":
426
- if (o_y1 / crop_h) < CROP_ROUND_RATE:
427
- o_y1 = 0
428
- if ((crop_h - o_y2) / crop_h) < CROP_ROUND_RATE:
429
- o_y2 = crop_h
430
- if (o_x1 / crop_w) < CROP_ROUND_RATE:
431
- o_x1 = 0
432
- if ((crop_w - o_x2) / crop_w) < CROP_ROUND_RATE:
433
- o_x2 = crop_w
434
-
435
- obj_image[o_y1:o_y2, o_x1:o_x2] = 0
436
-
437
- obj_image, remain_ratio = cropout_black_parts(obj_image, CROP_ROUND_RATE)
438
- if remain_ratio < MIN_PERSON_CROP_AFTERCUT_RATIO:
439
- return None
440
-
441
- return obj_image
442
-
443
- def collect_crops(self, image) -> PersonAndFaceCrops:
444
-
445
- crops_data = PersonAndFaceCrops()
446
- for face_ind, person_ind in self.face_to_person_map.items():
447
- face_image = self.crop_object(image, face_ind, cut_other_classes=[])
448
-
449
- if person_ind is None:
450
- crops_data.crops_faces_wo_body[face_ind] = face_image
451
- continue
452
-
453
- person_image = self.crop_object(image, person_ind, cut_other_classes=["face", "person"])
454
-
455
- crops_data.crops_faces[face_ind] = face_image
456
- crops_data.crops_persons[person_ind] = person_image
457
-
458
- for person_ind in self.unassigned_persons_inds:
459
- person_image = self.crop_object(image, person_ind, cut_other_classes=["face", "person"])
460
- crops_data.crops_persons_wo_face[person_ind] = person_image
461
-
462
- # uncomment to save preprocessed crops
463
- # crops_data.save()
464
- return crops_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mivolo/version.py DELETED
@@ -1 +0,0 @@
1
- __version__ = "0.3.0dev"
 
 
model/model_imdb_cross_person_4.22_99.46.pth.tar → model_imdb_cross_person_4.22_99.46.pth.tar RENAMED
File without changes
product2item.py DELETED
@@ -1,95 +0,0 @@
1
- import json
2
- from urllib.parse import quote
3
- from tqdm import tqdm
4
- from bs4 import BeautifulSoup
5
- from selenium import webdriver
6
- from utils import *
7
-
8
- MAX_PAGE = 618
9
-
10
-
11
- def append_dict_to_jsonl(dictionary, file_path='./output/items.jsonl'):
12
- with open(file_path, 'a', encoding='utf-8') as jsonl_file:
13
- json.dump(dictionary, jsonl_file, ensure_ascii=False)
14
- jsonl_file.write('\n')
15
-
16
-
17
- def get_second_links(keyword):
18
- # selenium
19
- option = webdriver.ChromeOptions()
20
- option.add_experimental_option('excludeSwitches', ['enable-automation'])
21
- option.add_argument("--disable-blink-features=AutomationControlled")
22
- # option.add_argument('--headless')
23
- browser = webdriver.Chrome(options=option)
24
- browser.get(f'https://www.taobao.com/list/product/{quote(keyword)}.htm')
25
- # browser.minimize_window()
26
- browser.maximize_window()
27
-
28
- skip_captcha()
29
-
30
- # 遍历product页面下的所有item,直至已加载全部商品
31
- for i in tqdm(range(1, MAX_PAGE + 1)):
32
- browser.execute_script(f'window.scrollTo(0, {i * 500})')
33
- sleeps(0.5, 1.0)
34
- page_str = str(browser.page_source)
35
- if "<title>taobao | 淘寶</title>" in page_str:
36
- print('遭遇验证码...')
37
- return []
38
-
39
- if "已加载全部商品" in page_str:
40
- print('已加载全部商品!')
41
- break
42
-
43
- if "加载错误,请重试" in page_str:
44
- print('加载错误,爬取中断')
45
- break
46
-
47
- html_content = browser.page_source
48
-
49
- # bs4
50
- soup = BeautifulSoup(html_content, 'html.parser')
51
- return [link.get('href') for link in soup.find_all('a', class_='item')]
52
-
53
-
54
- def read_lines_to_array(file_path):
55
- create_dir('./' + os.path.dirname(file_path))
56
- lines_array = []
57
- with open(file_path, 'r', encoding='utf-8') as file:
58
- for line in file:
59
- lines_array.append(line.strip())
60
-
61
- return lines_array
62
-
63
-
64
- def product_to_items():
65
- keywords = read_lines_to_array('./input/keywords.txt')
66
- create_dir('./output')
67
-
68
- for key in keywords:
69
- urls = list(get_second_links(key))
70
- print(f'Saving url into jsonl for keyword [{key}]')
71
- for url in tqdm(urls):
72
- tmp_dict = {
73
- 'keyword': key,
74
- 'id': url.split('.htm?spm=')[0].split('//www.taobao.com/list/item/')[1]
75
- }
76
- append_dict_to_jsonl(tmp_dict)
77
-
78
- rm_duplicates_by_key()
79
-
80
-
81
- if __name__ == "__main__":
82
- keywords = read_lines_to_array('./input/keywords.txt')
83
- create_dir('./output')
84
-
85
- for key in keywords:
86
- urls = list(get_second_links(key))
87
- print(f'Saving url into jsonl for keyword [{key}]')
88
- for url in tqdm(urls):
89
- tmp_dict = {
90
- 'keyword': key,
91
- 'id': url.split('.htm?spm=')[0].split('//www.taobao.com/list/item/')[1]
92
- }
93
- append_dict_to_jsonl(tmp_dict)
94
-
95
- rm_duplicates_by_key()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,9 +0,0 @@
1
- requests
2
- beautifulsoup4
3
- selenium
4
- Cython==0.29.28
5
- ultralytics
6
- timm==0.8.13.dev0
7
- omegaconf
8
- tqdm
9
- opencv-python
 
 
 
 
 
 
 
 
 
 
utils.py DELETED
@@ -1,59 +0,0 @@
1
- import os
2
- import json
3
- import random
4
- from time import sleep
5
-
6
-
7
- def create_dir(dir_path):
8
- if not os.path.exists(dir_path):
9
- os.makedirs(dir_path)
10
-
11
-
12
- def skip_captcha():
13
- print('爬取链接中...')
14
-
15
-
16
- def sleeps(a, b):
17
- if a > 0 and b > a:
18
- sleep((b - a) * random.random() + a)
19
-
20
- else:
21
- print('Invalid parms!')
22
-
23
-
24
- def save_to_file(data_list, file_path='./output/items.jsonl'):
25
- with open(file_path, 'w', encoding='utf-8') as jsonl_file:
26
- for data in data_list:
27
- json.dump(data, jsonl_file, ensure_ascii=(
28
- file_path != './output/items.jsonl'))
29
- jsonl_file.write('\n')
30
-
31
-
32
- def rm_duplicates_by_key(jsonl_path='./output/items.jsonl', key_to_check='id', failist_path='./output/duplicate_id.txt'):
33
- print('Removing duplicates...')
34
- if not os.path.exists(jsonl_path):
35
- print('jsonl not exist')
36
- return
37
-
38
- data_set = set()
39
- unique_data = []
40
- duplicates = set()
41
-
42
- with open(jsonl_path, 'r', encoding='utf-8') as jsonl_file:
43
- for line in jsonl_file:
44
- data = json.loads(line)
45
-
46
- # 提取指定键值的值,并用作判断重复的标识
47
- key_value = data.get(key_to_check)
48
-
49
- # 如果标识值已存在,表示数据重复
50
- if key_value in data_set:
51
- duplicates.add(key_value)
52
- continue
53
- else:
54
- data_set.add(key_value)
55
- unique_data.append(data)
56
-
57
- save_to_file(unique_data, file_path=jsonl_path)
58
- save_to_file(duplicates, file_path=failist_path)
59
- print('Duplicates removed!')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/yolov8x_person_face.pt → yolov8x_person_face.pt RENAMED
File without changes