Spaces:
Runtime error
Runtime error
+++ venv/lib/python3.10/site-packages/selfcontact/body_segmentation.py | |
# | |
# Contact: ps-license@tuebingen.mpg.de | |
+from pathlib import Path | |
+ | |
import torch | |
import trimesh | |
import torch.nn as nn | |
from .utils.mesh import winding_numbers | |
+ | |
+def load_pkl(path): | |
+ with open(path, "rb") as fin: | |
+ return pickle.load(fin) | |
+ | |
+ | |
+def save_pkl(obj, path): | |
+ with open(path, "wb") as fout: | |
+ pickle.dump(obj, fout) | |
+ | |
+ | |
class BodySegment(nn.Module): | |
def __init__(self, | |
name, | |
self.register_buffer('segment_faces', segment_faces) | |
# create vector to select vertices form faces | |
- tri_vidx = [] | |
- for ii in range(faces.max().item()+1): | |
- tri_vidx += [torch.nonzero(faces==ii)[0].tolist()] | |
+ segments_folder = Path(segments_folder) | |
+ tri_vidx_path = segments_folder / "tri_vidx.pkl" | |
+ if not tri_vidx_path.is_file(): | |
+ tri_vidx = [] | |
+ for ii in range(faces.max().item()+1): | |
+ tri_vidx += [torch.nonzero(faces==ii)[0].tolist()] | |
+ | |
+ save_pkl(tri_vidx, tri_vidx_path) | |
+ else: | |
+ tri_vidx = load_pkl(tri_vidx_path) | |
+ | |
self.register_buffer('tri_vidx', torch.tensor(tri_vidx)) | |
def create_band_faces(self): | |
self.segmentation = {} | |
for idx, name in enumerate(names): | |
self.segmentation[name] = BodySegment(name, faces, segments_folder, | |
- model_type).to('cuda') | |
+ model_type).to(device) | |
def batch_has_self_isec_verts(self, vertices): | |
""" | |
+++ venv/lib/python3.10/site-packages/selfcontact/selfcontact.py | |
test_segments=True, | |
compute_hd=False, | |
buffer_geodists=False, | |
+ device="cuda", | |
): | |
super().__init__() | |
if self.test_segments: | |
sxseg = pickle.load(open(segments_bounds_path, 'rb')) | |
self.segments = BatchBodySegment( | |
- [x for x in sxseg.keys()], faces, segments_folder, self.model_type | |
+ [x for x in sxseg.keys()], faces, segments_folder, self.model_type, device=device, | |
) | |
# load regressor to get high density mesh | |
torch.tensor(hd_operator['values']), | |
torch.Size(hd_operator['size'])) | |
self.register_buffer('hd_operator', | |
- torch.tensor(hd_operator).float()) | |
+ hd_operator.clone().detach().float()) | |
with open(point_vert_corres_path, 'rb') as f: | |
hd_geovec = pickle.load(f)['faces_vert_is_sampled_from'] | |
# split because of memory into two chunks | |
exterior = torch.zeros((bs, nv), device=vertices.device, | |
dtype=torch.bool) | |
- exterior[:, :5000] = winding_numbers(vertices[:,:5000,:], | |
+ exterior[:, :3000] = winding_numbers(vertices[:,:3000,:], | |
triangles).le(0.99) | |
- exterior[:, 5000:] = winding_numbers(vertices[:,5000:,:], | |
+ exterior[:, 3000:6000] = winding_numbers(vertices[:,3000:6000,:], | |
+ triangles).le(0.99) | |
+ exterior[:, 6000:9000] = winding_numbers(vertices[:,6000:9000,:], | |
+ triangles).le(0.99) | |
+ exterior[:, 9000:] = winding_numbers(vertices[:,9000:,:], | |
triangles).le(0.99) | |
# check if intersections happen within segments | |
# split because of memory into two chunks | |
exterior = torch.zeros((bs, np), device=points.device, | |
dtype=torch.bool) | |
- exterior[:, :6000] = winding_numbers(points[:,:6000,:], | |
+ exterior[:, :3000] = winding_numbers(points[:,:3000,:], | |
+ triangles).le(0.99) | |
+ exterior[:, 3000:6000] = winding_numbers(points[:,3000:6000,:], | |
triangles).le(0.99) | |
- exterior[:, 6000:] = winding_numbers(points[:,6000:,:], | |
+ exterior[:, 6000:9000] = winding_numbers(points[:,6000:9000,:], | |
+ triangles).le(0.99) | |
+ exterior[:, 9000:] = winding_numbers(points[:,9000:,:], | |
triangles).le(0.99) | |
return exterior | |
return hd_v2v_mins, hd_exteriors, hd_points, hd_faces_in_contacts | |
+ def verts_in_contact(self, vertices, return_idx=False): | |
+ | |
+ # get pairwise distances of vertices | |
+ v2v = self.get_pairwise_dists(vertices, vertices, squared=True) | |
+ | |
+ # mask v2v with eucledean and geodesic dsitance | |
+ euclmask = v2v < self.euclthres**2 | |
+ mask = euclmask * self.geomask | |
+ | |
+ # find closes vertex in contact | |
+ in_contact = mask.sum(1) > 0 | |
+ | |
+ if return_idx: | |
+ in_contact = torch.where(in_contact) | |
+ | |
+ return in_contact | |
+ | |
class SelfContactSmall(nn.Module): | |
+++ venv/lib/python3.10/site-packages/selfcontact/utils/mesh.py | |
if valid_vals > 0: | |
loss = (mask * dists).sum() / valid_vals | |
else: | |
- loss = torch.Tensor([0]).cuda() | |
+ loss = mask.new_tensor([0]) | |
return loss | |
def batch_index_select(inp, dim, index): | |
xx = torch.bmm(x, x.transpose(2, 1)) | |
yy = torch.bmm(y, y.transpose(2, 1)) | |
zz = torch.bmm(x, y.transpose(2, 1)) | |
+ use_cuda = x.device.type == "cuda" | |
if use_cuda: | |
dtype = torch.cuda.LongTensor | |
else: | |