File size: 3,963 Bytes
8b7a3d1
 
 
 
 
 
 
 
 
c0a7c3c
8b7a3d1
 
 
 
c0a7c3c
 
8b7a3d1
 
 
 
 
 
5eb79d4
dd59441
8b7a3d1
 
 
 
 
 
 
 
 
c0a7c3c
 
8b7a3d1
 
 
 
 
 
 
a3f0757
dd59441
 
8b7a3d1
 
c0a7c3c
 
8b7a3d1
 
 
 
c0a7c3c
 
8b7a3d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0a7c3c
8b7a3d1
5eb79d4
 
 
8b7a3d1
c0a7c3c
 
8b7a3d1
 
 
 
 
c0a7c3c
8b7a3d1
 
c0a7c3c
8b7a3d1
 
 
 
 
 
5eb79d4
8b7a3d1
 
 
 
 
 
 
 
 
 
 
db0cd98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python

from __future__ import annotations

import pathlib

import gradio as gr
import slugify

from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
from uploader import Uploader
from utils import find_exp_dirs


class ModelUploader(Uploader):
    def upload_model(
        self,
        folder_path: str,
        repo_name: str,
        upload_to: str,
        private: bool,
        delete_existing_repo: bool,
        input_hf_token: str | None = None,
        return_html_link: bool = True,
    ) -> str:
        if not folder_path:
            raise ValueError
        if not repo_name:
            repo_name = pathlib.Path(folder_path).name
        repo_name = slugify.slugify(repo_name)

        if upload_to == UploadTarget.PERSONAL_PROFILE.value:
            organization = ''
        elif upload_to == UploadTarget.MODEL_LIBRARY.value:
            organization = MODEL_LIBRARY_ORG_NAME
        else:
            raise ValueError

        return self.upload(folder_path,
                           repo_name,
                           organization=organization,
                           private=private,
                           delete_existing_repo=delete_existing_repo,
                           input_hf_token=input_hf_token,
                           return_html_link=return_html_link)


def load_local_model_list() -> dict:
    choices = find_exp_dirs()
    return gr.update(choices=choices, value=choices[0] if choices else None)


def create_upload_demo(hf_token: str | None) -> gr.Blocks:
    uploader = ModelUploader(hf_token)
    model_dirs = find_exp_dirs()

    with gr.Blocks() as demo:
        with gr.Box():
            gr.Markdown('Local Models')
            reload_button = gr.Button('Reload Model List')
            model_dir = gr.Dropdown(
                label='Model names',
                choices=model_dirs,
                value=model_dirs[0] if model_dirs else None)
        with gr.Box():
            gr.Markdown('Upload Settings')
            with gr.Row():
                use_private_repo = gr.Checkbox(label='Private', value=True)
                delete_existing_repo = gr.Checkbox(
                    label='Delete existing repo of the same name', value=False)
            upload_to = gr.Radio(label='Upload to',
                                 choices=[_.value for _ in UploadTarget],
                                 value=UploadTarget.MODEL_LIBRARY.value)
            model_name = gr.Textbox(label='Model Name')
            input_hf_token = gr.Text(label='Hugging Face Write Token',
                                     placeholder='',
                                     visible=False if hf_token else True)
        upload_button = gr.Button('Upload')
        gr.Markdown(f'''
            - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
            ''')
        with gr.Box():
            gr.Markdown('Output message')
            output_message = gr.Markdown()

        reload_button.click(fn=load_local_model_list,
                            inputs=None,
                            outputs=model_dir)
        upload_button.click(fn=uploader.upload_model,
                            inputs=[
                                model_dir,
                                model_name,
                                upload_to,
                                use_private_repo,
                                delete_existing_repo,
                                input_hf_token,
                            ],
                            outputs=output_message)

    return demo


if __name__ == '__main__':
    import os

    hf_token = os.getenv('HF_TOKEN')
    demo = create_upload_demo(hf_token)
    demo.queue(api_open=False, max_size=1).launch()