Ubuntu commited on
Commit
b7fe3c7
β€’
1 Parent(s): 968cf44

add revision option

Browse files
Files changed (2) hide show
  1. app.py +4 -2
  2. convert.py +15 -9
app.py CHANGED
@@ -19,7 +19,7 @@ if HF_TOKEN:
19
  repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, token=HF_TOKEN)
20
 
21
 
22
- def run(token: str, model_id: str) -> str:
23
  if token == "" or model_id == "":
24
  return """
25
  ### Invalid input 🐞
@@ -31,7 +31,7 @@ def run(token: str, model_id: str) -> str:
31
  is_private = api.model_info(repo_id=model_id).private
32
  print("is_private", is_private)
33
 
34
- commit_info = convert(api=api, model_id=model_id, force=True)
35
  print("[commit_info]", commit_info)
36
 
37
  # save in a (public) dataset:
@@ -72,6 +72,7 @@ The steps are the following:
72
 
73
  - Paste a read-access token from hf.co/settings/tokens. Read access is enough given that we will open a PR against the source repo.
74
  - Input a model id from the Hub
 
75
  - Click "Submit"
76
  - That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the opened PR πŸ”₯
77
 
@@ -86,6 +87,7 @@ demo = gr.Interface(
86
  inputs=[
87
  gr.Text(max_lines=1, label="your_hf_token"),
88
  gr.Text(max_lines=1, label="model_id"),
 
89
  ],
90
  outputs=[gr.Markdown(label="output")],
91
  fn=run,
 
19
  repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, token=HF_TOKEN)
20
 
21
 
22
+ def run(token: str, model_id: str, revision: str = "main") -> str:
23
  if token == "" or model_id == "":
24
  return """
25
  ### Invalid input 🐞
 
31
  is_private = api.model_info(repo_id=model_id).private
32
  print("is_private", is_private)
33
 
34
+ commit_info = convert(api=api, model_id=model_id, revision=revision, force=True)
35
  print("[commit_info]", commit_info)
36
 
37
  # save in a (public) dataset:
 
72
 
73
  - Paste a read-access token from hf.co/settings/tokens. Read access is enough given that we will open a PR against the source repo.
74
  - Input a model id from the Hub
75
+ - Optionally select a revision like fp16
76
  - Click "Submit"
77
  - That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the opened PR πŸ”₯
78
 
 
87
  inputs=[
88
  gr.Text(max_lines=1, label="your_hf_token"),
89
  gr.Text(max_lines=1, label="model_id"),
90
+ gr.Text(max_lines=1, label="revision", default="main"),
91
  ],
92
  outputs=[gr.Markdown(label="output")],
93
  fn=run,
convert.py CHANGED
@@ -51,15 +51,15 @@ def rename(pt_filename: str) -> str:
51
  return local
52
 
53
 
54
- def convert_multi(model_id: str, folder: str) -> List["CommitOperationAdd"]:
55
- filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json")
56
  with open(filename, "r") as f:
57
  data = json.load(f)
58
 
59
  filenames = set(data["weight_map"].values())
60
  local_filenames = []
61
  for filename in filenames:
62
- pt_filename = hf_hub_download(repo_id=model_id, filename=filename)
63
 
64
  sf_filename = rename(pt_filename)
65
  sf_filename = os.path.join(folder, sf_filename)
@@ -143,14 +143,14 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
143
  return discussion
144
 
145
 
146
- def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["CommitOperationAdd"]:
147
  operations = []
148
 
149
  extensions = set([".bin", ".ckpt"])
150
  for filename in filenames:
151
  prefix, ext = os.path.splitext(filename)
152
  if ext in extensions:
153
- pt_filename = hf_hub_download(model_id, filename=filename)
154
  dirname, raw_filename = os.path.split(filename)
155
  if raw_filename == "pytorch_model.bin":
156
  # XXX: This is a special case to handle `transformers` and the
@@ -164,9 +164,9 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["Co
164
  return operations
165
 
166
 
167
- def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
168
  pr_title = "Adding `safetensors` variant of this model"
169
- info = api.model_info(model_id)
170
 
171
  def is_valid_filename(filename):
172
  return len(filename.split("/")) > 1 or filename in ["pytorch_model.bin", "diffusion_pytorch_model.bin"]
@@ -190,7 +190,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi
190
  raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
191
  else:
192
  print("Convert generic")
193
- operations = convert_generic(model_id, folder, filenames)
194
 
195
  if operations:
196
  new_pr = api.create_commit(
@@ -225,7 +225,13 @@ if __name__ == "__main__":
225
  action="store_true",
226
  help="Create the PR even if it already exists of if the model was already converted.",
227
  )
 
 
 
 
 
228
  args = parser.parse_args()
229
  model_id = args.model_id
 
230
  api = HfApi()
231
- convert(api, model_id, force=args.force)
 
51
  return local
52
 
53
 
54
+ def convert_multi(model_id: str, folder: str, revision: str = "main") -> List["CommitOperationAdd"]:
55
+ filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", revision=revision)
56
  with open(filename, "r") as f:
57
  data = json.load(f)
58
 
59
  filenames = set(data["weight_map"].values())
60
  local_filenames = []
61
  for filename in filenames:
62
+ pt_filename = hf_hub_download(repo_id=model_id, filename=filename, revision=revision)
63
 
64
  sf_filename = rename(pt_filename)
65
  sf_filename = os.path.join(folder, sf_filename)
 
143
  return discussion
144
 
145
 
146
+ def convert_generic(model_id: str, folder: str, filenames: Set[str], revision: str = "main") -> List["CommitOperationAdd"]:
147
  operations = []
148
 
149
  extensions = set([".bin", ".ckpt"])
150
  for filename in filenames:
151
  prefix, ext = os.path.splitext(filename)
152
  if ext in extensions:
153
+ pt_filename = hf_hub_download(model_id, filename=filename, revision=revision)
154
  dirname, raw_filename = os.path.split(filename)
155
  if raw_filename == "pytorch_model.bin":
156
  # XXX: This is a special case to handle `transformers` and the
 
164
  return operations
165
 
166
 
167
+ def convert(api: "HfApi", model_id: str, force: bool = False, revision: str = "main") -> Optional["CommitInfo"]:
168
  pr_title = "Adding `safetensors` variant of this model"
169
+ info = api.model_info(model_id, revision=revision)
170
 
171
  def is_valid_filename(filename):
172
  return len(filename.split("/")) > 1 or filename in ["pytorch_model.bin", "diffusion_pytorch_model.bin"]
 
190
  raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
191
  else:
192
  print("Convert generic")
193
+ operations = convert_generic(model_id, folder, filenames, revision=revision)
194
 
195
  if operations:
196
  new_pr = api.create_commit(
 
225
  action="store_true",
226
  help="Create the PR even if it already exists of if the model was already converted.",
227
  )
228
+ parser.add_argument(
229
+ "revision",
230
+ default="main",
231
+ help="Branch to convert. E.g. main, fp16, bf16"
232
+ )
233
  args = parser.parse_args()
234
  model_id = args.model_id
235
+ revision = args.revision
236
  api = HfApi()
237
+ convert(api, model_id, force=args.force, revision=revision)