# Copyright (c) Facebook, Inc. and its affiliates. import torch from densepose.structures.data_relative import DensePoseDataRelative class DensePoseList: _TORCH_DEVICE_CPU = torch.device("cpu") def __init__(self, densepose_datas, boxes_xyxy_abs, image_size_hw, device=_TORCH_DEVICE_CPU): assert len(densepose_datas) == len( boxes_xyxy_abs ), "Attempt to initialize DensePoseList with {} DensePose datas " "and {} boxes".format( len(densepose_datas), len(boxes_xyxy_abs) ) self.densepose_datas = [] for densepose_data in densepose_datas: assert isinstance(densepose_data, DensePoseDataRelative) or densepose_data is None, ( "Attempt to initialize DensePoseList with DensePose datas " "of type {}, expected DensePoseDataRelative".format(type(densepose_data)) ) densepose_data_ondevice = ( densepose_data.to(device) if densepose_data is not None else None ) self.densepose_datas.append(densepose_data_ondevice) self.boxes_xyxy_abs = boxes_xyxy_abs.to(device) self.image_size_hw = image_size_hw self.device = device def to(self, device): if self.device == device: return self return DensePoseList(self.densepose_datas, self.boxes_xyxy_abs, self.image_size_hw, device) def __iter__(self): return iter(self.densepose_datas) def __len__(self): return len(self.densepose_datas) def __repr__(self): s = self.__class__.__name__ + "(" s += "num_instances={}, ".format(len(self.densepose_datas)) s += "image_width={}, ".format(self.image_size_hw[1]) s += "image_height={})".format(self.image_size_hw[0]) return s def __getitem__(self, item): if isinstance(item, int): densepose_data_rel = self.densepose_datas[item] return densepose_data_rel elif isinstance(item, slice): densepose_datas_rel = self.densepose_datas[item] boxes_xyxy_abs = self.boxes_xyxy_abs[item] return DensePoseList( densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device ) elif isinstance(item, torch.Tensor) and (item.dtype == torch.bool): densepose_datas_rel = [self.densepose_datas[i] for i, x in enumerate(item) if x > 0] boxes_xyxy_abs = self.boxes_xyxy_abs[item] return DensePoseList( densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device ) else: densepose_datas_rel = [self.densepose_datas[i] for i in item] boxes_xyxy_abs = self.boxes_xyxy_abs[item] return DensePoseList( densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device )