multimodalart HF staff hysts HF staff commited on
Commit
37822b0
·
0 Parent(s):

Duplicate from Tune-A-Video-library/Tune-A-Video-Training-UI

Browse files

Co-authored-by: hysts <hysts@users.noreply.huggingface.co>

Files changed (22) hide show
  1. .gitattributes +35 -0
  2. .gitignore +164 -0
  3. .gitmodules +3 -0
  4. .pre-commit-config.yaml +37 -0
  5. .style.yapf +5 -0
  6. Dockerfile +57 -0
  7. LICENSE +21 -0
  8. README.md +12 -0
  9. Tune-A-Video +1 -0
  10. app.py +76 -0
  11. app_inference.py +170 -0
  12. app_training.py +140 -0
  13. app_upload.py +100 -0
  14. constants.py +10 -0
  15. inference.py +109 -0
  16. packages.txt +1 -0
  17. patch +15 -0
  18. requirements.txt +19 -0
  19. style.css +3 -0
  20. trainer.py +156 -0
  21. uploader.py +42 -0
  22. utils.py +65 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.whl filter=lfs diff=lfs merge=lfs -text
2
+ *.7z filter=lfs diff=lfs merge=lfs -text
3
+ *.arrow filter=lfs diff=lfs merge=lfs -text
4
+ *.bin filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
7
+ *.ftz filter=lfs diff=lfs merge=lfs -text
8
+ *.gz filter=lfs diff=lfs merge=lfs -text
9
+ *.h5 filter=lfs diff=lfs merge=lfs -text
10
+ *.joblib filter=lfs diff=lfs merge=lfs -text
11
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
12
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
13
+ *.model filter=lfs diff=lfs merge=lfs -text
14
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
15
+ *.npy filter=lfs diff=lfs merge=lfs -text
16
+ *.npz filter=lfs diff=lfs merge=lfs -text
17
+ *.onnx filter=lfs diff=lfs merge=lfs -text
18
+ *.ot filter=lfs diff=lfs merge=lfs -text
19
+ *.parquet filter=lfs diff=lfs merge=lfs -text
20
+ *.pb filter=lfs diff=lfs merge=lfs -text
21
+ *.pickle filter=lfs diff=lfs merge=lfs -text
22
+ *.pkl filter=lfs diff=lfs merge=lfs -text
23
+ *.pt filter=lfs diff=lfs merge=lfs -text
24
+ *.pth filter=lfs diff=lfs merge=lfs -text
25
+ *.rar filter=lfs diff=lfs merge=lfs -text
26
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
27
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints/
2
+ experiments/
3
+
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "Tune-A-Video"]
2
+ path = Tune-A-Video
3
+ url = https://github.com/showlab/Tune-A-Video
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: patch
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
Dockerfile ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && \
4
+ apt-get upgrade -y && \
5
+ apt-get install -y --no-install-recommends \
6
+ git \
7
+ git-lfs \
8
+ wget \
9
+ curl \
10
+ # ffmpeg \
11
+ ffmpeg \
12
+ x264 \
13
+ # python build dependencies \
14
+ build-essential \
15
+ libssl-dev \
16
+ zlib1g-dev \
17
+ libbz2-dev \
18
+ libreadline-dev \
19
+ libsqlite3-dev \
20
+ libncursesw5-dev \
21
+ xz-utils \
22
+ tk-dev \
23
+ libxml2-dev \
24
+ libxmlsec1-dev \
25
+ libffi-dev \
26
+ liblzma-dev && \
27
+ apt-get clean && \
28
+ rm -rf /var/lib/apt/lists/*
29
+
30
+ RUN useradd -m -u 1000 user
31
+ USER user
32
+ ENV HOME=/home/user \
33
+ PATH=/home/user/.local/bin:${PATH}
34
+ WORKDIR ${HOME}/app
35
+
36
+ RUN curl https://pyenv.run | bash
37
+ ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
38
+ ENV PYTHON_VERSION=3.10.9
39
+ RUN pyenv install ${PYTHON_VERSION} && \
40
+ pyenv global ${PYTHON_VERSION} && \
41
+ pyenv rehash && \
42
+ pip install --no-cache-dir -U pip setuptools wheel
43
+
44
+ RUN pip install --no-cache-dir -U torch==1.13.1 torchvision==0.14.1
45
+ COPY --chown=1000 requirements.txt /tmp/requirements.txt
46
+ RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
47
+
48
+ COPY --chown=1000 . ${HOME}/app
49
+ RUN cd Tune-A-Video && patch -p1 < ../patch
50
+ ENV PYTHONPATH=${HOME}/app \
51
+ PYTHONUNBUFFERED=1 \
52
+ GRADIO_ALLOW_FLAGGING=never \
53
+ GRADIO_NUM_PORTS=1 \
54
+ GRADIO_SERVER_NAME=0.0.0.0 \
55
+ GRADIO_THEME=huggingface \
56
+ SYSTEM=spaces
57
+ CMD ["python", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Tune-A-Video Training UI
3
+ emoji: ⚡
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ duplicated_from: Tune-A-Video-library/Tune-A-Video-Training-UI
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
Tune-A-Video ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit b2c8c3eeac0df5c5d9eccc4dd2153e17b83c638c
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+ import torch
9
+
10
+ from app_inference import create_inference_demo
11
+ from app_training import create_training_demo
12
+ from app_upload import create_upload_demo
13
+ from inference import InferencePipeline
14
+ from trainer import Trainer
15
+
16
+ TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/) Training UI'
17
+
18
+ ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
19
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
20
+ SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU. (Please note that there seems to be an issue with training on the A10G GPU now. The model doesn't learn anything when trained on A10G. Training on T4 works perfectly fine and inference works fine on both.)
21
+
22
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
23
+ '''
24
+
25
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
26
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
27
+ else:
28
+ SETTINGS = 'Settings'
29
+ CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
30
+ <center>
31
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
32
+ You can use "T4 small/medium" or "A10G small/large" to run this demo.
33
+ </center>
34
+ '''
35
+
36
+ HF_TOKEN_NOT_SPECIFIED_WARNING = f'''# Attention - The environment variable `HF_TOKEN` is not specified. Please specify your Hugging Face token with write permission as the value of it.
37
+ <center>
38
+ You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
39
+ You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
40
+ </center>
41
+ '''
42
+
43
+ HF_TOKEN = os.getenv('HF_TOKEN')
44
+
45
+
46
+ def show_warning(warning_text: str) -> gr.Blocks:
47
+ with gr.Blocks() as demo:
48
+ with gr.Box():
49
+ gr.Markdown(warning_text)
50
+ return demo
51
+
52
+
53
+ pipe = InferencePipeline(HF_TOKEN)
54
+ trainer = Trainer(HF_TOKEN)
55
+
56
+ with gr.Blocks(css='style.css') as demo:
57
+ if os.getenv('IS_SHARED_UI'):
58
+ show_warning(SHARED_UI_WARNING)
59
+ if not torch.cuda.is_available():
60
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
61
+ if not HF_TOKEN:
62
+ show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
63
+
64
+ gr.Markdown(TITLE)
65
+ with gr.Tabs():
66
+ with gr.TabItem('Train'):
67
+ create_training_demo(trainer, pipe)
68
+ with gr.TabItem('Test'):
69
+ create_inference_demo(pipe, HF_TOKEN)
70
+ with gr.TabItem('Upload'):
71
+ gr.Markdown('''
72
+ - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
73
+ ''')
74
+ create_upload_demo(HF_TOKEN)
75
+
76
+ demo.queue(max_size=1).launch(share=False)
app_inference.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import enum
6
+
7
+ import gradio as gr
8
+ from huggingface_hub import HfApi
9
+
10
+ from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
11
+ from inference import InferencePipeline
12
+ from utils import find_exp_dirs
13
+
14
+
15
+ class ModelSource(enum.Enum):
16
+ HUB_LIB = UploadTarget.MODEL_LIBRARY.value
17
+ LOCAL = 'Local'
18
+
19
+
20
+ class InferenceUtil:
21
+ def __init__(self, hf_token: str | None):
22
+ self.hf_token = hf_token
23
+
24
+ def load_hub_model_list(self) -> dict:
25
+ api = HfApi(token=self.hf_token)
26
+ choices = [
27
+ info.modelId
28
+ for info in api.list_models(author=MODEL_LIBRARY_ORG_NAME)
29
+ ]
30
+ return gr.update(choices=choices,
31
+ value=choices[0] if choices else None)
32
+
33
+ @staticmethod
34
+ def load_local_model_list() -> dict:
35
+ choices = find_exp_dirs()
36
+ return gr.update(choices=choices,
37
+ value=choices[0] if choices else None)
38
+
39
+ def reload_model_list(self, model_source: str) -> dict:
40
+ if model_source == ModelSource.HUB_LIB.value:
41
+ return self.load_hub_model_list()
42
+ elif model_source == ModelSource.LOCAL.value:
43
+ return self.load_local_model_list()
44
+ else:
45
+ raise ValueError
46
+
47
+ def load_model_info(self, model_id: str) -> tuple[str, str]:
48
+ try:
49
+ card = InferencePipeline.get_model_card(model_id, self.hf_token)
50
+ except Exception:
51
+ return '', ''
52
+ base_model = getattr(card.data, 'base_model', '')
53
+ training_prompt = getattr(card.data, 'training_prompt', '')
54
+ return base_model, training_prompt
55
+
56
+ def reload_model_list_and_update_model_info(
57
+ self, model_source: str) -> tuple[dict, str, str]:
58
+ model_list_update = self.reload_model_list(model_source)
59
+ model_list = model_list_update['choices']
60
+ model_info = self.load_model_info(model_list[0] if model_list else '')
61
+ return model_list_update, *model_info
62
+
63
+
64
+ def create_inference_demo(pipe: InferencePipeline,
65
+ hf_token: str | None = None) -> gr.Blocks:
66
+ app = InferenceUtil(hf_token)
67
+
68
+ with gr.Blocks() as demo:
69
+ with gr.Row():
70
+ with gr.Column():
71
+ with gr.Box():
72
+ model_source = gr.Radio(
73
+ label='Model Source',
74
+ choices=[_.value for _ in ModelSource],
75
+ value=ModelSource.HUB_LIB.value)
76
+ reload_button = gr.Button('Reload Model List')
77
+ model_id = gr.Dropdown(label='Model ID',
78
+ choices=None,
79
+ value=None)
80
+ with gr.Accordion(
81
+ label=
82
+ 'Model info (Base model and prompt used for training)',
83
+ open=False):
84
+ with gr.Row():
85
+ base_model_used_for_training = gr.Text(
86
+ label='Base model', interactive=False)
87
+ prompt_used_for_training = gr.Text(
88
+ label='Training prompt', interactive=False)
89
+ prompt = gr.Textbox(
90
+ label='Prompt',
91
+ max_lines=1,
92
+ placeholder='Example: "A panda is surfing"')
93
+ video_length = gr.Slider(label='Video length',
94
+ minimum=4,
95
+ maximum=12,
96
+ step=1,
97
+ value=8)
98
+ fps = gr.Slider(label='FPS',
99
+ minimum=1,
100
+ maximum=12,
101
+ step=1,
102
+ value=1)
103
+ seed = gr.Slider(label='Seed',
104
+ minimum=0,
105
+ maximum=100000,
106
+ step=1,
107
+ value=0)
108
+ with gr.Accordion('Other Parameters', open=False):
109
+ num_steps = gr.Slider(label='Number of Steps',
110
+ minimum=0,
111
+ maximum=100,
112
+ step=1,
113
+ value=50)
114
+ guidance_scale = gr.Slider(label='CFG Scale',
115
+ minimum=0,
116
+ maximum=50,
117
+ step=0.1,
118
+ value=7.5)
119
+
120
+ run_button = gr.Button('Generate')
121
+
122
+ gr.Markdown('''
123
+ - After training, you can press "Reload Model List" button to load your trained model names.
124
+ - It takes a few minutes to download model first.
125
+ - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
126
+ ''')
127
+ with gr.Column():
128
+ result = gr.Video(label='Result')
129
+
130
+ model_source.change(fn=app.reload_model_list_and_update_model_info,
131
+ inputs=model_source,
132
+ outputs=[
133
+ model_id,
134
+ base_model_used_for_training,
135
+ prompt_used_for_training,
136
+ ])
137
+ reload_button.click(fn=app.reload_model_list_and_update_model_info,
138
+ inputs=model_source,
139
+ outputs=[
140
+ model_id,
141
+ base_model_used_for_training,
142
+ prompt_used_for_training,
143
+ ])
144
+ model_id.change(fn=app.load_model_info,
145
+ inputs=model_id,
146
+ outputs=[
147
+ base_model_used_for_training,
148
+ prompt_used_for_training,
149
+ ])
150
+ inputs = [
151
+ model_id,
152
+ prompt,
153
+ video_length,
154
+ fps,
155
+ seed,
156
+ num_steps,
157
+ guidance_scale,
158
+ ]
159
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
160
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
161
+ return demo
162
+
163
+
164
+ if __name__ == '__main__':
165
+ import os
166
+
167
+ hf_token = os.getenv('HF_TOKEN')
168
+ pipe = InferencePipeline(hf_token)
169
+ demo = create_inference_demo(pipe, hf_token)
170
+ demo.queue(max_size=10).launch(share=False)
app_training.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+
9
+ from constants import MODEL_LIBRARY_ORG_NAME, SAMPLE_MODEL_REPO, UploadTarget
10
+ from inference import InferencePipeline
11
+ from trainer import Trainer
12
+
13
+
14
+ def create_training_demo(trainer: Trainer,
15
+ pipe: InferencePipeline | None = None) -> gr.Blocks:
16
+ with gr.Blocks() as demo:
17
+ with gr.Row():
18
+ with gr.Column():
19
+ with gr.Box():
20
+ gr.Markdown('Training Data')
21
+ training_video = gr.File(label='Training video')
22
+ training_prompt = gr.Textbox(
23
+ label='Training prompt',
24
+ max_lines=1,
25
+ placeholder='A man is surfing')
26
+ gr.Markdown('''
27
+ - Upload a video and write a prompt describing the video.
28
+ ''')
29
+ with gr.Box():
30
+ gr.Markdown('Output Model')
31
+ output_model_name = gr.Text(label='Name of your model',
32
+ max_lines=1)
33
+ delete_existing_model = gr.Checkbox(
34
+ label='Delete existing model of the same name',
35
+ value=False)
36
+ validation_prompt = gr.Text(label='Validation Prompt')
37
+ with gr.Box():
38
+ gr.Markdown('Upload Settings')
39
+ with gr.Row():
40
+ upload_to_hub = gr.Checkbox(
41
+ label='Upload model to Hub', value=True)
42
+ use_private_repo = gr.Checkbox(label='Private',
43
+ value=True)
44
+ delete_existing_repo = gr.Checkbox(
45
+ label='Delete existing repo of the same name',
46
+ value=False)
47
+ upload_to = gr.Radio(
48
+ label='Upload to',
49
+ choices=[_.value for _ in UploadTarget],
50
+ value=UploadTarget.MODEL_LIBRARY.value)
51
+ gr.Markdown(f'''
52
+ - By default, trained models will be uploaded to [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (see [this example model](https://huggingface.co/{SAMPLE_MODEL_REPO})).
53
+ - You can also choose "Personal Profile", in which case, the model will be uploaded to https://huggingface.co/{{your_username}}/{{model_name}}.
54
+ ''')
55
+
56
+ with gr.Box():
57
+ gr.Markdown('Training Parameters')
58
+ with gr.Row():
59
+ base_model = gr.Text(label='Base Model',
60
+ value='CompVis/stable-diffusion-v1-4',
61
+ max_lines=1)
62
+ resolution = gr.Dropdown(choices=['512', '768'],
63
+ value='512',
64
+ label='Resolution',
65
+ visible=False)
66
+ num_training_steps = gr.Number(
67
+ label='Number of Training Steps', value=300, precision=0)
68
+ learning_rate = gr.Number(label='Learning Rate',
69
+ value=0.000035)
70
+ gradient_accumulation = gr.Number(
71
+ label='Number of Gradient Accumulation',
72
+ value=1,
73
+ precision=0)
74
+ seed = gr.Slider(label='Seed',
75
+ minimum=0,
76
+ maximum=100000,
77
+ step=1,
78
+ value=0)
79
+ fp16 = gr.Checkbox(label='FP16', value=True)
80
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=False)
81
+ checkpointing_steps = gr.Number(label='Checkpointing Steps',
82
+ value=1000,
83
+ precision=0)
84
+ validation_epochs = gr.Number(label='Validation Epochs',
85
+ value=100,
86
+ precision=0)
87
+ gr.Markdown('''
88
+ - The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
89
+ - It takes a few minutes to download the base model first.
90
+ - Expected time to train a model for 300 steps: 20 minutes with T4, 8 minutes with A10G, (4 minutes with A100)
91
+ - It takes a few minutes to upload your trained model.
92
+ - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
93
+ - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
94
+ ''')
95
+
96
+ remove_gpu_after_training = gr.Checkbox(
97
+ label='Remove GPU after training',
98
+ value=False,
99
+ interactive=bool(os.getenv('SPACE_ID')),
100
+ visible=False)
101
+ run_button = gr.Button('Start Training')
102
+
103
+ with gr.Box():
104
+ gr.Markdown('Output message')
105
+ output_message = gr.Markdown()
106
+
107
+ if pipe is not None:
108
+ run_button.click(fn=pipe.clear)
109
+ run_button.click(fn=trainer.run,
110
+ inputs=[
111
+ training_video,
112
+ training_prompt,
113
+ output_model_name,
114
+ delete_existing_model,
115
+ validation_prompt,
116
+ base_model,
117
+ resolution,
118
+ num_training_steps,
119
+ learning_rate,
120
+ gradient_accumulation,
121
+ seed,
122
+ fp16,
123
+ use_8bit_adam,
124
+ checkpointing_steps,
125
+ validation_epochs,
126
+ upload_to_hub,
127
+ use_private_repo,
128
+ delete_existing_repo,
129
+ upload_to,
130
+ remove_gpu_after_training,
131
+ ],
132
+ outputs=output_message)
133
+ return demo
134
+
135
+
136
+ if __name__ == '__main__':
137
+ hf_token = os.getenv('HF_TOKEN')
138
+ trainer = Trainer(hf_token)
139
+ demo = create_training_demo(trainer)
140
+ demo.queue(max_size=1).launch(share=False)
app_upload.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import pathlib
6
+
7
+ import gradio as gr
8
+ import slugify
9
+
10
+ from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
11
+ from uploader import Uploader
12
+ from utils import find_exp_dirs
13
+
14
+
15
+ class ModelUploader(Uploader):
16
+ def upload_model(
17
+ self,
18
+ folder_path: str,
19
+ repo_name: str,
20
+ upload_to: str,
21
+ private: bool,
22
+ delete_existing_repo: bool,
23
+ ) -> str:
24
+ if not folder_path:
25
+ raise ValueError
26
+ if not repo_name:
27
+ repo_name = pathlib.Path(folder_path).name
28
+ repo_name = slugify.slugify(repo_name)
29
+
30
+ if upload_to == UploadTarget.PERSONAL_PROFILE.value:
31
+ organization = ''
32
+ elif upload_to == UploadTarget.MODEL_LIBRARY.value:
33
+ organization = MODEL_LIBRARY_ORG_NAME
34
+ else:
35
+ raise ValueError
36
+
37
+ return self.upload(folder_path,
38
+ repo_name,
39
+ organization=organization,
40
+ private=private,
41
+ delete_existing_repo=delete_existing_repo)
42
+
43
+
44
+ def load_local_model_list() -> dict:
45
+ choices = find_exp_dirs()
46
+ return gr.update(choices=choices, value=choices[0] if choices else None)
47
+
48
+
49
+ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
50
+ uploader = ModelUploader(hf_token)
51
+ model_dirs = find_exp_dirs()
52
+
53
+ with gr.Blocks() as demo:
54
+ with gr.Box():
55
+ gr.Markdown('Local Models')
56
+ reload_button = gr.Button('Reload Model List')
57
+ model_dir = gr.Dropdown(
58
+ label='Model names',
59
+ choices=model_dirs,
60
+ value=model_dirs[0] if model_dirs else None)
61
+ with gr.Box():
62
+ gr.Markdown('Upload Settings')
63
+ with gr.Row():
64
+ use_private_repo = gr.Checkbox(label='Private', value=True)
65
+ delete_existing_repo = gr.Checkbox(
66
+ label='Delete existing repo of the same name', value=False)
67
+ upload_to = gr.Radio(label='Upload to',
68
+ choices=[_.value for _ in UploadTarget],
69
+ value=UploadTarget.MODEL_LIBRARY.value)
70
+ model_name = gr.Textbox(label='Model Name')
71
+ upload_button = gr.Button('Upload')
72
+ gr.Markdown(f'''
73
+ - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
74
+ ''')
75
+ with gr.Box():
76
+ gr.Markdown('Output message')
77
+ output_message = gr.Markdown()
78
+
79
+ reload_button.click(fn=load_local_model_list,
80
+ inputs=None,
81
+ outputs=model_dir)
82
+ upload_button.click(fn=uploader.upload_model,
83
+ inputs=[
84
+ model_dir,
85
+ model_name,
86
+ upload_to,
87
+ use_private_repo,
88
+ delete_existing_repo,
89
+ ],
90
+ outputs=output_message)
91
+
92
+ return demo
93
+
94
+
95
+ if __name__ == '__main__':
96
+ import os
97
+
98
+ hf_token = os.getenv('HF_TOKEN')
99
+ demo = create_upload_demo(hf_token)
100
+ demo.queue(max_size=1).launch(share=False)
constants.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+
4
+ class UploadTarget(enum.Enum):
5
+ PERSONAL_PROFILE = 'Personal Profile'
6
+ MODEL_LIBRARY = 'Tune-A-Video Library'
7
+
8
+
9
+ MODEL_LIBRARY_ORG_NAME = 'Tune-A-Video-library'
10
+ SAMPLE_MODEL_REPO = 'Tune-A-Video-library/a-man-is-surfing'
inference.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+ import sys
6
+ import tempfile
7
+
8
+ import gradio as gr
9
+ import imageio
10
+ import PIL.Image
11
+ import torch
12
+ from diffusers.utils.import_utils import is_xformers_available
13
+ from einops import rearrange
14
+ from huggingface_hub import ModelCard
15
+
16
+ sys.path.append('Tune-A-Video')
17
+
18
+ from tuneavideo.models.unet import UNet3DConditionModel
19
+ from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
20
+
21
+
22
+ class InferencePipeline:
23
+ def __init__(self, hf_token: str | None = None):
24
+ self.hf_token = hf_token
25
+ self.pipe = None
26
+ self.device = torch.device(
27
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
28
+ self.model_id = None
29
+
30
+ def clear(self) -> None:
31
+ self.model_id = None
32
+ del self.pipe
33
+ self.pipe = None
34
+ torch.cuda.empty_cache()
35
+ gc.collect()
36
+
37
+ @staticmethod
38
+ def check_if_model_is_local(model_id: str) -> bool:
39
+ return pathlib.Path(model_id).exists()
40
+
41
+ @staticmethod
42
+ def get_model_card(model_id: str,
43
+ hf_token: str | None = None) -> ModelCard:
44
+ if InferencePipeline.check_if_model_is_local(model_id):
45
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
46
+ else:
47
+ card_path = model_id
48
+ return ModelCard.load(card_path, token=hf_token)
49
+
50
+ @staticmethod
51
+ def get_base_model_info(model_id: str, hf_token: str | None = None) -> str:
52
+ card = InferencePipeline.get_model_card(model_id, hf_token)
53
+ return card.data.base_model
54
+
55
+ def load_pipe(self, model_id: str) -> None:
56
+ if model_id == self.model_id:
57
+ return
58
+ base_model_id = self.get_base_model_info(model_id, self.hf_token)
59
+ unet = UNet3DConditionModel.from_pretrained(
60
+ model_id,
61
+ subfolder='unet',
62
+ torch_dtype=torch.float16,
63
+ use_auth_token=self.hf_token)
64
+ pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
65
+ unet=unet,
66
+ torch_dtype=torch.float16,
67
+ use_auth_token=self.hf_token)
68
+ pipe = pipe.to(self.device)
69
+ if is_xformers_available():
70
+ pipe.unet.enable_xformers_memory_efficient_attention()
71
+ self.pipe = pipe
72
+ self.model_id = model_id # type: ignore
73
+
74
+ def run(
75
+ self,
76
+ model_id: str,
77
+ prompt: str,
78
+ video_length: int,
79
+ fps: int,
80
+ seed: int,
81
+ n_steps: int,
82
+ guidance_scale: float,
83
+ ) -> PIL.Image.Image:
84
+ if not torch.cuda.is_available():
85
+ raise gr.Error('CUDA is not available.')
86
+
87
+ self.load_pipe(model_id)
88
+
89
+ generator = torch.Generator(device=self.device).manual_seed(seed)
90
+ out = self.pipe(
91
+ prompt,
92
+ video_length=video_length,
93
+ width=512,
94
+ height=512,
95
+ num_inference_steps=n_steps,
96
+ guidance_scale=guidance_scale,
97
+ generator=generator,
98
+ ) # type: ignore
99
+
100
+ frames = rearrange(out.videos[0], 'c t h w -> t h w c')
101
+ frames = (frames * 255).to(torch.uint8).numpy()
102
+
103
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
104
+ writer = imageio.get_writer(out_file.name, fps=fps)
105
+ for frame in frames:
106
+ writer.append_data(frame)
107
+ writer.close()
108
+
109
+ return out_file.name
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
patch ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/train_tuneavideo.py b/train_tuneavideo.py
2
+ index 66d51b2..86b2a5d 100644
3
+ --- a/train_tuneavideo.py
4
+ +++ b/train_tuneavideo.py
5
+ @@ -94,8 +94,8 @@ def main(
6
+
7
+ # Handle the output folder creation
8
+ if accelerator.is_main_process:
9
+ - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
10
+ - output_dir = os.path.join(output_dir, now)
11
+ + #now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
12
+ + #output_dir = os.path.join(output_dir, now)
13
+ os.makedirs(output_dir, exist_ok=True)
14
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
15
+
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.15.0
2
+ bitsandbytes==0.35.4
3
+ decord==0.6.0
4
+ diffusers[torch]==0.11.1
5
+ einops==0.6.0
6
+ ftfy==6.1.1
7
+ gradio==3.16.2
8
+ huggingface-hub==0.12.0
9
+ imageio==2.25.0
10
+ imageio-ffmpeg==0.4.8
11
+ omegaconf==2.3.0
12
+ Pillow==9.4.0
13
+ python-slugify==7.0.0
14
+ tensorboard==2.11.2
15
+ torch==1.13.1
16
+ torchvision==0.14.1
17
+ transformers==4.26.0
18
+ triton==2.0.0.dev20221202
19
+ xformers==0.0.16
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
trainer.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import os
5
+ import pathlib
6
+ import shlex
7
+ import shutil
8
+ import subprocess
9
+ import sys
10
+
11
+ import gradio as gr
12
+ import slugify
13
+ import torch
14
+ from huggingface_hub import HfApi
15
+ from omegaconf import OmegaConf
16
+
17
+ from app_upload import ModelUploader
18
+ from utils import save_model_card
19
+
20
+ sys.path.append('Tune-A-Video')
21
+
22
+ URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
23
+
24
+
25
+ class Trainer:
26
+ def __init__(self, hf_token: str | None = None):
27
+ self.hf_token = hf_token
28
+ self.api = HfApi(token=hf_token)
29
+ self.model_uploader = ModelUploader(hf_token)
30
+
31
+ self.checkpoint_dir = pathlib.Path('checkpoints')
32
+ self.checkpoint_dir.mkdir(exist_ok=True)
33
+
34
+ def download_base_model(self, base_model_id: str) -> str:
35
+ model_dir = self.checkpoint_dir / base_model_id
36
+ if not model_dir.exists():
37
+ org_name = base_model_id.split('/')[0]
38
+ org_dir = self.checkpoint_dir / org_name
39
+ org_dir.mkdir(exist_ok=True)
40
+ subprocess.run(shlex.split(
41
+ f'git clone https://huggingface.co/{base_model_id}'),
42
+ cwd=org_dir)
43
+ return model_dir.as_posix()
44
+
45
+ def join_model_library_org(self) -> None:
46
+ subprocess.run(
47
+ shlex.split(
48
+ f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
49
+ ))
50
+
51
+ def run(
52
+ self,
53
+ training_video: str,
54
+ training_prompt: str,
55
+ output_model_name: str,
56
+ overwrite_existing_model: bool,
57
+ validation_prompt: str,
58
+ base_model: str,
59
+ resolution_s: str,
60
+ n_steps: int,
61
+ learning_rate: float,
62
+ gradient_accumulation: int,
63
+ seed: int,
64
+ fp16: bool,
65
+ use_8bit_adam: bool,
66
+ checkpointing_steps: int,
67
+ validation_epochs: int,
68
+ upload_to_hub: bool,
69
+ use_private_repo: bool,
70
+ delete_existing_repo: bool,
71
+ upload_to: str,
72
+ remove_gpu_after_training: bool,
73
+ ) -> str:
74
+ if not torch.cuda.is_available():
75
+ raise gr.Error('CUDA is not available.')
76
+ if training_video is None:
77
+ raise gr.Error('You need to upload a video.')
78
+ if not training_prompt:
79
+ raise gr.Error('The training prompt is missing.')
80
+ if not validation_prompt:
81
+ raise gr.Error('The validation prompt is missing.')
82
+
83
+ resolution = int(resolution_s)
84
+
85
+ if not output_model_name:
86
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
87
+ output_model_name = f'tune-a-video-{timestamp}'
88
+ output_model_name = slugify.slugify(output_model_name)
89
+
90
+ repo_dir = pathlib.Path(__file__).parent
91
+ output_dir = repo_dir / 'experiments' / output_model_name
92
+ if overwrite_existing_model or upload_to_hub:
93
+ shutil.rmtree(output_dir, ignore_errors=True)
94
+ output_dir.mkdir(parents=True)
95
+
96
+ if upload_to_hub:
97
+ self.join_model_library_org()
98
+
99
+ config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml')
100
+ config.pretrained_model_path = self.download_base_model(base_model)
101
+ config.output_dir = output_dir.as_posix()
102
+ config.train_data.video_path = training_video.name # type: ignore
103
+ config.train_data.prompt = training_prompt
104
+ config.train_data.n_sample_frames = 8
105
+ config.train_data.width = resolution
106
+ config.train_data.height = resolution
107
+ config.train_data.sample_start_idx = 0
108
+ config.train_data.sample_frame_rate = 1
109
+ config.validation_data.prompts = [validation_prompt]
110
+ config.validation_data.video_length = 8
111
+ config.validation_data.width = resolution
112
+ config.validation_data.height = resolution
113
+ config.validation_data.num_inference_steps = 50
114
+ config.validation_data.guidance_scale = 7.5
115
+ config.learning_rate = learning_rate
116
+ config.gradient_accumulation_steps = gradient_accumulation
117
+ config.train_batch_size = 1
118
+ config.max_train_steps = n_steps
119
+ config.checkpointing_steps = checkpointing_steps
120
+ config.validation_steps = validation_epochs
121
+ config.seed = seed
122
+ config.mixed_precision = 'fp16' if fp16 else ''
123
+ config.use_8bit_adam = use_8bit_adam
124
+
125
+ config_path = output_dir / 'config.yaml'
126
+ with open(config_path, 'w') as f:
127
+ OmegaConf.save(config, f)
128
+
129
+ command = f'accelerate launch Tune-A-Video/train_tuneavideo.py --config {config_path}'
130
+ subprocess.run(shlex.split(command))
131
+ save_model_card(save_dir=output_dir,
132
+ base_model=base_model,
133
+ training_prompt=training_prompt,
134
+ test_prompt=validation_prompt,
135
+ test_image_dir='samples')
136
+
137
+ message = 'Training completed!'
138
+ print(message)
139
+
140
+ if upload_to_hub:
141
+ upload_message = self.model_uploader.upload_model(
142
+ folder_path=output_dir.as_posix(),
143
+ repo_name=output_model_name,
144
+ upload_to=upload_to,
145
+ private=use_private_repo,
146
+ delete_existing_repo=delete_existing_repo)
147
+ print(upload_message)
148
+ message = message + '\n' + upload_message
149
+
150
+ if remove_gpu_after_training:
151
+ space_id = os.getenv('SPACE_ID')
152
+ if space_id:
153
+ self.api.request_space_hardware(repo_id=space_id,
154
+ hardware='cpu-basic')
155
+
156
+ return message
uploader.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from huggingface_hub import HfApi
4
+
5
+
6
+ class Uploader:
7
+ def __init__(self, hf_token: str | None):
8
+ self.api = HfApi(token=hf_token)
9
+
10
+ def get_username(self) -> str:
11
+ return self.api.whoami()['name']
12
+
13
+ def upload(self,
14
+ folder_path: str,
15
+ repo_name: str,
16
+ organization: str = '',
17
+ repo_type: str = 'model',
18
+ private: bool = True,
19
+ delete_existing_repo: bool = False) -> str:
20
+ if not folder_path:
21
+ raise ValueError
22
+ if not repo_name:
23
+ raise ValueError
24
+ if not organization:
25
+ organization = self.get_username()
26
+ repo_id = f'{organization}/{repo_name}'
27
+ if delete_existing_repo:
28
+ try:
29
+ self.api.delete_repo(repo_id, repo_type=repo_type)
30
+ except Exception:
31
+ pass
32
+ try:
33
+ self.api.create_repo(repo_id, repo_type=repo_type, private=private)
34
+ self.api.upload_folder(repo_id=repo_id,
35
+ folder_path=folder_path,
36
+ path_in_repo='.',
37
+ repo_type=repo_type)
38
+ url = f'https://huggingface.co/{repo_id}'
39
+ message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
40
+ except Exception as e:
41
+ message = str(e)
42
+ return message
utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pathlib
4
+
5
+
6
+ def find_exp_dirs() -> list[str]:
7
+ repo_dir = pathlib.Path(__file__).parent
8
+ exp_root_dir = repo_dir / 'experiments'
9
+ if not exp_root_dir.exists():
10
+ return []
11
+ exp_dirs = sorted(exp_root_dir.glob('*'))
12
+ exp_dirs = [
13
+ exp_dir for exp_dir in exp_dirs
14
+ if (exp_dir / 'model_index.json').exists()
15
+ ]
16
+ return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
17
+
18
+
19
+ def save_model_card(
20
+ save_dir: pathlib.Path,
21
+ base_model: str,
22
+ training_prompt: str,
23
+ test_prompt: str = '',
24
+ test_image_dir: str = '',
25
+ ) -> None:
26
+ image_str = ''
27
+ if test_prompt and test_image_dir:
28
+ image_paths = sorted((save_dir / test_image_dir).glob('*.gif'))
29
+ if image_paths:
30
+ image_path = image_paths[-1]
31
+ rel_path = image_path.relative_to(save_dir)
32
+ image_str = f'''## Samples
33
+ Test prompt: {test_prompt}
34
+
35
+ ![{image_path.stem}]({rel_path})'''
36
+
37
+ model_card = f'''---
38
+ license: creativeml-openrail-m
39
+ base_model: {base_model}
40
+ training_prompt: {training_prompt}
41
+ tags:
42
+ - stable-diffusion
43
+ - stable-diffusion-diffusers
44
+ - text-to-image
45
+ - diffusers
46
+ - text-to-video
47
+ - tune-a-video
48
+ inference: false
49
+ ---
50
+
51
+ # Tune-A-Video - {save_dir.name}
52
+
53
+ ## Model description
54
+ - Base model: [{base_model}](https://huggingface.co/{base_model})
55
+ - Training prompt: {training_prompt}
56
+
57
+ {image_str}
58
+
59
+ ## Related papers:
60
+ - [Tune-A-Video](https://arxiv.org/abs/2212.11565): One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation
61
+ - [Stable-Diffusion](https://arxiv.org/abs/2112.10752): High-Resolution Image Synthesis with Latent Diffusion Models
62
+ '''
63
+
64
+ with open(save_dir / 'README.md', 'w') as f:
65
+ f.write(model_card)