10x demo speedup

#1
by cbensimon HF Staff - opened
Files changed (4) hide show
  1. aoti.py +17 -0
  2. app.py +8 -5
  3. fa3.py +18 -0
  4. requirements.txt +2 -1
aoti.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+ from spaces.zero.torch.aoti import ZeroGPUCompiledModel
7
+ from spaces.zero.torch.aoti import ZeroGPUWeights
8
+
9
+
10
+ def aoti_load(module: torch.nn.Module, repo_id: str):
11
+ repeated_blocks = module._repeated_blocks
12
+ aoti_files = {name: hf_hub_download(repo_id, f'{name}.pt2') for name in repeated_blocks}
13
+ for block_name, aoti_file in aoti_files.items():
14
+ for block in module.modules():
15
+ if block.__class__.__name__ == block_name:
16
+ weights = ZeroGPUWeights(block.state_dict())
17
+ block.forward = ZeroGPUCompiledModel(aoti_file, weights)
app.py CHANGED
@@ -1,9 +1,11 @@
 
 
 
1
  import sys
2
  sys.path.append('./')
3
 
4
  import gradio as gr
5
  import spaces
6
- import os
7
  import sys
8
  import subprocess
9
  import numpy as np
@@ -61,10 +63,11 @@ canny = CannyDetector()
61
  anyline = AnylineDetector.from_pretrained("TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline")
62
  open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
63
 
64
- torch.backends.cuda.matmul.allow_tf32 = True
65
- pipe.vae.enable_tiling()
66
- pipe.vae.enable_slicing()
67
- pipe.enable_model_cpu_offload() # for saving memory
 
68
 
69
  def convert_from_image_to_cv2(img: Image) -> np.ndarray:
70
  return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
 
1
+ import os
2
+ os.system("pip install --upgrade spaces")
3
+
4
  import sys
5
  sys.path.append('./')
6
 
7
  import gradio as gr
8
  import spaces
 
9
  import sys
10
  import subprocess
11
  import numpy as np
 
63
  anyline = AnylineDetector.from_pretrained("TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline")
64
  open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
65
 
66
+ import fa3
67
+ from aoti import aoti_load
68
+
69
+ pipe.transformer.fuse_qkv_projections()
70
+ aoti_load(pipe.transformer, 'zerogpu-aoti/FLUX.1')
71
 
72
  def convert_from_image_to_cv2(img: Image) -> np.ndarray:
73
  return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
fa3.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ import torch
5
+ from kernels import get_kernel
6
+
7
+
8
+ _flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_func
9
+
10
+
11
+ @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
12
+ def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
13
+ outputs, lse = _flash_attn_func(q, k, v)
14
+ return outputs
15
+
16
+ @flash_attn_func.register_fake
17
+ def _(q, k, v, **kwargs):
18
+ return torch.empty_like(q).contiguous()
requirements.txt CHANGED
@@ -14,4 +14,5 @@ xformers
14
  sentencepiece
15
  peft
16
  scipy
17
- scikit-image
 
 
14
  sentencepiece
15
  peft
16
  scipy
17
+ scikit-image
18
+ kernels