sayakpaul's picture
sayakpaul HF staff
fix
1632490
raw
history blame
12.8 kB
from huggingface_hub import hf_hub_download, model_info
import gradio as gr
import json
COMPONENT_FILTER = [
"scheduler",
"feature_extractor",
"tokenizer",
"tokenizer_2",
"_class_name",
"_diffusers_version",
]
ARTICLE = """
## Notes on how to use the `controlnet_id` and `t2i_adapter_id` fields
Both `controlnet_id` and `t2i_adapter_id` fields support passing multiple checkpoint ids,
e.g., "thibaud/controlnet-openpose-sdxl-1.0,diffusers/controlnet-canny-sdxl-1.0". For
`t2i_adapter_id`, this could be like - "TencentARC/t2iadapter_keypose_sd14v1,TencentARC/t2iadapter_depth_sd14v1".
Users should take care of passing the underlying base `pipeline_id` appropriately. For example,
passing `pipeline_id` as "runwayml/stable-diffusion-v1-5" and `controlnet_id` as "thibaud/controlnet-openpose-sdxl-1.0"
won't result in an error but these two things aren't meant to compatible. You should pass
a `controlnet_id` that is compatible with "runwayml/stable-diffusion-v1-5".
For further clarification on this topic, feel free to open a [discussion](https://huggingface.co/spaces/diffusers/compute-pipeline-size/discussions).
📔 Also, note that `revision` field is only reserved for `pipeline_id`. It won't have any effect on the
`controlnet_id` or `t2i_adapter_id`.
"""
ALLOWED_VARIANTS = ["fp32", "fp16", "bf16"]
def format_size(num: int) -> str:
"""Format size in bytes into a human-readable string.
Taken from https://stackoverflow.com/a/1094933
"""
num_f = float(num)
for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
if abs(num_f) < 1000.0:
return f"{num_f:3.1f}{unit}"
num_f /= 1000.0
return f"{num_f:.1f}Y"
def format_output(pipeline_id, memory_mapping, variant=None, controlnet_mapping=None, t2i_adapter_mapping=None):
if variant is None:
variant = "fp32"
markdown_str = f"## {pipeline_id} ({variant})\n"
if memory_mapping:
for component, memory in memory_mapping.items():
markdown_str += f"* {component}: {format_size(memory)}\n"
if controlnet_mapping:
markdown_str += f"\n## ControlNet(s) ({variant})\n"
for controlnet_id, memory in controlnet_mapping.items():
markdown_str += f"* {controlnet_id}: {format_size(memory)}\n"
if t2i_adapter_mapping:
markdown_str += f"\n## T2I-Adapters(s) ({variant})\n"
for t2_adapter_id, memory in t2i_adapter_mapping.items():
markdown_str += f"* {t2_adapter_id}: {format_size(memory)}\n"
return markdown_str
def load_model_index(pipeline_id, token=None, revision=None):
index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token)
with open(index_path, "r") as f:
index_dict = json.load(f)
return index_dict
def get_individual_model_memory(id, token, variant, extension):
# Retrieve all files in the repository.
files_in_repo = model_info(id, token=token, files_metadata=True).siblings
# Filter files by extension and variant (if provided).
if variant:
candidates = [x for x in files_in_repo if (extension in x.rfilename) and (variant in x.rfilename)]
if not candidates:
raise ValueError(f"Requested variant ({variant}) for {id} couldn't be found with {extension} extension.")
else:
candidates = [
x
for x in files_in_repo
if (extension in x.rfilename) and all(var not in x.rfilename for var in ALLOWED_VARIANTS[1:])
]
if not candidates:
raise ValueError(f"No file for {id} could be found with {extension} extension without specified variants.")
# Return the size of the first matching file.
return candidates[0].size
def get_component_wise_memory(
pipeline_id,
controlnet_id=None,
t2i_adapter_id=None,
token=None,
variant=None,
revision=None,
extension=".safetensors",
):
if controlnet_id == "":
controlnet_id = None
if t2i_adapter_id == "":
t2i_adapter_id = None
if controlnet_id and t2i_adapter_id:
raise ValueError("Both `controlnet_id` and `t2i_adapter_id` cannot be provided.")
if token == "":
token = None
if revision == "":
revision = None
if variant == "fp32":
variant = None
# Handle ControlNet and T2I-Adapter.
controlnet_mapping = t2_adapter_mapping = None
if controlnet_id is not None:
controlnet_id = controlnet_id.split(",")
controlnet_sizes = [
get_individual_model_memory(id_, token=token, variant=variant, extension=extension)
for id_ in controlnet_id
]
controlnet_mapping = dict(zip(controlnet_id, controlnet_sizes))
elif t2i_adapter_id is not None:
t2i_adapter_id = t2i_adapter_id.split(",")
t2i_adapter_sizes = [
get_individual_model_memory(id_, token=token, variant=variant, extension=extension)
for id_ in t2i_adapter_id
]
t2_adapter_mapping = dict(zip(t2i_adapter_id, t2i_adapter_sizes))
print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}")
# Load pipeline metadata.
files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
index_dict = load_model_index(pipeline_id, token=token, revision=revision)
# Check if all the concerned components have the checkpoints in
# the requested "variant" and "extension".
print(f"Index dict: {index_dict}")
for current_component in index_dict:
if (
current_component not in COMPONENT_FILTER
and isinstance(index_dict[current_component], list)
and len(index_dict[current_component]) == 2
):
current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))
if current_component_fileobjs:
current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
condition = ( # noqa: E731
lambda filename: extension in filename and variant in filename
if variant is not None
else lambda filename: extension in filename
)
variant_present_with_extension = any(condition(filename) for filename in current_component_filenames)
if not variant_present_with_extension:
formatted_filenames = ", ".join(current_component_filenames)
raise ValueError(
f"Requested extension ({extension}) and variant ({variant}) not present for {current_component}."
f" Available files for this component: {formatted_filenames}."
)
else:
raise ValueError(f"Problem with {current_component}.")
# Handle text encoder separately when it's sharded.
is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
component_wise_memory = {}
if is_text_encoder_shared:
for current_file in files_in_repo:
if "text_encoder" in current_file.rfilename:
if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
if variant is not None and variant in current_file.rfilename:
selected_file = current_file
else:
selected_file = current_file
if "text_encoder" not in component_wise_memory:
component_wise_memory["text_encoder"] = selected_file.size
else:
component_wise_memory["text_encoder"] += selected_file.size
# Handle pipeline components.
if is_text_encoder_shared:
COMPONENT_FILTER.append("text_encoder")
for current_file in files_in_repo:
if all(substring not in current_file.rfilename for substring in COMPONENT_FILTER):
is_folder = len(current_file.rfilename.split("/")) == 2
if is_folder and current_file.rfilename.split("/")[0] in index_dict:
selected_file = None
if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
component = current_file.rfilename.split("/")[0]
if (
variant is not None
and variant in current_file.rfilename
and "ema" not in current_file.rfilename
):
selected_file = current_file
elif variant is None and "ema" not in current_file.rfilename:
selected_file = current_file
if selected_file is not None:
component_wise_memory[component] = selected_file.size
return format_output(pipeline_id, component_wise_memory, variant, controlnet_mapping, t2_adapter_mapping)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Column():
gr.Markdown(
"""<img src="https://huggingface.co/spaces/hf-accelerate/model-memory-usage/resolve/main/measure_model_size.png" style="float: left;" width="150" height="175"><h1>🧨 Diffusers Pipeline Memory Calculator</h1>
This tool will help you to gauge the memory requirements of a Diffusers pipeline. Pipelines containing text encoders with sharded checkpoints are also supported
(PixArt-Alpha, for example) 🤗 See instructions below the form on how to pass `controlnet_id` or `t2_adapter_id`. When performing inference, expect to add up to an
additional 20% to this as found by [EleutherAI](https://blog.eleuther.ai/transformer-math/). You can click on one of the examples below the "Calculate Memory Usage" button
to get started. Design adapted from [this Space](https://huggingface.co/spaces/hf-accelerate/model-memory-usage).
"""
)
out_text = gr.Markdown()
with gr.Row():
pipeline_id = gr.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5")
with gr.Row():
controlnet_id = gr.Textbox(lines=1, label="controlnet_id", info="Example: lllyasviel/sd-controlnet-canny")
t2i_adapter_id = gr.Textbox(
lines=1, label="t2i_adapter_id", info="Example: TencentARC/t2iadapter_color_sd14v1"
)
with gr.Row():
token = gr.Textbox(lines=1, label="hf_token", info="Pass this in case of private/gated repositories.")
variant = gr.Radio(
ALLOWED_VARIANTS,
label="variant",
info="Precision to use for calculation.",
)
revision = gr.Textbox(lines=1, label="revision", info="Repository revision to use.")
extension = gr.Radio(
[".bin", ".safetensors"],
label="extension",
info="Extension to use.",
)
with gr.Row():
btn = gr.Button("Calculate Memory Usage")
gr.Markdown("## Examples")
gr.Examples(
[
["runwayml/stable-diffusion-v1-5", None, None, None, "fp32", None, ".safetensors"],
["PixArt-alpha/PixArt-XL-2-1024-MS", None, None, None, "fp32", None, ".safetensors"],
[
"runwayml/stable-diffusion-v1-5",
"lllyasviel/sd-controlnet-canny",
None,
None,
"fp32",
None,
".safetensors",
],
[
"stabilityai/stable-diffusion-xl-base-1.0",
None,
"TencentARC/t2i-adapter-lineart-sdxl-1.0,TencentARC/t2i-adapter-canny-sdxl-1.0",
None,
"fp16",
None,
".safetensors",
],
["stabilityai/stable-cascade", None, None, None, "bf16", None, ".safetensors"],
["Deci/DeciDiffusion-v2-0", None, None, None, "fp32", None, ".safetensors"],
],
[pipeline_id, controlnet_id, t2i_adapter_id, token, variant, revision, extension],
out_text,
get_component_wise_memory,
cache_examples=False,
)
gr.Markdown(ARTICLE)
btn.click(
get_component_wise_memory,
inputs=[pipeline_id, controlnet_id, t2i_adapter_id, token, variant, revision, extension],
outputs=[out_text],
api_name=False,
)
demo.launch(show_error=True)