ironjr commited on
Commit
e687c57
1 Parent(s): 42080d8

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +1 -6
model.py CHANGED
@@ -78,7 +78,7 @@ class StreamMultiDiffusion(nn.Module):
78
  self.default_mask_strength = default_mask_strength
79
  self.default_prompt_strength = default_prompt_strength
80
  self.register_buffer('bootstrap_steps', (
81
- bootstrap_steps > torch.arange(len(t_index_list))).to(dtype=self.dtype, device=self.device))
82
  self.bootstrap_mix_steps = bootstrap_mix_steps
83
  self.register_buffer('bootstrap_mix_ratios', (
84
  bootstrap_mix_steps - torch.arange(len(t_index_list), device=self.device)).clip_(0, 1).to(self.dtype))
@@ -1091,8 +1091,6 @@ class StreamMultiDiffusion(nn.Module):
1091
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1092
  p = self.num_layers
1093
  x_t_latent = x_t_latent.repeat_interleave(p, dim=0) # (T * p, 4, h, w)
1094
- print('111111111111111111111')
1095
-
1096
  if self.bootstrap_steps[0] > 0:
1097
  # Background bootstrapping.
1098
  bootstrap_latent = self.scheduler.add_noise(
@@ -1100,7 +1098,6 @@ class StreamMultiDiffusion(nn.Module):
1100
  self.stock_noise,
1101
  torch.tensor(self.sub_timesteps_tensor, device=self.device),
1102
  )
1103
- print('111111111111111111111', self.bootstrap_steps)
1104
 
1105
  x_t_latent = rearrange(x_t_latent, '(t p) c h w -> p t c h w', p=p)
1106
  bootstrap_mask = (
@@ -1109,11 +1106,9 @@ class StreamMultiDiffusion(nn.Module):
1109
  ) # (p, t, c, h, w)
1110
  x_t_latent = (1.0 - bootstrap_mask) * bootstrap_latent[None] + bootstrap_mask * x_t_latent
1111
  x_t_latent = rearrange(x_t_latent, 'p t c h w -> (t p) c h w')
1112
- print('222222222222222222222')
1113
 
1114
  # Centering.
1115
  x_t_latent = shift_to_mask_bbox_center(x_t_latent, rearrange(self.masks, 'p t c h w -> (t p) c h w'), reverse=True)
1116
- print('333333333333333333333')
1117
 
1118
  t_list = self.sub_timesteps_tensor_ # (T * p,)
1119
  if self.guidance_scale > 1.0 and self.cfg_type == 'initialize':
 
78
  self.default_mask_strength = default_mask_strength
79
  self.default_prompt_strength = default_prompt_strength
80
  self.register_buffer('bootstrap_steps', (
81
+ bootstrap_steps > torch.arange(len(t_index_list))).float().to(dtype=self.dtype, device=self.device))
82
  self.bootstrap_mix_steps = bootstrap_mix_steps
83
  self.register_buffer('bootstrap_mix_ratios', (
84
  bootstrap_mix_steps - torch.arange(len(t_index_list), device=self.device)).clip_(0, 1).to(self.dtype))
 
1091
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1092
  p = self.num_layers
1093
  x_t_latent = x_t_latent.repeat_interleave(p, dim=0) # (T * p, 4, h, w)
 
 
1094
  if self.bootstrap_steps[0] > 0:
1095
  # Background bootstrapping.
1096
  bootstrap_latent = self.scheduler.add_noise(
 
1098
  self.stock_noise,
1099
  torch.tensor(self.sub_timesteps_tensor, device=self.device),
1100
  )
 
1101
 
1102
  x_t_latent = rearrange(x_t_latent, '(t p) c h w -> p t c h w', p=p)
1103
  bootstrap_mask = (
 
1106
  ) # (p, t, c, h, w)
1107
  x_t_latent = (1.0 - bootstrap_mask) * bootstrap_latent[None] + bootstrap_mask * x_t_latent
1108
  x_t_latent = rearrange(x_t_latent, 'p t c h w -> (t p) c h w')
 
1109
 
1110
  # Centering.
1111
  x_t_latent = shift_to_mask_bbox_center(x_t_latent, rearrange(self.masks, 'p t c h w -> (t p) c h w'), reverse=True)
 
1112
 
1113
  t_list = self.sub_timesteps_tensor_ # (T * p,)
1114
  if self.guidance_scale > 1.0 and self.cfg_type == 'initialize':