yfdeng's picture
init
744eb4e
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from collections import abc
# from pointnet2_ops import pointnet2_utils
# def fps(data, number):
# '''
# data B N 3
# number int
# '''
# fps_idx = pointnet2_utils.furthest_point_sample(data, number)
# fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
# return fps_data
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def fps(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
dist = torch.sum((xyz - centroid) ** 2, -1)
distance = torch.min(distance, dist)
farthest = torch.max(distance, -1)[1]
return index_points(xyz, centroids)
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
def build_lambda_sche(opti, config):
if config.get('decay_step') is not None:
lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay)
scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd)
else:
raise NotImplementedError()
return scheduler
def build_lambda_bnsche(model, config):
if config.get('decay_step') is not None:
bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay)
bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd)
else:
raise NotImplementedError()
return bnm_scheduler
def set_random_seed(seed, deterministic=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
if cuda_deterministic: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
cudnn.deterministic = False
cudnn.benchmark = True
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def is_seq_of(seq, expected_type, seq_type=None):
"""Check whether it is a sequence of some type.
Args:
seq (Sequence): The sequence to be checked.
expected_type (type): Expected type of sequence items.
seq_type (type, optional): Expected sequence type.
Returns:
bool: Whether the sequence is valid.
"""
if seq_type is None:
exp_seq_type = abc.Sequence
else:
assert isinstance(seq_type, type)
exp_seq_type = seq_type
if not isinstance(seq, exp_seq_type):
return False
for item in seq:
if not isinstance(item, expected_type):
return False
return True
def set_bn_momentum_default(bn_momentum):
def fn(m):
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
m.momentum = bn_momentum
return fn
class BNMomentumScheduler(object):
def __init__(
self, model, bn_lambda, last_epoch=-1,
setter=set_bn_momentum_default
):
if not isinstance(model, nn.Module):
raise RuntimeError(
"Class '{}' is not a PyTorch nn Module".format(
type(model).__name__
)
)
self.model = model
self.setter = setter
self.lmbd = bn_lambda
self.step(last_epoch + 1)
self.last_epoch = last_epoch
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
self.model.apply(self.setter(self.lmbd(epoch)))
def get_momentum(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
return self.lmbd(epoch)
def seprate_point_cloud(xyz, num_points, crop, fixed_points = None, padding_zeros = False):
'''
seprate point cloud: usage : using to generate the incomplete point cloud with a setted number.
'''
_,n,c = xyz.shape
assert n == num_points
assert c == 3
if crop == num_points:
return xyz, None
INPUT = []
CROP = []
for points in xyz:
if isinstance(crop,list):
num_crop = random.randint(crop[0],crop[1])
else:
num_crop = crop
points = points.unsqueeze(0)
if fixed_points is None:
center = F.normalize(torch.randn(1,1,3),p=2,dim=-1).cuda()
else:
if isinstance(fixed_points,list):
fixed_point = random.sample(fixed_points,1)[0]
else:
fixed_point = fixed_points
center = fixed_point.reshape(1,1,3).cuda()
distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1) # 1 1 2048
idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048
if padding_zeros:
input_data = points.clone()
input_data[0, idx[:num_crop]] = input_data[0,idx[:num_crop]] * 0
else:
input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3
crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0)
if isinstance(crop,list):
INPUT.append(fps(input_data,2048))
CROP.append(fps(crop_data,2048))
else:
INPUT.append(input_data)
CROP.append(crop_data)
input_data = torch.cat(INPUT,dim=0)# B N 3
crop_data = torch.cat(CROP,dim=0)# B M 3
return input_data.contiguous(), crop_data.contiguous()
def get_ptcloud_img(ptcloud):
fig = plt.figure(figsize=(8, 8))
x, z, y = ptcloud.transpose(1, 0)
ax = fig.gca(projection=Axes3D.name, adjustable='box')
ax.axis('off')
# ax.axis('scaled')
ax.view_init(30, 45)
max, min = np.max(ptcloud), np.min(ptcloud)
ax.set_xbound(min, max)
ax.set_ybound(min, max)
ax.set_zbound(min, max)
ax.scatter(x, y, z, zdir='z', c=x, cmap='jet')
fig.canvas.draw()
img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
return img
def visualize_KITTI(path, data_list, titles = ['input','pred'], cmap=['bwr','autumn'], zdir='y',
xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1) ):
fig = plt.figure(figsize=(6*len(data_list),6))
cmax = data_list[-1][:,0].max()
for i in range(len(data_list)):
data = data_list[i][:-2048] if i == 1 else data_list[i]
color = data[:,0] /cmax
ax = fig.add_subplot(1, len(data_list) , i + 1, projection='3d')
ax.view_init(30, -120)
b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color,vmin=-1,vmax=1 ,cmap = cmap[0],s=4,linewidth=0.05, edgecolors = 'black')
ax.set_title(titles[i])
ax.set_axis_off()
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_zlim(zlim)
plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0)
if not os.path.exists(path):
os.makedirs(path)
pic_path = path + '.png'
fig.savefig(pic_path)
np.save(os.path.join(path, 'input.npy'), data_list[0].numpy())
np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy())
plt.close(fig)
def random_dropping(pc, e):
up_num = max(64, 768 // (e//50 + 1))
pc = pc
random_num = torch.randint(1, up_num, (1,1))[0,0]
pc = fps(pc, random_num)
padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device)
pc = torch.cat([pc, padding], dim = 1)
return pc
def random_scale(partial, scale_range=[0.8, 1.2]):
scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0]
return partial * scale