quasi-physical-sims / models /data_utils_torch.py
meow
a
710e818
raw
history blame
62.1 kB
"""Mesh data utilities."""
from re import I
# import matplotlib.pyplot as plt
# from mpl_toolkits import mplot3d # pylint: disable=unused-import
# from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import networkx as nx
import numpy as np
import six
import os
import math
import torch
import models.fields as fields
# from options.options import opt
# from polygen_torch.
# from utils.constants import MASK_GRID_VALIE
try:
from torch_cluster import fps
except:
pass
MAX_RANGE = 0.1
MIN_RANGE = -0.1
# import open3d as o3d
def calculate_correspondences(last_mesh, bending_network, tot_timesteps, delta_bending):
# iterate all the timesteps and get the bended
timestep_to_pts = {
tot_timesteps - 1: last_mesh.detach().cpu().numpy()
}
if delta_bending:
for i_ts in range(tot_timesteps - 1, -1, -1):
if isinstance(bending_network, list):
tot_offsets = []
for i_bending, cur_bending_net in enumerate(bending_network):
cur_def_pts = cur_bending_net(last_mesh, i_ts)
tot_offsets.append(cur_def_pts - last_mesh)
tot_offsets = torch.stack(tot_offsets, dim=0)
tot_offsets = torch.sum(tot_offsets, dim=0)
last_mesh = last_mesh + tot_offsets
else:
last_mesh = bending_network(last_mesh, i_ts)
timestep_to_pts[i_ts - 1] = last_mesh.detach().cpu().numpy()
elif isinstance(bending_network, fields.BendingNetworkRigidTrans):
for i_ts in range(tot_timesteps - 1, -1, -1):
last_mesh = bending_network.forward_delta(last_mesh, i_ts)
timestep_to_pts[i_ts - 1] = last_mesh.detach().cpu().numpy()
else:
raise ValueError(f"the function is designed for delta bending")
return timestep_to_pts
# pass
def joint_infos_to_numpy(joint_infos):
joint_infos_np = []
for part_joint_info in joint_infos:
for zz in ["dir", "center"]:
# if isinstance(part_joint_info["axis"][zz], np.array):
part_joint_info["axis"][zz] = part_joint_info["axis"][zz].detach().numpy()
joint_infos_np.append(part_joint_info)
return joint_infos_np
def normalie_pc_bbox_batched(pc, rt_stats=False):
pc_min = torch.min(pc, dim=1, keepdim=True)[0]
pc_max = torch.max(pc, dim=1, keepdim=True)[0]
pc_center = 0.5 * (pc_min + pc_max)
pc = pc - pc_center
extents = pc_max - pc_min
scale = torch.sqrt(torch.sum(extents ** 2, dim=-1, keepdim=True))
pc = pc / torch.clamp(scale, min=1e-6)
if rt_stats:
return pc, pc_center, scale
else:
return pc
def scale_vertices_to_target_scale(vertices, target_scale):
# vertices: bsz x N x 3;
# target_scale: bsz x 1
# normalized_vertices = normalize_vertices_scale_torch(vertices)
normalized_vertices = normalie_pc_bbox_batched(vertices)
normalized_vertices = normalized_vertices * target_scale.unsqueeze(1)
return normalized_vertices
def compute_normals_o3d(verts, faces): #### assume no batching... ####
mesh = o3d.geometry.TriangleMesh()
# o3d_mesh_b.vertices = verts_b
# o3d_mesh_b.triangles = np.array(faces_b, dtype=np.long)
mesh.vertices = verts.detach().cpu().numpy()
mesh.triangles = faces.detach().cpu().numpy().astype(np.long)
verts_normals = mesh.compute_vertex_normals(normalized=True)
verts_normals = torch.from_numpy(verts_normals, dtype=torch.float32).cuda()
return verts_normals
def get_vals_via_nearest_neighbours(pts_src, pts_tar, val_tar):
### n_src x 3 ---> n_src x n_tar
dist_src_tar = torch.sum((pts_src.unsqueeze(-2) - pts_tar.unsqueeze(0)) ** 2, dim=-1)
minn_dists_src_tar, minn_dists_tar_idxes = torch.min(dist_src_tar, dim=-1) ### n_src
selected_src_val = batched_index_select(values=val_tar, indices=minn_dists_tar_idxes, dim=0) ### n_src x dim
return selected_src_val
## sample conected componetn start from selected_verts
def sample_bfs_component(selected_vert, faces, max_num_grids):
vert_idx_to_adj_verts = {}
for i_f, cur_f in enumerate(faces):
# for i0, v0 in enumerate(cur_f):
for i0 in range(len(cur_f)):
v0 = cur_f[i0] - 1
i1 = (i0 + 1) % len(cur_f)
v1 = cur_f[i1] - 1
if v0 not in vert_idx_to_adj_verts:
vert_idx_to_adj_verts[v0] = {v1: 1}
else:
vert_idx_to_adj_verts[v0][v1] = 1
if v1 not in vert_idx_to_adj_verts:
vert_idx_to_adj_verts[v1] = {v0: 1}
else:
vert_idx_to_adj_verts[v1][v0] = 1
vert_idx_to_visited = {} # whether visisted here #
vis_que = [selected_vert]
vert_idx_to_visited[selected_vert] = 1
visited = [selected_vert]
while len(vis_que) > 0 and len(visited) < max_num_grids:
cur_frnt_vert = vis_que[0]
vis_que.pop(0)
if cur_frnt_vert in vert_idx_to_adj_verts:
cur_frnt_vert_adjs = vert_idx_to_adj_verts[cur_frnt_vert]
for adj_vert in cur_frnt_vert_adjs:
if adj_vert not in vert_idx_to_visited:
vert_idx_to_visited[adj_vert] = 1
vis_que.append(adj_vert)
visited.append(adj_vert)
if len(visited) >= max_num_grids:
visited = visited[: max_num_grids - 1]
return visited
def select_faces_via_verts(selected_verts, faces):
if not isinstance(selected_verts, list):
selected_verts = selected_verts.tolist()
# selected_verts_dict = {ii: 1 for ii in selected_verts}
old_idx_to_new_idx = {v + 1: ii + 1 for ii, v in enumerate(selected_verts)} ####### v + 1: ii + 1 --> for the selected_verts
new_faces = []
for i_f, cur_f in enumerate(faces):
cur_new_f = []
valid = True
for cur_v in cur_f:
if cur_v in old_idx_to_new_idx:
cur_new_f.append(old_idx_to_new_idx[cur_v])
else:
valid = False
break
if valid:
new_faces.append(cur_new_f)
return new_faces
def convert_grid_content_to_grid_pts(content_value, grid_size):
flat_grid = torch.zeros([grid_size ** 3], dtype=torch.long)
cur_idx = flat_grid.size(0) - 1
while content_value > 0:
flat_grid[cur_idx] = content_value % grid_size
content_value = content_value // grid_size
cur_idx -= 1
grid_pts = flat_grid.contiguous().view(grid_size, grid_size, grid_size).contiguous()
return grid_pts
# 0.2
def warp_coord(sampled_gradients, val, reflect=False): # val from [0.0, 1.0] # from the 0.0
# assume single value as inputs
grad_values = sampled_gradients.tolist()
# mid_val
mid_val = grad_values[0] * 0.2 + grad_values[1] * 0.2 + grad_values[2] * 0.1
if reflect:
grad_values[3] = grad_values[1]
grad_values[4] = grad_values[0]
# if not reflect:
accum_val = 0.0
for i_val in range(len(grad_values)):
if val > 0.2 * (i_val + 1) and i_val < 4: # if i_val == 4, directly use the reamining length * corresponding gradient value
accum_val += grad_values[i_val] * 0.2
else:
accum_val += grad_values[i_val] * (val - 0.2 * i_val)
break
return accum_val # modified value
def random_shift(vertices, shift_factor=0.25):
"""Apply random shift to vertices."""
# max_shift_pos = tf.cast(255 - tf.reduce_max(vertices, axis=0), tf.float32)
# max_shift_pos = tf.maximum(max_shift_pos, 1e-9)
# max_shift_neg = tf.cast(tf.reduce_min(vertices, axis=0), tf.float32)
# max_shift_neg = tf.maximum(max_shift_neg, 1e-9)
# shift = tfd.TruncatedNormal(
# tf.zeros([1, 3]), shift_factor*255, -max_shift_neg,
# max_shift_pos).sample()
# shift = tf.cast(shift, tf.int32)
# vertices += shift
minn_tensor = torch.tensor([1e-9], dtype=torch.float32)
max_shift_pos = (255 - torch.max(vertices, dim=0)[0]).float()
max_shift_pos = torch.maximum(max_shift_pos, minn_tensor)
max_shift_neg = (torch.min(vertices, dim=0)[0]).float()
max_shift_neg = torch.maximum(max_shift_neg, minn_tensor)
shift = torch.zeros((1, 3), dtype=torch.float32)
# torch.nn.init.trunc_normal_(shift, 0., shift_factor * 255., -max_shift_neg, max_shift_pos)
for i_s in range(shift.size(-1)):
cur_axis_max_shift_neg = max_shift_neg[i_s].item()
cur_axis_max_shift_pos = max_shift_pos[i_s].item()
cur_axis_shift = torch.zeros((1,), dtype=torch.float32)
torch.nn.init.trunc_normal_(cur_axis_shift, 0., shift_factor * 255., -cur_axis_max_shift_neg, cur_axis_max_shift_pos)
shift[:, i_s] = cur_axis_shift.item()
shift = shift.long()
vertices += shift
return vertices
def safe_transpose(tsr, dima, dimb):
tsr = tsr.contiguous().transpose(dima, dimb).contiguous()
return tsr
def batched_index_select(values, indices, dim = 1):
value_dims = values.shape[(dim + 1):]
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
indices = indices[(..., *((None,) * len(value_dims)))]
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
value_expand_len = len(indices_shape) - (dim + 1)
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
value_expand_shape = [-1] * len(values.shape)
expand_slice = slice(dim, (dim + value_expand_len))
value_expand_shape[expand_slice] = indices.shape[expand_slice]
values = values.expand(*value_expand_shape)
dim += value_expand_len
return values.gather(dim, indices)
def read_obj_file_ours(obj_fn, sub_one=False):
vertices = []
faces = []
with open(obj_fn, "r") as rf:
for line in rf:
items = line.strip().split(" ")
if items[0] == 'v':
cur_verts = items[1:]
cur_verts = [float(vv) for vv in cur_verts]
vertices.append(cur_verts)
elif items[0] == 'f':
cur_faces = items[1:] # faces
cur_face_idxes = []
for cur_f in cur_faces:
try:
cur_f_idx = int(cur_f.split("/")[0])
except:
cur_f_idx = int(cur_f.split("//")[0])
cur_face_idxes.append(cur_f_idx if not sub_one else cur_f_idx - 1)
faces.append(cur_face_idxes)
rf.close()
vertices = np.array(vertices, dtype=np.float)
return vertices, faces
def read_obj_file(obj_file):
"""Read vertices and faces from already opened file."""
vertex_list = []
flat_vertices_list = []
flat_vertices_indices = {}
flat_triangles = []
for line in obj_file:
tokens = line.split()
if not tokens:
continue
line_type = tokens[0]
# We skip lines not starting with v or f.
if line_type == 'v': #
vertex_list.append([float(x) for x in tokens[1:]])
elif line_type == 'f':
triangle = []
for i in range(len(tokens) - 1):
vertex_name = tokens[i + 1]
if vertex_name in flat_vertices_indices: # triangles
triangle.append(flat_vertices_indices[vertex_name])
continue
flat_vertex = []
for index in six.ensure_str(vertex_name).split('/'):
if not index:
continue
# obj triangle indices are 1 indexed, so subtract 1 here.
flat_vertex += vertex_list[int(index) - 1]
flat_vertex_index = len(flat_vertices_list)
flat_vertices_list.append(flat_vertex)
# flat_vertex_index
flat_vertices_indices[vertex_name] = flat_vertex_index
triangle.append(flat_vertex_index)
flat_triangles.append(triangle)
return np.array(flat_vertices_list, dtype=np.float32), flat_triangles
def batched_index_select(values, indices, dim = 1):
value_dims = values.shape[(dim + 1):]
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
indices = indices[(..., *((None,) * len(value_dims)))]
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
value_expand_len = len(indices_shape) - (dim + 1)
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
value_expand_shape = [-1] * len(values.shape)
expand_slice = slice(dim, (dim + value_expand_len))
value_expand_shape[expand_slice] = indices.shape[expand_slice]
values = values.expand(*value_expand_shape)
dim += value_expand_len
return values.gather(dim, indices)
def safe_transpose(x, dim1, dim2):
x = x.contiguous().transpose(dim1, dim2).contiguous()
return x
def merge_meshes(vertices_list, faces_list):
tot_vertices = []
tot_faces = []
nn_verts = 0
for cur_vertices, cur_faces in zip(vertices_list, faces_list):
tot_vertices.append(cur_vertices)
new_cur_faces = []
for cur_face_idx in cur_faces:
new_cur_face_idx = [vert_idx + nn_verts for vert_idx in cur_face_idx]
new_cur_faces.append(new_cur_face_idx)
nn_verts += cur_vertices.shape[0]
tot_faces += new_cur_faces # get total-faces
tot_vertices = np.concatenate(tot_vertices, axis=0)
return tot_vertices, tot_faces
def read_obj(obj_path):
"""Open .obj file from the path provided and read vertices and faces."""
with open(obj_path) as obj_file:
return read_obj_file_ours(obj_path, sub_one=True)
# return read_obj_file(obj_file)
def center_vertices(vertices):
"""Translate the vertices so that bounding box is centered at zero."""
vert_min = vertices.min(axis=0)
vert_max = vertices.max(axis=0)
vert_center = 0.5 * (vert_min + vert_max)
return vertices - vert_center
def normalize_vertices_scale(vertices):
"""Scale the vertices so that the long diagonal of the bounding box is one."""
vert_min = vertices.min(axis=0)
vert_max = vertices.max(axis=0)
extents = vert_max - vert_min
scale = np.sqrt(np.sum(extents**2)) # normalize the diagonal line to 1.
return vertices / scale
def get_vertices_center(vertices):
vert_min = vertices.min(axis=0)
vert_max = vertices.max(axis=0)
vert_center = 0.5 * (vert_min + vert_max)
return vert_center
def get_batched_vertices_center(vertices):
vert_min = vertices.min(axis=1)
vert_max = vertices.max(axis=1)
vert_center = 0.5 * (vert_min + vert_max)
return vert_center
def get_vertices_scale(vertices):
vert_min = vertices.min(axis=0)
vert_max = vertices.max(axis=0)
extents = vert_max - vert_min
scale = np.sqrt(np.sum(extents**2))
return scale
def quantize_verts(verts, n_bits=8, min_range=None, max_range=None):
"""Convert vertices in [-1., 1.] to discrete values in [0, n_bits**2 - 1]."""
min_range = -0.5 if min_range is None else min_range
max_range = 0.5 if max_range is None else max_range
range_quantize = 2**n_bits - 1
verts_quantize = (verts - min_range) * range_quantize / (
max_range - min_range)
return verts_quantize.astype('int32')
def quantize_verts_torch(verts, n_bits=8, min_range=None, max_range=None):
min_range = -0.5 if min_range is None else min_range
max_range = 0.5 if max_range is None else max_range
range_quantize = 2**n_bits - 1
verts_quantize = (verts - min_range) * range_quantize / (
max_range - min_range)
return verts_quantize.long()
def dequantize_verts(verts, n_bits=8, add_noise=False, min_range=None, max_range=None):
"""Convert quantized vertices to floats."""
min_range = -0.5 if min_range is None else min_range
max_range = 0.5 if max_range is None else max_range
range_quantize = 2**n_bits - 1
verts = verts.astype('float32')
verts = verts * (max_range - min_range) / range_quantize + min_range
if add_noise:
verts += np.random.uniform(size=verts.shape) * (1 / range_quantize)
return verts
def dequantize_verts_torch(verts, n_bits=8, add_noise=False, min_range=None, max_range=None):
min_range = -0.5 if min_range is None else min_range
max_range = 0.5 if max_range is None else max_range
range_quantize = 2**n_bits - 1
verts = verts.float()
verts = verts * (max_range - min_range) / range_quantize + min_range
# if add_noise:
# verts += np.random.uniform(size=verts.shape) * (1 / range_quantize)
return verts
### dump vertices and faces to the obj file
def write_obj(vertices, faces, file_path, transpose=True, scale=1.):
"""Write vertices and faces to obj."""
if transpose:
vertices = vertices[:, [1, 2, 0]]
vertices *= scale
if faces is not None:
if min(min(faces)) == 0:
f_add = 1
else:
f_add = 0
else:
faces = []
with open(file_path, 'w') as f:
for v in vertices:
f.write('v {} {} {}\n'.format(v[0], v[1], v[2]))
for face in faces:
line = 'f'
for i in face:
line += ' {}'.format(i + f_add)
line += '\n'
f.write(line)
def face_to_cycles(face):
"""Find cycles in face."""
g = nx.Graph()
for v in range(len(face) - 1):
g.add_edge(face[v], face[v + 1])
g.add_edge(face[-1], face[0])
return list(nx.cycle_basis(g))
def flatten_faces(faces):
"""Converts from list of faces to flat face array with stopping indices."""
if not faces:
return np.array([0])
else:
l = [f + [-1] for f in faces[:-1]]
l += [faces[-1] + [-2]]
return np.array([item for sublist in l for item in sublist]) + 2 # pylint: disable=g-complex-comprehension
def unflatten_faces(flat_faces):
"""Converts from flat face sequence to a list of separate faces."""
def group(seq):
g = []
for el in seq:
if el == 0 or el == -1:
yield g
g = []
else:
g.append(el - 1)
yield g
outputs = list(group(flat_faces - 1))[:-1]
# Remove empty faces
return [o for o in outputs if len(o) > 2]
def quantize_process_mesh(vertices, faces, tris=None, quantization_bits=8, remove_du=True):
"""Quantize vertices, remove resulting duplicates and reindex faces."""
vertices = quantize_verts(vertices, quantization_bits)
vertices, inv = np.unique(vertices, axis=0, return_inverse=True) # return inverse and unique the vertices
#
if opt.dataset.sort_dist:
if opt.model.debug:
print("sorting via dist...")
vertices_max = np.max(vertices, axis=0)
vertices_min = np.min(vertices, axis=0)
dist_vertices = np.minimum(np.abs(vertices - np.array([[vertices_min[0], vertices_min[1], 0]])), np.abs(vertices - np.array([[vertices_max[0], vertices_max[1], 0]])))
dist_vertices = np.concatenate([dist_vertices[:, 0:1] + dist_vertices[:, 1:2], dist_vertices[:, 2:]], axis=-1)
sort_inds = np.lexsort(dist_vertices.T)
else:
# Sort vertices by z then y then x.
sort_inds = np.lexsort(vertices.T) # sorted indices...
vertices = vertices[sort_inds]
# Re-index faces and tris to re-ordered vertices.
faces = [np.argsort(sort_inds)[inv[f]] for f in faces]
if tris is not None:
tris = np.array([np.argsort(sort_inds)[inv[t]] for t in tris])
# Merging duplicate vertices and re-indexing the faces causes some faces to
# contain loops (e.g [2, 3, 5, 2, 4]). Split these faces into distinct
# sub-faces.
sub_faces = []
for f in faces:
cliques = face_to_cycles(f)
for c in cliques:
c_length = len(c)
# Only append faces with more than two verts.
if c_length > 2:
d = np.argmin(c)
# Cyclically permute faces just that first index is the smallest.
sub_faces.append([c[(d + i) % c_length] for i in range(c_length)])
faces = sub_faces
if tris is not None:
tris = np.array([v for v in tris if len(set(v)) == len(v)])
# Sort faces by lowest vertex indices. If two faces have the same lowest
# index then sort by next lowest and so on.
faces.sort(key=lambda f: tuple(sorted(f)))
if tris is not None:
tris = tris.tolist()
tris.sort(key=lambda f: tuple(sorted(f)))
tris = np.array(tris)
# After removing degenerate faces some vertices are now unreferenced. # Vertices
# Remove these. # Vertices
num_verts = vertices.shape[0]
# print(f"remove_du: {remove_du}")
if remove_du: ##### num_verts
print("Removing du..")
try:
vert_connected = np.equal(
np.arange(num_verts)[:, None], np.hstack(faces)[None]).any(axis=-1)
vertices = vertices[vert_connected]
# Re-index faces and tris to re-ordered vertices.
vert_indices = (
np.arange(num_verts) - np.cumsum(1 - vert_connected.astype('int')))
faces = [vert_indices[f].tolist() for f in faces]
except:
pass
if tris is not None:
tris = np.array([vert_indices[t].tolist() for t in tris])
return vertices, faces, tris
def process_mesh(vertices, faces, quantization_bits=8, recenter_mesh=True, remove_du=True):
"""Process mesh vertices and faces."""
# Transpose so that z-axis is vertical.
vertices = vertices[:, [2, 0, 1]]
if recenter_mesh:
# Translate the vertices so that bounding box is centered at zero.
vertices = center_vertices(vertices)
# Scale the vertices so that the long diagonal of the bounding box is equal
# to one.
vertices = normalize_vertices_scale(vertices)
# Quantize and sort vertices, remove resulting duplicates, sort and reindex
# faces.
vertices, faces, _ = quantize_process_mesh(
vertices, faces, quantization_bits=quantization_bits, remove_du=remove_du) ##### quantize_process_mesh
# unflatten_faces = np.array(faces, dtype=np.long) ### start from zero
# Flatten faces and add 'new face' = 1 and 'stop' = 0 tokens.
faces = flatten_faces(faces)
# Discard degenerate meshes without faces.
return {
'vertices': vertices,
'faces': faces,
}
def process_mesh_ours(vertices, faces, quantization_bits=8, recenter_mesh=True, remove_du=True):
"""Process mesh vertices and faces."""
# Transpose so that z-axis is vertical.
vertices = vertices[:, [2, 0, 1]]
if recenter_mesh:
# Translate the vertices so that bounding box is centered at zero.
vertices = center_vertices(vertices)
# Scale the vertices so that the long diagonal of the bounding box is equal
# to one.
vertices = normalize_vertices_scale(vertices)
# Quantize and sort vertices, remove resulting duplicates, sort and reindex
# faces.
quant_vertices, faces, _ = quantize_process_mesh(
vertices, faces, quantization_bits=quantization_bits, remove_du=remove_du) ##### quantize_process_mesh
vertices = dequantize_verts(quant_vertices) #### dequantize vertices ####
### vertices: nn_verts x 3
# try:
# # print("faces", faces)
# unflatten_faces = np.array(faces, dtype=np.long)
# except:
# print("faces", faces)
# raise ValueError("Something bad happened when processing meshes...")
# Flatten faces and add 'new face' = 1 and 'stop' = 0 tokens.
faces = flatten_faces(faces)
# Discard degenerate meshes without faces.
return {
'vertices': quant_vertices,
'faces': faces,
# 'unflatten_faces': unflatten_faces,
'dequant_vertices': vertices,
'class_label': 0
}
def read_mesh_from_obj_file(fn, quantization_bits=8, recenter_mesh=True, remove_du=True):
vertices, faces = read_obj(fn)
# print(vertices.shape)
mesh_dict = process_mesh_ours(vertices, faces, quantization_bits=quantization_bits, recenter_mesh=recenter_mesh, remove_du=remove_du)
return mesh_dict
def process_mesh_list(vertices, faces, quantization_bits=8, recenter_mesh=True):
"""Process mesh vertices and faces."""
vertices = [cur_vert[:, [2, 0, 1]] for cur_vert in vertices]
tot_vertices = np.concatenate(vertices, axis=0) # center and scale of those vertices
vert_center = get_vertices_center(tot_vertices)
vert_scale = get_vertices_scale(tot_vertices)
processed_vertices, processed_faces = [], []
for cur_verts, cur_faces in zip(vertices, faces):
# print(f"current vertices: {cur_verts.shape}, faces: {len(cur_faces)}")
normalized_verts = (cur_verts - vert_center) / vert_scale
cur_processed_verts, cur_processed_faces, _ = quantize_process_mesh(
normalized_verts, cur_faces, quantization_bits=quantization_bits
)
processed_vertices.append(cur_processed_verts)
processed_faces.append(cur_processed_faces)
vertices, faces = merge_meshes(processed_vertices, processed_faces)
faces = flatten_faces(faces=faces)
# Discard degenerate meshes without faces.
return {
'vertices': vertices,
'faces': faces,
}
def plot_sampled_meshes(v_sample, f_sample, sv_mesh_folder, cur_step=0, predict_joint=True,):
if not os.path.exists(sv_mesh_folder):
os.mkdir(sv_mesh_folder)
part_vertex_samples = [v_sample['left'], v_sample['rgt']]
part_face_samples = [f_sample['left'], f_sample['rgt']]
tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
tot_n_part = 2
if predict_joint:
pred_dir = v_sample['joint_dir']
pred_pvp = v_sample['joint_pvp']
print("pred_dir", pred_dir.shape, pred_dir)
print("pred_pvp", pred_pvp.shape, pred_pvp)
else:
pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
tot_mesh_list = []
for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
mesh_list = []
for i_n in range(tot_n_samples):
mesh_list.append(
{
'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
'faces': unflatten_faces(
cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
}
)
tot_mesh_list.append(mesh_list)
# and write this obj file?
# write_obj(vertices, faces, file_path, transpose=True, scale=1.):
# write mesh objs
for i_n in range(tot_n_samples):
cur_mesh = mesh_list[i_n]
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=False, scale=1.)
###### plot mesh (predicted) ######
tot_samples_mesh_dict = []
for i_s in range(tot_n_samples):
cur_s_tot_vertices = []
cur_s_tot_faces = []
cur_s_n_vertices = 0
for i_p in range(tot_n_part):
cur_s_cur_part_mesh_dict = tot_mesh_list[i_p][i_s]
cur_s_cur_part_vertices, cur_s_cur_part_faces = cur_s_cur_part_mesh_dict['vertices'], \
cur_s_cur_part_mesh_dict['faces']
cur_s_cur_part_new_faces = []
for cur_s_cur_part_cur_face in cur_s_cur_part_faces:
cur_s_cur_part_cur_new_face = [fid + cur_s_n_vertices for fid in cur_s_cur_part_cur_face]
cur_s_cur_part_new_faces.append(cur_s_cur_part_cur_new_face)
cur_s_n_vertices += cur_s_cur_part_vertices.shape[0]
cur_s_tot_vertices.append(cur_s_cur_part_vertices)
cur_s_tot_faces += cur_s_cur_part_new_faces
cur_s_tot_vertices = np.concatenate(cur_s_tot_vertices, axis=0)
cur_s_mesh_dict = {
'vertices': cur_s_tot_vertices, 'faces': cur_s_tot_faces
}
tot_samples_mesh_dict.append(cur_s_mesh_dict)
for i_s in range(tot_n_samples):
cur_mesh = tot_samples_mesh_dict[i_s]
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_s}.obj")
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=False, scale=1.)
###### plot mesh (predicted) ######
###### plot mesh (translated) ######
tot_samples_mesh_dict = []
for i_s in range(tot_n_samples):
cur_s_tot_vertices = []
cur_s_tot_faces = []
cur_s_n_vertices = 0
cur_s_pred_pvp = pred_pvp[i_s]
for i_p in range(tot_n_part):
cur_s_cur_part_mesh_dict = tot_mesh_list[i_p][i_s]
cur_s_cur_part_vertices, cur_s_cur_part_faces = cur_s_cur_part_mesh_dict['vertices'], \
cur_s_cur_part_mesh_dict['faces']
cur_s_cur_part_new_faces = []
for cur_s_cur_part_cur_face in cur_s_cur_part_faces:
cur_s_cur_part_cur_new_face = [fid + cur_s_n_vertices for fid in cur_s_cur_part_cur_face]
cur_s_cur_part_new_faces.append(cur_s_cur_part_cur_new_face)
cur_s_n_vertices += cur_s_cur_part_vertices.shape[0]
if i_p == 1:
# min_rngs = cur_s_cur_part_vertices.min(1)
# max_rngs = cur_s_cur_part_vertices.max(1)
min_rngs = cur_s_cur_part_vertices.min(0)
max_rngs = cur_s_cur_part_vertices.max(0)
# shifted; cur_s_pred_pvp
# shifted = np.array([0., cur_s_pred_pvp[1] - max_rngs[1], cur_s_pred_pvp[2] - min_rngs[2]], dtype=np.float)
# shifted = np.reshape(shifted, [1, 3]) #
cur_s_pred_pvp = np.array([0., max_rngs[1], min_rngs[2]], dtype=np.float32)
pvp_sample_pred_err = np.sum((cur_s_pred_pvp - pred_pvp[i_s]) ** 2)
# print prediction err, pred pvp and real pvp
# print("cur_s, sample_pred_pvp_err:", pvp_sample_pred_err.item(), ";real val:", cur_s_pred_pvp, "; pred_val:", pred_pvp[i_s])
pred_pvp[i_s] = cur_s_pred_pvp
shifted = np.zeros((1, 3), dtype=np.float32)
cur_s_cur_part_vertices = cur_s_cur_part_vertices + shifted # shift vertices... # min_rngs
# shifted
cur_s_tot_vertices.append(cur_s_cur_part_vertices)
cur_s_tot_faces += cur_s_cur_part_new_faces
cur_s_tot_vertices = np.concatenate(cur_s_tot_vertices, axis=0)
cur_s_mesh_dict = {
'vertices': cur_s_tot_vertices, 'faces': cur_s_tot_faces
}
tot_samples_mesh_dict.append(cur_s_mesh_dict)
for i_s in range(tot_n_samples):
cur_mesh = tot_samples_mesh_dict[i_s]
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_s}_shifted.obj")
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=False, scale=1.)
###### plot mesh (translated) ######
###### plot mesh (rotated) ######
if predict_joint:
from revolute_transform import revoluteTransform
tot_samples_mesh_dict = []
for i_s in range(tot_n_samples):
cur_s_tot_vertices = []
cur_s_tot_faces = []
cur_s_n_vertices = 0
# cur_s_pred_dir = pred_dir[i_s]
cur_s_pred_pvp = pred_pvp[i_s]
print("current pred dir:", cur_s_pred_dir, "; current pred pvp:", cur_s_pred_pvp)
cur_s_pred_dir = np.array([1.0, 0.0, 0.0], dtype=np.float)
# cur_s_pred_pvp = cur_s_pred_pvp[[1, 2, 0]]
for i_p in range(tot_n_part):
cur_s_cur_part_mesh_dict = tot_mesh_list[i_p][i_s]
cur_s_cur_part_vertices, cur_s_cur_part_faces = cur_s_cur_part_mesh_dict['vertices'], \
cur_s_cur_part_mesh_dict['faces']
if i_p == 1:
cur_s_cur_part_vertices, _ = revoluteTransform(cur_s_cur_part_vertices, cur_s_pred_pvp, cur_s_pred_dir, 0.5 * np.pi) # reverse revolute vertices of the upper piece
cur_s_cur_part_vertices = cur_s_cur_part_vertices[:, :3] #
cur_s_cur_part_new_faces = []
for cur_s_cur_part_cur_face in cur_s_cur_part_faces:
cur_s_cur_part_cur_new_face = [fid + cur_s_n_vertices for fid in cur_s_cur_part_cur_face]
cur_s_cur_part_new_faces.append(cur_s_cur_part_cur_new_face)
cur_s_n_vertices += cur_s_cur_part_vertices.shape[0]
cur_s_tot_vertices.append(cur_s_cur_part_vertices)
# print(f"i_s: {i_s}, i_p: {i_p}, n_vertices: {cur_s_cur_part_vertices.shape[0]}")
cur_s_tot_faces += cur_s_cur_part_new_faces
cur_s_tot_vertices = np.concatenate(cur_s_tot_vertices, axis=0)
# print(f"i_s: {i_s}, n_cur_s_tot_vertices: {cur_s_tot_vertices.shape[0]}")
cur_s_mesh_dict = {
'vertices': cur_s_tot_vertices, 'faces': cur_s_tot_faces
}
tot_samples_mesh_dict.append(cur_s_mesh_dict)
# plot_meshes(tot_samples_mesh_dict, ax_lims=0.5, mesh_sv_fn=f"./figs/training_step_{n}_part_{tot_n_part}_rot.png") # plot the mesh;
for i_s in range(tot_n_samples):
cur_mesh = tot_samples_mesh_dict[i_s]
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
# rotated mesh fn
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_s}_rot.obj")
# write object in the file...
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=False, scale=1.)
def sample_pts_from_mesh(vertices, faces, npoints=512, minus_one=True):
return vertices
sampled_pcts = []
pts_to_seg_idx = []
# triangles and pts
minus_val = 0 if not minus_one else 1
for i in range(len(faces)): #
cur_face = faces[i]
n_tris = len(cur_face) - 2
v_as, v_bs, v_cs = [cur_face[0] for _ in range(n_tris)], cur_face[1: len(cur_face) - 1], cur_face[2: len(cur_face)]
for v_a, v_b, v_c in zip(v_as, v_bs, v_cs):
v_a, v_b, v_c = vertices[v_a - minus_val], vertices[v_b - minus_val], vertices[v_c - minus_val]
ab, ac = v_b - v_a, v_c - v_a
cos_ab_ac = (np.sum(ab * ac) / np.clip(np.sqrt(np.sum(ab ** 2)) * np.sqrt(np.sum(ac ** 2)), a_min=1e-9,
a_max=9999999.)).item()
sin_ab_ac = math.sqrt(min(max(0., 1. - cos_ab_ac ** 2), 1.))
cur_area = 0.5 * sin_ab_ac * np.sqrt(np.sum(ab ** 2)).item() * np.sqrt(np.sum(ac ** 2)).item()
cur_sampled_pts = int(cur_area * npoints)
cur_sampled_pts = 1 if cur_sampled_pts == 0 else cur_sampled_pts
# if cur_sampled_pts == 0:
tmp_x, tmp_y = np.random.uniform(0, 1., (cur_sampled_pts,)).tolist(), np.random.uniform(0., 1., (
cur_sampled_pts,)).tolist()
for xx, yy in zip(tmp_x, tmp_y):
sqrt_xx, sqrt_yy = math.sqrt(xx), math.sqrt(yy)
aa = 1. - sqrt_xx
bb = sqrt_xx * (1. - yy)
cc = yy * sqrt_xx
cur_pos = v_a * aa + v_b * bb + v_c * cc
sampled_pcts.append(cur_pos)
# pts_to_seg_idx.append(cur_tri_seg)
# seg_idx_to_sampled_pts = {}
sampled_pcts = np.array(sampled_pcts, dtype=np.float)
return sampled_pcts
def fps_fr_numpy(np_pts, n_sampling=4096):
pts = torch.from_numpy(np_pts).float().cuda()
pts_fps_idx = farthest_point_sampling(pts.unsqueeze(0), n_sampling=n_sampling) # farthes points sampling ##
pts = pts[pts_fps_idx].cpu()
return pts
def farthest_point_sampling(pos: torch.FloatTensor, n_sampling: int):
bz, N = pos.size(0), pos.size(1)
feat_dim = pos.size(-1)
device = pos.device
sampling_ratio = float(n_sampling / N)
pos_float = pos.float()
batch = torch.arange(bz, dtype=torch.long).view(bz, 1).to(device)
mult_one = torch.ones((N,), dtype=torch.long).view(1, N).to(device)
batch = batch * mult_one
batch = batch.view(-1)
pos_float = pos_float.contiguous().view(-1, feat_dim).contiguous() # (bz x N, 3)
# sampling_ratio = torch.tensor([sampling_ratio for _ in range(bz)], dtype=torch.float).to(device)
# batch = torch.zeros((N, ), dtype=torch.long, device=device)
sampled_idx = fps(pos_float, batch, ratio=sampling_ratio, random_start=False)
# shape of sampled_idx?
return sampled_idx
def plot_sampled_meshes_single_part(v_sample, f_sample, sv_mesh_folder, cur_step=0, predict_joint=True,):
if not os.path.exists(sv_mesh_folder):
os.mkdir(sv_mesh_folder)
part_vertex_samples = [v_sample]
part_face_samples = [f_sample]
tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
tot_n_part = 2
# not predict joints here
# if predict_joint:
# pred_dir = v_sample['joint_dir']
# pred_pvp = v_sample['joint_pvp']
# print("pred_dir", pred_dir.shape, pred_dir)
# print("pred_pvp", pred_pvp.shape, pred_pvp)
# else:
# pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
tot_mesh_list = []
for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
mesh_list = []
for i_n in range(tot_n_samples):
mesh_list.append(
{
'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
'faces': unflatten_faces(
cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
}
)
tot_mesh_list.append(mesh_list)
for i_n in range(tot_n_samples):
cur_mesh = mesh_list[i_n]
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
# cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=True, scale=1.)
def plot_sampled_meshes(v_sample, f_sample, sv_mesh_folder, cur_step=0, predict_joint=True,):
if not os.path.exists(sv_mesh_folder):
os.mkdir(sv_mesh_folder)
part_vertex_samples = [v_sample]
part_face_samples = [f_sample]
tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
# tot_n_part = 2
# not predict joints here
# if predict_joint:
# pred_dir = v_sample['joint_dir']
# pred_pvp = v_sample['joint_pvp']
# print("pred_dir", pred_dir.shape, pred_dir)
# print("pred_pvp", pred_pvp.shape, pred_pvp)
# else:
# pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
tot_mesh_list = []
for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
mesh_list = []
for i_n in range(tot_n_samples):
mesh_list.append(
{
'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
'faces': unflatten_faces(
cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
}
)
tot_mesh_list.append(mesh_list)
for i_n in range(tot_n_samples):
cur_mesh = mesh_list[i_n]
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
# cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=True, scale=1.)
def filter_masked_vertices(vertices, mask_indicator):
# vertices: n_verts x 3
mask_indicator = np.reshape(mask_indicator, (vertices.shape[0], 3))
tot_vertices = []
for i_v in range(vertices.shape[0]):
cur_vert = vertices[i_v]
cur_vert_indicator = mask_indicator[i_v][0].item()
if cur_vert_indicator < 0.5:
tot_vertices.append(cur_vert)
tot_vertices = np.array(tot_vertices)
return tot_vertices
def plot_sampled_ar_subd_meshes(v_sample, f_sample, sv_mesh_folder, cur_step=0, ):
if not os.path.exists(sv_mesh_folder): ### vertices_mask
os.mkdir(sv_mesh_folder)
### v_sample: bsz x nn_verts x 3
vertices_mask = v_sample['vertices_mask']
vertices = v_sample['vertices']
faces = f_sample['faces']
num_face_indices = f_sample['num_face_indices'] #### num_faces_indices
bsz = vertices.shape[0]
for i_bsz in range(bsz):
cur_vertices = vertices[i_bsz]
cur_vertices_mask = vertices_mask[i_bsz]
cur_faces = faces[i_bsz]
cur_num_face_indices = num_face_indices[i_bsz]
cur_nn_verts = cur_vertices_mask.sum(-1).item()
cur_nn_verts = int(cur_nn_verts)
cur_vertices = cur_vertices[:cur_nn_verts]
cur_faces = unflatten_faces(
cur_faces[:int(cur_num_face_indices)])
cur_num_faces = len(cur_faces)
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_inst_{i_bsz}.obj")
# cur_context_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}_context.obj")
print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_nn_verts}, nn_faces: {cur_num_faces}")
if cur_nn_verts > 0 and cur_num_faces > 0:
write_obj(cur_vertices, cur_faces, cur_mesh_sv_fn, transpose=True, scale=1.)
def plot_sampled_meshes_single_part_for_pretraining(v_sample, f_sample, context, sv_mesh_folder, cur_step=0, predict_joint=True, with_context=True):
if not os.path.exists(sv_mesh_folder):
os.mkdir(sv_mesh_folder)
part_vertex_samples = [v_sample]
part_face_samples = [f_sample]
context_vertices = [context['vertices']]
context_faces = [context['faces']]
context_vertices_mask = [context['vertices_mask']]
context_faces_mask = [context['faces_mask']]
tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
tot_n_part = 2
# not predict joints here
# if predict_joint:
# pred_dir = v_sample['joint_dir']
# pred_pvp = v_sample['joint_pvp']
# print("pred_dir", pred_dir.shape, pred_dir)
# print("pred_pvp", pred_pvp.shape, pred_pvp)
# else:
# pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
#
tot_mesh_list = []
for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
mesh_list = []
context_mesh_list = []
for i_n in range(tot_n_samples):
mesh_list.append(
{
'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
'faces': unflatten_faces(
cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
}
)
cur_context_vertices = context_vertices[i_p][i_n]
cur_context_faces = context_faces[i_p][i_n]
cur_context_vertices_mask = context_vertices_mask[i_p][i_n]
cur_context_faces_mask = context_faces_mask[i_p][i_n]
cur_nn_vertices = np.sum(cur_context_vertices_mask).item()
cur_nn_faces = np.sum(cur_context_faces_mask).item()
cur_nn_vertices, cur_nn_faces = int(cur_nn_vertices), int(cur_nn_faces)
cur_context_vertices = cur_context_vertices[:cur_nn_vertices]
if 'mask_vertices_flat_indicator' in context:
cur_context_vertices_mask_indicator = context['mask_vertices_flat_indicator'][i_n]
cur_context_vertices_mask_indicator = cur_context_vertices_mask_indicator[:cur_nn_vertices * 3]
cur_context_vertices = filter_masked_vertices(cur_context_vertices, cur_context_vertices_mask_indicator)
cur_context_faces = cur_context_faces[:cur_nn_faces] # context faces...
context_mesh_dict = {
'vertices': dequantize_verts(cur_context_vertices, n_bits=8), 'faces': unflatten_faces(cur_context_faces)
}
context_mesh_list.append(context_mesh_dict)
tot_mesh_list.append(mesh_list)
# if with_context:
for i_n in range(tot_n_samples):
cur_mesh = mesh_list[i_n]
cur_context_mesh = context_mesh_list[i_n]
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
cur_context_vertices, cur_context_faces = cur_context_mesh['vertices'], cur_context_mesh['faces']
# cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
cur_context_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}_context.obj")
print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=True, scale=1.)
if cur_context_vertices.shape[0] > 0 and len(cur_context_faces) > 0:
write_obj(cur_context_vertices, cur_context_faces, cur_context_mesh_sv_fn, transpose=True, scale=1.)
def plot_grids_for_pretraining_ml(v_sample, sv_mesh_folder="", cur_step=0, context=None):
if not os.path.exists(sv_mesh_folder):
os.mkdir(sv_mesh_folder)
mesh_list = []
context_mesh_list = []
tot_n_samples = v_sample['vertices'].shape[0]
for i_n in range(tot_n_samples):
mesh_list.append(
{
'vertices': v_sample['vertices'][i_n][:v_sample['num_vertices'][i_n]],
'faces': []
}
)
cur_context_vertices = context['vertices'][i_n]
cur_context_vertices_mask = context['vertices_mask'][i_n]
cur_nn_vertices = np.sum(cur_context_vertices_mask).item()
cur_nn_vertices = int(cur_nn_vertices)
cur_context_vertices = cur_context_vertices[:cur_nn_vertices]
if 'mask_vertices_flat_indicator' in context:
cur_context_vertices_mask_indicator = context['mask_vertices_flat_indicator'][i_n]
cur_context_vertices_mask_indicator = cur_context_vertices_mask_indicator[:cur_nn_vertices * 3]
cur_context_vertices = filter_masked_vertices(cur_context_vertices, cur_context_vertices_mask_indicator)
context_mesh_dict = {
'vertices': dequantize_verts(cur_context_vertices, n_bits=8), 'faces': []
}
context_mesh_list.append(context_mesh_dict)
# tot_mesh_list.append(mesh_list)
# if with_context:
for i_n in range(tot_n_samples):
cur_mesh = mesh_list[i_n]
cur_context_mesh = context_mesh_list[i_n]
cur_mesh_vertices = cur_mesh['vertices']
cur_context_vertices = cur_context_mesh['vertices']
# cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_n}.obj")
cur_context_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_n}_context.obj")
# print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
print(f"saving the sample to {cur_mesh_sv_fn}, context sample to {cur_context_mesh_sv_fn}")
if cur_mesh_vertices.shape[0] > 0:
write_obj(cur_mesh_vertices, None, cur_mesh_sv_fn, transpose=True, scale=1.)
if cur_context_vertices.shape[0] > 0:
write_obj(cur_context_vertices, None, cur_context_mesh_sv_fn, transpose=True, scale=1.)
def get_grid_content_from_grids(grid_xyzs, grid_values, grid_size=2):
cur_bsz_grid_xyzs = grid_xyzs # grid_length x 3 # grids pts for a sinlge batch
cur_bsz_grid_values = grid_values # grid_length x gs x gs x gs
pts = []
for i_grid in range(cur_bsz_grid_xyzs.shape[0]): # cur_bsz_grid_xyzs
cur_grid_xyz = cur_bsz_grid_xyzs[i_grid].tolist()
if cur_grid_xyz[0] == -1 or cur_grid_xyz[1] == -1 or cur_grid_xyz[2] == -1:
break
if len(cur_bsz_grid_values.shape) > 1:
cur_grid_values = cur_bsz_grid_values[i_grid]
else:
cur_grid_content = cur_bsz_grid_values[i_grid].item()
if cur_grid_content >= MASK_GRID_VALIE:
continue
inde = 2
cur_grid_values = []
for i_s in range(grid_size ** 3):
cur_mod_value = cur_grid_content % inde
cur_grid_content = cur_grid_content // inde
cur_grid_values = [cur_mod_value] + cur_grid_values # higher values should be put to the front of the list
cur_grid_values = np.array(cur_grid_values, dtype=np.long)
cur_grid_values = np.reshape(cur_grid_values, (grid_size, grid_size, grid_size))
# if words
# flip words
for i_x in range(cur_grid_values.shape[0]):
for i_y in range(cur_grid_values.shape[1]):
for i_z in range(cur_grid_values.shape[2]):
cur_grid_xyz_value = int(cur_grid_values[i_x, i_y, i_z].item())
if cur_grid_xyz_value > 0.5:
cur_x, cur_y, cur_z = cur_grid_xyz[0] * grid_size + i_x, cur_grid_xyz[1] * grid_size + i_y, cur_grid_xyz[2] * grid_size + i_z
pts.append([cur_x, cur_y, cur_z])
return pts
def plot_grids_for_pretraining(v_sample, sv_mesh_folder="", cur_step=0, context=None, sv_mesh_fn=None):
##### plot grids
if not os.path.exists(sv_mesh_folder):
os.mkdir(sv_mesh_folder)
# part_vertex_samples = [v_sample] # vertex samples
# part_face_samples = [f_sample] # face samples
grid_xyzs = v_sample['grid_xyzs']
grid_values = v_sample['grid_values']
bsz = grid_xyzs.shape[0]
grid_size = opt.vertex_model.grid_size
for i_bsz in range(bsz):
cur_bsz_grid_xyzs = grid_xyzs[i_bsz] # grid_length x 3
cur_bsz_grid_values = grid_values[i_bsz] # grid_length x gs x gs x gs
pts = []
for i_grid in range(cur_bsz_grid_xyzs.shape[0]): # cur_bsz_grid_xyzs
cur_grid_xyz = cur_bsz_grid_xyzs[i_grid].tolist()
if cur_grid_xyz[0] == -1 or cur_grid_xyz[1] == -1 or cur_grid_xyz[2] == -1:
break
if len(cur_bsz_grid_values.shape) > 1:
cur_grid_values = cur_bsz_grid_values[i_grid]
else:
cur_grid_content = cur_bsz_grid_values[i_grid].item()
if cur_grid_content >= MASK_GRID_VALIE:
continue
inde = 2
cur_grid_values = []
for i_s in range(grid_size ** 3):
cur_mod_value = cur_grid_content % inde
cur_grid_content = cur_grid_content // inde
cur_grid_values = [cur_mod_value] + cur_grid_values # higher values should be put to the front of the list
cur_grid_values = np.array(cur_grid_values, dtype=np.long)
cur_grid_values = np.reshape(cur_grid_values, (grid_size, grid_size, grid_size))
# if
for i_x in range(cur_grid_values.shape[0]):
for i_y in range(cur_grid_values.shape[1]):
for i_z in range(cur_grid_values.shape[2]):
cur_grid_xyz_value = int(cur_grid_values[i_x, i_y, i_z].item())
if cur_grid_xyz_value > 0.5:
cur_x, cur_y, cur_z = cur_grid_xyz[0] * grid_size + i_x, cur_grid_xyz[1] * grid_size + i_y, cur_grid_xyz[2] * grid_size + i_z
pts.append([cur_x, cur_y, cur_z])
if len(pts) == 0:
print("zzz, len(pts) == 0")
continue
pts = np.array(pts, dtype=np.float32)
# pts = center_vertices(pts)
# pts = normalize_vertices_scale(pts)
pts = pts[:, [2, 1, 0]]
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}.obj" if sv_mesh_fn is None else sv_mesh_fn)
print(f"Saving obj to {cur_mesh_sv_fn}")
write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
def plot_grids_for_pretraining_obj_corpus(v_sample, sv_mesh_folder="", cur_step=0, context=None, sv_mesh_fn=None):
##### plot grids
if not os.path.exists(sv_mesh_folder):
os.mkdir(sv_mesh_folder)
# part_vertex_samples = [v_sample] # vertex samples
# part_face_samples = [f_sample] # face samples
grid_xyzs = v_sample['grid_xyzs']
grid_values = v_sample['grid_values']
bsz = grid_xyzs.shape[0]
grid_size = opt.vertex_model.grid_size
for i_bsz in range(bsz):
cur_bsz_grid_xyzs = grid_xyzs[i_bsz] # grid_length x 3
cur_bsz_grid_values = grid_values[i_bsz] # grid_length x gs x gs x gs
part_pts = []
pts = []
for i_grid in range(cur_bsz_grid_xyzs.shape[0]): # cur_bsz_grid_xyzs
cur_grid_xyz = cur_bsz_grid_xyzs[i_grid].tolist()
##### grid_xyz; grid_
if cur_grid_xyz[0] == -1 and cur_grid_xyz[1] == -1 and cur_grid_xyz[2] == -1:
part_pts.append(pts)
pts = []
continue
##### cur_grid_xyz... #####
elif not (cur_grid_xyz[0] >= 0 and cur_grid_xyz[1] >= 0 and cur_grid_xyz[2] >= 0):
continue
if len(cur_bsz_grid_values.shape) > 1:
cur_grid_values = cur_bsz_grid_values[i_grid]
else:
cur_grid_content = cur_bsz_grid_values[i_grid].item()
if cur_grid_content >= MASK_GRID_VALIE: # mask grid value
continue
inde = 2
cur_grid_values = []
for i_s in range(grid_size ** 3):
cur_mod_value = cur_grid_content % inde
cur_grid_content = cur_grid_content // inde
cur_grid_values = [cur_mod_value] + cur_grid_values # higher values should be put to the front of the list
cur_grid_values = np.array(cur_grid_values, dtype=np.long)
cur_grid_values = np.reshape(cur_grid_values, (grid_size, grid_size, grid_size))
# if
for i_x in range(cur_grid_values.shape[0]):
for i_y in range(cur_grid_values.shape[1]):
for i_z in range(cur_grid_values.shape[2]):
cur_grid_xyz_value = int(cur_grid_values[i_x, i_y, i_z].item())
##### gird-xyz-values #####
if cur_grid_xyz_value > 0.5: # cur_grid_xyz_value
cur_x, cur_y, cur_z = cur_grid_xyz[0] * grid_size + i_x, cur_grid_xyz[1] * grid_size + i_y, cur_grid_xyz[2] * grid_size + i_z
pts.append([cur_x, cur_y, cur_z])
if len(pts) > 0:
part_pts.append(pts)
pts = []
tot_nn_pts = sum([len(aa) for aa in part_pts])
if tot_nn_pts == 0:
print("zzz, tot_nn_pts == 0")
continue
for i_p, pts in enumerate(part_pts):
if len(pts) == 0:
continue
pts = np.array(pts, dtype=np.float32)
pts = center_vertices(pts)
# pts = normalize_vertices_scale(pts)
pts = pts[:, [2, 1, 0]]
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}_ip_{i_p}.obj" if sv_mesh_fn is None else sv_mesh_fn)
print(f"Saving {i_p}-th part obj to {cur_mesh_sv_fn}")
write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
def plot_grids_for_pretraining_obj_part(v_sample, sv_mesh_folder="", cur_step=0, context=None, sv_mesh_fn=None):
##### plot grids
if not os.path.exists(sv_mesh_folder):
os.mkdir(sv_mesh_folder)
# part_vertex_samples = [v_sample] # vertex samples
# part_face_samples = [f_sample] # face samples
grid_xyzs = v_sample['grid_xyzs']
grid_values = v_sample['grid_values']
bsz = grid_xyzs.shape[0]
grid_size = opt.vertex_model.grid_size
for i_bsz in range(bsz):
cur_bsz_grid_xyzs = grid_xyzs[i_bsz] # grid_length x 3
cur_bsz_grid_values = grid_values[i_bsz] # grid_length x gs x gs x gs
part_pts = []
pts = []
for i_grid in range(cur_bsz_grid_xyzs.shape[0]): # cur_bsz_grid_xyzs
cur_grid_xyz = cur_bsz_grid_xyzs[i_grid].tolist()
##### grid_xyz; grid_
if cur_grid_xyz[0] == -1 and cur_grid_xyz[1] == -1 and cur_grid_xyz[2] == -1:
part_pts.append(pts)
pts = []
break
elif cur_grid_xyz[0] == -1 and cur_grid_xyz[1] == -1 and cur_grid_xyz[2] == 0:
part_pts.append(pts)
pts = []
continue
##### cur_grid_xyz... #####
elif not (cur_grid_xyz[0] >= 0 and cur_grid_xyz[1] >= 0 and cur_grid_xyz[2] >= 0):
continue
if len(cur_bsz_grid_values.shape) > 1:
cur_grid_values = cur_bsz_grid_values[i_grid]
else:
cur_grid_content = cur_bsz_grid_values[i_grid].item()
if cur_grid_content >= MASK_GRID_VALIE: # invalid jor dummy content value s
continue
inde = 2
cur_grid_values = []
for i_s in range(grid_size ** 3):
cur_mod_value = cur_grid_content % inde
cur_grid_content = cur_grid_content // inde
cur_grid_values = [cur_mod_value] + cur_grid_values # higher values should be put to the front of the list
cur_grid_values = np.array(cur_grid_values, dtype=np.long)
cur_grid_values = np.reshape(cur_grid_values, (grid_size, grid_size, grid_size))
# if
for i_x in range(cur_grid_values.shape[0]):
for i_y in range(cur_grid_values.shape[1]):
for i_z in range(cur_grid_values.shape[2]):
cur_grid_xyz_value = int(cur_grid_values[i_x, i_y, i_z].item())
##### gird-xyz-values #####
if cur_grid_xyz_value > 0.5: # cur_grid_xyz_value
cur_x, cur_y, cur_z = cur_grid_xyz[0] * grid_size + i_x, cur_grid_xyz[1] * grid_size + i_y, cur_grid_xyz[2] * grid_size + i_z
pts.append([cur_x, cur_y, cur_z])
if len(pts) > 0:
part_pts.append(pts)
pts = []
tot_nn_pts = sum([len(aa) for aa in part_pts])
if tot_nn_pts == 0:
print("zzz, tot_nn_pts == 0")
continue
for i_p, pts in enumerate(part_pts):
if len(pts) == 0:
continue
pts = np.array(pts, dtype=np.float32)
pts = center_vertices(pts)
# pts = normalize_vertices_scale(pts)
pts = pts[:, [2, 1, 0]]
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}_ip_{i_p}.obj" if sv_mesh_fn is None else sv_mesh_fn)
print(f"Saving {i_p}-th part obj to {cur_mesh_sv_fn}")
write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
def plot_grids_for_pretraining_ml(v_sample, sv_mesh_folder="", cur_step=0, context=None):
if not os.path.exists(sv_mesh_folder):
os.mkdir(sv_mesh_folder)
# part_vertex_samples = [v_sample] # vertex samples
# part_face_samples = [f_sample] # face samples
grid_xyzs = v_sample['grid_xyzs']
grid_values = v_sample['grid_values']
context_grid_xyzs = context['grid_xyzs'] - 1
# context_grid_values = context['grid_content']
context_grid_values = context['mask_grid_content']
bsz = grid_xyzs.shape[0]
grid_size = opt.vertex_model.grid_size
for i_bsz in range(bsz):
cur_bsz_grid_pts = get_grid_content_from_grids(grid_xyzs[i_bsz], grid_values[i_bsz], grid_size=grid_size)
cur_context_grid_pts = get_grid_content_from_grids(context_grid_xyzs[i_bsz], context_grid_values[i_bsz], grid_size=grid_size)
if len(cur_bsz_grid_pts) > 0 and len(cur_context_grid_pts) > 0:
cur_bsz_grid_pts = np.array(cur_bsz_grid_pts, dtype=np.float32)
cur_bsz_grid_pts = center_vertices(cur_bsz_grid_pts)
cur_bsz_grid_pts = normalize_vertices_scale(cur_bsz_grid_pts)
cur_bsz_grid_pts = cur_bsz_grid_pts[:, [2, 1, 0]]
#### plot current mesh / sampled points ####
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}.obj")
print(f"Saving predicted grid points to {cur_mesh_sv_fn}")
write_obj(cur_bsz_grid_pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
cur_context_grid_pts = np.array(cur_context_grid_pts, dtype=np.float32)
cur_context_grid_pts = center_vertices(cur_context_grid_pts)
cur_context_grid_pts = normalize_vertices_scale(cur_context_grid_pts)
cur_context_grid_pts = cur_context_grid_pts[:, [2, 1, 0]]
#### plot current mesh / sampled points ####
cur_context_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}_context.obj")
print(f"Saving context grid points to {cur_context_mesh_sv_fn}")
write_obj(cur_context_grid_pts, None, cur_context_mesh_sv_fn, transpose=True, scale=1.)
# print(f"Saving obj to {cur_mesh_sv_fn}")
# write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
# if len(cur_bsz_grid_pts) == 0:
# print("zzz, len(pts) == 0")
# continue
# pts = np.array(pts, dtype=np.float32)
# pts = center_vertices(pts)
# pts = normalize_vertices_scale(pts)
# pts = pts[:, [2, 1, 0]]
# cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}.obj")
# print(f"Saving obj to {cur_mesh_sv_fn}")
# write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
def plot_sampled_meshes_single_part_for_sampling(v_sample, f_sample, sv_mesh_folder, cur_step=0, predict_joint=True,):
if not os.path.exists(sv_mesh_folder):
os.mkdir(sv_mesh_folder)
part_vertex_samples = [v_sample]
part_face_samples = [f_sample]
tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
tot_n_part = 2
# not predict joints here
# if predict_joint:
# pred_dir = v_sample['joint_dir']
# pred_pvp = v_sample['joint_pvp']
# print("pred_dir", pred_dir.shape, pred_dir)
# print("pred_pvp", pred_pvp.shape, pred_pvp)
# else:
# pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
tot_mesh_list = []
for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
mesh_list = []
for i_n in range(tot_n_samples):
mesh_list.append(
{
'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
'faces': unflatten_faces(
cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
}
)
tot_mesh_list.append(mesh_list)
for i_n in range(tot_n_samples):
cur_mesh = mesh_list[i_n]
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
# cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=True, scale=1.)