import spaces import gradio as gr import torch from diffusers import DiffusionPipeline from optimization import compile_transformer from hub_utils import _push_compiled_graph_to_hub from huggingface_hub import whoami import time from fa3 import FlashFluxAttnProcessor3_0 # --- Model Loading --- dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" # Load the model pipeline pipe = DiffusionPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=dtype).to(device) pipe.transformer.set_attn_processor(FlashFluxAttnProcessor3_0()) @spaces.GPU(duration=1200) def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken): if not filename.endswith(".pt2"): raise NotImplementedError("The filename must end with a `.pt2` extension.") # this will throw if token is invalid try: _ = whoami(oauth_token.token) # --- Ahead-of-time compilation --- start = time.perf_counter() compiled_transformer = compile_transformer(pipe, prompt="prompt") end = time.perf_counter() print(f"Compilation took: {end - start} seconds.") token = oauth_token.token out = _push_compiled_graph_to_hub( compiled_transformer.archive_file, repo_id=repo_id, token=token, path_in_repo=filename ) return out except Exception as e: raise gr.Error( f"""Oops, you forgot to login. Please use the loggin button on the top left to migrate your repo {e}""" ) css = """ #col-container { margin: 0 auto; max-width: 520px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown( "## Compile [Flux.1-Dev](https://hf.co/black-forest-labs/Flux.1-Dev) graph ahead of time & push to the Hub" ) gr.Markdown( "Enter a **repo_id** and **filename**. This repo automatically compiles the Flux.1-Dev model ahead of time. Read more about this in [this post](https://huggingface.co/blog/zerogpu-aoti)." ) gr.Markdown("Depending on the model, it can take some time (2-10 mins) to compile.") repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot") filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2") run = gr.Button("Push graph to Hub", variant="primary") markdown_out = gr.Markdown(label="Output") run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out]) def swap_visibilty(profile: gr.OAuthProfile | None): return gr.update(elem_classes=["main_ui_logged_in"]) if profile else gr.update(elem_classes=["main_ui_logged_out"]) css_login = """ .main_ui_logged_out{opacity: 0.3; pointer-events: none; margin: 0 auto; max-width: 520px} """ with gr.Blocks(css=css_login) as demo_login: gr.LoginButton() with gr.Column(elem_classes="main_ui_logged_out") as main_ui: demo.render() demo_login.load(fn=swap_visibilty, outputs=main_ui) demo_login.queue() demo_login.launch()