justinblalock87 commited on
Commit
43c0fb7
1 Parent(s): dc19714

Add quantize

Browse files
__pycache__/app.cpython-38.pyc ADDED
Binary file (3.27 kB). View file
 
__pycache__/quantize.cpython-38.pyc ADDED
Binary file (4.69 kB). View file
 
app.py CHANGED
@@ -1,7 +1,98 @@
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ from datetime import datetime
3
+ import os
4
+ from typing import Optional
5
  import gradio as gr
6
 
7
+ import quantize
8
+ from huggingface_hub import HfApi, Repository
9
 
10
+
11
+ DATASET_REPO_URL = "https://huggingface.co/datasets/safetensors/conversions"
12
+ DATA_FILENAME = "data.csv"
13
+ DATA_FILE = os.path.join("data", DATA_FILENAME)
14
+
15
+ HF_TOKEN = os.environ.get("HF_TOKEN")
16
+
17
+ repo: Optional[Repository] = None
18
+ # TODO
19
+ if False and HF_TOKEN:
20
+ repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, token=HF_TOKEN)
21
+
22
+
23
+ def run(model_id: str, is_private: bool, token: Optional[str] = None) -> str:
24
+ if model_id == "":
25
+ return """
26
+ ### Invalid input 🐞
27
+
28
+ Please fill a token and model_id.
29
+ """
30
+ try:
31
+ if is_private:
32
+ api = HfApi(token=token)
33
+ else:
34
+ api = HfApi(token=HF_TOKEN)
35
+ hf_is_private = api.model_info(repo_id=model_id).private
36
+ if is_private and not hf_is_private:
37
+ # This model is NOT private
38
+ # Change the token so we make the PR on behalf of the bot.
39
+ api = HfApi(token=HF_TOKEN)
40
+
41
+ print("is_private", is_private)
42
+
43
+ commit_info, errors = quantize.quantize(api=api, model_id=model_id)
44
+ print("[commit_info]", commit_info)
45
+
46
+
47
+ string = f"""
48
+ ### Success 🔥
49
+ Yay! This model was successfully converted and a PR was open using your token, here:
50
+ [{commit_info.pr_url}]({commit_info.pr_url})
51
+ """
52
+ if errors:
53
+ string += "\nErrors during conversion:\n"
54
+ string += "\n".join(f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors)
55
+ return string
56
+ except Exception as e:
57
+ return f"""
58
+ ### Error 😢😢😢
59
+
60
+ {e}
61
+ """
62
+
63
+
64
+ DESCRIPTION = """
65
+ The steps are the following:
66
+ - Paste a read-access token from hf.co/settings/tokens. Read access is enough given that we will open a PR against the source repo.
67
+ - Input a model id from the Hub
68
+ - Click "Submit"
69
+ - That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the opened PR 🔥
70
+ ⚠️ For now only `pytorch_model.bin` files are supported but we'll extend in the future.
71
+ """
72
+
73
+ title="Quantize model and convert to CoreML"
74
+ allow_flagging="never"
75
+
76
+ def token_text(visible=False):
77
+ return gr.Text(max_lines=1, label="your_hf_token", visible=True, value="")
78
+
79
+ with gr.Blocks(title=title) as demo:
80
+ description = gr.Markdown(f"""# {title}""")
81
+ description = gr.Markdown(DESCRIPTION)
82
+
83
+ with gr.Row() as r:
84
+ with gr.Column() as c:
85
+ model_id = gr.Text(max_lines=1, label="model_id", value="jblalock30/coreml")
86
+ is_private = gr.Checkbox(label="Private model")
87
+ token = token_text()
88
+ with gr.Row() as c:
89
+ clean = gr.ClearButton()
90
+ submit = gr.Button("Submit", variant="primary")
91
+
92
+ with gr.Column() as d:
93
+ output = gr.Markdown(value="hi")
94
+
95
+ is_private.change(lambda s: token_text(s), inputs=is_private, outputs=token)
96
+ submit.click(run, inputs=[model_id, is_private, token], outputs=output, concurrency_limit=1)
97
+
98
+ demo.queue(max_size=10).launch(show_api=True)
ml-stable-diffusion ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit f7d2f5e9fb0681b15770943e492bf2f6dc3414f3
quantize.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import shutil
5
+ from collections import defaultdict
6
+ from tempfile import TemporaryDirectory
7
+ from typing import Dict, List, Optional, Set, Tuple
8
+ import subprocess
9
+
10
+ import torch
11
+
12
+ from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
13
+ from huggingface_hub.file_download import repo_folder_name
14
+ from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file
15
+
16
+ ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]
17
+
18
+ def convert_generic(
19
+ model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]
20
+ ) -> ConversionResult:
21
+ operations = []
22
+ errors = []
23
+
24
+ # python3 -m python_coreml_stable_diffusion.torch2coreml \
25
+ # --model-version stabilityai/sdxl-turbo \
26
+ # -o packages/sdxl-turbo \
27
+ # --convert-unet --convert-text-encoder --convert-vae-decoder --chunk-unet --attention-implementation ORIGINAL \
28
+ # --bundle-resources-for-swift-cli \
29
+ # --quantize-nbits 2
30
+
31
+
32
+ print("Starting conversion") #
33
+
34
+ subprocess.run(["python3", "-m" , "python_coreml_stable_diffusion.torch2coreml", "--model-version", "stabilityai/sd-turbo", "-o", folder, "--convert-unet", "--convert-text-encoder", "--convert-vae-decoder", "--chunk-unet", "--attention-implementation", "ORIGINAL", "--bundle-resources-for-swift-cli"])
35
+ # with open(f'{folder}/newfile.txt', 'w') as f:
36
+ # f.write('Hello, World!')
37
+
38
+ print("Done")
39
+
40
+ operations.append(CommitOperationAdd(path_in_repo='Resources', path_or_fileobj=f'{folder}/Resources'))
41
+
42
+ # extensions = set([".bin", ".ckpt"])
43
+ # for filename in filenames:
44
+ # prefix, ext = os.path.splitext(filename)
45
+ # if ext in extensions:
46
+ # pt_filename = hf_hub_download(
47
+ # model_id, revision=revision, filename=filename, token=token, cache_dir=folder
48
+ # )
49
+ # dirname, raw_filename = os.path.split(filename)
50
+ # if raw_filename == "pytorch_model.bin":
51
+ # # XXX: This is a special case to handle `transformers` and the
52
+ # # `transformers` part of the model which is actually loaded by `transformers`.
53
+ # sf_in_repo = os.path.join(dirname, "model.safetensors")
54
+ # else:
55
+ # sf_in_repo = f"{prefix}.safetensors"
56
+ # sf_filename = os.path.join(folder, sf_in_repo)
57
+ # try:
58
+ # convert_file(pt_filename, sf_filename, discard_names=[])
59
+ # operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
60
+ # except Exception as e:
61
+ # errors.append((pt_filename, e))
62
+ return operations, errors
63
+
64
+ def quantize(
65
+ api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False
66
+ ) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
67
+ pr_title = "Adding `CoreML` variant of this model"
68
+ # info = api.model_info(model_id, revision=revision)
69
+ # filenames = set(s.rfilename for s in info.siblings)
70
+
71
+ with TemporaryDirectory() as d:
72
+ folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
73
+ os.makedirs(folder)
74
+ new_pr = None
75
+ try:
76
+ operations = None
77
+ pr = None
78
+
79
+ operations, errors = convert_generic(model_id, revision=revision, folder=folder, filenames={"pytorch_model.bin"}, token=api.token)
80
+
81
+ new_pr = api.create_commit(
82
+ repo_id=model_id,
83
+ revision=revision,
84
+ operations=operations,
85
+ commit_message=pr_title,
86
+ commit_description="Add CoreML variant of this model",
87
+ create_pr=True,
88
+ )
89
+ print(f"Pr created at {new_pr.pr_url}")
90
+ finally:
91
+ shutil.rmtree(folder)
92
+ return new_pr, errors
93
+
94
+
95
+ if __name__ == "__main__":
96
+ DESCRIPTION = """
97
+ Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
98
+ It is PyTorch exclusive for now.
99
+ It works by downloading the weights (PT), converting them locally, and uploading them back
100
+ as a PR on the hub.
101
+ """
102
+ parser = argparse.ArgumentParser(description=DESCRIPTION)
103
+ parser.add_argument(
104
+ "model_id",
105
+ type=str,
106
+ help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
107
+ )
108
+ parser.add_argument(
109
+ "--revision",
110
+ type=str,
111
+ help="The revision to convert",
112
+ )
113
+ parser.add_argument(
114
+ "--force",
115
+ action="store_true",
116
+ help="Create the PR even if it already exists of if the model was already converted.",
117
+ )
118
+ parser.add_argument(
119
+ "-y",
120
+ action="store_true",
121
+ help="Ignore safety prompt",
122
+ )
123
+ args = parser.parse_args()
124
+ model_id = args.model_id
125
+ api = HfApi()
126
+ if args.y:
127
+ txt = "y"
128
+ else:
129
+ txt = input(
130
+ "This conversion script will unpickle a pickled file, which is inherently unsafe. If you do not trust this file, we invite you to use"
131
+ " https://huggingface.co/spaces/safetensors/convert or google colab or other hosted solution to avoid potential issues with this file."
132
+ " Continue [Y/n] ?"
133
+ )
134
+ if txt.lower() in {"", "y"}:
135
+ commit_info, errors = convert(api, model_id, revision=args.revision, force=args.force)
136
+ string = f"""
137
+ ### Success 🔥
138
+ Yay! This model was successfully converted and a PR was open using your token, here:
139
+ [{commit_info.pr_url}]({commit_info.pr_url})
140
+ """
141
+ if errors:
142
+ string += "\nErrors during conversion:\n"
143
+ string += "\n".join(
144
+ f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors
145
+ )
146
+ print(string)
147
+ else:
148
+ print(f"Answer was `{txt}` aborting.")
requirements.txt ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.24.1
2
+ aiofiles==23.2.1
3
+ altair==5.3.0
4
+ annotated-types==0.6.0
5
+ antlr4-python3-runtime==4.9.3
6
+ anyio==4.3.0
7
+ attrs==23.1.0
8
+ cattrs==23.1.2
9
+ certifi==2023.11.17
10
+ charset-normalizer==3.3.2
11
+ click==8.1.7
12
+ contourpy==1.1.1
13
+ coremltools==7.1
14
+ cycler==0.12.1
15
+ diffusers==0.22.3
16
+ exceptiongroup==1.1.3
17
+ fastapi==0.110.2
18
+ ffmpy==0.3.2
19
+ filelock==3.13.1
20
+ fonttools==4.44.0
21
+ fsspec==2023.12.2
22
+ gradio==4.27.0
23
+ gradio_client==0.15.1
24
+ h11==0.14.0
25
+ httpcore==1.0.5
26
+ httpx==0.27.0
27
+ huggingface-hub==0.22.2
28
+ idna==3.6
29
+ importlib-metadata==6.8.0
30
+ importlib-resources==6.1.1
31
+ iniconfig==2.0.0
32
+ invisible-watermark==0.2.0
33
+ Jinja2==3.1.2
34
+ joblib==1.3.2
35
+ jsonschema==4.21.1
36
+ jsonschema-specifications==2023.12.1
37
+ kiwisolver==1.4.5
38
+ lightning-utilities==0.11.2
39
+ markdown-it-py==3.0.0
40
+ MarkupSafe==2.1.3
41
+ matplotlib==3.7.3
42
+ mdurl==0.1.2
43
+ mpmath==1.3.0
44
+ networkx==3.1
45
+ ninja==1.11.1.1
46
+ numpy==1.23.5
47
+ omegaconf==2.3.0
48
+ opencv-python==4.8.1.78
49
+ orjson==3.10.1
50
+ packaging==23.2
51
+ pandas==2.0.3
52
+ Pillow==10.1.0
53
+ pkgutil_resolve_name==1.3.10
54
+ pluggy==1.3.0
55
+ protobuf==3.20.3
56
+ psutil==5.9.6
57
+ pyaml==23.9.7
58
+ pydantic==2.7.0
59
+ pydantic_core==2.18.1
60
+ pydub==0.25.1
61
+ Pygments==2.17.2
62
+ pyparsing==3.1.1
63
+ pytest==7.4.3
64
+ -e git+https://github.com/apple/ml-stable-diffusion.git@f7d2f5e9fb0681b15770943e492bf2f6dc3414f3#egg=python_coreml_stable_diffusion
65
+ python-dateutil==2.8.2
66
+ python-multipart==0.0.9
67
+ pytz==2024.1
68
+ PyWavelets==1.4.1
69
+ PyYAML==6.0.1
70
+ quanto==0.1.0
71
+ referencing==0.34.0
72
+ regex==2023.12.25
73
+ requests==2.31.0
74
+ rich==13.7.1
75
+ rpds-py==0.18.0
76
+ ruff==0.4.1
77
+ safetensors==0.4.1
78
+ scikit-learn==1.1.2
79
+ scipy==1.10.1
80
+ semantic-version==2.10.0
81
+ sentencepiece==0.2.0
82
+ shellingham==1.5.4
83
+ six==1.16.0
84
+ sniffio==1.3.1
85
+ starlette==0.37.2
86
+ sympy==1.12
87
+ threadpoolctl==3.2.0
88
+ tokenizers==0.14.1
89
+ tomli==2.0.1
90
+ tomlkit==0.12.0
91
+ toolz==0.12.1
92
+ torch==2.1.0
93
+ torchao==0.1
94
+ torchmetrics==1.3.2
95
+ tqdm==4.66.1
96
+ transformers==4.34.1
97
+ typer==0.12.3
98
+ typing_extensions==4.9.0
99
+ tzdata==2024.1
100
+ urllib3==2.1.0
101
+ uvicorn==0.29.0
102
+ websockets==11.0.3
103
+ zipp==3.17.0
104
+ setuptools_rust
105
+ pytorch_lightning