silentchen commited on
Commit
1e2802d
1 Parent(s): 03200f4

read model every inference

Browse files
Files changed (1) hide show
  1. app.py +20 -9
app.py CHANGED
@@ -115,9 +115,9 @@ def generate_3d_with_shap_e(xm, diffusion, latent_model, device, prompt, rand_se
115
  mesh.apply_transform(rotation_matrix)
116
  output_path_tmp = os.path.join(output_path, 'source', '{}_{}.obj'.format(prompt_hash, i))
117
  mesh.export(output_path_tmp)
118
- state['latent'].append(latent.clone().detach())
119
  mesh_path.append(output_path_tmp)
120
-
121
  return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state
122
  @torch.no_grad()
123
  def _3d_editing(xm, models, diffusion, initial_noise, start_t, device, instruction, rand_seed, state):
@@ -140,8 +140,18 @@ def _3d_editing(xm, models, diffusion, initial_noise, start_t, device, instructi
140
  elif 'cyber' in instruction:
141
  e_type = 'cyber'
142
 
143
- # import pdb; pdb.set_trace()
144
- model = models[e_type].to(device)
 
 
 
 
 
 
 
 
 
 
145
  noise_initial = initial_noise[e_type].to(device)
146
  noise_start_t = start_t[e_type]
147
  general_save_path = './logs/edited'
@@ -266,9 +276,9 @@ def main():
266
  editing_types = ['rainbow', 'santa_hat', 'lego', 'golden', 'wooden', 'cyber']
267
  # prepare models
268
  for editing_type in editing_types:
269
- tmp_model = load_model('text300M', device=device)
270
  with torch.no_grad():
271
- new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=tmp_model.wrapped.input_proj.weight.dtype)
272
  new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
273
  new_proj.weight[:, :1024].copy_(tmp_model.wrapped.input_proj.weight) #
274
  new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
@@ -277,12 +287,13 @@ def main():
277
 
278
  ckp = torch.load(hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(editing_type)), map_location='cpu')
279
  tmp_model.load_state_dict(ckp['model'])
280
- noise_initial = ckp['initial_noise']['noise'].to(device)
281
  initial_noise[editing_type] = noise_initial
282
  noise_start_t[editing_type] = ckp['t_start']
283
  models[editing_type] = tmp_model
284
- # del models
285
- # models = None
 
286
  with Blocks(
287
  css=css,
288
  analytics_enabled=False,
 
115
  mesh.apply_transform(rotation_matrix)
116
  output_path_tmp = os.path.join(output_path, 'source', '{}_{}.obj'.format(prompt_hash, i))
117
  mesh.export(output_path_tmp)
118
+ state['latent'].append(latent.clone().detach().cpu())
119
  mesh_path.append(output_path_tmp)
120
+ del latents
121
  return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state
122
  @torch.no_grad()
123
  def _3d_editing(xm, models, diffusion, initial_noise, start_t, device, instruction, rand_seed, state):
 
140
  elif 'cyber' in instruction:
141
  e_type = 'cyber'
142
 
143
+ model = load_model('text300M', device=device)
144
+ with torch.no_grad():
145
+ new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=model.wrapped.input_proj.weight.dtype)
146
+ new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
147
+ new_proj.weight[:, :1024].copy_(model.wrapped.input_proj.weight) #
148
+ new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
149
+ new_proj.bias[:1024].copy_(model.wrapped.input_proj.bias)
150
+ model.wrapped.input_proj = new_proj
151
+
152
+ ckp = torch.load(hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(e_type)), map_location='cpu')
153
+ model.load_state_dict(ckp['model'])
154
+
155
  noise_initial = initial_noise[e_type].to(device)
156
  noise_start_t = start_t[e_type]
157
  general_save_path = './logs/edited'
 
276
  editing_types = ['rainbow', 'santa_hat', 'lego', 'golden', 'wooden', 'cyber']
277
  # prepare models
278
  for editing_type in editing_types:
279
+ tmp_model = load_model('text300M', device=torch.device('cpu'))
280
  with torch.no_grad():
281
+ new_proj = nn.Linear(1024 * 2, 1024, device=torch.device('cpu'), dtype=tmp_model.wrapped.input_proj.weight.dtype)
282
  new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
283
  new_proj.weight[:, :1024].copy_(tmp_model.wrapped.input_proj.weight) #
284
  new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
 
287
 
288
  ckp = torch.load(hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(editing_type)), map_location='cpu')
289
  tmp_model.load_state_dict(ckp['model'])
290
+ noise_initial = ckp['initial_noise']['noise'].to(torch.device('cpu'))
291
  initial_noise[editing_type] = noise_initial
292
  noise_start_t[editing_type] = ckp['t_start']
293
  models[editing_type] = tmp_model
294
+
295
+ del models
296
+ models = None
297
  with Blocks(
298
  css=css,
299
  analytics_enabled=False,