Spaces:
Runtime error
Runtime error
Commit
·
fd16ff8
1
Parent(s):
ab8308a
update space
Browse files
app.py
CHANGED
|
@@ -257,7 +257,8 @@ def main():
|
|
| 257 |
elif 'cyber' in instruction:
|
| 258 |
e_type = 'cyber'
|
| 259 |
|
| 260 |
-
model = models[e_type]
|
|
|
|
| 261 |
# model = load_model('text300M', device=device)
|
| 262 |
# with torch.no_grad():
|
| 263 |
# new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=model.wrapped.input_proj.weight.dtype)
|
|
@@ -280,7 +281,7 @@ def main():
|
|
| 280 |
latent = latent.to(device)
|
| 281 |
text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction]))
|
| 282 |
print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
|
| 283 |
-
ref_latent = latent.clone().unsqueeze(0)
|
| 284 |
t_1 = torch.randint(noise_start_t_e_type, noise_start_t_e_type + 1, (1,), device=device).long()
|
| 285 |
|
| 286 |
noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
|
|
|
|
| 257 |
elif 'cyber' in instruction:
|
| 258 |
e_type = 'cyber'
|
| 259 |
|
| 260 |
+
model = models[e_type]
|
| 261 |
+
model = model.to(device)
|
| 262 |
# model = load_model('text300M', device=device)
|
| 263 |
# with torch.no_grad():
|
| 264 |
# new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=model.wrapped.input_proj.weight.dtype)
|
|
|
|
| 281 |
latent = latent.to(device)
|
| 282 |
text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction]))
|
| 283 |
print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
|
| 284 |
+
ref_latent = latent.clone().unsqueeze(0).to(device)
|
| 285 |
t_1 = torch.randint(noise_start_t_e_type, noise_start_t_e_type + 1, (1,), device=device).long()
|
| 286 |
|
| 287 |
noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
|