|
import argparse |
|
import json |
|
import os |
|
import shutil |
|
from diffusers.pipelines.stable_diffusion import safety_checker |
|
import torch |
|
from tempfile import TemporaryDirectory |
|
from typing import List, Optional |
|
from diffusers import StableDiffusionPipeline, ControlNetModel |
|
|
|
from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download |
|
from huggingface_hub.file_download import repo_folder_name |
|
|
|
|
|
def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]: |
|
info = api.model_info(model_id) |
|
filenames = set(s.rfilename for s in info.siblings) |
|
|
|
is_sd = "model_index.json" in filenames |
|
|
|
if is_sd: |
|
model = StableDiffusionPipeline.from_pretrained(model_id, from_flax=True, safety_checker=None) |
|
else: |
|
model = ControlNetModel.from_pretrained(model_id, from_flax=True) |
|
|
|
with TemporaryDirectory() as d: |
|
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models")) |
|
os.makedirs(folder) |
|
|
|
model.save_pretrained(folder) |
|
model.save_pretrained(folder, safe_serialization=True) |
|
|
|
if is_sd: |
|
model.to(torch_dtype=torch.float16) |
|
else: |
|
model.half() |
|
|
|
model.save_pretrained(folder, variant="fp16") |
|
model.save_pretrained(folder, safe_serialization=True, variant="fp16") |
|
|
|
api.upload_folder( |
|
folder_path=folder, |
|
repo_id=model_id, |
|
repo_type="model", |
|
create_pr=True, |
|
) |
|
print(model_id) |
|
|
|
if __name__ == "__main__": |
|
DESCRIPTION = """ |
|
Simple utility tool to convert automatically some weights on the hub to `safetensors` format. |
|
It is PyTorch exclusive for now. |
|
It works by downloading the weights (PT), converting them locally, and uploading them back |
|
as a PR on the hub. |
|
""" |
|
parser = argparse.ArgumentParser(description=DESCRIPTION) |
|
parser.add_argument( |
|
"model_id", |
|
type=str, |
|
help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`", |
|
) |
|
args = parser.parse_args() |
|
model_id = args.model_id |
|
api = HfApi() |
|
convert(api, model_id) |
|
|