DeepBeepMeep
commited on
Commit
·
e0a0e76
1
Parent(s):
fcbb34f
RIFLEx support
Browse files- README.md +6 -12
- gradio_server.py +4 -4
- wan/image2video.py +1 -1
- wan/modules/model.py +2 -2
- wan/text2video.py +1 -1
README.md
CHANGED
|
@@ -45,25 +45,22 @@ It is an illustration on how one can set up on an existing model some fast and p
|
|
| 45 |
|
| 46 |
For more information on how to use the mmpg module, please go to: https://github.com/deepbeepmeep/mmgp
|
| 47 |
|
| 48 |
-
You will find the original
|
|
|
|
| 49 |
|
| 50 |
|
| 51 |
|
| 52 |
## Installation Guide for Linux and Windows
|
| 53 |
|
| 54 |
-
We provide an `environment.yml` file for setting up a Conda environment.
|
| 55 |
-
Conda's installation instructions are available [here](https://docs.anaconda.com/free/miniconda/index.html).
|
| 56 |
|
| 57 |
This app has been tested on Python 3.10 / 2.6.0 / Cuda 12.4.\
|
| 58 |
|
| 59 |
-
```shell
|
| 60 |
-
# 1 - conda. Prepare and activate a conda environment
|
| 61 |
-
conda env create -f environment.yml
|
| 62 |
-
conda activate Wan2
|
| 63 |
|
| 64 |
-
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
# 1
|
| 67 |
pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
|
| 68 |
|
| 69 |
|
|
@@ -81,9 +78,6 @@ pip install -e .
|
|
| 81 |
# 3.2 optional Flash attention support (easy to install on Linux but much harder on Windows)
|
| 82 |
python -m pip install flash-attn==2.7.2.post1
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
```
|
| 88 |
|
| 89 |
Note that *Flash attention* and *Sage attention* are quite complex to install on Windows but offers a better memory management (and consequently longer videos) than the default *sdpa attention*.
|
|
|
|
| 45 |
|
| 46 |
For more information on how to use the mmpg module, please go to: https://github.com/deepbeepmeep/mmgp
|
| 47 |
|
| 48 |
+
You will find the original Wan2.1 Video repository here: https://github.com/Wan-Video/Wan2.1
|
| 49 |
+
|
| 50 |
|
| 51 |
|
| 52 |
|
| 53 |
## Installation Guide for Linux and Windows
|
| 54 |
|
|
|
|
|
|
|
| 55 |
|
| 56 |
This app has been tested on Python 3.10 / 2.6.0 / Cuda 12.4.\
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
```shell
|
| 60 |
+
# 0 Create a Python 3.10.9 environment or a venv using python
|
| 61 |
+
conda create -name Wan2GP python==3.10.9 #if you have conda
|
| 62 |
|
| 63 |
+
# 1 Install pytorch 2.6.0
|
| 64 |
pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
|
| 65 |
|
| 66 |
|
|
|
|
| 78 |
# 3.2 optional Flash attention support (easy to install on Linux but much harder on Windows)
|
| 79 |
python -m pip install flash-attn==2.7.2.post1
|
| 80 |
|
|
|
|
|
|
|
|
|
|
| 81 |
```
|
| 82 |
|
| 83 |
Note that *Flash attention* and *Sage attention* are quite complex to install on Windows but offers a better memory management (and consequently longer videos) than the default *sdpa attention*.
|
gradio_server.py
CHANGED
|
@@ -975,9 +975,9 @@ def create_demo():
|
|
| 975 |
gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance : 24 GB of VRAM (RTX 3090 / RTX 4090), the limits are as follows:")
|
| 976 |
gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
|
| 977 |
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
|
| 978 |
-
gr.Markdown("- 1280 x 720 with a 14B model:
|
| 979 |
-
gr.Markdown("Note that the VAE stages (encoding / decoding at image2video ) or just the decoding at text2video will create a temporary VRAM
|
| 980 |
-
gr.Markdown("It is not recommmended to generate a video longer than 8s even if there is still some VRAM left as some artifact may appear")
|
| 981 |
gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
| 982 |
|
| 983 |
|
|
@@ -996,7 +996,7 @@ def create_demo():
|
|
| 996 |
index = 0 if index ==0 else index
|
| 997 |
transformer_t2v_choice = gr.Dropdown(
|
| 998 |
choices=[
|
| 999 |
-
("WAN 2.1 1.3B Text to Video 16 bits - the small model for fast generations with low VRAM requirements", 0),
|
| 1000 |
("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
|
| 1001 |
("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
|
| 1002 |
],
|
|
|
|
| 975 |
gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance : 24 GB of VRAM (RTX 3090 / RTX 4090), the limits are as follows:")
|
| 976 |
gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
|
| 977 |
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
|
| 978 |
+
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
|
| 979 |
+
gr.Markdown("Note that the VAE stages (encoding / decoding at image2video ) or just the decoding at text2video will create a temporary VRAM peaks (up to 12GB for 420P and 22 GB for 720P)")
|
| 980 |
+
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifact may appear")
|
| 981 |
gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
| 982 |
|
| 983 |
|
|
|
|
| 996 |
index = 0 if index ==0 else index
|
| 997 |
transformer_t2v_choice = gr.Dropdown(
|
| 998 |
choices=[
|
| 999 |
+
("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0),
|
| 1000 |
("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
|
| 1001 |
("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
|
| 1002 |
],
|
wan/image2video.py
CHANGED
|
@@ -289,7 +289,7 @@ class WanI2V:
|
|
| 289 |
# sample videos
|
| 290 |
latent = noise
|
| 291 |
|
| 292 |
-
freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k =
|
| 293 |
|
| 294 |
arg_c = {
|
| 295 |
'context': [context[0]],
|
|
|
|
| 289 |
# sample videos
|
| 290 |
latent = noise
|
| 291 |
|
| 292 |
+
freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 6 if enable_RIFLEx else None )
|
| 293 |
|
| 294 |
arg_c = {
|
| 295 |
'context': [context[0]],
|
wan/modules/model.py
CHANGED
|
@@ -654,8 +654,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 654 |
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 655 |
|
| 656 |
|
| 657 |
-
freqs = torch.cat([
|
| 658 |
-
rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ), #44
|
| 659 |
rope_params(1024, 2 * (d // 6)), #42
|
| 660 |
rope_params(1024, 2 * (d // 6)) #42
|
| 661 |
],dim=1)
|
|
|
|
| 654 |
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 655 |
|
| 656 |
|
| 657 |
+
freqs = torch.cat([
|
| 658 |
+
rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)), #44
|
| 659 |
rope_params(1024, 2 * (d // 6)), #42
|
| 660 |
rope_params(1024, 2 * (d // 6)) #42
|
| 661 |
],dim=1)
|
wan/text2video.py
CHANGED
|
@@ -240,7 +240,7 @@ class WanT2V:
|
|
| 240 |
# k, N_k = identify_k(10000, 44, 26)
|
| 241 |
# print(f"value nb latent frames={nf}, k={k}, n_k={N_k}")
|
| 242 |
|
| 243 |
-
freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k =
|
| 244 |
|
| 245 |
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 246 |
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
|
|
|
| 240 |
# k, N_k = identify_k(10000, 44, 26)
|
| 241 |
# print(f"value nb latent frames={nf}, k={k}, n_k={N_k}")
|
| 242 |
|
| 243 |
+
freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 6 if enable_RIFLEx else None )
|
| 244 |
|
| 245 |
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 246 |
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|