Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import math | |
| from prettytable import PrettyTable | |
| def count_parameters(model): | |
| table = PrettyTable(["Modules", "Parameters"]) | |
| total_params = 0 | |
| for name, parameter in model.named_parameters(): | |
| if not parameter.requires_grad: | |
| continue | |
| param = parameter.numel() | |
| if param > 100000: | |
| table.add_row([name, param]) | |
| total_params+=param | |
| print(table) | |
| print('total params: %.2f M' % (total_params/1000000.0)) | |
| return total_params | |
| def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False): | |
| device = xy.device | |
| dtype = xy.dtype | |
| B, S, D = xy.shape | |
| assert(D==2) | |
| x = xy[:,:,0] | |
| y = xy[:,:,1] | |
| assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' | |
| omega = torch.arange(C // 4, device=device) / (C // 4 - 1) | |
| omega = 1. / (temperature ** omega) | |
| y = y.flatten()[:, None] * omega[None, :] | |
| x = x.flatten()[:, None] * omega[None, :] | |
| pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) | |
| pe = pe.reshape(B,S,C).type(dtype) | |
| if cat_coords: | |
| pe = torch.cat([pe, xy], dim=2) # B,N,C+2 | |
| return pe | |
| class SimplePool(): | |
| def __init__(self, pool_size, version='pt'): | |
| self.pool_size = pool_size | |
| self.version = version | |
| self.items = [] | |
| if not (version=='pt' or version=='np'): | |
| print('version = %s; please choose pt or np') | |
| assert(False) # please choose pt or np | |
| def __len__(self): | |
| return len(self.items) | |
| def mean(self, min_size=1): | |
| if min_size=='half': | |
| pool_size_thresh = self.pool_size/2 | |
| else: | |
| pool_size_thresh = min_size | |
| if self.version=='np': | |
| if len(self.items) >= pool_size_thresh: | |
| return np.sum(self.items)/float(len(self.items)) | |
| else: | |
| return np.nan | |
| if self.version=='pt': | |
| if len(self.items) >= pool_size_thresh: | |
| return torch.sum(self.items)/float(len(self.items)) | |
| else: | |
| return torch.from_numpy(np.nan) | |
| def sample(self, with_replacement=True): | |
| idx = np.random.randint(len(self.items)) | |
| if with_replacement: | |
| return self.items[idx] | |
| else: | |
| return self.items.pop(idx) | |
| def fetch(self, num=None): | |
| if self.version=='pt': | |
| item_array = torch.stack(self.items) | |
| elif self.version=='np': | |
| item_array = np.stack(self.items) | |
| if num is not None: | |
| # there better be some items | |
| assert(len(self.items) >= num) | |
| # if there are not that many elements just return however many there are | |
| if len(self.items) < num: | |
| return item_array | |
| else: | |
| idxs = np.random.randint(len(self.items), size=num) | |
| return item_array[idxs] | |
| else: | |
| return item_array | |
| def is_full(self): | |
| full = len(self.items)==self.pool_size | |
| return full | |
| def empty(self): | |
| self.items = [] | |
| def update(self, items): | |
| for item in items: | |
| if len(self.items) < self.pool_size: | |
| # the pool is not full, so let's add this in | |
| self.items.append(item) | |
| else: | |
| # the pool is full | |
| # pop from the front | |
| self.items.pop(0) | |
| # add to the back | |
| self.items.append(item) | |
| return self.items | |
| def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False): | |
| """ | |
| Input: | |
| xyz: pointcloud data, [B, N, C], where C is probably 3 | |
| npoint: number of samples | |
| Return: | |
| inds: sampled pointcloud index, [B, npoint] | |
| """ | |
| device = xyz.device | |
| B, N, C = xyz.shape | |
| xyz = xyz.float() | |
| inds = torch.zeros(B, npoint, dtype=torch.long).to(device) | |
| distance = torch.ones(B, N).to(device) * 1e10 | |
| if deterministic: | |
| farthest = torch.randint(0, 1, (B,), dtype=torch.long).to(device) | |
| else: | |
| farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) | |
| batch_indices = torch.arange(B, dtype=torch.long).to(device) | |
| for i in range(npoint): | |
| if include_ends: | |
| if i==0: | |
| farthest = 0 | |
| elif i==1: | |
| farthest = N-1 | |
| inds[:, i] = farthest | |
| centroid = xyz[batch_indices, farthest, :].view(B, 1, C) | |
| dist = torch.sum((xyz - centroid) ** 2, -1) | |
| mask = dist < distance | |
| distance[mask] = dist[mask] | |
| farthest = torch.max(distance, -1)[1] | |
| if npoint > N: | |
| # if we need more samples, make them random | |
| distance += torch.randn_like(distance) | |
| return inds | |
| def farthest_point_sample_py(xyz, npoint): | |
| N,C = xyz.shape | |
| inds = np.zeros(npoint, dtype=np.int32) | |
| distance = np.ones(N) * 1e10 | |
| farthest = np.random.randint(0, N, dtype=np.int32) | |
| for i in range(npoint): | |
| inds[i] = farthest | |
| centroid = xyz[farthest, :].reshape(1,C) | |
| dist = np.sum((xyz - centroid) ** 2, -1) | |
| mask = dist < distance | |
| distance[mask] = dist[mask] | |
| farthest = np.argmax(distance, -1) | |
| if npoint > N: | |
| # if we need more samples, make them random | |
| distance += np.random.randn(*distance.shape) | |
| return inds | |