File size: 3,263 Bytes
ebbd677
b22b80e
b05966a
f9f24d7
f5a3617
ee02270
 
6684d62
746f9fc
afa2559
b05966a
 
 
b22b80e
b05966a
13352a6
746f9fc
 
4b0fe46
ee4246b
f61fb8b
ee02270
 
4b0fe46
ee02270
7bf5ca7
 
b22b80e
7bf5ca7
6684d62
7bf5ca7
6684d62
452391a
b22b80e
4b0fe46
7bf5ca7
4b0fe46
7bf5ca7
 
 
 
 
 
 
4b0fe46
 
 
 
 
 
7bf5ca7
 
 
 
 
b22b80e
7bf5ca7
4b0fe46
 
 
 
 
 
104a64a
9c2430d
ee02270
 
9c2430d
7bf5ca7
9c2430d
1b06612
9c2430d
ee02270
b22b80e
4b0fe46
7bf5ca7
 
4b0fe46
 
 
7bf5ca7
4b0fe46
7bf5ca7
 
 
 
 
4b0fe46
7bf5ca7
4b0fe46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import spaces
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from optimization import compile_transformer
from hub_utils import _push_compiled_graph_to_hub
from huggingface_hub import whoami
import time
from fa3 import FlashFusedFluxAttnProcessor3_0

# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model pipeline
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=dtype).to(device)
pipe.transformer.fuse_qkv_projections()
pipe.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())

@spaces.GPU(duration=1200)
def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
    if not filename.endswith(".pt2"):
        raise NotImplementedError("The filename must end with a `.pt2` extension.")

    # this will throw if token is invalid
    try:
        _ = whoami(oauth_token.token)

        # --- Ahead-of-time compilation ---
        start = time.perf_counter()
        compiled_transformer = compile_transformer(pipe, prompt="prompt")
        end = time.perf_counter()
        print(f"Compilation took: {start - end} seconds.")

        token = oauth_token.token
        out = _push_compiled_graph_to_hub(
            compiled_transformer.archive_file, repo_id=repo_id, token=token, path_in_repo=filename
        )
        if not isinstance(out, str) and hasattr(out, "commit_url"):
            commit_url = out.commit_url
            return f"[{commit_url}]({commit_url})"
        else:
            return out
    except Exception as e:
        raise gr.Error(
            f"""Oops, you forgot to login. Please use the loggin button on the top left to migrate your repo {e}"""
        )


css = """
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(
            "## Compile [Flux.1-Dev](https://hf.co/black-forest-labs/Flux.1-Dev) graph ahead of time & push to the Hub"
        )
        gr.Markdown(
            "Enter a **repo_id** and **filename**. This repo automatically compiles the Flux.1-Dev model ahead of time. Read more about this in [this post](https://huggingface.co/blog/zerogpu-aoti)."
        )
        gr.Markdown("Depending on the model, it can take some time (2-10 mins) to compile.")

        repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot")
        filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2")

        run = gr.Button("Push graph to Hub", variant="primary")

        markdown_out = gr.Markdown(label="Output")

    run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out])


def swap_visibilty(profile: gr.OAuthProfile | None):
    return gr.update(elem_classes=["main_ui_logged_in"]) if profile else gr.update(elem_classes=["main_ui_logged_out"])


css_login = """
.main_ui_logged_out{opacity: 0.3; pointer-events: none; margin: 0 auto; max-width: 520px}
"""
with gr.Blocks(css=css_login) as demo_login:
    gr.LoginButton()
    with gr.Column(elem_classes="main_ui_logged_out") as main_ui:
        demo.render()
    demo_login.load(fn=swap_visibilty, outputs=main_ui)

demo_login.queue()
demo_login.launch()