DeepBeepMeep commited on
Commit
e0a0e76
·
1 Parent(s): fcbb34f

RIFLEx support

Browse files
README.md CHANGED
@@ -45,25 +45,22 @@ It is an illustration on how one can set up on an existing model some fast and p
45
 
46
  For more information on how to use the mmpg module, please go to: https://github.com/deepbeepmeep/mmgp
47
 
48
- You will find the original Hunyuan Video repository here: https://github.com/deepbeepmeep/Wan2GP
 
49
 
50
 
51
 
52
  ## Installation Guide for Linux and Windows
53
 
54
- We provide an `environment.yml` file for setting up a Conda environment.
55
- Conda's installation instructions are available [here](https://docs.anaconda.com/free/miniconda/index.html).
56
 
57
  This app has been tested on Python 3.10 / 2.6.0 / Cuda 12.4.\
58
 
59
- ```shell
60
- # 1 - conda. Prepare and activate a conda environment
61
- conda env create -f environment.yml
62
- conda activate Wan2
63
 
64
- # OR
 
 
65
 
66
- # 1 - venv. Alternatively create a python 3.10 venv and then do the following
67
  pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
68
 
69
 
@@ -81,9 +78,6 @@ pip install -e .
81
  # 3.2 optional Flash attention support (easy to install on Linux but much harder on Windows)
82
  python -m pip install flash-attn==2.7.2.post1
83
 
84
-
85
-
86
-
87
  ```
88
 
89
  Note that *Flash attention* and *Sage attention* are quite complex to install on Windows but offers a better memory management (and consequently longer videos) than the default *sdpa attention*.
 
45
 
46
  For more information on how to use the mmpg module, please go to: https://github.com/deepbeepmeep/mmgp
47
 
48
+ You will find the original Wan2.1 Video repository here: https://github.com/Wan-Video/Wan2.1
49
+
50
 
51
 
52
 
53
  ## Installation Guide for Linux and Windows
54
 
 
 
55
 
56
  This app has been tested on Python 3.10 / 2.6.0 / Cuda 12.4.\
57
 
 
 
 
 
58
 
59
+ ```shell
60
+ # 0 Create a Python 3.10.9 environment or a venv using python
61
+ conda create -name Wan2GP python==3.10.9 #if you have conda
62
 
63
+ # 1 Install pytorch 2.6.0
64
  pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
65
 
66
 
 
78
  # 3.2 optional Flash attention support (easy to install on Linux but much harder on Windows)
79
  python -m pip install flash-attn==2.7.2.post1
80
 
 
 
 
81
  ```
82
 
83
  Note that *Flash attention* and *Sage attention* are quite complex to install on Windows but offers a better memory management (and consequently longer videos) than the default *sdpa attention*.
gradio_server.py CHANGED
@@ -975,9 +975,9 @@ def create_demo():
975
  gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance : 24 GB of VRAM (RTX 3090 / RTX 4090), the limits are as follows:")
976
  gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
977
  gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
978
- gr.Markdown("- 1280 x 720 with a 14B model: 192 frames (8s): 11 GB of VRAM")
979
- gr.Markdown("Note that the VAE stages (encoding / decoding at image2video ) or just the decoding at text2video will create a temporary VRAM peak (up to 12GB for 420P and 22 GB for 720P)")
980
- gr.Markdown("It is not recommmended to generate a video longer than 8s even if there is still some VRAM left as some artifact may appear")
981
  gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
982
 
983
 
@@ -996,7 +996,7 @@ def create_demo():
996
  index = 0 if index ==0 else index
997
  transformer_t2v_choice = gr.Dropdown(
998
  choices=[
999
- ("WAN 2.1 1.3B Text to Video 16 bits - the small model for fast generations with low VRAM requirements", 0),
1000
  ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
1001
  ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
1002
  ],
 
975
  gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance : 24 GB of VRAM (RTX 3090 / RTX 4090), the limits are as follows:")
976
  gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
977
  gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
978
+ gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
979
+ gr.Markdown("Note that the VAE stages (encoding / decoding at image2video ) or just the decoding at text2video will create a temporary VRAM peaks (up to 12GB for 420P and 22 GB for 720P)")
980
+ gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifact may appear")
981
  gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
982
 
983
 
 
996
  index = 0 if index ==0 else index
997
  transformer_t2v_choice = gr.Dropdown(
998
  choices=[
999
+ ("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0),
1000
  ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
1001
  ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
1002
  ],
wan/image2video.py CHANGED
@@ -289,7 +289,7 @@ class WanI2V:
289
  # sample videos
290
  latent = noise
291
 
292
- freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 4 if enable_RIFLEx else None )
293
 
294
  arg_c = {
295
  'context': [context[0]],
 
289
  # sample videos
290
  latent = noise
291
 
292
+ freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 6 if enable_RIFLEx else None )
293
 
294
  arg_c = {
295
  'context': [context[0]],
wan/modules/model.py CHANGED
@@ -654,8 +654,8 @@ class WanModel(ModelMixin, ConfigMixin):
654
  assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
655
 
656
 
657
- freqs = torch.cat([
658
- rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ), #44
659
  rope_params(1024, 2 * (d // 6)), #42
660
  rope_params(1024, 2 * (d // 6)) #42
661
  ],dim=1)
 
654
  assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
655
 
656
 
657
+ freqs = torch.cat([
658
+ rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)), #44
659
  rope_params(1024, 2 * (d // 6)), #42
660
  rope_params(1024, 2 * (d // 6)) #42
661
  ],dim=1)
wan/text2video.py CHANGED
@@ -240,7 +240,7 @@ class WanT2V:
240
  # k, N_k = identify_k(10000, 44, 26)
241
  # print(f"value nb latent frames={nf}, k={k}, n_k={N_k}")
242
 
243
- freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 4 if enable_RIFLEx else None )
244
 
245
  arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
246
  arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
 
240
  # k, N_k = identify_k(10000, 44, 26)
241
  # print(f"value nb latent frames={nf}, k={k}, n_k={N_k}")
242
 
243
+ freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 6 if enable_RIFLEx else None )
244
 
245
  arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
246
  arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}