Yinhong Liu
commited on
Commit
·
e5487ed
1
Parent(s):
79baf4f
fix sd3 pipeline
Browse files- sid/pipeline_sid_sd3.py +16 -12
sid/pipeline_sid_sd3.py
CHANGED
|
@@ -759,28 +759,32 @@ class SiDSD3Pipeline(
|
|
| 759 |
raise ValueError(f"Unknown noise_type: {noise_type}")
|
| 760 |
|
| 761 |
# Compute t value, normalized to [0, 1]
|
| 762 |
-
|
|
|
|
|
|
|
|
|
|
| 763 |
if use_sd3_shift:
|
| 764 |
shift = 3.0
|
| 765 |
t_val = shift * t_val / (1 + (shift - 1) * t_val)
|
|
|
|
| 766 |
t = torch.full((latents.shape[0],), t_val, device=latents.device, dtype=latents.dtype)
|
| 767 |
-
|
| 768 |
if t.numel() > 1:
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
flow_pred = self.transformer(
|
| 775 |
-
hidden_states=
|
| 776 |
encoder_hidden_states=prompt_embeds,
|
|
|
|
| 777 |
pooled_projections=pooled_prompt_embeds,
|
| 778 |
-
timestep=
|
| 779 |
return_dict=False,
|
| 780 |
)[0]
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
# 5. Decode latent to image
|
| 785 |
image = self.vae.decode((D_x / self.vae.config.scaling_factor) + self.vae.config.shift_factor, return_dict=False)[0]
|
| 786 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
|
|
|
| 759 |
raise ValueError(f"Unknown noise_type: {noise_type}")
|
| 760 |
|
| 761 |
# Compute t value, normalized to [0, 1]
|
| 762 |
+
init_timesteps = 999
|
| 763 |
+
scalar_t = float(init_timesteps) * (1.0 - float(i) / float(num_inference_steps))
|
| 764 |
+
t_val = scalar_t / 999.0
|
| 765 |
+
# t_val = 1.0 - float(i) / float(num_inference_steps)
|
| 766 |
if use_sd3_shift:
|
| 767 |
shift = 3.0
|
| 768 |
t_val = shift * t_val / (1 + (shift - 1) * t_val)
|
| 769 |
+
|
| 770 |
t = torch.full((latents.shape[0],), t_val, device=latents.device, dtype=latents.dtype)
|
| 771 |
+
t_flattern = t.flatten()
|
| 772 |
if t.numel() > 1:
|
| 773 |
+
t = t.view(-1, 1, 1, 1)
|
| 774 |
+
|
| 775 |
+
latents = (1.0 - t) * D_x + t * noise
|
| 776 |
+
latent_model_input = latents
|
| 777 |
+
|
| 778 |
flow_pred = self.transformer(
|
| 779 |
+
hidden_states=latent_model_input,
|
| 780 |
encoder_hidden_states=prompt_embeds,
|
| 781 |
+
#encoder_attention_mask=prompt_attention_mask,
|
| 782 |
pooled_projections=pooled_prompt_embeds,
|
| 783 |
+
timestep=1000*t_flattern,
|
| 784 |
return_dict=False,
|
| 785 |
)[0]
|
| 786 |
+
D_x = latents - (t * flow_pred if torch.numel(t) == 1 else t.view(-1, 1, 1, 1) * flow_pred)
|
| 787 |
+
|
|
|
|
| 788 |
# 5. Decode latent to image
|
| 789 |
image = self.vae.decode((D_x / self.vae.config.scaling_factor) + self.vae.config.shift_factor, return_dict=False)[0]
|
| 790 |
image = self.image_processor.postprocess(image, output_type=output_type)
|