Update app.py
Browse files
app.py
CHANGED
@@ -2,10 +2,10 @@ from huggingface_hub import model_info, hf_hub_download
|
|
2 |
import gradio as gr
|
3 |
import json
|
4 |
|
|
|
5 |
|
6 |
def format_size(num: int) -> str:
|
7 |
"""Format size in bytes into a human-readable string.
|
8 |
-
|
9 |
Taken from https://stackoverflow.com/a/1094933
|
10 |
"""
|
11 |
num_f = float(num)
|
@@ -43,10 +43,25 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
43 |
files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
|
44 |
index_dict = load_model_index(pipeline_id, token=token, revision=revision)
|
45 |
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
# Handle text encoder separately when it's sharded.
|
|
|
|
|
50 |
if is_text_encoder_shared:
|
51 |
for current_file in files_in_repo:
|
52 |
if "text_encoder" in current_file.rfilename:
|
@@ -60,10 +75,7 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
60 |
else:
|
61 |
component_wise_memory["text_encoder"] += selected_file.size
|
62 |
|
63 |
-
print(component_wise_memory)
|
64 |
-
|
65 |
# Handle pipeline components.
|
66 |
-
component_filter = ["scheduler", "feature_extractor", "safety_checker", "tokenizer"]
|
67 |
if is_text_encoder_shared:
|
68 |
component_filter.append("text_encoder")
|
69 |
|
@@ -87,37 +99,4 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
87 |
print(selected_file.rfilename)
|
88 |
component_wise_memory[component] = selected_file.size
|
89 |
|
90 |
-
return format_output(pipeline_id, component_wise_memory)
|
91 |
-
|
92 |
-
|
93 |
-
gr.Interface(
|
94 |
-
title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
|
95 |
-
description="Sizes will be reported in GB. Pipelines containing text encoders with sharded checkpoints are also supported (PixArt-Alpha, for example) 🤗",
|
96 |
-
fn=get_component_wise_memory,
|
97 |
-
inputs=[
|
98 |
-
gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
|
99 |
-
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
|
100 |
-
gr.components.Dropdown(
|
101 |
-
[
|
102 |
-
"fp32",
|
103 |
-
"fp16",
|
104 |
-
],
|
105 |
-
label="variant",
|
106 |
-
info="Precision to use for calculation.",
|
107 |
-
),
|
108 |
-
gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."),
|
109 |
-
gr.components.Dropdown(
|
110 |
-
[".bin", ".safetensors"],
|
111 |
-
label="extension",
|
112 |
-
info="Extension to use.",
|
113 |
-
),
|
114 |
-
],
|
115 |
-
outputs=[gr.Markdown(label="Output")],
|
116 |
-
examples=[
|
117 |
-
["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
|
118 |
-
["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"],
|
119 |
-
["PixArt-alpha/PixArt-XL-2-1024-MS", None, "fp32", None, ".safetensors"],
|
120 |
-
],
|
121 |
-
theme=gr.themes.Soft(),
|
122 |
-
allow_flagging=False,
|
123 |
-
).launch(show_error=True)
|
|
|
2 |
import gradio as gr
|
3 |
import json
|
4 |
|
5 |
+
component_filter = ["scheduler", "safety_checker", "tokenizer"]
|
6 |
|
7 |
def format_size(num: int) -> str:
|
8 |
"""Format size in bytes into a human-readable string.
|
|
|
9 |
Taken from https://stackoverflow.com/a/1094933
|
10 |
"""
|
11 |
num_f = float(num)
|
|
|
43 |
files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
|
44 |
index_dict = load_model_index(pipeline_id, token=token, revision=revision)
|
45 |
|
46 |
+
# Check if all the concerned components have the checkpoints in the requested "variant" and "extension".
|
47 |
+
index_filter = component_filter.copy()
|
48 |
+
index_filter.extend(["_class_name", "_diffusers_version"])
|
49 |
+
for current_component in index_dict:
|
50 |
+
if current_component not in index_filter:
|
51 |
+
current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))
|
52 |
+
if current_component_fileobjs:
|
53 |
+
current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
|
54 |
+
condition = lambda filename: extension in filename and variant in filename if variant is not None else lambda filename: extension in filename
|
55 |
+
variant_present_with_extension = any(condition(filename) for filename in current_component_filenames)
|
56 |
+
if not variant_present_with_extension:
|
57 |
+
raise ValueError(f"Requested extension ({extension}) and variant ({variant}) not present for {current_component}. Available files for this component:\n{current_component_filenames}.")
|
58 |
+
else:
|
59 |
+
raise ValueError(f"Problem with {current_component}.")
|
60 |
+
|
61 |
|
62 |
# Handle text encoder separately when it's sharded.
|
63 |
+
is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
|
64 |
+
component_wise_memory = {}
|
65 |
if is_text_encoder_shared:
|
66 |
for current_file in files_in_repo:
|
67 |
if "text_encoder" in current_file.rfilename:
|
|
|
75 |
else:
|
76 |
component_wise_memory["text_encoder"] += selected_file.size
|
77 |
|
|
|
|
|
78 |
# Handle pipeline components.
|
|
|
79 |
if is_text_encoder_shared:
|
80 |
component_filter.append("text_encoder")
|
81 |
|
|
|
99 |
print(selected_file.rfilename)
|
100 |
component_wise_memory[component] = selected_file.size
|
101 |
|
102 |
+
return format_output(pipeline_id, component_wise_memory)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|