|
import argparse |
|
import json |
|
import os |
|
import shutil |
|
from collections import defaultdict |
|
from tempfile import TemporaryDirectory |
|
from typing import Dict, List, Optional, Set, Tuple |
|
import subprocess |
|
|
|
import torch |
|
|
|
from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download |
|
from huggingface_hub.file_download import repo_folder_name |
|
from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file |
|
|
|
ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]] |
|
|
|
def convert_generic( |
|
model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str] |
|
) -> ConversionResult: |
|
operations = [] |
|
errors = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Starting conversion") |
|
|
|
subprocess.run(["python3", "-m" , "python_coreml_stable_diffusion.torch2coreml", "--model-version", "stabilityai/sd-turbo", "-o", folder, "--convert-unet", "--convert-text-encoder", "--convert-vae-decoder", "--chunk-unet", "--attention-implementation", "ORIGINAL", "--bundle-resources-for-swift-cli"]) |
|
|
|
|
|
|
|
print("Done") |
|
|
|
operations.append(CommitOperationAdd(path_in_repo='Resources', path_or_fileobj=f'{folder}/Resources')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return operations, errors |
|
|
|
def quantize( |
|
api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False |
|
) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]: |
|
pr_title = "Adding `CoreML` variant of this model" |
|
|
|
|
|
|
|
with TemporaryDirectory() as d: |
|
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models")) |
|
os.makedirs(folder) |
|
new_pr = None |
|
try: |
|
operations = None |
|
pr = None |
|
|
|
operations, errors = convert_generic(model_id, revision=revision, folder=folder, filenames={"pytorch_model.bin"}, token=api.token) |
|
|
|
new_pr = api.create_commit( |
|
repo_id=model_id, |
|
revision=revision, |
|
operations=operations, |
|
commit_message=pr_title, |
|
commit_description="Add CoreML variant of this model", |
|
create_pr=True, |
|
) |
|
print(f"Pr created at {new_pr.pr_url}") |
|
finally: |
|
shutil.rmtree(folder) |
|
return new_pr, errors |
|
|
|
|
|
if __name__ == "__main__": |
|
DESCRIPTION = """ |
|
Simple utility tool to convert automatically some weights on the hub to `safetensors` format. |
|
It is PyTorch exclusive for now. |
|
It works by downloading the weights (PT), converting them locally, and uploading them back |
|
as a PR on the hub. |
|
""" |
|
parser = argparse.ArgumentParser(description=DESCRIPTION) |
|
parser.add_argument( |
|
"model_id", |
|
type=str, |
|
help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`", |
|
) |
|
parser.add_argument( |
|
"--revision", |
|
type=str, |
|
help="The revision to convert", |
|
) |
|
parser.add_argument( |
|
"--force", |
|
action="store_true", |
|
help="Create the PR even if it already exists of if the model was already converted.", |
|
) |
|
parser.add_argument( |
|
"-y", |
|
action="store_true", |
|
help="Ignore safety prompt", |
|
) |
|
args = parser.parse_args() |
|
model_id = args.model_id |
|
api = HfApi() |
|
if args.y: |
|
txt = "y" |
|
else: |
|
txt = input( |
|
"This conversion script will unpickle a pickled file, which is inherently unsafe. If you do not trust this file, we invite you to use" |
|
" https://huggingface.co/spaces/safetensors/convert or google colab or other hosted solution to avoid potential issues with this file." |
|
" Continue [Y/n] ?" |
|
) |
|
if txt.lower() in {"", "y"}: |
|
commit_info, errors = convert(api, model_id, revision=args.revision, force=args.force) |
|
string = f""" |
|
### Success 🔥 |
|
Yay! This model was successfully converted and a PR was open using your token, here: |
|
[{commit_info.pr_url}]({commit_info.pr_url}) |
|
""" |
|
if errors: |
|
string += "\nErrors during conversion:\n" |
|
string += "\n".join( |
|
f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors |
|
) |
|
print(string) |
|
else: |
|
print(f"Answer was `{txt}` aborting.") |