Update app.py
Browse files
app.py
CHANGED
|
@@ -2,10 +2,10 @@ from huggingface_hub import model_info, hf_hub_download
|
|
| 2 |
import gradio as gr
|
| 3 |
import json
|
| 4 |
|
|
|
|
| 5 |
|
| 6 |
def format_size(num: int) -> str:
|
| 7 |
"""Format size in bytes into a human-readable string.
|
| 8 |
-
|
| 9 |
Taken from https://stackoverflow.com/a/1094933
|
| 10 |
"""
|
| 11 |
num_f = float(num)
|
|
@@ -43,10 +43,25 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
| 43 |
files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
|
| 44 |
index_dict = load_model_index(pipeline_id, token=token, revision=revision)
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
# Handle text encoder separately when it's sharded.
|
|
|
|
|
|
|
| 50 |
if is_text_encoder_shared:
|
| 51 |
for current_file in files_in_repo:
|
| 52 |
if "text_encoder" in current_file.rfilename:
|
|
@@ -60,10 +75,7 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
| 60 |
else:
|
| 61 |
component_wise_memory["text_encoder"] += selected_file.size
|
| 62 |
|
| 63 |
-
print(component_wise_memory)
|
| 64 |
-
|
| 65 |
# Handle pipeline components.
|
| 66 |
-
component_filter = ["scheduler", "feature_extractor", "safety_checker", "tokenizer"]
|
| 67 |
if is_text_encoder_shared:
|
| 68 |
component_filter.append("text_encoder")
|
| 69 |
|
|
@@ -87,37 +99,4 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
| 87 |
print(selected_file.rfilename)
|
| 88 |
component_wise_memory[component] = selected_file.size
|
| 89 |
|
| 90 |
-
return format_output(pipeline_id, component_wise_memory)
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
gr.Interface(
|
| 94 |
-
title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
|
| 95 |
-
description="Sizes will be reported in GB. Pipelines containing text encoders with sharded checkpoints are also supported (PixArt-Alpha, for example) 🤗",
|
| 96 |
-
fn=get_component_wise_memory,
|
| 97 |
-
inputs=[
|
| 98 |
-
gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
|
| 99 |
-
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
|
| 100 |
-
gr.components.Dropdown(
|
| 101 |
-
[
|
| 102 |
-
"fp32",
|
| 103 |
-
"fp16",
|
| 104 |
-
],
|
| 105 |
-
label="variant",
|
| 106 |
-
info="Precision to use for calculation.",
|
| 107 |
-
),
|
| 108 |
-
gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."),
|
| 109 |
-
gr.components.Dropdown(
|
| 110 |
-
[".bin", ".safetensors"],
|
| 111 |
-
label="extension",
|
| 112 |
-
info="Extension to use.",
|
| 113 |
-
),
|
| 114 |
-
],
|
| 115 |
-
outputs=[gr.Markdown(label="Output")],
|
| 116 |
-
examples=[
|
| 117 |
-
["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
|
| 118 |
-
["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"],
|
| 119 |
-
["PixArt-alpha/PixArt-XL-2-1024-MS", None, "fp32", None, ".safetensors"],
|
| 120 |
-
],
|
| 121 |
-
theme=gr.themes.Soft(),
|
| 122 |
-
allow_flagging=False,
|
| 123 |
-
).launch(show_error=True)
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import json
|
| 4 |
|
| 5 |
+
component_filter = ["scheduler", "safety_checker", "tokenizer"]
|
| 6 |
|
| 7 |
def format_size(num: int) -> str:
|
| 8 |
"""Format size in bytes into a human-readable string.
|
|
|
|
| 9 |
Taken from https://stackoverflow.com/a/1094933
|
| 10 |
"""
|
| 11 |
num_f = float(num)
|
|
|
|
| 43 |
files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
|
| 44 |
index_dict = load_model_index(pipeline_id, token=token, revision=revision)
|
| 45 |
|
| 46 |
+
# Check if all the concerned components have the checkpoints in the requested "variant" and "extension".
|
| 47 |
+
index_filter = component_filter.copy()
|
| 48 |
+
index_filter.extend(["_class_name", "_diffusers_version"])
|
| 49 |
+
for current_component in index_dict:
|
| 50 |
+
if current_component not in index_filter:
|
| 51 |
+
current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))
|
| 52 |
+
if current_component_fileobjs:
|
| 53 |
+
current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
|
| 54 |
+
condition = lambda filename: extension in filename and variant in filename if variant is not None else lambda filename: extension in filename
|
| 55 |
+
variant_present_with_extension = any(condition(filename) for filename in current_component_filenames)
|
| 56 |
+
if not variant_present_with_extension:
|
| 57 |
+
raise ValueError(f"Requested extension ({extension}) and variant ({variant}) not present for {current_component}. Available files for this component:\n{current_component_filenames}.")
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Problem with {current_component}.")
|
| 60 |
+
|
| 61 |
|
| 62 |
# Handle text encoder separately when it's sharded.
|
| 63 |
+
is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
|
| 64 |
+
component_wise_memory = {}
|
| 65 |
if is_text_encoder_shared:
|
| 66 |
for current_file in files_in_repo:
|
| 67 |
if "text_encoder" in current_file.rfilename:
|
|
|
|
| 75 |
else:
|
| 76 |
component_wise_memory["text_encoder"] += selected_file.size
|
| 77 |
|
|
|
|
|
|
|
| 78 |
# Handle pipeline components.
|
|
|
|
| 79 |
if is_text_encoder_shared:
|
| 80 |
component_filter.append("text_encoder")
|
| 81 |
|
|
|
|
| 99 |
print(selected_file.rfilename)
|
| 100 |
component_wise_memory[component] = selected_file.size
|
| 101 |
|
| 102 |
+
return format_output(pipeline_id, component_wise_memory)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|