Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,263 Bytes
ebbd677 b22b80e b05966a f9f24d7 f5a3617 ee02270 6684d62 746f9fc afa2559 b05966a b22b80e b05966a 13352a6 746f9fc 4b0fe46 ee4246b f61fb8b ee02270 4b0fe46 ee02270 7bf5ca7 b22b80e 7bf5ca7 6684d62 7bf5ca7 6684d62 452391a b22b80e 4b0fe46 7bf5ca7 4b0fe46 7bf5ca7 4b0fe46 7bf5ca7 b22b80e 7bf5ca7 4b0fe46 104a64a 9c2430d ee02270 9c2430d 7bf5ca7 9c2430d 1b06612 9c2430d ee02270 b22b80e 4b0fe46 7bf5ca7 4b0fe46 7bf5ca7 4b0fe46 7bf5ca7 4b0fe46 7bf5ca7 4b0fe46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import spaces
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from optimization import compile_transformer
from hub_utils import _push_compiled_graph_to_hub
from huggingface_hub import whoami
import time
from fa3 import FlashFusedFluxAttnProcessor3_0
# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the model pipeline
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=dtype).to(device)
pipe.transformer.fuse_qkv_projections()
pipe.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
@spaces.GPU(duration=1200)
def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
if not filename.endswith(".pt2"):
raise NotImplementedError("The filename must end with a `.pt2` extension.")
# this will throw if token is invalid
try:
_ = whoami(oauth_token.token)
# --- Ahead-of-time compilation ---
start = time.perf_counter()
compiled_transformer = compile_transformer(pipe, prompt="prompt")
end = time.perf_counter()
print(f"Compilation took: {start - end} seconds.")
token = oauth_token.token
out = _push_compiled_graph_to_hub(
compiled_transformer.archive_file, repo_id=repo_id, token=token, path_in_repo=filename
)
if not isinstance(out, str) and hasattr(out, "commit_url"):
commit_url = out.commit_url
return f"[{commit_url}]({commit_url})"
else:
return out
except Exception as e:
raise gr.Error(
f"""Oops, you forgot to login. Please use the loggin button on the top left to migrate your repo {e}"""
)
css = """
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"## Compile [Flux.1-Dev](https://hf.co/black-forest-labs/Flux.1-Dev) graph ahead of time & push to the Hub"
)
gr.Markdown(
"Enter a **repo_id** and **filename**. This repo automatically compiles the Flux.1-Dev model ahead of time. Read more about this in [this post](https://huggingface.co/blog/zerogpu-aoti)."
)
gr.Markdown("Depending on the model, it can take some time (2-10 mins) to compile.")
repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot")
filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2")
run = gr.Button("Push graph to Hub", variant="primary")
markdown_out = gr.Markdown(label="Output")
run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out])
def swap_visibilty(profile: gr.OAuthProfile | None):
return gr.update(elem_classes=["main_ui_logged_in"]) if profile else gr.update(elem_classes=["main_ui_logged_out"])
css_login = """
.main_ui_logged_out{opacity: 0.3; pointer-events: none; margin: 0 auto; max-width: 520px}
"""
with gr.Blocks(css=css_login) as demo_login:
gr.LoginButton()
with gr.Column(elem_classes="main_ui_logged_out") as main_ui:
demo.render()
demo_login.load(fn=swap_visibilty, outputs=main_ui)
demo_login.queue()
demo_login.launch()
|