huy hysts HF staff commited on
Commit
5dfd462
0 Parent(s):

Duplicate from Tune-A-Video-library/Tune-A-Video-Training-UI

Browse files

Co-authored-by: hysts <hysts@users.noreply.huggingface.co>

Files changed (22) hide show
  1. .gitattributes +35 -0
  2. .gitignore +164 -0
  3. .gitmodules +3 -0
  4. .pre-commit-config.yaml +37 -0
  5. .style.yapf +5 -0
  6. Dockerfile +57 -0
  7. LICENSE +21 -0
  8. README.md +12 -0
  9. Tune-A-Video +1 -0
  10. app.py +84 -0
  11. app_inference.py +170 -0
  12. app_training.py +135 -0
  13. app_upload.py +106 -0
  14. constants.py +10 -0
  15. inference.py +109 -0
  16. packages.txt +1 -0
  17. patch +15 -0
  18. requirements.txt +19 -0
  19. style.css +3 -0
  20. trainer.py +166 -0
  21. uploader.py +44 -0
  22. utils.py +65 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.whl filter=lfs diff=lfs merge=lfs -text
2
+ *.7z filter=lfs diff=lfs merge=lfs -text
3
+ *.arrow filter=lfs diff=lfs merge=lfs -text
4
+ *.bin filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
7
+ *.ftz filter=lfs diff=lfs merge=lfs -text
8
+ *.gz filter=lfs diff=lfs merge=lfs -text
9
+ *.h5 filter=lfs diff=lfs merge=lfs -text
10
+ *.joblib filter=lfs diff=lfs merge=lfs -text
11
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
12
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
13
+ *.model filter=lfs diff=lfs merge=lfs -text
14
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
15
+ *.npy filter=lfs diff=lfs merge=lfs -text
16
+ *.npz filter=lfs diff=lfs merge=lfs -text
17
+ *.onnx filter=lfs diff=lfs merge=lfs -text
18
+ *.ot filter=lfs diff=lfs merge=lfs -text
19
+ *.parquet filter=lfs diff=lfs merge=lfs -text
20
+ *.pb filter=lfs diff=lfs merge=lfs -text
21
+ *.pickle filter=lfs diff=lfs merge=lfs -text
22
+ *.pkl filter=lfs diff=lfs merge=lfs -text
23
+ *.pt filter=lfs diff=lfs merge=lfs -text
24
+ *.pth filter=lfs diff=lfs merge=lfs -text
25
+ *.rar filter=lfs diff=lfs merge=lfs -text
26
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
27
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints/
2
+ experiments/
3
+
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "Tune-A-Video"]
2
+ path = Tune-A-Video
3
+ url = https://github.com/showlab/Tune-A-Video
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: patch
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
Dockerfile ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && \
4
+ apt-get upgrade -y && \
5
+ apt-get install -y --no-install-recommends \
6
+ git \
7
+ git-lfs \
8
+ wget \
9
+ curl \
10
+ # ffmpeg \
11
+ ffmpeg \
12
+ x264 \
13
+ # python build dependencies \
14
+ build-essential \
15
+ libssl-dev \
16
+ zlib1g-dev \
17
+ libbz2-dev \
18
+ libreadline-dev \
19
+ libsqlite3-dev \
20
+ libncursesw5-dev \
21
+ xz-utils \
22
+ tk-dev \
23
+ libxml2-dev \
24
+ libxmlsec1-dev \
25
+ libffi-dev \
26
+ liblzma-dev && \
27
+ apt-get clean && \
28
+ rm -rf /var/lib/apt/lists/*
29
+
30
+ RUN useradd -m -u 1000 user
31
+ USER user
32
+ ENV HOME=/home/user \
33
+ PATH=/home/user/.local/bin:${PATH}
34
+ WORKDIR ${HOME}/app
35
+
36
+ RUN curl https://pyenv.run | bash
37
+ ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
38
+ ENV PYTHON_VERSION=3.10.9
39
+ RUN pyenv install ${PYTHON_VERSION} && \
40
+ pyenv global ${PYTHON_VERSION} && \
41
+ pyenv rehash && \
42
+ pip install --no-cache-dir -U pip setuptools wheel
43
+
44
+ RUN pip install --no-cache-dir -U torch==1.13.1 torchvision==0.14.1
45
+ COPY --chown=1000 requirements.txt /tmp/requirements.txt
46
+ RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
47
+
48
+ COPY --chown=1000 . ${HOME}/app
49
+ RUN cd Tune-A-Video && patch -p1 < ../patch
50
+ ENV PYTHONPATH=${HOME}/app \
51
+ PYTHONUNBUFFERED=1 \
52
+ GRADIO_ALLOW_FLAGGING=never \
53
+ GRADIO_NUM_PORTS=1 \
54
+ GRADIO_SERVER_NAME=0.0.0.0 \
55
+ GRADIO_THEME=huggingface \
56
+ SYSTEM=spaces
57
+ CMD ["python", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Tune-A-Video Training UI
3
+ emoji: ⚡
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ duplicated_from: Tune-A-Video-library/Tune-A-Video-Training-UI
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
Tune-A-Video ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit b2c8c3eeac0df5c5d9eccc4dd2153e17b83c638c
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from subprocess import getoutput
7
+
8
+ import gradio as gr
9
+ import torch
10
+
11
+ from app_inference import create_inference_demo
12
+ from app_training import create_training_demo
13
+ from app_upload import create_upload_demo
14
+ from inference import InferencePipeline
15
+ from trainer import Trainer
16
+
17
+ TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/) UI'
18
+
19
+ ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
20
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
21
+ GPU_DATA = getoutput('nvidia-smi')
22
+ SHARED_UI_WARNING = f'''## Attention - Training doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
23
+
24
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
25
+ '''
26
+
27
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
28
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
29
+ else:
30
+ SETTINGS = 'Settings'
31
+
32
+ INVALID_GPU_WARNING = f'''## Attention - the specified GPU is invalid. Training may not work. Make sure you have selected a `T4 GPU` for this task.'''
33
+
34
+ CUDA_NOT_AVAILABLE_WARNING = f'''## Attention - Running on CPU.
35
+ <center>
36
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
37
+ You can use "T4 small/medium" to run this demo.
38
+ </center>
39
+ '''
40
+
41
+ HF_TOKEN_NOT_SPECIFIED_WARNING = f'''The environment variable `HF_TOKEN` is not specified. Feel free to specify your Hugging Face token with write permission if you don't want to manually provide it for every run.
42
+ <center>
43
+ You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
44
+ You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
45
+ </center>
46
+ '''
47
+
48
+ HF_TOKEN = os.getenv('HF_TOKEN')
49
+
50
+
51
+ def show_warning(warning_text: str) -> gr.Blocks:
52
+ with gr.Blocks() as demo:
53
+ with gr.Box():
54
+ gr.Markdown(warning_text)
55
+ return demo
56
+
57
+
58
+ pipe = InferencePipeline(HF_TOKEN)
59
+ trainer = Trainer(HF_TOKEN)
60
+
61
+ with gr.Blocks(css='style.css') as demo:
62
+ if SPACE_ID == ORIGINAL_SPACE_ID:
63
+ show_warning(SHARED_UI_WARNING)
64
+ elif not torch.cuda.is_available():
65
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
66
+ elif (not 'T4' in GPU_DATA):
67
+ show_warning(INVALID_GPU_WARNING)
68
+
69
+ gr.Markdown(TITLE)
70
+ with gr.Tabs():
71
+ with gr.TabItem('Train'):
72
+ create_training_demo(trainer, pipe)
73
+ with gr.TabItem('Run'):
74
+ create_inference_demo(pipe, HF_TOKEN)
75
+ with gr.TabItem('Upload'):
76
+ gr.Markdown('''
77
+ - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
78
+ ''')
79
+ create_upload_demo(HF_TOKEN)
80
+
81
+ if not HF_TOKEN:
82
+ show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
83
+
84
+ demo.queue(max_size=1).launch(share=False)
app_inference.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import enum
6
+
7
+ import gradio as gr
8
+ from huggingface_hub import HfApi
9
+
10
+ from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
11
+ from inference import InferencePipeline
12
+ from utils import find_exp_dirs
13
+
14
+
15
+ class ModelSource(enum.Enum):
16
+ HUB_LIB = UploadTarget.MODEL_LIBRARY.value
17
+ LOCAL = 'Local'
18
+
19
+
20
+ class InferenceUtil:
21
+ def __init__(self, hf_token: str | None):
22
+ self.hf_token = hf_token
23
+
24
+ def load_hub_model_list(self) -> dict:
25
+ api = HfApi(token=self.hf_token)
26
+ choices = [
27
+ info.modelId
28
+ for info in api.list_models(author=MODEL_LIBRARY_ORG_NAME)
29
+ ]
30
+ return gr.update(choices=choices,
31
+ value=choices[0] if choices else None)
32
+
33
+ @staticmethod
34
+ def load_local_model_list() -> dict:
35
+ choices = find_exp_dirs()
36
+ return gr.update(choices=choices,
37
+ value=choices[0] if choices else None)
38
+
39
+ def reload_model_list(self, model_source: str) -> dict:
40
+ if model_source == ModelSource.HUB_LIB.value:
41
+ return self.load_hub_model_list()
42
+ elif model_source == ModelSource.LOCAL.value:
43
+ return self.load_local_model_list()
44
+ else:
45
+ raise ValueError
46
+
47
+ def load_model_info(self, model_id: str) -> tuple[str, str]:
48
+ try:
49
+ card = InferencePipeline.get_model_card(model_id, self.hf_token)
50
+ except Exception:
51
+ return '', ''
52
+ base_model = getattr(card.data, 'base_model', '')
53
+ training_prompt = getattr(card.data, 'training_prompt', '')
54
+ return base_model, training_prompt
55
+
56
+ def reload_model_list_and_update_model_info(
57
+ self, model_source: str) -> tuple[dict, str, str]:
58
+ model_list_update = self.reload_model_list(model_source)
59
+ model_list = model_list_update['choices']
60
+ model_info = self.load_model_info(model_list[0] if model_list else '')
61
+ return model_list_update, *model_info
62
+
63
+
64
+ def create_inference_demo(pipe: InferencePipeline,
65
+ hf_token: str | None = None) -> gr.Blocks:
66
+ app = InferenceUtil(hf_token)
67
+
68
+ with gr.Blocks() as demo:
69
+ with gr.Row():
70
+ with gr.Column():
71
+ with gr.Box():
72
+ model_source = gr.Radio(
73
+ label='Model Source',
74
+ choices=[_.value for _ in ModelSource],
75
+ value=ModelSource.HUB_LIB.value)
76
+ reload_button = gr.Button('Reload Model List')
77
+ model_id = gr.Dropdown(label='Model ID',
78
+ choices=None,
79
+ value=None)
80
+ with gr.Accordion(
81
+ label=
82
+ 'Model info (Base model and prompt used for training)',
83
+ open=False):
84
+ with gr.Row():
85
+ base_model_used_for_training = gr.Text(
86
+ label='Base model', interactive=False)
87
+ prompt_used_for_training = gr.Text(
88
+ label='Training prompt', interactive=False)
89
+ prompt = gr.Textbox(
90
+ label='Prompt',
91
+ max_lines=1,
92
+ placeholder='Example: "A panda is surfing"')
93
+ video_length = gr.Slider(label='Video length',
94
+ minimum=4,
95
+ maximum=12,
96
+ step=1,
97
+ value=8)
98
+ fps = gr.Slider(label='FPS',
99
+ minimum=1,
100
+ maximum=12,
101
+ step=1,
102
+ value=1)
103
+ seed = gr.Slider(label='Seed',
104
+ minimum=0,
105
+ maximum=100000,
106
+ step=1,
107
+ value=0)
108
+ with gr.Accordion('Other Parameters', open=False):
109
+ num_steps = gr.Slider(label='Number of Steps',
110
+ minimum=0,
111
+ maximum=100,
112
+ step=1,
113
+ value=50)
114
+ guidance_scale = gr.Slider(label='CFG Scale',
115
+ minimum=0,
116
+ maximum=50,
117
+ step=0.1,
118
+ value=7.5)
119
+
120
+ run_button = gr.Button('Generate')
121
+
122
+ gr.Markdown('''
123
+ - After training, you can press "Reload Model List" button to load your trained model names.
124
+ - It takes a few minutes to download model first.
125
+ - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
126
+ ''')
127
+ with gr.Column():
128
+ result = gr.Video(label='Result')
129
+
130
+ model_source.change(fn=app.reload_model_list_and_update_model_info,
131
+ inputs=model_source,
132
+ outputs=[
133
+ model_id,
134
+ base_model_used_for_training,
135
+ prompt_used_for_training,
136
+ ])
137
+ reload_button.click(fn=app.reload_model_list_and_update_model_info,
138
+ inputs=model_source,
139
+ outputs=[
140
+ model_id,
141
+ base_model_used_for_training,
142
+ prompt_used_for_training,
143
+ ])
144
+ model_id.change(fn=app.load_model_info,
145
+ inputs=model_id,
146
+ outputs=[
147
+ base_model_used_for_training,
148
+ prompt_used_for_training,
149
+ ])
150
+ inputs = [
151
+ model_id,
152
+ prompt,
153
+ video_length,
154
+ fps,
155
+ seed,
156
+ num_steps,
157
+ guidance_scale,
158
+ ]
159
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
160
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
161
+ return demo
162
+
163
+
164
+ if __name__ == '__main__':
165
+ import os
166
+
167
+ hf_token = os.getenv('HF_TOKEN')
168
+ pipe = InferencePipeline(hf_token)
169
+ demo = create_inference_demo(pipe, hf_token)
170
+ demo.queue(max_size=10).launch(share=False)
app_training.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+
9
+ from constants import MODEL_LIBRARY_ORG_NAME, SAMPLE_MODEL_REPO, UploadTarget
10
+ from inference import InferencePipeline
11
+ from trainer import Trainer
12
+
13
+
14
+ def create_training_demo(trainer: Trainer,
15
+ pipe: InferencePipeline | None = None) -> gr.Blocks:
16
+ hf_token = os.getenv('HF_TOKEN')
17
+ with gr.Blocks() as demo:
18
+ with gr.Row():
19
+ with gr.Column():
20
+ with gr.Box():
21
+ gr.Markdown('Training Data')
22
+ training_video = gr.File(label='Training video')
23
+ training_prompt = gr.Textbox(
24
+ label='Training prompt',
25
+ max_lines=1,
26
+ placeholder='A man is surfing')
27
+ gr.Markdown('''
28
+ - Upload a video and write a `Training Prompt` that describes the video.
29
+ ''')
30
+
31
+ with gr.Column():
32
+ with gr.Box():
33
+ gr.Markdown('Training Parameters')
34
+ with gr.Row():
35
+ base_model = gr.Text(
36
+ label='Base Model',
37
+ value='CompVis/stable-diffusion-v1-4',
38
+ max_lines=1)
39
+ resolution = gr.Dropdown(choices=['512', '768'],
40
+ value='512',
41
+ label='Resolution',
42
+ visible=False)
43
+
44
+ input_token = gr.Text(label='Hugging Face Write Token',
45
+ placeholder='',
46
+ visible=False if hf_token else True)
47
+ with gr.Accordion('Advanced settings', open=False):
48
+ num_training_steps = gr.Number(
49
+ label='Number of Training Steps',
50
+ value=300,
51
+ precision=0)
52
+ learning_rate = gr.Number(label='Learning Rate',
53
+ value=0.000035)
54
+ gradient_accumulation = gr.Number(
55
+ label='Number of Gradient Accumulation',
56
+ value=1,
57
+ precision=0)
58
+ seed = gr.Slider(label='Seed',
59
+ minimum=0,
60
+ maximum=100000,
61
+ step=1,
62
+ randomize=True,
63
+ value=0)
64
+ fp16 = gr.Checkbox(label='FP16', value=True)
65
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam',
66
+ value=False)
67
+ checkpointing_steps = gr.Number(
68
+ label='Checkpointing Steps',
69
+ value=1000,
70
+ precision=0)
71
+ validation_epochs = gr.Number(
72
+ label='Validation Epochs', value=100, precision=0)
73
+ gr.Markdown('''
74
+ - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
75
+ - Expected time to train a model for 300 steps: ~20 minutes with T4
76
+ - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
77
+ ''')
78
+
79
+ with gr.Row():
80
+ with gr.Column():
81
+ gr.Markdown('Output Model')
82
+ output_model_name = gr.Text(label='Name of your model',
83
+ placeholder='The surfer man',
84
+ max_lines=1)
85
+ validation_prompt = gr.Text(
86
+ label='Validation Prompt',
87
+ placeholder=
88
+ 'prompt to test the model, e.g: a dog is surfing')
89
+ with gr.Column():
90
+ gr.Markdown('Upload Settings')
91
+ with gr.Row():
92
+ upload_to_hub = gr.Checkbox(label='Upload model to Hub',
93
+ value=True)
94
+ use_private_repo = gr.Checkbox(label='Private', value=True)
95
+ delete_existing_repo = gr.Checkbox(
96
+ label='Delete existing repo of the same name',
97
+ value=False)
98
+ upload_to = gr.Radio(
99
+ label='Upload to',
100
+ choices=[_.value for _ in UploadTarget],
101
+ value=UploadTarget.MODEL_LIBRARY.value)
102
+
103
+ remove_gpu_after_training = gr.Checkbox(
104
+ label='Remove GPU after training',
105
+ value=False,
106
+ interactive=bool(os.getenv('SPACE_ID')),
107
+ visible=False)
108
+ run_button = gr.Button('Start Training')
109
+
110
+ with gr.Box():
111
+ gr.Markdown('Output message')
112
+ output_message = gr.Markdown()
113
+
114
+ if pipe is not None:
115
+ run_button.click(fn=pipe.clear)
116
+ run_button.click(
117
+ fn=trainer.run,
118
+ inputs=[
119
+ training_video, training_prompt, output_model_name,
120
+ delete_existing_repo, validation_prompt, base_model,
121
+ resolution, num_training_steps, learning_rate,
122
+ gradient_accumulation, seed, fp16, use_8bit_adam,
123
+ checkpointing_steps, validation_epochs, upload_to_hub,
124
+ use_private_repo, delete_existing_repo, upload_to,
125
+ remove_gpu_after_training, input_token
126
+ ],
127
+ outputs=output_message)
128
+ return demo
129
+
130
+
131
+ if __name__ == '__main__':
132
+ hf_token = os.getenv('HF_TOKEN')
133
+ trainer = Trainer(hf_token)
134
+ demo = create_training_demo(trainer)
135
+ demo.queue(max_size=1).launch(share=False)
app_upload.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import pathlib
6
+
7
+ import gradio as gr
8
+ import slugify
9
+
10
+ from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
11
+ from uploader import Uploader
12
+ from utils import find_exp_dirs
13
+
14
+
15
+ class ModelUploader(Uploader):
16
+ def upload_model(
17
+ self,
18
+ folder_path: str,
19
+ repo_name: str,
20
+ upload_to: str,
21
+ private: bool,
22
+ delete_existing_repo: bool,
23
+ input_token: str | None = None,
24
+ ) -> str:
25
+ if not folder_path:
26
+ raise ValueError
27
+ if not repo_name:
28
+ repo_name = pathlib.Path(folder_path).name
29
+ repo_name = slugify.slugify(repo_name)
30
+
31
+ if upload_to == UploadTarget.PERSONAL_PROFILE.value:
32
+ organization = ''
33
+ elif upload_to == UploadTarget.MODEL_LIBRARY.value:
34
+ organization = MODEL_LIBRARY_ORG_NAME
35
+ else:
36
+ raise ValueError
37
+
38
+ return self.upload(folder_path,
39
+ repo_name,
40
+ organization=organization,
41
+ private=private,
42
+ delete_existing_repo=delete_existing_repo,
43
+ input_token=input_token)
44
+
45
+
46
+ def load_local_model_list() -> dict:
47
+ choices = find_exp_dirs()
48
+ return gr.update(choices=choices, value=choices[0] if choices else None)
49
+
50
+
51
+ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
52
+ uploader = ModelUploader(hf_token)
53
+ model_dirs = find_exp_dirs()
54
+
55
+ with gr.Blocks() as demo:
56
+ with gr.Box():
57
+ gr.Markdown('Local Models')
58
+ reload_button = gr.Button('Reload Model List')
59
+ model_dir = gr.Dropdown(
60
+ label='Model names',
61
+ choices=model_dirs,
62
+ value=model_dirs[0] if model_dirs else None)
63
+ with gr.Box():
64
+ gr.Markdown('Upload Settings')
65
+ with gr.Row():
66
+ use_private_repo = gr.Checkbox(label='Private', value=True)
67
+ delete_existing_repo = gr.Checkbox(
68
+ label='Delete existing repo of the same name', value=False)
69
+ upload_to = gr.Radio(label='Upload to',
70
+ choices=[_.value for _ in UploadTarget],
71
+ value=UploadTarget.MODEL_LIBRARY.value)
72
+ model_name = gr.Textbox(label='Model Name')
73
+ input_token = gr.Text(label='Hugging Face Write Token',
74
+ placeholder='',
75
+ visible=False if hf_token else True)
76
+ upload_button = gr.Button('Upload')
77
+ gr.Markdown(f'''
78
+ - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
79
+ ''')
80
+ with gr.Box():
81
+ gr.Markdown('Output message')
82
+ output_message = gr.Markdown()
83
+
84
+ reload_button.click(fn=load_local_model_list,
85
+ inputs=None,
86
+ outputs=model_dir)
87
+ upload_button.click(fn=uploader.upload_model,
88
+ inputs=[
89
+ model_dir,
90
+ model_name,
91
+ upload_to,
92
+ use_private_repo,
93
+ delete_existing_repo,
94
+ input_token,
95
+ ],
96
+ outputs=output_message)
97
+
98
+ return demo
99
+
100
+
101
+ if __name__ == '__main__':
102
+ import os
103
+
104
+ hf_token = os.getenv('HF_TOKEN')
105
+ demo = create_upload_demo(hf_token)
106
+ demo.queue(max_size=1).launch(share=False)
constants.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+
4
+ class UploadTarget(enum.Enum):
5
+ PERSONAL_PROFILE = 'Personal Profile'
6
+ MODEL_LIBRARY = 'Tune-A-Video Library'
7
+
8
+
9
+ MODEL_LIBRARY_ORG_NAME = 'Tune-A-Video-library'
10
+ SAMPLE_MODEL_REPO = 'Tune-A-Video-library/a-man-is-surfing'
inference.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+ import sys
6
+ import tempfile
7
+
8
+ import gradio as gr
9
+ import imageio
10
+ import PIL.Image
11
+ import torch
12
+ from diffusers.utils.import_utils import is_xformers_available
13
+ from einops import rearrange
14
+ from huggingface_hub import ModelCard
15
+
16
+ sys.path.append('Tune-A-Video')
17
+
18
+ from tuneavideo.models.unet import UNet3DConditionModel
19
+ from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
20
+
21
+
22
+ class InferencePipeline:
23
+ def __init__(self, hf_token: str | None = None):
24
+ self.hf_token = hf_token
25
+ self.pipe = None
26
+ self.device = torch.device(
27
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
28
+ self.model_id = None
29
+
30
+ def clear(self) -> None:
31
+ self.model_id = None
32
+ del self.pipe
33
+ self.pipe = None
34
+ torch.cuda.empty_cache()
35
+ gc.collect()
36
+
37
+ @staticmethod
38
+ def check_if_model_is_local(model_id: str) -> bool:
39
+ return pathlib.Path(model_id).exists()
40
+
41
+ @staticmethod
42
+ def get_model_card(model_id: str,
43
+ hf_token: str | None = None) -> ModelCard:
44
+ if InferencePipeline.check_if_model_is_local(model_id):
45
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
46
+ else:
47
+ card_path = model_id
48
+ return ModelCard.load(card_path, token=hf_token)
49
+
50
+ @staticmethod
51
+ def get_base_model_info(model_id: str, hf_token: str | None = None) -> str:
52
+ card = InferencePipeline.get_model_card(model_id, hf_token)
53
+ return card.data.base_model
54
+
55
+ def load_pipe(self, model_id: str) -> None:
56
+ if model_id == self.model_id:
57
+ return
58
+ base_model_id = self.get_base_model_info(model_id, self.hf_token)
59
+ unet = UNet3DConditionModel.from_pretrained(
60
+ model_id,
61
+ subfolder='unet',
62
+ torch_dtype=torch.float16,
63
+ use_auth_token=self.hf_token)
64
+ pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
65
+ unet=unet,
66
+ torch_dtype=torch.float16,
67
+ use_auth_token=self.hf_token)
68
+ pipe = pipe.to(self.device)
69
+ if is_xformers_available():
70
+ pipe.unet.enable_xformers_memory_efficient_attention()
71
+ self.pipe = pipe
72
+ self.model_id = model_id # type: ignore
73
+
74
+ def run(
75
+ self,
76
+ model_id: str,
77
+ prompt: str,
78
+ video_length: int,
79
+ fps: int,
80
+ seed: int,
81
+ n_steps: int,
82
+ guidance_scale: float,
83
+ ) -> PIL.Image.Image:
84
+ if not torch.cuda.is_available():
85
+ raise gr.Error('CUDA is not available.')
86
+
87
+ self.load_pipe(model_id)
88
+
89
+ generator = torch.Generator(device=self.device).manual_seed(seed)
90
+ out = self.pipe(
91
+ prompt,
92
+ video_length=video_length,
93
+ width=512,
94
+ height=512,
95
+ num_inference_steps=n_steps,
96
+ guidance_scale=guidance_scale,
97
+ generator=generator,
98
+ ) # type: ignore
99
+
100
+ frames = rearrange(out.videos[0], 'c t h w -> t h w c')
101
+ frames = (frames * 255).to(torch.uint8).numpy()
102
+
103
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
104
+ writer = imageio.get_writer(out_file.name, fps=fps)
105
+ for frame in frames:
106
+ writer.append_data(frame)
107
+ writer.close()
108
+
109
+ return out_file.name
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
patch ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/train_tuneavideo.py b/train_tuneavideo.py
2
+ index 66d51b2..86b2a5d 100644
3
+ --- a/train_tuneavideo.py
4
+ +++ b/train_tuneavideo.py
5
+ @@ -94,8 +94,8 @@ def main(
6
+
7
+ # Handle the output folder creation
8
+ if accelerator.is_main_process:
9
+ - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
10
+ - output_dir = os.path.join(output_dir, now)
11
+ + #now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
12
+ + #output_dir = os.path.join(output_dir, now)
13
+ os.makedirs(output_dir, exist_ok=True)
14
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
15
+
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.15.0
2
+ bitsandbytes==0.35.4
3
+ decord==0.6.0
4
+ diffusers[torch]==0.11.1
5
+ einops==0.6.0
6
+ ftfy==6.1.1
7
+ gradio==3.16.2
8
+ huggingface-hub==0.12.0
9
+ imageio==2.25.0
10
+ imageio-ffmpeg==0.4.8
11
+ omegaconf==2.3.0
12
+ Pillow==9.4.0
13
+ python-slugify==7.0.0
14
+ tensorboard==2.11.2
15
+ torch==1.13.1
16
+ torchvision==0.14.1
17
+ transformers==4.26.0
18
+ triton==2.0.0.dev20221202
19
+ xformers==0.0.16
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
trainer.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import os
5
+ import pathlib
6
+ import shlex
7
+ import shutil
8
+ import subprocess
9
+ import sys
10
+
11
+ import gradio as gr
12
+ import slugify
13
+ import torch
14
+ from huggingface_hub import HfApi
15
+ from omegaconf import OmegaConf
16
+
17
+ from app_upload import ModelUploader
18
+ from utils import save_model_card
19
+
20
+ sys.path.append('Tune-A-Video')
21
+
22
+ URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
23
+ ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
24
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
25
+
26
+
27
+ class Trainer:
28
+ def __init__(self, hf_token: str | None = None):
29
+ self.hf_token = hf_token
30
+ self.model_uploader = ModelUploader(hf_token)
31
+
32
+ self.checkpoint_dir = pathlib.Path('checkpoints')
33
+ self.checkpoint_dir.mkdir(exist_ok=True)
34
+
35
+ def download_base_model(self, base_model_id: str) -> str:
36
+ model_dir = self.checkpoint_dir / base_model_id
37
+ if not model_dir.exists():
38
+ org_name = base_model_id.split('/')[0]
39
+ org_dir = self.checkpoint_dir / org_name
40
+ org_dir.mkdir(exist_ok=True)
41
+ subprocess.run(shlex.split(
42
+ f'git clone https://huggingface.co/{base_model_id}'),
43
+ cwd=org_dir)
44
+ return model_dir.as_posix()
45
+
46
+ def join_model_library_org(self, token: str) -> None:
47
+ subprocess.run(
48
+ shlex.split(
49
+ f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
50
+ ))
51
+
52
+ def run(
53
+ self,
54
+ training_video: str,
55
+ training_prompt: str,
56
+ output_model_name: str,
57
+ overwrite_existing_model: bool,
58
+ validation_prompt: str,
59
+ base_model: str,
60
+ resolution_s: str,
61
+ n_steps: int,
62
+ learning_rate: float,
63
+ gradient_accumulation: int,
64
+ seed: int,
65
+ fp16: bool,
66
+ use_8bit_adam: bool,
67
+ checkpointing_steps: int,
68
+ validation_epochs: int,
69
+ upload_to_hub: bool,
70
+ use_private_repo: bool,
71
+ delete_existing_repo: bool,
72
+ upload_to: str,
73
+ remove_gpu_after_training: bool,
74
+ input_token: str,
75
+ ) -> str:
76
+ if SPACE_ID == ORIGINAL_SPACE_ID:
77
+ raise gr.Error(
78
+ 'This Space does not work on this Shared UI. Duplicate the Space and attribute a GPU'
79
+ )
80
+ if not torch.cuda.is_available():
81
+ raise gr.Error('CUDA is not available.')
82
+ if training_video is None:
83
+ raise gr.Error('You need to upload a video.')
84
+ if not training_prompt:
85
+ raise gr.Error('The training prompt is missing.')
86
+ if not validation_prompt:
87
+ raise gr.Error('The validation prompt is missing.')
88
+
89
+ resolution = int(resolution_s)
90
+
91
+ if not output_model_name:
92
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
93
+ output_model_name = f'tune-a-video-{timestamp}'
94
+ output_model_name = slugify.slugify(output_model_name)
95
+
96
+ repo_dir = pathlib.Path(__file__).parent
97
+ output_dir = repo_dir / 'experiments' / output_model_name
98
+ if overwrite_existing_model or upload_to_hub:
99
+ shutil.rmtree(output_dir, ignore_errors=True)
100
+ output_dir.mkdir(parents=True)
101
+
102
+ if upload_to_hub:
103
+ self.join_model_library_org(
104
+ self.hf_token if self.hf_token else input_token)
105
+
106
+ config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml')
107
+ config.pretrained_model_path = self.download_base_model(base_model)
108
+ config.output_dir = output_dir.as_posix()
109
+ config.train_data.video_path = training_video.name # type: ignore
110
+ config.train_data.prompt = training_prompt
111
+ config.train_data.n_sample_frames = 8
112
+ config.train_data.width = resolution
113
+ config.train_data.height = resolution
114
+ config.train_data.sample_start_idx = 0
115
+ config.train_data.sample_frame_rate = 1
116
+ config.validation_data.prompts = [validation_prompt]
117
+ config.validation_data.video_length = 8
118
+ config.validation_data.width = resolution
119
+ config.validation_data.height = resolution
120
+ config.validation_data.num_inference_steps = 50
121
+ config.validation_data.guidance_scale = 7.5
122
+ config.learning_rate = learning_rate
123
+ config.gradient_accumulation_steps = gradient_accumulation
124
+ config.train_batch_size = 1
125
+ config.max_train_steps = n_steps
126
+ config.checkpointing_steps = checkpointing_steps
127
+ config.validation_steps = validation_epochs
128
+ config.seed = seed
129
+ config.mixed_precision = 'fp16' if fp16 else ''
130
+ config.use_8bit_adam = use_8bit_adam
131
+
132
+ config_path = output_dir / 'config.yaml'
133
+ with open(config_path, 'w') as f:
134
+ OmegaConf.save(config, f)
135
+
136
+ command = f'accelerate launch Tune-A-Video/train_tuneavideo.py --config {config_path}'
137
+ subprocess.run(shlex.split(command))
138
+ save_model_card(save_dir=output_dir,
139
+ base_model=base_model,
140
+ training_prompt=training_prompt,
141
+ test_prompt=validation_prompt,
142
+ test_image_dir='samples')
143
+
144
+ message = 'Training completed!'
145
+ print(message)
146
+
147
+ if upload_to_hub:
148
+ upload_message = self.model_uploader.upload_model(
149
+ folder_path=output_dir.as_posix(),
150
+ repo_name=output_model_name,
151
+ upload_to=upload_to,
152
+ private=use_private_repo,
153
+ delete_existing_repo=delete_existing_repo,
154
+ input_token=input_token)
155
+ print(upload_message)
156
+ message = message + '\n' + upload_message
157
+
158
+ if remove_gpu_after_training:
159
+ space_id = os.getenv('SPACE_ID')
160
+ if space_id:
161
+ api = HfApi(
162
+ token=self.hf_token if self.hf_token else input_token)
163
+ api.request_space_hardware(repo_id=space_id,
164
+ hardware='cpu-basic')
165
+
166
+ return message
uploader.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from huggingface_hub import HfApi
4
+
5
+
6
+ class Uploader:
7
+ def __init__(self, hf_token: str | None):
8
+ self.hf_token = hf_token
9
+
10
+ def upload(self,
11
+ folder_path: str,
12
+ repo_name: str,
13
+ organization: str = '',
14
+ repo_type: str = 'model',
15
+ private: bool = True,
16
+ delete_existing_repo: bool = False,
17
+ input_token: str | None = None) -> str:
18
+
19
+ api = HfApi(token=self.hf_token if self.hf_token else input_token)
20
+
21
+ if not folder_path:
22
+ raise ValueError
23
+ if not repo_name:
24
+ raise ValueError
25
+ if not organization:
26
+ organization = api.whoami()['name']
27
+
28
+ repo_id = f'{organization}/{repo_name}'
29
+ if delete_existing_repo:
30
+ try:
31
+ api.delete_repo(repo_id, repo_type=repo_type)
32
+ except Exception:
33
+ pass
34
+ try:
35
+ api.create_repo(repo_id, repo_type=repo_type, private=private)
36
+ api.upload_folder(repo_id=repo_id,
37
+ folder_path=folder_path,
38
+ path_in_repo='.',
39
+ repo_type=repo_type)
40
+ url = f'https://huggingface.co/{repo_id}'
41
+ message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
42
+ except Exception as e:
43
+ message = str(e)
44
+ return message
utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pathlib
4
+
5
+
6
+ def find_exp_dirs() -> list[str]:
7
+ repo_dir = pathlib.Path(__file__).parent
8
+ exp_root_dir = repo_dir / 'experiments'
9
+ if not exp_root_dir.exists():
10
+ return []
11
+ exp_dirs = sorted(exp_root_dir.glob('*'))
12
+ exp_dirs = [
13
+ exp_dir for exp_dir in exp_dirs
14
+ if (exp_dir / 'model_index.json').exists()
15
+ ]
16
+ return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
17
+
18
+
19
+ def save_model_card(
20
+ save_dir: pathlib.Path,
21
+ base_model: str,
22
+ training_prompt: str,
23
+ test_prompt: str = '',
24
+ test_image_dir: str = '',
25
+ ) -> None:
26
+ image_str = ''
27
+ if test_prompt and test_image_dir:
28
+ image_paths = sorted((save_dir / test_image_dir).glob('*.gif'))
29
+ if image_paths:
30
+ image_path = image_paths[-1]
31
+ rel_path = image_path.relative_to(save_dir)
32
+ image_str = f'''## Samples
33
+ Test prompt: {test_prompt}
34
+
35
+ ![{image_path.stem}]({rel_path})'''
36
+
37
+ model_card = f'''---
38
+ license: creativeml-openrail-m
39
+ base_model: {base_model}
40
+ training_prompt: {training_prompt}
41
+ tags:
42
+ - stable-diffusion
43
+ - stable-diffusion-diffusers
44
+ - text-to-image
45
+ - diffusers
46
+ - text-to-video
47
+ - tune-a-video
48
+ inference: false
49
+ ---
50
+
51
+ # Tune-A-Video - {save_dir.name}
52
+
53
+ ## Model description
54
+ - Base model: [{base_model}](https://huggingface.co/{base_model})
55
+ - Training prompt: {training_prompt}
56
+
57
+ {image_str}
58
+
59
+ ## Related papers:
60
+ - [Tune-A-Video](https://arxiv.org/abs/2212.11565): One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation
61
+ - [Stable-Diffusion](https://arxiv.org/abs/2112.10752): High-Resolution Image Synthesis with Latent Diffusion Models
62
+ '''
63
+
64
+ with open(save_dir / 'README.md', 'w') as f:
65
+ f.write(model_card)