import sys import os import os import html import glob import uuid import hashlib import requests from tqdm import tqdm os.system("git clone") import torch pretrained_model = dict(file_url='', alt_url='', file_size=330571863, file_md5='13b7ae859b28b37479ec84f1449d07fc7', file_path='./',) def download_file(session, file_spec, use_alt_url=False, chunk_size=128, num_attempts=10): file_path = file_spec['file_path'] if use_alt_url: file_url = file_spec['alt_url'] else: file_url = file_spec['file_url'] file_dir = os.path.dirname(file_path) tmp_path = file_path + '.tmp.' + uuid.uuid4().hex if file_dir: os.makedirs(file_dir, exist_ok=True) progress_bar = tqdm(total=file_spec['file_size'], unit='B', unit_scale=True) for attempts_left in reversed(range(num_attempts)): data_size = 0 progress_bar.reset() try: # Download. data_md5 = hashlib.md5() with session.get(file_url, stream=True) as res: res.raise_for_status() with open(tmp_path, 'wb') as f: for chunk in res.iter_content(chunk_size=chunk_size<<10): progress_bar.update(len(chunk)) f.write(chunk) data_size += len(chunk) data_md5.update(chunk) # Validate. if 'file_size' in file_spec and data_size != file_spec['file_size']: raise IOError('Incorrect file size', file_path) if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']: raise IOError('Incorrect file MD5', file_path) break except Exception as e: # print(e) # Last attempt => raise error. if not attempts_left: raise # Handle Google Drive virus checker nag. if data_size > 0 and data_size < 8192: with open(tmp_path, 'rb') as f: data = links = [html.unescape(link) for link in data.decode('utf-8').split('"') if 'confirm=t' in link] if len(links) == 1: file_url = requests.compat.urljoin(file_url, links[0]) continue progress_bar.close() # Rename temp file to the correct name. os.replace(tmp_path, file_path) # atomic # Attempt to clean up any leftover temps. for filename in glob.glob(file_path + '.tmp.*'): try: os.remove(filename) except: pass print('Downloading SceneDreamer pretrained model...') with requests.Session() as session: try: download_file(session, pretrained_model) except: print('Google Drive download failed.\n') import os import torch import argparse from imaginaire.config import Config from imaginaire.utils.cudnn import init_cudnn from imaginaire.utils.dataset import get_test_dataloader from imaginaire.utils.distributed import init_dist from imaginaire.utils.gpu_affinity import set_affinity from import get_checkpoint as get_checkpoint from imaginaire.utils.logging import init_logging from imaginaire.utils.trainer import \ (get_model_optimizer_and_scheduler, set_random_seed) import imaginaire.config import gradio as gr from PIL import Image def parse_args(): parser = argparse.ArgumentParser(description='Training') parser.add_argument('--config', type=str, default='./configs/scenedreamer_inference.yaml' help='Path to the training config file.') parser.add_argument('--checkpoint', default='./', help='Checkpoint path.') parser.add_argument('--output_dir', type=str, default='./test/', help='Location to save the image outputs') parser.add_argument('--seed', type=int, default=8888, help='Random seed.') args = parser.parse_args() return args args = parse_args() set_random_seed(args.seed, by_rank=False) cfg = Config(args.config) # Initialize cudnn. init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark) # Initialize data loaders and models. net_G = get_model_optimizer_and_scheduler(cfg, seed=args.seed, generator_only=True) if args.checkpoint == '': raise NotImplementedError("No checkpoint is provided for inference!") # Load checkpoint. # trainer.load_checkpoint(cfg, args.checkpoint) checkpoint = torch.load(args.checkpoint, map_location='cpu') net_G.load_state_dict(checkpoint['net_G']) # Do inference. net_G = net_G.module net_G.eval() for name, param in net_G.named_parameters(): param.requires_grad = False torch.cuda.empty_cache() world_dir = os.path.join(args.output_dir) os.makedirs(world_dir, exist_ok=True) def get_bev(seed): print('[PCGGenerator] Generating BEV scene representation...') os.system('python --size {} --seed {} --outdir {}'.format(net_G.voxel.sample_size, seed, world_dir)) heightmap_path = os.path.join(world_dir, 'heightmap.png') semantic_path = os.path.join(world_dir, 'semanticmap.png') heightmap = semantic = return semantic, heightmap def get_video(seed, num_frames): device = torch.device('cuda') rng_cuda = torch.Generator(device=device) rng_cuda = rng_cuda.manual_seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) net_G.voxel.next_world(device, world_dir, checkpoint) cam_mode = cfg.inference_args.camera_mode current_outdir = os.path.join(world_dir, 'camera_{:02d}'.format(cam_mode)) os.makedirs(current_outdir, exist_ok=True) z = torch.empty(1, net_G.style_dims, dtype=torch.float32, device=device) z.normal_(generator=rng_cuda) net_G.inference_givenstyle(z, current_outdir, **vars(cfg.inference_args)) return os.path.join(current_outdir, ‘rgb_render.mp4’) markdown=f''' # SceneDreamer: Unbounded 3D Scene Generation from 2D Image Collections Authored by Zhaoxi Chen, Guangcong Wang, Ziwei Liu ### Useful links: - [Official Github Repo]( - [Project Page]( - [arXiv Link]( Licensed under the S-Lab License. First use the button "Generate BEV" to randomly sample a 3D world represented by a height map and a semantic map. Then push the button "Render" to generate a camera trajectory flying through the world. ''' with gr.Blocks() as demo: with gr.Row(): with gr.Column(): gr.Markdown(markdown) with gr.Column(): with gr.Row(): with gr.Column(): semantic = gr.Image(type="pil",shape=(2048, 2048)) with gr.Column(): height = gr.Image(type="pil",shape=(2048, 2048)) with gr.Row(): # with gr.Column(): # image = gr.Image(type='pil', shape(540, 960)) with gr.Column(): video=gr.Video() with gr.Row(): num_frames = gr.Slider(minimum=40, maximum=200, value=40, label='Number of frames for video generation') user_seed = gr.Slider(minimum=0, maximum=999999, value=8888, label='Random seed to control styles and scenes') with gr.Row(): btn = gr.Button(value="Generate BEV") btn_2=gr.Button(value="Render"),[user_seed],[semantic, height]),[user_seed, num_frames],[video]) demo.launch(debug=True)