SceneDreamer / app.py
FrozenBurning's picture
Update app.py
f285f50
raw history blame
No virus
7.73 kB
import os
import sys
import html
import glob
import uuid
import hashlib
import requests
from tqdm import tqdm
os.system("git clone https://github.com/FrozenBurning/SceneDreamer.git")
os.system("cp -r SceneDreamer/* ./")
os.system("bash install.sh")
pretrained_model = dict(file_url='https://drive.google.com/uc?id=1IFu1vNrgF1EaRqPizyEgN_5Vt7Fyg0Mj',
alt_url='', file_size=330571863,
file_path='./scenedreamer_released.pt',)
def download_file(session, file_spec, use_alt_url=False, chunk_size=128, num_attempts=10):
file_path = file_spec['file_path']
if use_alt_url:
file_url = file_spec['alt_url']
else:
file_url = file_spec['file_url']
file_dir = os.path.dirname(file_path)
tmp_path = file_path + '.tmp.' + uuid.uuid4().hex
if file_dir:
os.makedirs(file_dir, exist_ok=True)
progress_bar = tqdm(total=file_spec['file_size'], unit='B', unit_scale=True)
for attempts_left in reversed(range(num_attempts)):
data_size = 0
progress_bar.reset()
try:
# Download.
data_md5 = hashlib.md5()
with session.get(file_url, stream=True) as res:
res.raise_for_status()
with open(tmp_path, 'wb') as f:
for chunk in res.iter_content(chunk_size=chunk_size<<10):
progress_bar.update(len(chunk))
f.write(chunk)
data_size += len(chunk)
data_md5.update(chunk)
# Validate.
if 'file_size' in file_spec and data_size != file_spec['file_size']:
raise IOError('Incorrect file size', file_path)
if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']:
raise IOError('Incorrect file MD5', file_path)
break
except Exception as e:
# print(e)
# Last attempt => raise error.
if not attempts_left:
raise
# Handle Google Drive virus checker nag.
if data_size > 0 and data_size < 8192:
with open(tmp_path, 'rb') as f:
data = f.read()
links = [html.unescape(link) for link in data.decode('utf-8').split('"') if 'confirm=t' in link]
if len(links) == 1:
file_url = requests.compat.urljoin(file_url, links[0])
continue
progress_bar.close()
# Rename temp file to the correct name.
os.replace(tmp_path, file_path) # atomic
# Attempt to clean up any leftover temps.
for filename in glob.glob(file_path + '.tmp.*'):
try:
os.remove(filename)
except:
pass
print('Downloading SceneDreamer pretrained model...')
with requests.Session() as session:
try:
download_file(session, pretrained_model)
except:
print('Google Drive download failed.\n')
import os
import torch
import torch.nn as nn
import importlib
import argparse
from imaginaire.config import Config
from imaginaire.utils.cudnn import init_cudnn
import gradio as gr
from PIL import Image
class WrappedModel(nn.Module):
r"""Dummy wrapping the module.
"""
def __init__(self, module):
super(WrappedModel, self).__init__()
self.module = module
def forward(self, *args, **kwargs):
r"""PyTorch module forward function overload."""
return self.module(*args, **kwargs)
def parse_args():
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--config', type=str, default='./configs/scenedreamer_inference.yaml', help='Path to the training config file.')
parser.add_argument('--checkpoint', default='./scenedreamer_released.pt',
help='Checkpoint path.')
parser.add_argument('--output_dir', type=str, default='./test/',
help='Location to save the image outputs')
parser.add_argument('--seed', type=int, default=8888,
help='Random seed.')
args = parser.parse_args()
return args
args = parse_args()
cfg = Config(args.config)
# Initialize cudnn.
init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)
# Initialize data loaders and models.
lib_G = importlib.import_module(cfg.gen.type)
net_G = lib_G.Generator(cfg.gen, cfg.data)
net_G = net_G.to('cuda')
net_G = WrappedModel(net_G)
if args.checkpoint == '':
raise NotImplementedError("No checkpoint is provided for inference!")
# Load checkpoint.
# trainer.load_checkpoint(cfg, args.checkpoint)
checkpoint = torch.load(args.checkpoint, map_location='cpu')
net_G.load_state_dict(checkpoint['net_G'])
# Do inference.
net_G = net_G.module
net_G.eval()
for name, param in net_G.named_parameters():
param.requires_grad = False
torch.cuda.empty_cache()
world_dir = os.path.join(args.output_dir)
os.makedirs(world_dir, exist_ok=True)
def get_bev(seed):
print('[PCGGenerator] Generating BEV scene representation...')
os.system('python terrain_generator.py --size {} --seed {} --outdir {}'.format(net_G.voxel.sample_size, seed, world_dir))
heightmap_path = os.path.join(world_dir, 'heightmap.png')
semantic_path = os.path.join(world_dir, 'colormap.png')
heightmap = Image.open(heightmap_path)
semantic = Image.open(semantic_path)
return semantic, heightmap
def get_video(seed, num_frames):
device = torch.device('cuda')
rng_cuda = torch.Generator(device=device)
rng_cuda = rng_cuda.manual_seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
net_G.voxel.next_world(device, world_dir, checkpoint)
cam_mode = cfg.inference_args.camera_mode
current_outdir = os.path.join(world_dir, 'camera_{:02d}'.format(cam_mode))
os.makedirs(current_outdir, exist_ok=True)
z = torch.empty(1, net_G.style_dims, dtype=torch.float32, device=device)
z.normal_(generator=rng_cuda)
net_G.inference_givenstyle(z, current_outdir, **vars(cfg.inference_args))
return os.path.join(current_outdir, 'rgb_render.mp4')
markdown=f'''
# SceneDreamer: Unbounded 3D Scene Generation from 2D Image Collections
Authored by Zhaoxi Chen, Guangcong Wang, Ziwei Liu
### Useful links:
- [Official Github Repo](https://github.com/FrozenBurning/SceneDreamer)
- [Project Page](https://scene-dreamer.github.io/)
- [arXiv Link](https://arxiv.org/abs/2302.01330)
Licensed under the S-Lab License.
First use the button "Generate BEV" to randomly sample a 3D world represented by a height map and a semantic map. Then push the button "Render" to generate a camera trajectory flying through the world.
'''
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown(markdown)
with gr.Column():
with gr.Row():
with gr.Column():
semantic = gr.Image(type="pil", shape=(2048, 2048))
with gr.Column():
height = gr.Image(type="pil", shape=(2048, 2048))
with gr.Row():
# with gr.Column():
# image = gr.Image(type='pil', shape(540, 960))
with gr.Column():
video=gr.Video()
with gr.Row():
num_frames = gr.Slider(minimum=10, maximum=200, value=10, label='Number of rendered frames')
user_seed = gr.Slider(minimum=0, maximum=999999, value=8888, label='Random seed')
with gr.Row():
btn = gr.Button(value="Generate BEV")
btn_2=gr.Button(value="Render")
btn.click(get_bev,[user_seed],[semantic, height])
btn_2.click(get_video,[user_seed, num_frames],[video])
demo.launch(debug=True)