Phung commited on
Commit
b815839
β€’
1 Parent(s): 3a7c3d3

Update gligen/ldm/models/diffusion/plms.py

Browse files
gligen/ldm/models/diffusion/plms.py CHANGED
@@ -119,11 +119,11 @@ class PLMSSampler(object):
119
 
120
  def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
121
  if index1 < 10:
122
- loss_scale = 3
123
- max_iter = 5
124
  elif index1 < 20:
125
- loss_scale = 2
126
- max_iter = 5
127
  else:
128
  loss_scale = 0.8
129
  max_iter = 1
@@ -156,7 +156,8 @@ class PLMSSampler(object):
156
  x = x - grad_cond
157
  x = x.detach()
158
  iteration += 1
159
- torch.cuda.empty_cache()
 
160
  return x
161
 
162
  def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):
 
119
 
120
  def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
121
  if index1 < 10:
122
+ loss_scale = 4
123
+ max_iter = 1
124
  elif index1 < 20:
125
+ loss_scale = 3
126
+ max_iter = 1
127
  else:
128
  loss_scale = 0.8
129
  max_iter = 1
 
156
  x = x - grad_cond
157
  x = x.detach()
158
  iteration += 1
159
+
160
+ torch.cuda.empty_cache()
161
  return x
162
 
163
  def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):