import gradio as gr import numpy as np import os from PIL import Image import torch import torchvision.transforms as transforms import options import test import importlib from scipy.interpolate import interp1d, splev, splprep import cv2 def get_single(sat_img, style_img, x_offset, y_offset): name = '' for i in [name for name in os.listdir('demo_img') if 'case' in name]: style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB') style =np.array(style) if (style == style_img).all(): name = i break input_dict = {} trans = transforms.ToTensor() input_dict['sat'] = trans(sat_img) input_dict['pano'] = trans(style_img) input_dict['paths'] = "demo.png" sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L")) input_a = input_dict['pano']*sky sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))]) input_dict['sky_histc'] = sky_histc input_dict['sky_mask'] = sky for key in input_dict.keys(): if isinstance(input_dict[key], torch.Tensor): input_dict[key] = input_dict[key].unsqueeze(0) args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png", "--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"] opt_cmd = options.parse_arguments(args=args) opt = options.set(opt_cmd=opt_cmd) opt.isTrain = False opt.name = opt.yaml if opt.name is None else opt.name opt.batch_size = 1 m = importlib.import_module("model.{}".format(opt.model)) model = m.Model(opt) # m.load_dataset(opt) model.build_networks(opt) ckpt = torch.load(opt.test_ckpt_path, map_location='cpu') model.netG.load_state_dict(ckpt['netG']) model.netG.eval() model.set_input(input_dict) model.style_temp = model.sky_histc opt.origin_H_W = [-(y_offset*256-128)/128, (x_offset*256-128)/128] # TODO: hard code should be removed in the future model.forward(opt) rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0)) rgb = np.array(rgb*255, dtype=np.uint8) return rgb def get_video(sat_img, style_img, positions): name = '' for i in [name for name in os.listdir('demo_img') if 'case' in name]: style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB') style =np.array(style) if (style == style_img).all(): name = i break input_dict = {} trans = transforms.ToTensor() input_dict['sat'] = trans(sat_img) input_dict['pano'] = trans(style_img) input_dict['paths'] = "demo.png" sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L")) input_a = input_dict['pano']*sky sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))]) input_dict['sky_histc'] = sky_histc input_dict['sky_mask'] = sky for key in input_dict.keys(): if isinstance(input_dict[key], torch.Tensor): input_dict[key] = input_dict[key].unsqueeze(0) args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png", "--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"] opt_cmd = options.parse_arguments(args=args) opt = options.set(opt_cmd=opt_cmd) opt.isTrain = False opt.name = opt.yaml if opt.name is None else opt.name opt.batch_size = 1 m = importlib.import_module("model.{}".format(opt.model)) model = m.Model(opt) # m.load_dataset(opt) model.build_networks(opt) ckpt = torch.load(opt.test_ckpt_path, map_location='cpu') model.netG.load_state_dict(ckpt['netG']) model.netG.eval() model.set_input(input_dict) model.style_temp = model.sky_histc unique_lst = list(dict.fromkeys(positions)) pixels = [] for x in positions: if x in unique_lst: if x not in pixels: pixels.append(x) pixels = np.array(pixels) tck, u = splprep(pixels.T, s=25, per=0) u_new = np.linspace(u.min(), u.max(), 80) x_new, y_new = splev(u_new, tck) smooth_path = np.array([x_new,y_new]).T rendered_image_list = [] rendered_depth_list = [] for i, (x,y) in enumerate(smooth_path): opt.origin_H_W = [(y-128)/128, (x-128)/128] # TODO: hard code should be removed in the future print('Rendering at ({}, {})'.format(x,y)) model.forward(opt) rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0)) rgb = np.array(rgb*255, dtype=np.uint8) rendered_image_list.append(rgb) rendered_depth_list.append( model.out_put.depth[0,0].cpu().detach().numpy() ) output_video_path = 'output_video.mp4' # 设置视频的帧率、宽度和高度 frame_rate = 15 frame_width = 512 frame_height = 128 # 使用OpenCV创建视频写入对象,选择H.264编码器 fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (frame_width, frame_height)) # 遍历图像列表并将它们写入视频 for image_np in rendered_image_list: image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) out.write(image_np) # 释放视频写入对象 out.release() return "output_video.mp4" def copy_image(image): return image def show_image_and_point(image, x, y): x = int(x*image.shape[1]) y = image.shape[0]-int(y*image.shape[0]) mask = np.zeros(image.shape[:2]) radius = min(image.shape[0], image.shape[1])//60 for i in range(x-radius-2, x+radius+2): for j in range(y-radius-2, y+radius+2): if (i-x)**2+(j-y)**2<=radius**2: mask[j, i] = 1 return (image, [(mask, 'render point')]) def add_select_point(image, evt: gr.SelectData, state1): if state1 == None: state1 = [] x, y = evt.index state1.append((x, y)) print(state1) radius = min(image.shape[0], image.shape[1])//60 for i in range(x-radius-2, x+radius+2): for j in range(y-radius-2, y+radius+2): if (i-x)**2+(j-y)**2<=radius**2: image[j, i, :] = 0 return image, state1 def reset_select_points(image): return image, [] with gr.Blocks() as demo: gr.Markdown("# Sat2Density Demos") gr.Markdown("### select/upload the satllite image and select the style image") with gr.Row(): with gr.Column(): sat_img = gr.Image(source='upload', shape=[256, 256], interactive=True) img_examples = gr.Examples(examples=['demo_img/{}/satview-input.png'.format(i) for i in os.listdir('demo_img') if 'case' in i], inputs=sat_img, outputs=None, examples_per_page=20) with gr.Column(): style_img = gr.Image() style_examples = gr.Examples(examples=['demo_img/{}/groundview.image.png'.format(i) for i in os.listdir('demo_img') if 'case' in i], inputs=style_img, outputs=None, examples_per_page=20) gr.Markdown("### select a certain point to generate single groundview image") with gr.Row(): with gr.Column(): with gr.Row(): with gr.Column(): slider_x = gr.Slider(0.2, 0.8, 0.5, label="x-axis position") slider_y = gr.Slider(0.2, 0.8, 0.5, label="y-axis position") btn_single = gr.Button(label="demo1") annotation_image = gr.AnnotatedImage() out_single = gr.Image() gr.Markdown("### draw a trajectory on the map to generate video") state_select_points = gr.State() with gr.Row(): with gr.Column(): draw_img = gr.Image(shape=[256, 256], interactive=True) with gr.Column(): out_video = gr.Video() reset_btn =gr.Button(value="Reset") btn_video = gr.Button(label="demo1") sat_img.change(copy_image, inputs = sat_img, outputs=draw_img) draw_img.select(add_select_point, [draw_img, state_select_points], [draw_img, state_select_points]) sat_img.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image) slider_x.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden') slider_y.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden') btn_single.click(get_single, inputs = [sat_img, style_img, slider_x, slider_y], outputs=out_single) reset_btn.click(reset_select_points, [sat_img], [draw_img, state_select_points]) btn_video.click(get_video, inputs=[sat_img, style_img, state_select_points], outputs=out_video) # 触发 demo.launch()