multimodalart HF staff commited on
Commit
918aa0f
1 Parent(s): 5e6effb

devices fix

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -75,11 +75,11 @@ def run_all(prompt, steps, n_images, weight, clip_guided):
75
 
76
  for prompt in prompts:
77
  txt, weight = parse_prompt(prompt)
78
- target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
79
  weights.append(weight)
80
 
81
  target_embeds = torch.cat(target_embeds)
82
- weights = torch.tensor(weights, device=device)
83
  if weights.sum().abs() < 1e-3:
84
  raise RuntimeError('The weights must not sum to 0.')
85
  weights /= weights.sum().abs()
 
75
 
76
  for prompt in prompts:
77
  txt, weight = parse_prompt(prompt)
78
+ target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to('cuda')).float())
79
  weights.append(weight)
80
 
81
  target_embeds = torch.cat(target_embeds)
82
+ weights = torch.tensor(weights, device='cuda')
83
  if weights.sum().abs() < 1e-3:
84
  raise RuntimeError('The weights must not sum to 0.')
85
  weights /= weights.sum().abs()