Spaces:
Running
on
Zero
Running
on
Zero
up
Browse files
app.py
CHANGED
|
@@ -6,6 +6,7 @@ from optimization import compile_transformer
|
|
| 6 |
from hub_utils import _push_compiled_graph_to_hub
|
| 7 |
from huggingface_hub import whoami
|
| 8 |
import time
|
|
|
|
| 9 |
|
| 10 |
# --- Model Loading ---
|
| 11 |
dtype = torch.bfloat16
|
|
@@ -13,7 +14,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
| 13 |
|
| 14 |
# Load the model pipeline
|
| 15 |
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=dtype).to(device)
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
@spaces.GPU(duration=1200)
|
| 19 |
def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
| 6 |
from hub_utils import _push_compiled_graph_to_hub
|
| 7 |
from huggingface_hub import whoami
|
| 8 |
import time
|
| 9 |
+
from fa3 import FlashFusedFluxAttnProcessor3_0
|
| 10 |
|
| 11 |
# --- Model Loading ---
|
| 12 |
dtype = torch.bfloat16
|
|
|
|
| 14 |
|
| 15 |
# Load the model pipeline
|
| 16 |
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=dtype).to(device)
|
| 17 |
+
pipe.transformer.fuse_qkv_projections()
|
| 18 |
+
pipe.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
|
| 19 |
|
| 20 |
@spaces.GPU(duration=1200)
|
| 21 |
def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken, progress=gr.Progress(track_tqdm=True)):
|
fa3.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from kernels import get_kernel
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
_flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_func
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
|
| 9 |
+
def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
| 10 |
+
outputs, lse = _flash_attn_func(q, k, v)
|
| 11 |
+
return outputs
|
| 12 |
+
|
| 13 |
+
@flash_attn_func.register_fake
|
| 14 |
+
def _(q, k, v, **kwargs):
|
| 15 |
+
# two outputs:
|
| 16 |
+
# 1. output: (batch, seq_len, num_heads, head_dim)
|
| 17 |
+
# 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
|
| 18 |
+
meta_q = torch.empty_like(q).contiguous()
|
| 19 |
+
return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)
|
| 20 |
+
|
| 21 |
+
# Copied FusedFluxAttnProcessor2_0 but using flash v3 instead of SDPA
|
| 22 |
+
class FlashFusedFluxAttnProcessor3_0:
|
| 23 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
| 24 |
+
|
| 25 |
+
def __call__(
|
| 26 |
+
self,
|
| 27 |
+
attn,
|
| 28 |
+
hidden_states: torch.FloatTensor,
|
| 29 |
+
encoder_hidden_states: torch.FloatTensor | None = None,
|
| 30 |
+
attention_mask: torch.FloatTensor | None = None,
|
| 31 |
+
image_rotary_emb: torch.Tensor | None = None,
|
| 32 |
+
) -> torch.FloatTensor:
|
| 33 |
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 34 |
+
|
| 35 |
+
# `sample` projections.
|
| 36 |
+
qkv = attn.to_qkv(hidden_states)
|
| 37 |
+
split_size = qkv.shape[-1] // 3
|
| 38 |
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
| 39 |
+
|
| 40 |
+
inner_dim = key.shape[-1]
|
| 41 |
+
head_dim = inner_dim // attn.heads
|
| 42 |
+
|
| 43 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 44 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 45 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 46 |
+
|
| 47 |
+
if attn.norm_q is not None:
|
| 48 |
+
query = attn.norm_q(query)
|
| 49 |
+
if attn.norm_k is not None:
|
| 50 |
+
key = attn.norm_k(key)
|
| 51 |
+
|
| 52 |
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
| 53 |
+
# `context` projections.
|
| 54 |
+
if encoder_hidden_states is not None:
|
| 55 |
+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
| 56 |
+
split_size = encoder_qkv.shape[-1] // 3
|
| 57 |
+
(
|
| 58 |
+
encoder_hidden_states_query_proj,
|
| 59 |
+
encoder_hidden_states_key_proj,
|
| 60 |
+
encoder_hidden_states_value_proj,
|
| 61 |
+
) = torch.split(encoder_qkv, split_size, dim=-1)
|
| 62 |
+
|
| 63 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
| 64 |
+
batch_size, -1, attn.heads, head_dim
|
| 65 |
+
).transpose(1, 2)
|
| 66 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
| 67 |
+
batch_size, -1, attn.heads, head_dim
|
| 68 |
+
).transpose(1, 2)
|
| 69 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
| 70 |
+
batch_size, -1, attn.heads, head_dim
|
| 71 |
+
).transpose(1, 2)
|
| 72 |
+
|
| 73 |
+
if attn.norm_added_q is not None:
|
| 74 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
| 75 |
+
if attn.norm_added_k is not None:
|
| 76 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
| 77 |
+
|
| 78 |
+
# attention
|
| 79 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
| 80 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
| 81 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
| 82 |
+
|
| 83 |
+
if image_rotary_emb is not None:
|
| 84 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 85 |
+
|
| 86 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 87 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 88 |
+
|
| 89 |
+
# NB: transposes are necessary to match expected SDPA input shape
|
| 90 |
+
hidden_states = flash_attn_func(
|
| 91 |
+
query.transpose(1, 2),
|
| 92 |
+
key.transpose(1, 2),
|
| 93 |
+
value.transpose(1, 2))[0].transpose(1, 2)
|
| 94 |
+
|
| 95 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 96 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 97 |
+
|
| 98 |
+
if encoder_hidden_states is not None:
|
| 99 |
+
encoder_hidden_states, hidden_states = (
|
| 100 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
| 101 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# linear proj
|
| 105 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 106 |
+
# dropout
|
| 107 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 108 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 109 |
+
|
| 110 |
+
return hidden_states, encoder_hidden_states
|
| 111 |
+
else:
|
| 112 |
+
return hidden_states
|