mingshuaii hysts HF staff commited on
Commit
7537243
0 Parent(s):

Duplicate from Tune-A-Video-library/Tune-A-Video-Training-UI

Browse files

Co-authored-by: hysts <hysts@users.noreply.huggingface.co>

Files changed (24) hide show
  1. .gitattributes +35 -0
  2. .gitignore +164 -0
  3. .gitmodules +3 -0
  4. .pre-commit-config.yaml +37 -0
  5. .style.yapf +5 -0
  6. Dockerfile +59 -0
  7. LICENSE +21 -0
  8. README.md +12 -0
  9. Tune-A-Video +1 -0
  10. app.py +93 -0
  11. app_inference.py +172 -0
  12. app_system_monitor.py +87 -0
  13. app_training.py +155 -0
  14. app_upload.py +69 -0
  15. constants.py +11 -0
  16. inference.py +109 -0
  17. packages.txt +1 -0
  18. patch +15 -0
  19. requirements-monitor.txt +4 -0
  20. requirements.txt +19 -0
  21. style.css +3 -0
  22. trainer.py +145 -0
  23. uploader.py +63 -0
  24. utils.py +65 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.whl filter=lfs diff=lfs merge=lfs -text
2
+ *.7z filter=lfs diff=lfs merge=lfs -text
3
+ *.arrow filter=lfs diff=lfs merge=lfs -text
4
+ *.bin filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
7
+ *.ftz filter=lfs diff=lfs merge=lfs -text
8
+ *.gz filter=lfs diff=lfs merge=lfs -text
9
+ *.h5 filter=lfs diff=lfs merge=lfs -text
10
+ *.joblib filter=lfs diff=lfs merge=lfs -text
11
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
12
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
13
+ *.model filter=lfs diff=lfs merge=lfs -text
14
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
15
+ *.npy filter=lfs diff=lfs merge=lfs -text
16
+ *.npz filter=lfs diff=lfs merge=lfs -text
17
+ *.onnx filter=lfs diff=lfs merge=lfs -text
18
+ *.ot filter=lfs diff=lfs merge=lfs -text
19
+ *.parquet filter=lfs diff=lfs merge=lfs -text
20
+ *.pb filter=lfs diff=lfs merge=lfs -text
21
+ *.pickle filter=lfs diff=lfs merge=lfs -text
22
+ *.pkl filter=lfs diff=lfs merge=lfs -text
23
+ *.pt filter=lfs diff=lfs merge=lfs -text
24
+ *.pth filter=lfs diff=lfs merge=lfs -text
25
+ *.rar filter=lfs diff=lfs merge=lfs -text
26
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
27
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints/
2
+ experiments/
3
+
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "Tune-A-Video"]
2
+ path = Tune-A-Video
3
+ url = https://github.com/showlab/Tune-A-Video
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: patch
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
Dockerfile ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && \
4
+ apt-get upgrade -y && \
5
+ apt-get install -y --no-install-recommends \
6
+ git \
7
+ git-lfs \
8
+ wget \
9
+ curl \
10
+ # ffmpeg \
11
+ ffmpeg \
12
+ x264 \
13
+ # python build dependencies \
14
+ build-essential \
15
+ libssl-dev \
16
+ zlib1g-dev \
17
+ libbz2-dev \
18
+ libreadline-dev \
19
+ libsqlite3-dev \
20
+ libncursesw5-dev \
21
+ xz-utils \
22
+ tk-dev \
23
+ libxml2-dev \
24
+ libxmlsec1-dev \
25
+ libffi-dev \
26
+ liblzma-dev && \
27
+ apt-get clean && \
28
+ rm -rf /var/lib/apt/lists/*
29
+
30
+ RUN useradd -m -u 1000 user
31
+ USER user
32
+ ENV HOME=/home/user \
33
+ PATH=/home/user/.local/bin:${PATH}
34
+ WORKDIR ${HOME}/app
35
+
36
+ RUN curl https://pyenv.run | bash
37
+ ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
38
+ ARG PYTHON_VERSION=3.10.11
39
+ RUN pyenv install ${PYTHON_VERSION} && \
40
+ pyenv global ${PYTHON_VERSION} && \
41
+ pyenv rehash && \
42
+ pip install --no-cache-dir -U pip setuptools wheel
43
+
44
+ RUN pip install --no-cache-dir -U torch==1.13.1 torchvision==0.14.1
45
+ COPY --chown=1000 requirements.txt /tmp/requirements.txt
46
+ RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
47
+ COPY --chown=1000 requirements-monitor.txt /tmp/requirements-monitor.txt
48
+ RUN pip install --no-cache-dir -U -r /tmp/requirements-monitor.txt
49
+
50
+ COPY --chown=1000 . ${HOME}/app
51
+ RUN cd Tune-A-Video && patch -p1 < ../patch
52
+ ENV PYTHONPATH=${HOME}/app \
53
+ PYTHONUNBUFFERED=1 \
54
+ GRADIO_ALLOW_FLAGGING=never \
55
+ GRADIO_NUM_PORTS=1 \
56
+ GRADIO_SERVER_NAME=0.0.0.0 \
57
+ GRADIO_THEME=huggingface \
58
+ SYSTEM=spaces
59
+ CMD ["python", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Tune-A-Video Training UI
3
+ emoji: ⚡
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ duplicated_from: Tune-A-Video-library/Tune-A-Video-Training-UI
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
Tune-A-Video ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit b2c8c3eeac0df5c5d9eccc4dd2153e17b83c638c
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from subprocess import getoutput
7
+
8
+ import gradio as gr
9
+ import torch
10
+
11
+ from app_inference import create_inference_demo
12
+ from app_system_monitor import create_monitor_demo
13
+ from app_training import create_training_demo
14
+ from app_upload import create_upload_demo
15
+ from inference import InferencePipeline
16
+ from trainer import Trainer
17
+
18
+ TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/)'
19
+
20
+ ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
21
+ SPACE_ID = os.getenv('SPACE_ID')
22
+ GPU_DATA = getoutput('nvidia-smi')
23
+ SHARED_UI_WARNING = f'''## Attention - Training doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
24
+
25
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
26
+ '''
27
+
28
+ IS_SHARED_UI = SPACE_ID == ORIGINAL_SPACE_ID
29
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
30
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
31
+ else:
32
+ SETTINGS = 'Settings'
33
+
34
+ INVALID_GPU_WARNING = f'''## Attention - the specified GPU is invalid. Training may not work. Make sure you have selected a `T4 GPU` for this task.'''
35
+
36
+ CUDA_NOT_AVAILABLE_WARNING = f'''## Attention - Running on CPU.
37
+ <center>
38
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
39
+ You can use "T4 small/medium" to run this demo.
40
+ </center>
41
+ '''
42
+
43
+ HF_TOKEN_NOT_SPECIFIED_WARNING = f'''The environment variable `HF_TOKEN` is not specified. Feel free to specify your Hugging Face token with write permission if you don't want to manually provide it for every run.
44
+
45
+ You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>. You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
46
+ '''
47
+
48
+ HF_TOKEN = os.getenv('HF_TOKEN')
49
+
50
+
51
+ def show_warning(warning_text: str) -> gr.Blocks:
52
+ with gr.Blocks() as demo:
53
+ with gr.Box():
54
+ gr.Markdown(warning_text)
55
+ return demo
56
+
57
+
58
+ pipe = InferencePipeline(HF_TOKEN)
59
+ trainer = Trainer()
60
+
61
+ with gr.Blocks(css='style.css') as demo:
62
+ if IS_SHARED_UI:
63
+ show_warning(SHARED_UI_WARNING)
64
+ elif not torch.cuda.is_available():
65
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
66
+ elif 'T4' not in GPU_DATA:
67
+ show_warning(INVALID_GPU_WARNING)
68
+
69
+ gr.Markdown(TITLE)
70
+ with gr.Tabs():
71
+ with gr.TabItem('Train'):
72
+ create_training_demo(trainer,
73
+ pipe,
74
+ disable_run_button=IS_SHARED_UI)
75
+ with gr.TabItem('Run'):
76
+ create_inference_demo(pipe,
77
+ HF_TOKEN,
78
+ disable_run_button=IS_SHARED_UI)
79
+ with gr.TabItem('Upload'):
80
+ gr.Markdown('''
81
+ - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
82
+ ''')
83
+ create_upload_demo(disable_run_button=IS_SHARED_UI)
84
+
85
+ with gr.Row():
86
+ if not IS_SHARED_UI and not os.getenv('DISABLE_SYSTEM_MONITOR'):
87
+ with gr.Accordion(label='System info', open=False):
88
+ create_monitor_demo()
89
+
90
+ if not HF_TOKEN:
91
+ show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
92
+
93
+ demo.queue(api_open=False, max_size=1).launch()
app_inference.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import enum
6
+
7
+ import gradio as gr
8
+ from huggingface_hub import HfApi
9
+
10
+ from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
11
+ from inference import InferencePipeline
12
+ from utils import find_exp_dirs
13
+
14
+
15
+ class ModelSource(enum.Enum):
16
+ HUB_LIB = UploadTarget.MODEL_LIBRARY.value
17
+ LOCAL = 'Local'
18
+
19
+
20
+ class InferenceUtil:
21
+ def __init__(self, hf_token: str | None):
22
+ self.hf_token = hf_token
23
+
24
+ def load_hub_model_list(self) -> dict:
25
+ api = HfApi(token=self.hf_token)
26
+ choices = [
27
+ info.modelId
28
+ for info in api.list_models(author=MODEL_LIBRARY_ORG_NAME)
29
+ ]
30
+ return gr.update(choices=choices,
31
+ value=choices[0] if choices else None)
32
+
33
+ @staticmethod
34
+ def load_local_model_list() -> dict:
35
+ choices = find_exp_dirs()
36
+ return gr.update(choices=choices,
37
+ value=choices[0] if choices else None)
38
+
39
+ def reload_model_list(self, model_source: str) -> dict:
40
+ if model_source == ModelSource.HUB_LIB.value:
41
+ return self.load_hub_model_list()
42
+ elif model_source == ModelSource.LOCAL.value:
43
+ return self.load_local_model_list()
44
+ else:
45
+ raise ValueError
46
+
47
+ def load_model_info(self, model_id: str) -> tuple[str, str]:
48
+ try:
49
+ card = InferencePipeline.get_model_card(model_id, self.hf_token)
50
+ except Exception:
51
+ return '', ''
52
+ base_model = getattr(card.data, 'base_model', '')
53
+ training_prompt = getattr(card.data, 'training_prompt', '')
54
+ return base_model, training_prompt
55
+
56
+ def reload_model_list_and_update_model_info(
57
+ self, model_source: str) -> tuple[dict, str, str]:
58
+ model_list_update = self.reload_model_list(model_source)
59
+ model_list = model_list_update['choices']
60
+ model_info = self.load_model_info(model_list[0] if model_list else '')
61
+ return model_list_update, *model_info
62
+
63
+
64
+ def create_inference_demo(pipe: InferencePipeline,
65
+ hf_token: str | None = None,
66
+ disable_run_button: bool = False) -> gr.Blocks:
67
+ app = InferenceUtil(hf_token)
68
+
69
+ with gr.Blocks() as demo:
70
+ with gr.Row():
71
+ with gr.Column():
72
+ with gr.Box():
73
+ model_source = gr.Radio(
74
+ label='Model Source',
75
+ choices=[_.value for _ in ModelSource],
76
+ value=ModelSource.HUB_LIB.value)
77
+ reload_button = gr.Button('Reload Model List')
78
+ model_id = gr.Dropdown(label='Model ID',
79
+ choices=None,
80
+ value=None)
81
+ with gr.Accordion(
82
+ label=
83
+ 'Model info (Base model and prompt used for training)',
84
+ open=False):
85
+ with gr.Row():
86
+ base_model_used_for_training = gr.Text(
87
+ label='Base model', interactive=False)
88
+ prompt_used_for_training = gr.Text(
89
+ label='Training prompt', interactive=False)
90
+ prompt = gr.Textbox(
91
+ label='Prompt',
92
+ max_lines=1,
93
+ placeholder='Example: "A panda is surfing"')
94
+ video_length = gr.Slider(label='Video length',
95
+ minimum=4,
96
+ maximum=12,
97
+ step=1,
98
+ value=8)
99
+ fps = gr.Slider(label='FPS',
100
+ minimum=1,
101
+ maximum=12,
102
+ step=1,
103
+ value=1)
104
+ seed = gr.Slider(label='Seed',
105
+ minimum=0,
106
+ maximum=100000,
107
+ step=1,
108
+ value=0)
109
+ with gr.Accordion('Advanced options', open=False):
110
+ num_steps = gr.Slider(label='Number of Steps',
111
+ minimum=0,
112
+ maximum=100,
113
+ step=1,
114
+ value=50)
115
+ guidance_scale = gr.Slider(label='Guidance scale',
116
+ minimum=0,
117
+ maximum=50,
118
+ step=0.1,
119
+ value=7.5)
120
+
121
+ run_button = gr.Button('Generate',
122
+ interactive=not disable_run_button)
123
+
124
+ gr.Markdown('''
125
+ - After training, you can press "Reload Model List" button to load your trained model names.
126
+ - It takes a few minutes to download model first.
127
+ - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
128
+ ''')
129
+ with gr.Column():
130
+ result = gr.Video(label='Result')
131
+
132
+ model_source.change(fn=app.reload_model_list_and_update_model_info,
133
+ inputs=model_source,
134
+ outputs=[
135
+ model_id,
136
+ base_model_used_for_training,
137
+ prompt_used_for_training,
138
+ ])
139
+ reload_button.click(fn=app.reload_model_list_and_update_model_info,
140
+ inputs=model_source,
141
+ outputs=[
142
+ model_id,
143
+ base_model_used_for_training,
144
+ prompt_used_for_training,
145
+ ])
146
+ model_id.change(fn=app.load_model_info,
147
+ inputs=model_id,
148
+ outputs=[
149
+ base_model_used_for_training,
150
+ prompt_used_for_training,
151
+ ])
152
+ inputs = [
153
+ model_id,
154
+ prompt,
155
+ video_length,
156
+ fps,
157
+ seed,
158
+ num_steps,
159
+ guidance_scale,
160
+ ]
161
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
162
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
163
+ return demo
164
+
165
+
166
+ if __name__ == '__main__':
167
+ import os
168
+
169
+ hf_token = os.getenv('HF_TOKEN')
170
+ pipe = InferencePipeline(hf_token)
171
+ demo = create_inference_demo(pipe, hf_token)
172
+ demo.queue(api_open=False, max_size=10).launch()
app_system_monitor.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import collections
6
+
7
+ import gradio as gr
8
+ import nvitop
9
+ import pandas as pd
10
+ import plotly.express as px
11
+ import psutil
12
+
13
+
14
+ class SystemMonitor:
15
+ MAX_SIZE = 61
16
+
17
+ def __init__(self):
18
+ self.devices = nvitop.Device.all()
19
+ self.cpu_memory_usage = collections.deque(
20
+ [0 for _ in range(self.MAX_SIZE)], maxlen=self.MAX_SIZE)
21
+ self.cpu_memory_usage_str = ''
22
+ self.gpu_memory_usage = collections.deque(
23
+ [0 for _ in range(self.MAX_SIZE)], maxlen=self.MAX_SIZE)
24
+ self.gpu_util = collections.deque([0 for _ in range(self.MAX_SIZE)],
25
+ maxlen=self.MAX_SIZE)
26
+ self.gpu_memory_usage_str = ''
27
+ self.gpu_util_str = ''
28
+
29
+ def update(self) -> None:
30
+ self.update_cpu()
31
+ self.update_gpu()
32
+
33
+ def update_cpu(self) -> None:
34
+ memory = psutil.virtual_memory()
35
+ self.cpu_memory_usage.append(memory.percent)
36
+ self.cpu_memory_usage_str = f'{memory.used / 1024**3:0.2f}GiB / {memory.total / 1024**3:0.2f}GiB ({memory.percent}%)'
37
+
38
+ def update_gpu(self) -> None:
39
+ if not self.devices:
40
+ return
41
+ device = self.devices[0]
42
+ self.gpu_memory_usage.append(device.memory_percent())
43
+ self.gpu_util.append(device.gpu_utilization())
44
+ self.gpu_memory_usage_str = f'{device.memory_usage()} ({device.memory_percent()}%)'
45
+ self.gpu_util_str = f'{device.gpu_utilization()}%'
46
+
47
+ def get_json(self) -> dict[str, str]:
48
+ return {
49
+ 'CPU memory usage': self.cpu_memory_usage_str,
50
+ 'GPU memory usage': self.gpu_memory_usage_str,
51
+ 'GPU Util': self.gpu_util_str,
52
+ }
53
+
54
+ def get_graph_data(self) -> dict[str, list[int | float]]:
55
+ return {
56
+ 'index': list(range(-self.MAX_SIZE + 1, 1)),
57
+ 'CPU memory usage': self.cpu_memory_usage,
58
+ 'GPU memory usage': self.gpu_memory_usage,
59
+ 'GPU Util': self.gpu_util,
60
+ }
61
+
62
+ def get_graph(self):
63
+ df = pd.DataFrame(self.get_graph_data())
64
+ return px.line(df,
65
+ x='index',
66
+ y=[
67
+ 'CPU memory usage',
68
+ 'GPU memory usage',
69
+ 'GPU Util',
70
+ ],
71
+ range_y=[-5,
72
+ 105]).update_layout(xaxis_title='Time',
73
+ yaxis_title='Percentage')
74
+
75
+
76
+ def create_monitor_demo() -> gr.Blocks:
77
+ monitor = SystemMonitor()
78
+ with gr.Blocks() as demo:
79
+ gr.JSON(value=monitor.update, every=1, visible=False)
80
+ gr.JSON(value=monitor.get_json, show_label=False, every=1)
81
+ gr.Plot(value=monitor.get_graph, show_label=False, every=1)
82
+ return demo
83
+
84
+
85
+ if __name__ == '__main__':
86
+ demo = create_monitor_demo()
87
+ demo.queue(api_open=False).launch()
app_training.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+
9
+ from constants import UploadTarget
10
+ from inference import InferencePipeline
11
+ from trainer import Trainer
12
+
13
+
14
+ def create_training_demo(trainer: Trainer,
15
+ pipe: InferencePipeline | None = None,
16
+ disable_run_button: bool = False) -> gr.Blocks:
17
+ def read_log() -> str:
18
+ with open(trainer.log_file) as f:
19
+ lines = f.readlines()
20
+ return ''.join(lines[-10:])
21
+
22
+ with gr.Blocks() as demo:
23
+ with gr.Row():
24
+ with gr.Column():
25
+ with gr.Box():
26
+ gr.Markdown('Training Data')
27
+ training_video = gr.File(label='Training video')
28
+ training_prompt = gr.Textbox(
29
+ label='Training prompt',
30
+ max_lines=1,
31
+ placeholder='A man is surfing')
32
+ gr.Markdown('''
33
+ - Upload a video and write a `Training Prompt` that describes the video.
34
+ ''')
35
+
36
+ with gr.Column():
37
+ with gr.Box():
38
+ gr.Markdown('Training Parameters')
39
+ with gr.Row():
40
+ base_model = gr.Text(
41
+ label='Base Model',
42
+ value='CompVis/stable-diffusion-v1-4',
43
+ max_lines=1)
44
+ resolution = gr.Dropdown(choices=['512', '768'],
45
+ value='512',
46
+ label='Resolution',
47
+ visible=False)
48
+
49
+ hf_token = gr.Text(label='Hugging Face Write Token',
50
+ type='password',
51
+ visible=os.getenv('HF_TOKEN') is None)
52
+ with gr.Accordion(label='Advanced options', open=False):
53
+ num_training_steps = gr.Number(
54
+ label='Number of Training Steps',
55
+ value=300,
56
+ precision=0)
57
+ learning_rate = gr.Number(label='Learning Rate',
58
+ value=0.000035)
59
+ gradient_accumulation = gr.Number(
60
+ label='Number of Gradient Accumulation',
61
+ value=1,
62
+ precision=0)
63
+ seed = gr.Slider(label='Seed',
64
+ minimum=0,
65
+ maximum=100000,
66
+ step=1,
67
+ randomize=True,
68
+ value=0)
69
+ fp16 = gr.Checkbox(label='FP16', value=True)
70
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam',
71
+ value=False)
72
+ checkpointing_steps = gr.Number(
73
+ label='Checkpointing Steps',
74
+ value=1000,
75
+ precision=0)
76
+ validation_epochs = gr.Number(
77
+ label='Validation Epochs', value=100, precision=0)
78
+ gr.Markdown('''
79
+ - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
80
+ - Expected time to train a model for 300 steps: ~20 minutes with T4
81
+ - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
82
+ ''')
83
+
84
+ with gr.Row():
85
+ with gr.Column():
86
+ gr.Markdown('Output Model')
87
+ output_model_name = gr.Text(label='Name of your model',
88
+ placeholder='The surfer man',
89
+ max_lines=1)
90
+ validation_prompt = gr.Text(
91
+ label='Validation Prompt',
92
+ placeholder=
93
+ 'prompt to test the model, e.g: a dog is surfing')
94
+ with gr.Column():
95
+ gr.Markdown('Upload Settings')
96
+ with gr.Row():
97
+ upload_to_hub = gr.Checkbox(label='Upload model to Hub',
98
+ value=True)
99
+ use_private_repo = gr.Checkbox(label='Private', value=True)
100
+ delete_existing_repo = gr.Checkbox(
101
+ label='Delete existing repo of the same name',
102
+ value=False)
103
+ upload_to = gr.Radio(
104
+ label='Upload to',
105
+ choices=[_.value for _ in UploadTarget],
106
+ value=UploadTarget.MODEL_LIBRARY.value)
107
+
108
+ pause_space_after_training = gr.Checkbox(
109
+ label='Pause this Space after training',
110
+ value=False,
111
+ interactive=bool(os.getenv('SPACE_ID')),
112
+ visible=False)
113
+ run_button = gr.Button('Start Training',
114
+ interactive=not disable_run_button)
115
+
116
+ with gr.Box():
117
+ gr.Text(label='Log',
118
+ value=read_log,
119
+ lines=10,
120
+ max_lines=10,
121
+ every=1)
122
+
123
+ if pipe is not None:
124
+ run_button.click(fn=pipe.clear)
125
+ run_button.click(fn=trainer.run,
126
+ inputs=[
127
+ training_video,
128
+ training_prompt,
129
+ output_model_name,
130
+ delete_existing_repo,
131
+ validation_prompt,
132
+ base_model,
133
+ resolution,
134
+ num_training_steps,
135
+ learning_rate,
136
+ gradient_accumulation,
137
+ seed,
138
+ fp16,
139
+ use_8bit_adam,
140
+ checkpointing_steps,
141
+ validation_epochs,
142
+ upload_to_hub,
143
+ use_private_repo,
144
+ delete_existing_repo,
145
+ upload_to,
146
+ pause_space_after_training,
147
+ hf_token,
148
+ ])
149
+ return demo
150
+
151
+
152
+ if __name__ == '__main__':
153
+ trainer = Trainer()
154
+ demo = create_training_demo(trainer)
155
+ demo.queue(api_open=False, max_size=1).launch()
app_upload.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+
9
+ from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
10
+ from uploader import upload
11
+ from utils import find_exp_dirs
12
+
13
+
14
+ def load_local_model_list() -> dict:
15
+ choices = find_exp_dirs()
16
+ return gr.update(choices=choices, value=choices[0] if choices else None)
17
+
18
+
19
+ def create_upload_demo(disable_run_button: bool = False) -> gr.Blocks:
20
+ model_dirs = find_exp_dirs()
21
+
22
+ with gr.Blocks() as demo:
23
+ with gr.Box():
24
+ gr.Markdown('Local Models')
25
+ reload_button = gr.Button('Reload Model List')
26
+ model_dir = gr.Dropdown(
27
+ label='Model names',
28
+ choices=model_dirs,
29
+ value=model_dirs[0] if model_dirs else None)
30
+ with gr.Box():
31
+ gr.Markdown('Upload Settings')
32
+ with gr.Row():
33
+ use_private_repo = gr.Checkbox(label='Private', value=True)
34
+ delete_existing_repo = gr.Checkbox(
35
+ label='Delete existing repo of the same name', value=False)
36
+ upload_to = gr.Radio(label='Upload to',
37
+ choices=[_.value for _ in UploadTarget],
38
+ value=UploadTarget.MODEL_LIBRARY.value)
39
+ model_name = gr.Textbox(label='Model Name')
40
+ hf_token = gr.Text(label='Hugging Face Write Token',
41
+ type='password',
42
+ visible=os.getenv('HF_TOKEN') is None)
43
+ upload_button = gr.Button('Upload', interactive=not disable_run_button)
44
+ gr.Markdown(f'''
45
+ - You can upload your trained model to your personal profile (i.e. `https://huggingface.co/{{your_username}}/{{model_name}}`) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. `https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}`).
46
+ ''')
47
+ with gr.Box():
48
+ gr.Markdown('Output message')
49
+ output_message = gr.Markdown()
50
+
51
+ reload_button.click(fn=load_local_model_list,
52
+ inputs=None,
53
+ outputs=model_dir)
54
+ upload_button.click(fn=upload,
55
+ inputs=[
56
+ model_dir,
57
+ model_name,
58
+ upload_to,
59
+ use_private_repo,
60
+ delete_existing_repo,
61
+ hf_token,
62
+ ],
63
+ outputs=output_message)
64
+ return demo
65
+
66
+
67
+ if __name__ == '__main__':
68
+ demo = create_upload_demo()
69
+ demo.queue(api_open=False, max_size=1).launch()
constants.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+
4
+ class UploadTarget(enum.Enum):
5
+ PERSONAL_PROFILE = 'Personal Profile'
6
+ MODEL_LIBRARY = 'Tune-A-Video Library'
7
+
8
+
9
+ MODEL_LIBRARY_ORG_NAME = 'Tune-A-Video-library'
10
+ SAMPLE_MODEL_REPO = 'Tune-A-Video-library/a-man-is-surfing'
11
+ URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
inference.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+ import sys
6
+ import tempfile
7
+
8
+ import gradio as gr
9
+ import imageio
10
+ import PIL.Image
11
+ import torch
12
+ from diffusers.utils.import_utils import is_xformers_available
13
+ from einops import rearrange
14
+ from huggingface_hub import ModelCard
15
+
16
+ sys.path.append('Tune-A-Video')
17
+
18
+ from tuneavideo.models.unet import UNet3DConditionModel
19
+ from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
20
+
21
+
22
+ class InferencePipeline:
23
+ def __init__(self, hf_token: str | None = None):
24
+ self.hf_token = hf_token
25
+ self.pipe = None
26
+ self.device = torch.device(
27
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
28
+ self.model_id = None
29
+
30
+ def clear(self) -> None:
31
+ self.model_id = None
32
+ del self.pipe
33
+ self.pipe = None
34
+ torch.cuda.empty_cache()
35
+ gc.collect()
36
+
37
+ @staticmethod
38
+ def check_if_model_is_local(model_id: str) -> bool:
39
+ return pathlib.Path(model_id).exists()
40
+
41
+ @staticmethod
42
+ def get_model_card(model_id: str,
43
+ hf_token: str | None = None) -> ModelCard:
44
+ if InferencePipeline.check_if_model_is_local(model_id):
45
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
46
+ else:
47
+ card_path = model_id
48
+ return ModelCard.load(card_path, token=hf_token)
49
+
50
+ @staticmethod
51
+ def get_base_model_info(model_id: str, hf_token: str | None = None) -> str:
52
+ card = InferencePipeline.get_model_card(model_id, hf_token)
53
+ return card.data.base_model
54
+
55
+ def load_pipe(self, model_id: str) -> None:
56
+ if model_id == self.model_id:
57
+ return
58
+ base_model_id = self.get_base_model_info(model_id, self.hf_token)
59
+ unet = UNet3DConditionModel.from_pretrained(
60
+ model_id,
61
+ subfolder='unet',
62
+ torch_dtype=torch.float16,
63
+ use_auth_token=self.hf_token)
64
+ pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
65
+ unet=unet,
66
+ torch_dtype=torch.float16,
67
+ use_auth_token=self.hf_token)
68
+ pipe = pipe.to(self.device)
69
+ if is_xformers_available():
70
+ pipe.unet.enable_xformers_memory_efficient_attention()
71
+ self.pipe = pipe
72
+ self.model_id = model_id # type: ignore
73
+
74
+ def run(
75
+ self,
76
+ model_id: str,
77
+ prompt: str,
78
+ video_length: int,
79
+ fps: int,
80
+ seed: int,
81
+ n_steps: int,
82
+ guidance_scale: float,
83
+ ) -> PIL.Image.Image:
84
+ if not torch.cuda.is_available():
85
+ raise gr.Error('CUDA is not available.')
86
+
87
+ self.load_pipe(model_id)
88
+
89
+ generator = torch.Generator(device=self.device).manual_seed(seed)
90
+ out = self.pipe(
91
+ prompt,
92
+ video_length=video_length,
93
+ width=512,
94
+ height=512,
95
+ num_inference_steps=n_steps,
96
+ guidance_scale=guidance_scale,
97
+ generator=generator,
98
+ ) # type: ignore
99
+
100
+ frames = rearrange(out.videos[0], 'c t h w -> t h w c')
101
+ frames = (frames * 255).to(torch.uint8).numpy()
102
+
103
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
104
+ writer = imageio.get_writer(out_file.name, fps=fps)
105
+ for frame in frames:
106
+ writer.append_data(frame)
107
+ writer.close()
108
+
109
+ return out_file.name
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
patch ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/train_tuneavideo.py b/train_tuneavideo.py
2
+ index 66d51b2..86b2a5d 100644
3
+ --- a/train_tuneavideo.py
4
+ +++ b/train_tuneavideo.py
5
+ @@ -94,8 +94,8 @@ def main(
6
+
7
+ # Handle the output folder creation
8
+ if accelerator.is_main_process:
9
+ - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
10
+ - output_dir = os.path.join(output_dir, now)
11
+ + #now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
12
+ + #output_dir = os.path.join(output_dir, now)
13
+ os.makedirs(output_dir, exist_ok=True)
14
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
15
+
requirements-monitor.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ nvitop==1.1.1
2
+ pandas==2.0.0
3
+ plotly==5.14.1
4
+ psutil==5.9.4
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.18.0
2
+ bitsandbytes==0.37.2
3
+ decord==0.6.0
4
+ diffusers[torch]==0.11.1
5
+ einops==0.6.0
6
+ ftfy==6.1.1
7
+ gradio==3.24.1
8
+ huggingface-hub==0.13.4
9
+ imageio==2.27.0
10
+ imageio-ffmpeg==0.4.8
11
+ omegaconf==2.3.0
12
+ Pillow==9.5.0
13
+ python-slugify==8.0.1
14
+ tensorboard==2.11.2
15
+ torch==1.13.1
16
+ torchvision==0.14.1
17
+ transformers==4.26.0
18
+ triton==2.0.0
19
+ xformers==0.0.16
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
trainer.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import os
5
+ import pathlib
6
+ import shlex
7
+ import shutil
8
+ import subprocess
9
+ import sys
10
+
11
+ import slugify
12
+ import torch
13
+ from huggingface_hub import HfApi
14
+ from omegaconf import OmegaConf
15
+
16
+ from uploader import upload
17
+ from utils import save_model_card
18
+
19
+ sys.path.append('Tune-A-Video')
20
+
21
+
22
+ class Trainer:
23
+ def __init__(self):
24
+ self.checkpoint_dir = pathlib.Path('checkpoints')
25
+ self.checkpoint_dir.mkdir(exist_ok=True)
26
+
27
+ self.log_file = pathlib.Path('log.txt')
28
+ self.log_file.touch(exist_ok=True)
29
+
30
+ def download_base_model(self, base_model_id: str) -> str:
31
+ model_dir = self.checkpoint_dir / base_model_id
32
+ if not model_dir.exists():
33
+ org_name = base_model_id.split('/')[0]
34
+ org_dir = self.checkpoint_dir / org_name
35
+ org_dir.mkdir(exist_ok=True)
36
+ subprocess.run(shlex.split(
37
+ f'git clone https://huggingface.co/{base_model_id}'),
38
+ cwd=org_dir)
39
+ return model_dir.as_posix()
40
+
41
+ def run(
42
+ self,
43
+ training_video: str,
44
+ training_prompt: str,
45
+ output_model_name: str,
46
+ overwrite_existing_model: bool,
47
+ validation_prompt: str,
48
+ base_model: str,
49
+ resolution_s: str,
50
+ n_steps: int,
51
+ learning_rate: float,
52
+ gradient_accumulation: int,
53
+ seed: int,
54
+ fp16: bool,
55
+ use_8bit_adam: bool,
56
+ checkpointing_steps: int,
57
+ validation_epochs: int,
58
+ upload_to_hub: bool,
59
+ use_private_repo: bool,
60
+ delete_existing_repo: bool,
61
+ upload_to: str,
62
+ pause_space_after_training: bool,
63
+ hf_token: str,
64
+ ) -> None:
65
+ if not torch.cuda.is_available():
66
+ raise RuntimeError('CUDA is not available.')
67
+ if training_video is None:
68
+ raise ValueError('You need to upload a video.')
69
+ if not training_prompt:
70
+ raise ValueError('The training prompt is missing.')
71
+ if not validation_prompt:
72
+ raise ValueError('The validation prompt is missing.')
73
+
74
+ resolution = int(resolution_s)
75
+
76
+ if not output_model_name:
77
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
78
+ output_model_name = f'tune-a-video-{timestamp}'
79
+ output_model_name = slugify.slugify(output_model_name)
80
+
81
+ repo_dir = pathlib.Path(__file__).parent
82
+ output_dir = repo_dir / 'experiments' / output_model_name
83
+ if overwrite_existing_model or upload_to_hub:
84
+ shutil.rmtree(output_dir, ignore_errors=True)
85
+ output_dir.mkdir(parents=True)
86
+
87
+ config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml')
88
+ config.pretrained_model_path = self.download_base_model(base_model)
89
+ config.output_dir = output_dir.as_posix()
90
+ config.train_data.video_path = training_video.name # type: ignore
91
+ config.train_data.prompt = training_prompt
92
+ config.train_data.n_sample_frames = 8
93
+ config.train_data.width = resolution
94
+ config.train_data.height = resolution
95
+ config.train_data.sample_start_idx = 0
96
+ config.train_data.sample_frame_rate = 1
97
+ config.validation_data.prompts = [validation_prompt]
98
+ config.validation_data.video_length = 8
99
+ config.validation_data.width = resolution
100
+ config.validation_data.height = resolution
101
+ config.validation_data.num_inference_steps = 50
102
+ config.validation_data.guidance_scale = 7.5
103
+ config.learning_rate = learning_rate
104
+ config.gradient_accumulation_steps = gradient_accumulation
105
+ config.train_batch_size = 1
106
+ config.max_train_steps = n_steps
107
+ config.checkpointing_steps = checkpointing_steps
108
+ config.validation_steps = validation_epochs
109
+ config.seed = seed
110
+ config.mixed_precision = 'fp16' if fp16 else ''
111
+ config.use_8bit_adam = use_8bit_adam
112
+
113
+ config_path = output_dir / 'config.yaml'
114
+ with open(config_path, 'w') as f:
115
+ OmegaConf.save(config, f)
116
+
117
+ command = f'accelerate launch Tune-A-Video/train_tuneavideo.py --config {config_path}'
118
+ with open(self.log_file, 'w') as f:
119
+ subprocess.run(shlex.split(command),
120
+ stdout=f,
121
+ stderr=subprocess.STDOUT,
122
+ text=True)
123
+ save_model_card(save_dir=output_dir,
124
+ base_model=base_model,
125
+ training_prompt=training_prompt,
126
+ test_prompt=validation_prompt,
127
+ test_image_dir='samples')
128
+
129
+ with open(self.log_file, 'a') as f:
130
+ f.write('Training completed!\n')
131
+
132
+ if upload_to_hub:
133
+ upload_message = upload(local_folder_path=output_dir.as_posix(),
134
+ target_repo_name=output_model_name,
135
+ upload_to=upload_to,
136
+ private=use_private_repo,
137
+ delete_existing_repo=delete_existing_repo,
138
+ hf_token=hf_token)
139
+ with open(self.log_file, 'a') as f:
140
+ f.write(upload_message)
141
+
142
+ if pause_space_after_training:
143
+ if space_id := os.getenv('SPACE_ID'):
144
+ api = HfApi(token=os.getenv('HF_TOKEN') or hf_token)
145
+ api.pause_space(repo_id=space_id)
uploader.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+ import shlex
6
+ import subprocess
7
+
8
+ import slugify
9
+ from huggingface_hub import HfApi
10
+
11
+ from constants import (MODEL_LIBRARY_ORG_NAME, URL_TO_JOIN_MODEL_LIBRARY_ORG,
12
+ UploadTarget)
13
+
14
+
15
+ def join_model_library_org(hf_token: str) -> None:
16
+ subprocess.run(
17
+ shlex.split(
18
+ f'curl -X POST -H "Authorization: Bearer {hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
19
+ ))
20
+
21
+
22
+ def upload(local_folder_path: str,
23
+ target_repo_name: str,
24
+ upload_to: str,
25
+ private: bool = True,
26
+ delete_existing_repo: bool = False,
27
+ hf_token: str = '') -> str:
28
+ hf_token = os.getenv('HF_TOKEN') or hf_token
29
+ if not hf_token:
30
+ raise ValueError
31
+ api = HfApi(token=hf_token)
32
+
33
+ if not local_folder_path:
34
+ raise ValueError
35
+ if not target_repo_name:
36
+ target_repo_name = pathlib.Path(local_folder_path).name
37
+ target_repo_name = slugify.slugify(target_repo_name)
38
+
39
+ if upload_to == UploadTarget.PERSONAL_PROFILE.value:
40
+ organization = api.whoami()['name']
41
+ elif upload_to == UploadTarget.MODEL_LIBRARY.value:
42
+ organization = MODEL_LIBRARY_ORG_NAME
43
+ join_model_library_org(hf_token)
44
+ else:
45
+ raise ValueError
46
+
47
+ repo_id = f'{organization}/{target_repo_name}'
48
+ if delete_existing_repo:
49
+ try:
50
+ api.delete_repo(repo_id, repo_type='model')
51
+ except Exception:
52
+ pass
53
+ try:
54
+ api.create_repo(repo_id, repo_type='model', private=private)
55
+ api.upload_folder(repo_id=repo_id,
56
+ folder_path=local_folder_path,
57
+ path_in_repo='.',
58
+ repo_type='model')
59
+ url = f'https://huggingface.co/{repo_id}'
60
+ message = f'Your model was successfully uploaded to {url}.'
61
+ except Exception as e:
62
+ message = str(e)
63
+ return message
utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pathlib
4
+
5
+
6
+ def find_exp_dirs() -> list[str]:
7
+ repo_dir = pathlib.Path(__file__).parent
8
+ exp_root_dir = repo_dir / 'experiments'
9
+ if not exp_root_dir.exists():
10
+ return []
11
+ exp_dirs = sorted(exp_root_dir.glob('*'))
12
+ exp_dirs = [
13
+ exp_dir for exp_dir in exp_dirs
14
+ if (exp_dir / 'model_index.json').exists()
15
+ ]
16
+ return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
17
+
18
+
19
+ def save_model_card(
20
+ save_dir: pathlib.Path,
21
+ base_model: str,
22
+ training_prompt: str,
23
+ test_prompt: str = '',
24
+ test_image_dir: str = '',
25
+ ) -> None:
26
+ image_str = ''
27
+ if test_prompt and test_image_dir:
28
+ image_paths = sorted((save_dir / test_image_dir).glob('*.gif'))
29
+ if image_paths:
30
+ image_path = image_paths[-1]
31
+ rel_path = image_path.relative_to(save_dir)
32
+ image_str = f'''## Samples
33
+ Test prompt: {test_prompt}
34
+
35
+ ![{image_path.stem}]({rel_path})'''
36
+
37
+ model_card = f'''---
38
+ license: creativeml-openrail-m
39
+ base_model: {base_model}
40
+ training_prompt: {training_prompt}
41
+ tags:
42
+ - stable-diffusion
43
+ - stable-diffusion-diffusers
44
+ - text-to-image
45
+ - diffusers
46
+ - text-to-video
47
+ - tune-a-video
48
+ inference: false
49
+ ---
50
+
51
+ # Tune-A-Video - {save_dir.name}
52
+
53
+ ## Model description
54
+ - Base model: [{base_model}](https://huggingface.co/{base_model})
55
+ - Training prompt: {training_prompt}
56
+
57
+ {image_str}
58
+
59
+ ## Related papers:
60
+ - [Tune-A-Video](https://arxiv.org/abs/2212.11565): One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation
61
+ - [Stable-Diffusion](https://arxiv.org/abs/2112.10752): High-Resolution Image Synthesis with Latent Diffusion Models
62
+ '''
63
+
64
+ with open(save_dir / 'README.md', 'w') as f:
65
+ f.write(model_card)