Spaces:
Runtime error
Runtime error
import multiprocessing | |
import networkx as nx | |
import numpy as np | |
import argparse | |
import os | |
import trimesh | |
from tqdm import tqdm | |
import ray | |
from check_valid import check_step_valid_soild, load_data_with_prefix | |
from eval_brepgen import normalize_pc | |
from eval_unique_novel import * | |
def find_connected_components(matrix): | |
N = len(matrix) | |
visited = [False] * N | |
components = [] | |
def dfs(idx, component): | |
stack = [idx] | |
while stack: | |
node = stack.pop() | |
if not visited[node]: | |
visited[node] = True | |
component.append(node) | |
for neighbor in range(N): | |
if matrix[node][neighbor] and not visited[neighbor]: | |
stack.append(neighbor) | |
for i in range(N): | |
if not visited[i]: | |
component = [] | |
dfs(i, component) | |
components.append(component) | |
return components | |
def compute_unique(graph_list, atol=None, is_use_ray=False, batch_size=100000, num_max_split_batch=128): | |
N = len(graph_list) | |
identical_pairs = [] | |
unique_graph_idx = list(range(N)) | |
pair_0, pair_1 = np.triu_indices(N, k=1) | |
check_pairs = np.column_stack((pair_0, pair_1)) | |
num_split_batch = len(check_pairs) // batch_size | |
if num_split_batch > 64: | |
num_split_batch = num_max_split_batch | |
batch_size = len(check_pairs) // num_split_batch | |
if not is_use_ray: | |
for idx1, idx2 in tqdm(check_pairs): | |
is_identical = is_graph_identical(graph_list[idx1], graph_list[idx2], atol=atol) | |
if is_identical: | |
unique_graph_idx.remove(idx2) if idx2 in unique_graph_idx else None | |
else: | |
N_batch = len(check_pairs) // batch_size | |
futures = [] | |
for i in tqdm(range(N_batch)): | |
batch_pairs = check_pairs[i * batch_size: (i + 1) * batch_size] | |
batch_graph_pair = [(graph_list[idx1], graph_list[idx2]) for idx1, idx2 in batch_pairs] | |
futures.append(is_graph_identical_remote.remote(batch_graph_pair, atol)) | |
results = ray.get(futures) | |
for batch_idx in tqdm(range(N_batch)): | |
for idx, is_identical in enumerate(results[batch_idx]): | |
if not is_identical: | |
continue | |
idx1, idx2 = check_pairs[batch_idx * batch_size + idx] | |
if idx2 in unique_graph_idx: | |
unique_graph_idx.remove(idx2) | |
identical_pairs.append((idx1, idx2)) | |
return unique_graph_idx, identical_pairs | |
def test_check(): | |
sample = np.random.rand(3, 32, 32, 3) | |
face1 = sample[[0, 1, 2]] | |
face2 = sample[[0, 2, 1]] | |
faces_adj1 = [[0, 1]] | |
faces_adj2 = [[0, 2]] | |
graph1 = build_graph(face1, faces_adj1) | |
graph2 = build_graph(face2, faces_adj2) | |
is_identical = is_graph_identical(graph1, graph2) | |
# 判断图是否相等 | |
print("Graphs are equal" if is_identical else "Graphs are not equal") | |
def load_data_from_npz(data_npz_file): | |
data_npz = np.load(data_npz_file, allow_pickle=True) | |
data_npz1 = np.load(data_npz_file.replace("deepcad_32", "deepcad_train_v6"), allow_pickle=True) | |
# Brepgen | |
if 'face_edge_adj' in data_npz: | |
faces = data_npz['pred_face'] | |
face_edge_adj = data_npz['face_edge_adj'] | |
faces_adj_pair = [] | |
N = face_edge_adj.shape[0] | |
for face_idx1 in range(N): | |
for face_idx2 in range(face_idx1 + 1, N): | |
face_edges1 = face_edge_adj[face_idx1] | |
face_edges2 = face_edge_adj[face_idx2] | |
if sorted((face_idx1, face_idx2)) in faces_adj_pair: | |
continue | |
if len(set(face_edges1).intersection(set(face_edges2))) > 0: | |
faces_adj_pair.append(sorted((face_idx1, face_idx2))) | |
return faces, faces_adj_pair | |
# Ours | |
if 'sample_points_faces' in data_npz: | |
face_points = data_npz['sample_points_faces'] # Face sample points (num_faces*20*20*3) | |
edge_face_connectivity = data_npz['edge_face_connectivity'] # (num_intersection, (id_edge, id_face1, id_face2)) | |
elif 'pred_face' in data_npz and 'pred_edge_face_connectivity' in data_npz: | |
face_points = data_npz['pred_face'] | |
edge_face_connectivity = data_npz['pred_edge_face_connectivity'] | |
else: | |
raise ValueError("Invalid data format") | |
faces_adj_pair = [] | |
for edge_idx, face_idx1, face_idx2 in edge_face_connectivity: | |
faces_adj_pair.append([face_idx1, face_idx2]) | |
if face_points.shape[-1] != 3: | |
face_points = face_points[..., :3] | |
src_shape = face_points.shape | |
face_points = normalize_pc(face_points.reshape(-1, 3)).reshape(src_shape) | |
return face_points, faces_adj_pair | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--train_root", type=str, required=True) | |
parser.add_argument("--n_bit", type=int) | |
parser.add_argument("--atol", type=float) | |
parser.add_argument("--use_ray", action='store_true') | |
parser.add_argument("--load_batch_size", type=int, default=100) | |
parser.add_argument("--compute_batch_size", type=int, default=10000) | |
parser.add_argument("--txt", type=str, default=None) | |
parser.add_argument("--num_cpus", type=int, default=32) | |
args = parser.parse_args() | |
train_data_root = args.train_root | |
is_use_ray = args.use_ray | |
n_bit = args.n_bit | |
atol = args.atol | |
load_batch_size = args.load_batch_size | |
compute_batch_size = args.compute_batch_size | |
folder_list_txt = args.txt | |
num_cpus = args.num_cpus | |
if not n_bit and not atol: | |
raise ValueError("Must set either n_bit or atol") | |
if n_bit and atol: | |
raise ValueError("Cannot set both n_bit and atol") | |
if n_bit: | |
atol = None | |
if atol: | |
n_bit = -1 | |
if folder_list_txt: | |
with open(folder_list_txt, "r") as f: | |
check_folders = [line.strip() for line in f.readlines()] | |
else: | |
check_folders = None | |
################################################## Unqiue ####################################################### | |
# Load all the data files | |
print("Loading data files...") | |
data_npz_file_list = load_data_with_prefix(train_data_root, 'data.npz') | |
data_npz_file_list.sort() | |
if is_use_ray: | |
ray.init() | |
futures = [] | |
graph_list = [] | |
prefix_list = [] | |
for i in tqdm(range(0, len(data_npz_file_list), load_batch_size)): | |
batch_data_npz_file_list = data_npz_file_list[i: i + load_batch_size] | |
futures.append(load_and_build_graph_remote.remote(batch_data_npz_file_list, check_folders, n_bit)) | |
for future in tqdm(futures): | |
result = ray.get(future) | |
graph_list_batch, prefix_list_batch = result | |
graph_list.extend(graph_list_batch) | |
prefix_list.extend(prefix_list_batch) | |
ray.shutdown() | |
else: | |
graph_list, prefix_list = load_and_build_graph(data_npz_file_list, n_bit) | |
print(f"Loaded {len(graph_list)} data files") | |
# sort the graph list according the face num | |
graph_node_num = np.array([graph.number_of_nodes() for graph in graph_list]) | |
identical_pairs_txt = train_data_root + f"_identical_pairs_{n_bit}bit.txt" | |
fp_identical_pairs = open(identical_pairs_txt, "w") | |
fp_identical_pairs.close() | |
novel_txt = train_data_root + f"_novel_{n_bit}bit.txt" | |
fp_novel = open(novel_txt, "w") | |
fp_novel.close() | |
if is_use_ray: | |
ray.init(_temp_dir=r"/mnt/d/img2brep/ray_temp") | |
unique_graph_idx_list = [] | |
pbar = tqdm(range(3, 31)) | |
for num_face in pbar: | |
print(f"Processing {num_face}") | |
pbar.set_description(f"Processing {num_face}") | |
fp_identical_pairs = open(identical_pairs_txt, "a") | |
fp_novel = open(novel_txt, "a") | |
print(f"face_num = {num_face}", file=fp_identical_pairs) | |
hits_graph_idx = np.where(graph_node_num == num_face)[0] | |
hits_graph = [graph_list[idx] for idx in tqdm(hits_graph_idx)] | |
hits_graph_prefix = [prefix_list[idx] for idx in hits_graph_idx] | |
if len(hits_graph) != 0: | |
local_unique_graph_idx_list, identical_pairs = compute_unique(hits_graph, atol, is_use_ray, compute_batch_size) | |
for unique_graph_idx in local_unique_graph_idx_list: | |
print(f"{hits_graph_prefix[unique_graph_idx]}", file=fp_novel) | |
local_unique_graph_idx_list = [hits_graph_idx[idx] for idx in local_unique_graph_idx_list] | |
unique_graph_idx_list.extend(local_unique_graph_idx_list) | |
if len(identical_pairs) > 0: | |
for idx1, idx2 in identical_pairs: | |
print(f"{hits_graph_prefix[idx1]} {hits_graph_prefix[idx2]}", file=fp_identical_pairs) | |
pbar.set_postfix({"Local Unique": len(local_unique_graph_idx_list) / len(hits_graph), | |
"Total Unique": len(unique_graph_idx_list) / len(graph_list), }) | |
print(f"Unique: {len(local_unique_graph_idx_list)}/{len(hits_graph_idx)}" | |
f"={len(local_unique_graph_idx_list) / len(hits_graph_idx)}", file=fp_identical_pairs) | |
else: | |
print(f"face_num = {num_face} has no data", file=fp_identical_pairs) | |
fp_identical_pairs.close() | |
fp_novel.close() | |
if is_use_ray: | |
ray.shutdown() | |
print(f"Unique num: {len(unique_graph_idx_list)}/{len(graph_list)}={len(unique_graph_idx_list) / len(graph_list)}") | |
print(f"Identical pairs are saved to {identical_pairs_txt}") | |
print(f"Novel txt are saved to {novel_txt}") | |
print("Done") | |
if __name__ == "__main__": | |
main() | |