import torch import torch.nn.functional as F from model.model import PointSemSeg, Find3D import numpy as np import random from transformers import AutoTokenizer, AutoModel DEVICE = "cuda:0" #if torch.cuda.is_available(): #DEVICE = "cuda:0" def get_seg_color(labels): part_num = labels.max() cmap_matrix = torch.tensor([[1,1,1], [1,0,0], [0,1,0], [0,0,1], [1,1,0], [1,0,1], [0,1,1], [0.5,0.5,0.5], [0.5,0.5,0], [0.5,0,0.5],[0,0.5,0.5], [0.1,0.2,0.3],[0.2,0.5,0.3], [0.6,0.3,0.2], [0.5,0.3,0.5], [0.6,0.7,0.2],[0.5,0.8,0.3]])[:part_num+1,:] onehot = F.one_hot(labels.long(), num_classes=part_num+1) * 1.0 # n_pts, part_num+1, each row 00.010.0, first place is unlabeled (0 originally) pts_rgb = torch.matmul(onehot, cmap_matrix) return pts_rgb def get_legend(parts): colors = ["white", "red", "green", "blue", "yellow", "magenta", "cyan","grey", "olive", "purple", "teal", "navy", "darkgreen", "brown", "pinkpurple", "yellowgreen", "limegreen"] legends = [] i = 1 for part in parts: cur_color = colors[i] legends.append(f"{cur_color}:{part}") i += 1 legend = " ".join(legends) return legend def load_model(): model = Find3D.from_pretrained("ziqima/find3d-checkpt0", dim_output=768) #model.load_state_dict(torch.load("find3d_checkpoint.pth")["model_state_dict"]) model.eval() model = model.to(DEVICE) return model def set_seed(seed): torch.manual_seed(seed) if DEVICE != "cpu": torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) def fnv_hash_vec(arr): """ FNV64-1A """ assert arr.ndim == 2 # Floor first for negative coordinates arr = arr.copy() arr = arr.astype(np.uint64, copy=False) hashed_arr = np.uint64(14695981039346656037) * np.ones( arr.shape[0], dtype=np.uint64 ) for j in range(arr.shape[1]): hashed_arr *= np.uint64(1099511628211) hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j]) return hashed_arr def grid_sample_numpy(xyz, rgb, normal, grid_size): # this should hopefully be 5000 or close xyz = xyz.cpu().numpy() rgb = rgb.cpu().numpy() normal = normal.cpu().numpy() scaled_coord = xyz / np.array(grid_size) grid_coord = np.floor(scaled_coord).astype(int) min_coord = grid_coord.min(0) grid_coord -= min_coord scaled_coord -= min_coord min_coord = min_coord * np.array(grid_size) key = fnv_hash_vec(grid_coord) idx_sort = np.argsort(key) key_sort = key[idx_sort] _, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True) idx_select = ( np.cumsum(np.insert(count, 0, 0)[0:-1]) + np.random.randint(0, count.max(), count.size) % count ) idx_unique = idx_sort[idx_select] grid_coord = grid_coord[idx_unique] xyz = torch.tensor(xyz[idx_unique]).to(DEVICE) rgb = torch.tensor(rgb[idx_unique]).to(DEVICE) normal = torch.tensor(normal[idx_unique]).to(DEVICE) grid_coord = torch.tensor(grid_coord).to(DEVICE) return xyz, rgb, normal, grid_coord def encode_text(texts): siglip = AutoModel.from_pretrained("google/siglip-base-patch16-224") # dim 768 #"google/siglip-so400m-patch14-384") tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")#"google/siglip-so400m-patch14-384") inputs = tokenizer(texts, padding="max_length", return_tensors="pt") for key in inputs: inputs[key] = inputs[key].to(DEVICE) with torch.no_grad(): text_feat = siglip.to(DEVICE).get_text_features(**inputs) text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-12) return text_feat def preprocess_pcd(xyz, rgb, normal): # rgb should be 0-1 assert rgb.max() <=1 # normalize # this is the same preprocessing I do before training center = xyz.mean(0) scale = max((xyz - center).abs().max(0)[0]) xyz -= center xyz *= (0.75 / float(scale)) # put in 0.75-size box # axis swap xyz = torch.cat([-xyz[:,0].reshape(-1,1), xyz[:,2].reshape(-1,1), xyz[:,1].reshape(-1,1)], dim=1) # center shift xyz_min = xyz.min(dim=0)[0] xyz_max = xyz.max(dim=0)[0] xyz_max[2] = 0 shift = (xyz_min+xyz_max)/2 xyz -= shift # subsample/upsample to 5000 pts for grid sampling if xyz.shape[0] != 5000: random_indices = torch.randint(0, xyz.shape[0], (5000,)) pts_xyz_subsampled = xyz[random_indices] pts_rgb_subsampled = rgb[random_indices] normal_subsampled = normal[random_indices] else: pts_xyz_subsampled = xyz pts_rgb_subsampled = rgb normal_subsampled = normal # grid sampling pts_xyz_gridsampled, pts_rgb_gridsampled, normal_gridsampled, grid_coord = grid_sample_numpy(pts_xyz_subsampled, pts_rgb_subsampled, normal_subsampled, 0.02) # another center shift, z=false xyz_min = pts_xyz_gridsampled.min(dim=0)[0] xyz_min[2] = 0 xyz_max = pts_xyz_gridsampled.max(dim=0)[0] xyz_max[2] = 0 shift = (xyz_min+xyz_max)/2 pts_xyz_gridsampled -= shift xyz -= shift # normalize color pts_rgb_gridsampled = pts_rgb_gridsampled / 0.5 - 1 # combine color and normal as feat feat = torch.cat([pts_rgb_gridsampled, normal_gridsampled], dim=1) data_dict = {} data_dict["coord"] = pts_xyz_gridsampled data_dict["feat"] = feat data_dict["grid_coord"] = grid_coord data_dict["xyz_full"] = xyz data_dict["offset"] = torch.tensor([pts_xyz_gridsampled.shape[0]]) return data_dict