haakohu's picture
initial
5d756f1
from pathlib import Path
from typing import Union, Optional
import numpy as np
import torch
import tops
import torchvision.transforms.functional as F
from motpy import Detection, MultiObjectTracker
from dp2.utils import load_config
from dp2.infer import build_trained_generator
from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection
def load_generator_from_cfg_path(cfg_path: Union[str, Path]):
cfg = load_config(cfg_path)
G = build_trained_generator(cfg)
tops.logger.log(f"Loaded generator from: {cfg_path}")
return G
class Anonymizer:
def __init__(
self,
detector,
load_cache: bool = False,
person_G_cfg: Optional[Union[str, Path]] = None,
cse_person_G_cfg: Optional[Union[str, Path]] = None,
face_G_cfg: Optional[Union[str, Path]] = None,
car_G_cfg: Optional[Union[str, Path]] = None,
) -> None:
self.detector = detector
self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]}
self.load_cache = load_cache
if cse_person_G_cfg is not None:
self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg)
if person_G_cfg is not None:
self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg)
if face_G_cfg is not None:
self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg)
if car_G_cfg is not None:
self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg)
def initialize_tracker(self, fps: float):
self.tracker = MultiObjectTracker(dt=1/fps)
self.track_to_z_idx = dict()
def reset_tracker(self):
self.track_to_z_idx = dict()
def forward_G(self,
G,
batch,
multi_modal_truncation: bool,
amp: bool,
z_idx: int,
truncation_value: float,
idx: int,
all_styles=None):
batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
batch["img"] = batch["img"].float()
batch["condition"] = batch["mask"].float() * batch["img"]
with torch.cuda.amp.autocast(amp):
z = None
if z_idx is not None:
state = np.random.RandomState(seed=z_idx[idx])
z = state.normal(size=(1, G.z_channels)).astype(np.float32)
z = tops.to_cuda(torch.from_numpy(z))
if all_styles is not None:
anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"]
elif multi_modal_truncation:
w_indices = None
if z_idx is not None:
w_indices = [z_idx[idx] % len(G.style_net.w_centers)]
anonymized_im = G.multi_modal_truncate(
**batch, truncation_value=truncation_value,
w_indices=w_indices,
z=z
)["img"]
else:
anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"]
anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255)
return anonymized_im
@torch.no_grad()
def anonymize_detections(self,
im, detection,
update_identity=None,
**synthesis_kwargs
):
G = self.generators[type(detection)]
if G is None:
return im
C, H, W = im.shape
if update_identity is None:
update_identity = [True for i in range(len(detection))]
for idx in range(len(detection)):
if not update_identity[idx]:
continue
batch = detection.get_crop(idx, im)
x0, y0, x1, y1 = batch.pop("boxes")[0]
batch = {k: tops.to_cuda(v) for k, v in batch.items()}
anonymized_im = self.forward_G(G, batch, **synthesis_kwargs, idx=idx)
gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.BICUBIC, antialias=True)
mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0)
# Remove padding
pad = [max(-x0, 0), max(-y0, 0)]
pad = [*pad, max(x1-W, 0), max(y1-H, 0)]
def remove_pad(x): return x[..., pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]]
gim = remove_pad(gim)
mask = remove_pad(mask) > 0.5
x0, y0 = max(x0, 0), max(y0, 0)
x1, y1 = min(x1, W), min(y1, H)
mask = mask.logical_not()[None].repeat(3, 1, 1)
im[:, y0:y1, x0:x1][mask] = gim[mask].round().clamp(0, 255).byte()
return im
def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor:
all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
im = im.cpu()
for det in all_detections:
im = det.visualize(im)
return im
@torch.no_grad()
def forward(self, im: torch.Tensor, cache_id: str = None, track=True, detections=None, **synthesis_kwargs) -> torch.Tensor:
assert im.dtype == torch.uint8
im = tops.to_cuda(im)
all_detections = detections
if detections is None:
if self.load_cache:
all_detections = self.detector.forward_and_cache(im, cache_id)
else:
all_detections = self.detector(im)
if hasattr(self, "tracker") and track:
[_.pre_process() for _ in all_detections]
boxes = np.concatenate([_.boxes for _ in all_detections])
boxes = [Detection(box) for box in boxes]
self.tracker.step(boxes)
track_ids = self.tracker.detections_matched_ids
z_idx = []
for track_id in track_ids:
if track_id not in self.track_to_z_idx:
self.track_to_z_idx[track_id] = np.random.randint(0, 2**32-1)
z_idx.append(self.track_to_z_idx[track_id])
z_idx = np.array(z_idx)
idx_offset = 0
for detection in all_detections:
zs = None
if hasattr(self, "tracker") and track:
zs = z_idx[idx_offset:idx_offset+len(detection)]
idx_offset += len(detection)
im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs)
return im.cpu()
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)