File size: 5,598 Bytes
a699001
1299110
a699001
 
 
 
 
 
 
 
8e5d73b
f285f50
8b9a803
a699001
 
 
8b9a803
 
a699001
 
 
 
 
 
8b9a803
 
 
 
 
 
 
 
 
 
 
 
 
a699001
 
5d7e7c2
a699001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b9a803
 
 
 
 
a699001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d7e7c2
a699001
 
 
5d7e7c2
f0b144f
a699001
 
 
 
 
 
 
f0b144f
 
a699001
 
 
 
 
5d7e7c2
a699001
 
 
 
 
 
 
 
 
 
f0b144f
 
 
 
 
a699001
 
 
 
 
 
 
 
 
57a8718
a699001
57a8718
a699001
 
 
 
f0b144f
a699001
f0b144f
 
 
 
a699001
 
 
 
 
 
f0b144f
a699001
5d7e7c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import sys
import html
import glob
import uuid
import hashlib
import requests
from tqdm import tqdm

os.system("git clone https://github.com/FrozenBurning/SceneDreamer.git")
os.system("cp -r SceneDreamer/* ./")
os.system("bash install.sh")


import os
import torch
import torch.nn as nn
import importlib
import argparse
from imaginaire.config import Config
from imaginaire.utils.cudnn import init_cudnn
import gradio as gr
from PIL import Image


class WrappedModel(nn.Module):
    r"""Dummy wrapping the module.
    """

    def __init__(self, module):
        super(WrappedModel, self).__init__()
        self.module = module

    def forward(self, *args, **kwargs):
        r"""PyTorch module forward function overload."""
        return self.module(*args, **kwargs)

def parse_args():
    parser = argparse.ArgumentParser(description='Training')
    parser.add_argument('--config', type=str, default='./configs/scenedreamer_inference.yaml', help='Path to the training config file.')
    parser.add_argument('--checkpoint', default='./scenedreamer_released.pt',
                        help='Checkpoint path.')
    parser.add_argument('--output_dir', type=str, default='./test/',
                        help='Location to save the image outputs')
    parser.add_argument('--seed', type=int, default=8888,
                        help='Random seed.')
    args = parser.parse_args()
    return args


args = parse_args()
cfg = Config(args.config)

# Initialize cudnn.
init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)

# Initialize data loaders and models.

lib_G = importlib.import_module(cfg.gen.type)
net_G = lib_G.Generator(cfg.gen, cfg.data)
net_G = net_G.to('cuda')
net_G = WrappedModel(net_G)

if args.checkpoint == '':
    raise NotImplementedError("No checkpoint is provided for inference!")

# Load checkpoint.
# trainer.load_checkpoint(cfg, args.checkpoint)
checkpoint = torch.load(args.checkpoint, map_location='cpu')
net_G.load_state_dict(checkpoint['net_G'])

# Do inference.
net_G = net_G.module
net_G.eval()
for name, param in net_G.named_parameters():
    param.requires_grad = False
torch.cuda.empty_cache()
world_dir = os.path.join(args.output_dir)
os.makedirs(world_dir, exist_ok=True)



def get_bev(seed):
    print('[PCGGenerator] Generating BEV scene representation...')
    os.system('python terrain_generator.py --size {} --seed {} --outdir {}'.format(net_G.voxel.sample_size, seed, world_dir))
    heightmap_path = os.path.join(world_dir, 'heightmap.png')
    semantic_path = os.path.join(world_dir, 'colormap.png')
    heightmap = Image.open(heightmap_path)
    semantic = Image.open(semantic_path)
    return semantic, heightmap

def get_video(seed, num_frames, reso_h, reso_w):
    device = torch.device('cuda')
    rng_cuda = torch.Generator(device=device)
    rng_cuda = rng_cuda.manual_seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    net_G.voxel.next_world(device, world_dir, checkpoint)
    cam_mode = cfg.inference_args.camera_mode
    cfg.inference_args.cam_maxstep = num_frames
    cfg.inference_args.resolution_hw = [reso_h, reso_w]
    current_outdir = os.path.join(world_dir, 'camera_{:02d}'.format(cam_mode))
    os.makedirs(current_outdir, exist_ok=True)
    z = torch.empty(1, net_G.style_dims, dtype=torch.float32, device=device)
    z.normal_(generator=rng_cuda)
    net_G.inference_givenstyle(z, current_outdir, **vars(cfg.inference_args))
    return os.path.join(current_outdir, 'rgb_render.mp4')

markdown=f'''
  # SceneDreamer: Unbounded 3D Scene Generation from 2D Image Collections
  
  Authored by Zhaoxi Chen, Guangcong Wang, Ziwei Liu
  ### Useful links:
  - [Official Github Repo](https://github.com/FrozenBurning/SceneDreamer)
  - [Project Page](https://scene-dreamer.github.io/)
  - [arXiv Link](https://arxiv.org/abs/2302.01330)
  Licensed under the S-Lab License.


  We offer a sampled scene whose BEVs are shown on the right. You can also use the button "Generate BEV" to randomly sample a new 3D world represented by a height map and a semantic map. But it requires a long time. 
  
  To render video, push the button "Render" to generate a camera trajectory flying through the world. You can specify rendering options as shown below!
'''

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown(markdown)
        with gr.Column():
            with gr.Row():
                with gr.Column():
                    semantic = gr.Image(value='./test/colormap.png',type="pil", shape=(512, 512))
                with gr.Column():
                    height = gr.Image(value='./test/heightmap.png', type="pil", shape=(512, 512))
            with gr.Row():
                # with gr.Column():
                #     image = gr.Image(type='pil', shape(540, 960))
                with gr.Column():
                    video = gr.Video()
    with gr.Row():
        num_frames = gr.Slider(minimum=10, maximum=200, value=20, step=1, label='Number of rendered frames')
        user_seed = gr.Slider(minimum=0, maximum=999999, value=8888, step=1, label='Random seed')
        resolution_h = gr.Slider(minimum=256, maximum=2160, value=270, step=1, label='Height of rendered image')
        resolution_w = gr.Slider(minimum=256, maximum=3840, value=480, step=1, label='Width of rendered image')

    with gr.Row():
        btn = gr.Button(value="Generate BEV")
        btn_2=gr.Button(value="Render")

    btn.click(get_bev,[user_seed],[semantic, height])
    btn_2.click(get_video,[user_seed, num_frames, resolution_h, resolution_w], [video])

demo.launch(debug=True)