justinblalock87
Add quantize
43c0fb7
raw
history blame
5.93 kB
import argparse
import json
import os
import shutil
from collections import defaultdict
from tempfile import TemporaryDirectory
from typing import Dict, List, Optional, Set, Tuple
import subprocess
import torch
from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name
from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file
ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]
def convert_generic(
model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]
) -> ConversionResult:
operations = []
errors = []
# python3 -m python_coreml_stable_diffusion.torch2coreml \
# --model-version stabilityai/sdxl-turbo \
# -o packages/sdxl-turbo \
# --convert-unet --convert-text-encoder --convert-vae-decoder --chunk-unet --attention-implementation ORIGINAL \
# --bundle-resources-for-swift-cli \
# --quantize-nbits 2
print("Starting conversion") #
subprocess.run(["python3", "-m" , "python_coreml_stable_diffusion.torch2coreml", "--model-version", "stabilityai/sd-turbo", "-o", folder, "--convert-unet", "--convert-text-encoder", "--convert-vae-decoder", "--chunk-unet", "--attention-implementation", "ORIGINAL", "--bundle-resources-for-swift-cli"])
# with open(f'{folder}/newfile.txt', 'w') as f:
# f.write('Hello, World!')
print("Done")
operations.append(CommitOperationAdd(path_in_repo='Resources', path_or_fileobj=f'{folder}/Resources'))
# extensions = set([".bin", ".ckpt"])
# for filename in filenames:
# prefix, ext = os.path.splitext(filename)
# if ext in extensions:
# pt_filename = hf_hub_download(
# model_id, revision=revision, filename=filename, token=token, cache_dir=folder
# )
# dirname, raw_filename = os.path.split(filename)
# if raw_filename == "pytorch_model.bin":
# # XXX: This is a special case to handle `transformers` and the
# # `transformers` part of the model which is actually loaded by `transformers`.
# sf_in_repo = os.path.join(dirname, "model.safetensors")
# else:
# sf_in_repo = f"{prefix}.safetensors"
# sf_filename = os.path.join(folder, sf_in_repo)
# try:
# convert_file(pt_filename, sf_filename, discard_names=[])
# operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
# except Exception as e:
# errors.append((pt_filename, e))
return operations, errors
def quantize(
api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False
) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
pr_title = "Adding `CoreML` variant of this model"
# info = api.model_info(model_id, revision=revision)
# filenames = set(s.rfilename for s in info.siblings)
with TemporaryDirectory() as d:
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
os.makedirs(folder)
new_pr = None
try:
operations = None
pr = None
operations, errors = convert_generic(model_id, revision=revision, folder=folder, filenames={"pytorch_model.bin"}, token=api.token)
new_pr = api.create_commit(
repo_id=model_id,
revision=revision,
operations=operations,
commit_message=pr_title,
commit_description="Add CoreML variant of this model",
create_pr=True,
)
print(f"Pr created at {new_pr.pr_url}")
finally:
shutil.rmtree(folder)
return new_pr, errors
if __name__ == "__main__":
DESCRIPTION = """
Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
It is PyTorch exclusive for now.
It works by downloading the weights (PT), converting them locally, and uploading them back
as a PR on the hub.
"""
parser = argparse.ArgumentParser(description=DESCRIPTION)
parser.add_argument(
"model_id",
type=str,
help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
)
parser.add_argument(
"--revision",
type=str,
help="The revision to convert",
)
parser.add_argument(
"--force",
action="store_true",
help="Create the PR even if it already exists of if the model was already converted.",
)
parser.add_argument(
"-y",
action="store_true",
help="Ignore safety prompt",
)
args = parser.parse_args()
model_id = args.model_id
api = HfApi()
if args.y:
txt = "y"
else:
txt = input(
"This conversion script will unpickle a pickled file, which is inherently unsafe. If you do not trust this file, we invite you to use"
" https://huggingface.co/spaces/safetensors/convert or google colab or other hosted solution to avoid potential issues with this file."
" Continue [Y/n] ?"
)
if txt.lower() in {"", "y"}:
commit_info, errors = convert(api, model_id, revision=args.revision, force=args.force)
string = f"""
### Success 🔥
Yay! This model was successfully converted and a PR was open using your token, here:
[{commit_info.pr_url}]({commit_info.pr_url})
"""
if errors:
string += "\nErrors during conversion:\n"
string += "\n".join(
f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors
)
print(string)
else:
print(f"Answer was `{txt}` aborting.")