File size: 4,096 Bytes
8b7a3d1
 
 
 
 
a3f0757
8b7a3d1
 
 
 
 
a60b0e7
8b7a3d1
 
 
 
 
06affa1
8b7a3d1
a11d22c
06affa1
a3f0757
 
8b7a3d1
a3f0757
8b7a3d1
 
bd420fb
8b7a3d1
 
 
 
a3f0757
 
 
 
8b7a3d1
 
a3f0757
8b7a3d1
 
 
07a02e5
06affa1
 
8b7a3d1
 
 
 
 
 
 
 
 
 
 
 
 
67d646a
8b7a3d1
 
bd420fb
8b7a3d1
a3f0757
8b7a3d1
06affa1
a3f0757
8b7a3d1
 
 
 
a60b0e7
 
 
a3f0757
a60b0e7
 
 
8b7a3d1
 
 
 
a60b0e7
 
 
 
 
 
07a02e5
a3f0757
 
8b7a3d1
db0cd98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python

from __future__ import annotations

import os
from subprocess import getoutput

import gradio as gr
import torch

from app_inference import create_inference_demo
from app_system_monitor import create_monitor_demo
from app_training import create_training_demo
from app_upload import create_upload_demo
from inference import InferencePipeline
from trainer import Trainer

TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/)'

ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
SPACE_ID = os.getenv('SPACE_ID')
GPU_DATA = getoutput('nvidia-smi')
SHARED_UI_WARNING = f'''## Attention - Training doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.

<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
'''

IS_SHARED_UI = SPACE_ID == ORIGINAL_SPACE_ID
if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
    SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
else:
    SETTINGS = 'Settings'

INVALID_GPU_WARNING = f'''## Attention - the specified GPU is invalid. Training may not work. Make sure you have selected a `T4 GPU` for this task.'''

CUDA_NOT_AVAILABLE_WARNING = f'''## Attention - Running on CPU.
<center>
You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
You can use "T4 small/medium" to run this demo.
</center>
'''

HF_TOKEN_NOT_SPECIFIED_WARNING = f'''The environment variable `HF_TOKEN` is not specified. Feel free to specify your Hugging Face token with write permission if you don't want to manually provide it for every run.

You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>. You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
'''

HF_TOKEN = os.getenv('HF_TOKEN')


def show_warning(warning_text: str) -> gr.Blocks:
    with gr.Blocks() as demo:
        with gr.Box():
            gr.Markdown(warning_text)
    return demo


pipe = InferencePipeline(HF_TOKEN)
trainer = Trainer()

with gr.Blocks(css='style.css') as demo:
    if IS_SHARED_UI:
        show_warning(SHARED_UI_WARNING)
    elif not torch.cuda.is_available():
        show_warning(CUDA_NOT_AVAILABLE_WARNING)
    elif 'T4' not in GPU_DATA:
        show_warning(INVALID_GPU_WARNING)

    gr.Markdown(TITLE)
    with gr.Tabs():
        with gr.TabItem('Train'):
            create_training_demo(trainer,
                                 pipe,
                                 disable_run_button=IS_SHARED_UI)
        with gr.TabItem('Run'):
            create_inference_demo(pipe,
                                  HF_TOKEN,
                                  disable_run_button=IS_SHARED_UI)
        with gr.TabItem('Upload'):
            gr.Markdown('''
            - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
            ''')
            create_upload_demo(disable_run_button=IS_SHARED_UI)

    with gr.Row():
        if not IS_SHARED_UI and not os.getenv('DISABLE_SYSTEM_MONITOR'):
            with gr.Accordion(label='System info', open=False):
                create_monitor_demo()

    if not HF_TOKEN:
        show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)

demo.queue(api_open=False, max_size=1).launch()