silentchen commited on
Commit
5d94b0f
1 Parent(s): 4533120

update space

Browse files
Files changed (1) hide show
  1. app.py +158 -155
app.py CHANGED
@@ -62,158 +62,6 @@ class Blocks(gr.Blocks):
62
 
63
  return config
64
 
65
- @torch.no_grad()
66
- def optimize_all(xm, models, initial_noise, noise_start_t, diffusion, latent_model, device, prompt, instruction, rand_seed):
67
- state = {}
68
- out_gen_1, out_gen_2, out_gen_3, out_gen_4, state = generate_3d_with_shap_e(xm, diffusion, latent_model, device, prompt, rand_seed, state)
69
- edited_1, edited_2, edited_3, edited_4, state = _3d_editing(xm, models, diffusion, initial_noise, noise_start_t, device, instruction, rand_seed, state)
70
- print(state)
71
- return out_gen_1, out_gen_2, out_gen_3, out_gen_4, edited_1, edited_2, edited_3, edited_4
72
-
73
-
74
- @spaces.GPU()
75
- @torch.no_grad()
76
- def generate_3d_with_shap_e(xm, diffusion, latent_model, device, prompt, rand_seed, state):
77
- print("Check if I can use partial")
78
- set_seed(rand_seed)
79
- batch_size = 4
80
- guidance_scale = 15.0
81
- xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.0]).to(device)
82
- xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1.0]).to(device)
83
- xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max])
84
-
85
- print("prompt: ", prompt, "rand_seed: ", rand_seed, "state:", state)
86
- latents = sample_latents(
87
- batch_size=batch_size,
88
- model=latent_model,
89
- diffusion=diffusion,
90
- guidance_scale=guidance_scale,
91
- model_kwargs=dict(texts=[prompt] * batch_size),
92
- progress=True,
93
- clip_denoised=True,
94
- use_fp16=True,
95
- use_karras=True,
96
- karras_steps=64,
97
- sigma_min=1e-3,
98
- sigma_max=160,
99
- s_churn=0,
100
- )
101
- prompt_hash = str(hashlib.sha256((prompt + '_' + str(rand_seed)).encode('utf-8')).hexdigest())
102
- mesh_path = []
103
- output_path = './logs'
104
- os.makedirs(os.path.join(output_path, 'source'), exist_ok=True)
105
- state['latent'] = []
106
- state['prompt'] = prompt
107
- state['rand_seed_1'] = rand_seed
108
- for i, latent in enumerate(latents):
109
-
110
- output_path_tmp = os.path.join(output_path, 'source', '{}_{}.obj'.format(prompt_hash, i))
111
- t_obj = decode_latent_mesh(xm, latent).tri_mesh()
112
- with open(output_path_tmp, 'w') as f:
113
- t_obj.write_obj(f)
114
-
115
- mesh = trimesh.load_mesh(output_path_tmp)
116
- angle = np.radians(180)
117
- axis = [0, 1, 0]
118
- rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
119
- mesh.apply_transform(rotation_matrix)
120
- angle = np.radians(90)
121
- axis = [1, 0, 0]
122
- rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
123
- mesh.apply_transform(rotation_matrix)
124
- output_path_tmp = os.path.join(output_path, 'source', '{}_{}.obj'.format(prompt_hash, i))
125
- mesh.export(output_path_tmp)
126
- state['latent'].append(latent.clone().detach().cpu())
127
- mesh_path.append(output_path_tmp)
128
- del latents
129
- return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state
130
-
131
- @spaces.GPU()
132
- @torch.no_grad()
133
- def _3d_editing(xm, models, diffusion, initial_noise, start_t, device, instruction, rand_seed, state):
134
- set_seed(rand_seed)
135
- mesh_path = []
136
- prompt = state['prompt']
137
- rand_seed_1 = state['rand_seed_1']
138
- print("prompt: ", prompt, "rand_seed: ", rand_seed, "instruction:", instruction, "state:", state)
139
- prompt_hash = str(hashlib.sha256((prompt + '_' + str(rand_seed_1) + '_' + instruction + '_' + str(rand_seed)).encode('utf-8')).hexdigest())
140
- if 'santa' in instruction:
141
- e_type = 'santa_hat'
142
- elif 'rainbow' in instruction:
143
- e_type = 'rainbow'
144
- elif 'gold' in instruction:
145
- e_type = 'golden'
146
- elif 'lego' in instruction:
147
- e_type = 'lego'
148
- elif 'wooden' in instruction:
149
- e_type = 'wooden'
150
- elif 'cyber' in instruction:
151
- e_type = 'cyber'
152
-
153
- model = load_model('text300M', device=device)
154
- with torch.no_grad():
155
- new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=model.wrapped.input_proj.weight.dtype)
156
- new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
157
- new_proj.weight[:, :1024].copy_(model.wrapped.input_proj.weight) #
158
- new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
159
- new_proj.bias[:1024].copy_(model.wrapped.input_proj.bias)
160
- model.wrapped.input_proj = new_proj
161
-
162
- ckp = torch.load(hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(e_type)), map_location='cpu')
163
- model.load_state_dict(ckp['model'])
164
-
165
- noise_initial = initial_noise[e_type].to(device)
166
- noise_start_t = start_t[e_type]
167
- general_save_path = './logs/edited'
168
- os.makedirs(general_save_path, exist_ok=True)
169
- for i, latent in enumerate(state['latent']):
170
- latent = latent.to(device)
171
- text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction]))
172
- print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
173
- ref_latent = latent.clone().unsqueeze(0)
174
- t_1 = torch.randint(noise_start_t, noise_start_t + 1, (1,), device=device).long()
175
-
176
- noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
177
- out_1 = diffusion.p_mean_variance(model, noise_input, t_1, clip_denoised=True,
178
- model_kwargs=text_embeddings_clip,
179
- condition_latents=ref_latent)
180
-
181
- updated_latents = out_1['pred_xstart']
182
-
183
- if 'santa' in instruction:
184
- xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.25]).to(device)
185
- xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1]).to(device)
186
- xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max])
187
-
188
- else:
189
- xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.0]).to(device)
190
- xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1.0]).to(device)
191
- xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max])
192
-
193
- for latent_idx, updated_latent in enumerate(updated_latents):
194
- output_path = os.path.join(general_save_path, '{}_{}.obj'.format(prompt_hash, i))
195
-
196
- t = decode_latent_mesh(xm, updated_latent).tri_mesh()
197
- with open(output_path, 'w') as f:
198
- t.write_obj(f)
199
- mesh = trimesh.load_mesh(output_path)
200
-
201
- angle = np.radians(180)
202
- axis = [0, 1, 0]
203
-
204
- rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
205
- mesh.apply_transform(rotation_matrix)
206
- angle = np.radians(90)
207
- axis = [1, 0, 0]
208
-
209
- rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
210
- mesh.apply_transform(rotation_matrix)
211
-
212
- output_path = os.path.join(general_save_path, '{}_{}.obj'.format(prompt_hash, i))
213
- mesh.export(output_path)
214
- mesh_path.append(output_path)
215
- return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state
216
-
217
  def main():
218
 
219
  css = """
@@ -320,6 +168,161 @@ def main():
320
  initial_noise[editing_type] = noise_initial
321
  noise_start_t[editing_type] = ckp['t_start']
322
  models[editing_type] = tmp_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
  del models
325
  models = None
@@ -388,13 +391,13 @@ def main():
388
  rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random seed")
389
 
390
  gen_btn.click(
391
- fn=partial(generate_3d_with_shap_e, xm, diffusion, latent_model, device),
392
  inputs=[prompt, rand_seed, state],
393
  outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
394
  queue=False)
395
 
396
  apply_btn.click(
397
- fn=partial(_3d_editing, xm, models, diffusion, initial_noise, noise_start_t, device),
398
  inputs=[
399
  editing_choice[0], rand_seed, state
400
  ],
@@ -416,7 +419,7 @@ def main():
416
  ],
417
  inputs=[prompt, editing_choice[0], rand_seed],
418
  outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, edited_1, edited_2, edited_3, edited_4],
419
- fn=partial(optimize_all, xm, models, initial_noise, noise_start_t, diffusion, latent_model, device),
420
  cache_examples=True,
421
  )
422
 
 
62
 
63
  return config
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def main():
66
 
67
  css = """
 
168
  initial_noise[editing_type] = noise_initial
169
  noise_start_t[editing_type] = ckp['t_start']
170
  models[editing_type] = tmp_model
171
+ @torch.no_grad()
172
+ def optimize_all(prompt, instruction,
173
+ rand_seed):
174
+ print("Optimizing all")
175
+ state = {}
176
+ out_gen_1, out_gen_2, out_gen_3, out_gen_4, state = generate_3d_with_shap_e(prompt, rand_seed, state)
177
+ edited_1, edited_2, edited_3, edited_4, state = _3d_editing(instruction, rand_seed, state)
178
+ print(state)
179
+ return out_gen_1, out_gen_2, out_gen_3, out_gen_4, edited_1, edited_2, edited_3, edited_4
180
+
181
+ @spaces.GPU()
182
+ @torch.no_grad()
183
+ def generate_3d_with_shap_e(prompt, rand_seed, state):
184
+ print("Check if I can use partial")
185
+ set_seed(rand_seed)
186
+ batch_size = 4
187
+ guidance_scale = 15.0
188
+ xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.0]).to(device)
189
+ xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1.0]).to(device)
190
+ xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max])
191
+
192
+ print("prompt: ", prompt, "rand_seed: ", rand_seed, "state:", state)
193
+ latents = sample_latents(
194
+ batch_size=batch_size,
195
+ model=latent_model,
196
+ diffusion=diffusion,
197
+ guidance_scale=guidance_scale,
198
+ model_kwargs=dict(texts=[prompt] * batch_size),
199
+ progress=True,
200
+ clip_denoised=True,
201
+ use_fp16=True,
202
+ use_karras=True,
203
+ karras_steps=64,
204
+ sigma_min=1e-3,
205
+ sigma_max=160,
206
+ s_churn=0,
207
+ )
208
+ prompt_hash = str(hashlib.sha256((prompt + '_' + str(rand_seed)).encode('utf-8')).hexdigest())
209
+ mesh_path = []
210
+ output_path = './logs'
211
+ os.makedirs(os.path.join(output_path, 'source'), exist_ok=True)
212
+ state['latent'] = []
213
+ state['prompt'] = prompt
214
+ state['rand_seed_1'] = rand_seed
215
+ for i, latent in enumerate(latents):
216
+ output_path_tmp = os.path.join(output_path, 'source', '{}_{}.obj'.format(prompt_hash, i))
217
+ t_obj = decode_latent_mesh(xm, latent).tri_mesh()
218
+ with open(output_path_tmp, 'w') as f:
219
+ t_obj.write_obj(f)
220
+
221
+ mesh = trimesh.load_mesh(output_path_tmp)
222
+ angle = np.radians(180)
223
+ axis = [0, 1, 0]
224
+ rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
225
+ mesh.apply_transform(rotation_matrix)
226
+ angle = np.radians(90)
227
+ axis = [1, 0, 0]
228
+ rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
229
+ mesh.apply_transform(rotation_matrix)
230
+ output_path_tmp = os.path.join(output_path, 'source', '{}_{}.obj'.format(prompt_hash, i))
231
+ mesh.export(output_path_tmp)
232
+ state['latent'].append(latent.clone().detach().cpu())
233
+ mesh_path.append(output_path_tmp)
234
+ del latents
235
+ return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state
236
+
237
+ @spaces.GPU()
238
+ @torch.no_grad()
239
+ def _3d_editing(instruction, rand_seed, state):
240
+ set_seed(rand_seed)
241
+ mesh_path = []
242
+ prompt = state['prompt']
243
+ rand_seed_1 = state['rand_seed_1']
244
+ print("prompt: ", prompt, "rand_seed: ", rand_seed, "instruction:", instruction, "state:", state)
245
+ prompt_hash = str(hashlib.sha256(
246
+ (prompt + '_' + str(rand_seed_1) + '_' + instruction + '_' + str(rand_seed)).encode('utf-8')).hexdigest())
247
+ if 'santa' in instruction:
248
+ e_type = 'santa_hat'
249
+ elif 'rainbow' in instruction:
250
+ e_type = 'rainbow'
251
+ elif 'gold' in instruction:
252
+ e_type = 'golden'
253
+ elif 'lego' in instruction:
254
+ e_type = 'lego'
255
+ elif 'wooden' in instruction:
256
+ e_type = 'wooden'
257
+ elif 'cyber' in instruction:
258
+ e_type = 'cyber'
259
+
260
+ model = load_model('text300M', device=device)
261
+ with torch.no_grad():
262
+ new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=model.wrapped.input_proj.weight.dtype)
263
+ new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
264
+ new_proj.weight[:, :1024].copy_(model.wrapped.input_proj.weight) #
265
+ new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
266
+ new_proj.bias[:1024].copy_(model.wrapped.input_proj.bias)
267
+ model.wrapped.input_proj = new_proj
268
+
269
+ ckp = torch.load(
270
+ hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(e_type)),
271
+ map_location='cpu')
272
+ model.load_state_dict(ckp['model'])
273
+
274
+ noise_initial = initial_noise[e_type].to(device)
275
+ noise_start_t = noise_start_t[e_type]
276
+ general_save_path = './logs/edited'
277
+ os.makedirs(general_save_path, exist_ok=True)
278
+ for i, latent in enumerate(state['latent']):
279
+ latent = latent.to(device)
280
+ text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction]))
281
+ print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
282
+ ref_latent = latent.clone().unsqueeze(0)
283
+ t_1 = torch.randint(noise_start_t, noise_start_t + 1, (1,), device=device).long()
284
+
285
+ noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
286
+ out_1 = diffusion.p_mean_variance(model, noise_input, t_1, clip_denoised=True,
287
+ model_kwargs=text_embeddings_clip,
288
+ condition_latents=ref_latent)
289
+
290
+ updated_latents = out_1['pred_xstart']
291
+
292
+ if 'santa' in instruction:
293
+ xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.25]).to(device)
294
+ xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1]).to(device)
295
+ xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max])
296
+
297
+ else:
298
+ xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.0]).to(device)
299
+ xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1.0]).to(device)
300
+ xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max])
301
+
302
+ for latent_idx, updated_latent in enumerate(updated_latents):
303
+ output_path = os.path.join(general_save_path, '{}_{}.obj'.format(prompt_hash, i))
304
+
305
+ t = decode_latent_mesh(xm, updated_latent).tri_mesh()
306
+ with open(output_path, 'w') as f:
307
+ t.write_obj(f)
308
+ mesh = trimesh.load_mesh(output_path)
309
+
310
+ angle = np.radians(180)
311
+ axis = [0, 1, 0]
312
+
313
+ rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
314
+ mesh.apply_transform(rotation_matrix)
315
+ angle = np.radians(90)
316
+ axis = [1, 0, 0]
317
+
318
+ rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
319
+ mesh.apply_transform(rotation_matrix)
320
+
321
+ output_path = os.path.join(general_save_path, '{}_{}.obj'.format(prompt_hash, i))
322
+ mesh.export(output_path)
323
+ mesh_path.append(output_path)
324
+ return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state
325
+
326
 
327
  del models
328
  models = None
 
391
  rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random seed")
392
 
393
  gen_btn.click(
394
+ fn=generate_3d_with_shap_e,
395
  inputs=[prompt, rand_seed, state],
396
  outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
397
  queue=False)
398
 
399
  apply_btn.click(
400
+ fn=_3d_editing,
401
  inputs=[
402
  editing_choice[0], rand_seed, state
403
  ],
 
419
  ],
420
  inputs=[prompt, editing_choice[0], rand_seed],
421
  outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, edited_1, edited_2, edited_3, edited_4],
422
+ fn=optimize_all,
423
  cache_examples=True,
424
  )
425