rayli commited on
Commit
392bda6
1 Parent(s): f32f04f

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +1 -1
model.py CHANGED
@@ -1587,7 +1587,7 @@ class UNet2DDragConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMi
1587
 
1588
  bsz, num_drags, drag_dim = drags.shape
1589
  assert num_drags == self.num_drags
1590
- if (self.train and self.drag_dropout_prob > 0) or force_drop_ids is not None:
1591
  if force_drop_ids is None:
1592
  drop_ids = torch.rand(bsz, device=x_cond_extra.device) < self.drag_dropout_prob
1593
  else:
 
1587
 
1588
  bsz, num_drags, drag_dim = drags.shape
1589
  assert num_drags == self.num_drags
1590
+ if (self.training and self.drag_dropout_prob > 0) or force_drop_ids is not None:
1591
  if force_drop_ids is None:
1592
  drop_ids = torch.rand(bsz, device=x_cond_extra.device) < self.drag_dropout_prob
1593
  else: