Yinhong Liu commited on
Commit
e5487ed
·
1 Parent(s): 79baf4f

fix sd3 pipeline

Browse files
Files changed (1) hide show
  1. sid/pipeline_sid_sd3.py +16 -12
sid/pipeline_sid_sd3.py CHANGED
@@ -759,28 +759,32 @@ class SiDSD3Pipeline(
759
  raise ValueError(f"Unknown noise_type: {noise_type}")
760
 
761
  # Compute t value, normalized to [0, 1]
762
- t_val = 1.0 - float(i) / float(num_inference_steps)
 
 
 
763
  if use_sd3_shift:
764
  shift = 3.0
765
  t_val = shift * t_val / (1 + (shift - 1) * t_val)
 
766
  t = torch.full((latents.shape[0],), t_val, device=latents.device, dtype=latents.dtype)
767
- t_flatten = t.flatten()
768
  if t.numel() > 1:
769
- t_view = t.view(-1, 1, 1, 1)
770
- else:
771
- t_view = t
772
- # SiD update
773
- latents = (1.0 - t_view) * D_x + t_view * noise
774
  flow_pred = self.transformer(
775
- hidden_states=latents,
776
  encoder_hidden_states=prompt_embeds,
 
777
  pooled_projections=pooled_prompt_embeds,
778
- timestep=t_flatten,
779
  return_dict=False,
780
  )[0]
781
-
782
- D_x = latents - (t_view * flow_pred if torch.numel(t_view) == 1 else t_view.view(-1, 1, 1, 1) * flow_pred)
783
-
784
  # 5. Decode latent to image
785
  image = self.vae.decode((D_x / self.vae.config.scaling_factor) + self.vae.config.shift_factor, return_dict=False)[0]
786
  image = self.image_processor.postprocess(image, output_type=output_type)
 
759
  raise ValueError(f"Unknown noise_type: {noise_type}")
760
 
761
  # Compute t value, normalized to [0, 1]
762
+ init_timesteps = 999
763
+ scalar_t = float(init_timesteps) * (1.0 - float(i) / float(num_inference_steps))
764
+ t_val = scalar_t / 999.0
765
+ # t_val = 1.0 - float(i) / float(num_inference_steps)
766
  if use_sd3_shift:
767
  shift = 3.0
768
  t_val = shift * t_val / (1 + (shift - 1) * t_val)
769
+
770
  t = torch.full((latents.shape[0],), t_val, device=latents.device, dtype=latents.dtype)
771
+ t_flattern = t.flatten()
772
  if t.numel() > 1:
773
+ t = t.view(-1, 1, 1, 1)
774
+
775
+ latents = (1.0 - t) * D_x + t * noise
776
+ latent_model_input = latents
777
+
778
  flow_pred = self.transformer(
779
+ hidden_states=latent_model_input,
780
  encoder_hidden_states=prompt_embeds,
781
+ #encoder_attention_mask=prompt_attention_mask,
782
  pooled_projections=pooled_prompt_embeds,
783
+ timestep=1000*t_flattern,
784
  return_dict=False,
785
  )[0]
786
+ D_x = latents - (t * flow_pred if torch.numel(t) == 1 else t.view(-1, 1, 1, 1) * flow_pred)
787
+
 
788
  # 5. Decode latent to image
789
  image = self.vae.decode((D_x / self.vae.config.scaling_factor) + self.vae.config.shift_factor, return_dict=False)[0]
790
  image = self.image_processor.postprocess(image, output_type=output_type)