Spaces:
Running
on
T4
Running
on
T4
import torch | |
import argparse | |
import os | |
import numpy as np | |
from lightning_fabric import seed_everything | |
from tqdm import tqdm | |
import random | |
import warnings | |
from scipy.stats import entropy | |
from sklearn.neighbors import NearestNeighbors | |
from plyfile import PlyData | |
from pathlib import Path | |
import multiprocessing | |
from chamfer_distance import ChamferDistance | |
from eval.eval_pc_set import * | |
N_POINTS = 2000 | |
def find_files(folder, extension): | |
return sorted([Path(os.path.join(folder, f)) for f in os.listdir(folder) if f.endswith(extension)]) | |
def read_ply(path): | |
with open(path, 'rb') as f: | |
plydata = PlyData.read(f) | |
x = np.array(plydata['vertex']['x']) | |
y = np.array(plydata['vertex']['y']) | |
z = np.array(plydata['vertex']['z']) | |
vertex = np.stack([x, y, z], axis=1) | |
return vertex | |
def distChamfer(a, b): | |
x, y = a, b | |
bs, num_points, points_dim = x.size() | |
xx = torch.bmm(x, x.transpose(2, 1)) | |
yy = torch.bmm(y, y.transpose(2, 1)) | |
zz = torch.bmm(x, y.transpose(2, 1)) | |
diag_ind = torch.arange(0, num_points).to(a).long() | |
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) | |
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) | |
P = (rx.transpose(2, 1) + ry - 2 * zz) | |
return P.min(1)[0], P.min(2)[0] | |
def _pairwise_CD(sample_pcs, ref_pcs, batch_size): | |
N_sample = sample_pcs.shape[0] | |
N_ref = ref_pcs.shape[0] | |
all_cd = [] | |
all_emd = [] | |
iterator = range(N_sample) | |
matched_gt = [] | |
pbar = tqdm(iterator) | |
chamfer_dist = ChamferDistance() | |
for sample_b_start in pbar: | |
sample_batch = sample_pcs[sample_b_start] | |
cd_lst = [] | |
emd_lst = [] | |
for ref_b_start in range(0, N_ref, batch_size): | |
ref_b_end = min(N_ref, ref_b_start + batch_size) | |
ref_batch = ref_pcs[ref_b_start:ref_b_end] | |
batch_size_ref = ref_batch.size(0) | |
sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1) | |
sample_batch_exp = sample_batch_exp.contiguous() | |
dl, dr, idx1, idx2 = chamfer_dist(sample_batch_exp, ref_batch) | |
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1)) | |
cd_lst = torch.cat(cd_lst, dim=1) | |
all_cd.append(cd_lst) | |
hit = np.argmin(cd_lst.detach().cpu().numpy()[0]) | |
matched_gt.append(hit) | |
pbar.set_postfix({"cov": len(np.unique(matched_gt)) * 1.0 / N_ref}) | |
all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref | |
return all_cd | |
def compute_cov_mmd(sample_pcs, ref_pcs, batch_size): | |
all_dist = _pairwise_CD(sample_pcs, ref_pcs, batch_size) | |
N_sample, N_ref = all_dist.size(0), all_dist.size(1) | |
min_val_fromsmp, min_idx = torch.min(all_dist, dim=1) | |
min_val, _ = torch.min(all_dist, dim=0) | |
mmd = min_val.mean() | |
cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) | |
cov = torch.tensor(cov).to(all_dist) | |
return { | |
'MMD-CD': mmd.item(), | |
'COV-CD': cov.item(), | |
}, min_idx.cpu().numpy() | |
def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, in_unit_sphere, resolution=28): | |
'''Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models | |
For 3D Point Clouds```. | |
Args: | |
sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points. | |
ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points. | |
resolution: (int) grid-resolution. Affects granularity of measurements. | |
''' | |
sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1] | |
ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1] | |
return jensen_shannon_divergence(sample_grid_var, ref_grid_var) | |
def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False): | |
'''Given a collection of point-clouds, estimate the entropy of the random variables | |
corresponding to occupancy-grid activation patterns. | |
Inputs: | |
pclouds: (numpy array) #point-clouds x points per point-cloud x 3 | |
grid_resolution (int) size of occupancy grid that will be used. | |
''' | |
epsilon = 10e-4 | |
bound = 1 + epsilon | |
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: | |
print(abs(np.max(pclouds)), abs(np.min(pclouds))) | |
warnings.warn('Point-clouds are not in unit cube.') | |
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: | |
warnings.warn('Point-clouds are not in unit sphere.') | |
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) | |
grid_coordinates = grid_coordinates.reshape(-1, 3) | |
grid_counters = np.zeros(len(grid_coordinates)) | |
grid_bernoulli_rvars = np.zeros(len(grid_coordinates)) | |
nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates) | |
for pc in pclouds: | |
_, indices = nn.kneighbors(pc) | |
indices = np.squeeze(indices) | |
for i in indices: | |
grid_counters[i] += 1 | |
indices = np.unique(indices) | |
for i in indices: | |
grid_bernoulli_rvars[i] += 1 | |
acc_entropy = 0.0 | |
n = float(len(pclouds)) | |
for g in grid_bernoulli_rvars: | |
p = 0.0 | |
if g > 0: | |
p = float(g) / n | |
acc_entropy += entropy([p, 1.0 - p]) | |
return acc_entropy / len(grid_counters), grid_counters | |
def unit_cube_grid_point_cloud(resolution, clip_sphere=False): | |
'''Returns the center coordinates of each cell of a 3D grid with resolution^3 cells, | |
that is placed in the unit-cube. | |
If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere. | |
''' | |
grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) | |
spacing = 1.0 / float(resolution - 1) * 2 | |
for i in range(resolution): | |
for j in range(resolution): | |
for k in range(resolution): | |
grid[i, j, k, 0] = i * spacing - 0.5 * 2 | |
grid[i, j, k, 1] = j * spacing - 0.5 * 2 | |
grid[i, j, k, 2] = k * spacing - 0.5 * 2 | |
if clip_sphere: | |
grid = grid.reshape(-1, 3) | |
grid = grid[np.linalg.norm(grid, axis=1) <= 0.5] | |
return grid, spacing | |
def jensen_shannon_divergence(P, Q): | |
if np.any(P < 0) or np.any(Q < 0): | |
raise ValueError('Negative values.') | |
if len(P) != len(Q): | |
raise ValueError('Non equal size.') | |
P_ = P / np.sum(P) # Ensure probabilities. | |
Q_ = Q / np.sum(Q) | |
e1 = entropy(P_, base=2) | |
e2 = entropy(Q_, base=2) | |
e_sum = entropy((P_ + Q_) / 2.0, base=2) | |
res = e_sum - ((e1 + e2) / 2.0) | |
res2 = _jsdiv(P_, Q_) | |
if not np.allclose(res, res2, atol=10e-5, rtol=0): | |
warnings.warn('Numerical values of two JSD methods don\'t agree.') | |
return res | |
def _jsdiv(P, Q): | |
'''another way of computing JSD''' | |
def _kldiv(A, B): | |
a = A.copy() | |
b = B.copy() | |
idx = np.logical_and(a > 0, b > 0) | |
a = a[idx] | |
b = b[idx] | |
return np.sum([v for v in a * np.log2(a / b)]) | |
P_ = P / np.sum(P) | |
Q_ = Q / np.sum(Q) | |
M = 0.5 * (P_ + Q_) | |
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) | |
def downsample_pc(points, n): | |
sample_idx = random.sample(list(range(points.shape[0])), n) | |
return points[sample_idx] | |
def normalize_pc(points): | |
# normalize | |
mean = np.mean(points, axis=0) | |
points = (points - mean) | |
# fit to unit cube | |
scale = np.max(np.abs(points)) | |
points = points / scale | |
return points | |
def align_pc(points): | |
# 1. Center the point cloud | |
centroid = np.mean(points, axis=0) | |
centered_points = points - centroid | |
# 2. Calculate the three edge lengths of bbox | |
min_coords = np.min(centered_points, axis=0) | |
max_coords = np.max(centered_points, axis=0) | |
dimensions = max_coords - min_coords | |
# 3. Sort axes by dimension length to get axis order | |
axis_order = np.argsort(dimensions)[::-1] # sort from longest to shortest | |
# 4. Create permutation matrix (align longest edge to x, shortest to y) | |
perm_matrix = np.zeros((3, 3)) | |
perm_matrix[0, axis_order[0]] = 1 # longest edge -> x | |
perm_matrix[1, axis_order[2]] = 1 # shortest edge -> y | |
perm_matrix[2, axis_order[1]] = 1 # medium edge -> z | |
# 5. Apply transformation | |
aligned_points = np.dot(centered_points, perm_matrix.T) | |
# 6. Ensure same centroid faces direction | |
if np.mean(aligned_points[:, 2]) < 0: | |
aligned_points[:, 2] *= -1 | |
return aligned_points | |
def collect_pc(cad_folder): | |
pc_path = find_files(os.path.join(cad_folder, 'pcd'), 'final_pcd.ply') | |
if len(pc_path) == 0: | |
return [] | |
pc_path = pc_path[-1] # final pcd | |
pc = read_ply(pc_path) | |
if pc.shape[0] > N_POINTS: | |
pc = downsample_pc(pc, N_POINTS) | |
pc = normalize_pc(pc) | |
return pc | |
def collect_pc2(cad_folder): | |
pc = read_ply(cad_folder) | |
if pc.shape[0] > N_POINTS: | |
pc = downsample_pc(pc, N_POINTS) | |
pc = normalize_pc(pc) | |
pc = align_pc(pc) | |
return pc | |
theta_x = np.radians(90) # Rotation angle around X-axis | |
theta_y = np.radians(90) # Rotation angle around Y-axis | |
theta_z = np.radians(180) # Rotation angle around Z-axis | |
# Create individual rotation matrices | |
Rx = np.array([[1, 0, 0], | |
[0, np.cos(theta_x), -np.sin(theta_x)], | |
[0, np.sin(theta_x), np.cos(theta_x)]]) | |
Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)], | |
[0, 1, 0], | |
[-np.sin(theta_y), 0, np.cos(theta_y)]]) | |
Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0], | |
[np.sin(theta_z), np.cos(theta_z), 0], | |
[0, 0, 1]]) | |
rotation_matrix = np.dot(np.dot(Rz, Ry), Rx) | |
def collect_pc3(cad_folder): | |
pc = read_ply(cad_folder) | |
if pc.shape[0] > N_POINTS: | |
pc = downsample_pc(pc, N_POINTS) | |
pc = normalize_pc(pc) | |
rotated_point_cloud = np.dot(pc, rotation_matrix.T).astype(np.float32) # Transpose the rotation matrix to apply it correctly | |
return rotated_point_cloud | |
def load_data_with_prefix(root_folder, prefix): | |
data_files = [] | |
# Walk through the directory tree starting from the root folder | |
for root, dirs, files in os.walk(root_folder): | |
for filename in files: | |
# Check if the file ends with the specified prefix | |
if filename.endswith(prefix): | |
file_path = os.path.join(root, filename) | |
data_files.append(file_path) | |
data_files.sort() | |
return data_files | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--fake", type=str) | |
parser.add_argument("--real", type=str) | |
parser.add_argument("--n_test", type=int, default=1000) | |
parser.add_argument("--multi", type=float, default=3) | |
parser.add_argument("--times", type=int, default=10) | |
parser.add_argument("--batch_size", type=int, default=64) | |
args = parser.parse_args() | |
seed_everything(0) | |
print("n_test: {}, multiplier: {}, repeat times: {}".format(args.n_test, args.multi, args.times)) | |
args.output = args.fake + '_results.txt' | |
seed_everything(0) | |
# Load reference pcd | |
num_cpus = multiprocessing.cpu_count() | |
ref_pcs = [] | |
gt_shape_paths = load_data_with_prefix(args.real, '.ply') | |
load_iter = multiprocessing.Pool(num_cpus).imap(collect_pc2, gt_shape_paths) | |
for pc in tqdm(load_iter, total=len(gt_shape_paths)): | |
if len(pc) > 0: | |
ref_pcs.append(pc) | |
ref_pcs = np.stack(ref_pcs, axis=0) | |
print("real point clouds: {}".format(ref_pcs.shape)) | |
# Load fake pcd | |
sample_pcs = [] | |
shape_paths = load_data_with_prefix(args.fake, '.ply') | |
load_iter = multiprocessing.Pool(num_cpus).imap(collect_pc2, shape_paths) | |
for pc in tqdm(load_iter, total=len(shape_paths)): | |
if len(pc) > 0: | |
sample_pcs.append(pc) | |
sample_pcs = np.stack(sample_pcs, axis=0) | |
print("fake point clouds: {}".format(sample_pcs.shape)) | |
# Testing | |
cov_on_gt = [] | |
fp = open(args.output, "w") | |
result_list = [] | |
for i in range(args.times): | |
print("iteration {}...".format(i)) | |
select_idx1 = random.sample(list(range(len(sample_pcs))), int(args.multi * args.n_test)) | |
rand_sample_pcs = sample_pcs[select_idx1] | |
select_idx2 = random.sample(list(range(len(ref_pcs))), args.n_test) | |
rand_ref_pcs = ref_pcs[select_idx2] | |
jsd = jsd_between_point_cloud_sets(rand_sample_pcs, rand_ref_pcs, in_unit_sphere=False) | |
with torch.no_grad(): | |
rand_sample_pcs = torch.tensor(rand_sample_pcs).cuda().float() | |
rand_ref_pcs = torch.tensor(rand_ref_pcs).cuda().float() | |
result, idx = compute_cov_mmd(rand_sample_pcs, rand_ref_pcs, batch_size=args.batch_size) | |
result.update({"JSD": jsd}) | |
cov_on_gt.extend(list(np.array(select_idx2)[np.unique(idx)])) | |
if False: | |
unique_idx = np.unique(idx, return_counts=True) | |
id_gts = unique_idx[0][np.argsort(unique_idx[1])[::-1][:100]] | |
gt_prefixes = [os.path.basename(gt_shape_paths[i])[:8] for i in select_idx2] | |
pred_prefixes = [os.path.basename(shape_paths[i])[:8] for i in select_idx1] | |
gt_prefixes[403] | |
print(result) | |
print(result, file=fp) | |
result_list.append(result) | |
avg_result = {} | |
for k in result_list[0].keys(): | |
avg_result.update({"avg-" + k: np.mean([x[k] for x in result_list])}) | |
print("average result:") | |
print(avg_result) | |
print(avg_result, file=fp) | |
fp.close() | |
cov_on_gt = list(set(cov_on_gt)) | |
cov_on_gt = [gt_shape_paths[i] for i in cov_on_gt] | |
np.save(args.fake + '_cov_on_gt.npy', cov_on_gt) | |
if __name__ == '__main__': | |
main() | |