|
import gradio as gr |
|
|
|
|
|
import numpy as np |
|
import cv2 |
|
from tqdm import tqdm |
|
|
|
import torch |
|
from pytorch3d.io.obj_io import load_obj |
|
import tempfile |
|
import main_mcc |
|
import mcc_model |
|
import util.misc as misc |
|
from engine_mcc import prepare_data |
|
from plyfile import PlyData, PlyElement |
|
|
|
def run_inference(model, samples, device, temperature, args): |
|
model.eval() |
|
|
|
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data( |
|
samples, device, is_train=False, args=args, is_viz=True |
|
) |
|
pred_occupy = [] |
|
pred_colors = [] |
|
|
|
max_n_unseen_fwd = 2000 |
|
|
|
model.cached_enc_feat = None |
|
num_passes = int(np.ceil(unseen_xyz.shape[1] / max_n_unseen_fwd)) |
|
for p_idx in range(num_passes): |
|
p_start = p_idx * max_n_unseen_fwd |
|
p_end = (p_idx + 1) * max_n_unseen_fwd |
|
cur_unseen_xyz = unseen_xyz[:, p_start:p_end] |
|
cur_unseen_rgb = unseen_rgb[:, p_start:p_end].zero_() |
|
cur_labels = labels[:, p_start:p_end].zero_() |
|
|
|
with torch.no_grad(): |
|
_, pred = model( |
|
seen_images=seen_images, |
|
seen_xyz=seen_xyz, |
|
unseen_xyz=cur_unseen_xyz, |
|
unseen_rgb=cur_unseen_rgb, |
|
unseen_occupy=cur_labels, |
|
cache_enc=True, |
|
valid_seen_xyz=valid_seen_xyz, |
|
) |
|
if device == "cuda": |
|
pred_occupy.append(pred[..., 0].cuda()) |
|
else: |
|
pred_occupy.append(pred[..., 0].cpu()) |
|
if args.regress_color: |
|
pred_colors.append(pred[..., 1:].reshape((-1, 3))) |
|
else: |
|
pred_colors.append( |
|
( |
|
torch.nn.Softmax(dim=2)( |
|
pred[..., 1:].reshape((-1, 3, 256)) / temperature |
|
) * torch.linspace(0, 1, 256, device=pred.device) |
|
).sum(axis=2) |
|
) |
|
|
|
pred_occupy = torch.cat(pred_occupy, dim=1) |
|
pred_occupy = torch.nn.Sigmoid()(pred_occupy) |
|
return torch.cat(pred_colors, dim=0).cpu().numpy(), pred_occupy.cpu().numpy(), unseen_xyz.cpu().numpy() |
|
|
|
def pad_image(im, value): |
|
if im.shape[0] > im.shape[1]: |
|
diff = im.shape[0] - im.shape[1] |
|
return torch.cat([im, (torch.zeros((im.shape[0], diff, im.shape[2])) + value)], dim=1) |
|
else: |
|
diff = im.shape[1] - im.shape[0] |
|
return torch.cat([im, (torch.zeros((diff, im.shape[1], im.shape[2])) + value)], dim=0) |
|
|
|
|
|
def normalize(seen_xyz): |
|
seen_xyz = seen_xyz / (seen_xyz[torch.isfinite(seen_xyz.sum(dim=-1))].var(dim=0) ** 0.5).mean() |
|
seen_xyz = seen_xyz - seen_xyz[torch.isfinite(seen_xyz.sum(dim=-1))].mean(axis=0) |
|
return seen_xyz |
|
|
|
def infer( |
|
image, |
|
point_cloud, |
|
seg, |
|
granularity, |
|
temperature, |
|
): |
|
|
|
score_thresholds = [0.1, 0.2, 0.3, 0.4, 0.5] |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
parser = main_mcc.get_args_parser() |
|
parser.set_defaults(eval=True) |
|
|
|
args = parser.parse_args() |
|
|
|
model = mcc_model.get_mcc_model( |
|
occupancy_weight=1.0, |
|
rgb_weight=0.01, |
|
args=args, |
|
) |
|
|
|
if device == "cuda": |
|
model = model.cuda() |
|
|
|
misc.load_model(args=args, model_without_ddp=model, optimizer=None, loss_scaler=None) |
|
|
|
rgb = image |
|
obj = load_obj(point_cloud.name) |
|
|
|
seen_rgb = (torch.tensor(rgb).float() / 255)[..., [2, 1, 0]] |
|
H, W = seen_rgb.shape[:2] |
|
seen_rgb = torch.nn.functional.interpolate( |
|
seen_rgb.permute(2, 0, 1)[None], |
|
size=[H, W], |
|
mode="bilinear", |
|
align_corners=False, |
|
)[0].permute(1, 2, 0) |
|
|
|
seen_xyz = obj[0].reshape(H, W, 3) |
|
seg = cv2.imread(seg.name, cv2.IMREAD_UNCHANGED) |
|
mask = torch.tensor(cv2.resize(seg, (W, H))).bool() |
|
seen_xyz[~mask] = float('inf') |
|
|
|
seen_xyz = normalize(seen_xyz) |
|
|
|
bottom, right = mask.nonzero().max(dim=0)[0] |
|
top, left = mask.nonzero().min(dim=0)[0] |
|
|
|
bottom = bottom + 40 |
|
right = right + 40 |
|
top = max(top - 40, 0) |
|
left = max(left - 40, 0) |
|
|
|
seen_xyz = seen_xyz[top:bottom+1, left:right+1] |
|
seen_rgb = seen_rgb[top:bottom+1, left:right+1] |
|
|
|
seen_xyz = pad_image(seen_xyz, float('inf')) |
|
seen_rgb = pad_image(seen_rgb, 0) |
|
|
|
seen_rgb = torch.nn.functional.interpolate( |
|
seen_rgb.permute(2, 0, 1)[None], |
|
size=[800, 800], |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
seen_xyz = torch.nn.functional.interpolate( |
|
seen_xyz.permute(2, 0, 1)[None], |
|
size=[112, 112], |
|
mode="bilinear", |
|
align_corners=False, |
|
).permute(0, 2, 3, 1) |
|
|
|
samples = [ |
|
[seen_xyz, seen_rgb], |
|
[torch.zeros((20000, 3)), torch.zeros((20000, 3))], |
|
] |
|
|
|
pred_colors, pred_occupy, unseen_xyz = run_inference(model, samples, device, temperature, args) |
|
_masks = pred_occupy > 0.1 |
|
unseen_xyz = unseen_xyz[_masks] |
|
pred_colors = pred_colors[None, ...][_masks] * 255 |
|
|
|
|
|
vertex = np.core.records.fromarrays(np.hstack((unseen_xyz, pred_colors)).transpose(), |
|
names='x, y, z, red, green, blue', |
|
formats='f8, f8, f8, u1, u1, u1') |
|
|
|
|
|
|
|
element = PlyElement.describe(vertex, 'vertex') |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as f: |
|
PlyData([element], text=True).write(f) |
|
temp_file_name = f.name |
|
|
|
return temp_file_name |
|
|
|
|
|
demo = gr.Interface(fn=infer, |
|
inputs=[gr.Image(label="Input Image"), |
|
gr.File(label="Pointcloud File"), |
|
gr.File(label="Segmentation File"), |
|
gr.Slider(minimum=0.05, maximum=0.5, step=0.05, value=0.2, label="Granularity"), |
|
gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.1, label="Temperature") |
|
], |
|
outputs=[gr.outputs.File(label="Point Cloud Json")], |
|
examples=[["demo/quest2.jpg", "demo/quest2.obj", "demo/quest2_seg.png", 0.2, 0.1]], |
|
cache_examples=True) |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|