import torch from typing import List import tops from torchvision.transforms.functional import InterpolationMode, resize from densepose.data.utils import get_class_to_mesh_name_mapping from densepose import add_densepose_config from densepose.structures import DensePoseEmbeddingPredictorOutput from densepose.vis.extractor import DensePoseOutputsExtractor from densepose.modeling import build_densepose_embedder from detectron2.config import get_cfg from detectron2.data.transforms import ResizeShortestEdge from detectron2.checkpoint.detection_checkpoint import DetectionCheckpointer from detectron2.modeling import build_model model_urls = { "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x/250713061/model_final_1d3314.pkl", "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl", } def cse_det_to_global(boxes_XYXY, S: torch.Tensor, imshape): assert len(S.shape) == 3 H, W = imshape N = len(boxes_XYXY) segmentation = torch.zeros((N, H, W), dtype=torch.bool, device=S.device) boxes_XYXY = boxes_XYXY.long() for i in range(N): x0, y0, x1, y1 = boxes_XYXY[i] assert x0 >= 0 and y0 >= 0 assert x1 <= imshape[1] assert y1 <= imshape[0] h = y1 - y0 w = x1 - x0 segmentation[i:i+1, y0:y1, x0:x1] = resize(S[i:i+1], (h, w), interpolation=InterpolationMode.NEAREST) > 0 return segmentation class CSEDetector: def __init__( self, cfg_url: str = "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml", cfg_2_download: List[str] = [ "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml", "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml", "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml"], score_thres: float = 0.9, nms_thresh: float = None, ) -> None: with tops.logger.capture_log_stdout(): cfg = get_cfg() self.device = tops.get_device() add_densepose_config(cfg) cfg_path = tops.download_file(cfg_url) for p in cfg_2_download: tops.download_file(p) with tops.logger.capture_log_stdout(): cfg.merge_from_file(cfg_path) assert cfg_url in model_urls, cfg_url model_path = tops.download_file(model_urls[cfg_url]) cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres if nms_thresh is not None: cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = nms_thresh cfg.MODEL.WEIGHTS = str(model_path) cfg.MODEL.DEVICE = str(self.device) cfg.freeze() with tops.logger.capture_log_stdout(): self.model = build_model(cfg) self.model.eval() DetectionCheckpointer(self.model).load(str(model_path)) self.input_format = cfg.INPUT.FORMAT self.densepose_extractor = DensePoseOutputsExtractor() self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg) self.embedder = build_densepose_embedder(cfg) self.mesh_vertex_embeddings = { mesh_name: self.embedder(mesh_name).to(self.device) for mesh_name in self.class_to_mesh_name.values() if self.embedder.has_embeddings(mesh_name) } self.cfg = cfg self.embed_map = self.mesh_vertex_embeddings["smpl_27554"] tops.logger.log("CSEDetector built.") def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) def resize_im(self, im): H, W = im.shape[1:] newH, newW = ResizeShortestEdge.get_output_shape( H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST) return resize( im, (newH, newW), InterpolationMode.BILINEAR, antialias=True) @torch.no_grad() def forward(self, im): assert im.dtype == torch.uint8 if self.input_format == "BGR": im = im.flip(0) H, W = im.shape[1:] im = self.resize_im(im) output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"] scores = output.get("scores") if len(scores) == 0: return dict( instance_segmentation=torch.empty((0, 0, 112, 112), dtype=torch.bool, device=im.device), instance_embedding=torch.empty((0, 16, 112, 112), dtype=torch.float32, device=im.device), embed_map=self.mesh_vertex_embeddings["smpl_27554"], bbox_XYXY=torch.empty((0, 4), dtype=torch.long, device=im.device), im_segmentation=torch.empty((0, H, W), dtype=torch.bool, device=im.device), scores=torch.empty((0), dtype=torch.float, device=im.device) ) pred_densepose, boxes_xywh, classes = self.densepose_extractor(output) assert isinstance(pred_densepose, DensePoseEmbeddingPredictorOutput), pred_densepose S = pred_densepose.coarse_segm.argmax(dim=1) # Segmentation channel Nx2xHxW (2 because only 2 classes) E = pred_densepose.embedding mesh_name = self.class_to_mesh_name[classes[0]] assert mesh_name == "smpl_27554" x0, y0, w, h = [boxes_xywh[:, i] for i in range(4)] boxes_XYXY = torch.stack((x0, y0, x0+w, y0+h), dim=-1) boxes_XYXY = boxes_XYXY.round_().long() non_empty_boxes = (boxes_XYXY[:, :2] == boxes_XYXY[:, 2:]).any(dim=1).logical_not() S = S[non_empty_boxes] E = E[non_empty_boxes] boxes_XYXY = boxes_XYXY[non_empty_boxes] scores = scores[non_empty_boxes] im_segmentation = cse_det_to_global(boxes_XYXY, S, [H, W]) return dict( instance_segmentation=S, instance_embedding=E, bbox_XYXY=boxes_XYXY, im_segmentation=im_segmentation, scores=scores.view(-1))