sayakpaul HF Staff commited on
Commit
746f9fc
·
1 Parent(s): 1b06612
Files changed (2) hide show
  1. app.py +3 -1
  2. fa3.py +112 -0
app.py CHANGED
@@ -6,6 +6,7 @@ from optimization import compile_transformer
6
  from hub_utils import _push_compiled_graph_to_hub
7
  from huggingface_hub import whoami
8
  import time
 
9
 
10
  # --- Model Loading ---
11
  dtype = torch.bfloat16
@@ -13,7 +14,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  # Load the model pipeline
15
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=dtype).to(device)
16
-
 
17
 
18
  @spaces.GPU(duration=1200)
19
  def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken, progress=gr.Progress(track_tqdm=True)):
 
6
  from hub_utils import _push_compiled_graph_to_hub
7
  from huggingface_hub import whoami
8
  import time
9
+ from fa3 import FlashFusedFluxAttnProcessor3_0
10
 
11
  # --- Model Loading ---
12
  dtype = torch.bfloat16
 
14
 
15
  # Load the model pipeline
16
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=dtype).to(device)
17
+ pipe.transformer.fuse_qkv_projections()
18
+ pipe.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
19
 
20
  @spaces.GPU(duration=1200)
21
  def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken, progress=gr.Progress(track_tqdm=True)):
fa3.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from kernels import get_kernel
3
+
4
+
5
+ _flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_func
6
+
7
+
8
+ @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
9
+ def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
10
+ outputs, lse = _flash_attn_func(q, k, v)
11
+ return outputs
12
+
13
+ @flash_attn_func.register_fake
14
+ def _(q, k, v, **kwargs):
15
+ # two outputs:
16
+ # 1. output: (batch, seq_len, num_heads, head_dim)
17
+ # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
18
+ meta_q = torch.empty_like(q).contiguous()
19
+ return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)
20
+
21
+ # Copied FusedFluxAttnProcessor2_0 but using flash v3 instead of SDPA
22
+ class FlashFusedFluxAttnProcessor3_0:
23
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
24
+
25
+ def __call__(
26
+ self,
27
+ attn,
28
+ hidden_states: torch.FloatTensor,
29
+ encoder_hidden_states: torch.FloatTensor | None = None,
30
+ attention_mask: torch.FloatTensor | None = None,
31
+ image_rotary_emb: torch.Tensor | None = None,
32
+ ) -> torch.FloatTensor:
33
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
34
+
35
+ # `sample` projections.
36
+ qkv = attn.to_qkv(hidden_states)
37
+ split_size = qkv.shape[-1] // 3
38
+ query, key, value = torch.split(qkv, split_size, dim=-1)
39
+
40
+ inner_dim = key.shape[-1]
41
+ head_dim = inner_dim // attn.heads
42
+
43
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
44
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
45
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
46
+
47
+ if attn.norm_q is not None:
48
+ query = attn.norm_q(query)
49
+ if attn.norm_k is not None:
50
+ key = attn.norm_k(key)
51
+
52
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
53
+ # `context` projections.
54
+ if encoder_hidden_states is not None:
55
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
56
+ split_size = encoder_qkv.shape[-1] // 3
57
+ (
58
+ encoder_hidden_states_query_proj,
59
+ encoder_hidden_states_key_proj,
60
+ encoder_hidden_states_value_proj,
61
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
62
+
63
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
64
+ batch_size, -1, attn.heads, head_dim
65
+ ).transpose(1, 2)
66
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
67
+ batch_size, -1, attn.heads, head_dim
68
+ ).transpose(1, 2)
69
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
70
+ batch_size, -1, attn.heads, head_dim
71
+ ).transpose(1, 2)
72
+
73
+ if attn.norm_added_q is not None:
74
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
75
+ if attn.norm_added_k is not None:
76
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
77
+
78
+ # attention
79
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
80
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
81
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
82
+
83
+ if image_rotary_emb is not None:
84
+ from diffusers.models.embeddings import apply_rotary_emb
85
+
86
+ query = apply_rotary_emb(query, image_rotary_emb)
87
+ key = apply_rotary_emb(key, image_rotary_emb)
88
+
89
+ # NB: transposes are necessary to match expected SDPA input shape
90
+ hidden_states = flash_attn_func(
91
+ query.transpose(1, 2),
92
+ key.transpose(1, 2),
93
+ value.transpose(1, 2))[0].transpose(1, 2)
94
+
95
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
96
+ hidden_states = hidden_states.to(query.dtype)
97
+
98
+ if encoder_hidden_states is not None:
99
+ encoder_hidden_states, hidden_states = (
100
+ hidden_states[:, : encoder_hidden_states.shape[1]],
101
+ hidden_states[:, encoder_hidden_states.shape[1] :],
102
+ )
103
+
104
+ # linear proj
105
+ hidden_states = attn.to_out[0](hidden_states)
106
+ # dropout
107
+ hidden_states = attn.to_out[1](hidden_states)
108
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
109
+
110
+ return hidden_states, encoder_hidden_states
111
+ else:
112
+ return hidden_states