Spaces:
Sleeping
Sleeping
Duplicate from Tune-A-Video-library/Tune-A-Video-Training-UI
Browse filesCo-authored-by: hysts <hysts@users.noreply.huggingface.co>
- .gitattributes +35 -0
- .gitignore +164 -0
- .gitmodules +3 -0
- .pre-commit-config.yaml +37 -0
- .style.yapf +5 -0
- Dockerfile +57 -0
- LICENSE +21 -0
- README.md +12 -0
- Tune-A-Video +1 -0
- app.py +76 -0
- app_inference.py +170 -0
- app_training.py +140 -0
- app_upload.py +100 -0
- constants.py +10 -0
- inference.py +109 -0
- packages.txt +1 -0
- patch +15 -0
- requirements.txt +19 -0
- style.css +3 -0
- trainer.py +156 -0
- uploader.py +42 -0
- utils.py +65 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.whl filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
27 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoints/
|
2 |
+
experiments/
|
3 |
+
|
4 |
+
|
5 |
+
# Byte-compiled / optimized / DLL files
|
6 |
+
__pycache__/
|
7 |
+
*.py[cod]
|
8 |
+
*$py.class
|
9 |
+
|
10 |
+
# C extensions
|
11 |
+
*.so
|
12 |
+
|
13 |
+
# Distribution / packaging
|
14 |
+
.Python
|
15 |
+
build/
|
16 |
+
develop-eggs/
|
17 |
+
dist/
|
18 |
+
downloads/
|
19 |
+
eggs/
|
20 |
+
.eggs/
|
21 |
+
lib/
|
22 |
+
lib64/
|
23 |
+
parts/
|
24 |
+
sdist/
|
25 |
+
var/
|
26 |
+
wheels/
|
27 |
+
share/python-wheels/
|
28 |
+
*.egg-info/
|
29 |
+
.installed.cfg
|
30 |
+
*.egg
|
31 |
+
MANIFEST
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Unit test / coverage reports
|
44 |
+
htmlcov/
|
45 |
+
.tox/
|
46 |
+
.nox/
|
47 |
+
.coverage
|
48 |
+
.coverage.*
|
49 |
+
.cache
|
50 |
+
nosetests.xml
|
51 |
+
coverage.xml
|
52 |
+
*.cover
|
53 |
+
*.py,cover
|
54 |
+
.hypothesis/
|
55 |
+
.pytest_cache/
|
56 |
+
cover/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
db.sqlite3
|
66 |
+
db.sqlite3-journal
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
.pybuilder/
|
80 |
+
target/
|
81 |
+
|
82 |
+
# Jupyter Notebook
|
83 |
+
.ipynb_checkpoints
|
84 |
+
|
85 |
+
# IPython
|
86 |
+
profile_default/
|
87 |
+
ipython_config.py
|
88 |
+
|
89 |
+
# pyenv
|
90 |
+
# For a library or package, you might want to ignore these files since the code is
|
91 |
+
# intended to run in multiple environments; otherwise, check them in:
|
92 |
+
# .python-version
|
93 |
+
|
94 |
+
# pipenv
|
95 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
96 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
97 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
98 |
+
# install all needed dependencies.
|
99 |
+
#Pipfile.lock
|
100 |
+
|
101 |
+
# poetry
|
102 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
103 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
104 |
+
# commonly ignored for libraries.
|
105 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
106 |
+
#poetry.lock
|
107 |
+
|
108 |
+
# pdm
|
109 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
110 |
+
#pdm.lock
|
111 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
112 |
+
# in version control.
|
113 |
+
# https://pdm.fming.dev/#use-with-ide
|
114 |
+
.pdm.toml
|
115 |
+
|
116 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
117 |
+
__pypackages__/
|
118 |
+
|
119 |
+
# Celery stuff
|
120 |
+
celerybeat-schedule
|
121 |
+
celerybeat.pid
|
122 |
+
|
123 |
+
# SageMath parsed files
|
124 |
+
*.sage.py
|
125 |
+
|
126 |
+
# Environments
|
127 |
+
.env
|
128 |
+
.venv
|
129 |
+
env/
|
130 |
+
venv/
|
131 |
+
ENV/
|
132 |
+
env.bak/
|
133 |
+
venv.bak/
|
134 |
+
|
135 |
+
# Spyder project settings
|
136 |
+
.spyderproject
|
137 |
+
.spyproject
|
138 |
+
|
139 |
+
# Rope project settings
|
140 |
+
.ropeproject
|
141 |
+
|
142 |
+
# mkdocs documentation
|
143 |
+
/site
|
144 |
+
|
145 |
+
# mypy
|
146 |
+
.mypy_cache/
|
147 |
+
.dmypy.json
|
148 |
+
dmypy.json
|
149 |
+
|
150 |
+
# Pyre type checker
|
151 |
+
.pyre/
|
152 |
+
|
153 |
+
# pytype static type analyzer
|
154 |
+
.pytype/
|
155 |
+
|
156 |
+
# Cython debug symbols
|
157 |
+
cython_debug/
|
158 |
+
|
159 |
+
# PyCharm
|
160 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
161 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
162 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
163 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
164 |
+
#.idea/
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "Tune-A-Video"]
|
2 |
+
path = Tune-A-Video
|
3 |
+
url = https://github.com/showlab/Tune-A-Video
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
exclude: patch
|
2 |
+
repos:
|
3 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
4 |
+
rev: v4.2.0
|
5 |
+
hooks:
|
6 |
+
- id: check-executables-have-shebangs
|
7 |
+
- id: check-json
|
8 |
+
- id: check-merge-conflict
|
9 |
+
- id: check-shebang-scripts-are-executable
|
10 |
+
- id: check-toml
|
11 |
+
- id: check-yaml
|
12 |
+
- id: double-quote-string-fixer
|
13 |
+
- id: end-of-file-fixer
|
14 |
+
- id: mixed-line-ending
|
15 |
+
args: ['--fix=lf']
|
16 |
+
- id: requirements-txt-fixer
|
17 |
+
- id: trailing-whitespace
|
18 |
+
- repo: https://github.com/myint/docformatter
|
19 |
+
rev: v1.4
|
20 |
+
hooks:
|
21 |
+
- id: docformatter
|
22 |
+
args: ['--in-place']
|
23 |
+
- repo: https://github.com/pycqa/isort
|
24 |
+
rev: 5.12.0
|
25 |
+
hooks:
|
26 |
+
- id: isort
|
27 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
28 |
+
rev: v0.991
|
29 |
+
hooks:
|
30 |
+
- id: mypy
|
31 |
+
args: ['--ignore-missing-imports']
|
32 |
+
additional_dependencies: ['types-python-slugify']
|
33 |
+
- repo: https://github.com/google/yapf
|
34 |
+
rev: v0.32.0
|
35 |
+
hooks:
|
36 |
+
- id: yapf
|
37 |
+
args: ['--parallel', '--in-place']
|
.style.yapf
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[style]
|
2 |
+
based_on_style = pep8
|
3 |
+
blank_line_before_nested_class_or_def = false
|
4 |
+
spaces_before_comment = 2
|
5 |
+
split_before_logical_operator = true
|
Dockerfile
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
|
2 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
3 |
+
RUN apt-get update && \
|
4 |
+
apt-get upgrade -y && \
|
5 |
+
apt-get install -y --no-install-recommends \
|
6 |
+
git \
|
7 |
+
git-lfs \
|
8 |
+
wget \
|
9 |
+
curl \
|
10 |
+
# ffmpeg \
|
11 |
+
ffmpeg \
|
12 |
+
x264 \
|
13 |
+
# python build dependencies \
|
14 |
+
build-essential \
|
15 |
+
libssl-dev \
|
16 |
+
zlib1g-dev \
|
17 |
+
libbz2-dev \
|
18 |
+
libreadline-dev \
|
19 |
+
libsqlite3-dev \
|
20 |
+
libncursesw5-dev \
|
21 |
+
xz-utils \
|
22 |
+
tk-dev \
|
23 |
+
libxml2-dev \
|
24 |
+
libxmlsec1-dev \
|
25 |
+
libffi-dev \
|
26 |
+
liblzma-dev && \
|
27 |
+
apt-get clean && \
|
28 |
+
rm -rf /var/lib/apt/lists/*
|
29 |
+
|
30 |
+
RUN useradd -m -u 1000 user
|
31 |
+
USER user
|
32 |
+
ENV HOME=/home/user \
|
33 |
+
PATH=/home/user/.local/bin:${PATH}
|
34 |
+
WORKDIR ${HOME}/app
|
35 |
+
|
36 |
+
RUN curl https://pyenv.run | bash
|
37 |
+
ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
|
38 |
+
ENV PYTHON_VERSION=3.10.9
|
39 |
+
RUN pyenv install ${PYTHON_VERSION} && \
|
40 |
+
pyenv global ${PYTHON_VERSION} && \
|
41 |
+
pyenv rehash && \
|
42 |
+
pip install --no-cache-dir -U pip setuptools wheel
|
43 |
+
|
44 |
+
RUN pip install --no-cache-dir -U torch==1.13.1 torchvision==0.14.1
|
45 |
+
COPY --chown=1000 requirements.txt /tmp/requirements.txt
|
46 |
+
RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
|
47 |
+
|
48 |
+
COPY --chown=1000 . ${HOME}/app
|
49 |
+
RUN cd Tune-A-Video && patch -p1 < ../patch
|
50 |
+
ENV PYTHONPATH=${HOME}/app \
|
51 |
+
PYTHONUNBUFFERED=1 \
|
52 |
+
GRADIO_ALLOW_FLAGGING=never \
|
53 |
+
GRADIO_NUM_PORTS=1 \
|
54 |
+
GRADIO_SERVER_NAME=0.0.0.0 \
|
55 |
+
GRADIO_THEME=huggingface \
|
56 |
+
SYSTEM=spaces
|
57 |
+
CMD ["python", "app.py"]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 hysts
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Tune-A-Video Training UI
|
3 |
+
emoji: ⚡
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: purple
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
license: mit
|
9 |
+
duplicated_from: Tune-A-Video-library/Tune-A-Video-Training-UI
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
Tune-A-Video
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit b2c8c3eeac0df5c5d9eccc4dd2153e17b83c638c
|
app.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from app_inference import create_inference_demo
|
11 |
+
from app_training import create_training_demo
|
12 |
+
from app_upload import create_upload_demo
|
13 |
+
from inference import InferencePipeline
|
14 |
+
from trainer import Trainer
|
15 |
+
|
16 |
+
TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/) Training UI'
|
17 |
+
|
18 |
+
ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
|
19 |
+
SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
|
20 |
+
SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU. (Please note that there seems to be an issue with training on the A10G GPU now. The model doesn't learn anything when trained on A10G. Training on T4 works perfectly fine and inference works fine on both.)
|
21 |
+
|
22 |
+
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
|
23 |
+
'''
|
24 |
+
|
25 |
+
if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
|
26 |
+
SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
|
27 |
+
else:
|
28 |
+
SETTINGS = 'Settings'
|
29 |
+
CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
|
30 |
+
<center>
|
31 |
+
You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
|
32 |
+
You can use "T4 small/medium" or "A10G small/large" to run this demo.
|
33 |
+
</center>
|
34 |
+
'''
|
35 |
+
|
36 |
+
HF_TOKEN_NOT_SPECIFIED_WARNING = f'''# Attention - The environment variable `HF_TOKEN` is not specified. Please specify your Hugging Face token with write permission as the value of it.
|
37 |
+
<center>
|
38 |
+
You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
|
39 |
+
You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
|
40 |
+
</center>
|
41 |
+
'''
|
42 |
+
|
43 |
+
HF_TOKEN = os.getenv('HF_TOKEN')
|
44 |
+
|
45 |
+
|
46 |
+
def show_warning(warning_text: str) -> gr.Blocks:
|
47 |
+
with gr.Blocks() as demo:
|
48 |
+
with gr.Box():
|
49 |
+
gr.Markdown(warning_text)
|
50 |
+
return demo
|
51 |
+
|
52 |
+
|
53 |
+
pipe = InferencePipeline(HF_TOKEN)
|
54 |
+
trainer = Trainer(HF_TOKEN)
|
55 |
+
|
56 |
+
with gr.Blocks(css='style.css') as demo:
|
57 |
+
if os.getenv('IS_SHARED_UI'):
|
58 |
+
show_warning(SHARED_UI_WARNING)
|
59 |
+
if not torch.cuda.is_available():
|
60 |
+
show_warning(CUDA_NOT_AVAILABLE_WARNING)
|
61 |
+
if not HF_TOKEN:
|
62 |
+
show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
|
63 |
+
|
64 |
+
gr.Markdown(TITLE)
|
65 |
+
with gr.Tabs():
|
66 |
+
with gr.TabItem('Train'):
|
67 |
+
create_training_demo(trainer, pipe)
|
68 |
+
with gr.TabItem('Test'):
|
69 |
+
create_inference_demo(pipe, HF_TOKEN)
|
70 |
+
with gr.TabItem('Upload'):
|
71 |
+
gr.Markdown('''
|
72 |
+
- You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
|
73 |
+
''')
|
74 |
+
create_upload_demo(HF_TOKEN)
|
75 |
+
|
76 |
+
demo.queue(max_size=1).launch(share=False)
|
app_inference.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import enum
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
from huggingface_hub import HfApi
|
9 |
+
|
10 |
+
from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
|
11 |
+
from inference import InferencePipeline
|
12 |
+
from utils import find_exp_dirs
|
13 |
+
|
14 |
+
|
15 |
+
class ModelSource(enum.Enum):
|
16 |
+
HUB_LIB = UploadTarget.MODEL_LIBRARY.value
|
17 |
+
LOCAL = 'Local'
|
18 |
+
|
19 |
+
|
20 |
+
class InferenceUtil:
|
21 |
+
def __init__(self, hf_token: str | None):
|
22 |
+
self.hf_token = hf_token
|
23 |
+
|
24 |
+
def load_hub_model_list(self) -> dict:
|
25 |
+
api = HfApi(token=self.hf_token)
|
26 |
+
choices = [
|
27 |
+
info.modelId
|
28 |
+
for info in api.list_models(author=MODEL_LIBRARY_ORG_NAME)
|
29 |
+
]
|
30 |
+
return gr.update(choices=choices,
|
31 |
+
value=choices[0] if choices else None)
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
def load_local_model_list() -> dict:
|
35 |
+
choices = find_exp_dirs()
|
36 |
+
return gr.update(choices=choices,
|
37 |
+
value=choices[0] if choices else None)
|
38 |
+
|
39 |
+
def reload_model_list(self, model_source: str) -> dict:
|
40 |
+
if model_source == ModelSource.HUB_LIB.value:
|
41 |
+
return self.load_hub_model_list()
|
42 |
+
elif model_source == ModelSource.LOCAL.value:
|
43 |
+
return self.load_local_model_list()
|
44 |
+
else:
|
45 |
+
raise ValueError
|
46 |
+
|
47 |
+
def load_model_info(self, model_id: str) -> tuple[str, str]:
|
48 |
+
try:
|
49 |
+
card = InferencePipeline.get_model_card(model_id, self.hf_token)
|
50 |
+
except Exception:
|
51 |
+
return '', ''
|
52 |
+
base_model = getattr(card.data, 'base_model', '')
|
53 |
+
training_prompt = getattr(card.data, 'training_prompt', '')
|
54 |
+
return base_model, training_prompt
|
55 |
+
|
56 |
+
def reload_model_list_and_update_model_info(
|
57 |
+
self, model_source: str) -> tuple[dict, str, str]:
|
58 |
+
model_list_update = self.reload_model_list(model_source)
|
59 |
+
model_list = model_list_update['choices']
|
60 |
+
model_info = self.load_model_info(model_list[0] if model_list else '')
|
61 |
+
return model_list_update, *model_info
|
62 |
+
|
63 |
+
|
64 |
+
def create_inference_demo(pipe: InferencePipeline,
|
65 |
+
hf_token: str | None = None) -> gr.Blocks:
|
66 |
+
app = InferenceUtil(hf_token)
|
67 |
+
|
68 |
+
with gr.Blocks() as demo:
|
69 |
+
with gr.Row():
|
70 |
+
with gr.Column():
|
71 |
+
with gr.Box():
|
72 |
+
model_source = gr.Radio(
|
73 |
+
label='Model Source',
|
74 |
+
choices=[_.value for _ in ModelSource],
|
75 |
+
value=ModelSource.HUB_LIB.value)
|
76 |
+
reload_button = gr.Button('Reload Model List')
|
77 |
+
model_id = gr.Dropdown(label='Model ID',
|
78 |
+
choices=None,
|
79 |
+
value=None)
|
80 |
+
with gr.Accordion(
|
81 |
+
label=
|
82 |
+
'Model info (Base model and prompt used for training)',
|
83 |
+
open=False):
|
84 |
+
with gr.Row():
|
85 |
+
base_model_used_for_training = gr.Text(
|
86 |
+
label='Base model', interactive=False)
|
87 |
+
prompt_used_for_training = gr.Text(
|
88 |
+
label='Training prompt', interactive=False)
|
89 |
+
prompt = gr.Textbox(
|
90 |
+
label='Prompt',
|
91 |
+
max_lines=1,
|
92 |
+
placeholder='Example: "A panda is surfing"')
|
93 |
+
video_length = gr.Slider(label='Video length',
|
94 |
+
minimum=4,
|
95 |
+
maximum=12,
|
96 |
+
step=1,
|
97 |
+
value=8)
|
98 |
+
fps = gr.Slider(label='FPS',
|
99 |
+
minimum=1,
|
100 |
+
maximum=12,
|
101 |
+
step=1,
|
102 |
+
value=1)
|
103 |
+
seed = gr.Slider(label='Seed',
|
104 |
+
minimum=0,
|
105 |
+
maximum=100000,
|
106 |
+
step=1,
|
107 |
+
value=0)
|
108 |
+
with gr.Accordion('Other Parameters', open=False):
|
109 |
+
num_steps = gr.Slider(label='Number of Steps',
|
110 |
+
minimum=0,
|
111 |
+
maximum=100,
|
112 |
+
step=1,
|
113 |
+
value=50)
|
114 |
+
guidance_scale = gr.Slider(label='CFG Scale',
|
115 |
+
minimum=0,
|
116 |
+
maximum=50,
|
117 |
+
step=0.1,
|
118 |
+
value=7.5)
|
119 |
+
|
120 |
+
run_button = gr.Button('Generate')
|
121 |
+
|
122 |
+
gr.Markdown('''
|
123 |
+
- After training, you can press "Reload Model List" button to load your trained model names.
|
124 |
+
- It takes a few minutes to download model first.
|
125 |
+
- Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
|
126 |
+
''')
|
127 |
+
with gr.Column():
|
128 |
+
result = gr.Video(label='Result')
|
129 |
+
|
130 |
+
model_source.change(fn=app.reload_model_list_and_update_model_info,
|
131 |
+
inputs=model_source,
|
132 |
+
outputs=[
|
133 |
+
model_id,
|
134 |
+
base_model_used_for_training,
|
135 |
+
prompt_used_for_training,
|
136 |
+
])
|
137 |
+
reload_button.click(fn=app.reload_model_list_and_update_model_info,
|
138 |
+
inputs=model_source,
|
139 |
+
outputs=[
|
140 |
+
model_id,
|
141 |
+
base_model_used_for_training,
|
142 |
+
prompt_used_for_training,
|
143 |
+
])
|
144 |
+
model_id.change(fn=app.load_model_info,
|
145 |
+
inputs=model_id,
|
146 |
+
outputs=[
|
147 |
+
base_model_used_for_training,
|
148 |
+
prompt_used_for_training,
|
149 |
+
])
|
150 |
+
inputs = [
|
151 |
+
model_id,
|
152 |
+
prompt,
|
153 |
+
video_length,
|
154 |
+
fps,
|
155 |
+
seed,
|
156 |
+
num_steps,
|
157 |
+
guidance_scale,
|
158 |
+
]
|
159 |
+
prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
|
160 |
+
run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
|
161 |
+
return demo
|
162 |
+
|
163 |
+
|
164 |
+
if __name__ == '__main__':
|
165 |
+
import os
|
166 |
+
|
167 |
+
hf_token = os.getenv('HF_TOKEN')
|
168 |
+
pipe = InferencePipeline(hf_token)
|
169 |
+
demo = create_inference_demo(pipe, hf_token)
|
170 |
+
demo.queue(max_size=10).launch(share=False)
|
app_training.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from constants import MODEL_LIBRARY_ORG_NAME, SAMPLE_MODEL_REPO, UploadTarget
|
10 |
+
from inference import InferencePipeline
|
11 |
+
from trainer import Trainer
|
12 |
+
|
13 |
+
|
14 |
+
def create_training_demo(trainer: Trainer,
|
15 |
+
pipe: InferencePipeline | None = None) -> gr.Blocks:
|
16 |
+
with gr.Blocks() as demo:
|
17 |
+
with gr.Row():
|
18 |
+
with gr.Column():
|
19 |
+
with gr.Box():
|
20 |
+
gr.Markdown('Training Data')
|
21 |
+
training_video = gr.File(label='Training video')
|
22 |
+
training_prompt = gr.Textbox(
|
23 |
+
label='Training prompt',
|
24 |
+
max_lines=1,
|
25 |
+
placeholder='A man is surfing')
|
26 |
+
gr.Markdown('''
|
27 |
+
- Upload a video and write a prompt describing the video.
|
28 |
+
''')
|
29 |
+
with gr.Box():
|
30 |
+
gr.Markdown('Output Model')
|
31 |
+
output_model_name = gr.Text(label='Name of your model',
|
32 |
+
max_lines=1)
|
33 |
+
delete_existing_model = gr.Checkbox(
|
34 |
+
label='Delete existing model of the same name',
|
35 |
+
value=False)
|
36 |
+
validation_prompt = gr.Text(label='Validation Prompt')
|
37 |
+
with gr.Box():
|
38 |
+
gr.Markdown('Upload Settings')
|
39 |
+
with gr.Row():
|
40 |
+
upload_to_hub = gr.Checkbox(
|
41 |
+
label='Upload model to Hub', value=True)
|
42 |
+
use_private_repo = gr.Checkbox(label='Private',
|
43 |
+
value=True)
|
44 |
+
delete_existing_repo = gr.Checkbox(
|
45 |
+
label='Delete existing repo of the same name',
|
46 |
+
value=False)
|
47 |
+
upload_to = gr.Radio(
|
48 |
+
label='Upload to',
|
49 |
+
choices=[_.value for _ in UploadTarget],
|
50 |
+
value=UploadTarget.MODEL_LIBRARY.value)
|
51 |
+
gr.Markdown(f'''
|
52 |
+
- By default, trained models will be uploaded to [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (see [this example model](https://huggingface.co/{SAMPLE_MODEL_REPO})).
|
53 |
+
- You can also choose "Personal Profile", in which case, the model will be uploaded to https://huggingface.co/{{your_username}}/{{model_name}}.
|
54 |
+
''')
|
55 |
+
|
56 |
+
with gr.Box():
|
57 |
+
gr.Markdown('Training Parameters')
|
58 |
+
with gr.Row():
|
59 |
+
base_model = gr.Text(label='Base Model',
|
60 |
+
value='CompVis/stable-diffusion-v1-4',
|
61 |
+
max_lines=1)
|
62 |
+
resolution = gr.Dropdown(choices=['512', '768'],
|
63 |
+
value='512',
|
64 |
+
label='Resolution',
|
65 |
+
visible=False)
|
66 |
+
num_training_steps = gr.Number(
|
67 |
+
label='Number of Training Steps', value=300, precision=0)
|
68 |
+
learning_rate = gr.Number(label='Learning Rate',
|
69 |
+
value=0.000035)
|
70 |
+
gradient_accumulation = gr.Number(
|
71 |
+
label='Number of Gradient Accumulation',
|
72 |
+
value=1,
|
73 |
+
precision=0)
|
74 |
+
seed = gr.Slider(label='Seed',
|
75 |
+
minimum=0,
|
76 |
+
maximum=100000,
|
77 |
+
step=1,
|
78 |
+
value=0)
|
79 |
+
fp16 = gr.Checkbox(label='FP16', value=True)
|
80 |
+
use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=False)
|
81 |
+
checkpointing_steps = gr.Number(label='Checkpointing Steps',
|
82 |
+
value=1000,
|
83 |
+
precision=0)
|
84 |
+
validation_epochs = gr.Number(label='Validation Epochs',
|
85 |
+
value=100,
|
86 |
+
precision=0)
|
87 |
+
gr.Markdown('''
|
88 |
+
- The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
|
89 |
+
- It takes a few minutes to download the base model first.
|
90 |
+
- Expected time to train a model for 300 steps: 20 minutes with T4, 8 minutes with A10G, (4 minutes with A100)
|
91 |
+
- It takes a few minutes to upload your trained model.
|
92 |
+
- You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
|
93 |
+
- You can check the training status by pressing the "Open logs" button if you are running this on your Space.
|
94 |
+
''')
|
95 |
+
|
96 |
+
remove_gpu_after_training = gr.Checkbox(
|
97 |
+
label='Remove GPU after training',
|
98 |
+
value=False,
|
99 |
+
interactive=bool(os.getenv('SPACE_ID')),
|
100 |
+
visible=False)
|
101 |
+
run_button = gr.Button('Start Training')
|
102 |
+
|
103 |
+
with gr.Box():
|
104 |
+
gr.Markdown('Output message')
|
105 |
+
output_message = gr.Markdown()
|
106 |
+
|
107 |
+
if pipe is not None:
|
108 |
+
run_button.click(fn=pipe.clear)
|
109 |
+
run_button.click(fn=trainer.run,
|
110 |
+
inputs=[
|
111 |
+
training_video,
|
112 |
+
training_prompt,
|
113 |
+
output_model_name,
|
114 |
+
delete_existing_model,
|
115 |
+
validation_prompt,
|
116 |
+
base_model,
|
117 |
+
resolution,
|
118 |
+
num_training_steps,
|
119 |
+
learning_rate,
|
120 |
+
gradient_accumulation,
|
121 |
+
seed,
|
122 |
+
fp16,
|
123 |
+
use_8bit_adam,
|
124 |
+
checkpointing_steps,
|
125 |
+
validation_epochs,
|
126 |
+
upload_to_hub,
|
127 |
+
use_private_repo,
|
128 |
+
delete_existing_repo,
|
129 |
+
upload_to,
|
130 |
+
remove_gpu_after_training,
|
131 |
+
],
|
132 |
+
outputs=output_message)
|
133 |
+
return demo
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == '__main__':
|
137 |
+
hf_token = os.getenv('HF_TOKEN')
|
138 |
+
trainer = Trainer(hf_token)
|
139 |
+
demo = create_training_demo(trainer)
|
140 |
+
demo.queue(max_size=1).launch(share=False)
|
app_upload.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import pathlib
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import slugify
|
9 |
+
|
10 |
+
from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
|
11 |
+
from uploader import Uploader
|
12 |
+
from utils import find_exp_dirs
|
13 |
+
|
14 |
+
|
15 |
+
class ModelUploader(Uploader):
|
16 |
+
def upload_model(
|
17 |
+
self,
|
18 |
+
folder_path: str,
|
19 |
+
repo_name: str,
|
20 |
+
upload_to: str,
|
21 |
+
private: bool,
|
22 |
+
delete_existing_repo: bool,
|
23 |
+
) -> str:
|
24 |
+
if not folder_path:
|
25 |
+
raise ValueError
|
26 |
+
if not repo_name:
|
27 |
+
repo_name = pathlib.Path(folder_path).name
|
28 |
+
repo_name = slugify.slugify(repo_name)
|
29 |
+
|
30 |
+
if upload_to == UploadTarget.PERSONAL_PROFILE.value:
|
31 |
+
organization = ''
|
32 |
+
elif upload_to == UploadTarget.MODEL_LIBRARY.value:
|
33 |
+
organization = MODEL_LIBRARY_ORG_NAME
|
34 |
+
else:
|
35 |
+
raise ValueError
|
36 |
+
|
37 |
+
return self.upload(folder_path,
|
38 |
+
repo_name,
|
39 |
+
organization=organization,
|
40 |
+
private=private,
|
41 |
+
delete_existing_repo=delete_existing_repo)
|
42 |
+
|
43 |
+
|
44 |
+
def load_local_model_list() -> dict:
|
45 |
+
choices = find_exp_dirs()
|
46 |
+
return gr.update(choices=choices, value=choices[0] if choices else None)
|
47 |
+
|
48 |
+
|
49 |
+
def create_upload_demo(hf_token: str | None) -> gr.Blocks:
|
50 |
+
uploader = ModelUploader(hf_token)
|
51 |
+
model_dirs = find_exp_dirs()
|
52 |
+
|
53 |
+
with gr.Blocks() as demo:
|
54 |
+
with gr.Box():
|
55 |
+
gr.Markdown('Local Models')
|
56 |
+
reload_button = gr.Button('Reload Model List')
|
57 |
+
model_dir = gr.Dropdown(
|
58 |
+
label='Model names',
|
59 |
+
choices=model_dirs,
|
60 |
+
value=model_dirs[0] if model_dirs else None)
|
61 |
+
with gr.Box():
|
62 |
+
gr.Markdown('Upload Settings')
|
63 |
+
with gr.Row():
|
64 |
+
use_private_repo = gr.Checkbox(label='Private', value=True)
|
65 |
+
delete_existing_repo = gr.Checkbox(
|
66 |
+
label='Delete existing repo of the same name', value=False)
|
67 |
+
upload_to = gr.Radio(label='Upload to',
|
68 |
+
choices=[_.value for _ in UploadTarget],
|
69 |
+
value=UploadTarget.MODEL_LIBRARY.value)
|
70 |
+
model_name = gr.Textbox(label='Model Name')
|
71 |
+
upload_button = gr.Button('Upload')
|
72 |
+
gr.Markdown(f'''
|
73 |
+
- You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
|
74 |
+
''')
|
75 |
+
with gr.Box():
|
76 |
+
gr.Markdown('Output message')
|
77 |
+
output_message = gr.Markdown()
|
78 |
+
|
79 |
+
reload_button.click(fn=load_local_model_list,
|
80 |
+
inputs=None,
|
81 |
+
outputs=model_dir)
|
82 |
+
upload_button.click(fn=uploader.upload_model,
|
83 |
+
inputs=[
|
84 |
+
model_dir,
|
85 |
+
model_name,
|
86 |
+
upload_to,
|
87 |
+
use_private_repo,
|
88 |
+
delete_existing_repo,
|
89 |
+
],
|
90 |
+
outputs=output_message)
|
91 |
+
|
92 |
+
return demo
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
import os
|
97 |
+
|
98 |
+
hf_token = os.getenv('HF_TOKEN')
|
99 |
+
demo = create_upload_demo(hf_token)
|
100 |
+
demo.queue(max_size=1).launch(share=False)
|
constants.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
|
3 |
+
|
4 |
+
class UploadTarget(enum.Enum):
|
5 |
+
PERSONAL_PROFILE = 'Personal Profile'
|
6 |
+
MODEL_LIBRARY = 'Tune-A-Video Library'
|
7 |
+
|
8 |
+
|
9 |
+
MODEL_LIBRARY_ORG_NAME = 'Tune-A-Video-library'
|
10 |
+
SAMPLE_MODEL_REPO = 'Tune-A-Video-library/a-man-is-surfing'
|
inference.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import gc
|
4 |
+
import pathlib
|
5 |
+
import sys
|
6 |
+
import tempfile
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import imageio
|
10 |
+
import PIL.Image
|
11 |
+
import torch
|
12 |
+
from diffusers.utils.import_utils import is_xformers_available
|
13 |
+
from einops import rearrange
|
14 |
+
from huggingface_hub import ModelCard
|
15 |
+
|
16 |
+
sys.path.append('Tune-A-Video')
|
17 |
+
|
18 |
+
from tuneavideo.models.unet import UNet3DConditionModel
|
19 |
+
from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
|
20 |
+
|
21 |
+
|
22 |
+
class InferencePipeline:
|
23 |
+
def __init__(self, hf_token: str | None = None):
|
24 |
+
self.hf_token = hf_token
|
25 |
+
self.pipe = None
|
26 |
+
self.device = torch.device(
|
27 |
+
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
28 |
+
self.model_id = None
|
29 |
+
|
30 |
+
def clear(self) -> None:
|
31 |
+
self.model_id = None
|
32 |
+
del self.pipe
|
33 |
+
self.pipe = None
|
34 |
+
torch.cuda.empty_cache()
|
35 |
+
gc.collect()
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def check_if_model_is_local(model_id: str) -> bool:
|
39 |
+
return pathlib.Path(model_id).exists()
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def get_model_card(model_id: str,
|
43 |
+
hf_token: str | None = None) -> ModelCard:
|
44 |
+
if InferencePipeline.check_if_model_is_local(model_id):
|
45 |
+
card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
|
46 |
+
else:
|
47 |
+
card_path = model_id
|
48 |
+
return ModelCard.load(card_path, token=hf_token)
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def get_base_model_info(model_id: str, hf_token: str | None = None) -> str:
|
52 |
+
card = InferencePipeline.get_model_card(model_id, hf_token)
|
53 |
+
return card.data.base_model
|
54 |
+
|
55 |
+
def load_pipe(self, model_id: str) -> None:
|
56 |
+
if model_id == self.model_id:
|
57 |
+
return
|
58 |
+
base_model_id = self.get_base_model_info(model_id, self.hf_token)
|
59 |
+
unet = UNet3DConditionModel.from_pretrained(
|
60 |
+
model_id,
|
61 |
+
subfolder='unet',
|
62 |
+
torch_dtype=torch.float16,
|
63 |
+
use_auth_token=self.hf_token)
|
64 |
+
pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
|
65 |
+
unet=unet,
|
66 |
+
torch_dtype=torch.float16,
|
67 |
+
use_auth_token=self.hf_token)
|
68 |
+
pipe = pipe.to(self.device)
|
69 |
+
if is_xformers_available():
|
70 |
+
pipe.unet.enable_xformers_memory_efficient_attention()
|
71 |
+
self.pipe = pipe
|
72 |
+
self.model_id = model_id # type: ignore
|
73 |
+
|
74 |
+
def run(
|
75 |
+
self,
|
76 |
+
model_id: str,
|
77 |
+
prompt: str,
|
78 |
+
video_length: int,
|
79 |
+
fps: int,
|
80 |
+
seed: int,
|
81 |
+
n_steps: int,
|
82 |
+
guidance_scale: float,
|
83 |
+
) -> PIL.Image.Image:
|
84 |
+
if not torch.cuda.is_available():
|
85 |
+
raise gr.Error('CUDA is not available.')
|
86 |
+
|
87 |
+
self.load_pipe(model_id)
|
88 |
+
|
89 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
90 |
+
out = self.pipe(
|
91 |
+
prompt,
|
92 |
+
video_length=video_length,
|
93 |
+
width=512,
|
94 |
+
height=512,
|
95 |
+
num_inference_steps=n_steps,
|
96 |
+
guidance_scale=guidance_scale,
|
97 |
+
generator=generator,
|
98 |
+
) # type: ignore
|
99 |
+
|
100 |
+
frames = rearrange(out.videos[0], 'c t h w -> t h w c')
|
101 |
+
frames = (frames * 255).to(torch.uint8).numpy()
|
102 |
+
|
103 |
+
out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
104 |
+
writer = imageio.get_writer(out_file.name, fps=fps)
|
105 |
+
for frame in frames:
|
106 |
+
writer.append_data(frame)
|
107 |
+
writer.close()
|
108 |
+
|
109 |
+
return out_file.name
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ffmpeg
|
patch
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diff --git a/train_tuneavideo.py b/train_tuneavideo.py
|
2 |
+
index 66d51b2..86b2a5d 100644
|
3 |
+
--- a/train_tuneavideo.py
|
4 |
+
+++ b/train_tuneavideo.py
|
5 |
+
@@ -94,8 +94,8 @@ def main(
|
6 |
+
|
7 |
+
# Handle the output folder creation
|
8 |
+
if accelerator.is_main_process:
|
9 |
+
- now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
10 |
+
- output_dir = os.path.join(output_dir, now)
|
11 |
+
+ #now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
12 |
+
+ #output_dir = os.path.join(output_dir, now)
|
13 |
+
os.makedirs(output_dir, exist_ok=True)
|
14 |
+
OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
|
15 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.15.0
|
2 |
+
bitsandbytes==0.35.4
|
3 |
+
decord==0.6.0
|
4 |
+
diffusers[torch]==0.11.1
|
5 |
+
einops==0.6.0
|
6 |
+
ftfy==6.1.1
|
7 |
+
gradio==3.16.2
|
8 |
+
huggingface-hub==0.12.0
|
9 |
+
imageio==2.25.0
|
10 |
+
imageio-ffmpeg==0.4.8
|
11 |
+
omegaconf==2.3.0
|
12 |
+
Pillow==9.4.0
|
13 |
+
python-slugify==7.0.0
|
14 |
+
tensorboard==2.11.2
|
15 |
+
torch==1.13.1
|
16 |
+
torchvision==0.14.1
|
17 |
+
transformers==4.26.0
|
18 |
+
triton==2.0.0.dev20221202
|
19 |
+
xformers==0.0.16
|
style.css
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|
trainer.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
import os
|
5 |
+
import pathlib
|
6 |
+
import shlex
|
7 |
+
import shutil
|
8 |
+
import subprocess
|
9 |
+
import sys
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
import slugify
|
13 |
+
import torch
|
14 |
+
from huggingface_hub import HfApi
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
|
17 |
+
from app_upload import ModelUploader
|
18 |
+
from utils import save_model_card
|
19 |
+
|
20 |
+
sys.path.append('Tune-A-Video')
|
21 |
+
|
22 |
+
URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
|
23 |
+
|
24 |
+
|
25 |
+
class Trainer:
|
26 |
+
def __init__(self, hf_token: str | None = None):
|
27 |
+
self.hf_token = hf_token
|
28 |
+
self.api = HfApi(token=hf_token)
|
29 |
+
self.model_uploader = ModelUploader(hf_token)
|
30 |
+
|
31 |
+
self.checkpoint_dir = pathlib.Path('checkpoints')
|
32 |
+
self.checkpoint_dir.mkdir(exist_ok=True)
|
33 |
+
|
34 |
+
def download_base_model(self, base_model_id: str) -> str:
|
35 |
+
model_dir = self.checkpoint_dir / base_model_id
|
36 |
+
if not model_dir.exists():
|
37 |
+
org_name = base_model_id.split('/')[0]
|
38 |
+
org_dir = self.checkpoint_dir / org_name
|
39 |
+
org_dir.mkdir(exist_ok=True)
|
40 |
+
subprocess.run(shlex.split(
|
41 |
+
f'git clone https://huggingface.co/{base_model_id}'),
|
42 |
+
cwd=org_dir)
|
43 |
+
return model_dir.as_posix()
|
44 |
+
|
45 |
+
def join_model_library_org(self) -> None:
|
46 |
+
subprocess.run(
|
47 |
+
shlex.split(
|
48 |
+
f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
|
49 |
+
))
|
50 |
+
|
51 |
+
def run(
|
52 |
+
self,
|
53 |
+
training_video: str,
|
54 |
+
training_prompt: str,
|
55 |
+
output_model_name: str,
|
56 |
+
overwrite_existing_model: bool,
|
57 |
+
validation_prompt: str,
|
58 |
+
base_model: str,
|
59 |
+
resolution_s: str,
|
60 |
+
n_steps: int,
|
61 |
+
learning_rate: float,
|
62 |
+
gradient_accumulation: int,
|
63 |
+
seed: int,
|
64 |
+
fp16: bool,
|
65 |
+
use_8bit_adam: bool,
|
66 |
+
checkpointing_steps: int,
|
67 |
+
validation_epochs: int,
|
68 |
+
upload_to_hub: bool,
|
69 |
+
use_private_repo: bool,
|
70 |
+
delete_existing_repo: bool,
|
71 |
+
upload_to: str,
|
72 |
+
remove_gpu_after_training: bool,
|
73 |
+
) -> str:
|
74 |
+
if not torch.cuda.is_available():
|
75 |
+
raise gr.Error('CUDA is not available.')
|
76 |
+
if training_video is None:
|
77 |
+
raise gr.Error('You need to upload a video.')
|
78 |
+
if not training_prompt:
|
79 |
+
raise gr.Error('The training prompt is missing.')
|
80 |
+
if not validation_prompt:
|
81 |
+
raise gr.Error('The validation prompt is missing.')
|
82 |
+
|
83 |
+
resolution = int(resolution_s)
|
84 |
+
|
85 |
+
if not output_model_name:
|
86 |
+
timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
87 |
+
output_model_name = f'tune-a-video-{timestamp}'
|
88 |
+
output_model_name = slugify.slugify(output_model_name)
|
89 |
+
|
90 |
+
repo_dir = pathlib.Path(__file__).parent
|
91 |
+
output_dir = repo_dir / 'experiments' / output_model_name
|
92 |
+
if overwrite_existing_model or upload_to_hub:
|
93 |
+
shutil.rmtree(output_dir, ignore_errors=True)
|
94 |
+
output_dir.mkdir(parents=True)
|
95 |
+
|
96 |
+
if upload_to_hub:
|
97 |
+
self.join_model_library_org()
|
98 |
+
|
99 |
+
config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml')
|
100 |
+
config.pretrained_model_path = self.download_base_model(base_model)
|
101 |
+
config.output_dir = output_dir.as_posix()
|
102 |
+
config.train_data.video_path = training_video.name # type: ignore
|
103 |
+
config.train_data.prompt = training_prompt
|
104 |
+
config.train_data.n_sample_frames = 8
|
105 |
+
config.train_data.width = resolution
|
106 |
+
config.train_data.height = resolution
|
107 |
+
config.train_data.sample_start_idx = 0
|
108 |
+
config.train_data.sample_frame_rate = 1
|
109 |
+
config.validation_data.prompts = [validation_prompt]
|
110 |
+
config.validation_data.video_length = 8
|
111 |
+
config.validation_data.width = resolution
|
112 |
+
config.validation_data.height = resolution
|
113 |
+
config.validation_data.num_inference_steps = 50
|
114 |
+
config.validation_data.guidance_scale = 7.5
|
115 |
+
config.learning_rate = learning_rate
|
116 |
+
config.gradient_accumulation_steps = gradient_accumulation
|
117 |
+
config.train_batch_size = 1
|
118 |
+
config.max_train_steps = n_steps
|
119 |
+
config.checkpointing_steps = checkpointing_steps
|
120 |
+
config.validation_steps = validation_epochs
|
121 |
+
config.seed = seed
|
122 |
+
config.mixed_precision = 'fp16' if fp16 else ''
|
123 |
+
config.use_8bit_adam = use_8bit_adam
|
124 |
+
|
125 |
+
config_path = output_dir / 'config.yaml'
|
126 |
+
with open(config_path, 'w') as f:
|
127 |
+
OmegaConf.save(config, f)
|
128 |
+
|
129 |
+
command = f'accelerate launch Tune-A-Video/train_tuneavideo.py --config {config_path}'
|
130 |
+
subprocess.run(shlex.split(command))
|
131 |
+
save_model_card(save_dir=output_dir,
|
132 |
+
base_model=base_model,
|
133 |
+
training_prompt=training_prompt,
|
134 |
+
test_prompt=validation_prompt,
|
135 |
+
test_image_dir='samples')
|
136 |
+
|
137 |
+
message = 'Training completed!'
|
138 |
+
print(message)
|
139 |
+
|
140 |
+
if upload_to_hub:
|
141 |
+
upload_message = self.model_uploader.upload_model(
|
142 |
+
folder_path=output_dir.as_posix(),
|
143 |
+
repo_name=output_model_name,
|
144 |
+
upload_to=upload_to,
|
145 |
+
private=use_private_repo,
|
146 |
+
delete_existing_repo=delete_existing_repo)
|
147 |
+
print(upload_message)
|
148 |
+
message = message + '\n' + upload_message
|
149 |
+
|
150 |
+
if remove_gpu_after_training:
|
151 |
+
space_id = os.getenv('SPACE_ID')
|
152 |
+
if space_id:
|
153 |
+
self.api.request_space_hardware(repo_id=space_id,
|
154 |
+
hardware='cpu-basic')
|
155 |
+
|
156 |
+
return message
|
uploader.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from huggingface_hub import HfApi
|
4 |
+
|
5 |
+
|
6 |
+
class Uploader:
|
7 |
+
def __init__(self, hf_token: str | None):
|
8 |
+
self.api = HfApi(token=hf_token)
|
9 |
+
|
10 |
+
def get_username(self) -> str:
|
11 |
+
return self.api.whoami()['name']
|
12 |
+
|
13 |
+
def upload(self,
|
14 |
+
folder_path: str,
|
15 |
+
repo_name: str,
|
16 |
+
organization: str = '',
|
17 |
+
repo_type: str = 'model',
|
18 |
+
private: bool = True,
|
19 |
+
delete_existing_repo: bool = False) -> str:
|
20 |
+
if not folder_path:
|
21 |
+
raise ValueError
|
22 |
+
if not repo_name:
|
23 |
+
raise ValueError
|
24 |
+
if not organization:
|
25 |
+
organization = self.get_username()
|
26 |
+
repo_id = f'{organization}/{repo_name}'
|
27 |
+
if delete_existing_repo:
|
28 |
+
try:
|
29 |
+
self.api.delete_repo(repo_id, repo_type=repo_type)
|
30 |
+
except Exception:
|
31 |
+
pass
|
32 |
+
try:
|
33 |
+
self.api.create_repo(repo_id, repo_type=repo_type, private=private)
|
34 |
+
self.api.upload_folder(repo_id=repo_id,
|
35 |
+
folder_path=folder_path,
|
36 |
+
path_in_repo='.',
|
37 |
+
repo_type=repo_type)
|
38 |
+
url = f'https://huggingface.co/{repo_id}'
|
39 |
+
message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
|
40 |
+
except Exception as e:
|
41 |
+
message = str(e)
|
42 |
+
return message
|
utils.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import pathlib
|
4 |
+
|
5 |
+
|
6 |
+
def find_exp_dirs() -> list[str]:
|
7 |
+
repo_dir = pathlib.Path(__file__).parent
|
8 |
+
exp_root_dir = repo_dir / 'experiments'
|
9 |
+
if not exp_root_dir.exists():
|
10 |
+
return []
|
11 |
+
exp_dirs = sorted(exp_root_dir.glob('*'))
|
12 |
+
exp_dirs = [
|
13 |
+
exp_dir for exp_dir in exp_dirs
|
14 |
+
if (exp_dir / 'model_index.json').exists()
|
15 |
+
]
|
16 |
+
return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
|
17 |
+
|
18 |
+
|
19 |
+
def save_model_card(
|
20 |
+
save_dir: pathlib.Path,
|
21 |
+
base_model: str,
|
22 |
+
training_prompt: str,
|
23 |
+
test_prompt: str = '',
|
24 |
+
test_image_dir: str = '',
|
25 |
+
) -> None:
|
26 |
+
image_str = ''
|
27 |
+
if test_prompt and test_image_dir:
|
28 |
+
image_paths = sorted((save_dir / test_image_dir).glob('*.gif'))
|
29 |
+
if image_paths:
|
30 |
+
image_path = image_paths[-1]
|
31 |
+
rel_path = image_path.relative_to(save_dir)
|
32 |
+
image_str = f'''## Samples
|
33 |
+
Test prompt: {test_prompt}
|
34 |
+
|
35 |
+
![{image_path.stem}]({rel_path})'''
|
36 |
+
|
37 |
+
model_card = f'''---
|
38 |
+
license: creativeml-openrail-m
|
39 |
+
base_model: {base_model}
|
40 |
+
training_prompt: {training_prompt}
|
41 |
+
tags:
|
42 |
+
- stable-diffusion
|
43 |
+
- stable-diffusion-diffusers
|
44 |
+
- text-to-image
|
45 |
+
- diffusers
|
46 |
+
- text-to-video
|
47 |
+
- tune-a-video
|
48 |
+
inference: false
|
49 |
+
---
|
50 |
+
|
51 |
+
# Tune-A-Video - {save_dir.name}
|
52 |
+
|
53 |
+
## Model description
|
54 |
+
- Base model: [{base_model}](https://huggingface.co/{base_model})
|
55 |
+
- Training prompt: {training_prompt}
|
56 |
+
|
57 |
+
{image_str}
|
58 |
+
|
59 |
+
## Related papers:
|
60 |
+
- [Tune-A-Video](https://arxiv.org/abs/2212.11565): One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation
|
61 |
+
- [Stable-Diffusion](https://arxiv.org/abs/2112.10752): High-Resolution Image Synthesis with Latent Diffusion Models
|
62 |
+
'''
|
63 |
+
|
64 |
+
with open(save_dir / 'README.md', 'w') as f:
|
65 |
+
f.write(model_card)
|