cbensimon HF Staff commited on
Commit
d70d883
·
1 Parent(s): bd3cfcb
Files changed (4) hide show
  1. aoti.py +17 -0
  2. app.py +9 -1
  3. fa3.py +115 -0
  4. requirements.txt +2 -1
aoti.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+ from spaces.zero.torch.aoti import ZeroGPUCompiledModel
7
+ from spaces.zero.torch.aoti import ZeroGPUWeights
8
+
9
+
10
+ def aoti_load(module: torch.nn.Module, repo_id: str):
11
+ repeated_blocks = module._repeated_blocks
12
+ aoti_files = {name: hf_hub_download(repo_id, f'{name}.pt2') for name in repeated_blocks}
13
+ for block_name, aoti_file in aoti_files.items():
14
+ for block in module.modules():
15
+ if block.__class__.__name__ == block_name:
16
+ weights = ZeroGPUWeights(block.state_dict())
17
+ block.forward = ZeroGPUCompiledModel(aoti_file, weights)
app.py CHANGED
@@ -1,9 +1,11 @@
 
 
 
1
  import sys
2
  sys.path.append('./')
3
 
4
  import gradio as gr
5
  import spaces
6
- import os
7
  import sys
8
  import subprocess
9
  import numpy as np
@@ -57,6 +59,12 @@ canny = CannyDetector()
57
  anyline = AnylineDetector.from_pretrained("TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline")
58
  open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
59
 
 
 
 
 
 
 
60
  def convert_from_image_to_cv2(img: Image) -> np.ndarray:
61
  return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
62
 
 
1
+ import os
2
+ os.system("pip install --upgrade spaces")
3
+
4
  import sys
5
  sys.path.append('./')
6
 
7
  import gradio as gr
8
  import spaces
 
9
  import sys
10
  import subprocess
11
  import numpy as np
 
59
  anyline = AnylineDetector.from_pretrained("TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline")
60
  open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
61
 
62
+ import fa3
63
+ from aoti import aoti_load
64
+
65
+ pipe.transformer.fuse_qkv_projections()
66
+ aoti_load(pipe.transformer, 'zerogpu-aoti/FLUX.1')
67
+
68
  def convert_from_image_to_cv2(img: Image) -> np.ndarray:
69
  return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
70
 
fa3.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ import torch
5
+ from kernels import get_kernel
6
+
7
+
8
+ _flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_func
9
+
10
+
11
+ @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
12
+ def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
13
+ outputs, lse = _flash_attn_func(q, k, v)
14
+ return outputs
15
+
16
+ @flash_attn_func.register_fake
17
+ def _(q, k, v, **kwargs):
18
+ # two outputs:
19
+ # 1. output: (batch, seq_len, num_heads, head_dim)
20
+ # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
21
+ meta_q = torch.empty_like(q).contiguous()
22
+ return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)
23
+
24
+ # Copied FusedFluxAttnProcessor2_0 but using flash v3 instead of SDPA
25
+ class FlashFusedFluxAttnProcessor3_0:
26
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
27
+
28
+ def __call__(
29
+ self,
30
+ attn,
31
+ hidden_states: torch.FloatTensor,
32
+ encoder_hidden_states: torch.FloatTensor | None = None,
33
+ attention_mask: torch.FloatTensor | None = None,
34
+ image_rotary_emb: torch.Tensor | None = None,
35
+ ) -> torch.FloatTensor:
36
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
37
+
38
+ # `sample` projections.
39
+ qkv = attn.to_qkv(hidden_states)
40
+ split_size = qkv.shape[-1] // 3
41
+ query, key, value = torch.split(qkv, split_size, dim=-1)
42
+
43
+ inner_dim = key.shape[-1]
44
+ head_dim = inner_dim // attn.heads
45
+
46
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
47
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
48
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
49
+
50
+ if attn.norm_q is not None:
51
+ query = attn.norm_q(query)
52
+ if attn.norm_k is not None:
53
+ key = attn.norm_k(key)
54
+
55
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
56
+ # `context` projections.
57
+ if encoder_hidden_states is not None:
58
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
59
+ split_size = encoder_qkv.shape[-1] // 3
60
+ (
61
+ encoder_hidden_states_query_proj,
62
+ encoder_hidden_states_key_proj,
63
+ encoder_hidden_states_value_proj,
64
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
65
+
66
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
67
+ batch_size, -1, attn.heads, head_dim
68
+ ).transpose(1, 2)
69
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
70
+ batch_size, -1, attn.heads, head_dim
71
+ ).transpose(1, 2)
72
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
73
+ batch_size, -1, attn.heads, head_dim
74
+ ).transpose(1, 2)
75
+
76
+ if attn.norm_added_q is not None:
77
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
78
+ if attn.norm_added_k is not None:
79
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
80
+
81
+ # attention
82
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
83
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
84
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
85
+
86
+ if image_rotary_emb is not None:
87
+ from diffusers.models.embeddings import apply_rotary_emb
88
+
89
+ query = apply_rotary_emb(query, image_rotary_emb)
90
+ key = apply_rotary_emb(key, image_rotary_emb)
91
+
92
+ # NB: transposes are necessary to match expected SDPA input shape
93
+ hidden_states = flash_attn_func(
94
+ query.transpose(1, 2),
95
+ key.transpose(1, 2),
96
+ value.transpose(1, 2))[0].transpose(1, 2)
97
+
98
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
99
+ hidden_states = hidden_states.to(query.dtype)
100
+
101
+ if encoder_hidden_states is not None:
102
+ encoder_hidden_states, hidden_states = (
103
+ hidden_states[:, : encoder_hidden_states.shape[1]],
104
+ hidden_states[:, encoder_hidden_states.shape[1] :],
105
+ )
106
+
107
+ # linear proj
108
+ hidden_states = attn.to_out[0](hidden_states)
109
+ # dropout
110
+ hidden_states = attn.to_out[1](hidden_states)
111
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
112
+
113
+ return hidden_states, encoder_hidden_states
114
+ else:
115
+ return hidden_states
requirements.txt CHANGED
@@ -14,4 +14,5 @@ xformers
14
  sentencepiece
15
  peft
16
  scipy
17
- scikit-image
 
 
14
  sentencepiece
15
  peft
16
  scipy
17
+ scikit-image
18
+ kernels