File size: 3,899 Bytes
f75daf5
 
04c8db1
 
f75daf5
c43dedc
 
 
 
f75daf5
9005b4d
f75daf5
025ff03
7567dc4
04c8db1
f75daf5
 
 
 
 
 
04c8db1
 
 
 
 
f75daf5
 
04c8db1
9005b4d
 
 
 
12e01c3
9005b4d
be527a9
04c8db1
 
 
 
 
 
 
 
 
3166d00
04c8db1
 
 
 
 
 
3166d00
04c8db1
 
 
 
 
 
 
 
be527a9
f75daf5
 
04c8db1
c43dedc
 
 
 
 
04c8db1
f75daf5
 
 
 
7567dc4
 
f75daf5
 
 
 
 
 
 
025ff03
f75daf5
 
 
04c8db1
46363ea
04c8db1
f75daf5
9005b4d
be527a9
7567dc4
171b6b3
7567dc4
171b6b3
7567dc4
171b6b3
7567dc4
be527a9
 
 
 
7567dc4
be527a9
 
f75daf5
 
7804c1f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import shutil
from tempfile import TemporaryDirectory
from typing import List, Optional, Tuple

from huggingface_hub import (
    CommitOperationAdd,
    HfApi,
)
from huggingface_hub.file_download import repo_folder_name
from optimum.exporters.onnx import main_export

SPACES_URL = "https://huggingface.co/spaces/onnx/export"


def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
    try:
        discussions = api.get_repo_discussions(repo_id=model_id)
    except Exception:
        return None
    for discussion in discussions:
        if (
            discussion.status == "open"
            and discussion.is_pull_request
            and discussion.title == pr_title
        ):
            return discussion


def export_and_git_add(model_id: str, task: str, folder: str, opset: int) -> List:
    main_export(
        model_name_or_path=model_id,
        output=folder,
        task=task,
        opset=opset,
    )

    n_files = len(
        [
            name
            for name in os.listdir(folder)
            if os.path.isfile(os.path.join(folder, name)) and not name.startswith(".")
        ]
    )

    if n_files == 1:
        operations = [
            CommitOperationAdd(
                path_in_repo=file_name, path_or_fileobj=os.path.join(folder, file_name)
            )
            for file_name in os.listdir(folder)
        ]
    else:
        operations = [
            CommitOperationAdd(
                path_in_repo=os.path.join("onnx", file_name),
                path_or_fileobj=os.path.join(folder, file_name),
            )
            for file_name in os.listdir(folder)
        ]

    return operations


def convert(
    api: "HfApi",
    model_id: str,
    task: str,
    force: bool = False,
    opset: int = None,
) -> Tuple[int, "CommitInfo"]:
    pr_title = "Adding ONNX file of this model"
    info = api.model_info(model_id)
    filenames = set(s.rfilename for s in info.siblings)

    requesting_user = api.whoami()["name"]

    with TemporaryDirectory() as d:
        folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
        os.makedirs(folder)
        new_pr = None
        try:
            pr = previous_pr(api, model_id, pr_title)
            if "model.onnx" in filenames and not force:
                raise Exception(f"Model {model_id} is already converted, skipping the export.")
            elif pr is not None and not force:
                url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
                new_pr = pr
                raise Exception(
                    f"Model {model_id} already has an open PR check out [{url}]({url})"
                )
            else:
                operations = export_and_git_add(model_id, task, folder, opset)

                commit_description = f"""
Beep boop I am the [ONNX export bot 🤖🏎️]({SPACES_URL}). On behalf of [{requesting_user}](https://huggingface.co/{requesting_user}), I would like to add to this repository the model converted to ONNX.

What is ONNX? It stands for "Open Neural Network Exchange", and is the most commonly used open standard for machine learning interoperability. You can find out more at [onnx.ai](https://onnx.ai/)!

The exported ONNX model can be then be consumed by various backends as TensorRT or TVM, or simply be used in a few lines with 🤗 Optimum through ONNX Runtime, check out how [here](https://huggingface.co/docs/optimum/main/en/onnxruntime/usage_guides/models)!
                """
                new_pr = api.create_commit(
                    repo_id=model_id,
                    operations=operations,
                    commit_message=pr_title,
                    commit_description=commit_description,  # TODO
                    create_pr=True,
                )
        finally:
            shutil.rmtree(folder)
        return "0", new_pr