Spaces:
Runtime error
Runtime error
from pathlib import Path | |
from typing import Union, Optional | |
import numpy as np | |
import torch | |
import tops | |
import torchvision.transforms.functional as F | |
from motpy import Detection, MultiObjectTracker | |
from dp2.utils import load_config | |
from dp2.infer import build_trained_generator | |
from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection | |
def load_generator_from_cfg_path(cfg_path: Union[str, Path]): | |
cfg = load_config(cfg_path) | |
G = build_trained_generator(cfg) | |
tops.logger.log(f"Loaded generator from: {cfg_path}") | |
return G | |
class Anonymizer: | |
def __init__( | |
self, | |
detector, | |
load_cache: bool = False, | |
person_G_cfg: Optional[Union[str, Path]] = None, | |
cse_person_G_cfg: Optional[Union[str, Path]] = None, | |
face_G_cfg: Optional[Union[str, Path]] = None, | |
car_G_cfg: Optional[Union[str, Path]] = None, | |
) -> None: | |
self.detector = detector | |
self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]} | |
self.load_cache = load_cache | |
if cse_person_G_cfg is not None: | |
self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg) | |
if person_G_cfg is not None: | |
self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg) | |
if face_G_cfg is not None: | |
self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg) | |
if car_G_cfg is not None: | |
self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg) | |
def initialize_tracker(self, fps: float): | |
self.tracker = MultiObjectTracker(dt=1/fps) | |
self.track_to_z_idx = dict() | |
def reset_tracker(self): | |
self.track_to_z_idx = dict() | |
def forward_G(self, | |
G, | |
batch, | |
multi_modal_truncation: bool, | |
amp: bool, | |
z_idx: int, | |
truncation_value: float, | |
idx: int, | |
all_styles=None): | |
batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255]) | |
batch["img"] = batch["img"].float() | |
batch["condition"] = batch["mask"].float() * batch["img"] | |
with torch.cuda.amp.autocast(amp): | |
z = None | |
if z_idx is not None: | |
state = np.random.RandomState(seed=z_idx[idx]) | |
z = state.normal(size=(1, G.z_channels)).astype(np.float32) | |
z = tops.to_cuda(torch.from_numpy(z)) | |
if all_styles is not None: | |
anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"] | |
elif multi_modal_truncation: | |
w_indices = None | |
if z_idx is not None: | |
w_indices = [z_idx[idx] % len(G.style_net.w_centers)] | |
anonymized_im = G.multi_modal_truncate( | |
**batch, truncation_value=truncation_value, | |
w_indices=w_indices, | |
z=z | |
)["img"] | |
else: | |
anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"] | |
anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255) | |
return anonymized_im | |
def anonymize_detections(self, | |
im, detection, | |
update_identity=None, | |
**synthesis_kwargs | |
): | |
G = self.generators[type(detection)] | |
if G is None: | |
return im | |
C, H, W = im.shape | |
if update_identity is None: | |
update_identity = [True for i in range(len(detection))] | |
for idx in range(len(detection)): | |
if not update_identity[idx]: | |
continue | |
batch = detection.get_crop(idx, im) | |
x0, y0, x1, y1 = batch.pop("boxes")[0] | |
batch = {k: tops.to_cuda(v) for k, v in batch.items()} | |
anonymized_im = self.forward_G(G, batch, **synthesis_kwargs, idx=idx) | |
gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.BICUBIC, antialias=True) | |
mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0) | |
# Remove padding | |
pad = [max(-x0, 0), max(-y0, 0)] | |
pad = [*pad, max(x1-W, 0), max(y1-H, 0)] | |
def remove_pad(x): return x[..., pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]] | |
gim = remove_pad(gim) | |
mask = remove_pad(mask) > 0.5 | |
x0, y0 = max(x0, 0), max(y0, 0) | |
x1, y1 = min(x1, W), min(y1, H) | |
mask = mask.logical_not()[None].repeat(3, 1, 1) | |
im[:, y0:y1, x0:x1][mask] = gim[mask].round().clamp(0, 255).byte() | |
return im | |
def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor: | |
all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache) | |
im = im.cpu() | |
for det in all_detections: | |
im = det.visualize(im) | |
return im | |
def forward(self, im: torch.Tensor, cache_id: str = None, track=True, detections=None, **synthesis_kwargs) -> torch.Tensor: | |
assert im.dtype == torch.uint8 | |
im = tops.to_cuda(im) | |
all_detections = detections | |
if detections is None: | |
if self.load_cache: | |
all_detections = self.detector.forward_and_cache(im, cache_id) | |
else: | |
all_detections = self.detector(im) | |
if hasattr(self, "tracker") and track: | |
[_.pre_process() for _ in all_detections] | |
boxes = np.concatenate([_.boxes for _ in all_detections]) | |
boxes = [Detection(box) for box in boxes] | |
self.tracker.step(boxes) | |
track_ids = self.tracker.detections_matched_ids | |
z_idx = [] | |
for track_id in track_ids: | |
if track_id not in self.track_to_z_idx: | |
self.track_to_z_idx[track_id] = np.random.randint(0, 2**32-1) | |
z_idx.append(self.track_to_z_idx[track_id]) | |
z_idx = np.array(z_idx) | |
idx_offset = 0 | |
for detection in all_detections: | |
zs = None | |
if hasattr(self, "tracker") and track: | |
zs = z_idx[idx_offset:idx_offset+len(detection)] | |
idx_offset += len(detection) | |
im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs) | |
return im.cpu() | |
def __call__(self, *args, **kwargs): | |
return self.forward(*args, **kwargs) | |