jimmycarter commited on
Commit
bf281ac
·
verified ·
1 Parent(s): e3d7ec7

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +0 -99
pipeline.py CHANGED
@@ -825,105 +825,6 @@ class FluxTransformer2DModelWithMasking(
825
 
826
  return Transformer2DModelOutput(sample=output)
827
 
828
-
829
- if __name__ == "__main__":
830
- dtype = torch.bfloat16
831
- bsz = 2
832
- img = torch.rand((bsz, 16, 64, 64)).to("cuda", dtype=dtype)
833
- timestep = torch.tensor([0.5, 0.5]).to("cuda", dtype=torch.float32)
834
- pooled = torch.rand(bsz, 768).to("cuda", dtype=dtype)
835
- text = torch.rand((bsz, 512, 4096)).to("cuda", dtype=dtype)
836
- attn_mask = torch.tensor([[1.0] * 384 + [0.0] * 128] * bsz).to(
837
- "cuda", dtype=dtype
838
- ) # Last 128 positions are masked
839
-
840
- def _pack_latents(latents, batch_size, num_channels_latents, height, width):
841
- latents = latents.view(
842
- batch_size, num_channels_latents, height // 2, 2, width // 2, 2
843
- )
844
- latents = latents.permute(0, 2, 4, 1, 3, 5)
845
- latents = latents.reshape(
846
- batch_size, (height // 2) * (width // 2), num_channels_latents * 4
847
- )
848
-
849
- return latents
850
-
851
- def _prepare_latent_image_ids(
852
- batch_size, height, width, device="cuda", dtype=dtype
853
- ):
854
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
855
- latent_image_ids[..., 1] = (
856
- latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
857
- )
858
- latent_image_ids[..., 2] = (
859
- latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
860
- )
861
-
862
- latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
863
- latent_image_ids.shape
864
- )
865
-
866
- latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
867
- latent_image_ids = latent_image_ids.reshape(
868
- batch_size,
869
- latent_image_id_height * latent_image_id_width,
870
- latent_image_id_channels,
871
- )
872
-
873
- return latent_image_ids.to(device=device, dtype=dtype)
874
-
875
- txt_ids = torch.zeros(bsz, text.shape[1], 3).to(device="cuda", dtype=dtype)
876
-
877
- vae_scale_factor = 16
878
- height = 2 * (int(512) // vae_scale_factor)
879
- width = 2 * (int(512) // vae_scale_factor)
880
- img_ids = _prepare_latent_image_ids(bsz, height, width)
881
- img = _pack_latents(img, img.shape[0], 16, height, width)
882
-
883
- # Gotta go fast
884
- transformer = FluxTransformer2DModelWithMasking.from_config(
885
- {
886
- "attention_head_dim": 128,
887
- "guidance_embeds": True,
888
- "in_channels": 64,
889
- "joint_attention_dim": 4096,
890
- "num_attention_heads": 24,
891
- "num_layers": 4,
892
- "num_single_layers": 8,
893
- "patch_size": 1,
894
- "pooled_projection_dim": 768,
895
- }
896
- ).to("cuda", dtype=dtype)
897
-
898
- guidance = torch.tensor([2.0], device="cuda")
899
- guidance = guidance.expand(bsz)
900
-
901
- with torch.no_grad():
902
- no_mask = transformer(
903
- img,
904
- encoder_hidden_states=text,
905
- pooled_projections=pooled,
906
- timestep=timestep,
907
- img_ids=img_ids,
908
- txt_ids=txt_ids,
909
- guidance=guidance,
910
- )
911
- mask = transformer(
912
- img,
913
- encoder_hidden_states=text,
914
- pooled_projections=pooled,
915
- timestep=timestep,
916
- img_ids=img_ids,
917
- txt_ids=txt_ids,
918
- guidance=guidance,
919
- attention_mask=attn_mask,
920
- )
921
-
922
- assert torch.allclose(no_mask.sample, mask.sample) is False
923
- print("Attention masking test ran OK. Differences in output were detected.")
924
-
925
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
926
-
927
  EXAMPLE_DOC_STRING = """
928
  Examples:
929
  ```py
 
825
 
826
  return Transformer2DModelOutput(sample=output)
827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
828
  EXAMPLE_DOC_STRING = """
829
  Examples:
830
  ```py