Spaces:
Running
Running
import logging | |
import os | |
from functools import partial | |
from multiprocessing.pool import ThreadPool | |
from typing import Dict, List, Optional, Tuple | |
import cv2 | |
import numpy as np | |
from mivolo.data.data_reader import AnnotType, PictureInfo, get_all_files, read_csv_annotation_file | |
from mivolo.data.misc import IOU, class_letterbox, cropout_black_parts | |
from timm.data.readers.reader import Reader | |
from tqdm import tqdm | |
CROP_ROUND_TOL = 0.3 | |
MIN_PERSON_SIZE = 100 | |
MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4 | |
_logger = logging.getLogger("ReaderAgeGender") | |
class ReaderAgeGender(Reader): | |
""" | |
Reader for almost original imdb-wiki cleaned dataset. | |
Two changes: | |
1. Your annotation must be in ./annotation subdir of dataset root | |
2. Images must be in images subdir | |
""" | |
def __init__( | |
self, | |
images_path, | |
annotations_path, | |
split="validation", | |
target_size=224, | |
min_size=5, | |
seed=1234, | |
with_persons=False, | |
min_person_size=MIN_PERSON_SIZE, | |
disable_faces=False, | |
only_age=False, | |
min_person_aftercut_ratio=MIN_PERSON_CROP_AFTERCUT_RATIO, | |
crop_round_tol=CROP_ROUND_TOL, | |
): | |
super().__init__() | |
self.with_persons = with_persons | |
self.disable_faces = disable_faces | |
self.only_age = only_age | |
# can be only black for now, even though it's not very good with further normalization | |
self.crop_out_color = (0, 0, 0) | |
self.empty_crop = np.ones((target_size, target_size, 3)) * self.crop_out_color | |
self.empty_crop = self.empty_crop.astype(np.uint8) | |
self.min_person_size = min_person_size | |
self.min_person_aftercut_ratio = min_person_aftercut_ratio | |
self.crop_round_tol = crop_round_tol | |
self.split = split | |
self.min_size = min_size | |
self.seed = seed | |
self.target_size = target_size | |
# Reading annotations. Can be multiple files if annotations_path dir | |
self._ann: Dict[str, List[PictureInfo]] = {} # list of samples for each image | |
self._associated_objects: Dict[str, Dict[int, List[List[int]]]] = {} | |
self._faces_list: List[Tuple[str, int]] = [] # samples from this list will be loaded in __getitem__ | |
self._read_annotations(images_path, annotations_path) | |
_logger.info(f"Dataset length: {len(self._faces_list)} crops") | |
def __getitem__(self, index): | |
return self._read_img_and_label(index) | |
def __len__(self): | |
return len(self._faces_list) | |
def _filename(self, index, basename=False, absolute=False): | |
img_p = self._faces_list[index][0] | |
return os.path.basename(img_p) if basename else img_p | |
def _read_annotations(self, images_path, csvs_path): | |
self._ann = {} | |
self._faces_list = [] | |
self._associated_objects = {} | |
csvs = get_all_files(csvs_path, [".csv"]) | |
csvs = [c for c in csvs if self.split in os.path.basename(c)] | |
# load annotations per image | |
for csv in csvs: | |
db, ann_type = read_csv_annotation_file(csv, images_path) | |
if self.with_persons and ann_type != AnnotType.PERSONS: | |
raise ValueError( | |
f"Annotation type in file {csv} contains no persons, " | |
f"but annotations with persons are requested." | |
) | |
self._ann.update(db) | |
if len(self._ann) == 0: | |
raise ValueError("Annotations are empty!") | |
self._ann, self._associated_objects = self.prepare_annotations() | |
images_list = list(self._ann.keys()) | |
for img_path in images_list: | |
for index, image_sample_info in enumerate(self._ann[img_path]): | |
assert image_sample_info.has_gt( | |
self.only_age | |
), "Annotations must be checked with self.prepare_annotations() func" | |
self._faces_list.append((img_path, index)) | |
def _read_img_and_label(self, index): | |
if not isinstance(index, int): | |
raise TypeError("ReaderAgeGender expected index to be integer") | |
img_p, face_index = self._faces_list[index] | |
ann: PictureInfo = self._ann[img_p][face_index] | |
img = cv2.imread(img_p) | |
face_empty = True | |
if ann.has_face_bbox and not (self.with_persons and self.disable_faces): | |
face_crop, face_empty = self._get_crop(ann.bbox, img) | |
if not self.with_persons and face_empty: | |
# model without persons | |
raise ValueError("Annotations must be checked with self.prepare_annotations() func") | |
if face_empty: | |
face_crop = self.empty_crop | |
person_empty = True | |
if self.with_persons or self.disable_faces: | |
if ann.has_person_bbox: | |
# cut off all associated objects from person crop | |
objects = self._associated_objects[img_p][face_index] | |
person_crop, person_empty = self._get_crop( | |
ann.person_bbox, | |
img, | |
crop_out_color=self.crop_out_color, | |
asced_objects=objects, | |
) | |
if face_empty and person_empty: | |
raise ValueError("Annotations must be checked with self.prepare_annotations() func") | |
if person_empty: | |
person_crop = self.empty_crop | |
return (face_crop, person_crop), [ann.age, ann.gender] | |
def _get_crop( | |
self, | |
bbox, | |
img, | |
asced_objects=None, | |
crop_out_color=(0, 0, 0), | |
) -> Tuple[np.ndarray, bool]: | |
empty_bbox = False | |
xmin, ymin, xmax, ymax = bbox | |
assert not ( | |
ymax - ymin < self.min_size or xmax - xmin < self.min_size | |
), "Annotations must be checked with self.prepare_annotations() func" | |
crop = img[ymin:ymax, xmin:xmax] | |
if asced_objects: | |
# cut off other objects for person crop | |
crop, empty_bbox = _cropout_asced_objs( | |
asced_objects, | |
bbox, | |
crop.copy(), | |
crop_out_color=crop_out_color, | |
min_person_size=self.min_person_size, | |
crop_round_tol=self.crop_round_tol, | |
min_person_aftercut_ratio=self.min_person_aftercut_ratio, | |
) | |
if empty_bbox: | |
crop = self.empty_crop | |
crop = class_letterbox(crop, new_shape=(self.target_size, self.target_size), color=crop_out_color) | |
return crop, empty_bbox | |
def prepare_annotations(self): | |
good_anns: Dict[str, List[PictureInfo]] = {} | |
all_associated_objects: Dict[str, Dict[int, List[List[int]]]] = {} | |
if not self.with_persons: | |
# remove all persons | |
for img_path, bboxes in self._ann.items(): | |
for sample in bboxes: | |
sample.clear_person_bbox() | |
# check dataset and collect associated_objects | |
verify_images_func = partial( | |
verify_images, | |
min_size=self.min_size, | |
min_person_size=self.min_person_size, | |
with_persons=self.with_persons, | |
disable_faces=self.disable_faces, | |
crop_round_tol=self.crop_round_tol, | |
min_person_aftercut_ratio=self.min_person_aftercut_ratio, | |
only_age=self.only_age, | |
) | |
num_threads = min(8, os.cpu_count()) | |
all_msgs = [] | |
broken = 0 | |
skipped = 0 | |
all_skipped_crops = 0 | |
desc = "Check annotations..." | |
with ThreadPool(num_threads) as pool: | |
pbar = tqdm( | |
pool.imap_unordered(verify_images_func, list(self._ann.items())), | |
desc=desc, | |
total=len(self._ann), | |
) | |
for (img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops) in pbar: | |
broken += 1 if is_corrupted else 0 | |
all_msgs.extend(msgs) | |
all_skipped_crops += skipped_crops | |
skipped += 1 if is_empty_annotations else 0 | |
if img_info is not None: | |
img_path, img_samples = img_info | |
good_anns[img_path] = img_samples | |
all_associated_objects.update({img_path: associated_objects}) | |
pbar.desc = ( | |
f"{desc} {skipped} images skipped ({all_skipped_crops} crops are incorrect); " | |
f"{broken} images corrupted" | |
) | |
pbar.close() | |
for msg in all_msgs: | |
print(msg) | |
print(f"\nLeft images: {len(good_anns)}") | |
return good_anns, all_associated_objects | |
def verify_images( | |
img_info, | |
min_size: int, | |
min_person_size: int, | |
with_persons: bool, | |
disable_faces: bool, | |
crop_round_tol: float, | |
min_person_aftercut_ratio: float, | |
only_age: bool, | |
): | |
# If crop is too small, if image can not be read or if image does not exist | |
# then filter out this sample | |
disable_faces = disable_faces and with_persons | |
kwargs = dict( | |
min_person_size=min_person_size, | |
disable_faces=disable_faces, | |
with_persons=with_persons, | |
crop_round_tol=crop_round_tol, | |
min_person_aftercut_ratio=min_person_aftercut_ratio, | |
only_age=only_age, | |
) | |
def bbox_correct(bbox, min_size, im_h, im_w) -> Tuple[bool, List[int]]: | |
ymin, ymax, xmin, xmax = _correct_bbox(bbox, im_h, im_w) | |
crop_h, crop_w = ymax - ymin, xmax - xmin | |
if crop_h < min_size or crop_w < min_size: | |
return False, [-1, -1, -1, -1] | |
bbox = [xmin, ymin, xmax, ymax] | |
return True, bbox | |
msgs = [] | |
skipped_crops = 0 | |
is_corrupted = False | |
is_empty_annotations = False | |
img_path: str = img_info[0] | |
img_samples: List[PictureInfo] = img_info[1] | |
try: | |
im_cv = cv2.imread(img_path) | |
im_h, im_w = im_cv.shape[:2] | |
except Exception: | |
msgs.append(f"Can not load image {img_path}") | |
is_corrupted = True | |
return None, {}, msgs, is_corrupted, is_empty_annotations, skipped_crops | |
out_samples: List[PictureInfo] = [] | |
for sample in img_samples: | |
# correct face bbox | |
if sample.has_face_bbox: | |
is_correct, sample.bbox = bbox_correct(sample.bbox, min_size, im_h, im_w) | |
if not is_correct and sample.has_gt(only_age): | |
msgs.append("Small face. Passing..") | |
skipped_crops += 1 | |
# correct person bbox | |
if sample.has_person_bbox: | |
is_correct, sample.person_bbox = bbox_correct( | |
sample.person_bbox, max(min_person_size, min_size), im_h, im_w | |
) | |
if not is_correct and sample.has_gt(only_age): | |
msgs.append(f"Small person {img_path}. Passing..") | |
skipped_crops += 1 | |
if sample.has_face_bbox or sample.has_person_bbox: | |
out_samples.append(sample) | |
elif sample.has_gt(only_age): | |
msgs.append("Sample hs no face and no body. Passing..") | |
skipped_crops += 1 | |
# sort that samples with undefined age and gender be the last | |
out_samples = sorted(out_samples, key=lambda sample: 1 if not sample.has_gt(only_age) else 0) | |
# for each person find other faces and persons bboxes, intersected with it | |
associated_objects: Dict[int, List[List[int]]] = find_associated_objects(out_samples, only_age=only_age) | |
out_samples, associated_objects, skipped_crops = filter_bad_samples( | |
out_samples, associated_objects, im_cv, msgs, skipped_crops, **kwargs | |
) | |
out_img_info: Optional[Tuple[str, List]] = (img_path, out_samples) | |
if len(out_samples) == 0: | |
out_img_info = None | |
is_empty_annotations = True | |
return out_img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops | |
def filter_bad_samples( | |
out_samples: List[PictureInfo], | |
associated_objects: dict, | |
im_cv: np.ndarray, | |
msgs: List[str], | |
skipped_crops: int, | |
**kwargs, | |
): | |
with_persons, disable_faces, min_person_size, crop_round_tol, min_person_aftercut_ratio, only_age = ( | |
kwargs["with_persons"], | |
kwargs["disable_faces"], | |
kwargs["min_person_size"], | |
kwargs["crop_round_tol"], | |
kwargs["min_person_aftercut_ratio"], | |
kwargs["only_age"], | |
) | |
# left only samples with annotations | |
inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_gt(only_age)] | |
out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds) | |
if kwargs["disable_faces"]: | |
# clear all faces | |
for ind, sample in enumerate(out_samples): | |
sample.clear_face_bbox() | |
# left only samples with person_bbox | |
inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_person_bbox] | |
out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds) | |
if with_persons or disable_faces: | |
# check that preprocessing func | |
# _cropout_asced_objs() return not empty person_image for each out sample | |
inds = [] | |
for ind, sample in enumerate(out_samples): | |
person_empty = True | |
if sample.has_person_bbox: | |
xmin, ymin, xmax, ymax = sample.person_bbox | |
crop = im_cv[ymin:ymax, xmin:xmax] | |
# cut off all associated objects from person crop | |
_, person_empty = _cropout_asced_objs( | |
associated_objects[ind], | |
sample.person_bbox, | |
crop.copy(), | |
min_person_size=min_person_size, | |
crop_round_tol=crop_round_tol, | |
min_person_aftercut_ratio=min_person_aftercut_ratio, | |
) | |
if person_empty and not sample.has_face_bbox: | |
msgs.append("Small person after preprocessing. Passing..") | |
skipped_crops += 1 | |
else: | |
inds.append(ind) | |
out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds) | |
assert len(associated_objects) == len(out_samples) | |
return out_samples, associated_objects, skipped_crops | |
def _filter_by_ind(out_samples, associated_objects, inds): | |
_associated_objects = {} | |
_out_samples = [] | |
for ind, sample in enumerate(out_samples): | |
if ind in inds: | |
_associated_objects[len(_out_samples)] = associated_objects[ind] | |
_out_samples.append(sample) | |
return _out_samples, _associated_objects | |
def find_associated_objects( | |
image_samples: List[PictureInfo], iou_thresh=0.0001, only_age=False | |
) -> Dict[int, List[List[int]]]: | |
""" | |
For each person (which has gt age and gt gender) find other faces and persons bboxes, intersected with it | |
""" | |
associated_objects: Dict[int, List[List[int]]] = {} | |
for iindex, image_sample_info in enumerate(image_samples): | |
# add own face | |
associated_objects[iindex] = [image_sample_info.bbox] if image_sample_info.has_face_bbox else [] | |
if not image_sample_info.has_person_bbox or not image_sample_info.has_gt(only_age): | |
# if sample has not gt => not be used | |
continue | |
iperson_box = image_sample_info.person_bbox | |
for jindex, other_image_sample in enumerate(image_samples): | |
if iindex == jindex: | |
continue | |
if other_image_sample.has_face_bbox: | |
jface_bbox = other_image_sample.bbox | |
iou = _get_iou(jface_bbox, iperson_box) | |
if iou >= iou_thresh: | |
associated_objects[iindex].append(jface_bbox) | |
if other_image_sample.has_person_bbox: | |
jperson_bbox = other_image_sample.person_bbox | |
iou = _get_iou(jperson_bbox, iperson_box) | |
if iou >= iou_thresh: | |
associated_objects[iindex].append(jperson_bbox) | |
return associated_objects | |
def _cropout_asced_objs( | |
asced_objects, | |
person_bbox, | |
crop, | |
min_person_size, | |
crop_round_tol, | |
min_person_aftercut_ratio, | |
crop_out_color=(0, 0, 0), | |
): | |
empty = False | |
xmin, ymin, xmax, ymax = person_bbox | |
for a_obj in asced_objects: | |
aobj_xmin, aobj_ymin, aobj_xmax, aobj_ymax = a_obj | |
aobj_ymin = int(max(aobj_ymin - ymin, 0)) | |
aobj_xmin = int(max(aobj_xmin - xmin, 0)) | |
aobj_ymax = int(min(aobj_ymax - ymin, ymax - ymin)) | |
aobj_xmax = int(min(aobj_xmax - xmin, xmax - xmin)) | |
crop[aobj_ymin:aobj_ymax, aobj_xmin:aobj_xmax] = crop_out_color | |
crop, cropped_ratio = cropout_black_parts(crop, crop_round_tol) | |
if ( | |
crop.shape[0] < min_person_size or crop.shape[1] < min_person_size | |
) or cropped_ratio < min_person_aftercut_ratio: | |
crop = None | |
empty = True | |
return crop, empty | |
def _correct_bbox(bbox, h, w): | |
xmin, ymin, xmax, ymax = bbox | |
ymin = min(max(ymin, 0), h) | |
ymax = min(max(ymax, 0), h) | |
xmin = min(max(xmin, 0), w) | |
xmax = min(max(xmax, 0), w) | |
return ymin, ymax, xmin, xmax | |
def _get_iou(bbox1, bbox2): | |
xmin1, ymin1, xmax1, ymax1 = bbox1 | |
xmin2, ymin2, xmax2, ymax2 = bbox2 | |
iou = IOU( | |
[ymin1, xmin1, ymax1, xmax1], | |
[ymin2, xmin2, ymax2, xmax2], | |
) | |
return iou | |