FrozenBurning commited on
Commit
a699001
1 Parent(s): d199afa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import os
4
+ import html
5
+ import glob
6
+ import uuid
7
+ import hashlib
8
+ import requests
9
+ from tqdm import tqdm
10
+
11
+ os.system("git clone https://github.com/FrozenBurning/SceneDreamer.git")
12
+
13
+ import torch
14
+
15
+ pretrained_model = dict(file_url='https://drive.google.com/uc?id=1IFu1vNrgF1EaRqPizyEgN_5Vt7Fyg0Mj',
16
+ alt_url='', file_size=330571863, file_md5='13b7ae859b28b37479ec84f1449d07fc7',
17
+ file_path='./scenedreamer_released.pt',)
18
+
19
+
20
+ def download_file(session, file_spec, use_alt_url=False, chunk_size=128, num_attempts=10):
21
+ file_path = file_spec['file_path']
22
+ if use_alt_url:
23
+ file_url = file_spec['alt_url']
24
+ else:
25
+ file_url = file_spec['file_url']
26
+
27
+ file_dir = os.path.dirname(file_path)
28
+ tmp_path = file_path + '.tmp.' + uuid.uuid4().hex
29
+ if file_dir:
30
+ os.makedirs(file_dir, exist_ok=True)
31
+
32
+ progress_bar = tqdm(total=file_spec['file_size'], unit='B', unit_scale=True)
33
+ for attempts_left in reversed(range(num_attempts)):
34
+ data_size = 0
35
+ progress_bar.reset()
36
+ try:
37
+ # Download.
38
+ data_md5 = hashlib.md5()
39
+ with session.get(file_url, stream=True) as res:
40
+ res.raise_for_status()
41
+ with open(tmp_path, 'wb') as f:
42
+ for chunk in res.iter_content(chunk_size=chunk_size<<10):
43
+ progress_bar.update(len(chunk))
44
+ f.write(chunk)
45
+ data_size += len(chunk)
46
+ data_md5.update(chunk)
47
+
48
+ # Validate.
49
+ if 'file_size' in file_spec and data_size != file_spec['file_size']:
50
+ raise IOError('Incorrect file size', file_path)
51
+ if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']:
52
+ raise IOError('Incorrect file MD5', file_path)
53
+ break
54
+
55
+ except Exception as e:
56
+ # print(e)
57
+ # Last attempt => raise error.
58
+ if not attempts_left:
59
+ raise
60
+
61
+ # Handle Google Drive virus checker nag.
62
+ if data_size > 0 and data_size < 8192:
63
+ with open(tmp_path, 'rb') as f:
64
+ data = f.read()
65
+ links = [html.unescape(link) for link in data.decode('utf-8').split('"') if 'confirm=t' in link]
66
+ if len(links) == 1:
67
+ file_url = requests.compat.urljoin(file_url, links[0])
68
+ continue
69
+
70
+ progress_bar.close()
71
+
72
+ # Rename temp file to the correct name.
73
+ os.replace(tmp_path, file_path) # atomic
74
+
75
+ # Attempt to clean up any leftover temps.
76
+ for filename in glob.glob(file_path + '.tmp.*'):
77
+ try:
78
+ os.remove(filename)
79
+ except:
80
+ pass
81
+
82
+ print('Downloading SceneDreamer pretrained model...')
83
+ with requests.Session() as session:
84
+ try:
85
+ download_file(session, pretrained_model)
86
+ except:
87
+ print('Google Drive download failed.\n')
88
+
89
+
90
+
91
+ import os
92
+ import torch
93
+ import argparse
94
+ from imaginaire.config import Config
95
+ from imaginaire.utils.cudnn import init_cudnn
96
+ from imaginaire.utils.dataset import get_test_dataloader
97
+ from imaginaire.utils.distributed import init_dist
98
+ from imaginaire.utils.gpu_affinity import set_affinity
99
+ from imaginaire.utils.io import get_checkpoint as get_checkpoint
100
+ from imaginaire.utils.logging import init_logging
101
+ from imaginaire.utils.trainer import \
102
+ (get_model_optimizer_and_scheduler, set_random_seed)
103
+ import imaginaire.config
104
+ import gradio as gr
105
+ from PIL import Image
106
+
107
+ def parse_args():
108
+ parser = argparse.ArgumentParser(description='Training')
109
+ parser.add_argument('--config', type=str, default='./configs/scenedreamer_inference.yaml'
110
+ help='Path to the training config file.')
111
+ parser.add_argument('--checkpoint', default='./scenedreamer_released.pt',
112
+ help='Checkpoint path.')
113
+ parser.add_argument('--output_dir', type=str, default='./test/',
114
+ help='Location to save the image outputs')
115
+ parser.add_argument('--seed', type=int, default=8888,
116
+ help='Random seed.')
117
+ args = parser.parse_args()
118
+ return args
119
+
120
+
121
+ args = parse_args()
122
+ set_random_seed(args.seed, by_rank=False)
123
+ cfg = Config(args.config)
124
+
125
+ # Initialize cudnn.
126
+ init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)
127
+
128
+ # Initialize data loaders and models.
129
+ net_G = get_model_optimizer_and_scheduler(cfg, seed=args.seed, generator_only=True)
130
+
131
+ if args.checkpoint == '':
132
+ raise NotImplementedError("No checkpoint is provided for inference!")
133
+
134
+ # Load checkpoint.
135
+ # trainer.load_checkpoint(cfg, args.checkpoint)
136
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
137
+ net_G.load_state_dict(checkpoint['net_G'])
138
+
139
+ # Do inference.
140
+ net_G = net_G.module
141
+ net_G.eval()
142
+ for name, param in net_G.named_parameters():
143
+ param.requires_grad = False
144
+ torch.cuda.empty_cache()
145
+ world_dir = os.path.join(args.output_dir)
146
+ os.makedirs(world_dir, exist_ok=True)
147
+
148
+
149
+
150
+ def get_bev(seed):
151
+ print('[PCGGenerator] Generating BEV scene representation...')
152
+ os.system('python terrain_generator.py --size {} --seed {} --outdir {}'.format(net_G.voxel.sample_size, seed, world_dir))
153
+ heightmap_path = os.path.join(world_dir, 'heightmap.png')
154
+ semantic_path = os.path.join(world_dir, 'semanticmap.png')
155
+ heightmap = Image.open(heightmap_path)
156
+ semantic = Image.open(semantic_path)
157
+ return semantic, heightmap
158
+ def get_video(seed, num_frames):
159
+ device = torch.device('cuda')
160
+ rng_cuda = torch.Generator(device=device)
161
+ rng_cuda = rng_cuda.manual_seed(seed)
162
+ torch.manual_seed(seed)
163
+ torch.cuda.manual_seed(seed)
164
+ net_G.voxel.next_world(device, world_dir, checkpoint)
165
+ cam_mode = cfg.inference_args.camera_mode
166
+ current_outdir = os.path.join(world_dir, 'camera_{:02d}'.format(cam_mode))
167
+ os.makedirs(current_outdir, exist_ok=True)
168
+ z = torch.empty(1, net_G.style_dims, dtype=torch.float32, device=device)
169
+ z.normal_(generator=rng_cuda)
170
+ net_G.inference_givenstyle(z, current_outdir, **vars(cfg.inference_args))
171
+ return os.path.join(current_outdir, ‘rgb_render.mp4’)
172
+
173
+ markdown=f'''
174
+ # SceneDreamer: Unbounded 3D Scene Generation from 2D Image Collections
175
+
176
+ Authored by Zhaoxi Chen, Guangcong Wang, Ziwei Liu
177
+
178
+ ### Useful links:
179
+ - [Official Github Repo](https://github.com/FrozenBurning/SceneDreamer)
180
+ - [Project Page](https://scene-dreamer.github.io/)
181
+ - [arXiv Link](https://arxiv.org/abs/2302.01330)
182
+ Licensed under the S-Lab License.
183
+ First use the button "Generate BEV" to randomly sample a 3D world represented by a height map and a semantic map. Then push the button "Render" to generate a camera trajectory flying through the world.
184
+ '''
185
+
186
+ with gr.Blocks() as demo:
187
+ with gr.Row():
188
+ with gr.Column():
189
+ gr.Markdown(markdown)
190
+ with gr.Column():
191
+ with gr.Row():
192
+ with gr.Column():
193
+ semantic = gr.Image(type="pil",shape=(2048, 2048))
194
+ with gr.Column():
195
+ height = gr.Image(type="pil",shape=(2048, 2048))
196
+ with gr.Row():
197
+ # with gr.Column():
198
+ # image = gr.Image(type='pil', shape(540, 960))
199
+ with gr.Column():
200
+ video=gr.Video()
201
+ with gr.Row():
202
+ num_frames = gr.Slider(minimum=40, maximum=200, value=40, label='Number of frames for video generation')
203
+ user_seed = gr.Slider(minimum=0, maximum=999999, value=8888, label='Random seed to control styles and scenes')
204
+
205
+ with gr.Row():
206
+ btn = gr.Button(value="Generate BEV")
207
+ btn_2=gr.Button(value="Render")
208
+
209
+ btn.click(get_bev,[user_seed],[semantic, height])
210
+ btn_2.click(get_video,[user_seed, num_frames],[video])
211
+
212
+ demo.launch(debug=True)