Spaces:
Running
Running
import numpy as np | |
import copy | |
import argparse | |
import os, sys | |
import open3d as o3d | |
from sys import argv | |
from PIL import Image | |
import math | |
import cv2 | |
import torch | |
sys.path.append("../") | |
from lib.extractMatchTop import getPerspKeypoints, getPerspKeypointsEnsemble, siftMatching | |
from lib.model_test import D2Net | |
#### Cuda #### | |
use_cuda = torch.cuda.is_available() | |
device = torch.device('cuda:0' if use_cuda else 'cpu') | |
#### Argument Parsing #### | |
parser = argparse.ArgumentParser(description='RoRD ICP evaluation') | |
parser.add_argument( | |
'--rgb1', type=str, default = 'rgb/rgb2_1.jpg', | |
help='path to the rgb image1' | |
) | |
parser.add_argument( | |
'--rgb2', type=str, default = 'rgb/rgb2_2.jpg', | |
help='path to the rgb image2' | |
) | |
parser.add_argument( | |
'--depth1', type=str, default = 'depth/depth2_1.png', | |
help='path to the depth image1' | |
) | |
parser.add_argument( | |
'--depth2', type=str, default = 'depth/depth2_2.png', | |
help='path to the depth image2' | |
) | |
parser.add_argument( | |
'--model_rord', type=str, default = '../models/rord.pth', | |
help='path to the RoRD model for evaluation' | |
) | |
parser.add_argument( | |
'--model_d2', type=str, | |
help='path to the vanilla D2-Net model for evaluation' | |
) | |
parser.add_argument( | |
'--model_ens', action='store_true', | |
help='ensemble model of RoRD + D2-Net' | |
) | |
parser.add_argument( | |
'--sift', action='store_true', | |
help='Sift' | |
) | |
parser.add_argument( | |
'--camera_file', type=str, default='../configs/camera.txt', | |
help='path to the camera intrinsics file. In order: focal_x, focal_y, center_x, center_y, scaling_factor.' | |
) | |
parser.add_argument( | |
'--viz3d', action='store_true', | |
help='visualize the pointcloud registrations' | |
) | |
args = parser.parse_args() | |
if args.model_ens: # Change default paths accordingly for ensemble | |
model1_ens = '../../models/rord.pth' | |
model2_ens = '../../models/d2net.pth' | |
def draw_registration_result(source, target, transformation): | |
source_temp = copy.deepcopy(source) | |
target_temp = copy.deepcopy(target) | |
source_temp.transform(transformation) | |
target_temp += source_temp | |
# print("Saved registered PointCloud.") | |
# o3d.io.write_point_cloud("registered.pcd", target_temp) | |
trgSph.append(source_temp); trgSph.append(target_temp) | |
axis1 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) | |
axis2 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) | |
axis2.transform(transformation) | |
trgSph.append(axis1); trgSph.append(axis2) | |
print("Showing registered PointCloud.") | |
o3d.visualization.draw_geometries(trgSph) | |
def readDepth(depthFile): | |
depth = Image.open(depthFile) | |
if depth.mode != "I": | |
raise Exception("Depth image is not in intensity format") | |
return np.asarray(depth) | |
def readCamera(camera): | |
with open (camera, "rt") as file: | |
contents = file.read().split() | |
focalX = float(contents[0]) | |
focalY = float(contents[1]) | |
centerX = float(contents[2]) | |
centerY = float(contents[3]) | |
scalingFactor = float(contents[4]) | |
return focalX, focalY, centerX, centerY, scalingFactor | |
def getPointCloud(rgbFile, depthFile, pts): | |
thresh = 15.0 | |
depth = readDepth(depthFile) | |
rgb = Image.open(rgbFile) | |
points = [] | |
colors = [] | |
corIdx = [-1]*len(pts) | |
corPts = [None]*len(pts) | |
ptIdx = 0 | |
for v in range(depth.shape[0]): | |
for u in range(depth.shape[1]): | |
Z = depth[v, u] / scalingFactor | |
if Z==0: continue | |
if (Z > thresh): continue | |
X = (u - centerX) * Z / focalX | |
Y = (v - centerY) * Z / focalY | |
points.append((X, Y, Z)) | |
colors.append(rgb.getpixel((u, v))) | |
if((u, v) in pts): | |
# print("Point found.") | |
index = pts.index((u, v)) | |
corIdx[index] = ptIdx | |
corPts[index] = (X, Y, Z) | |
ptIdx = ptIdx+1 | |
points = np.asarray(points) | |
colors = np.asarray(colors) | |
pcd = o3d.geometry.PointCloud() | |
pcd.points = o3d.utility.Vector3dVector(points) | |
pcd.colors = o3d.utility.Vector3dVector(colors/255) | |
return pcd, corIdx, corPts | |
def convertPts(A): | |
X = A[0]; Y = A[1] | |
x = []; y = [] | |
for i in range(len(X)): | |
x.append(int(float(X[i]))) | |
for i in range(len(Y)): | |
y.append(int(float(Y[i]))) | |
pts = [] | |
for i in range(len(x)): | |
pts.append((x[i], y[i])) | |
return pts | |
def getSphere(pts): | |
sphs = [] | |
for ele in pts: | |
if(ele is not None): | |
sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.03) | |
sphere.paint_uniform_color([0.9, 0.2, 0]) | |
trans = np.identity(4) | |
trans[0, 3] = ele[0] | |
trans[1, 3] = ele[1] | |
trans[2, 3] = ele[2] | |
sphere.transform(trans) | |
sphs.append(sphere) | |
return sphs | |
def get3dCor(src, trg): | |
corr = [] | |
for sId, tId in zip(src, trg): | |
if(sId != -1 and tId != -1): | |
corr.append((sId, tId)) | |
corr = np.asarray(corr) | |
return corr | |
if __name__ == "__main__": | |
focalX, focalY, centerX, centerY, scalingFactor = readCamera(args.camera_file) | |
rgb_name_src = os.path.basename(args.rgb1) | |
H_name_src = os.path.splitext(rgb_name_src)[0] + '.npy' | |
srcH = os.path.join(os.path.dirname(args.rgb1), H_name_src) | |
rgb_name_trg = os.path.basename(args.rgb2) | |
H_name_trg = os.path.splitext(rgb_name_trg)[0] + '.npy' | |
trgH = os.path.join(os.path.dirname(args.rgb2), H_name_trg) | |
use_cuda = torch.cuda.is_available() | |
device = torch.device('cuda:0' if use_cuda else 'cpu') | |
model1 = D2Net(model_file=args.model_d2) | |
model1 = model1.to(device) | |
model2 = D2Net(model_file=args.model_rord) | |
model2 = model2.to(device) | |
if args.model_rord: | |
srcPts, trgPts, matchImg, matchImgOrtho = getPerspKeypoints(args.rgb1, args.rgb2, srcH, trgH, model2, device) | |
elif args.model_d2: | |
srcPts, trgPts, matchImg, matchImgOrtho = getPerspKeypoints(args.rgb1, args.rgb2, srcH, trgH, model1, device) | |
elif args.model_ens: | |
model1 = D2Net(model_file=model1_ens) | |
model1 = model1.to(device) | |
model2 = D2Net(model_file=model2_ens) | |
model2 = model2.to(device) | |
srcPts, trgPts, matchImg, matchImgOrtho = getPerspKeypointsEnsemble(model1, model2, args.rgb1, args.rgb2, srcH, trgH, device) | |
elif args.sift: | |
srcPts, trgPts, matchImg, matchImgOrtho = siftMatching(args.rgb1, args.rgb2, srcH, trgH, device) | |
#### Visualization #### | |
print("\nShowing matches in perspective and orthographic view. Press q\n") | |
cv2.imshow('Orthographic view', matchImgOrtho) | |
cv2.imshow('Perspective view', matchImg) | |
cv2.waitKey() | |
srcPts = convertPts(srcPts) | |
trgPts = convertPts(trgPts) | |
srcCld, srcIdx, srcCor = getPointCloud(args.rgb1, args.depth1, srcPts) | |
trgCld, trgIdx, trgCor = getPointCloud(args.rgb2, args.depth2, trgPts) | |
srcSph = getSphere(srcCor) | |
trgSph = getSphere(trgCor) | |
axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) | |
srcSph.append(srcCld); srcSph.append(axis) | |
trgSph.append(trgCld); trgSph.append(axis) | |
corr = get3dCor(srcIdx, trgIdx) | |
p2p = o3d.registration.TransformationEstimationPointToPoint() | |
trans_init = p2p.compute_transformation(srcCld, trgCld, o3d.utility.Vector2iVector(corr)) | |
print("Transformation matrix: \n", trans_init) | |
if args.viz3d: | |
# o3d.visualization.draw_geometries(srcSph) | |
# o3d.visualization.draw_geometries(trgSph) | |
draw_registration_result(srcCld, trgCld, trans_init) | |