Felix Marty commited on
Commit
be527a9
1 Parent(s): f75daf5

working version?

Browse files
Files changed (1) hide show
  1. onnx_export.py +58 -54
onnx_export.py CHANGED
@@ -4,9 +4,7 @@ from optimum.exporters.onnx import OnnxConfigWithPast, export, validate_model_ou
4
 
5
  from tempfile import TemporaryDirectory
6
 
7
- from transformers import AutoConfig, is_torch_available
8
-
9
- from transformers import AutoConfig
10
 
11
  from pathlib import Path
12
 
@@ -29,55 +27,54 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
29
  return discussion
30
 
31
  def convert_onnx(model_id: str, task: str, folder: str):
32
- model_class = TasksManager.get_model_class_for_task(task)
33
- config = AutoConfig.from_pretrained(model_id)
34
- model = model_class.from_config(config)
35
-
36
- device = "cpu" # ?
37
-
38
- # Dynamic axes aren't supported for YOLO-like models. This means they cannot be exported to ONNX on CUDA devices.
39
- # See: https://github.com/ultralytics/yolov5/pull/8378
40
- if model.__class__.__name__.startswith("Yolos") and device != "cpu":
41
- return
42
-
43
- onnx_config_class_constructor = TasksManager.get_exporter_config_constructor(model_type=config.model_type, exporter="onnx", task=task, model_name=model_id)
44
- onnx_config = onnx_config_class_constructor(model.config)
45
-
46
- # We need to set this to some value to be able to test the outputs values for batch size > 1.
47
- if (
48
- isinstance(onnx_config, OnnxConfigWithPast)
49
- and getattr(model.config, "pad_token_id", None) is None
50
- and task == "sequence-classification"
51
- ):
52
- model.config.pad_token_id = 0
53
-
54
- if is_torch_available():
55
- from optimum.exporters.onnx.utils import TORCH_VERSION
56
-
57
- if not onnx_config.is_torch_support_available:
58
- print(
59
- "Skipping due to incompatible PyTorch version. Minimum required is"
60
- f" {onnx_config.MIN_TORCH_VERSION}, got: {TORCH_VERSION}"
61
- )
62
-
63
- onnx_inputs, onnx_outputs = export(
64
- model, onnx_config, onnx_config.DEFAULT_ONNX_OPSET, Path(folder), device=device
65
- )
66
- atol = onnx_config.ATOL_FOR_VALIDATION
67
- if isinstance(atol, dict):
68
- atol = atol[task.replace("-with-past", "")]
69
- validate_model_outputs(
70
- onnx_config,
71
- model,
72
- Path(folder),
73
- onnx_outputs,
74
- atol,
75
- )
76
-
77
- # TODO: iterate in folder and add all
78
- operations = [CommitOperationAdd(path_in_repo=local.split("/")[-1], path_or_fileobj=local) for local in local_filenames]
79
 
80
- return operations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
 
83
  def convert(api: "HfApi", model_id: str, task:str, force: bool=False) -> Optional["CommitInfo"]:
@@ -98,7 +95,14 @@ def convert(api: "HfApi", model_id: str, task:str, force: bool=False) -> Optiona
98
  new_pr = pr
99
  raise Exception(f"Model {model_id} already has an open PR check out {url}")
100
  else:
101
- convert_onnx(model_id, task, folder)
 
 
 
 
 
 
 
102
  finally:
103
  shutil.rmtree(folder)
104
  return new_pr
@@ -113,12 +117,12 @@ if __name__ == "__main__":
113
  """
114
  parser = argparse.ArgumentParser(description=DESCRIPTION)
115
  parser.add_argument(
116
- "model_id",
117
  type=str,
118
  help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
119
  )
120
  parser.add_argument(
121
- "task",
122
  type=str,
123
  help="The task the model is performing",
124
  )
 
4
 
5
  from tempfile import TemporaryDirectory
6
 
7
+ from transformers import AutoConfig, AutoTokenizer, is_torch_available
 
 
8
 
9
  from pathlib import Path
10
 
 
27
  return discussion
28
 
29
  def convert_onnx(model_id: str, task: str, folder: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # Allocate the model
32
+ model = TasksManager.get_model_from_task(task, model_id, framework="pt")
33
+ model_type = model.config.model_type.replace("_", "-")
34
+ model_name = getattr(model, "name", None)
35
+
36
+ onnx_config_constructor = TasksManager.get_exporter_config_constructor(
37
+ model_type, "onnx", task=task, model_name=model_name
38
+ )
39
+ onnx_config = onnx_config_constructor(model.config)
40
+
41
+ needs_pad_token_id = (
42
+ isinstance(onnx_config, OnnxConfigWithPast)
43
+ and getattr(model.config, "pad_token_id", None) is None
44
+ and task in ["sequence_classification"]
45
+ )
46
+ if needs_pad_token_id:
47
+ #if args.pad_token_id is not None:
48
+ # model.config.pad_token_id = args.pad_token_id
49
+ try:
50
+ tok = AutoTokenizer.from_pretrained(model_id)
51
+ model.config.pad_token_id = tok.pad_token_id
52
+ except Exception:
53
+ raise ValueError(
54
+ "Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
55
+ )
56
+
57
+ # Ensure the requested opset is sufficient
58
+ opset = onnx_config.DEFAULT_ONNX_OPSET
59
+
60
+ output = Path(folder).joinpath("model.onnx")
61
+ onnx_inputs, onnx_outputs = export(
62
+ model,
63
+ onnx_config,
64
+ opset,
65
+ output,
66
+ )
67
+
68
+ atol = onnx_config.ATOL_FOR_VALIDATION
69
+ if isinstance(atol, dict):
70
+ atol = atol[task.replace("-with-past", "")]
71
+
72
+ validate_model_outputs(onnx_config, model, output, onnx_outputs, atol)
73
+ print(f"All good, model saved at: {output}")
74
+
75
+ operations = [CommitOperationAdd(path_in_repo=file_name, path_or_fileobj=os.path.join(folder, file_name)) for file_name in os.listdir(folder)]
76
+
77
+ return operations
78
 
79
 
80
  def convert(api: "HfApi", model_id: str, task:str, force: bool=False) -> Optional["CommitInfo"]:
 
95
  new_pr = pr
96
  raise Exception(f"Model {model_id} already has an open PR check out {url}")
97
  else:
98
+ operations = convert_onnx(model_id, task, folder)
99
+
100
+ new_pr = api.create_commit(
101
+ repo_id=model_id,
102
+ operations=operations,
103
+ commit_message=pr_title,
104
+ create_pr=True,
105
+ )
106
  finally:
107
  shutil.rmtree(folder)
108
  return new_pr
 
117
  """
118
  parser = argparse.ArgumentParser(description=DESCRIPTION)
119
  parser.add_argument(
120
+ "--model_id",
121
  type=str,
122
  help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
123
  )
124
  parser.add_argument(
125
+ "--task",
126
  type=str,
127
  help="The task the model is performing",
128
  )