File size: 3,151 Bytes
cbdc5ec
 
88775d1
cbdc5ec
 
66662af
cbdc5ec
 
88775d1
 
cbdc5ec
 
 
88775d1
cbdc5ec
88775d1
cbdc5ec
dfc18f3
 
cbdc5ec
66662af
 
cbdc5ec
 
 
 
 
 
 
66662af
cbdc5ec
 
 
 
1041ffd
cbdc5ec
 
 
0b0a7a0
 
cbdc5ec
 
 
 
 
f227033
cbdc5ec
 
 
 
 
 
 
 
 
 
1041ffd
cbdc5ec
 
 
 
 
 
1041ffd
 
 
 
cbdc5ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705ab71
cbdc5ec
aab64ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import csv
from datetime import datetime
import os
from typing import Optional
import gradio as gr

from convert import convert
from huggingface_hub import HfApi, Repository


DATASET_REPO_URL = "https://huggingface.co/datasets/safetensors/conversions"
DATA_FILENAME = "data.csv"
DATA_FILE = os.path.join("data", DATA_FILENAME)

HF_TOKEN = os.environ.get("HF_TOKEN")

repo: Optional[Repository] = None
# TODO
if False and HF_TOKEN:
    repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, token=HF_TOKEN)


def run(token: str, model_id: str) -> str:
    if token == "" or model_id == "":
        return """
        ### Invalid input 🐞
        
        Please fill a token and model_id.
        """
    try:
        api = HfApi(token=token)
        is_private = api.model_info(repo_id=model_id).private
        print("is_private", is_private)

        commit_info, errors = convert(api=api, model_id=model_id)
        print("[commit_info]", commit_info)

        # save in a (public) dataset:
        # TODO False because of LFS bug.
        if False and repo is not None and not is_private:
            repo.git_pull(rebase=True)
            print("pulled")
            with open(DATA_FILE, "a") as csvfile:
                writer = csv.DictWriter(
                    csvfile, fieldnames=["model_id", "pr_url", "time"]
                )
                writer.writerow(
                    {
                        "model_id": model_id,
                        "pr_url": commit_info.pr_url,
                        "time": str(datetime.now()),
                    }
                )
            commit_url = repo.push_to_hub()
            print("[dataset]", commit_url)

        string =  f"""
        ### Success 🔥

        Yay! This model was successfully converted and a PR was open using your token, here:

        [{commit_info.pr_url}]({commit_info.pr_url})
        """
        if errors:
            string += "\nErrors during conversion:\n"
            string += "\n".join(f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors)
        return string
    except Exception as e:
        return f"""
        ### Error 😢😢😢
        
        {e}
        """


DESCRIPTION = """
The steps are the following:

- Paste a read-access token from hf.co/settings/tokens. Read access is enough given that we will open a PR against the source repo.
- Input a model id from the Hub
- Click "Submit"
- That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the opened PR 🔥

⚠️ For now only `pytorch_model.bin` files are supported but we'll extend in the future.
"""

demo = gr.Interface(
    title="Convert any model to Safetensors and open a PR",
    description=DESCRIPTION,
    allow_flagging="never",
    article="Check out the [Safetensors repo on GitHub](https://github.com/huggingface/safetensors)",
    inputs=[
        gr.Text(max_lines=1, label="your_hf_token"),
        gr.Text(max_lines=1, label="model_id"),
    ],
    outputs=[gr.Markdown(label="output")],
    fn=run,
).queue(max_size=10, concurrency_count=1)

demo.launch(show_api=True)