File size: 5,597 Bytes
1e7dfc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import csv
import os
from datetime import datetime
from typing import Optional, Union

import gradio as gr
from huggingface_hub import HfApi, Repository

from onnx_export import convert

from apscheduler.schedulers.background import BackgroundScheduler

DATASET_REPO_URL = "https://huggingface.co/datasets/optimum/exporters"
DATA_FILENAME = "data.csv"
DATA_FILE = os.path.join("data", DATA_FILENAME)

HF_TOKEN = os.environ.get("HF_WRITE_TOKEN")

DATADIR = "exporters_data"

repo: Optional[Repository] = None
# if HF_TOKEN:
#     repo = Repository(local_dir=DATADIR, clone_from=DATASET_REPO_URL, token=HF_TOKEN)


def onnx_export(token: str, model_id: str, task: str, opset: Union[int, str]) -> str:
    if token == "" or model_id == "":
        return """
        ### Invalid input 🐞

        Please fill a token and model name.
        """
    try:
        if opset == "":
            opset = None
        else:
            opset = int(opset)

        api = HfApi(token=token)

        error, commit_info = convert(api=api, model_id=model_id, task=task, opset=opset)
        if error != "0":
            return error

        print("[commit_info]", commit_info)

        # save in a private dataset
        if repo is not None:
            repo.git_pull(rebase=True)
            with open(os.path.join(DATADIR, DATA_FILE), "a") as csvfile:
                writer = csv.DictWriter(
                    csvfile, fieldnames=["model_id", "pr_url", "time"]
                )
                writer.writerow(
                    {
                        "model_id": model_id,
                        "pr_url": commit_info.pr_url,
                        "time": str(datetime.now()),
                    }
                )
            commit_url = repo.push_to_hub()
            print("[dataset]", commit_url)

        pr_revision = commit_info.pr_revision.replace("/", "%2F")

        return f"#### This model was successfully exported and a PR was open using your token, here: [{commit_info.pr_url}]({commit_info.pr_url}). If you would like to use the exported model without waiting for the PR to be approved, head to https://huggingface.co/{model_id}/tree/{pr_revision}"
    except Exception as e:
        return f"#### Error: {e}"


TTILE_IMAGE = """
<div
    style="
        display: block;
        margin-left: auto;
        margin-right: auto;
        width: 50%;
    "
>
<img src="https://i.ibb.co/m5VnjSsQ/Blue-and-White-Illustrative-Profile-Twitter-Header.png"/>
</div>
"""

TITLE = """
<div
    style="
        display: inline-flex;
        align-items: center;
        text-align: center;
        max-width: 1400px;
        gap: 0.8rem;
        font-size: 2.2rem;
    "
>
<h1 style="font-weight: 900; margin-bottom: 10px; margin-top: 10px;">
    Export transformers model to ONNX with HF Optimum exporters.
</h1>
</div>
"""

# for some reason https://huggingface.co/settings/tokens is not showing as a link by default?
DESCRIPTION = """
This Space enables automatic export of Hugging Face transformers PyTorch models to [ONNX](https://onnx.ai/). It creates a pull request on the target model repository, allowing model owners to review and merge the ONNX export, making their models accessible across a wide range of devices and platforms.

Once exported, the model can be seamlessly integrated with [HF Optimum](https://huggingface.co/docs/optimum/), maintaining compatibility with the transformers API. For detailed implementation, check out [this comprehensive guide](https://huggingface.co/docs/optimum/main/en/onnxruntime/usage_guides/models).

Quick Start Guide:
1. Obtain a read-access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) (read access is sufficient for PR creation)
2. Enter a model ID from the Hub (e.g., [textattack/distilbert-base-cased-CoLA](https://huggingface.co/textattack/distilbert-base-cased-CoLA))
3. Click "Export to ONNX"
4. Done! You'll receive feedback on the export status and, if successful, the URL of the created pull request

Important Note: For models exceeding 2 GB, the ONNX export will be saved in an `onnx/` subfolder. When loading such models with Optimum, remember to include the `subfolder="onnx"` parameter."""

with gr.Blocks() as demo:
    gr.HTML(TTILE_IMAGE)
    gr.HTML(TITLE)

    with gr.Row():
        with gr.Column(scale=50):
            gr.Markdown(DESCRIPTION)

        with gr.Column(scale=50):
            input_token = gr.Textbox(
                max_lines=1,
                label="Hugging Face token",
            )
            input_model = gr.Textbox(
                max_lines=1,
                label="Model name",
                placeholder="textattack/distilbert-base-cased-CoLA",
            )
            input_task = gr.Textbox(
                value="auto",
                max_lines=1,
                label='Task (can be left to "auto", will be automatically inferred)',
            )
            onnx_opset = gr.Textbox(
                placeholder="for example 14, can be left blank",
                max_lines=1,
                label="ONNX opset (optional, can be left blank)",
            )

            btn = gr.Button("Export to ONNX")
            output = gr.Markdown(label="Output")

    btn.click(
        fn=onnx_export,
        inputs=[input_token, input_model, input_task, onnx_opset],
        outputs=output,
    )

def restart_space():
    HfApi().restart_space(repo_id="onnx/export", token=HF_TOKEN, factory_reboot=True)

scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", seconds=21600)
scheduler.start()

demo.launch()