File size: 4,919 Bytes
601c9fd 4b23311 601c9fd 4b23311 1ea8dd9 4b23311 601c9fd 4b23311 601c9fd 4b23311 bf03ba4 601c9fd 4b23311 601c9fd 4b23311 601c9fd 4b23311 601c9fd 4b23311 601c9fd 59db7fd 4b23311 1ea8dd9 4b23311 1ea8dd9 4b23311 1ea8dd9 4b23311 1ea8dd9 4b23311 601c9fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from huggingface_hub import model_info, hf_hub_download
import gradio as gr
import json
def format_size(num: int) -> str:
"""Format size in bytes into a human-readable string.
Taken from https://stackoverflow.com/a/1094933
"""
num_f = float(num)
for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
if abs(num_f) < 1000.0:
return f"{num_f:3.1f}{unit}"
num_f /= 1000.0
return f"{num_f:.1f}Y"
def format_output(memory_mapping):
markdown_str = ""
if memory_mapping:
for component, memory in memory_mapping:
markdown_str += f"* {component}: {format_size(memory)}\n"
return markdown_str
def load_model_index(pipeline_id, token=None, revision=None):
index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token)
with open(index_path, "r") as f:
index_dict = json.load(f)
return index_dict
def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"):
if token == "":
token = None
if revision == "":
revision = None
if variant == "fp32":
variant = None
print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}")
files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
index_dict = load_model_index(pipeline_id, token=token, revision=revision)
is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
component_wise_memory = {}
# Handle text encoder separately when it's sharded.
if is_text_encoder_shared:
for current_file in files_in_repo:
if "text_encoder" in current_file.rfilename:
if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
if variant is not None and variant in current_file.rfilename:
selected_file = current_file
else:
selected_file = current_file
if "text_encoder" not in component_wise_memory:
component_wise_memory["text_encoder"] = selected_file.size
else:
component_wise_memory["text_encoder"] += selected_file.size
print(component_wise_memory)
# Handle pipeline components.
component_filter = ["scheduler", "feature_extractor", "safety_checker", "tokenizer"]
if is_text_encoder_shared:
component_filter.append("text_encoder")
for current_file in files_in_repo:
if all(substring not in current_file.rfilename for substring in component_filter):
is_folder = len(current_file.rfilename.split("/")) == 2
if is_folder and current_file.rfilename.split("/")[0] in index_dict:
selected_file = None
if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
component = current_file.rfilename.split("/")[0]
if (
variant is not None
and variant in current_file.rfilename
and "ema" not in current_file.rfilename
):
selected_file = current_file
elif variant is None and "ema" not in current_file.rfilename:
selected_file = current_file
if selected_file is not None:
print(selected_file.rfilename)
component_wise_memory[component] = selected_file.size
return format_output(component_wise_memory)
gr.Interface(
title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
description="Sizes will be reported in GB.",
fn=get_component_wise_memory,
inputs=[
gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
gr.components.Dropdown(
[
"fp32",
"fp16",
],
label="variant",
info="Precision to use for calculation.",
),
gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."),
gr.components.Dropdown(
[".bin", ".safetensors"],
label="extension",
info="Extension to use.",
),
],
outputs="markdown",
examples=[
["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"],
[""],
],
theme=gr.themes.Soft(),
allow_flagging=False,
).launch() |