Find3D / inference /inference.py
ziqima's picture
initial commit
4893ce0
raw
history blame
3.92 kB
import torch
import numpy as np
import matplotlib.pyplot as plt
from inference.utils import get_seg_color, load_model, preprocess_pcd, encode_text
DEVICE = "cpu"
if torch.cuda.is_available():
DEVICE = "cuda:0"
def pred_3d_upsample(
pred, # n_subsampled_pts, feat_dim
part_text_embeds, # n_parts, feat_dim
temperature,
xyz_sub,
xyz_full, # n_pts, 3
N_CHUNKS=1
):
xyz_full = xyz_full.squeeze()
logits = pred @ part_text_embeds.T # n_pts, n_mask
logits_prepend0 = torch.cat([torch.zeros(logits.shape[0],1).to(DEVICE), logits],axis=1)
pred_softmax = torch.nn.Softmax(dim=1)(logits_prepend0 * temperature)
chunk_len = xyz_full.shape[0]//N_CHUNKS+1
closest_idx_list = []
for i in range(N_CHUNKS):
cur_chunk = xyz_full[chunk_len*i:chunk_len*(i+1)]
dist_all = (xyz_sub.unsqueeze(0) - cur_chunk.to(DEVICE).unsqueeze(1))**2 # 300k,5k,3
cur_dist = (dist_all.sum(dim=-1))**0.5 # 300k,5k
min_idxs = torch.min(cur_dist, 1)[1]
del cur_dist
closest_idx_list.append(min_idxs)
all_nn_idxs = torch.cat(closest_idx_list,axis=0)
# just inversely weight all points
all_probs = pred_softmax[all_nn_idxs]
all_logits = logits[all_nn_idxs]
pred_full = all_probs.argmax(dim=1).cpu()# here, 0 is unlabeled, 1,...n_part correspond to actual part assignment
return all_logits, all_probs, pred_full
def get_segmentation_rgb(model, data, N_CHUNKS=5): # evaluate loader can only have batch size=1
temperature = np.exp(model.ln_logit_scale.item())
with torch.no_grad():
for key in data.keys():
if isinstance(data[key], torch.Tensor) and "full" not in key:
data[key] = data[key].to(DEVICE)
net_out = model(x=data)
text_embeds = data['label_embeds']
xyz_sub = data["coord"]
xyz_full = data["xyz_full"]
_, _, pred_full = pred_3d_upsample(net_out, # n_subsampled_pts, feat_dim
text_embeds, # n_parts, feat_dim
temperature,
xyz_sub,
xyz_full, # n_pts, 3
N_CHUNKS=N_CHUNKS)
seg_rgb = get_seg_color(pred_full.cpu())
return seg_rgb
def get_heatmap_rgb(model, data, N_CHUNKS=5): # evaluate loader can only have batch size=1
temperature = np.exp(model.ln_logit_scale.item())
with torch.no_grad():
for key in data.keys():
if isinstance(data[key], torch.Tensor) and "full" not in key:
data[key] = data[key].to(DEVICE)
net_out = model(x=data)
text_embeds = data['label_embeds']
xyz_sub = data["coord"]
xyz_full = data["xyz_full"]
all_logits, _, _ = pred_3d_upsample(net_out, # n_subsampled_pts, feat_dim
text_embeds, # n_parts, feat_dim
temperature,
xyz_sub,
xyz_full, # n_pts, 3
N_CHUNKS=N_CHUNKS)
scores = all_logits.squeeze().cpu()
heatmap_rgb = torch.tensor(plt.cm.jet(scores.numpy())[:,:3]).squeeze()
return heatmap_rgb
def segment_obj(xyz, rgb, normal, queries):
model = load_model()
data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
data_dict["label_embeds"] = encode_text(queries)
seg_rgb = get_segmentation_rgb(model, data_dict)
return seg_rgb
def get_heatmap(xyz, rgb, normal, query):
model = load_model()
data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
data_dict["label_embeds"] = encode_text([query])
heatmap_rgb = get_heatmap_rgb(model, data_dict)
return heatmap_rgb