justinblalock87
commited on
Commit
β’
eb851d8
1
Parent(s):
f23a5b3
Additional args
Browse files- __pycache__/app.cpython-38.pyc +0 -0
- __pycache__/quantize.cpython-38.pyc +0 -0
- app.py +17 -41
- quantize.py +11 -64
__pycache__/app.cpython-38.pyc
CHANGED
Binary files a/__pycache__/app.cpython-38.pyc and b/__pycache__/app.cpython-38.pyc differ
|
|
__pycache__/quantize.cpython-38.pyc
CHANGED
Binary files a/__pycache__/quantize.cpython-38.pyc and b/__pycache__/quantize.cpython-38.pyc differ
|
|
app.py
CHANGED
@@ -14,13 +14,7 @@ DATA_FILE = os.path.join("data", DATA_FILENAME)
|
|
14 |
|
15 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
16 |
|
17 |
-
|
18 |
-
# TODO
|
19 |
-
if False and HF_TOKEN:
|
20 |
-
repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, token=HF_TOKEN)
|
21 |
-
|
22 |
-
|
23 |
-
def run(model_id: str, model_version: str, is_private: bool, token: Optional[str] = None) -> str:
|
24 |
if model_id == "":
|
25 |
return """
|
26 |
### Invalid input π
|
@@ -28,38 +22,19 @@ def run(model_id: str, model_version: str, is_private: bool, token: Optional[str
|
|
28 |
Please fill a token and model_id.
|
29 |
"""
|
30 |
login(token=token)
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
commit_info, errors = quantize.quantize(api=api, model_id=model_id, model_version=model_version)
|
45 |
-
print("[commit_info]", commit_info)
|
46 |
-
|
47 |
-
|
48 |
-
string = f"""
|
49 |
-
### Success π₯
|
50 |
-
Yay! This model was successfully converted and a PR was open using your token, here:
|
51 |
-
[{commit_info.pr_url}]({commit_info.pr_url})
|
52 |
-
"""
|
53 |
-
if errors:
|
54 |
-
string += "\nErrors during conversion:\n"
|
55 |
-
string += "\n".join(f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors)
|
56 |
-
return string
|
57 |
-
except Exception as e:
|
58 |
-
return f"""
|
59 |
-
### Error π’π’π’
|
60 |
-
|
61 |
-
{e}
|
62 |
-
"""
|
63 |
|
64 |
|
65 |
DESCRIPTION = """
|
@@ -85,6 +60,7 @@ with gr.Blocks(title=title) as demo:
|
|
85 |
with gr.Column() as c:
|
86 |
model_id = gr.Text(max_lines=1, label="model_id", value="jblalock30/coreml")
|
87 |
model_version = gr.Text(max_lines=1, label="model_version", value="stabilityai/sd-turbo")
|
|
|
88 |
is_private = gr.Checkbox(label="Private model")
|
89 |
token = token_text()
|
90 |
with gr.Row() as c:
|
@@ -92,9 +68,9 @@ with gr.Blocks(title=title) as demo:
|
|
92 |
submit = gr.Button("Submit", variant="primary")
|
93 |
|
94 |
with gr.Column() as d:
|
95 |
-
output = gr.Markdown(
|
96 |
|
97 |
is_private.change(lambda s: token_text(s), inputs=is_private, outputs=token)
|
98 |
-
submit.click(run, inputs=[model_id, model_version, is_private, token], outputs=output, concurrency_limit=1)
|
99 |
|
100 |
demo.queue(max_size=10).launch(show_api=True)
|
|
|
14 |
|
15 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
16 |
|
17 |
+
def run(model_id: str, model_version: str, additional_args: str, is_private: bool, token: Optional[str] = None) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
if model_id == "":
|
19 |
return """
|
20 |
### Invalid input π
|
|
|
22 |
Please fill a token and model_id.
|
23 |
"""
|
24 |
login(token=token)
|
25 |
+
if is_private:
|
26 |
+
api = HfApi(token=token)
|
27 |
+
else:
|
28 |
+
api = HfApi(token=HF_TOKEN)
|
29 |
+
hf_is_private = api.model_info(repo_id=model_id).private
|
30 |
+
if is_private and not hf_is_private:
|
31 |
+
# This model is NOT private
|
32 |
+
# Change the token so we make the PR on behalf of the bot.
|
33 |
+
api = HfApi(token=HF_TOKEN)
|
34 |
+
|
35 |
+
print("is_private", is_private)
|
36 |
+
|
37 |
+
quantize.quantize(api=api, model_id=model_id, model_version=model_version, additional_args=additional_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
|
40 |
DESCRIPTION = """
|
|
|
60 |
with gr.Column() as c:
|
61 |
model_id = gr.Text(max_lines=1, label="model_id", value="jblalock30/coreml")
|
62 |
model_version = gr.Text(max_lines=1, label="model_version", value="stabilityai/sd-turbo")
|
63 |
+
additional_args = gr.Text(max_lines=1, label="additional_args", value="--quantize-nbits 2 --convert-unet --convert-text-encoder --convert-vae-decoder --chunk-unet --attention-implementation ORIGINAL")
|
64 |
is_private = gr.Checkbox(label="Private model")
|
65 |
token = token_text()
|
66 |
with gr.Row() as c:
|
|
|
68 |
submit = gr.Button("Submit", variant="primary")
|
69 |
|
70 |
with gr.Column() as d:
|
71 |
+
output = gr.Markdown()
|
72 |
|
73 |
is_private.change(lambda s: token_text(s), inputs=is_private, outputs=token)
|
74 |
+
submit.click(run, inputs=[model_id, model_version, additional_args, is_private, token], outputs=output, concurrency_limit=1)
|
75 |
|
76 |
demo.queue(max_size=10).launch(show_api=True)
|
quantize.py
CHANGED
@@ -16,90 +16,37 @@ from safetensors.torch import _find_shared_tensors, _is_complete, load_file, sav
|
|
16 |
ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]
|
17 |
|
18 |
def convert_generic(
|
19 |
-
model_id: str,
|
20 |
) -> ConversionResult:
|
21 |
-
|
22 |
-
|
|
|
23 |
|
24 |
-
|
25 |
-
# --model-version stabilityai/sdxl-turbo \
|
26 |
-
# -o packages/sdxl-turbo \
|
27 |
-
# --convert-unet --convert-text-encoder --convert-vae-decoder --chunk-unet --attention-implementation ORIGINAL \
|
28 |
-
# --bundle-resources-for-swift-cli \
|
29 |
-
# --quantize-nbits 2
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
# subprocess.run(["python3", "-m" , "python_coreml_stable_diffusion.torch2coreml", "--model-version", "stabilityai/sd-turbo", "-o", folder, "--convert-unet", "--convert-text-encoder", "--convert-vae-decoder", "--chunk-unet", "--attention-implementation", "ORIGINAL", "--bundle-resources-for-swift-cli"])
|
34 |
-
subprocess.run(["python3", "-m" , "python_coreml_stable_diffusion.torch2coreml", "--model-version", model_version, "-o", folder, "--convert-unet", "--convert-text-encoder", "--convert-vae-decoder", "--chunk-unet", "--attention-implementation", "ORIGINAL", "--quantize-nbits", "2"])
|
35 |
-
# with open(f'{folder}/newfile.txt', 'w') as f:
|
36 |
-
# f.write('Hello, World!')
|
37 |
|
38 |
print("Done")
|
39 |
|
40 |
api = HfApi(token=token)
|
41 |
api.upload_folder(
|
42 |
folder_path=folder,
|
43 |
-
repo_id=
|
44 |
path_in_repo="models",
|
45 |
repo_type="model",
|
46 |
)
|
47 |
-
# for filename in os.listdir(folder):
|
48 |
-
# print(filename)
|
49 |
-
|
50 |
-
# operations.append(CommitOperationAdd(path_in_repo=f'models/{filename}', path_or_fileobj=f'{folder}/{filename}'))
|
51 |
-
|
52 |
-
# extensions = set([".bin", ".ckpt"])
|
53 |
-
# for filename in filenames:
|
54 |
-
# prefix, ext = os.path.splitext(filename)
|
55 |
-
# if ext in extensions:
|
56 |
-
# pt_filename = hf_hub_download(
|
57 |
-
# model_id, revision=revision, filename=filename, token=token, cache_dir=folder
|
58 |
-
# )
|
59 |
-
# dirname, raw_filename = os.path.split(filename)
|
60 |
-
# if raw_filename == "pytorch_model.bin":
|
61 |
-
# # XXX: This is a special case to handle `transformers` and the
|
62 |
-
# # `transformers` part of the model which is actually loaded by `transformers`.
|
63 |
-
# sf_in_repo = os.path.join(dirname, "model.safetensors")
|
64 |
-
# else:
|
65 |
-
# sf_in_repo = f"{prefix}.safetensors"
|
66 |
-
# sf_filename = os.path.join(folder, sf_in_repo)
|
67 |
-
# try:
|
68 |
-
# convert_file(pt_filename, sf_filename, discard_names=[])
|
69 |
-
# operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
|
70 |
-
# except Exception as e:
|
71 |
-
# errors.append((pt_filename, e))
|
72 |
-
return operations, errors
|
73 |
|
74 |
def quantize(
|
75 |
-
api: "HfApi", model_id: str, model_version: str,
|
76 |
-
) ->
|
77 |
-
|
78 |
-
# info = api.model_info(model_id, revision=revision)
|
79 |
-
# filenames = set(s.rfilename for s in info.siblings)
|
80 |
-
|
81 |
with TemporaryDirectory() as d:
|
82 |
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
83 |
os.makedirs(folder)
|
84 |
-
new_pr = None
|
85 |
try:
|
86 |
-
|
87 |
-
pr = None
|
88 |
-
|
89 |
-
operations, errors = convert_generic(model_id, model_version=model_version, revision=revision, folder=folder, filenames={"pytorch_model.bin"}, token=api.token)
|
90 |
-
|
91 |
-
# new_pr = api.create_commit(
|
92 |
-
# repo_id=model_id,
|
93 |
-
# revision=revision,
|
94 |
-
# operations=operations,
|
95 |
-
# commit_message=pr_title,
|
96 |
-
# commit_description="Add CoreML variant of this model",
|
97 |
-
# create_pr=True,
|
98 |
-
# )
|
99 |
-
# print(f"Pr created at {new_pr.pr_url}")
|
100 |
finally:
|
101 |
shutil.rmtree(folder)
|
102 |
-
return new_pr, errors
|
103 |
|
104 |
|
105 |
if __name__ == "__main__":
|
|
|
16 |
ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]
|
17 |
|
18 |
def convert_generic(
|
19 |
+
model_id: str, folder: str, token: Optional[str], model_version: str, additional_args: str
|
20 |
) -> ConversionResult:
|
21 |
+
|
22 |
+
command = ["python3", "-m" , "python_coreml_stable_diffusion.torch2coreml", "--model-version", model_version, "-o", folder]
|
23 |
+
command.extend(additional_args.split(" "))
|
24 |
|
25 |
+
print("Starting conversion")
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
subprocess.run(command)
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
print("Done")
|
30 |
|
31 |
api = HfApi(token=token)
|
32 |
api.upload_folder(
|
33 |
folder_path=folder,
|
34 |
+
repo_id=model_id,
|
35 |
path_in_repo="models",
|
36 |
repo_type="model",
|
37 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
def quantize(
|
40 |
+
api: "HfApi", model_id: str, model_version: str, additional_args: str
|
41 |
+
) -> None:
|
42 |
+
|
|
|
|
|
|
|
43 |
with TemporaryDirectory() as d:
|
44 |
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
45 |
os.makedirs(folder)
|
|
|
46 |
try:
|
47 |
+
convert_generic(model_id, folder, token=api.token, model_version=model_version, additional_args=additional_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
finally:
|
49 |
shutil.rmtree(folder)
|
|
|
50 |
|
51 |
|
52 |
if __name__ == "__main__":
|