Spaces:
Runtime error
Runtime error
import importlib | |
import os | |
import os.path as osp | |
import sys | |
import warnings | |
import torch | |
import options | |
from utils import log | |
warnings.filterwarnings("ignore") | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torchvision.transforms as transforms | |
from matplotlib.widgets import Cursor | |
from PIL import Image | |
from scipy.interpolate import interp1d, splev, splprep | |
from torch.utils.data import default_convert,default_collate | |
import torchvision | |
from model.geometry_transform import render_sat,render | |
import cv2 | |
import imageio | |
def get_checkpoint(opt): | |
if opt.test_ckpt_path == '2u87bj8w': | |
opt.test_ckpt_path = osp.join('wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth') | |
elif opt.test_ckpt_path == '2cqv8uh4': | |
opt.test_ckpt_path = osp.join('wandb/run-20230303_142752-2cqv8uh4/files/checkpoint/model.pth') | |
else: | |
pass | |
def img_read(img,size=None,datatype='RGB'): | |
img = Image.open(img).convert('RGB' if datatype=='RGB' else "L") | |
if size: | |
if type(size) is int: | |
size = (size,size) | |
img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST) | |
img = transforms.ToTensor()(img) | |
return img | |
def select_points(sat_image): | |
fig = plt.figure() | |
fig.set_size_inches(1,1,forward=False) | |
ax = plt.Axes(fig, [0., 0., 1., 1.]) | |
ax.set_axis_off() | |
ax.imshow(sat_image) | |
coords = [] | |
def ondrag(event): | |
if event.button != 1: | |
return | |
x, y = int(event.xdata), int(event.ydata) | |
coords.append((x, y)) | |
ax.plot([x], [y], 'o', color='red') | |
fig.canvas.draw_idle() | |
fig.add_axes(ax) | |
cursor = Cursor(ax, useblit=True, color='red', linewidth=1) | |
fig.canvas.mpl_connect('motion_notify_event', ondrag) | |
plt.show() | |
plt.close() | |
unique_lst = list(dict.fromkeys(coords)) | |
pixels = [] | |
for x in coords: | |
if x in unique_lst: | |
if x not in pixels: | |
pixels.append(x) | |
print(pixels) | |
pixels = np.array(pixels) | |
tck, u = splprep(pixels.T, s=25, per=0) | |
u_new = np.linspace(u.min(), u.max(), 80) | |
x_new, y_new = splev(u_new, tck) | |
smooth_path = np.array([x_new,y_new]).T | |
angles = np.arctan2(y_new[1:]-y_new[:-1],x_new[1:]-x_new[:-1]) | |
return pixels, angles, smooth_path | |
def volume2pyvista(volume_data): | |
import pyvista as pv | |
grid = pv.UniformGrid() | |
grid.dimensions = volume_data.shape | |
grid.spacing = (1, 1, 1) | |
grid.origin = (0, 0, 0) | |
grid.point_data['values'] = volume_data.flatten(order='F') | |
return grid | |
def img_pair2vid(sat_list,save_dir,media_path= 'interpolation.mp4'): | |
fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') | |
out = cv2.VideoWriter(media_path, fourcc, 12.0, (512, 128)) | |
for i in range(len(sat_list)): | |
img1 = cv2.imread(os.path.join( save_dir , sat_list[i])) | |
out.write(img1) | |
out.release() | |
def test_vid(model, opt): | |
ckpt = torch.load(opt.test_ckpt_path, map_location='cpu') | |
model.netG.load_state_dict(ckpt['netG']) | |
model.netG.eval() | |
# for idx, data in enumerate(model.val_loader): | |
# import pdb; pdb.set_trace() | |
demo_imgpath = opt.demo_img | |
sty_imgpath = opt.sty_img | |
if opt.sky_img is None: | |
sky_imgpath = opt.sty_img.replace('image','sky') | |
else: | |
sky_imgpath = opt.sky_img | |
sat = img_read(demo_imgpath, size=opt.data.sat_size) | |
pano = img_read(sty_imgpath, size=opt.data.pano_size) | |
input_dict = {} | |
input_dict['sat'] = sat | |
input_dict['pano'] = pano | |
input_dict['paths'] = demo_imgpath | |
if opt.data.sky_mask: | |
sky = img_read(sky_imgpath, size=opt.data.pano_size, datatype='L') | |
input_a = pano*sky | |
sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))]) | |
input_dict['sky_histc'] = sky_histc | |
input_dict['sky_mask'] = sky | |
else: | |
sky_histc = None | |
for key in input_dict.keys(): | |
if isinstance(input_dict[key], torch.Tensor): | |
input_dict[key] = input_dict[key].unsqueeze(0) | |
model.set_input(input_dict) | |
model.style_temp = model.sky_histc | |
pixels, angles, smooth_path = select_points(sat_image=sat.permute(1,2,0).numpy()) | |
rendered_image_list = [] | |
rendered_depth_list = [] | |
volume_data = None | |
for i, (x,y) in enumerate(pixels): | |
opt.origin_H_W = [(y-128)/128, (x-128)/128] # TODO: hard code should be removed in the future | |
print('Rendering at ({}, {})'.format(x,y)) | |
model.forward(opt) | |
rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().numpy().transpose((1,2,0)) | |
rgb = np.array(rgb*255, dtype=np.uint8) | |
rendered_image_list.append(rgb) | |
rendered_depth_list.append( | |
model.out_put.depth[0,0].cpu().numpy() | |
) | |
sat_opacity, sat_depth = render_sat(opt,model.out_put.voxel) | |
volume_data = model.out_put.voxel[0].cpu().numpy().transpose((1,2,0)) | |
volume_data = np.clip(volume_data, None, 10) | |
volume_export = volume2pyvista(volume_data) | |
os.makedirs(opt.save_dir, exist_ok=True) | |
volume_export.save(os.path.join(opt.save_dir, 'volume.vtk')) | |
# save rendered images | |
os.makedirs(osp.join(opt.save_dir,'rendered_images'), exist_ok=True) | |
for i, img in enumerate(rendered_image_list): | |
plt.imsave(osp.join(opt.save_dir,'rendered_images','{:05d}.png'.format(i)), img) | |
os.makedirs(osp.join(opt.save_dir,'rendered_depth'), exist_ok=True) | |
os.makedirs(osp.join(opt.save_dir, | |
'rendered_images+depths'), exist_ok=True) | |
for i, img in enumerate(rendered_depth_list): | |
depth = np.array(img/img.max()*255,dtype=np.uint8) | |
depth = cv2.applyColorMap(depth, cv2.COLORMAP_TURBO) | |
plt.imsave(osp.join(opt.save_dir,'rendered_depth','{:05d}.png'.format(i)), depth) | |
image_and_depth = np.concatenate((rendered_image_list[i], depth), axis=0) | |
plt.imsave(osp.join(opt.save_dir,'rendered_images+depths','{:05d}.png'.format(i)), image_and_depth) | |
os.makedirs(osp.join(opt.save_dir,'sat_images'), exist_ok=True) | |
for i, (x,y) in enumerate(pixels): | |
# plt.plot(x, y, 'o', color='red') | |
sat_rgb = sat.permute(1,2,0).numpy() | |
sat_rgb = np.array(sat_rgb*255, dtype=np.uint8) | |
fig = plt.figure() | |
fig.set_size_inches(1,1,forward=False) | |
ax = plt.Axes(fig, [0., 0., 1., 1.]) | |
ax.set_axis_off() | |
ax.imshow(sat_rgb) | |
ax.plot(pixels[:i+1,0], pixels[:i+1,1], 'r-', color='red') | |
ax.plot(x, y, 'o', color='red', markersize=2) | |
# if i < len(pixels)-1: | |
# # ax.plot([x,pixels[0,0]],[y,pixels[0,1]],'r-') | |
# # else: | |
# ax.plot([x,pixels[i+1,0]],[y,pixels[i+1,1]],'r-') | |
fig.add_axes(ax) | |
plt.savefig(osp.join(opt.save_dir,'sat_images','{:05d}.png'.format(i)),bbox_inches='tight', pad_inches=0, dpi=256) | |
print('Done') | |
def test_interpolation(model,opt): | |
ckpt = torch.load(opt.test_ckpt_path, map_location='cpu') | |
model.netG.load_state_dict(ckpt['netG']) | |
model.netG.eval() | |
sat = img_read(opt.demo_img , size=opt.data.sat_size) | |
pano1 = img_read(opt.sty_img1 , size=opt.data.pano_size) | |
pano2 = img_read(opt.sty_img2 , size=opt.data.pano_size) | |
input_dict = {} | |
input_dict['sat'] = sat | |
input_dict['paths'] = opt.demo_img | |
# black_ground = torch.zeros_like(pano1) | |
sky_imgpath1 = opt.sty_img1.replace('image','sky') | |
sky_imgpath2 = opt.sty_img2.replace('image','sky') | |
sky = img_read(sky_imgpath1, size=opt.data.pano_size, datatype='L') | |
input_a = pano1*sky | |
sky_histc1 = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))]) | |
# for idx in range(len(input_a)): | |
# if idx == 0: | |
# sky_histc1 = input_a[idx].histc()[10:] | |
# else: | |
# sky_histc1 = torch.cat([input_a[idx].histc()[10:],sky_histc1],dim=0) | |
sky = img_read(sky_imgpath2, size=opt.data.pano_size, datatype='L') | |
input_b = pano2*sky | |
sky_histc2 = torch.cat([input_b[i].histc()[10:] for i in reversed(range(3))]) | |
# for idx in range(len(input_b)): | |
# if idx == 0: | |
# sky_histc2 = input_b[idx].histc()[10:] | |
# else: | |
# sky_histc2 = torch.cat([input_b[idx].histc()[10:],sky_histc2],dim=0) | |
for key in input_dict.keys(): | |
if isinstance(input_dict[key], torch.Tensor): | |
input_dict[key] = input_dict[key].unsqueeze(0) | |
model.set_input(input_dict) | |
pixels = [(128,128)] | |
x,y = pixels[0] | |
opt.origin_H_W = [(y-128)/128 , (x-128)/128] | |
print(opt.origin_H_W) | |
estimated_height = model.netG.depth_model(model.real_A) | |
geo_outputs = render(opt,model.real_A,estimated_height,model.netG.pano_direction,PE=model.netG.PE) | |
generator_inputs,opacity,depth = geo_outputs['rgb'],geo_outputs['opacity'],geo_outputs['depth'] | |
if model.netG.gen_cfg.cat_opa: | |
generator_inputs = torch.cat((generator_inputs,opacity),dim=1) | |
if model.netG.gen_cfg.cat_depth: | |
generator_inputs = torch.cat((generator_inputs,depth),dim=1) | |
_, _, z1 = model.netG.style_encode(sky_histc1.unsqueeze(0).to(model.device)) | |
_, _, z2 = model.netG.style_encode(sky_histc2.unsqueeze(0).to(model.device)) | |
num_inter = 60 | |
for i in range(num_inter): | |
z = z1 * (1-i/(num_inter-1)) + z2* (i/(num_inter-1)) | |
z = model.netG.style_model(z) | |
output_RGB = model.netG.denoise_model(generator_inputs,z) | |
save_img = output_RGB.cpu() | |
name = 'img{:03d}.png'.format(i) | |
torchvision.utils.save_image(save_img,os.path.join(opt.save_dir,name)) | |
img_list = sorted(os.listdir(opt.save_dir)) | |
sat_list = [] | |
for img in img_list: | |
sat_list.append(img) | |
media_path = os.path.join(opt.save_dir,'interpolation.mp4') | |
img_pair2vid(sat_list,opt.save_dir,media_path) | |
print('Done, save 2 ',media_path) | |
def main(): | |
log.process(os.getpid()) | |
log.title("[{}] (PyTorch code for testing Sat2Density and debug".format(sys.argv[0])) | |
opt_cmd = options.parse_arguments(sys.argv[1:]) | |
opt = options.set(opt_cmd=opt_cmd) | |
opt.isTrain = False | |
opt.name = opt.yaml if opt.name is None else opt.name | |
opt.batch_size = 1 | |
if opt.save_dir is None: | |
raise Exception("Please specify the save dir") | |
get_checkpoint(opt) | |
mode = importlib.import_module("model.{}".format(opt.model)) | |
m = mode.Model(opt) | |
# m.load_dataset(opt) | |
m.build_networks(opt) | |
if os.path.exists(opt.save_dir): | |
import shutil | |
shutil.rmtree(opt.save_dir) | |
if opt.task == 'test_vid': | |
test_vid(m, opt) | |
if opt.task == 'test_interpolation': | |
assert opt.sty_img1 | |
assert opt.sty_img2 | |
os.makedirs(opt.save_dir, exist_ok=True) | |
test_interpolation(m,opt) | |
# import pdb; pdb.set_trace() | |
# print(m) | |
# # test or visualization | |
# if opt.task == 'test_vid': | |
# m.test_vid(opt) | |
# elif opt.task == 'test_sty': | |
# m.test_sty(opt) | |
# elif opt.task == 'test_interpolation': | |
# m.test_interpolation(opt) | |
# else: | |
# raise RuntimeError("Unknow task") | |
if __name__ == "__main__": | |
main() |