trhacknon hysts HF staff commited on
Commit
c50a2e8
0 Parent(s):

Duplicate from Tune-A-Video-library/Tune-A-Video-inference

Browse files

Co-authored-by: hysts <hysts@users.noreply.huggingface.co>

Files changed (15) hide show
  1. .gitattributes +35 -0
  2. .gitignore +164 -0
  3. .gitmodules +3 -0
  4. .pre-commit-config.yaml +37 -0
  5. .style.yapf +5 -0
  6. Dockerfile +57 -0
  7. LICENSE +21 -0
  8. README.md +12 -0
  9. Tune-A-Video +1 -0
  10. app.py +217 -0
  11. inference.py +109 -0
  12. packages.txt +1 -0
  13. patch +15 -0
  14. requirements.txt +19 -0
  15. style.css +3 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.whl filter=lfs diff=lfs merge=lfs -text
2
+ *.7z filter=lfs diff=lfs merge=lfs -text
3
+ *.arrow filter=lfs diff=lfs merge=lfs -text
4
+ *.bin filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
7
+ *.ftz filter=lfs diff=lfs merge=lfs -text
8
+ *.gz filter=lfs diff=lfs merge=lfs -text
9
+ *.h5 filter=lfs diff=lfs merge=lfs -text
10
+ *.joblib filter=lfs diff=lfs merge=lfs -text
11
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
12
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
13
+ *.model filter=lfs diff=lfs merge=lfs -text
14
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
15
+ *.npy filter=lfs diff=lfs merge=lfs -text
16
+ *.npz filter=lfs diff=lfs merge=lfs -text
17
+ *.onnx filter=lfs diff=lfs merge=lfs -text
18
+ *.ot filter=lfs diff=lfs merge=lfs -text
19
+ *.parquet filter=lfs diff=lfs merge=lfs -text
20
+ *.pb filter=lfs diff=lfs merge=lfs -text
21
+ *.pickle filter=lfs diff=lfs merge=lfs -text
22
+ *.pkl filter=lfs diff=lfs merge=lfs -text
23
+ *.pt filter=lfs diff=lfs merge=lfs -text
24
+ *.pth filter=lfs diff=lfs merge=lfs -text
25
+ *.rar filter=lfs diff=lfs merge=lfs -text
26
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
27
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints/
2
+ experiments/
3
+
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "Tune-A-Video"]
2
+ path = Tune-A-Video
3
+ url = https://github.com/showlab/Tune-A-Video
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: patch
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
Dockerfile ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && \
4
+ apt-get upgrade -y && \
5
+ apt-get install -y --no-install-recommends \
6
+ git \
7
+ git-lfs \
8
+ wget \
9
+ curl \
10
+ # ffmpeg \
11
+ ffmpeg \
12
+ x264 \
13
+ # python build dependencies \
14
+ build-essential \
15
+ libssl-dev \
16
+ zlib1g-dev \
17
+ libbz2-dev \
18
+ libreadline-dev \
19
+ libsqlite3-dev \
20
+ libncursesw5-dev \
21
+ xz-utils \
22
+ tk-dev \
23
+ libxml2-dev \
24
+ libxmlsec1-dev \
25
+ libffi-dev \
26
+ liblzma-dev && \
27
+ apt-get clean && \
28
+ rm -rf /var/lib/apt/lists/*
29
+
30
+ RUN useradd -m -u 1000 user
31
+ USER user
32
+ ENV HOME=/home/user \
33
+ PATH=/home/user/.local/bin:${PATH}
34
+ WORKDIR ${HOME}/app
35
+
36
+ RUN curl https://pyenv.run | bash
37
+ ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
38
+ ENV PYTHON_VERSION=3.10.9
39
+ RUN pyenv install ${PYTHON_VERSION} && \
40
+ pyenv global ${PYTHON_VERSION} && \
41
+ pyenv rehash && \
42
+ pip install --no-cache-dir -U pip setuptools wheel
43
+
44
+ RUN pip install --no-cache-dir -U torch==1.13.1 torchvision==0.14.1
45
+ COPY --chown=1000 requirements.txt /tmp/requirements.txt
46
+ RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
47
+
48
+ COPY --chown=1000 . ${HOME}/app
49
+ RUN cd Tune-A-Video && patch -p1 < ../patch
50
+ ENV PYTHONPATH=${HOME}/app \
51
+ PYTHONUNBUFFERED=1 \
52
+ GRADIO_ALLOW_FLAGGING=never \
53
+ GRADIO_NUM_PORTS=1 \
54
+ GRADIO_SERVER_NAME=0.0.0.0 \
55
+ GRADIO_THEME=huggingface \
56
+ SYSTEM=spaces
57
+ CMD ["python", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Tune-A-Video Inference
3
+ emoji: 🐠
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ duplicated_from: Tune-A-Video-library/Tune-A-Video-inference
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
Tune-A-Video ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit b2c8c3eeac0df5c5d9eccc4dd2153e17b83c638c
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+
9
+ from inference import InferencePipeline
10
+
11
+
12
+ class InferenceUtil:
13
+ def __init__(self, hf_token: str | None):
14
+ self.hf_token = hf_token
15
+
16
+ def load_model_info(self, model_id: str) -> tuple[str, str]:
17
+ try:
18
+ card = InferencePipeline.get_model_card(model_id, self.hf_token)
19
+ except Exception:
20
+ return '', ''
21
+ base_model = getattr(card.data, 'base_model', '')
22
+ training_prompt = getattr(card.data, 'training_prompt', '')
23
+ return base_model, training_prompt
24
+
25
+
26
+ TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/)'
27
+ HF_TOKEN = os.getenv('HF_TOKEN')
28
+ pipe = InferencePipeline(HF_TOKEN)
29
+ app = InferenceUtil(HF_TOKEN)
30
+
31
+ with gr.Blocks(css='style.css') as demo:
32
+ gr.Markdown(TITLE)
33
+
34
+ with gr.Row():
35
+ with gr.Column():
36
+ with gr.Box():
37
+ model_id = gr.Dropdown(
38
+ label='Model ID',
39
+ choices=[
40
+ 'Tune-A-Video-library/a-man-is-surfing',
41
+ 'Tune-A-Video-library/mo-di-bear-guitar',
42
+ 'Tune-A-Video-library/redshift-man-skiing',
43
+ ],
44
+ value='Tune-A-Video-library/a-man-is-surfing')
45
+ with gr.Accordion(
46
+ label=
47
+ 'Model info (Base model and prompt used for training)',
48
+ open=False):
49
+ with gr.Row():
50
+ base_model_used_for_training = gr.Text(
51
+ label='Base model', interactive=False)
52
+ prompt_used_for_training = gr.Text(
53
+ label='Training prompt', interactive=False)
54
+ prompt = gr.Textbox(label='Prompt',
55
+ max_lines=1,
56
+ placeholder='Example: "A panda is surfing"')
57
+ video_length = gr.Slider(label='Video length',
58
+ minimum=4,
59
+ maximum=12,
60
+ step=1,
61
+ value=8)
62
+ fps = gr.Slider(label='FPS',
63
+ minimum=1,
64
+ maximum=12,
65
+ step=1,
66
+ value=1)
67
+ seed = gr.Slider(label='Seed',
68
+ minimum=0,
69
+ maximum=100000,
70
+ step=1,
71
+ value=0)
72
+ with gr.Accordion('Other Parameters', open=False):
73
+ num_steps = gr.Slider(label='Number of Steps',
74
+ minimum=0,
75
+ maximum=100,
76
+ step=1,
77
+ value=50)
78
+ guidance_scale = gr.Slider(label='CFG Scale',
79
+ minimum=0,
80
+ maximum=50,
81
+ step=0.1,
82
+ value=7.5)
83
+
84
+ run_button = gr.Button('Generate')
85
+
86
+ gr.Markdown('''
87
+ - It takes a few minutes to download model first.
88
+ - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
89
+ ''')
90
+ with gr.Column():
91
+ result = gr.Video(label='Result')
92
+ with gr.Row():
93
+ examples = [
94
+ [
95
+ 'Tune-A-Video-library/a-man-is-surfing',
96
+ 'A panda is surfing.',
97
+ 8,
98
+ 1,
99
+ 3,
100
+ 50,
101
+ 7.5,
102
+ ],
103
+ [
104
+ 'Tune-A-Video-library/a-man-is-surfing',
105
+ 'A racoon is surfing, cartoon style.',
106
+ 8,
107
+ 1,
108
+ 3,
109
+ 50,
110
+ 7.5,
111
+ ],
112
+ [
113
+ 'Tune-A-Video-library/mo-di-bear-guitar',
114
+ 'a handsome prince is playing guitar, modern disney style.',
115
+ 8,
116
+ 1,
117
+ 123,
118
+ 50,
119
+ 7.5,
120
+ ],
121
+ [
122
+ 'Tune-A-Video-library/mo-di-bear-guitar',
123
+ 'a magical princess is playing guitar, modern disney style.',
124
+ 8,
125
+ 1,
126
+ 123,
127
+ 50,
128
+ 7.5,
129
+ ],
130
+ [
131
+ 'Tune-A-Video-library/mo-di-bear-guitar',
132
+ 'a rabbit is playing guitar, modern disney style.',
133
+ 8,
134
+ 1,
135
+ 123,
136
+ 50,
137
+ 7.5,
138
+ ],
139
+ [
140
+ 'Tune-A-Video-library/mo-di-bear-guitar',
141
+ 'a baby is playing guitar, modern disney style.',
142
+ 8,
143
+ 1,
144
+ 123,
145
+ 50,
146
+ 7.5,
147
+ ],
148
+ [
149
+ 'Tune-A-Video-library/redshift-man-skiing',
150
+ '(redshift style) spider man is skiing.',
151
+ 8,
152
+ 1,
153
+ 123,
154
+ 50,
155
+ 7.5,
156
+ ],
157
+ [
158
+ 'Tune-A-Video-library/redshift-man-skiing',
159
+ '(redshift style) black widow is skiing.',
160
+ 8,
161
+ 1,
162
+ 123,
163
+ 50,
164
+ 7.5,
165
+ ],
166
+ [
167
+ 'Tune-A-Video-library/redshift-man-skiing',
168
+ '(redshift style) batman is skiing.',
169
+ 8,
170
+ 1,
171
+ 123,
172
+ 50,
173
+ 7.5,
174
+ ],
175
+ [
176
+ 'Tune-A-Video-library/redshift-man-skiing',
177
+ '(redshift style) hulk is skiing.',
178
+ 8,
179
+ 1,
180
+ 123,
181
+ 50,
182
+ 7.5,
183
+ ],
184
+ ]
185
+ gr.Examples(examples=examples,
186
+ inputs=[
187
+ model_id,
188
+ prompt,
189
+ video_length,
190
+ fps,
191
+ seed,
192
+ num_steps,
193
+ guidance_scale,
194
+ ],
195
+ outputs=result,
196
+ fn=pipe.run,
197
+ cache_examples=os.getenv('SYSTEM') == 'spaces')
198
+
199
+ model_id.change(fn=app.load_model_info,
200
+ inputs=model_id,
201
+ outputs=[
202
+ base_model_used_for_training,
203
+ prompt_used_for_training,
204
+ ])
205
+ inputs = [
206
+ model_id,
207
+ prompt,
208
+ video_length,
209
+ fps,
210
+ seed,
211
+ num_steps,
212
+ guidance_scale,
213
+ ]
214
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
215
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
216
+
217
+ demo.queue().launch()
inference.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+ import sys
6
+ import tempfile
7
+
8
+ import gradio as gr
9
+ import imageio
10
+ import PIL.Image
11
+ import torch
12
+ from diffusers.utils.import_utils import is_xformers_available
13
+ from einops import rearrange
14
+ from huggingface_hub import ModelCard
15
+
16
+ sys.path.append('Tune-A-Video')
17
+
18
+ from tuneavideo.models.unet import UNet3DConditionModel
19
+ from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
20
+
21
+
22
+ class InferencePipeline:
23
+ def __init__(self, hf_token: str | None = None):
24
+ self.hf_token = hf_token
25
+ self.pipe = None
26
+ self.device = torch.device(
27
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
28
+ self.model_id = None
29
+
30
+ def clear(self) -> None:
31
+ self.model_id = None
32
+ del self.pipe
33
+ self.pipe = None
34
+ torch.cuda.empty_cache()
35
+ gc.collect()
36
+
37
+ @staticmethod
38
+ def check_if_model_is_local(model_id: str) -> bool:
39
+ return pathlib.Path(model_id).exists()
40
+
41
+ @staticmethod
42
+ def get_model_card(model_id: str,
43
+ hf_token: str | None = None) -> ModelCard:
44
+ if InferencePipeline.check_if_model_is_local(model_id):
45
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
46
+ else:
47
+ card_path = model_id
48
+ return ModelCard.load(card_path, token=hf_token)
49
+
50
+ @staticmethod
51
+ def get_base_model_info(model_id: str, hf_token: str | None = None) -> str:
52
+ card = InferencePipeline.get_model_card(model_id, hf_token)
53
+ return card.data.base_model
54
+
55
+ def load_pipe(self, model_id: str) -> None:
56
+ if model_id == self.model_id:
57
+ return
58
+ base_model_id = self.get_base_model_info(model_id, self.hf_token)
59
+ unet = UNet3DConditionModel.from_pretrained(
60
+ model_id,
61
+ subfolder='unet',
62
+ torch_dtype=torch.float16,
63
+ use_auth_token=self.hf_token)
64
+ pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
65
+ unet=unet,
66
+ torch_dtype=torch.float16,
67
+ use_auth_token=self.hf_token)
68
+ pipe = pipe.to(self.device)
69
+ if is_xformers_available():
70
+ pipe.unet.enable_xformers_memory_efficient_attention()
71
+ self.pipe = pipe
72
+ self.model_id = model_id # type: ignore
73
+
74
+ def run(
75
+ self,
76
+ model_id: str,
77
+ prompt: str,
78
+ video_length: int,
79
+ fps: int,
80
+ seed: int,
81
+ n_steps: int,
82
+ guidance_scale: float,
83
+ ) -> PIL.Image.Image:
84
+ if not torch.cuda.is_available():
85
+ raise gr.Error('CUDA is not available.')
86
+
87
+ self.load_pipe(model_id)
88
+
89
+ generator = torch.Generator(device=self.device).manual_seed(seed)
90
+ out = self.pipe(
91
+ prompt,
92
+ video_length=video_length,
93
+ width=512,
94
+ height=512,
95
+ num_inference_steps=n_steps,
96
+ guidance_scale=guidance_scale,
97
+ generator=generator,
98
+ ) # type: ignore
99
+
100
+ frames = rearrange(out.videos[0], 'c t h w -> t h w c')
101
+ frames = (frames * 255).to(torch.uint8).numpy()
102
+
103
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
104
+ writer = imageio.get_writer(out_file.name, fps=fps)
105
+ for frame in frames:
106
+ writer.append_data(frame)
107
+ writer.close()
108
+
109
+ return out_file.name
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
patch ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/train_tuneavideo.py b/train_tuneavideo.py
2
+ index 66d51b2..86b2a5d 100644
3
+ --- a/train_tuneavideo.py
4
+ +++ b/train_tuneavideo.py
5
+ @@ -94,8 +94,8 @@ def main(
6
+
7
+ # Handle the output folder creation
8
+ if accelerator.is_main_process:
9
+ - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
10
+ - output_dir = os.path.join(output_dir, now)
11
+ + #now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
12
+ + #output_dir = os.path.join(output_dir, now)
13
+ os.makedirs(output_dir, exist_ok=True)
14
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
15
+
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.15.0
2
+ bitsandbytes==0.35.4
3
+ decord==0.6.0
4
+ diffusers[torch]==0.11.1
5
+ einops==0.6.0
6
+ ftfy==6.1.1
7
+ gradio==3.18.0
8
+ huggingface-hub==0.12.0
9
+ imageio==2.25.0
10
+ imageio-ffmpeg==0.4.8
11
+ omegaconf==2.3.0
12
+ Pillow==9.4.0
13
+ python-slugify==7.0.0
14
+ tensorboard==2.11.2
15
+ torch==1.13.1
16
+ torchvision==0.14.1
17
+ transformers==4.26.0
18
+ triton==2.0.0.dev20221202
19
+ xformers==0.0.16
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }