justinblalock87
Final
08328de
import os
import shutil
from tempfile import TemporaryDirectory
from typing import List, Optional, Tuple
import subprocess
from huggingface_hub import CommitOperationAdd, HfApi
from huggingface_hub.file_download import repo_folder_name
ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]
def convert_to_core_ml(
model_id: str, folder: str, token: Optional[str], model_version: str, additional_args: str
) -> ConversionResult:
command = ["python3", "-m" , "python_coreml_stable_diffusion.torch2coreml", "--model-version", model_version, "-o", folder]
additional_args = additional_args
if additional_args == "":
# Set default args
additional_args = f"--convert-unet --convert-text-encoder --convert-vae-decoder --attention-implementation SPLIT_EINSUM --quantize-nbits 6"
command.extend(additional_args.split(" "))
print("Starting conversion: ", command)
subprocess.run(command)
print("Done")
api = HfApi(token=token)
api.upload_folder(
folder_path=folder,
repo_id=model_id,
path_in_repo="models",
repo_type="model",
)
def quantize(
api: "HfApi", model_id: str, model_version: str, additional_args: str
) -> None:
with TemporaryDirectory() as d:
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
os.makedirs(folder)
try:
convert_to_core_ml(model_id, folder, token=api.token, model_version=model_version, additional_args=additional_args)
finally:
shutil.rmtree(folder)