Vincentqyw
update: rord
49a0323
raw
history blame
No virus
8.49 kB
import numpy as np
import argparse
import copy
import os, sys
import open3d as o3d
from sys import argv, exit
from PIL import Image
import math
from tqdm import tqdm
import cv2
sys.path.append("../../")
from lib.extractMatchTop import getPerspKeypoints, getPerspKeypointsEnsemble, siftMatching
import pandas as pd
import torch
from lib.model_test import D2Net
#### Cuda ####
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'cpu')
#### Argument Parsing ####
parser = argparse.ArgumentParser(description='RoRD ICP evaluation on a DiverseView dataset sequence.')
parser.add_argument('--dataset', type=str, default='/scratch/udit/realsense/RoRD_data/preprocessed/',
help='path to the dataset folder')
parser.add_argument('--sequence', type=str, default='data1')
parser.add_argument(
'--output_dir', type=str, default='out',
help='output directory for RT estimates'
)
parser.add_argument(
'--model_rord', type=str, help='path to the RoRD model for evaluation'
)
parser.add_argument(
'--model_d2', type=str, help='path to the vanilla D2-Net model for evaluation'
)
parser.add_argument(
'--model_ens', action='store_true',
help='ensemble model of RoRD + D2-Net'
)
parser.add_argument(
'--sift', action='store_true',
help='Sift'
)
parser.add_argument(
'--viz3d', action='store_true',
help='visualize the pointcloud registrations'
)
parser.add_argument(
'--log_interval', type=int, default=9,
help='Matched image logging interval'
)
parser.add_argument(
'--camera_file', type=str, default='../../configs/camera.txt',
help='path to the camera intrinsics file. In order: focal_x, focal_y, center_x, center_y, scaling_factor.'
)
parser.add_argument(
'--persp', action='store_true', default=False,
help='Feature matching on perspective images.'
)
parser.set_defaults(fp16=False)
args = parser.parse_args()
if args.model_ens: # Change default paths accordingly for ensemble
model1_ens = '../../models/rord.pth'
model2_ens = '../../models/d2net.pth'
def draw_registration_result(source, target, transformation):
source_temp = copy.deepcopy(source)
target_temp = copy.deepcopy(target)
source_temp.transform(transformation)
trgSph.append(source_temp); trgSph.append(target_temp)
axis1 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
axis2 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
axis2.transform(transformation)
trgSph.append(axis1); trgSph.append(axis2)
o3d.visualization.draw_geometries(trgSph)
def readDepth(depthFile):
depth = Image.open(depthFile)
if depth.mode != "I":
raise Exception("Depth image is not in intensity format")
return np.asarray(depth)
def readCamera(camera):
with open (camera, "rt") as file:
contents = file.read().split()
focalX = float(contents[0])
focalY = float(contents[1])
centerX = float(contents[2])
centerY = float(contents[3])
scalingFactor = float(contents[4])
return focalX, focalY, centerX, centerY, scalingFactor
def getPointCloud(rgbFile, depthFile, pts):
thresh = 15.0
depth = readDepth(depthFile)
rgb = Image.open(rgbFile)
points = []
colors = []
corIdx = [-1]*len(pts)
corPts = [None]*len(pts)
ptIdx = 0
for v in range(depth.shape[0]):
for u in range(depth.shape[1]):
Z = depth[v, u] / scalingFactor
if Z==0: continue
if (Z > thresh): continue
X = (u - centerX) * Z / focalX
Y = (v - centerY) * Z / focalY
points.append((X, Y, Z))
colors.append(rgb.getpixel((u, v)))
if((u, v) in pts):
index = pts.index((u, v))
corIdx[index] = ptIdx
corPts[index] = (X, Y, Z)
ptIdx = ptIdx+1
points = np.asarray(points)
colors = np.asarray(colors)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd.colors = o3d.utility.Vector3dVector(colors/255)
return pcd, corIdx, corPts
def convertPts(A):
X = A[0]; Y = A[1]
x = []; y = []
for i in range(len(X)):
x.append(int(float(X[i])))
for i in range(len(Y)):
y.append(int(float(Y[i])))
pts = []
for i in range(len(x)):
pts.append((x[i], y[i]))
return pts
def getSphere(pts):
sphs = []
for element in pts:
if(element is not None):
sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.03)
sphere.paint_uniform_color([0.9, 0.2, 0])
trans = np.identity(4)
trans[0, 3] = element[0]
trans[1, 3] = element[1]
trans[2, 3] = element[2]
sphere.transform(trans)
sphs.append(sphere)
return sphs
def get3dCor(src, trg):
corr = []
for sId, tId in zip(src, trg):
if(sId != -1 and tId != -1):
corr.append((sId, tId))
corr = np.asarray(corr)
return corr
if __name__ == "__main__":
camera_file = args.camera_file
rgb_csv = args.dataset + args.sequence + '/rtImagesRgb.csv'
depth_csv = args.dataset + args.sequence + '/rtImagesDepth.csv'
os.makedirs(os.path.join(args.output_dir, 'vis'), exist_ok=True)
dir_name = args.output_dir
os.makedirs(args.output_dir, exist_ok=True)
focalX, focalY, centerX, centerY, scalingFactor = readCamera(camera_file)
df_rgb = pd.read_csv(rgb_csv)
df_dep = pd.read_csv(depth_csv)
model1 = D2Net(model_file=args.model_d2).to(device)
model2 = D2Net(model_file=args.model_rord).to(device)
queryId = 0
for im_q, dep_q in tqdm(zip(df_rgb['query'], df_dep['query']), total=df_rgb.shape[0]):
filter_list = []
dbId = 0
for im_d, dep_d in tqdm(zip(df_rgb.iteritems(), df_dep.iteritems()), total=df_rgb.shape[1]):
if im_d[0] == 'query':
continue
rgb_name_src = os.path.basename(im_q)
H_name_src = os.path.splitext(rgb_name_src)[0] + '.npy'
srcH = args.dataset + args.sequence + '/rgb/' + H_name_src
rgb_name_trg = os.path.basename(im_d[1][1])
H_name_trg = os.path.splitext(rgb_name_trg)[0] + '.npy'
trgH = args.dataset + args.sequence + '/rgb/' + H_name_trg
srcImg = srcH.replace('.npy', '.jpg')
trgImg = trgH.replace('.npy', '.jpg')
if args.model_rord:
if args.persp:
srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, HFile1=None, HFile2=None, model=model2, device=device)
else:
srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, srcH, trgH, model2, device)
elif args.model_d2:
if args.persp:
srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, HFile1=None, HFile2=None, model=model2, device=device)
else:
srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, srcH, trgH, model1, device)
elif args.model_ens:
model1 = D2Net(model_file=model1_ens)
model1 = model1.to(device)
model2 = D2Net(model_file=model2_ens)
model2 = model2.to(device)
srcPts, trgPts, matchImg = getPerspKeypointsEnsemble(model1, model2, srcImg, trgImg, srcH, trgH, device)
elif args.sift:
if args.persp:
srcPts, trgPts, matchImg, _ = siftMatching(srcImg, trgImg, HFile1=None, HFile2=None, device=device)
else:
srcPts, trgPts, matchImg, _ = siftMatching(srcImg, trgImg, srcH, trgH, device)
if(isinstance(srcPts, list) == True):
print(np.identity(4))
filter_list.append(np.identity(4))
continue
srcPts = convertPts(srcPts)
trgPts = convertPts(trgPts)
depth_name_src = os.path.dirname(os.path.dirname(args.dataset)) + '/' + dep_q
depth_name_trg = os.path.dirname(os.path.dirname(args.dataset)) + '/' + dep_d[1][1]
srcCld, srcIdx, srcCor = getPointCloud(srcImg, depth_name_src, srcPts)
trgCld, trgIdx, trgCor = getPointCloud(trgImg, depth_name_trg, trgPts)
srcSph = getSphere(srcCor)
trgSph = getSphere(trgCor)
axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
srcSph.append(srcCld); srcSph.append(axis)
trgSph.append(trgCld); trgSph.append(axis)
corr = get3dCor(srcIdx, trgIdx)
p2p = o3d.pipelines.registration.TransformationEstimationPointToPoint()
trans_init = p2p.compute_transformation(srcCld, trgCld, o3d.utility.Vector2iVector(corr))
# print(trans_init)
filter_list.append(trans_init)
if args.viz3d:
o3d.visualization.draw_geometries(srcSph)
o3d.visualization.draw_geometries(trgSph)
draw_registration_result(srcCld, trgCld, trans_init)
if(dbId%args.log_interval == 0):
cv2.imwrite(os.path.join(args.output_dir, 'vis') + "/matchImg.%02d.%02d.jpg"%(queryId, dbId//args.log_interval), matchImg)
dbId += 1
RT = np.stack(filter_list).transpose(1,2,0)
np.save(os.path.join(dir_name, str(queryId) + '.npy'), RT)
queryId += 1
print('-----check-------', RT.shape)