jimmycarter
commited on
Upload pipeline.py
Browse files- pipeline.py +0 -99
pipeline.py
CHANGED
@@ -825,105 +825,6 @@ class FluxTransformer2DModelWithMasking(
|
|
825 |
|
826 |
return Transformer2DModelOutput(sample=output)
|
827 |
|
828 |
-
|
829 |
-
if __name__ == "__main__":
|
830 |
-
dtype = torch.bfloat16
|
831 |
-
bsz = 2
|
832 |
-
img = torch.rand((bsz, 16, 64, 64)).to("cuda", dtype=dtype)
|
833 |
-
timestep = torch.tensor([0.5, 0.5]).to("cuda", dtype=torch.float32)
|
834 |
-
pooled = torch.rand(bsz, 768).to("cuda", dtype=dtype)
|
835 |
-
text = torch.rand((bsz, 512, 4096)).to("cuda", dtype=dtype)
|
836 |
-
attn_mask = torch.tensor([[1.0] * 384 + [0.0] * 128] * bsz).to(
|
837 |
-
"cuda", dtype=dtype
|
838 |
-
) # Last 128 positions are masked
|
839 |
-
|
840 |
-
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
841 |
-
latents = latents.view(
|
842 |
-
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
|
843 |
-
)
|
844 |
-
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
845 |
-
latents = latents.reshape(
|
846 |
-
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
|
847 |
-
)
|
848 |
-
|
849 |
-
return latents
|
850 |
-
|
851 |
-
def _prepare_latent_image_ids(
|
852 |
-
batch_size, height, width, device="cuda", dtype=dtype
|
853 |
-
):
|
854 |
-
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
855 |
-
latent_image_ids[..., 1] = (
|
856 |
-
latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
857 |
-
)
|
858 |
-
latent_image_ids[..., 2] = (
|
859 |
-
latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
860 |
-
)
|
861 |
-
|
862 |
-
latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
|
863 |
-
latent_image_ids.shape
|
864 |
-
)
|
865 |
-
|
866 |
-
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
867 |
-
latent_image_ids = latent_image_ids.reshape(
|
868 |
-
batch_size,
|
869 |
-
latent_image_id_height * latent_image_id_width,
|
870 |
-
latent_image_id_channels,
|
871 |
-
)
|
872 |
-
|
873 |
-
return latent_image_ids.to(device=device, dtype=dtype)
|
874 |
-
|
875 |
-
txt_ids = torch.zeros(bsz, text.shape[1], 3).to(device="cuda", dtype=dtype)
|
876 |
-
|
877 |
-
vae_scale_factor = 16
|
878 |
-
height = 2 * (int(512) // vae_scale_factor)
|
879 |
-
width = 2 * (int(512) // vae_scale_factor)
|
880 |
-
img_ids = _prepare_latent_image_ids(bsz, height, width)
|
881 |
-
img = _pack_latents(img, img.shape[0], 16, height, width)
|
882 |
-
|
883 |
-
# Gotta go fast
|
884 |
-
transformer = FluxTransformer2DModelWithMasking.from_config(
|
885 |
-
{
|
886 |
-
"attention_head_dim": 128,
|
887 |
-
"guidance_embeds": True,
|
888 |
-
"in_channels": 64,
|
889 |
-
"joint_attention_dim": 4096,
|
890 |
-
"num_attention_heads": 24,
|
891 |
-
"num_layers": 4,
|
892 |
-
"num_single_layers": 8,
|
893 |
-
"patch_size": 1,
|
894 |
-
"pooled_projection_dim": 768,
|
895 |
-
}
|
896 |
-
).to("cuda", dtype=dtype)
|
897 |
-
|
898 |
-
guidance = torch.tensor([2.0], device="cuda")
|
899 |
-
guidance = guidance.expand(bsz)
|
900 |
-
|
901 |
-
with torch.no_grad():
|
902 |
-
no_mask = transformer(
|
903 |
-
img,
|
904 |
-
encoder_hidden_states=text,
|
905 |
-
pooled_projections=pooled,
|
906 |
-
timestep=timestep,
|
907 |
-
img_ids=img_ids,
|
908 |
-
txt_ids=txt_ids,
|
909 |
-
guidance=guidance,
|
910 |
-
)
|
911 |
-
mask = transformer(
|
912 |
-
img,
|
913 |
-
encoder_hidden_states=text,
|
914 |
-
pooled_projections=pooled,
|
915 |
-
timestep=timestep,
|
916 |
-
img_ids=img_ids,
|
917 |
-
txt_ids=txt_ids,
|
918 |
-
guidance=guidance,
|
919 |
-
attention_mask=attn_mask,
|
920 |
-
)
|
921 |
-
|
922 |
-
assert torch.allclose(no_mask.sample, mask.sample) is False
|
923 |
-
print("Attention masking test ran OK. Differences in output were detected.")
|
924 |
-
|
925 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
926 |
-
|
927 |
EXAMPLE_DOC_STRING = """
|
928 |
Examples:
|
929 |
```py
|
|
|
825 |
|
826 |
return Transformer2DModelOutput(sample=output)
|
827 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
828 |
EXAMPLE_DOC_STRING = """
|
829 |
Examples:
|
830 |
```py
|