Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	up
Browse files- app.py +25 -20
- hub_utils.py +6 -7
- optimization.py +14 -16
    	
        app.py
    CHANGED
    
    | @@ -11,15 +11,14 @@ dtype = torch.bfloat16 | |
| 11 | 
             
            device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 12 |  | 
| 13 | 
             
            # Load the model pipeline
         | 
| 14 | 
            -
            pipe = DiffusionPipeline.from_pretrained(
         | 
| 15 | 
            -
             | 
| 16 | 
            -
            ).to(device)
         | 
| 17 |  | 
| 18 | 
             
            @spaces.GPU(duration=120)
         | 
| 19 | 
            -
            def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
         | 
| 20 | 
             
                if not filename.endswith(".pt2"):
         | 
| 21 | 
             
                    raise NotImplementedError("The filename must end with a `.pt2` extension.")
         | 
| 22 | 
            -
             | 
| 23 | 
             
                # this will throw if token is invalid
         | 
| 24 | 
             
                try:
         | 
| 25 | 
             
                    _ = whoami(oauth_token.token)
         | 
| @@ -27,12 +26,9 @@ def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken): | |
| 27 | 
             
                    # --- Ahead-of-time compilation ---
         | 
| 28 | 
             
                    compiled_transformer = compile_transformer(pipe, prompt="prompt")
         | 
| 29 |  | 
| 30 | 
            -
                    token = oauth_token.token | 
| 31 | 
             
                    out = _push_compiled_graph_to_hub(
         | 
| 32 | 
            -
                        compiled_transformer.archive_file,
         | 
| 33 | 
            -
                        repo_id=repo_id,
         | 
| 34 | 
            -
                        token=token,
         | 
| 35 | 
            -
                        path_in_repo=filename
         | 
| 36 | 
             
                    )
         | 
| 37 | 
             
                    if not isinstance(out, str) and hasattr(out, "commit_url"):
         | 
| 38 | 
             
                        commit_url = out.commit_url
         | 
| @@ -40,9 +36,12 @@ def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken): | |
| 40 | 
             
                    else:
         | 
| 41 | 
             
                        return out
         | 
| 42 | 
             
                except Exception as e:
         | 
| 43 | 
            -
                    raise gr.Error( | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
|  | |
|  | |
|  | |
| 46 | 
             
            #col-container {
         | 
| 47 | 
             
                margin: 0 auto;
         | 
| 48 | 
             
                max-width: 520px;
         | 
| @@ -50,8 +49,12 @@ css=""" | |
| 50 | 
             
            """
         | 
| 51 | 
             
            with gr.Blocks(css=css) as demo:
         | 
| 52 | 
             
                with gr.Column(elem_id="col-container"):
         | 
| 53 | 
            -
                    gr.Markdown( | 
| 54 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 55 |  | 
| 56 | 
             
                    repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot")
         | 
| 57 | 
             
                    filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2")
         | 
| @@ -62,17 +65,19 @@ with gr.Blocks(css=css) as demo: | |
| 62 |  | 
| 63 | 
             
                run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out])
         | 
| 64 |  | 
|  | |
| 65 | 
             
            def swap_visibilty(profile: gr.OAuthProfile | None):
         | 
| 66 | 
             
                return gr.update(elem_classes=["main_ui_logged_in"]) if profile else gr.update(elem_classes=["main_ui_logged_out"])
         | 
| 67 | 
            -
             | 
| 68 | 
            -
             | 
|  | |
| 69 | 
             
            .main_ui_logged_out{opacity: 0.3; pointer-events: none; margin: 0 auto; max-width: 520px}
         | 
| 70 | 
            -
             | 
| 71 | 
             
            with gr.Blocks(css=css_login) as demo_login:
         | 
| 72 | 
             
                gr.LoginButton()
         | 
| 73 | 
             
                with gr.Column(elem_classes="main_ui_logged_out") as main_ui:
         | 
| 74 | 
             
                    demo.render()
         | 
| 75 | 
             
                demo_login.load(fn=swap_visibilty, outputs=main_ui)
         | 
| 76 | 
            -
             | 
| 77 | 
             
            demo_login.queue()
         | 
| 78 | 
            -
            demo_login.launch()
         | 
|  | |
| 11 | 
             
            device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 12 |  | 
| 13 | 
             
            # Load the model pipeline
         | 
| 14 | 
            +
            pipe = DiffusionPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=dtype).to(device)
         | 
| 15 | 
            +
             | 
|  | |
| 16 |  | 
| 17 | 
             
            @spaces.GPU(duration=120)
         | 
| 18 | 
            +
            def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken, progress=gr.Progress(track_tqdm=True)):
         | 
| 19 | 
             
                if not filename.endswith(".pt2"):
         | 
| 20 | 
             
                    raise NotImplementedError("The filename must end with a `.pt2` extension.")
         | 
| 21 | 
            +
             | 
| 22 | 
             
                # this will throw if token is invalid
         | 
| 23 | 
             
                try:
         | 
| 24 | 
             
                    _ = whoami(oauth_token.token)
         | 
|  | |
| 26 | 
             
                    # --- Ahead-of-time compilation ---
         | 
| 27 | 
             
                    compiled_transformer = compile_transformer(pipe, prompt="prompt")
         | 
| 28 |  | 
| 29 | 
            +
                    token = oauth_token.token
         | 
| 30 | 
             
                    out = _push_compiled_graph_to_hub(
         | 
| 31 | 
            +
                        compiled_transformer.archive_file, repo_id=repo_id, token=token, path_in_repo=filename
         | 
|  | |
|  | |
|  | |
| 32 | 
             
                    )
         | 
| 33 | 
             
                    if not isinstance(out, str) and hasattr(out, "commit_url"):
         | 
| 34 | 
             
                        commit_url = out.commit_url
         | 
|  | |
| 36 | 
             
                    else:
         | 
| 37 | 
             
                        return out
         | 
| 38 | 
             
                except Exception as e:
         | 
| 39 | 
            +
                    raise gr.Error(
         | 
| 40 | 
            +
                        f"""Oops, you forgot to login. Please use the loggin button on the top left to migrate your repo {e}"""
         | 
| 41 | 
            +
                    )
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            css = """
         | 
| 45 | 
             
            #col-container {
         | 
| 46 | 
             
                margin: 0 auto;
         | 
| 47 | 
             
                max-width: 520px;
         | 
|  | |
| 49 | 
             
            """
         | 
| 50 | 
             
            with gr.Blocks(css=css) as demo:
         | 
| 51 | 
             
                with gr.Column(elem_id="col-container"):
         | 
| 52 | 
            +
                    gr.Markdown(
         | 
| 53 | 
            +
                        "## Compile [Flux.1-Dev](https://hf.co/black-forest-labs/Flux.1-Dev) graph ahead of time & push to the Hub"
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
                    gr.Markdown(
         | 
| 56 | 
            +
                        "Enter a **repo_id** and **filename**. This repo automatically compiles the Flux.1-Dev model ahead of time. Read more about this in [this post](https://huggingface.co/blog/zerogpu-aoti)."
         | 
| 57 | 
            +
                    )
         | 
| 58 |  | 
| 59 | 
             
                    repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot")
         | 
| 60 | 
             
                    filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2")
         | 
|  | |
| 65 |  | 
| 66 | 
             
                run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out])
         | 
| 67 |  | 
| 68 | 
            +
             | 
| 69 | 
             
            def swap_visibilty(profile: gr.OAuthProfile | None):
         | 
| 70 | 
             
                return gr.update(elem_classes=["main_ui_logged_in"]) if profile else gr.update(elem_classes=["main_ui_logged_out"])
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            css_login = """
         | 
| 74 | 
             
            .main_ui_logged_out{opacity: 0.3; pointer-events: none; margin: 0 auto; max-width: 520px}
         | 
| 75 | 
            +
            """
         | 
| 76 | 
             
            with gr.Blocks(css=css_login) as demo_login:
         | 
| 77 | 
             
                gr.LoginButton()
         | 
| 78 | 
             
                with gr.Column(elem_classes="main_ui_logged_out") as main_ui:
         | 
| 79 | 
             
                    demo.render()
         | 
| 80 | 
             
                demo_login.load(fn=swap_visibilty, outputs=main_ui)
         | 
| 81 | 
            +
             | 
| 82 | 
             
            demo_login.queue()
         | 
| 83 | 
            +
            demo_login.launch()
         | 
    	
        hub_utils.py
    CHANGED
    
    | @@ -1,10 +1,11 @@ | |
| 1 | 
             
            from io import BytesIO
         | 
| 2 | 
             
            from huggingface_hub import create_repo, upload_file
         | 
| 3 | 
            -
            import tempfile | 
| 4 | 
             
            import os
         | 
| 5 |  | 
| 6 | 
             
            DEFAULT_ARCHIVE_FILENAME = "archived_graph.pt2"
         | 
| 7 |  | 
|  | |
| 8 | 
             
            def _push_compiled_graph_to_hub(archive: BytesIO, repo_id, **kwargs):
         | 
| 9 | 
             
                if not isinstance(archive, BytesIO):
         | 
| 10 | 
             
                    raise NotImplementedError("Incorrect type of `archive` provided.")
         | 
| @@ -13,9 +14,7 @@ def _push_compiled_graph_to_hub(archive: BytesIO, repo_id, **kwargs): | |
| 13 | 
             
                private = kwargs.pop("private", False)
         | 
| 14 | 
             
                path_in_repo = kwargs.pop("path_in_repo", DEFAULT_ARCHIVE_FILENAME)
         | 
| 15 | 
             
                token = kwargs.pop("token")
         | 
| 16 | 
            -
                repo_id = create_repo(
         | 
| 17 | 
            -
                    repo_id, private=private, exist_ok=True, token=token
         | 
| 18 | 
            -
                ).repo_id
         | 
| 19 |  | 
| 20 | 
             
                with tempfile.TemporaryDirectory() as tmpdir:
         | 
| 21 | 
             
                    output_path = os.path.join(tmpdir, os.path.basename(path_in_repo))
         | 
| @@ -24,8 +23,8 @@ def _push_compiled_graph_to_hub(archive: BytesIO, repo_id, **kwargs): | |
| 24 |  | 
| 25 | 
             
                    try:
         | 
| 26 | 
             
                        info = upload_file(
         | 
| 27 | 
            -
                            repo_id=repo_id, | 
| 28 | 
            -
                            path_or_fileobj=output_path, | 
| 29 | 
             
                            path_in_repo=os.path.basename(path_in_repo),
         | 
| 30 | 
             
                            commit_message=commit_message,
         | 
| 31 | 
             
                            token=token,
         | 
| @@ -33,4 +32,4 @@ def _push_compiled_graph_to_hub(archive: BytesIO, repo_id, **kwargs): | |
| 33 | 
             
                        return info
         | 
| 34 | 
             
                    except Exception as e:
         | 
| 35 | 
             
                        print(f"File couldn't be pushed to the Hub with the following error: {e}.")
         | 
| 36 | 
            -
                        return e
         | 
|  | |
| 1 | 
             
            from io import BytesIO
         | 
| 2 | 
             
            from huggingface_hub import create_repo, upload_file
         | 
| 3 | 
            +
            import tempfile
         | 
| 4 | 
             
            import os
         | 
| 5 |  | 
| 6 | 
             
            DEFAULT_ARCHIVE_FILENAME = "archived_graph.pt2"
         | 
| 7 |  | 
| 8 | 
            +
             | 
| 9 | 
             
            def _push_compiled_graph_to_hub(archive: BytesIO, repo_id, **kwargs):
         | 
| 10 | 
             
                if not isinstance(archive, BytesIO):
         | 
| 11 | 
             
                    raise NotImplementedError("Incorrect type of `archive` provided.")
         | 
|  | |
| 14 | 
             
                private = kwargs.pop("private", False)
         | 
| 15 | 
             
                path_in_repo = kwargs.pop("path_in_repo", DEFAULT_ARCHIVE_FILENAME)
         | 
| 16 | 
             
                token = kwargs.pop("token")
         | 
| 17 | 
            +
                repo_id = create_repo(repo_id, private=private, exist_ok=True, token=token).repo_id
         | 
|  | |
|  | |
| 18 |  | 
| 19 | 
             
                with tempfile.TemporaryDirectory() as tmpdir:
         | 
| 20 | 
             
                    output_path = os.path.join(tmpdir, os.path.basename(path_in_repo))
         | 
|  | |
| 23 |  | 
| 24 | 
             
                    try:
         | 
| 25 | 
             
                        info = upload_file(
         | 
| 26 | 
            +
                            repo_id=repo_id,
         | 
| 27 | 
            +
                            path_or_fileobj=output_path,
         | 
| 28 | 
             
                            path_in_repo=os.path.basename(path_in_repo),
         | 
| 29 | 
             
                            commit_message=commit_message,
         | 
| 30 | 
             
                            token=token,
         | 
|  | |
| 32 | 
             
                        return info
         | 
| 33 | 
             
                    except Exception as e:
         | 
| 34 | 
             
                        print(f"File couldn't be pushed to the Hub with the following error: {e}.")
         | 
| 35 | 
            +
                        return e
         | 
    	
        optimization.py
    CHANGED
    
    | @@ -5,26 +5,27 @@ import spaces | |
| 5 | 
             
            import torch
         | 
| 6 | 
             
            from torch.utils._pytree import tree_map
         | 
| 7 |  | 
| 8 | 
            -
            P = ParamSpec( | 
| 9 |  | 
| 10 | 
            -
            TRANSFORMER_HIDDEN_DIM = torch.export.Dim( | 
| 11 |  | 
| 12 | 
             
            # Specific to Flux. More about this is available in
         | 
| 13 | 
             
            # https://huggingface.co/blog/zerogpu-aoti
         | 
| 14 | 
             
            TRANSFORMER_DYNAMIC_SHAPES = {
         | 
| 15 | 
            -
                 | 
| 16 | 
            -
                 | 
| 17 | 
             
            }
         | 
| 18 |  | 
| 19 | 
             
            INDUCTOR_CONFIGS = {
         | 
| 20 | 
            -
                 | 
| 21 | 
            -
                 | 
| 22 | 
            -
                 | 
| 23 | 
            -
                 | 
| 24 | 
            -
                 | 
| 25 | 
            -
                 | 
| 26 | 
             
            }
         | 
| 27 |  | 
|  | |
| 28 | 
             
            def compile_transformer(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
         | 
| 29 | 
             
                @spaces.GPU(duration=1500)
         | 
| 30 | 
             
                def f():
         | 
| @@ -35,12 +36,9 @@ def compile_transformer(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.k | |
| 35 | 
             
                    dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
         | 
| 36 |  | 
| 37 | 
             
                    exported = torch.export.export(
         | 
| 38 | 
            -
                        mod=pipeline.transformer, 
         | 
| 39 | 
            -
                        args=call.args, 
         | 
| 40 | 
            -
                        kwargs=call.kwargs,
         | 
| 41 | 
            -
                        dynamic_shapes=dynamic_shapes
         | 
| 42 | 
             
                    )
         | 
| 43 | 
             
                    return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
         | 
| 44 | 
            -
             | 
| 45 | 
             
                compiled_transformer = f()
         | 
| 46 | 
            -
                return compiled_transformer
         | 
|  | |
| 5 | 
             
            import torch
         | 
| 6 | 
             
            from torch.utils._pytree import tree_map
         | 
| 7 |  | 
| 8 | 
            +
            P = ParamSpec("P")
         | 
| 9 |  | 
| 10 | 
            +
            TRANSFORMER_HIDDEN_DIM = torch.export.Dim("hidden", min=4096, max=8212)
         | 
| 11 |  | 
| 12 | 
             
            # Specific to Flux. More about this is available in
         | 
| 13 | 
             
            # https://huggingface.co/blog/zerogpu-aoti
         | 
| 14 | 
             
            TRANSFORMER_DYNAMIC_SHAPES = {
         | 
| 15 | 
            +
                "hidden_states": {1: TRANSFORMER_HIDDEN_DIM},
         | 
| 16 | 
            +
                "img_ids": {0: TRANSFORMER_HIDDEN_DIM},
         | 
| 17 | 
             
            }
         | 
| 18 |  | 
| 19 | 
             
            INDUCTOR_CONFIGS = {
         | 
| 20 | 
            +
                "conv_1x1_as_mm": True,
         | 
| 21 | 
            +
                "epilogue_fusion": False,
         | 
| 22 | 
            +
                "coordinate_descent_tuning": True,
         | 
| 23 | 
            +
                "coordinate_descent_check_all_directions": True,
         | 
| 24 | 
            +
                "max_autotune": True,
         | 
| 25 | 
            +
                "triton.cudagraphs": True,
         | 
| 26 | 
             
            }
         | 
| 27 |  | 
| 28 | 
            +
             | 
| 29 | 
             
            def compile_transformer(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
         | 
| 30 | 
             
                @spaces.GPU(duration=1500)
         | 
| 31 | 
             
                def f():
         | 
|  | |
| 36 | 
             
                    dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
         | 
| 37 |  | 
| 38 | 
             
                    exported = torch.export.export(
         | 
| 39 | 
            +
                        mod=pipeline.transformer, args=call.args, kwargs=call.kwargs, dynamic_shapes=dynamic_shapes
         | 
|  | |
|  | |
|  | |
| 40 | 
             
                    )
         | 
| 41 | 
             
                    return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
         | 
| 42 | 
            +
             | 
| 43 | 
             
                compiled_transformer = f()
         | 
| 44 | 
            +
                return compiled_transformer
         | 
 
			
