import argparse import json import os import shutil from collections import defaultdict from tempfile import TemporaryDirectory from typing import Dict, List, Optional, Set, Tuple import subprocess import torch from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download from huggingface_hub.file_download import repo_folder_name from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]] def convert_generic( model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str] ) -> ConversionResult: operations = [] errors = [] # python3 -m python_coreml_stable_diffusion.torch2coreml \ # --model-version stabilityai/sdxl-turbo \ # -o packages/sdxl-turbo \ # --convert-unet --convert-text-encoder --convert-vae-decoder --chunk-unet --attention-implementation ORIGINAL \ # --bundle-resources-for-swift-cli \ # --quantize-nbits 2 print("Starting conversion") # subprocess.run(["python3", "-m" , "python_coreml_stable_diffusion.torch2coreml", "--model-version", "stabilityai/sd-turbo", "-o", folder, "--convert-unet", "--convert-text-encoder", "--convert-vae-decoder", "--chunk-unet", "--attention-implementation", "ORIGINAL", "--bundle-resources-for-swift-cli"]) # with open(f'{folder}/newfile.txt', 'w') as f: # f.write('Hello, World!') print("Done") operations.append(CommitOperationAdd(path_in_repo='Resources', path_or_fileobj=f'{folder}/Resources')) # extensions = set([".bin", ".ckpt"]) # for filename in filenames: # prefix, ext = os.path.splitext(filename) # if ext in extensions: # pt_filename = hf_hub_download( # model_id, revision=revision, filename=filename, token=token, cache_dir=folder # ) # dirname, raw_filename = os.path.split(filename) # if raw_filename == "pytorch_model.bin": # # XXX: This is a special case to handle `transformers` and the # # `transformers` part of the model which is actually loaded by `transformers`. # sf_in_repo = os.path.join(dirname, "model.safetensors") # else: # sf_in_repo = f"{prefix}.safetensors" # sf_filename = os.path.join(folder, sf_in_repo) # try: # convert_file(pt_filename, sf_filename, discard_names=[]) # operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename)) # except Exception as e: # errors.append((pt_filename, e)) return operations, errors def quantize( api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False ) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]: pr_title = "Adding `CoreML` variant of this model" # info = api.model_info(model_id, revision=revision) # filenames = set(s.rfilename for s in info.siblings) with TemporaryDirectory() as d: folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models")) os.makedirs(folder) new_pr = None try: operations = None pr = None operations, errors = convert_generic(model_id, revision=revision, folder=folder, filenames={"pytorch_model.bin"}, token=api.token) new_pr = api.create_commit( repo_id=model_id, revision=revision, operations=operations, commit_message=pr_title, commit_description="Add CoreML variant of this model", create_pr=True, ) print(f"Pr created at {new_pr.pr_url}") finally: shutil.rmtree(folder) return new_pr, errors if __name__ == "__main__": DESCRIPTION = """ Simple utility tool to convert automatically some weights on the hub to `safetensors` format. It is PyTorch exclusive for now. It works by downloading the weights (PT), converting them locally, and uploading them back as a PR on the hub. """ parser = argparse.ArgumentParser(description=DESCRIPTION) parser.add_argument( "model_id", type=str, help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`", ) parser.add_argument( "--revision", type=str, help="The revision to convert", ) parser.add_argument( "--force", action="store_true", help="Create the PR even if it already exists of if the model was already converted.", ) parser.add_argument( "-y", action="store_true", help="Ignore safety prompt", ) args = parser.parse_args() model_id = args.model_id api = HfApi() if args.y: txt = "y" else: txt = input( "This conversion script will unpickle a pickled file, which is inherently unsafe. If you do not trust this file, we invite you to use" " https://huggingface.co/spaces/safetensors/convert or google colab or other hosted solution to avoid potential issues with this file." " Continue [Y/n] ?" ) if txt.lower() in {"", "y"}: commit_info, errors = convert(api, model_id, revision=args.revision, force=args.force) string = f""" ### Success 🔥 Yay! This model was successfully converted and a PR was open using your token, here: [{commit_info.pr_url}]({commit_info.pr_url}) """ if errors: string += "\nErrors during conversion:\n" string += "\n".join( f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors ) print(string) else: print(f"Answer was `{txt}` aborting.")