justinblalock87
commited on
Commit
•
43c0fb7
1
Parent(s):
dc19714
Add quantize
Browse files- __pycache__/app.cpython-38.pyc +0 -0
- __pycache__/quantize.cpython-38.pyc +0 -0
- app.py +95 -4
- ml-stable-diffusion +1 -0
- quantize.py +148 -0
- requirements.txt +105 -0
__pycache__/app.cpython-38.pyc
ADDED
Binary file (3.27 kB). View file
|
|
__pycache__/quantize.cpython-38.pyc
ADDED
Binary file (4.69 kB). View file
|
|
app.py
CHANGED
@@ -1,7 +1,98 @@
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
from datetime import datetime
|
3 |
+
import os
|
4 |
+
from typing import Optional
|
5 |
import gradio as gr
|
6 |
|
7 |
+
import quantize
|
8 |
+
from huggingface_hub import HfApi, Repository
|
9 |
|
10 |
+
|
11 |
+
DATASET_REPO_URL = "https://huggingface.co/datasets/safetensors/conversions"
|
12 |
+
DATA_FILENAME = "data.csv"
|
13 |
+
DATA_FILE = os.path.join("data", DATA_FILENAME)
|
14 |
+
|
15 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
16 |
+
|
17 |
+
repo: Optional[Repository] = None
|
18 |
+
# TODO
|
19 |
+
if False and HF_TOKEN:
|
20 |
+
repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, token=HF_TOKEN)
|
21 |
+
|
22 |
+
|
23 |
+
def run(model_id: str, is_private: bool, token: Optional[str] = None) -> str:
|
24 |
+
if model_id == "":
|
25 |
+
return """
|
26 |
+
### Invalid input 🐞
|
27 |
+
|
28 |
+
Please fill a token and model_id.
|
29 |
+
"""
|
30 |
+
try:
|
31 |
+
if is_private:
|
32 |
+
api = HfApi(token=token)
|
33 |
+
else:
|
34 |
+
api = HfApi(token=HF_TOKEN)
|
35 |
+
hf_is_private = api.model_info(repo_id=model_id).private
|
36 |
+
if is_private and not hf_is_private:
|
37 |
+
# This model is NOT private
|
38 |
+
# Change the token so we make the PR on behalf of the bot.
|
39 |
+
api = HfApi(token=HF_TOKEN)
|
40 |
+
|
41 |
+
print("is_private", is_private)
|
42 |
+
|
43 |
+
commit_info, errors = quantize.quantize(api=api, model_id=model_id)
|
44 |
+
print("[commit_info]", commit_info)
|
45 |
+
|
46 |
+
|
47 |
+
string = f"""
|
48 |
+
### Success 🔥
|
49 |
+
Yay! This model was successfully converted and a PR was open using your token, here:
|
50 |
+
[{commit_info.pr_url}]({commit_info.pr_url})
|
51 |
+
"""
|
52 |
+
if errors:
|
53 |
+
string += "\nErrors during conversion:\n"
|
54 |
+
string += "\n".join(f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors)
|
55 |
+
return string
|
56 |
+
except Exception as e:
|
57 |
+
return f"""
|
58 |
+
### Error 😢😢😢
|
59 |
+
|
60 |
+
{e}
|
61 |
+
"""
|
62 |
+
|
63 |
+
|
64 |
+
DESCRIPTION = """
|
65 |
+
The steps are the following:
|
66 |
+
- Paste a read-access token from hf.co/settings/tokens. Read access is enough given that we will open a PR against the source repo.
|
67 |
+
- Input a model id from the Hub
|
68 |
+
- Click "Submit"
|
69 |
+
- That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the opened PR 🔥
|
70 |
+
⚠️ For now only `pytorch_model.bin` files are supported but we'll extend in the future.
|
71 |
+
"""
|
72 |
+
|
73 |
+
title="Quantize model and convert to CoreML"
|
74 |
+
allow_flagging="never"
|
75 |
+
|
76 |
+
def token_text(visible=False):
|
77 |
+
return gr.Text(max_lines=1, label="your_hf_token", visible=True, value="")
|
78 |
+
|
79 |
+
with gr.Blocks(title=title) as demo:
|
80 |
+
description = gr.Markdown(f"""# {title}""")
|
81 |
+
description = gr.Markdown(DESCRIPTION)
|
82 |
+
|
83 |
+
with gr.Row() as r:
|
84 |
+
with gr.Column() as c:
|
85 |
+
model_id = gr.Text(max_lines=1, label="model_id", value="jblalock30/coreml")
|
86 |
+
is_private = gr.Checkbox(label="Private model")
|
87 |
+
token = token_text()
|
88 |
+
with gr.Row() as c:
|
89 |
+
clean = gr.ClearButton()
|
90 |
+
submit = gr.Button("Submit", variant="primary")
|
91 |
+
|
92 |
+
with gr.Column() as d:
|
93 |
+
output = gr.Markdown(value="hi")
|
94 |
+
|
95 |
+
is_private.change(lambda s: token_text(s), inputs=is_private, outputs=token)
|
96 |
+
submit.click(run, inputs=[model_id, is_private, token], outputs=output, concurrency_limit=1)
|
97 |
+
|
98 |
+
demo.queue(max_size=10).launch(show_api=True)
|
ml-stable-diffusion
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit f7d2f5e9fb0681b15770943e492bf2f6dc3414f3
|
quantize.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
from collections import defaultdict
|
6 |
+
from tempfile import TemporaryDirectory
|
7 |
+
from typing import Dict, List, Optional, Set, Tuple
|
8 |
+
import subprocess
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
|
13 |
+
from huggingface_hub.file_download import repo_folder_name
|
14 |
+
from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file
|
15 |
+
|
16 |
+
ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]
|
17 |
+
|
18 |
+
def convert_generic(
|
19 |
+
model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]
|
20 |
+
) -> ConversionResult:
|
21 |
+
operations = []
|
22 |
+
errors = []
|
23 |
+
|
24 |
+
# python3 -m python_coreml_stable_diffusion.torch2coreml \
|
25 |
+
# --model-version stabilityai/sdxl-turbo \
|
26 |
+
# -o packages/sdxl-turbo \
|
27 |
+
# --convert-unet --convert-text-encoder --convert-vae-decoder --chunk-unet --attention-implementation ORIGINAL \
|
28 |
+
# --bundle-resources-for-swift-cli \
|
29 |
+
# --quantize-nbits 2
|
30 |
+
|
31 |
+
|
32 |
+
print("Starting conversion") #
|
33 |
+
|
34 |
+
subprocess.run(["python3", "-m" , "python_coreml_stable_diffusion.torch2coreml", "--model-version", "stabilityai/sd-turbo", "-o", folder, "--convert-unet", "--convert-text-encoder", "--convert-vae-decoder", "--chunk-unet", "--attention-implementation", "ORIGINAL", "--bundle-resources-for-swift-cli"])
|
35 |
+
# with open(f'{folder}/newfile.txt', 'w') as f:
|
36 |
+
# f.write('Hello, World!')
|
37 |
+
|
38 |
+
print("Done")
|
39 |
+
|
40 |
+
operations.append(CommitOperationAdd(path_in_repo='Resources', path_or_fileobj=f'{folder}/Resources'))
|
41 |
+
|
42 |
+
# extensions = set([".bin", ".ckpt"])
|
43 |
+
# for filename in filenames:
|
44 |
+
# prefix, ext = os.path.splitext(filename)
|
45 |
+
# if ext in extensions:
|
46 |
+
# pt_filename = hf_hub_download(
|
47 |
+
# model_id, revision=revision, filename=filename, token=token, cache_dir=folder
|
48 |
+
# )
|
49 |
+
# dirname, raw_filename = os.path.split(filename)
|
50 |
+
# if raw_filename == "pytorch_model.bin":
|
51 |
+
# # XXX: This is a special case to handle `transformers` and the
|
52 |
+
# # `transformers` part of the model which is actually loaded by `transformers`.
|
53 |
+
# sf_in_repo = os.path.join(dirname, "model.safetensors")
|
54 |
+
# else:
|
55 |
+
# sf_in_repo = f"{prefix}.safetensors"
|
56 |
+
# sf_filename = os.path.join(folder, sf_in_repo)
|
57 |
+
# try:
|
58 |
+
# convert_file(pt_filename, sf_filename, discard_names=[])
|
59 |
+
# operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
|
60 |
+
# except Exception as e:
|
61 |
+
# errors.append((pt_filename, e))
|
62 |
+
return operations, errors
|
63 |
+
|
64 |
+
def quantize(
|
65 |
+
api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False
|
66 |
+
) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
|
67 |
+
pr_title = "Adding `CoreML` variant of this model"
|
68 |
+
# info = api.model_info(model_id, revision=revision)
|
69 |
+
# filenames = set(s.rfilename for s in info.siblings)
|
70 |
+
|
71 |
+
with TemporaryDirectory() as d:
|
72 |
+
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
73 |
+
os.makedirs(folder)
|
74 |
+
new_pr = None
|
75 |
+
try:
|
76 |
+
operations = None
|
77 |
+
pr = None
|
78 |
+
|
79 |
+
operations, errors = convert_generic(model_id, revision=revision, folder=folder, filenames={"pytorch_model.bin"}, token=api.token)
|
80 |
+
|
81 |
+
new_pr = api.create_commit(
|
82 |
+
repo_id=model_id,
|
83 |
+
revision=revision,
|
84 |
+
operations=operations,
|
85 |
+
commit_message=pr_title,
|
86 |
+
commit_description="Add CoreML variant of this model",
|
87 |
+
create_pr=True,
|
88 |
+
)
|
89 |
+
print(f"Pr created at {new_pr.pr_url}")
|
90 |
+
finally:
|
91 |
+
shutil.rmtree(folder)
|
92 |
+
return new_pr, errors
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
DESCRIPTION = """
|
97 |
+
Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
|
98 |
+
It is PyTorch exclusive for now.
|
99 |
+
It works by downloading the weights (PT), converting them locally, and uploading them back
|
100 |
+
as a PR on the hub.
|
101 |
+
"""
|
102 |
+
parser = argparse.ArgumentParser(description=DESCRIPTION)
|
103 |
+
parser.add_argument(
|
104 |
+
"model_id",
|
105 |
+
type=str,
|
106 |
+
help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
|
107 |
+
)
|
108 |
+
parser.add_argument(
|
109 |
+
"--revision",
|
110 |
+
type=str,
|
111 |
+
help="The revision to convert",
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--force",
|
115 |
+
action="store_true",
|
116 |
+
help="Create the PR even if it already exists of if the model was already converted.",
|
117 |
+
)
|
118 |
+
parser.add_argument(
|
119 |
+
"-y",
|
120 |
+
action="store_true",
|
121 |
+
help="Ignore safety prompt",
|
122 |
+
)
|
123 |
+
args = parser.parse_args()
|
124 |
+
model_id = args.model_id
|
125 |
+
api = HfApi()
|
126 |
+
if args.y:
|
127 |
+
txt = "y"
|
128 |
+
else:
|
129 |
+
txt = input(
|
130 |
+
"This conversion script will unpickle a pickled file, which is inherently unsafe. If you do not trust this file, we invite you to use"
|
131 |
+
" https://huggingface.co/spaces/safetensors/convert or google colab or other hosted solution to avoid potential issues with this file."
|
132 |
+
" Continue [Y/n] ?"
|
133 |
+
)
|
134 |
+
if txt.lower() in {"", "y"}:
|
135 |
+
commit_info, errors = convert(api, model_id, revision=args.revision, force=args.force)
|
136 |
+
string = f"""
|
137 |
+
### Success 🔥
|
138 |
+
Yay! This model was successfully converted and a PR was open using your token, here:
|
139 |
+
[{commit_info.pr_url}]({commit_info.pr_url})
|
140 |
+
"""
|
141 |
+
if errors:
|
142 |
+
string += "\nErrors during conversion:\n"
|
143 |
+
string += "\n".join(
|
144 |
+
f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors
|
145 |
+
)
|
146 |
+
print(string)
|
147 |
+
else:
|
148 |
+
print(f"Answer was `{txt}` aborting.")
|
requirements.txt
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.24.1
|
2 |
+
aiofiles==23.2.1
|
3 |
+
altair==5.3.0
|
4 |
+
annotated-types==0.6.0
|
5 |
+
antlr4-python3-runtime==4.9.3
|
6 |
+
anyio==4.3.0
|
7 |
+
attrs==23.1.0
|
8 |
+
cattrs==23.1.2
|
9 |
+
certifi==2023.11.17
|
10 |
+
charset-normalizer==3.3.2
|
11 |
+
click==8.1.7
|
12 |
+
contourpy==1.1.1
|
13 |
+
coremltools==7.1
|
14 |
+
cycler==0.12.1
|
15 |
+
diffusers==0.22.3
|
16 |
+
exceptiongroup==1.1.3
|
17 |
+
fastapi==0.110.2
|
18 |
+
ffmpy==0.3.2
|
19 |
+
filelock==3.13.1
|
20 |
+
fonttools==4.44.0
|
21 |
+
fsspec==2023.12.2
|
22 |
+
gradio==4.27.0
|
23 |
+
gradio_client==0.15.1
|
24 |
+
h11==0.14.0
|
25 |
+
httpcore==1.0.5
|
26 |
+
httpx==0.27.0
|
27 |
+
huggingface-hub==0.22.2
|
28 |
+
idna==3.6
|
29 |
+
importlib-metadata==6.8.0
|
30 |
+
importlib-resources==6.1.1
|
31 |
+
iniconfig==2.0.0
|
32 |
+
invisible-watermark==0.2.0
|
33 |
+
Jinja2==3.1.2
|
34 |
+
joblib==1.3.2
|
35 |
+
jsonschema==4.21.1
|
36 |
+
jsonschema-specifications==2023.12.1
|
37 |
+
kiwisolver==1.4.5
|
38 |
+
lightning-utilities==0.11.2
|
39 |
+
markdown-it-py==3.0.0
|
40 |
+
MarkupSafe==2.1.3
|
41 |
+
matplotlib==3.7.3
|
42 |
+
mdurl==0.1.2
|
43 |
+
mpmath==1.3.0
|
44 |
+
networkx==3.1
|
45 |
+
ninja==1.11.1.1
|
46 |
+
numpy==1.23.5
|
47 |
+
omegaconf==2.3.0
|
48 |
+
opencv-python==4.8.1.78
|
49 |
+
orjson==3.10.1
|
50 |
+
packaging==23.2
|
51 |
+
pandas==2.0.3
|
52 |
+
Pillow==10.1.0
|
53 |
+
pkgutil_resolve_name==1.3.10
|
54 |
+
pluggy==1.3.0
|
55 |
+
protobuf==3.20.3
|
56 |
+
psutil==5.9.6
|
57 |
+
pyaml==23.9.7
|
58 |
+
pydantic==2.7.0
|
59 |
+
pydantic_core==2.18.1
|
60 |
+
pydub==0.25.1
|
61 |
+
Pygments==2.17.2
|
62 |
+
pyparsing==3.1.1
|
63 |
+
pytest==7.4.3
|
64 |
+
-e git+https://github.com/apple/ml-stable-diffusion.git@f7d2f5e9fb0681b15770943e492bf2f6dc3414f3#egg=python_coreml_stable_diffusion
|
65 |
+
python-dateutil==2.8.2
|
66 |
+
python-multipart==0.0.9
|
67 |
+
pytz==2024.1
|
68 |
+
PyWavelets==1.4.1
|
69 |
+
PyYAML==6.0.1
|
70 |
+
quanto==0.1.0
|
71 |
+
referencing==0.34.0
|
72 |
+
regex==2023.12.25
|
73 |
+
requests==2.31.0
|
74 |
+
rich==13.7.1
|
75 |
+
rpds-py==0.18.0
|
76 |
+
ruff==0.4.1
|
77 |
+
safetensors==0.4.1
|
78 |
+
scikit-learn==1.1.2
|
79 |
+
scipy==1.10.1
|
80 |
+
semantic-version==2.10.0
|
81 |
+
sentencepiece==0.2.0
|
82 |
+
shellingham==1.5.4
|
83 |
+
six==1.16.0
|
84 |
+
sniffio==1.3.1
|
85 |
+
starlette==0.37.2
|
86 |
+
sympy==1.12
|
87 |
+
threadpoolctl==3.2.0
|
88 |
+
tokenizers==0.14.1
|
89 |
+
tomli==2.0.1
|
90 |
+
tomlkit==0.12.0
|
91 |
+
toolz==0.12.1
|
92 |
+
torch==2.1.0
|
93 |
+
torchao==0.1
|
94 |
+
torchmetrics==1.3.2
|
95 |
+
tqdm==4.66.1
|
96 |
+
transformers==4.34.1
|
97 |
+
typer==0.12.3
|
98 |
+
typing_extensions==4.9.0
|
99 |
+
tzdata==2024.1
|
100 |
+
urllib3==2.1.0
|
101 |
+
uvicorn==0.29.0
|
102 |
+
websockets==11.0.3
|
103 |
+
zipp==3.17.0
|
104 |
+
setuptools_rust
|
105 |
+
pytorch_lightning
|