patrickvonplaten commited on
Commit
f56edba
1 Parent(s): 9c60619
Files changed (1) hide show
  1. convert_flax_to_pt.py +191 -14
convert_flax_to_pt.py CHANGED
@@ -1,30 +1,207 @@
1
- #!/usr/bin/env python3
2
  import argparse
3
- from huggingface_hub import HfApi
 
 
 
 
4
 
 
 
5
 
6
- def main(api, model_id):
7
- info = api.list_repo_refs(model_id)
8
- branches = set([b.name for b in info.branches]) - set(["main"])
9
 
10
- return list(branches)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  if __name__ == "__main__":
14
  DESCRIPTION = """
15
- Simple utility to get all branches from a repo
 
 
 
16
  """
17
  parser = argparse.ArgumentParser(description=DESCRIPTION)
18
  parser.add_argument(
19
- "--model_id",
20
  type=str,
21
- help="The name of the model on the hub to retrieve the branches from. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
 
 
 
 
 
22
  )
23
-
24
  args = parser.parse_args()
25
  model_id = args.model_id
26
  api = HfApi()
27
- branches = main(api, model_id)
28
-
29
- if len(branches) > 0:
30
- print(f"{model_id}: {branches}")
 
 
1
  import argparse
2
+ import json
3
+ import os
4
+ import shutil
5
+ from tempfile import TemporaryDirectory
6
+ from typing import List, Optional
7
 
8
+ from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
9
+ from huggingface_hub.file_download import repo_folder_name
10
 
 
 
 
11
 
12
+ class AlreadyExists(Exception):
13
+ pass
14
+
15
+
16
+ def is_index_stable_diffusion_like(config_dict):
17
+ if "_class_name" not in config_dict:
18
+ return False
19
+
20
+ compatible_classes = [
21
+ "AltDiffusionImg2ImgPipeline",
22
+ "AltDiffusionPipeline",
23
+ "CycleDiffusionPipeline",
24
+ "StableDiffusionImageVariationPipeline",
25
+ "StableDiffusionImg2ImgPipeline",
26
+ "StableDiffusionInpaintPipeline",
27
+ "StableDiffusionInpaintPipelineLegacy",
28
+ "StableDiffusionPipeline",
29
+ "StableDiffusionPipelineSafe",
30
+ "StableDiffusionUpscalePipeline",
31
+ "VersatileDiffusionDualGuidedPipeline",
32
+ "VersatileDiffusionImageVariationPipeline",
33
+ "VersatileDiffusionPipeline",
34
+ "VersatileDiffusionTextToImagePipeline",
35
+ "OnnxStableDiffusionImg2ImgPipeline",
36
+ "OnnxStableDiffusionInpaintPipeline",
37
+ "OnnxStableDiffusionInpaintPipelineLegacy",
38
+ "OnnxStableDiffusionPipeline",
39
+ "StableDiffusionOnnxPipeline",
40
+ "FlaxStableDiffusionPipeline",
41
+ ]
42
+ return config_dict["_class_name"] in compatible_classes
43
+
44
+
45
+ def convert_single(model_id: str, folder: str) -> List["CommitOperationAdd"]:
46
+ config_file = "model_index.json"
47
+ # os.makedirs(os.path.join(folder, "scheduler"), exist_ok=True)
48
+ model_index_file = hf_hub_download(repo_id=model_id, filename="model_index.json")
49
+
50
+ with open(model_index_file, "r") as f:
51
+ index_dict = json.load(f)
52
+ if index_dict.get("feature_extractor", None) is None:
53
+ print(f"{model_id} has no feature extractor")
54
+ return False, False
55
+
56
+ if index_dict["feature_extractor"][-1] != "CLIPFeatureExtractor":
57
+ print(f"{model_id} is not out of date or is not CLIP")
58
+ return False, False
59
+
60
+ # old_config_file = hf_hub_download(repo_id=model_id, filename=config_file)
61
+ old_config_file = model_index_file
62
+
63
+ new_config_file = os.path.join(folder, config_file)
64
+ success = convert_file(old_config_file, new_config_file)
65
+ if success:
66
+ operations = [CommitOperationAdd(path_in_repo=config_file, path_or_fileobj=new_config_file)]
67
+ model_type = success
68
+ return operations, model_type
69
+ else:
70
+ return False, False
71
+
72
+
73
+ def convert_file(
74
+ old_config: str,
75
+ new_config: str,
76
+ ):
77
+ with open(old_config, "r") as f:
78
+ old_dict = json.load(f)
79
+
80
+ old_dict["feature_extractor"][-1] = "CLIPImageProcessor"
81
+ # if "clip_sample" not in old_dict:
82
+ # print("Make scheduler DDIM compatible")
83
+ # old_dict["clip_sample"] = False
84
+ # else:
85
+ # print("No matching config")
86
+ # return False
87
+
88
+ with open(new_config, 'w') as f:
89
+ json_str = json.dumps(old_dict, indent=2, sort_keys=True) + "\n"
90
+ f.write(json_str)
91
+
92
+ return "Stable Diffusion"
93
+
94
+
95
+ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
96
+ try:
97
+ discussions = api.get_repo_discussions(repo_id=model_id)
98
+ except Exception:
99
+ return None
100
+ for discussion in discussions:
101
+ if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
102
+ return discussion
103
+
104
+
105
+ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
106
+ # pr_title = "Correct `sample_size` of {}'s unet to have correct width and height default"
107
+ pr_title = "Fix deprecation warning by changing `CLIPFeatureExtractor` to `CLIPImageProcessor`."
108
+ info = api.model_info(model_id)
109
+ filenames = set(s.rfilename for s in info.siblings)
110
+
111
+ is_sd = "model_index.json" in filenames
112
+
113
+ if is_sd:
114
+ model = StableDiffusionPipeline.from_pretrained(model_id, from_flax=True)
115
+ else:
116
+ model = ControlNetModel.from_pretrained(model_id, from_flax=True)
117
+
118
+ with TemporaryDirectory() as d:
119
+ folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
120
+ os.makedirs(folder)
121
+
122
+ model.save_pretrained(folder)
123
+ model.save_pretrained(folder, safe_serialization=True)
124
+
125
+ if is_sd:
126
+ model.to(torch_dtype=torch.float16)
127
+ else:
128
+ model.half()
129
+
130
+ model.save_pretrained(folder, variant="fp16")
131
+ model.save_pretrained(folder, safe_serialization=True, variant="fp16")
132
+
133
+ api.upload_folder(
134
+ folder_path=folder,
135
+ repo_id=model_id,
136
+ repo_type="model",
137
+ )
138
+ )
139
+
140
+ new_pr = None
141
+ try:
142
+ operations = None
143
+ pr = previous_pr(api, model_id, pr_title)
144
+ if pr is not None and not force:
145
+ url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
146
+ new_pr = pr
147
+ raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
148
+ else:
149
+ operations, model_type = convert_single(model_id, folder)
150
+
151
+ if operations:
152
+ pr_title = pr_title.format(model_type)
153
+ # if model_type == "Stable Diffusion 1":
154
+ # sample_size = 64
155
+ # image_size = 512
156
+ # elif model_type == "Stable Diffusion 2":
157
+ # sample_size = 96
158
+ # image_size = 768
159
+
160
+ # pr_description = (
161
+ # f"Since `diffusers==0.9.0` the width and height is automatically inferred from the `sample_size` attribute of your unet's config. It seems like your diffusion model has the same architecture as {model_type} which means that when using this model, by default an image size of {image_size}x{image_size} should be generated. This in turn means the unet's sample size should be **{sample_size}**. \n\n In order to suppress to update your configuration on the fly and to suppress the deprecation warning added in this PR: https://github.com/huggingface/diffusers/pull/1406/files#r1035703505 it is strongly recommended to merge this PR."
162
+ # )
163
+ contributor = model_id.split("/")[0]
164
+ pr_description = (
165
+ f"Hey {contributor} 👋, \n\n Your model repository seems to contain logic to load a feature extractor that is deprecated, which you should notice by seeing the warning: "
166
+ "\n\n ```\ntransformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. "
167
+ f"Please use CLIPImageProcessor instead. warnings.warn(\n``` \n\n when running `pipe = DiffusionPipeline.from_pretrained({model_id})`."
168
+ "This PR makes sure that the warning does not show anymore by replacing `CLIPFeatureExtractor` with `CLIPImageProcessor`. This will certainly not change or break your checkpoint, but only"
169
+ "make sure that everything is up to date. \n\n Best, the 🧨 Diffusers team."
170
+ )
171
+ new_pr = api.create_commit(
172
+ repo_id=model_id,
173
+ operations=operations,
174
+ commit_message=pr_title,
175
+ commit_description=pr_description,
176
+ create_pr=True,
177
+ )
178
+ print(f"Pr created at {new_pr.pr_url}")
179
+ else:
180
+ print(f"No files to convert for {model_id}")
181
+ finally:
182
+ shutil.rmtree(folder)
183
+ return new_pr
184
 
185
 
186
  if __name__ == "__main__":
187
  DESCRIPTION = """
188
+ Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
189
+ It is PyTorch exclusive for now.
190
+ It works by downloading the weights (PT), converting them locally, and uploading them back
191
+ as a PR on the hub.
192
  """
193
  parser = argparse.ArgumentParser(description=DESCRIPTION)
194
  parser.add_argument(
195
+ "model_id",
196
  type=str,
197
+ help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
198
+ )
199
+ parser.add_argument(
200
+ "--force",
201
+ action="store_true",
202
+ help="Create the PR even if it already exists of if the model was already converted.",
203
  )
 
204
  args = parser.parse_args()
205
  model_id = args.model_id
206
  api = HfApi()
207
+ convert(api, model_id, force=args.force)