File size: 2,184 Bytes
066cf1b
f56edba
 
 
40d1ba9
d7c590b
f56edba
 
d7c590b
066cf1b
f56edba
 
066cf1b
 
f56edba
 
 
 
 
 
 
40d1ba9
f56edba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7c590b
f56edba
d7c590b
066cf1b
 
 
f56edba
 
 
 
066cf1b
 
 
f56edba
066cf1b
f56edba
 
066cf1b
 
 
d7c590b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import json
import os
import shutil
from diffusers.pipelines.stable_diffusion import safety_checker
import torch
from tempfile import TemporaryDirectory
from typing import List, Optional
from diffusers import StableDiffusionPipeline, ControlNetModel

from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name


def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
    info = api.model_info(model_id)
    filenames = set(s.rfilename for s in info.siblings)

    is_sd = "model_index.json" in filenames

    if is_sd:
        model = StableDiffusionPipeline.from_pretrained(model_id, from_flax=True, safety_checker=None)
    else:
        model = ControlNetModel.from_pretrained(model_id, from_flax=True)

    with TemporaryDirectory() as d:
        folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
        os.makedirs(folder)

        model.save_pretrained(folder)
        model.save_pretrained(folder, safe_serialization=True)

        if is_sd:
            model.to(torch_dtype=torch.float16)
        else:
            model.half()

        model.save_pretrained(folder, variant="fp16")
        model.save_pretrained(folder, safe_serialization=True, variant="fp16")

        api.upload_folder(
            folder_path=folder,
            repo_id=model_id,
            repo_type="model",
            create_pr=True,
        )
        print(model_id)

if __name__ == "__main__":
    DESCRIPTION = """
    Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
    It is PyTorch exclusive for now.
    It works by downloading the weights (PT), converting them locally, and uploading them back
    as a PR on the hub.
    """
    parser = argparse.ArgumentParser(description=DESCRIPTION)
    parser.add_argument(
        "model_id",
        type=str,
        help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
    )
    args = parser.parse_args()
    model_id = args.model_id
    api = HfApi()
    convert(api, model_id)