Your Name commited on
Commit
8b73ab4
1 Parent(s): cf373a3
Files changed (5) hide show
  1. .gitignore +2 -0
  2. Dockerfile +3 -0
  3. app.py +124 -204
  4. requirements.txt +1 -0
  5. test.py +60 -12
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  *.pyc
2
  *.pth
 
 
 
1
  *.pyc
2
  *.pth
3
+ *.whl
4
+ *.mp4
Dockerfile CHANGED
@@ -38,6 +38,9 @@ RUN pip install --no-cache-dir --upgrade -r requirements.txt
38
 
39
  RUN wget https://www.dropbox.com/scl/fi/105qy7mkqfjcmnfd3tmv0/edit.pth?rlkey=qcd67cdrqz4jra0p3er966iuk -O clevr.pth
40
 
 
 
 
41
  ENV TORCH_EXTENSIONS_DIR=/home/user/.cache
42
 
43
 
 
38
 
39
  RUN wget https://www.dropbox.com/scl/fi/105qy7mkqfjcmnfd3tmv0/edit.pth?rlkey=qcd67cdrqz4jra0p3er966iuk -O clevr.pth
40
 
41
+ RUN wget https://www.dropbox.com/scl/fi/k5qc5y5rmhuru5eztegbn/gradio_draggable-0.0.1-py3-none-any.whl
42
+ RUN pip install gradio_draggable-0.0.1-py3-none-any.whl
43
+
44
  ENV TORCH_EXTENSIONS_DIR=/home/user/.cache
45
 
46
 
app.py CHANGED
@@ -1,17 +1,20 @@
1
- print('start!', flush=True)
2
  import gradio as gr
3
  from models import build_model
4
  from PIL import Image
5
  import numpy as np
6
  import torchvision
 
7
  import ninja
8
  import torch
9
  from tqdm import trange
10
  import imageio
11
  import requests
12
  import argparse
 
 
 
 
13
 
14
- print('load!', flush=True)
15
  checkpoint = 'clevr.pth'
16
  state = torch.load(checkpoint, map_location='cpu')
17
  G = build_model(**state['model_kwargs_init']['generator_smooth'])
@@ -23,7 +26,25 @@ G_kwargs= dict(noise_mode='const',
23
  fused_modulate=False,
24
  impl='cuda',
25
  fp16_res=None)
26
- print('load finish', flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def trans(x, y, z, length):
29
  w = h = length
@@ -31,8 +52,29 @@ def trans(x, y, z, length):
31
  y = 0.5 * h - 128 + (y/9 + .5) * 256
32
  z = z / 9 * 256
33
  return x, y, z
34
- def get_bev_from_objs(objs, length=256, scale = 6):
35
- h, w = length, length *scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  nc = 14
37
  canvas = np.zeros([h, w, nc])
38
  xx = np.ones([h,w]).cumsum(0)
@@ -57,216 +99,94 @@ def get_bev_from_objs(objs, length=256, scale = 6):
57
  mask = ((xx-x)**2 + (y-yy)**2) ** 0.5 <= z
58
  canvas[mask] = feat
59
  canvas = np.transpose(canvas, [2, 0, 1]).astype(np.float32)
60
- rotate_angle = 0
61
- canvas = torchvision.transforms.functional.rotate(torch.tensor(canvas), rotate_angle).numpy()
62
  return canvas
63
-
64
- # COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'brown', 'blue']
65
- COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'purple', 'blue']
66
- SHAPE_NAME_LIST = ['cube', 'sphere', 'cylinder']
67
- MATERIAL_NAME_LIST = ['rubber', 'metal']
68
-
69
- xy_lib = dict()
70
- xy_lib['B'] = [
71
- [-2, -1],
72
- [-1, -1],
73
- [-2, 0],
74
- [-2, 1],
75
- [-1, .5],
76
- [0, 1],
77
- [0, 0],
78
- [0, -1],
79
- [0, 2],
80
- [-1, 2],
81
- [-2, 2]
82
- ]
83
- xy_lib['B'] = [
84
- [-2.5, 1.25],
85
- [-2, 2],
86
- [-2, 0.5],
87
- [-2, -0.75],
88
- [-1, -1],
89
- [-1, 2],
90
- [-1, 0],
91
- [-1, 2],
92
- [0, 1],
93
- [0, 0],
94
- [0, -1],
95
- [0, 2],
96
- # [-1, 2],
97
-
98
- ]
99
- xy_lib['B'] = [
100
- [-2.5, 1.25],
101
- [-2, 2],
102
- [-2, 0.5],
103
- [-2, -1],
104
- [-1, -1.25],
105
- [-1, 2],
106
- [-1, 0],
107
- [-1, 2],
108
- [0, 1],
109
- [0, 0],
110
- [0, -1.25],
111
- [0, 2],
112
- # [-1, 2],
113
-
114
- ]
115
- xy_lib['R'] = [
116
- [0, -1],
117
- [0, 0],
118
- [0, 1],
119
- [0, 2],
120
- [-1, -1],
121
- # [-1, 2],
122
- [-2, -1],
123
- [-2, 0],
124
- [-2.25, 2],
125
- [-1, 1]
126
- ]
127
- xy_lib['C'] = [
128
- [0, -1],
129
- [0, 0],
130
- [0, 1],
131
- [0, 2],
132
- [-1, -1],
133
- [-1, 2],
134
- [-2, -1],
135
- # [-2, .5],
136
- [-2, 2],
137
- # [-1, .5]
138
- ]
139
- xy_lib['s'] = [
140
- [0, -1],
141
- [0, 0],
142
- [0, 2],
143
- [-1, -1],
144
- [-1, 2],
145
- [-2, -1],
146
- [-2, 1],
147
- [-2, 2],
148
- [-1, .5]
149
- ]
150
-
151
- xy_lib['F'] = [
152
- [0, -1],
153
- [0, 0],
154
- [0, 1],
155
- [0, 2],
156
- [-1, -1],
157
- # [-1, 2],
158
- [-2, -1],
159
- [-2, .5],
160
- # [-2, 2],
161
- [-1, .5]
162
- ]
163
-
164
- xy_lib['c'] = [
165
- [0.8,1],
166
- # [-0.8,1],
167
- [0,0.1],
168
- [0,1.9],
169
- ]
170
-
171
- xy_lib['e'] = [
172
- [0, -1],
173
- [0, 0],
174
- [0, 1],
175
- [0, 2],
176
- [-1, -1],
177
- [-1, 2],
178
- [-2, -1],
179
- [-2, .5],
180
- [-2, 2],
181
- [-1, .5]
182
- ]
183
- xy_lib['n'] = [
184
- [0,1],
185
- [0,-1],
186
- [0,0.1],
187
- [0,1.9],
188
- [-1,0],
189
- [-2,1],
190
- [-3,-1],
191
- [-3,1],
192
- [-3,0.1],
193
- [-3,1.9],
194
- ]
195
- offset_x = dict(B=4, R=4, C=4, F=4, c=3, s=4, e=4, n=4.8)
196
- s = 'BeRFsCene'
197
- objs = []
198
- offset = 2
199
- for idx, c in enumerate(s):
200
- xy = xy_lib[c]
201
-
202
-
203
- color = np.random.choice(COLOR_NAME_LIST)
204
- for i in range(len(xy)):
205
- # while 1:
206
- # is_ok = 1
207
- # x, y =
208
-
209
- # for prev_x, prev_y in zip(xpool, ypool):
210
- x, y = xy[i]
211
- y *= 1.5
212
- y -= 0.5
213
- x -= offset
214
- z = 0.35
215
- # if idx<4:
216
- # color = np.random.choice(COLOR_NAME_LIST[:-1])
217
- # else:
218
- # color = 'blue'
219
- shape = 'cube'
220
- material = 'rubber'
221
- rot = 0
222
- objs.append([x, y, z, shape, color, material, rot])
223
- offset += offset_x[c]
224
- Image.fromarray((255 * .8 - get_bev_from_objs(objs)[0] *.8 * 255).astype(np.uint8))
225
-
226
- batch_size = 1
227
- code = torch.randn(1, G.z_dim).cuda()
228
- to_pil = torchvision.transforms.ToPILImage()
229
- large_bevs = torch.tensor(get_bev_from_objs(objs)).cuda()[None]
230
- bevs = large_bevs[..., 0: 0+256]
231
- RT = torch.tensor([[ -1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, -0.8660,
232
- 10.3923, 0.0000, -0.8660, -0.5000, 6.0000, 0.0000, 0.0000,
233
- 0.0000, 1.0000, 262.5000, 0.0000, 32.0000, 0.0000, 262.5000,
234
- 32.0000, 0.0000, 0.0000, 1.0000]], device='cuda')
235
-
236
- print('prepare finish', flush=True)
237
-
238
- def predict(name):
239
- print('inference', name, flush=True)
240
  gen = G(code, RT, bevs)
241
  rgb = gen['gen_output']['image'][0] * .5 + .5
242
- print('inference', name, flush=True)
243
  return to_pil(rgb)
244
 
245
- # to_pil(rgb).save('tmp.png')
246
- # save_path = '/mnt/petrelfs/zhangqihang/code/3d-scene-gen/tmp.png'
247
- # return [save_path]
248
-
249
- URL = "https://source.unsplash.com/random/500x500/?nature,fruit"
250
-
251
- def refresh(name):
252
- image = Image.open(requests.get(URL, stream=True).raw)
253
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  with gr.Blocks() as demo:
256
- gr.HTML(
 
 
 
 
 
 
 
 
257
  """
258
- BerfScene demo
259
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
- # gallery = gr.Image(show_label=False)
262
- image = gr.Image(show_label=False)
263
- btn = gr.Button("Result")
264
 
265
- x = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt")
266
- btn.click(fn=predict, inputs=x, outputs=image)
267
- # demo.load(fn=refresh, inputs=x, outputs=gallery, show_progress=False, every=1)
268
 
269
- # btn.click(fn=predict, inputs=num_frames, outputs=gallery, postprocess=False)
270
 
271
  parser = argparse.ArgumentParser()
272
  parser.add_argument('--port', type=int, help='The port number', default=7860)
 
 
1
  import gradio as gr
2
  from models import build_model
3
  from PIL import Image
4
  import numpy as np
5
  import torchvision
6
+ import math
7
  import ninja
8
  import torch
9
  from tqdm import trange
10
  import imageio
11
  import requests
12
  import argparse
13
+ import imageio
14
+ from scipy.spatial.transform import Rotation
15
+
16
+ from gradio_draggable import Draggable
17
 
 
18
  checkpoint = 'clevr.pth'
19
  state = torch.load(checkpoint, map_location='cpu')
20
  G = build_model(**state['model_kwargs_init']['generator_smooth'])
 
26
  fused_modulate=False,
27
  impl='cuda',
28
  fp16_res=None)
29
+ print('prepare finish', flush=True)
30
+
31
+
32
+ COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'purple', 'blue']
33
+ SHAPE_NAME_LIST = ['cube', 'sphere', 'cylinder']
34
+ MATERIAL_NAME_LIST = ['rubber', 'metal']
35
+
36
+ canvas_x = 800
37
+ canvas_y = 200
38
+ batch_size = 1
39
+ code = torch.randn(1, G.z_dim).cuda()
40
+ to_pil = torchvision.transforms.ToPILImage()
41
+
42
+ RT = torch.tensor([[ -1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, -0.8660,
43
+ 10.3923, 0.0000, -0.8660, -0.5000, 6.0000, 0.0000, 0.0000,
44
+ 0.0000, 1.0000, 262.5000, 0.0000, 32.0000, 0.0000, 262.5000,
45
+ 32.0000, 0.0000, 0.0000, 1.0000]], device='cuda')
46
+
47
+ obj_dict = {}
48
 
49
  def trans(x, y, z, length):
50
  w = h = length
 
52
  y = 0.5 * h - 128 + (y/9 + .5) * 256
53
  z = z / 9 * 256
54
  return x, y, z
55
+
56
+ def objs_to_canvas(lst, length=256, scale = 2.6):
57
+ objs = []
58
+ for each in lst:
59
+ x, y, obj_id = each['x'], each['y'], each['id']
60
+
61
+ if obj_id not in obj_dict:
62
+ color = np.random.choice(COLOR_NAME_LIST)
63
+ shape = 'cube'
64
+ material = 'rubber'
65
+ rot = 0
66
+ obj_dict[obj_id] = [color, shape, material, rot]
67
+
68
+ color, shape, material, rot = obj_dict[obj_id]
69
+ x = -x / canvas_x * 16
70
+ y = y / canvas_y * 2
71
+ y *= 2
72
+ x += 1.0
73
+ y -= 1.5
74
+ z = 0.35
75
+ objs.append([x, y, z, shape, color, material, rot])
76
+
77
+ h, w = length, int(length *scale)
78
  nc = 14
79
  canvas = np.zeros([h, w, nc])
80
  xx = np.ones([h,w]).cumsum(0)
 
99
  mask = ((xx-x)**2 + (y-yy)**2) ** 0.5 <= z
100
  canvas[mask] = feat
101
  canvas = np.transpose(canvas, [2, 0, 1]).astype(np.float32)
 
 
102
  return canvas
103
+
104
+ @torch.no_grad()
105
+ def predict_local_view(lst):
106
+ canvas = torch.tensor(objs_to_canvas(lst)).cuda()[None]
107
+ bevs = canvas[..., 0: 0+256]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  gen = G(code, RT, bevs)
109
  rgb = gen['gen_output']['image'][0] * .5 + .5
 
110
  return to_pil(rgb)
111
 
112
+ @torch.no_grad()
113
+ def predict_local_view_video(lst):
114
+ canvas = torch.tensor(objs_to_canvas(lst)).cuda()[None]
115
+ bevs = canvas[..., 0: 0+256]
116
+ RT_array = np.array(RT[0].cpu())
117
+ rot = RT_array[:16].reshape(4,4)
118
+ trans = RT_array[16:]
119
+ rot_new = rot.copy()
120
+ r = Rotation.from_matrix(rot[:3, :3])
121
+ angles = r.as_euler("zyx",degrees=True)
122
+ v_mean, h_mean = angles[1], angles[2]
123
+
124
+ writer = imageio.get_writer('tmp.mp4', fps=25)
125
+ for t in np.linspace(0, 1, 50):
126
+ angles[1] = 0.5 * np.cos(t * 2 * math.pi) + v_mean
127
+ angles[2] = 1 * np.sin(t * 2 * math.pi) + h_mean
128
+ r = Rotation.from_euler("zyx",angles,degrees=True)
129
+ rot_new[:3,:3] = r.as_matrix()
130
+ new_RT = torch.tensor(np.concatenate([rot_new.flatten(), trans])[None]).cuda().float()
131
+ gen = G(code, new_RT, bevs)
132
+ rgb = gen['gen_output']['image'][0] * .5 + .5
133
+ writer.append_data(np.array(to_pil(rgb)))
134
+ writer.close()
135
+ return 'tmp.mp4'
136
+
137
+ @torch.no_grad()
138
+ def predict_global_view(lst):
139
+ canvas = torch.tensor(objs_to_canvas(lst)).cuda()[None]
140
+ length = canvas.shape[-1]
141
+ lines = []
142
+ for i in trange(0, length - 256, 10):
143
+ bevs = canvas[..., i: i+256]
144
+ gen = G(code, RT, bevs)
145
+ start = 128 if i > 0 else 0
146
+ lines.append(gen['gen_output']['image'][0, ..., start:128+32])
147
+ rgb = torch.cat(lines, 2)*.5+.5
148
+ return to_pil(rgb)
149
 
150
  with gr.Blocks() as demo:
151
+ gr.Markdown(
152
+ """
153
+ # BerfScene: Bev-conditioned Equivariant Radiance Fields for Infinite 3D Scene Generation
154
+ Qihang Zhang, Yinghao Xu, Yujun Shen, Bo Dai, Bolei Zhou*, Ceyuan Yang* (*Corresponding Author)<br>
155
+ [Arxiv Report](https://arxiv.org/abs/2312.02136) | [Project Page](https://zqh0253.github.io/BerfScene/) | [Github](https://github.com/zqh0253/BerfScene)
156
+ """
157
+ )
158
+
159
+ gr.Markdown(
160
  """
161
+ ### Quick Start
162
+ 1. Drag and place objects in the canvas.
163
+ 2. Click `Add object` to insert object into the canvas.
164
+ 3. Click `Reset` to clean the canvas.
165
+ 4. Click `Get local view` to synthesize local 3D scenes.
166
+ 5. Click `Get global view` to synthesize global 3D scenes.
167
+ """
168
+ )
169
+
170
+ with gr.Row():
171
+ with gr.Column():
172
+
173
+ drag = Draggable()
174
+ with gr.Row():
175
+ submit_btn_local = gr.Button("Get local view", variant='primary')
176
+ submit_btn_global = gr.Button("Get global view", variant='primary')
177
+
178
+ with gr.Column():
179
+ with gr.Row():
180
+ single_view_image = gr.Image(label='single view', interactive=False)
181
+ single_view_video = gr.Video(label='mutli-view', interactive=False, autoplay=True)
182
+
183
+ global_view_image = gr.Image(label='global view', interactive=False)
184
 
 
 
 
185
 
186
+ submit_btn_local.click(fn=predict_local_view, inputs=drag, outputs=single_view_image)
187
+ submit_btn_local.click(fn=predict_local_view_video, inputs=drag, outputs=single_view_video)
188
+ submit_btn_global.click(fn=predict_global_view, inputs=drag, outputs=global_view_image)
189
 
 
190
 
191
  parser = argparse.ArgumentParser()
192
  parser.add_argument('--port', type=int, help='The port number', default=7860)
requirements.txt CHANGED
@@ -19,4 +19,5 @@ lmdb
19
  matplotlib
20
  einops
21
  imageio
 
22
  gradio
 
19
  matplotlib
20
  einops
21
  imageio
22
+ imageio-ffmpeg
23
  gradio
test.py CHANGED
@@ -1,18 +1,66 @@
1
  import gradio as gr
2
- import requests
3
- from PIL import Image
4
- URL = "https://source.unsplash.com/random/500x500/?nature,fruit"
5
 
 
 
 
6
 
7
- def refresh():
8
- image = Image.open(requests.get(URL, stream=True).raw)
9
- return image
10
 
 
 
 
 
11
 
12
- with gr.Blocks() as blocks:
13
- image = gr.Image(show_label=False)
14
- blocks.load(fn=refresh, inputs=None, outputs=image,
15
- show_progress=False, every=1)
16
 
17
- blocks.queue(api_open=False)
18
- blocks.launch(server_name='0.0.0.0', server_port=10093)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
2
 
3
+ def update_position(data):
4
+ # data will be the position of the rectangle, expected to be a JSON string
5
+ return data # Here you can parse and use the position data as needed
6
 
7
+ html_code = """
8
+ <div id="canvas-container"></div>
 
9
 
10
+ <script>
11
+ document.getElementById('canvas-container').innerHTML = `
12
+ <canvas id="canvas" width="500" height="500"></canvas>
13
+ `;
14
 
15
+ const canvas = document.getElementById('canvas');
16
+ const ctx = canvas.getContext('2d');
17
+ const rect = { x: 50, y: 50, width: 100, height: 50, isDragging: false };
 
18
 
19
+ function draw() {
20
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
21
+ ctx.fillStyle = 'blue';
22
+ ctx.fillRect(rect.x, rect.y, rect.width, rect.height);
23
+ }
24
+
25
+ function sendData() {
26
+ GradioApp.send({x: rect.x, y: rect.y});
27
+ }
28
+
29
+ function mouseDown(e) {
30
+ if (e.offsetX >= rect.x && e.offsetX <= rect.x + rect.width &&
31
+ e.offsetY >= rect.y && e.offsetY <= rect.y + rect.height) {
32
+ rect.isDragging = true;
33
+ }
34
+ }
35
+
36
+ function mouseMove(e) {
37
+ if (rect.isDragging) {
38
+ rect.x = e.offsetX - rect.width / 2;
39
+ rect.y = e.offsetY - rect.height / 2;
40
+ draw();
41
+ sendData();
42
+ }
43
+ }
44
+
45
+ function mouseUp() {
46
+ rect.isDragging = false;
47
+ sendData();
48
+ }
49
+
50
+ canvas.addEventListener('mousedown', mouseDown);
51
+ canvas.addEventListener('mousemove', mouseMove);
52
+ canvas.addEventListener('mouseup', mouseUp);
53
+
54
+ draw();
55
+ </script>
56
+ """
57
+
58
+ interface = gr.Interface(
59
+ fn=update_position,
60
+ inputs=gr.HTML(),
61
+ outputs="json",
62
+ allow_flagging="never",
63
+ live=True
64
+ )
65
+
66
+ interface.launch(server_name='0.0.0.0', server_port=7860)