tfwang commited on
Commit
430429e
1 Parent(s): 780da15

Update glide_text2im/glide_util.py

Browse files
Files changed (1) hide show
  1. glide_text2im/glide_util.py +2 -2
glide_text2im/glide_util.py CHANGED
@@ -44,7 +44,7 @@ def sample(
44
  uncond_ref = th.ones_like(cond_ref)
45
 
46
  model_kwargs = {}
47
- model_kwargs['ref'] = th.cat([cond_ref, uncond_ref], 0).cuda()
48
 
49
  def cfg_model_fn(x_t, ts, **kwargs):
50
  half = x_t[: len(x_t) // 2]
@@ -60,7 +60,7 @@ def sample(
60
 
61
 
62
  if upsample_enabled:
63
- model_kwargs['low_res'] = prompt['low_res'].cuda()
64
  noise = th.randn((batch_size, 3, side_y, side_x), device=device) * upsample_temp
65
  model_fn = glide_model # just use the base model, no need for CFG.
66
  model_kwargs['ref'] = model_kwargs['ref'][:batch_size]
44
  uncond_ref = th.ones_like(cond_ref)
45
 
46
  model_kwargs = {}
47
+ model_kwargs['ref'] = th.cat([cond_ref, uncond_ref], 0).to(device)
48
 
49
  def cfg_model_fn(x_t, ts, **kwargs):
50
  half = x_t[: len(x_t) // 2]
60
 
61
 
62
  if upsample_enabled:
63
+ model_kwargs['low_res'] = prompt['low_res'].to(device)
64
  noise = th.randn((batch_size, 3, side_y, side_x), device=device) * upsample_temp
65
  model_fn = glide_model # just use the base model, no need for CFG.
66
  model_kwargs['ref'] = model_kwargs['ref'][:batch_size]