silentchen commited on
Commit
fd16ff8
1 Parent(s): ab8308a

update space

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -257,7 +257,8 @@ def main():
257
  elif 'cyber' in instruction:
258
  e_type = 'cyber'
259
 
260
- model = models[e_type].to(device)
 
261
  # model = load_model('text300M', device=device)
262
  # with torch.no_grad():
263
  # new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=model.wrapped.input_proj.weight.dtype)
@@ -280,7 +281,7 @@ def main():
280
  latent = latent.to(device)
281
  text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction]))
282
  print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
283
- ref_latent = latent.clone().unsqueeze(0)
284
  t_1 = torch.randint(noise_start_t_e_type, noise_start_t_e_type + 1, (1,), device=device).long()
285
 
286
  noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
 
257
  elif 'cyber' in instruction:
258
  e_type = 'cyber'
259
 
260
+ model = models[e_type]
261
+ model = model.to(device)
262
  # model = load_model('text300M', device=device)
263
  # with torch.no_grad():
264
  # new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=model.wrapped.input_proj.weight.dtype)
 
281
  latent = latent.to(device)
282
  text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction]))
283
  print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
284
+ ref_latent = latent.clone().unsqueeze(0).to(device)
285
  t_1 = torch.randint(noise_start_t_e_type, noise_start_t_e_type + 1, (1,), device=device).long()
286
 
287
  noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)