sat3density / test.py
venite's picture
initial
f670afc
raw
history blame
11.3 kB
import importlib
import os
import os.path as osp
import sys
import warnings
import torch
import options
from utils import log
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from matplotlib.widgets import Cursor
from PIL import Image
from scipy.interpolate import interp1d, splev, splprep
from torch.utils.data import default_convert,default_collate
import torchvision
from model.geometry_transform import render_sat,render
import cv2
import imageio
def get_checkpoint(opt):
if opt.test_ckpt_path == '2u87bj8w':
opt.test_ckpt_path = osp.join('wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth')
elif opt.test_ckpt_path == '2cqv8uh4':
opt.test_ckpt_path = osp.join('wandb/run-20230303_142752-2cqv8uh4/files/checkpoint/model.pth')
else:
pass
def img_read(img,size=None,datatype='RGB'):
img = Image.open(img).convert('RGB' if datatype=='RGB' else "L")
if size:
if type(size) is int:
size = (size,size)
img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST)
img = transforms.ToTensor()(img)
return img
def select_points(sat_image):
fig = plt.figure()
fig.set_size_inches(1,1,forward=False)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
ax.imshow(sat_image)
coords = []
def ondrag(event):
if event.button != 1:
return
x, y = int(event.xdata), int(event.ydata)
coords.append((x, y))
ax.plot([x], [y], 'o', color='red')
fig.canvas.draw_idle()
fig.add_axes(ax)
cursor = Cursor(ax, useblit=True, color='red', linewidth=1)
fig.canvas.mpl_connect('motion_notify_event', ondrag)
plt.show()
plt.close()
unique_lst = list(dict.fromkeys(coords))
pixels = []
for x in coords:
if x in unique_lst:
if x not in pixels:
pixels.append(x)
print(pixels)
pixels = np.array(pixels)
tck, u = splprep(pixels.T, s=25, per=0)
u_new = np.linspace(u.min(), u.max(), 80)
x_new, y_new = splev(u_new, tck)
smooth_path = np.array([x_new,y_new]).T
angles = np.arctan2(y_new[1:]-y_new[:-1],x_new[1:]-x_new[:-1])
return pixels, angles, smooth_path
def volume2pyvista(volume_data):
import pyvista as pv
grid = pv.UniformGrid()
grid.dimensions = volume_data.shape
grid.spacing = (1, 1, 1)
grid.origin = (0, 0, 0)
grid.point_data['values'] = volume_data.flatten(order='F')
return grid
def img_pair2vid(sat_list,save_dir,media_path= 'interpolation.mp4'):
fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
out = cv2.VideoWriter(media_path, fourcc, 12.0, (512, 128))
for i in range(len(sat_list)):
img1 = cv2.imread(os.path.join( save_dir , sat_list[i]))
out.write(img1)
out.release()
@torch.no_grad()
def test_vid(model, opt):
ckpt = torch.load(opt.test_ckpt_path, map_location='cpu')
model.netG.load_state_dict(ckpt['netG'])
model.netG.eval()
# for idx, data in enumerate(model.val_loader):
# import pdb; pdb.set_trace()
demo_imgpath = opt.demo_img
sty_imgpath = opt.sty_img
if opt.sky_img is None:
sky_imgpath = opt.sty_img.replace('image','sky')
else:
sky_imgpath = opt.sky_img
sat = img_read(demo_imgpath, size=opt.data.sat_size)
pano = img_read(sty_imgpath, size=opt.data.pano_size)
input_dict = {}
input_dict['sat'] = sat
input_dict['pano'] = pano
input_dict['paths'] = demo_imgpath
if opt.data.sky_mask:
sky = img_read(sky_imgpath, size=opt.data.pano_size, datatype='L')
input_a = pano*sky
sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))])
input_dict['sky_histc'] = sky_histc
input_dict['sky_mask'] = sky
else:
sky_histc = None
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].unsqueeze(0)
model.set_input(input_dict)
model.style_temp = model.sky_histc
pixels, angles, smooth_path = select_points(sat_image=sat.permute(1,2,0).numpy())
rendered_image_list = []
rendered_depth_list = []
volume_data = None
for i, (x,y) in enumerate(pixels):
opt.origin_H_W = [(y-128)/128, (x-128)/128] # TODO: hard code should be removed in the future
print('Rendering at ({}, {})'.format(x,y))
model.forward(opt)
rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().numpy().transpose((1,2,0))
rgb = np.array(rgb*255, dtype=np.uint8)
rendered_image_list.append(rgb)
rendered_depth_list.append(
model.out_put.depth[0,0].cpu().numpy()
)
sat_opacity, sat_depth = render_sat(opt,model.out_put.voxel)
volume_data = model.out_put.voxel[0].cpu().numpy().transpose((1,2,0))
volume_data = np.clip(volume_data, None, 10)
volume_export = volume2pyvista(volume_data)
os.makedirs(opt.save_dir, exist_ok=True)
volume_export.save(os.path.join(opt.save_dir, 'volume.vtk'))
# save rendered images
os.makedirs(osp.join(opt.save_dir,'rendered_images'), exist_ok=True)
for i, img in enumerate(rendered_image_list):
plt.imsave(osp.join(opt.save_dir,'rendered_images','{:05d}.png'.format(i)), img)
os.makedirs(osp.join(opt.save_dir,'rendered_depth'), exist_ok=True)
os.makedirs(osp.join(opt.save_dir,
'rendered_images+depths'), exist_ok=True)
for i, img in enumerate(rendered_depth_list):
depth = np.array(img/img.max()*255,dtype=np.uint8)
depth = cv2.applyColorMap(depth, cv2.COLORMAP_TURBO)
plt.imsave(osp.join(opt.save_dir,'rendered_depth','{:05d}.png'.format(i)), depth)
image_and_depth = np.concatenate((rendered_image_list[i], depth), axis=0)
plt.imsave(osp.join(opt.save_dir,'rendered_images+depths','{:05d}.png'.format(i)), image_and_depth)
os.makedirs(osp.join(opt.save_dir,'sat_images'), exist_ok=True)
for i, (x,y) in enumerate(pixels):
# plt.plot(x, y, 'o', color='red')
sat_rgb = sat.permute(1,2,0).numpy()
sat_rgb = np.array(sat_rgb*255, dtype=np.uint8)
fig = plt.figure()
fig.set_size_inches(1,1,forward=False)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
ax.imshow(sat_rgb)
ax.plot(pixels[:i+1,0], pixels[:i+1,1], 'r-', color='red')
ax.plot(x, y, 'o', color='red', markersize=2)
# if i < len(pixels)-1:
# # ax.plot([x,pixels[0,0]],[y,pixels[0,1]],'r-')
# # else:
# ax.plot([x,pixels[i+1,0]],[y,pixels[i+1,1]],'r-')
fig.add_axes(ax)
plt.savefig(osp.join(opt.save_dir,'sat_images','{:05d}.png'.format(i)),bbox_inches='tight', pad_inches=0, dpi=256)
print('Done')
@torch.no_grad()
def test_interpolation(model,opt):
ckpt = torch.load(opt.test_ckpt_path, map_location='cpu')
model.netG.load_state_dict(ckpt['netG'])
model.netG.eval()
sat = img_read(opt.demo_img , size=opt.data.sat_size)
pano1 = img_read(opt.sty_img1 , size=opt.data.pano_size)
pano2 = img_read(opt.sty_img2 , size=opt.data.pano_size)
input_dict = {}
input_dict['sat'] = sat
input_dict['paths'] = opt.demo_img
# black_ground = torch.zeros_like(pano1)
sky_imgpath1 = opt.sty_img1.replace('image','sky')
sky_imgpath2 = opt.sty_img2.replace('image','sky')
sky = img_read(sky_imgpath1, size=opt.data.pano_size, datatype='L')
input_a = pano1*sky
sky_histc1 = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))])
# for idx in range(len(input_a)):
# if idx == 0:
# sky_histc1 = input_a[idx].histc()[10:]
# else:
# sky_histc1 = torch.cat([input_a[idx].histc()[10:],sky_histc1],dim=0)
sky = img_read(sky_imgpath2, size=opt.data.pano_size, datatype='L')
input_b = pano2*sky
sky_histc2 = torch.cat([input_b[i].histc()[10:] for i in reversed(range(3))])
# for idx in range(len(input_b)):
# if idx == 0:
# sky_histc2 = input_b[idx].histc()[10:]
# else:
# sky_histc2 = torch.cat([input_b[idx].histc()[10:],sky_histc2],dim=0)
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].unsqueeze(0)
model.set_input(input_dict)
pixels = [(128,128)]
x,y = pixels[0]
opt.origin_H_W = [(y-128)/128 , (x-128)/128]
print(opt.origin_H_W)
estimated_height = model.netG.depth_model(model.real_A)
geo_outputs = render(opt,model.real_A,estimated_height,model.netG.pano_direction,PE=model.netG.PE)
generator_inputs,opacity,depth = geo_outputs['rgb'],geo_outputs['opacity'],geo_outputs['depth']
if model.netG.gen_cfg.cat_opa:
generator_inputs = torch.cat((generator_inputs,opacity),dim=1)
if model.netG.gen_cfg.cat_depth:
generator_inputs = torch.cat((generator_inputs,depth),dim=1)
_, _, z1 = model.netG.style_encode(sky_histc1.unsqueeze(0).to(model.device))
_, _, z2 = model.netG.style_encode(sky_histc2.unsqueeze(0).to(model.device))
num_inter = 60
for i in range(num_inter):
z = z1 * (1-i/(num_inter-1)) + z2* (i/(num_inter-1))
z = model.netG.style_model(z)
output_RGB = model.netG.denoise_model(generator_inputs,z)
save_img = output_RGB.cpu()
name = 'img{:03d}.png'.format(i)
torchvision.utils.save_image(save_img,os.path.join(opt.save_dir,name))
img_list = sorted(os.listdir(opt.save_dir))
sat_list = []
for img in img_list:
sat_list.append(img)
media_path = os.path.join(opt.save_dir,'interpolation.mp4')
img_pair2vid(sat_list,opt.save_dir,media_path)
print('Done, save 2 ',media_path)
def main():
log.process(os.getpid())
log.title("[{}] (PyTorch code for testing Sat2Density and debug".format(sys.argv[0]))
opt_cmd = options.parse_arguments(sys.argv[1:])
opt = options.set(opt_cmd=opt_cmd)
opt.isTrain = False
opt.name = opt.yaml if opt.name is None else opt.name
opt.batch_size = 1
if opt.save_dir is None:
raise Exception("Please specify the save dir")
get_checkpoint(opt)
mode = importlib.import_module("model.{}".format(opt.model))
m = mode.Model(opt)
# m.load_dataset(opt)
m.build_networks(opt)
if os.path.exists(opt.save_dir):
import shutil
shutil.rmtree(opt.save_dir)
if opt.task == 'test_vid':
test_vid(m, opt)
if opt.task == 'test_interpolation':
assert opt.sty_img1
assert opt.sty_img2
os.makedirs(opt.save_dir, exist_ok=True)
test_interpolation(m,opt)
# import pdb; pdb.set_trace()
# print(m)
# # test or visualization
# if opt.task == 'test_vid':
# m.test_vid(opt)
# elif opt.task == 'test_sty':
# m.test_sty(opt)
# elif opt.task == 'test_interpolation':
# m.test_interpolation(opt)
# else:
# raise RuntimeError("Unknow task")
if __name__ == "__main__":
main()