Felix Marty commited on
Commit
04c8db1
1 Parent(s): ff649d1
Files changed (2) hide show
  1. app.py +21 -10
  2. onnx_export.py +52 -26
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import csv
2
- import datetime
3
  import os
 
4
  from typing import Optional
5
- import gradio as gr
6
 
7
- from onnx_export import convert
8
  from huggingface_hub import HfApi, Repository
9
 
 
10
 
11
  DATASET_REPO_URL = "https://huggingface.co/datasets/optimum/exporters"
12
  DATA_FILENAME = "data.csv"
@@ -20,6 +20,7 @@ repo: Optional[Repository] = None
20
  if HF_TOKEN:
21
  repo = Repository(local_dir=DATADIR, clone_from=DATASET_REPO_URL, token=HF_TOKEN)
22
 
 
23
  def onnx_export(token: str, model_id: str, task: str) -> str:
24
  if token == "" or model_id == "":
25
  return """
@@ -33,7 +34,7 @@ def onnx_export(token: str, model_id: str, task: str) -> str:
33
  error, commit_info = convert(api=api, model_id=model_id, task=task)
34
  if error != "0":
35
  return error
36
-
37
  print("[commit_info]", commit_info)
38
 
39
  # save in a private dataset
@@ -57,6 +58,7 @@ def onnx_export(token: str, model_id: str, task: str) -> str:
57
  except Exception as e:
58
  return f"#### Error: {e}"
59
 
 
60
  TTILE_IMAGE = """
61
  <div
62
  style="
@@ -111,14 +113,23 @@ with gr.Blocks() as demo:
111
 
112
  with gr.Column():
113
  input_token = gr.Textbox(max_lines=1, label="Hugging Face token")
114
- input_model = gr.Textbox(max_lines=1, label="Model name", placeholder="textattack/distilbert-base-cased-CoLA")
115
- input_task = gr.Textbox(value="auto", max_lines=1, label="Task (can be left to \"auto\", will be automatically inferred)")
 
 
 
 
 
 
 
 
116
 
117
  btn = gr.Button("Convert to ONNX")
118
  output = gr.Markdown(label="Output")
119
-
120
-
121
- btn.click(fn=onnx_export, inputs=[input_token, input_model, input_task], outputs=output)
 
122
 
123
  """
124
  demo = gr.Interface(
@@ -136,4 +147,4 @@ demo = gr.Interface(
136
  )
137
  """
138
 
139
- demo.launch()
 
1
  import csv
 
2
  import os
3
+ from datetime import datetime
4
  from typing import Optional
 
5
 
6
+ import gradio as gr
7
  from huggingface_hub import HfApi, Repository
8
 
9
+ from onnx_export import convert
10
 
11
  DATASET_REPO_URL = "https://huggingface.co/datasets/optimum/exporters"
12
  DATA_FILENAME = "data.csv"
 
20
  if HF_TOKEN:
21
  repo = Repository(local_dir=DATADIR, clone_from=DATASET_REPO_URL, token=HF_TOKEN)
22
 
23
+
24
  def onnx_export(token: str, model_id: str, task: str) -> str:
25
  if token == "" or model_id == "":
26
  return """
 
34
  error, commit_info = convert(api=api, model_id=model_id, task=task)
35
  if error != "0":
36
  return error
37
+
38
  print("[commit_info]", commit_info)
39
 
40
  # save in a private dataset
 
58
  except Exception as e:
59
  return f"#### Error: {e}"
60
 
61
+
62
  TTILE_IMAGE = """
63
  <div
64
  style="
 
113
 
114
  with gr.Column():
115
  input_token = gr.Textbox(max_lines=1, label="Hugging Face token")
116
+ input_model = gr.Textbox(
117
+ max_lines=1,
118
+ label="Model name",
119
+ placeholder="textattack/distilbert-base-cased-CoLA",
120
+ )
121
+ input_task = gr.Textbox(
122
+ value="auto",
123
+ max_lines=1,
124
+ label='Task (can be left to "auto", will be automatically inferred)',
125
+ )
126
 
127
  btn = gr.Button("Convert to ONNX")
128
  output = gr.Markdown(label="Output")
129
+
130
+ btn.click(
131
+ fn=onnx_export, inputs=[input_token, input_model, input_task], outputs=output
132
+ )
133
 
134
  """
135
  demo = gr.Interface(
 
147
  )
148
  """
149
 
150
+ demo.launch()
onnx_export.py CHANGED
@@ -1,33 +1,35 @@
1
- from optimum.exporters.tasks import TasksManager
2
-
3
- from optimum.exporters.onnx import OnnxConfigWithPast, export, validate_model_outputs
4
-
5
- from tempfile import TemporaryDirectory
6
-
7
- from transformers import AutoConfig, AutoTokenizer, is_torch_available
8
-
9
- from pathlib import Path
10
-
11
  import os
12
  import shutil
13
- import argparse
14
-
15
- from typing import Optional, Tuple, List
16
 
17
- from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download, get_repo_discussions
 
18
  from huggingface_hub.file_download import repo_folder_name
 
 
 
 
19
 
20
  SPACES_URL = "https://huggingface.co/spaces/optimum/exporters"
21
 
 
22
  def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
23
  try:
24
  discussions = api.get_repo_discussions(repo_id=model_id)
25
  except Exception:
26
  return None
27
  for discussion in discussions:
28
- if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
 
 
 
 
29
  return discussion
30
 
 
31
  def convert_onnx(model_id: str, task: str, folder: str) -> List:
32
 
33
  # Allocate the model
@@ -46,7 +48,7 @@ def convert_onnx(model_id: str, task: str, folder: str) -> List:
46
  and task in ["sequence_classification"]
47
  )
48
  if needs_pad_token_id:
49
- #if args.pad_token_id is not None:
50
  # model.config.pad_token_id = args.pad_token_id
51
  try:
52
  tok = AutoTokenizer.from_pretrained(model_id)
@@ -76,18 +78,37 @@ def convert_onnx(model_id: str, task: str, folder: str) -> List:
76
  print(f"All good, model saved at: {output}")
77
  except ValueError:
78
  print(f"An error occured, but the model was saved at: {output.as_posix()}")
79
-
80
- n_files = len([name for name in os.listdir(folder) if os.path.isfile(os.path.join(folder, name)) and not name.startswith(".")])
81
-
 
 
 
 
 
 
82
  if n_files == 1:
83
- operations = [CommitOperationAdd(path_in_repo=file_name, path_or_fileobj=os.path.join(folder, file_name)) for file_name in os.listdir(folder)]
 
 
 
 
 
84
  else:
85
- operations = [CommitOperationAdd(path_in_repo=os.path.join("onnx", file_name), path_or_fileobj=os.path.join(folder, file_name)) for file_name in os.listdir(folder)]
86
-
 
 
 
 
 
 
87
  return operations
88
 
89
 
90
- def convert(api: "HfApi", model_id: str, task: str, force: bool = False) -> Tuple[int, "CommitInfo"]:
 
 
91
  pr_title = "Adding ONNX file of this model"
92
  info = api.model_info(model_id)
93
  filenames = set(s.rfilename for s in info.siblings)
@@ -98,7 +119,10 @@ def convert(api: "HfApi", model_id: str, task: str, force: bool = False) -> Tupl
98
  try:
99
  task = TasksManager.infer_task_from_model(model_id)
100
  except Exception as e:
101
- return f"### Error: {e}. Please pass explicitely the task as it could not be infered.", None
 
 
 
102
 
103
  with TemporaryDirectory() as d:
104
  folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
@@ -111,7 +135,9 @@ def convert(api: "HfApi", model_id: str, task: str, force: bool = False) -> Tupl
111
  elif pr is not None and not force:
112
  url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
113
  new_pr = pr
114
- raise Exception(f"Model {model_id} already has an open PR check out {url}")
 
 
115
  else:
116
  operations = convert_onnx(model_id, task, folder)
117
 
@@ -159,4 +185,4 @@ if __name__ == "__main__":
159
  )
160
  args = parser.parse_args()
161
  api = HfApi()
162
- convert(api, args.model_id, task=args.task, force=args.force)
 
1
+ import argparse
 
 
 
 
 
 
 
 
 
2
  import os
3
  import shutil
4
+ from pathlib import Path
5
+ from tempfile import TemporaryDirectory
6
+ from typing import List, Optional, Tuple
7
 
8
+ from huggingface_hub import (CommitOperationAdd, HfApi, get_repo_discussions,
9
+ hf_hub_download)
10
  from huggingface_hub.file_download import repo_folder_name
11
+ from optimum.exporters.onnx import (OnnxConfigWithPast, export,
12
+ validate_model_outputs)
13
+ from optimum.exporters.tasks import TasksManager
14
+ from transformers import AutoConfig, AutoTokenizer, is_torch_available
15
 
16
  SPACES_URL = "https://huggingface.co/spaces/optimum/exporters"
17
 
18
+
19
  def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
20
  try:
21
  discussions = api.get_repo_discussions(repo_id=model_id)
22
  except Exception:
23
  return None
24
  for discussion in discussions:
25
+ if (
26
+ discussion.status == "open"
27
+ and discussion.is_pull_request
28
+ and discussion.title == pr_title
29
+ ):
30
  return discussion
31
 
32
+
33
  def convert_onnx(model_id: str, task: str, folder: str) -> List:
34
 
35
  # Allocate the model
 
48
  and task in ["sequence_classification"]
49
  )
50
  if needs_pad_token_id:
51
+ # if args.pad_token_id is not None:
52
  # model.config.pad_token_id = args.pad_token_id
53
  try:
54
  tok = AutoTokenizer.from_pretrained(model_id)
 
78
  print(f"All good, model saved at: {output}")
79
  except ValueError:
80
  print(f"An error occured, but the model was saved at: {output.as_posix()}")
81
+
82
+ n_files = len(
83
+ [
84
+ name
85
+ for name in os.listdir(folder)
86
+ if os.path.isfile(os.path.join(folder, name)) and not name.startswith(".")
87
+ ]
88
+ )
89
+
90
  if n_files == 1:
91
+ operations = [
92
+ CommitOperationAdd(
93
+ path_in_repo=file_name, path_or_fileobj=os.path.join(folder, file_name)
94
+ )
95
+ for file_name in os.listdir(folder)
96
+ ]
97
  else:
98
+ operations = [
99
+ CommitOperationAdd(
100
+ path_in_repo=os.path.join("onnx", file_name),
101
+ path_or_fileobj=os.path.join(folder, file_name),
102
+ )
103
+ for file_name in os.listdir(folder)
104
+ ]
105
+
106
  return operations
107
 
108
 
109
+ def convert(
110
+ api: "HfApi", model_id: str, task: str, force: bool = False
111
+ ) -> Tuple[int, "CommitInfo"]:
112
  pr_title = "Adding ONNX file of this model"
113
  info = api.model_info(model_id)
114
  filenames = set(s.rfilename for s in info.siblings)
 
119
  try:
120
  task = TasksManager.infer_task_from_model(model_id)
121
  except Exception as e:
122
+ return (
123
+ f"### Error: {e}. Please pass explicitely the task as it could not be infered.",
124
+ None,
125
+ )
126
 
127
  with TemporaryDirectory() as d:
128
  folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
 
135
  elif pr is not None and not force:
136
  url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
137
  new_pr = pr
138
+ raise Exception(
139
+ f"Model {model_id} already has an open PR check out {url}"
140
+ )
141
  else:
142
  operations = convert_onnx(model_id, task, folder)
143
 
 
185
  )
186
  args = parser.parse_args()
187
  api = HfApi()
188
+ convert(api, args.model_id, task=args.task, force=args.force)