File size: 2,694 Bytes
036cfd1
656540b
036cfd1
 
 
19316e6
 
 
 
 
656540b
6a7afd0
 
 
 
 
19316e6
6a7afd0
 
 
 
 
19316e6
6a7afd0
656540b
036cfd1
 
6a7afd0
 
2f5a58e
19316e6
036cfd1
 
 
 
 
19316e6
036cfd1
 
 
 
2f5a58e
19316e6
036cfd1
 
 
 
 
19316e6
036cfd1
 
 
6a7afd0
19316e6
1a129b9
6a7afd0
036cfd1
19316e6
036cfd1
 
 
4b3c79b
036cfd1
 
 
e89bd1d
19316e6
 
e89bd1d
19316e6
 
e89bd1d
19316e6
 
036cfd1
 
be9d28f
036cfd1
be9d28f
 
 
 
 
 
 
 
 
 
 
 
 
8eadebf
be9d28f
 
 
 
 
6eaf80e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import gradio as gr
import datetime
import tempfile
from huggingface_hub import hf_hub_download
import subprocess

def md5(filename):
    return subprocess.check_output(["md5sum", filename])
    

def download_very_slow(repo_id):
    os.environ.pop("HF_TRANSFER", None)
    os.environ["HF_CHUNK_SIZE"] = "1024"

    with tempfile.TemporaryDirectory() as workdir:
        filename = hf_hub_download(
            repo_id,
            filename="pytorch_model.bin",
            force_download=True,
            cache_dir=workdir,
        )
        return md5(filename)


def download_slow(repo_id):
    os.environ.pop("HF_TRANSFER", None)
    os.environ["HF_CHUNK_SIZE"] = "10485760"

    with tempfile.TemporaryDirectory() as workdir:
        filename = hf_hub_download(
            repo_id,
            filename="pytorch_model.bin",
            force_download=True,
            cache_dir=workdir,
        )
        return md5(filename)


def download_fast(repo_id):
    os.environ["HF_TRANSFER"] = "1"
    with tempfile.TemporaryDirectory() as workdir:
        filename = hf_hub_download(
            repo_id,
            filename="pytorch_model.bin",
            force_download=True,
            cache_dir=workdir,
        )
        return md5(filename)


def download(repo_id):
    start = datetime.datetime.now()
    md5_very_slow = download_very_slow(repo_id)
    taken_very_slow = datetime.datetime.now() - start
    
    start = datetime.datetime.now()
    md5_slow = download_slow(repo_id)
    taken_slow = datetime.datetime.now() - start

    start = datetime.datetime.now()
    md5_fast = download_fast(repo_id)
    taken_fast = datetime.datetime.now() - start

    return f"""
Very slow (huggingface_hub previous to https://github.com/huggingface/huggingface_hub/pull/1267): {taken_very_slow}
MD5: {md5_very_slow}

Slow (huggingface_hub after): {taken_slow}
MD5: {md5_slow}

Fast (with hf_transfer): {taken_fast}
MD5: {md5_fast}

    """

examples = ["gpt2", "openai/whisper-large-v2"]

with gr.Blocks() as demo:
    with gr.Row():           
        with gr.Column():
            inputs = gr.Textbox(
                        label="Repo id",
                        value="gpt2",  # should be set to " " when plugged into a real API
                    )
            submit = gr.Button("Submit")
        with gr.Column():
            outputs = gr.Textbox(
                        label="Download speeds",
                )
    with gr.Row():
        gr.Examples(examples=examples, inputs=[inputs], cache_examples=True, fn=download, outputs=[outputs])
    submit.click(
            download,
            inputs=[inputs],
            outputs=[outputs],
        )
demo.launch()