silentchen commited on
Commit
03200f4
1 Parent(s): 883a17d

torch.no_grad

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -58,12 +58,14 @@ class Blocks(gr.Blocks):
58
  config[k] = v
59
 
60
  return config
 
61
  def optimize_all(xm, models, initial_noise, noise_start_t, diffusion, latent_model, device, prompt, instruction, rand_seed):
62
  state = {}
63
  out_gen_1, out_gen_2, out_gen_3, out_gen_4, state = generate_3d_with_shap_e(xm, diffusion, latent_model, device, prompt, rand_seed, state)
64
  edited_1, edited_2, edited_3, edited_4, state = _3d_editing(xm, models, diffusion, initial_noise, noise_start_t, device, instruction, rand_seed, state)
65
  print(state)
66
  return out_gen_1, out_gen_2, out_gen_3, out_gen_4, edited_1, edited_2, edited_3, edited_4
 
67
  def generate_3d_with_shap_e(xm, diffusion, latent_model, device, prompt, rand_seed, state):
68
  set_seed(rand_seed)
69
  batch_size = 4
@@ -117,7 +119,7 @@ def generate_3d_with_shap_e(xm, diffusion, latent_model, device, prompt, rand_se
117
  mesh_path.append(output_path_tmp)
118
 
119
  return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state
120
-
121
  def _3d_editing(xm, models, diffusion, initial_noise, start_t, device, instruction, rand_seed, state):
122
  set_seed(rand_seed)
123
  mesh_path = []
@@ -262,7 +264,7 @@ def main():
262
  initial_noise = dict()
263
  noise_start_t = dict()
264
  editing_types = ['rainbow', 'santa_hat', 'lego', 'golden', 'wooden', 'cyber']
265
-
266
  for editing_type in editing_types:
267
  tmp_model = load_model('text300M', device=device)
268
  with torch.no_grad():
@@ -279,7 +281,8 @@ def main():
279
  initial_noise[editing_type] = noise_initial
280
  noise_start_t[editing_type] = ckp['t_start']
281
  models[editing_type] = tmp_model
282
-
 
283
  with Blocks(
284
  css=css,
285
  analytics_enabled=False,
@@ -352,7 +355,7 @@ def main():
352
  with gr.Column():
353
  gr.Examples(
354
  examples=[
355
- [ "a corgi",
356
  "Make the color of it look like rainbow",
357
  456,
358
  ],
@@ -369,7 +372,7 @@ def main():
369
 
370
 
371
  demo.queue(max_size=10, api_open=False)
372
- demo.launch(share=True, show_api=False, show_error=True)
373
 
374
  if __name__ == '__main__':
375
  main()
 
58
  config[k] = v
59
 
60
  return config
61
+ @torch.no_grad()
62
  def optimize_all(xm, models, initial_noise, noise_start_t, diffusion, latent_model, device, prompt, instruction, rand_seed):
63
  state = {}
64
  out_gen_1, out_gen_2, out_gen_3, out_gen_4, state = generate_3d_with_shap_e(xm, diffusion, latent_model, device, prompt, rand_seed, state)
65
  edited_1, edited_2, edited_3, edited_4, state = _3d_editing(xm, models, diffusion, initial_noise, noise_start_t, device, instruction, rand_seed, state)
66
  print(state)
67
  return out_gen_1, out_gen_2, out_gen_3, out_gen_4, edited_1, edited_2, edited_3, edited_4
68
+ @torch.no_grad()
69
  def generate_3d_with_shap_e(xm, diffusion, latent_model, device, prompt, rand_seed, state):
70
  set_seed(rand_seed)
71
  batch_size = 4
 
119
  mesh_path.append(output_path_tmp)
120
 
121
  return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state
122
+ @torch.no_grad()
123
  def _3d_editing(xm, models, diffusion, initial_noise, start_t, device, instruction, rand_seed, state):
124
  set_seed(rand_seed)
125
  mesh_path = []
 
264
  initial_noise = dict()
265
  noise_start_t = dict()
266
  editing_types = ['rainbow', 'santa_hat', 'lego', 'golden', 'wooden', 'cyber']
267
+ # prepare models
268
  for editing_type in editing_types:
269
  tmp_model = load_model('text300M', device=device)
270
  with torch.no_grad():
 
281
  initial_noise[editing_type] = noise_initial
282
  noise_start_t[editing_type] = ckp['t_start']
283
  models[editing_type] = tmp_model
284
+ # del models
285
+ # models = None
286
  with Blocks(
287
  css=css,
288
  analytics_enabled=False,
 
355
  with gr.Column():
356
  gr.Examples(
357
  examples=[
358
+ ["a corgi",
359
  "Make the color of it look like rainbow",
360
  456,
361
  ],
 
372
 
373
 
374
  demo.queue(max_size=10, api_open=False)
375
+ demo.launch(share=False, show_api=False, show_error=True)
376
 
377
  if __name__ == '__main__':
378
  main()