export / onnx_export.py
Felix Marty
working version?
be527a9
raw history blame
No virus
4.87 kB
from optimum.exporters.tasks import TasksManager
from optimum.exporters.onnx import OnnxConfigWithPast, export, validate_model_outputs
from tempfile import TemporaryDirectory
from transformers import AutoConfig, AutoTokenizer, is_torch_available
from pathlib import Path
import os
import shutil
import argparse
from typing import Optional
from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download, get_repo_discussions
from huggingface_hub.file_download import repo_folder_name
def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
try:
discussions = api.get_repo_discussions(repo_id=model_id)
except Exception:
return None
for discussion in discussions:
if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
return discussion
def convert_onnx(model_id: str, task: str, folder: str):
# Allocate the model
model = TasksManager.get_model_from_task(task, model_id, framework="pt")
model_type = model.config.model_type.replace("_", "-")
model_name = getattr(model, "name", None)
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model_type, "onnx", task=task, model_name=model_name
)
onnx_config = onnx_config_constructor(model.config)
needs_pad_token_id = (
isinstance(onnx_config, OnnxConfigWithPast)
and getattr(model.config, "pad_token_id", None) is None
and task in ["sequence_classification"]
)
if needs_pad_token_id:
#if args.pad_token_id is not None:
# model.config.pad_token_id = args.pad_token_id
try:
tok = AutoTokenizer.from_pretrained(model_id)
model.config.pad_token_id = tok.pad_token_id
except Exception:
raise ValueError(
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
)
# Ensure the requested opset is sufficient
opset = onnx_config.DEFAULT_ONNX_OPSET
output = Path(folder).joinpath("model.onnx")
onnx_inputs, onnx_outputs = export(
model,
onnx_config,
opset,
output,
)
atol = onnx_config.ATOL_FOR_VALIDATION
if isinstance(atol, dict):
atol = atol[task.replace("-with-past", "")]
validate_model_outputs(onnx_config, model, output, onnx_outputs, atol)
print(f"All good, model saved at: {output}")
operations = [CommitOperationAdd(path_in_repo=file_name, path_or_fileobj=os.path.join(folder, file_name)) for file_name in os.listdir(folder)]
return operations
def convert(api: "HfApi", model_id: str, task:str, force: bool=False) -> Optional["CommitInfo"]:
pr_title = "Adding ONNX file of this model"
info = api.model_info(model_id)
filenames = set(s.rfilename for s in info.siblings)
with TemporaryDirectory() as d:
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
os.makedirs(folder)
new_pr = None
try:
pr = previous_pr(api, model_id, pr_title)
if "model.onnx" in filenames and not force:
raise Exception(f"Model {model_id} is already converted, skipping..")
elif pr is not None and not force:
url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
new_pr = pr
raise Exception(f"Model {model_id} already has an open PR check out {url}")
else:
operations = convert_onnx(model_id, task, folder)
new_pr = api.create_commit(
repo_id=model_id,
operations=operations,
commit_message=pr_title,
create_pr=True,
)
finally:
shutil.rmtree(folder)
return new_pr
if __name__ == "__main__":
DESCRIPTION = """
Simple utility tool to convert automatically a model on the hub to onnx format.
It is PyTorch exclusive for now.
It works by downloading the weights (PT), converting them locally, and uploading them back
as a PR on the hub.
"""
parser = argparse.ArgumentParser(description=DESCRIPTION)
parser.add_argument(
"--model_id",
type=str,
help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
)
parser.add_argument(
"--task",
type=str,
help="The task the model is performing",
)
parser.add_argument(
"--force",
action="store_true",
help="Create the PR even if it already exists of if the model was already converted.",
)
args = parser.parse_args()
api = HfApi()
convert(api, args.model_id, task=args.task, force=args.force)