silentchen commited on
Commit
007806e
1 Parent(s): fd16ff8

update space

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -153,9 +153,14 @@ def main():
153
  editing_types = ['rainbow', 'santa_hat', 'lego', 'golden', 'wooden', 'cyber']
154
  # prepare models
155
  for editing_type in editing_types:
156
- tmp_model = load_model('text300M', device=torch.device('cpu'))
 
 
 
 
 
157
  with torch.no_grad():
158
- new_proj = nn.Linear(1024 * 2, 1024, device=torch.device('cpu'), dtype=tmp_model.wrapped.input_proj.weight.dtype)
159
  new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
160
  new_proj.weight[:, :1024].copy_(tmp_model.wrapped.input_proj.weight) #
161
  new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
@@ -164,10 +169,13 @@ def main():
164
 
165
  ckp = torch.load(hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(editing_type)), map_location='cpu')
166
  tmp_model.load_state_dict(ckp['model'])
167
- noise_initial = ckp['initial_noise']['noise'].to(torch.device('cpu'))
 
 
 
168
  initial_noise[editing_type] = noise_initial
169
  noise_start_t[editing_type] = ckp['t_start']
170
- models[editing_type] = tmp_model
171
  @torch.no_grad()
172
  def optimize_all(prompt, instruction,
173
  rand_seed):
@@ -279,12 +287,14 @@ def main():
279
  os.makedirs(general_save_path, exist_ok=True)
280
  for i, latent in enumerate(state['latent']):
281
  latent = latent.to(device)
282
- text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction]))
283
  print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
284
  ref_latent = latent.clone().unsqueeze(0).to(device)
285
  t_1 = torch.randint(noise_start_t_e_type, noise_start_t_e_type + 1, (1,), device=device).long()
286
 
287
  noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
 
 
288
  out_1 = diffusion.p_mean_variance(model, noise_input, t_1, clip_denoised=True,
289
  model_kwargs=text_embeddings_clip,
290
  condition_latents=ref_latent)
 
153
  editing_types = ['rainbow', 'santa_hat', 'lego', 'golden', 'wooden', 'cyber']
154
  # prepare models
155
  for editing_type in editing_types:
156
+ tmp_model = model_from_config(load_config('text300M'), device=device)
157
+ # print(model_name, kwargs)
158
+ # print(model)
159
+
160
+ # xm = load_model('transmitter', de
161
+ tmp_model = load_model('text300M', device=device)
162
  with torch.no_grad():
163
+ new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=tmp_model.wrapped.input_proj.weight.dtype)
164
  new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
165
  new_proj.weight[:, :1024].copy_(tmp_model.wrapped.input_proj.weight) #
166
  new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
 
169
 
170
  ckp = torch.load(hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(editing_type)), map_location='cpu')
171
  tmp_model.load_state_dict(ckp['model'])
172
+ tmp_model.eval()
173
+ # print("loaded latent model")
174
+ tmp_model.to(device)
175
+ noise_initial = ckp['initial_noise']['noise'].to(device)
176
  initial_noise[editing_type] = noise_initial
177
  noise_start_t[editing_type] = ckp['t_start']
178
+ models[editing_type] = tmp_model.to(device)
179
  @torch.no_grad()
180
  def optimize_all(prompt, instruction,
181
  rand_seed):
 
287
  os.makedirs(general_save_path, exist_ok=True)
288
  for i, latent in enumerate(state['latent']):
289
  latent = latent.to(device)
290
+ text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction])).to(device)
291
  print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
292
  ref_latent = latent.clone().unsqueeze(0).to(device)
293
  t_1 = torch.randint(noise_start_t_e_type, noise_start_t_e_type + 1, (1,), device=device).long()
294
 
295
  noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
296
+ print("noise_input:", noise_input.device)
297
+
298
  out_1 = diffusion.p_mean_variance(model, noise_input, t_1, clip_denoised=True,
299
  model_kwargs=text_embeddings_clip,
300
  condition_latents=ref_latent)