import logging import os import pickle import random import shutil import subprocess import SharedArray import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp def check_numpy_to_torch(x): if isinstance(x, np.ndarray): return torch.from_numpy(x).float(), True return x, False def limit_period(val, offset=0.5, period=np.pi): val, is_numpy = check_numpy_to_torch(val) ans = val - torch.floor(val / period + offset) * period return ans.numpy() if is_numpy else ans def drop_info_with_name(info, name): ret_info = {} keep_indices = [i for i, x in enumerate(info['name']) if x != name] for key in info.keys(): ret_info[key] = info[key][keep_indices] return ret_info def rotate_points_along_z(points, angle): """ Args: points: (B, N, 3 + C) angle: (B), angle along z-axis, angle increases x ==> y Returns: """ points, is_numpy = check_numpy_to_torch(points) angle, _ = check_numpy_to_torch(angle) cosa = torch.cos(angle) sina = torch.sin(angle) zeros = angle.new_zeros(points.shape[0]) ones = angle.new_ones(points.shape[0]) rot_matrix = torch.stack(( cosa, sina, zeros, -sina, cosa, zeros, zeros, zeros, ones ), dim=1).view(-1, 3, 3).float() points_rot = torch.matmul(points[:, :, 0:3], rot_matrix) points_rot = torch.cat((points_rot, points[:, :, 3:]), dim=-1) return points_rot.numpy() if is_numpy else points_rot def angle2matrix(angle): """ Args: angle: angle along z-axis, angle increases x ==> y Returns: rot_matrix: (3x3 Tensor) rotation matrix """ cosa = torch.cos(angle) sina = torch.sin(angle) rot_matrix = torch.tensor([ [cosa, -sina, 0], [sina, cosa, 0], [ 0, 0, 1] ]) return rot_matrix def mask_points_by_range(points, limit_range): mask = (points[:, 0] >= limit_range[0]) & (points[:, 0] <= limit_range[3]) \ & (points[:, 1] >= limit_range[1]) & (points[:, 1] <= limit_range[4]) return mask def get_voxel_centers(voxel_coords, downsample_times, voxel_size, point_cloud_range): """ Args: voxel_coords: (N, 3) downsample_times: voxel_size: point_cloud_range: Returns: """ assert voxel_coords.shape[1] == 3 voxel_centers = voxel_coords[:, [2, 1, 0]].float() # (xyz) voxel_size = torch.tensor(voxel_size, device=voxel_centers.device).float() * downsample_times pc_range = torch.tensor(point_cloud_range[0:3], device=voxel_centers.device).float() voxel_centers = (voxel_centers + 0.5) * voxel_size + pc_range return voxel_centers def create_logger(log_file=None, rank=0, log_level=logging.INFO): logger = logging.getLogger(__name__) logger.setLevel(log_level if rank == 0 else 'ERROR') formatter = logging.Formatter('%(asctime)s %(levelname)5s %(message)s') console = logging.StreamHandler() console.setLevel(log_level if rank == 0 else 'ERROR') console.setFormatter(formatter) logger.addHandler(console) if log_file is not None: file_handler = logging.FileHandler(filename=log_file) file_handler.setLevel(log_level if rank == 0 else 'ERROR') file_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.propagate = False return logger def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def worker_init_fn(worker_id, seed=666): if seed is not None: random.seed(seed + worker_id) np.random.seed(seed + worker_id) torch.manual_seed(seed + worker_id) torch.cuda.manual_seed(seed + worker_id) torch.cuda.manual_seed_all(seed + worker_id) def get_pad_params(desired_size, cur_size): """ Get padding parameters for np.pad function Args: desired_size: int, Desired padded output size cur_size: int, Current size. Should always be less than or equal to cur_size Returns: pad_params: tuple(int), Number of values padded to the edges (before, after) """ assert desired_size >= cur_size # Calculate amount to pad diff = desired_size - cur_size pad_params = (0, diff) return pad_params def keep_arrays_by_name(gt_names, used_classes): inds = [i for i, x in enumerate(gt_names) if x in used_classes] inds = np.array(inds, dtype=np.int64) return inds def init_dist_slurm(tcp_port, local_rank, backend='nccl'): """ modified from https://github.com/open-mmlab/mmdetection Args: tcp_port: backend: Returns: """ proc_id = int(os.environ['SLURM_PROCID']) ntasks = int(os.environ['SLURM_NTASKS']) node_list = os.environ['SLURM_NODELIST'] num_gpus = torch.cuda.device_count() torch.cuda.set_device(proc_id % num_gpus) addr = subprocess.getoutput('scontrol show hostname {} | head -n1'.format(node_list)) os.environ['MASTER_PORT'] = str(tcp_port) os.environ['MASTER_ADDR'] = addr os.environ['WORLD_SIZE'] = str(ntasks) os.environ['RANK'] = str(proc_id) dist.init_process_group(backend=backend) total_gpus = dist.get_world_size() rank = dist.get_rank() return total_gpus, rank def init_dist_pytorch(tcp_port, local_rank, backend='nccl'): if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') # os.environ['MASTER_PORT'] = str(tcp_port) # os.environ['MASTER_ADDR'] = 'localhost' num_gpus = torch.cuda.device_count() torch.cuda.set_device(local_rank % num_gpus) dist.init_process_group( backend=backend, # init_method='tcp://127.0.0.1:%d' % tcp_port, # rank=local_rank, # world_size=num_gpus ) rank = dist.get_rank() return num_gpus, rank def get_dist_info(return_gpu_per_machine=False): if torch.__version__ < '1.0': initialized = dist._initialized else: if dist.is_available(): initialized = dist.is_initialized() else: initialized = False if initialized: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 if return_gpu_per_machine: gpu_per_machine = torch.cuda.device_count() return rank, world_size, gpu_per_machine return rank, world_size def merge_results_dist(result_part, size, tmpdir): rank, world_size = get_dist_info() os.makedirs(tmpdir, exist_ok=True) dist.barrier() pickle.dump(result_part, open(os.path.join(tmpdir, 'result_part_{}.pkl'.format(rank)), 'wb')) dist.barrier() if rank != 0: return None part_list = [] for i in range(world_size): part_file = os.path.join(tmpdir, 'result_part_{}.pkl'.format(i)) part_list.append(pickle.load(open(part_file, 'rb'))) ordered_results = [] for res in zip(*part_list): ordered_results.extend(list(res)) ordered_results = ordered_results[:size] shutil.rmtree(tmpdir) return ordered_results def scatter_point_inds(indices, point_inds, shape): ret = -1 * torch.ones(*shape, dtype=point_inds.dtype, device=point_inds.device) ndim = indices.shape[-1] flattened_indices = indices.view(-1, ndim) slices = [flattened_indices[:, i] for i in range(ndim)] ret[slices] = point_inds return ret def generate_voxel2pinds(sparse_tensor): device = sparse_tensor.indices.device batch_size = sparse_tensor.batch_size spatial_shape = sparse_tensor.spatial_shape indices = sparse_tensor.indices.long() point_indices = torch.arange(indices.shape[0], device=device, dtype=torch.int32) output_shape = [batch_size] + list(spatial_shape) v2pinds_tensor = scatter_point_inds(indices, point_indices, output_shape) return v2pinds_tensor def sa_create(name, var): x = SharedArray.create(name, var.shape, dtype=var.dtype) x[...] = var[...] x.flags.writeable = False return x class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count