Spaces:
Running
on
Zero
Running
on
Zero
10x demo speedup
#1
by
cbensimon
HF Staff
- opened
aoti.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
from spaces.zero.torch.aoti import ZeroGPUCompiledModel
|
7 |
+
from spaces.zero.torch.aoti import ZeroGPUWeights
|
8 |
+
|
9 |
+
|
10 |
+
def aoti_load(module: torch.nn.Module, repo_id: str):
|
11 |
+
repeated_blocks = module._repeated_blocks
|
12 |
+
aoti_files = {name: hf_hub_download(repo_id, f'{name}.pt2') for name in repeated_blocks}
|
13 |
+
for block_name, aoti_file in aoti_files.items():
|
14 |
+
for block in module.modules():
|
15 |
+
if block.__class__.__name__ == block_name:
|
16 |
+
weights = ZeroGPUWeights(block.state_dict())
|
17 |
+
block.forward = ZeroGPUCompiledModel(aoti_file, weights)
|
app.py
CHANGED
@@ -1,9 +1,11 @@
|
|
|
|
|
|
|
|
1 |
import sys
|
2 |
sys.path.append('./')
|
3 |
|
4 |
import gradio as gr
|
5 |
import spaces
|
6 |
-
import os
|
7 |
import sys
|
8 |
import subprocess
|
9 |
import numpy as np
|
@@ -61,10 +63,11 @@ canny = CannyDetector()
|
|
61 |
anyline = AnylineDetector.from_pretrained("TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline")
|
62 |
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
pipe.
|
|
|
68 |
|
69 |
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
70 |
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
|
|
1 |
+
import os
|
2 |
+
os.system("pip install --upgrade spaces")
|
3 |
+
|
4 |
import sys
|
5 |
sys.path.append('./')
|
6 |
|
7 |
import gradio as gr
|
8 |
import spaces
|
|
|
9 |
import sys
|
10 |
import subprocess
|
11 |
import numpy as np
|
|
|
63 |
anyline = AnylineDetector.from_pretrained("TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline")
|
64 |
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
65 |
|
66 |
+
import fa3
|
67 |
+
from aoti import aoti_load
|
68 |
+
|
69 |
+
pipe.transformer.fuse_qkv_projections()
|
70 |
+
aoti_load(pipe.transformer, 'zerogpu-aoti/FLUX.1')
|
71 |
|
72 |
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
73 |
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
fa3.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from kernels import get_kernel
|
6 |
+
|
7 |
+
|
8 |
+
_flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_func
|
9 |
+
|
10 |
+
|
11 |
+
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
|
12 |
+
def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
13 |
+
outputs, lse = _flash_attn_func(q, k, v)
|
14 |
+
return outputs
|
15 |
+
|
16 |
+
@flash_attn_func.register_fake
|
17 |
+
def _(q, k, v, **kwargs):
|
18 |
+
return torch.empty_like(q).contiguous()
|
requirements.txt
CHANGED
@@ -14,4 +14,5 @@ xformers
|
|
14 |
sentencepiece
|
15 |
peft
|
16 |
scipy
|
17 |
-
scikit-image
|
|
|
|
14 |
sentencepiece
|
15 |
peft
|
16 |
scipy
|
17 |
+
scikit-image
|
18 |
+
kernels
|