rahul7star commited on
Commit
acb6e24
·
verified ·
1 Parent(s): cd5ed7e

Upload 27 files

Browse files
.gitignore ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hf_download/
2
+ outputs/
3
+ repo/
4
+ loras/
5
+ queue.json
6
+ settings.json
7
+ # Byte-compiled / optimized / DLL files
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+
12
+ # C extensions
13
+ *.so
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+ cover/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ .pybuilder/
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ # For a library or package, you might want to ignore these files since the code is
93
+ # intended to run in multiple environments; otherwise, check them in:
94
+ # .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # UV
104
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ #uv.lock
108
+
109
+ # poetry
110
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
111
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
112
+ # commonly ignored for libraries.
113
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
114
+ #poetry.lock
115
+
116
+ # pdm
117
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
118
+ #pdm.lock
119
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
120
+ # in version control.
121
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
122
+ .pdm.toml
123
+ .pdm-python
124
+ .pdm-build/
125
+
126
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
127
+ __pypackages__/
128
+
129
+ # Celery stuff
130
+ celerybeat-schedule
131
+ celerybeat.pid
132
+
133
+ # SageMath parsed files
134
+ *.sage.py
135
+
136
+ # Environments
137
+ .env
138
+ .venv
139
+ env/
140
+ venv/
141
+ ENV/
142
+ env.bak/
143
+ venv.bak/
144
+
145
+ # Spyder project settings
146
+ .spyderproject
147
+ .spyproject
148
+
149
+ # Rope project settings
150
+ .ropeproject
151
+
152
+ # mkdocs documentation
153
+ /site
154
+
155
+ # mypy
156
+ .mypy_cache/
157
+ .dmypy.json
158
+ dmypy.json
159
+
160
+ # Pyre type checker
161
+ .pyre/
162
+
163
+ # pytype static type analyzer
164
+ .pytype/
165
+
166
+ # Cython debug symbols
167
+ cython_debug/
168
+
169
+ # PyCharm
170
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
171
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
172
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
173
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
174
+ .idea/
175
+
176
+ # Ruff stuff:
177
+ .ruff_cache/
178
+
179
+ # PyPI configuration file
180
+ .pypirc
181
+
182
+ temp/
README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FramePack Studio
2
+
3
+ FramePack Studio is an enhanced version of the FramePack demo script, designed to create intricate video scenes with improved prompt adherence. This is very much a work in progress, expect some bugs and broken features.
4
+ ![screencapture-127-0-0-1-7860-2025-05-04-20_13_58](https://github.com/user-attachments/assets/8fcb90af-8c3f-47ca-8f23-61d9b59438ae)
5
+
6
+
7
+ ## Current Features
8
+
9
+ - **F1 and Original FramePack Models**: Run both in a single queue
10
+ - **Timestamped Prompts**: Define different prompts for specific time segments in your video
11
+ - **Prompt Blending**: Define the blending time between timestamped prompts
12
+ - **Basic LoRA Support**: Works with most (all?) hunyuan LoRAs
13
+ - **Queue System**: Process multiple generation jobs without blocking the interface
14
+ - **Metadata Saving/Import**: Prompt and seed are encoded into the output PNG, all other generation metadata is saved in a JSON file
15
+ - **I2V and T2V**: Works with or without an input image to allow for more flexibility when working with standard LoRAs
16
+ - **Latent Image Options**: When using T2V you can generate based on a black, white, green screen or pure noise image
17
+
18
+
19
+ ## Fresh Installation
20
+
21
+ ### Prerequisites
22
+
23
+ - Python 3.10+
24
+ - CUDA-compatible GPU with at least 8GB VRAM (16GB+ recommended)
25
+
26
+ ### Setup
27
+
28
+ Install via the Pinokio community script "FP-Studio" or:
29
+
30
+ 1. Clone the repository:
31
+ ```bash
32
+ git clone https://github.com/colinurbs/FramePack-Studio.git
33
+ cd FramePack-Studio
34
+ ```
35
+
36
+ 2. Install dependencies:
37
+ ```bash
38
+ pip install -r requirements.txt
39
+ ```
40
+
41
+ ## Usage
42
+
43
+ Run the studio interface:
44
+
45
+ ```bash
46
+ python studio.py
47
+ ```
48
+
49
+ Additional command line options:
50
+ - `--share`: Create a public Gradio link to share your interface
51
+ - `--server`: Specify the server address (default: 0.0.0.0)
52
+ - `--port`: Specify a custom port
53
+ - `--inbrowser`: Automatically open the interface in your browser
54
+
55
+ ## LoRAs
56
+
57
+ Add LoRAs to the /loras/ folder at the root of the installation. Select the LoRAs you wish to load and set the weights for each generation.
58
+
59
+ NOTE: slow lora loading is a known issue
60
+
61
+ ## Working with Timestamped Prompts
62
+
63
+ You can create videos with changing prompts over time using the following syntax:
64
+
65
+ ```
66
+ [0s: A serene forest with sunlight filtering through the trees ]
67
+ [5s: A deer appears in the clearing ]
68
+ [10s: The deer drinks from a small stream ]
69
+ ```
70
+
71
+ Each timestamp defines when that prompt should start influencing the generation. The system will (hopefully) smoothly transition between prompts for a cohesive video.
72
+
73
+ ## Credits
74
+ Many thanks to [Lvmin Zhang](https://github.com/lllyasviel) for the absolutely amazing work on the original [FramePack](https://github.com/lllyasviel/FramePack) code!
75
+
76
+ Thanks to [Rickard Edén](https://github.com/neph1) for the LoRA code and their general contributions to this growing FramePack scene!
77
+
78
+ Thanks to everyone who has joined the Discord, reported a bug, sumbitted a PR or helped with testing!
79
+
80
+
81
+
82
+ @article{zhang2025framepack,
83
+ title={Packing Input Frame Contexts in Next-Frame Prediction Models for Video Generation},
84
+ author={Lvmin Zhang and Maneesh Agrawala},
85
+ journal={Arxiv},
86
+ year={2025}
87
+ }
diffusers_helper/bucket_tools.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bucket_options = {
2
+ 128: [
3
+ (96, 160),
4
+ (112, 144),
5
+ (128, 128),
6
+ (144, 112),
7
+ (160, 96),
8
+ ],
9
+ 256: [
10
+ (192, 320),
11
+ (224, 288),
12
+ (256, 256),
13
+ (288, 224),
14
+ (320, 192),
15
+ ],
16
+ 384: [
17
+ (256, 512),
18
+ (320, 448),
19
+ (384, 384),
20
+ (448, 320),
21
+ (512, 256),
22
+ ],
23
+ 512: [
24
+ (352, 704),
25
+ (384, 640),
26
+ (448, 576),
27
+ (512, 512),
28
+ (576, 448),
29
+ (640, 384),
30
+ (704, 352),
31
+ ],
32
+ 640: [
33
+ (416, 960),
34
+ (448, 864),
35
+ (480, 832),
36
+ (512, 768),
37
+ (544, 704),
38
+ (576, 672),
39
+ (608, 640),
40
+ (640, 640),
41
+ (640, 608),
42
+ (672, 576),
43
+ (704, 544),
44
+ (768, 512),
45
+ (832, 480),
46
+ (864, 448),
47
+ (960, 416),
48
+ ],
49
+ 768: [
50
+ (512, 1024),
51
+ (576, 896),
52
+ (640, 832),
53
+ (704, 768),
54
+ (768, 768),
55
+ (768, 704),
56
+ (832, 640),
57
+ (896, 576),
58
+ (1024, 512),
59
+ ],
60
+ }
61
+
62
+
63
+ def find_nearest_bucket(h, w, resolution=640):
64
+ # Use the provided resolution or find the closest available bucket size
65
+ # print(f"find_nearest_bucket called with h={h}, w={w}, resolution={resolution}")
66
+
67
+ if resolution not in bucket_options:
68
+ # Find the closest available resolution
69
+ available_resolutions = list(bucket_options.keys())
70
+ closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution))
71
+ # print(f"Resolution {resolution} not found in bucket options, using closest available: {closest_resolution}")
72
+ resolution = closest_resolution
73
+ # else:
74
+ # print(f"Resolution {resolution} found in bucket options")
75
+
76
+ # Calculate the aspect ratio of the input image
77
+ input_aspect_ratio = w / h if h > 0 else 1.0
78
+ # print(f"Input aspect ratio: {input_aspect_ratio:.4f}")
79
+
80
+ min_diff = float('inf')
81
+ best_bucket = None
82
+
83
+ # Find the bucket size with the closest aspect ratio to the input image
84
+ for (bucket_h, bucket_w) in bucket_options[resolution]:
85
+ bucket_aspect_ratio = bucket_w / bucket_h if bucket_h > 0 else 1.0
86
+ # Calculate the difference in aspect ratios
87
+ diff = abs(bucket_aspect_ratio - input_aspect_ratio)
88
+ if diff < min_diff:
89
+ min_diff = diff
90
+ best_bucket = (bucket_h, bucket_w)
91
+ # print(f" Checking bucket ({bucket_h}, {bucket_w}), aspect ratio={bucket_aspect_ratio:.4f}, diff={diff:.4f}, current best={best_bucket}")
92
+
93
+ # print(f"Using resolution {resolution}, selected bucket: {best_bucket}")
94
+ return best_bucket
diffusers_helper/clip_vision.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def hf_clip_vision_encode(image, feature_extractor, image_encoder):
5
+ assert isinstance(image, np.ndarray)
6
+ assert image.ndim == 3 and image.shape[2] == 3
7
+ assert image.dtype == np.uint8
8
+
9
+ preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
10
+ image_encoder_output = image_encoder(**preprocessed)
11
+
12
+ return image_encoder_output
diffusers_helper/dit_common.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import accelerate.accelerator
3
+
4
+ from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous
5
+
6
+
7
+ accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x
8
+
9
+
10
+ def LayerNorm_forward(self, x):
11
+ return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
12
+
13
+
14
+ LayerNorm.forward = LayerNorm_forward
15
+ torch.nn.LayerNorm.forward = LayerNorm_forward
16
+
17
+
18
+ def FP32LayerNorm_forward(self, x):
19
+ origin_dtype = x.dtype
20
+ return torch.nn.functional.layer_norm(
21
+ x.float(),
22
+ self.normalized_shape,
23
+ self.weight.float() if self.weight is not None else None,
24
+ self.bias.float() if self.bias is not None else None,
25
+ self.eps,
26
+ ).to(origin_dtype)
27
+
28
+
29
+ FP32LayerNorm.forward = FP32LayerNorm_forward
30
+
31
+
32
+ def RMSNorm_forward(self, hidden_states):
33
+ input_dtype = hidden_states.dtype
34
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
35
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
36
+
37
+ if self.weight is None:
38
+ return hidden_states.to(input_dtype)
39
+
40
+ return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
41
+
42
+
43
+ RMSNorm.forward = RMSNorm_forward
44
+
45
+
46
+ def AdaLayerNormContinuous_forward(self, x, conditioning_embedding):
47
+ emb = self.linear(self.silu(conditioning_embedding))
48
+ scale, shift = emb.chunk(2, dim=1)
49
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
50
+ return x
51
+
52
+
53
+ AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward
diffusers_helper/gradio/progress_bar.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ progress_html = '''
2
+ <div class="loader-container">
3
+ <div class="loader"></div>
4
+ <div class="progress-container">
5
+ <progress value="*number*" max="100"></progress>
6
+ </div>
7
+ <span>*text*</span>
8
+ </div>
9
+ '''
10
+
11
+ css = '''
12
+ .loader-container {
13
+ display: flex; /* Use flex to align items horizontally */
14
+ align-items: center; /* Center items vertically within the container */
15
+ white-space: nowrap; /* Prevent line breaks within the container */
16
+ }
17
+
18
+ .loader {
19
+ border: 8px solid #f3f3f3; /* Light grey */
20
+ border-top: 8px solid #3498db; /* Blue */
21
+ border-radius: 50%;
22
+ width: 30px;
23
+ height: 30px;
24
+ animation: spin 2s linear infinite;
25
+ }
26
+
27
+ @keyframes spin {
28
+ 0% { transform: rotate(0deg); }
29
+ 100% { transform: rotate(360deg); }
30
+ }
31
+
32
+ /* Style the progress bar */
33
+ progress {
34
+ appearance: none; /* Remove default styling */
35
+ height: 20px; /* Set the height of the progress bar */
36
+ border-radius: 5px; /* Round the corners of the progress bar */
37
+ background-color: #f3f3f3; /* Light grey background */
38
+ width: 100%;
39
+ vertical-align: middle !important;
40
+ }
41
+
42
+ /* Style the progress bar container */
43
+ .progress-container {
44
+ margin-left: 20px;
45
+ margin-right: 20px;
46
+ flex-grow: 1; /* Allow the progress container to take up remaining space */
47
+ }
48
+
49
+ /* Set the color of the progress bar fill */
50
+ progress::-webkit-progress-value {
51
+ background-color: #3498db; /* Blue color for the fill */
52
+ }
53
+
54
+ progress::-moz-progress-bar {
55
+ background-color: #3498db; /* Blue color for the fill in Firefox */
56
+ }
57
+
58
+ /* Style the text on the progress bar */
59
+ progress::after {
60
+ content: attr(value '%'); /* Display the progress value followed by '%' */
61
+ position: absolute;
62
+ top: 50%;
63
+ left: 50%;
64
+ transform: translate(-50%, -50%);
65
+ color: white; /* Set text color */
66
+ font-size: 14px; /* Set font size */
67
+ }
68
+
69
+ /* Style other texts */
70
+ .loader-container > span {
71
+ margin-left: 5px; /* Add spacing between the progress bar and the text */
72
+ }
73
+
74
+ .no-generating-animation > .generating {
75
+ display: none !important;
76
+ }
77
+
78
+ '''
79
+
80
+
81
+ def make_progress_bar_html(number, text):
82
+ return progress_html.replace('*number*', str(number)).replace('*text*', text)
83
+
84
+
85
+ def make_progress_bar_css():
86
+ return css
diffusers_helper/hf_login.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def login(token):
5
+ from huggingface_hub import login
6
+ import time
7
+
8
+ while True:
9
+ try:
10
+ login(token)
11
+ print('HF login ok.')
12
+ break
13
+ except Exception as e:
14
+ print(f'HF login failed: {e}. Retrying')
15
+ time.sleep(0.5)
16
+
17
+
18
+ hf_token = os.environ.get('HF_TOKEN', None)
19
+
20
+ if hf_token is not None:
21
+ login(hf_token)
diffusers_helper/hunyuan.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
4
+ from diffusers_helper.utils import crop_or_pad_yield_mask
5
+
6
+
7
+ @torch.no_grad()
8
+ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256):
9
+ assert isinstance(prompt, str)
10
+
11
+ prompt = [prompt]
12
+
13
+ # LLAMA
14
+
15
+ prompt_llama = [DEFAULT_PROMPT_TEMPLATE["template"].format(p) for p in prompt]
16
+ crop_start = DEFAULT_PROMPT_TEMPLATE["crop_start"]
17
+
18
+ llama_inputs = tokenizer(
19
+ prompt_llama,
20
+ padding="max_length",
21
+ max_length=max_length + crop_start,
22
+ truncation=True,
23
+ return_tensors="pt",
24
+ return_length=False,
25
+ return_overflowing_tokens=False,
26
+ return_attention_mask=True,
27
+ )
28
+
29
+ llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
30
+ llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
31
+ llama_attention_length = int(llama_attention_mask.sum())
32
+
33
+ llama_outputs = text_encoder(
34
+ input_ids=llama_input_ids,
35
+ attention_mask=llama_attention_mask,
36
+ output_hidden_states=True,
37
+ )
38
+
39
+ llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
40
+ # llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
41
+ llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
42
+
43
+ assert torch.all(llama_attention_mask.bool())
44
+
45
+ # CLIP
46
+
47
+ clip_l_input_ids = tokenizer_2(
48
+ prompt,
49
+ padding="max_length",
50
+ max_length=77,
51
+ truncation=True,
52
+ return_overflowing_tokens=False,
53
+ return_length=False,
54
+ return_tensors="pt",
55
+ ).input_ids
56
+ clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
57
+
58
+ return llama_vec, clip_l_pooler
59
+
60
+
61
+ @torch.no_grad()
62
+ def vae_decode_fake(latents):
63
+ latent_rgb_factors = [
64
+ [-0.0395, -0.0331, 0.0445],
65
+ [0.0696, 0.0795, 0.0518],
66
+ [0.0135, -0.0945, -0.0282],
67
+ [0.0108, -0.0250, -0.0765],
68
+ [-0.0209, 0.0032, 0.0224],
69
+ [-0.0804, -0.0254, -0.0639],
70
+ [-0.0991, 0.0271, -0.0669],
71
+ [-0.0646, -0.0422, -0.0400],
72
+ [-0.0696, -0.0595, -0.0894],
73
+ [-0.0799, -0.0208, -0.0375],
74
+ [0.1166, 0.1627, 0.0962],
75
+ [0.1165, 0.0432, 0.0407],
76
+ [-0.2315, -0.1920, -0.1355],
77
+ [-0.0270, 0.0401, -0.0821],
78
+ [-0.0616, -0.0997, -0.0727],
79
+ [0.0249, -0.0469, -0.1703]
80
+ ] # From comfyui
81
+
82
+ latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
83
+
84
+ weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
85
+ bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
86
+
87
+ images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
88
+ images = images.clamp(0.0, 1.0)
89
+
90
+ return images
91
+
92
+
93
+ @torch.no_grad()
94
+ def vae_decode(latents, vae, image_mode=False):
95
+ latents = latents / vae.config.scaling_factor
96
+
97
+ if not image_mode:
98
+ image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
99
+ else:
100
+ latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
101
+ image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
102
+ image = torch.cat(image, dim=2)
103
+
104
+ return image
105
+
106
+
107
+ @torch.no_grad()
108
+ def vae_encode(image, vae):
109
+ latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
110
+ latents = latents * vae.config.scaling_factor
111
+ return latents
diffusers_helper/k_diffusion/uni_pc_fm.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Better Flow Matching UniPC by Lvmin Zhang
2
+ # (c) 2025
3
+ # CC BY-SA 4.0
4
+ # Attribution-ShareAlike 4.0 International Licence
5
+
6
+
7
+ import torch
8
+
9
+ from tqdm.auto import trange
10
+
11
+
12
+ def expand_dims(v, dims):
13
+ return v[(...,) + (None,) * (dims - 1)]
14
+
15
+
16
+ class FlowMatchUniPC:
17
+ def __init__(self, model, extra_args, variant='bh1'):
18
+ self.model = model
19
+ self.variant = variant
20
+ self.extra_args = extra_args
21
+
22
+ def model_fn(self, x, t):
23
+ return self.model(x, t, **self.extra_args)
24
+
25
+ def update_fn(self, x, model_prev_list, t_prev_list, t, order):
26
+ assert order <= len(model_prev_list)
27
+ dims = x.dim()
28
+
29
+ t_prev_0 = t_prev_list[-1]
30
+ lambda_prev_0 = - torch.log(t_prev_0)
31
+ lambda_t = - torch.log(t)
32
+ model_prev_0 = model_prev_list[-1]
33
+
34
+ h = lambda_t - lambda_prev_0
35
+
36
+ rks = []
37
+ D1s = []
38
+ for i in range(1, order):
39
+ t_prev_i = t_prev_list[-(i + 1)]
40
+ model_prev_i = model_prev_list[-(i + 1)]
41
+ lambda_prev_i = - torch.log(t_prev_i)
42
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
43
+ rks.append(rk)
44
+ D1s.append((model_prev_i - model_prev_0) / rk)
45
+
46
+ rks.append(1.)
47
+ rks = torch.tensor(rks, device=x.device)
48
+
49
+ R = []
50
+ b = []
51
+
52
+ hh = -h[0]
53
+ h_phi_1 = torch.expm1(hh)
54
+ h_phi_k = h_phi_1 / hh - 1
55
+
56
+ factorial_i = 1
57
+
58
+ if self.variant == 'bh1':
59
+ B_h = hh
60
+ elif self.variant == 'bh2':
61
+ B_h = torch.expm1(hh)
62
+ else:
63
+ raise NotImplementedError('Bad variant!')
64
+
65
+ for i in range(1, order + 1):
66
+ R.append(torch.pow(rks, i - 1))
67
+ b.append(h_phi_k * factorial_i / B_h)
68
+ factorial_i *= (i + 1)
69
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
70
+
71
+ R = torch.stack(R)
72
+ b = torch.tensor(b, device=x.device)
73
+
74
+ use_predictor = len(D1s) > 0
75
+
76
+ if use_predictor:
77
+ D1s = torch.stack(D1s, dim=1)
78
+ if order == 2:
79
+ rhos_p = torch.tensor([0.5], device=b.device)
80
+ else:
81
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
82
+ else:
83
+ D1s = None
84
+ rhos_p = None
85
+
86
+ if order == 1:
87
+ rhos_c = torch.tensor([0.5], device=b.device)
88
+ else:
89
+ rhos_c = torch.linalg.solve(R, b)
90
+
91
+ x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0
92
+
93
+ if use_predictor:
94
+ pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0]))
95
+ else:
96
+ pred_res = 0
97
+
98
+ x_t = x_t_ - expand_dims(B_h, dims) * pred_res
99
+ model_t = self.model_fn(x_t, t)
100
+
101
+ if D1s is not None:
102
+ corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0]))
103
+ else:
104
+ corr_res = 0
105
+
106
+ D1_t = (model_t - model_prev_0)
107
+ x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
108
+
109
+ return x_t, model_t
110
+
111
+ def sample(self, x, sigmas, callback=None, disable_pbar=False):
112
+ order = min(3, len(sigmas) - 2)
113
+ model_prev_list, t_prev_list = [], []
114
+ for i in trange(len(sigmas) - 1, disable=disable_pbar):
115
+ vec_t = sigmas[i].expand(x.shape[0])
116
+
117
+ if i == 0:
118
+ model_prev_list = [self.model_fn(x, vec_t)]
119
+ t_prev_list = [vec_t]
120
+ elif i < order:
121
+ init_order = i
122
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
123
+ model_prev_list.append(model_x)
124
+ t_prev_list.append(vec_t)
125
+ else:
126
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
127
+ model_prev_list.append(model_x)
128
+ t_prev_list.append(vec_t)
129
+
130
+ model_prev_list = model_prev_list[-order:]
131
+ t_prev_list = t_prev_list[-order:]
132
+
133
+ if callback is not None:
134
+ callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
135
+
136
+ return model_prev_list[-1]
137
+
138
+
139
+ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
140
+ assert variant in ['bh1', 'bh2']
141
+ return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable)
diffusers_helper/k_diffusion/wrapper.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def append_dims(x, target_dims):
5
+ return x[(...,) + (None,) * (target_dims - x.ndim)]
6
+
7
+
8
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0):
9
+ if guidance_rescale == 0:
10
+ return noise_cfg
11
+
12
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
13
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
14
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
15
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg
16
+ return noise_cfg
17
+
18
+
19
+ def fm_wrapper(transformer, t_scale=1000.0):
20
+ def k_model(x, sigma, **extra_args):
21
+ dtype = extra_args['dtype']
22
+ cfg_scale = extra_args['cfg_scale']
23
+ cfg_rescale = extra_args['cfg_rescale']
24
+ concat_latent = extra_args['concat_latent']
25
+
26
+ original_dtype = x.dtype
27
+ sigma = sigma.float()
28
+
29
+ x = x.to(dtype)
30
+ timestep = (sigma * t_scale).to(dtype)
31
+
32
+ if concat_latent is None:
33
+ hidden_states = x
34
+ else:
35
+ hidden_states = torch.cat([x, concat_latent.to(x)], dim=1)
36
+
37
+ pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float()
38
+
39
+ if cfg_scale == 1.0:
40
+ pred_negative = torch.zeros_like(pred_positive)
41
+ else:
42
+ pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float()
43
+
44
+ pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative)
45
+ pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale)
46
+
47
+ x0 = x.float() - pred.float() * append_dims(sigma, x.ndim)
48
+
49
+ return x0.to(dtype=original_dtype)
50
+
51
+ return k_model
diffusers_helper/lora_utils.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path, PurePath
2
+ from typing import Dict, List, Optional, Union
3
+ from diffusers.loaders.lora_pipeline import _fetch_state_dict
4
+ from diffusers.loaders.lora_conversion_utils import _convert_hunyuan_video_lora_to_diffusers
5
+ from diffusers.utils.peft_utils import set_weights_and_activate_adapters
6
+ from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
7
+ import torch
8
+
9
+ def load_lora(transformer, lora_path: Path, weight_name: Optional[str] = "pytorch_lora_weights.safetensors"):
10
+ """
11
+ Load LoRA weights into the transformer model.
12
+
13
+ Args:
14
+ transformer: The transformer model to which LoRA weights will be applied.
15
+ lora_path (Path): Path to the LoRA weights file.
16
+ weight_name (Optional[str]): Name of the weight to load.
17
+
18
+ """
19
+
20
+ state_dict = _fetch_state_dict(
21
+ lora_path,
22
+ weight_name,
23
+ True,
24
+ True,
25
+ None,
26
+ None,
27
+ None,
28
+ None,
29
+ None,
30
+ None,
31
+ None,
32
+ None)
33
+
34
+ state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
35
+
36
+ # should weight_name even be Optional[str] or just str?
37
+ # For now, we assume it is never None
38
+ # The module name in the state_dict must not include a . in the name
39
+ # See https://github.com/pytorch/pytorch/pull/6639/files#diff-4be56271f7bfe650e3521c81fd363da58f109cd23ee80d243156d2d6ccda6263R133-R134
40
+ adapter_name = PurePath(str(weight_name).replace('_DOT_', '.')).stem.replace('.', '_DOT_')
41
+ if '_DOT_' in adapter_name:
42
+ print(
43
+ f"LoRA file '{weight_name}' contains a '.' in the name. " +
44
+ 'This may cause issues. Consider renaming the file.' +
45
+ f" Using '{adapter_name}' as the adapter name to be safe."
46
+ )
47
+
48
+ # Check if adapter already exists and delete it if it does
49
+ if hasattr(transformer, 'peft_config') and adapter_name in transformer.peft_config:
50
+ print(f"Adapter '{adapter_name}' already exists. Removing it before loading again.")
51
+ # Use delete_adapters (plural) instead of delete_adapter
52
+ transformer.delete_adapters([adapter_name])
53
+
54
+ # Load the adapter with the original name
55
+ transformer.load_lora_adapter(state_dict, network_alphas=None, adapter_name=adapter_name)
56
+ print(f"LoRA weights '{adapter_name}' loaded successfully.")
57
+
58
+ return transformer
59
+
60
+ def unload_all_loras(transformer):
61
+ """
62
+ Completely unload all LoRA adapters from the transformer model.
63
+ """
64
+ if hasattr(transformer, 'peft_config') and transformer.peft_config:
65
+ # Get all adapter names
66
+ adapter_names = list(transformer.peft_config.keys())
67
+
68
+ if adapter_names:
69
+ print(f"Removing all LoRA adapters: {', '.join(adapter_names)}")
70
+ # Delete all adapters
71
+ transformer.delete_adapters(adapter_names)
72
+
73
+ # Force cleanup of any remaining adapter references
74
+ if hasattr(transformer, 'active_adapter'):
75
+ transformer.active_adapter = None
76
+
77
+ # Clear any cached states
78
+ for module in transformer.modules():
79
+ if hasattr(module, 'lora_A'):
80
+ if isinstance(module.lora_A, dict):
81
+ module.lora_A.clear()
82
+ if hasattr(module, 'lora_B'):
83
+ if isinstance(module.lora_B, dict):
84
+ module.lora_B.clear()
85
+ if hasattr(module, 'scaling'):
86
+ if isinstance(module.scaling, dict):
87
+ module.scaling.clear()
88
+
89
+ print("All LoRA adapters have been completely removed.")
90
+ else:
91
+ print("No LoRA adapters found to remove.")
92
+ else:
93
+ print("Model doesn't have any LoRA adapters or peft_config.")
94
+
95
+ # Force garbage collection
96
+ import gc
97
+ gc.collect()
98
+ if torch.cuda.is_available():
99
+ torch.cuda.empty_cache()
100
+
101
+ return transformer
102
+
103
+
104
+ # TODO(neph1): remove when HunyuanVideoTransformer3DModelPacked is in _SET_ADAPTER_SCALE_FN_MAPPING
105
+ def set_adapters(
106
+ transformer,
107
+ adapter_names: Union[List[str], str],
108
+ weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
109
+ ):
110
+
111
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
112
+
113
+ # Expand weights into a list, one entry per adapter
114
+ # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
115
+ if not isinstance(weights, list):
116
+ weights = [weights] * len(adapter_names)
117
+
118
+ if len(adapter_names) != len(weights):
119
+ raise ValueError(
120
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
121
+ )
122
+
123
+ # Set None values to default of 1.0
124
+ # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
125
+ weights = [w if w is not None else 1.0 for w in weights]
126
+
127
+ # e.g. [{...}, 7] -> [{expanded dict...}, 7]
128
+ scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING["HunyuanVideoTransformer3DModel"]
129
+ weights = scale_expansion_fn(transformer, weights)
130
+
131
+ set_weights_and_activate_adapters(transformer, adapter_names, weights)
diffusers_helper/memory.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # By lllyasviel
2
+
3
+
4
+ import torch
5
+
6
+
7
+ cpu = torch.device('cpu')
8
+ gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
9
+ gpu_complete_modules = []
10
+
11
+
12
+ class DynamicSwapInstaller:
13
+ @staticmethod
14
+ def _install_module(module: torch.nn.Module, **kwargs):
15
+ original_class = module.__class__
16
+ module.__dict__['forge_backup_original_class'] = original_class
17
+
18
+ def hacked_get_attr(self, name: str):
19
+ if '_parameters' in self.__dict__:
20
+ _parameters = self.__dict__['_parameters']
21
+ if name in _parameters:
22
+ p = _parameters[name]
23
+ if p is None:
24
+ return None
25
+ if p.__class__ == torch.nn.Parameter:
26
+ return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
27
+ else:
28
+ return p.to(**kwargs)
29
+ if '_buffers' in self.__dict__:
30
+ _buffers = self.__dict__['_buffers']
31
+ if name in _buffers:
32
+ return _buffers[name].to(**kwargs)
33
+ return super(original_class, self).__getattr__(name)
34
+
35
+ module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
36
+ '__getattr__': hacked_get_attr,
37
+ })
38
+
39
+ return
40
+
41
+ @staticmethod
42
+ def _uninstall_module(module: torch.nn.Module):
43
+ if 'forge_backup_original_class' in module.__dict__:
44
+ module.__class__ = module.__dict__.pop('forge_backup_original_class')
45
+ return
46
+
47
+ @staticmethod
48
+ def install_model(model: torch.nn.Module, **kwargs):
49
+ for m in model.modules():
50
+ DynamicSwapInstaller._install_module(m, **kwargs)
51
+ return
52
+
53
+ @staticmethod
54
+ def uninstall_model(model: torch.nn.Module):
55
+ for m in model.modules():
56
+ DynamicSwapInstaller._uninstall_module(m)
57
+ return
58
+
59
+
60
+ def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device):
61
+ if hasattr(model, 'scale_shift_table'):
62
+ model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
63
+ return
64
+
65
+ for k, p in model.named_modules():
66
+ if hasattr(p, 'weight'):
67
+ p.to(target_device)
68
+ return
69
+
70
+
71
+ def get_cuda_free_memory_gb(device=None):
72
+ if device is None:
73
+ device = gpu
74
+
75
+ memory_stats = torch.cuda.memory_stats(device)
76
+ bytes_active = memory_stats['active_bytes.all.current']
77
+ bytes_reserved = memory_stats['reserved_bytes.all.current']
78
+ bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
79
+ bytes_inactive_reserved = bytes_reserved - bytes_active
80
+ bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
81
+ return bytes_total_available / (1024 ** 3)
82
+
83
+
84
+ def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
85
+ print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')
86
+
87
+ for m in model.modules():
88
+ if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
89
+ torch.cuda.empty_cache()
90
+ return
91
+
92
+ if hasattr(m, 'weight'):
93
+ m.to(device=target_device)
94
+
95
+ model.to(device=target_device)
96
+ torch.cuda.empty_cache()
97
+ return
98
+
99
+
100
+ def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
101
+ print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')
102
+
103
+ for m in model.modules():
104
+ if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
105
+ torch.cuda.empty_cache()
106
+ return
107
+
108
+ if hasattr(m, 'weight'):
109
+ m.to(device=cpu)
110
+
111
+ model.to(device=cpu)
112
+ torch.cuda.empty_cache()
113
+ return
114
+
115
+
116
+ def unload_complete_models(*args):
117
+ for m in gpu_complete_modules + list(args):
118
+ m.to(device=cpu)
119
+ print(f'Unloaded {m.__class__.__name__} as complete.')
120
+
121
+ gpu_complete_modules.clear()
122
+ torch.cuda.empty_cache()
123
+ return
124
+
125
+
126
+ def load_model_as_complete(model, target_device, unload=True):
127
+ if unload:
128
+ unload_complete_models()
129
+
130
+ model.to(device=target_device)
131
+ print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')
132
+
133
+ gpu_complete_modules.append(model)
134
+ return
diffusers_helper/models/hunyuan_video_packed.py ADDED
@@ -0,0 +1,1032 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import einops
5
+ import torch.nn as nn
6
+ import numpy as np
7
+
8
+ from diffusers.loaders import FromOriginalModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.loaders import PeftAdapterMixin
11
+ from diffusers.utils import logging
12
+ from diffusers.models.attention import FeedForward
13
+ from diffusers.models.attention_processor import Attention
14
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection
15
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from diffusers_helper.dit_common import LayerNorm
18
+ from diffusers_helper.utils import zero_module
19
+
20
+
21
+ enabled_backends = []
22
+
23
+ if torch.backends.cuda.flash_sdp_enabled():
24
+ enabled_backends.append("flash")
25
+ if torch.backends.cuda.math_sdp_enabled():
26
+ enabled_backends.append("math")
27
+ if torch.backends.cuda.mem_efficient_sdp_enabled():
28
+ enabled_backends.append("mem_efficient")
29
+ if torch.backends.cuda.cudnn_sdp_enabled():
30
+ enabled_backends.append("cudnn")
31
+
32
+ print("Currently enabled native sdp backends:", enabled_backends)
33
+
34
+ try:
35
+ # raise NotImplementedError
36
+ from xformers.ops import memory_efficient_attention as xformers_attn_func
37
+ print('Xformers is installed!')
38
+ except:
39
+ print('Xformers is not installed!')
40
+ xformers_attn_func = None
41
+
42
+ try:
43
+ # raise NotImplementedError
44
+ from flash_attn import flash_attn_varlen_func, flash_attn_func
45
+ print('Flash Attn is installed!')
46
+ except:
47
+ print('Flash Attn is not installed!')
48
+ flash_attn_varlen_func = None
49
+ flash_attn_func = None
50
+
51
+ try:
52
+ # raise NotImplementedError
53
+ from sageattention import sageattn_varlen, sageattn
54
+ print('Sage Attn is installed!')
55
+ except:
56
+ print('Sage Attn is not installed!')
57
+ sageattn_varlen = None
58
+ sageattn = None
59
+
60
+
61
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
+
63
+
64
+ def pad_for_3d_conv(x, kernel_size):
65
+ b, c, t, h, w = x.shape
66
+ pt, ph, pw = kernel_size
67
+ pad_t = (pt - (t % pt)) % pt
68
+ pad_h = (ph - (h % ph)) % ph
69
+ pad_w = (pw - (w % pw)) % pw
70
+ return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate')
71
+
72
+
73
+ def center_down_sample_3d(x, kernel_size):
74
+ # pt, ph, pw = kernel_size
75
+ # cp = (pt * ph * pw) // 2
76
+ # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
77
+ # xc = xp[cp]
78
+ # return xc
79
+ return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
80
+
81
+
82
+ def get_cu_seqlens(text_mask, img_len):
83
+ batch_size = text_mask.shape[0]
84
+ text_len = text_mask.sum(dim=1)
85
+ max_len = text_mask.shape[1] + img_len
86
+
87
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
88
+
89
+ for i in range(batch_size):
90
+ s = text_len[i] + img_len
91
+ s1 = i * max_len + s
92
+ s2 = (i + 1) * max_len
93
+ cu_seqlens[2 * i + 1] = s1
94
+ cu_seqlens[2 * i + 2] = s2
95
+
96
+ return cu_seqlens
97
+
98
+
99
+ def apply_rotary_emb_transposed(x, freqs_cis):
100
+ cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
101
+ x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
102
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
103
+ out = x.float() * cos + x_rotated.float() * sin
104
+ out = out.to(x)
105
+ return out
106
+
107
+
108
+ def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv):
109
+ if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
110
+ if sageattn is not None:
111
+ x = sageattn(q, k, v, tensor_layout='NHD')
112
+ return x
113
+
114
+ if flash_attn_func is not None:
115
+ x = flash_attn_func(q, k, v)
116
+ return x
117
+
118
+ if xformers_attn_func is not None:
119
+ x = xformers_attn_func(q, k, v)
120
+ return x
121
+
122
+ x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
123
+ return x
124
+
125
+ batch_size = q.shape[0]
126
+ q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
127
+ k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
128
+ v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
129
+ if sageattn_varlen is not None:
130
+ x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
131
+ elif flash_attn_varlen_func is not None:
132
+ x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
133
+ else:
134
+ raise NotImplementedError('No Attn Installed!')
135
+ x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
136
+ return x
137
+
138
+
139
+ class HunyuanAttnProcessorFlashAttnDouble:
140
+ def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
141
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
142
+
143
+ query = attn.to_q(hidden_states)
144
+ key = attn.to_k(hidden_states)
145
+ value = attn.to_v(hidden_states)
146
+
147
+ query = query.unflatten(2, (attn.heads, -1))
148
+ key = key.unflatten(2, (attn.heads, -1))
149
+ value = value.unflatten(2, (attn.heads, -1))
150
+
151
+ query = attn.norm_q(query)
152
+ key = attn.norm_k(key)
153
+
154
+ query = apply_rotary_emb_transposed(query, image_rotary_emb)
155
+ key = apply_rotary_emb_transposed(key, image_rotary_emb)
156
+
157
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
158
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
159
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
160
+
161
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
162
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
163
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
164
+
165
+ encoder_query = attn.norm_added_q(encoder_query)
166
+ encoder_key = attn.norm_added_k(encoder_key)
167
+
168
+ query = torch.cat([query, encoder_query], dim=1)
169
+ key = torch.cat([key, encoder_key], dim=1)
170
+ value = torch.cat([value, encoder_value], dim=1)
171
+
172
+ hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
173
+ hidden_states = hidden_states.flatten(-2)
174
+
175
+ txt_length = encoder_hidden_states.shape[1]
176
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
177
+
178
+ hidden_states = attn.to_out[0](hidden_states)
179
+ hidden_states = attn.to_out[1](hidden_states)
180
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
181
+
182
+ return hidden_states, encoder_hidden_states
183
+
184
+
185
+ class HunyuanAttnProcessorFlashAttnSingle:
186
+ def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
187
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
188
+
189
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
190
+
191
+ query = attn.to_q(hidden_states)
192
+ key = attn.to_k(hidden_states)
193
+ value = attn.to_v(hidden_states)
194
+
195
+ query = query.unflatten(2, (attn.heads, -1))
196
+ key = key.unflatten(2, (attn.heads, -1))
197
+ value = value.unflatten(2, (attn.heads, -1))
198
+
199
+ query = attn.norm_q(query)
200
+ key = attn.norm_k(key)
201
+
202
+ txt_length = encoder_hidden_states.shape[1]
203
+
204
+ query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
205
+ key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
206
+
207
+ hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
208
+ hidden_states = hidden_states.flatten(-2)
209
+
210
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
211
+
212
+ return hidden_states, encoder_hidden_states
213
+
214
+
215
+ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
216
+ def __init__(self, embedding_dim, pooled_projection_dim):
217
+ super().__init__()
218
+
219
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
220
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
221
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
222
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
223
+
224
+ def forward(self, timestep, guidance, pooled_projection):
225
+ timesteps_proj = self.time_proj(timestep)
226
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
227
+
228
+ guidance_proj = self.time_proj(guidance)
229
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
230
+
231
+ time_guidance_emb = timesteps_emb + guidance_emb
232
+
233
+ pooled_projections = self.text_embedder(pooled_projection)
234
+ conditioning = time_guidance_emb + pooled_projections
235
+
236
+ return conditioning
237
+
238
+
239
+ class CombinedTimestepTextProjEmbeddings(nn.Module):
240
+ def __init__(self, embedding_dim, pooled_projection_dim):
241
+ super().__init__()
242
+
243
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
244
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
245
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
246
+
247
+ def forward(self, timestep, pooled_projection):
248
+ timesteps_proj = self.time_proj(timestep)
249
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
250
+
251
+ pooled_projections = self.text_embedder(pooled_projection)
252
+
253
+ conditioning = timesteps_emb + pooled_projections
254
+
255
+ return conditioning
256
+
257
+
258
+ class HunyuanVideoAdaNorm(nn.Module):
259
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
260
+ super().__init__()
261
+
262
+ out_features = out_features or 2 * in_features
263
+ self.linear = nn.Linear(in_features, out_features)
264
+ self.nonlinearity = nn.SiLU()
265
+
266
+ def forward(
267
+ self, temb: torch.Tensor
268
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
269
+ temb = self.linear(self.nonlinearity(temb))
270
+ gate_msa, gate_mlp = temb.chunk(2, dim=-1)
271
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
272
+ return gate_msa, gate_mlp
273
+
274
+
275
+ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
276
+ def __init__(
277
+ self,
278
+ num_attention_heads: int,
279
+ attention_head_dim: int,
280
+ mlp_width_ratio: str = 4.0,
281
+ mlp_drop_rate: float = 0.0,
282
+ attention_bias: bool = True,
283
+ ) -> None:
284
+ super().__init__()
285
+
286
+ hidden_size = num_attention_heads * attention_head_dim
287
+
288
+ self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
289
+ self.attn = Attention(
290
+ query_dim=hidden_size,
291
+ cross_attention_dim=None,
292
+ heads=num_attention_heads,
293
+ dim_head=attention_head_dim,
294
+ bias=attention_bias,
295
+ )
296
+
297
+ self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
298
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
299
+
300
+ self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
301
+
302
+ def forward(
303
+ self,
304
+ hidden_states: torch.Tensor,
305
+ temb: torch.Tensor,
306
+ attention_mask: Optional[torch.Tensor] = None,
307
+ ) -> torch.Tensor:
308
+ norm_hidden_states = self.norm1(hidden_states)
309
+
310
+ attn_output = self.attn(
311
+ hidden_states=norm_hidden_states,
312
+ encoder_hidden_states=None,
313
+ attention_mask=attention_mask,
314
+ )
315
+
316
+ gate_msa, gate_mlp = self.norm_out(temb)
317
+ hidden_states = hidden_states + attn_output * gate_msa
318
+
319
+ ff_output = self.ff(self.norm2(hidden_states))
320
+ hidden_states = hidden_states + ff_output * gate_mlp
321
+
322
+ return hidden_states
323
+
324
+
325
+ class HunyuanVideoIndividualTokenRefiner(nn.Module):
326
+ def __init__(
327
+ self,
328
+ num_attention_heads: int,
329
+ attention_head_dim: int,
330
+ num_layers: int,
331
+ mlp_width_ratio: float = 4.0,
332
+ mlp_drop_rate: float = 0.0,
333
+ attention_bias: bool = True,
334
+ ) -> None:
335
+ super().__init__()
336
+
337
+ self.refiner_blocks = nn.ModuleList(
338
+ [
339
+ HunyuanVideoIndividualTokenRefinerBlock(
340
+ num_attention_heads=num_attention_heads,
341
+ attention_head_dim=attention_head_dim,
342
+ mlp_width_ratio=mlp_width_ratio,
343
+ mlp_drop_rate=mlp_drop_rate,
344
+ attention_bias=attention_bias,
345
+ )
346
+ for _ in range(num_layers)
347
+ ]
348
+ )
349
+
350
+ def forward(
351
+ self,
352
+ hidden_states: torch.Tensor,
353
+ temb: torch.Tensor,
354
+ attention_mask: Optional[torch.Tensor] = None,
355
+ ) -> None:
356
+ self_attn_mask = None
357
+ if attention_mask is not None:
358
+ batch_size = attention_mask.shape[0]
359
+ seq_len = attention_mask.shape[1]
360
+ attention_mask = attention_mask.to(hidden_states.device).bool()
361
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
362
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
363
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
364
+ self_attn_mask[:, :, :, 0] = True
365
+
366
+ for block in self.refiner_blocks:
367
+ hidden_states = block(hidden_states, temb, self_attn_mask)
368
+
369
+ return hidden_states
370
+
371
+
372
+ class HunyuanVideoTokenRefiner(nn.Module):
373
+ def __init__(
374
+ self,
375
+ in_channels: int,
376
+ num_attention_heads: int,
377
+ attention_head_dim: int,
378
+ num_layers: int,
379
+ mlp_ratio: float = 4.0,
380
+ mlp_drop_rate: float = 0.0,
381
+ attention_bias: bool = True,
382
+ ) -> None:
383
+ super().__init__()
384
+
385
+ hidden_size = num_attention_heads * attention_head_dim
386
+
387
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
388
+ embedding_dim=hidden_size, pooled_projection_dim=in_channels
389
+ )
390
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
391
+ self.token_refiner = HunyuanVideoIndividualTokenRefiner(
392
+ num_attention_heads=num_attention_heads,
393
+ attention_head_dim=attention_head_dim,
394
+ num_layers=num_layers,
395
+ mlp_width_ratio=mlp_ratio,
396
+ mlp_drop_rate=mlp_drop_rate,
397
+ attention_bias=attention_bias,
398
+ )
399
+
400
+ def forward(
401
+ self,
402
+ hidden_states: torch.Tensor,
403
+ timestep: torch.LongTensor,
404
+ attention_mask: Optional[torch.LongTensor] = None,
405
+ ) -> torch.Tensor:
406
+ if attention_mask is None:
407
+ pooled_projections = hidden_states.mean(dim=1)
408
+ else:
409
+ original_dtype = hidden_states.dtype
410
+ mask_float = attention_mask.float().unsqueeze(-1)
411
+ pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
412
+ pooled_projections = pooled_projections.to(original_dtype)
413
+
414
+ temb = self.time_text_embed(timestep, pooled_projections)
415
+ hidden_states = self.proj_in(hidden_states)
416
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
417
+
418
+ return hidden_states
419
+
420
+
421
+ class HunyuanVideoRotaryPosEmbed(nn.Module):
422
+ def __init__(self, rope_dim, theta):
423
+ super().__init__()
424
+ self.DT, self.DY, self.DX = rope_dim
425
+ self.theta = theta
426
+
427
+ @torch.no_grad()
428
+ def get_frequency(self, dim, pos):
429
+ T, H, W = pos.shape
430
+ freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
431
+ freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
432
+ return freqs.cos(), freqs.sin()
433
+
434
+ @torch.no_grad()
435
+ def forward_inner(self, frame_indices, height, width, device):
436
+ GT, GY, GX = torch.meshgrid(
437
+ frame_indices.to(device=device, dtype=torch.float32),
438
+ torch.arange(0, height, device=device, dtype=torch.float32),
439
+ torch.arange(0, width, device=device, dtype=torch.float32),
440
+ indexing="ij"
441
+ )
442
+
443
+ FCT, FST = self.get_frequency(self.DT, GT)
444
+ FCY, FSY = self.get_frequency(self.DY, GY)
445
+ FCX, FSX = self.get_frequency(self.DX, GX)
446
+
447
+ result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
448
+
449
+ return result.to(device)
450
+
451
+ @torch.no_grad()
452
+ def forward(self, frame_indices, height, width, device):
453
+ frame_indices = frame_indices.unbind(0)
454
+ results = [self.forward_inner(f, height, width, device) for f in frame_indices]
455
+ results = torch.stack(results, dim=0)
456
+ return results
457
+
458
+
459
+ class AdaLayerNormZero(nn.Module):
460
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
461
+ super().__init__()
462
+ self.silu = nn.SiLU()
463
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
464
+ if norm_type == "layer_norm":
465
+ self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
466
+ else:
467
+ raise ValueError(f"unknown norm_type {norm_type}")
468
+
469
+ def forward(
470
+ self,
471
+ x: torch.Tensor,
472
+ emb: Optional[torch.Tensor] = None,
473
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
474
+ emb = emb.unsqueeze(-2)
475
+ emb = self.linear(self.silu(emb))
476
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
477
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
478
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
479
+
480
+
481
+ class AdaLayerNormZeroSingle(nn.Module):
482
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
483
+ super().__init__()
484
+
485
+ self.silu = nn.SiLU()
486
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
487
+ if norm_type == "layer_norm":
488
+ self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
489
+ else:
490
+ raise ValueError(f"unknown norm_type {norm_type}")
491
+
492
+ def forward(
493
+ self,
494
+ x: torch.Tensor,
495
+ emb: Optional[torch.Tensor] = None,
496
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
497
+ emb = emb.unsqueeze(-2)
498
+ emb = self.linear(self.silu(emb))
499
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
500
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
501
+ return x, gate_msa
502
+
503
+
504
+ class AdaLayerNormContinuous(nn.Module):
505
+ def __init__(
506
+ self,
507
+ embedding_dim: int,
508
+ conditioning_embedding_dim: int,
509
+ elementwise_affine=True,
510
+ eps=1e-5,
511
+ bias=True,
512
+ norm_type="layer_norm",
513
+ ):
514
+ super().__init__()
515
+ self.silu = nn.SiLU()
516
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
517
+ if norm_type == "layer_norm":
518
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
519
+ else:
520
+ raise ValueError(f"unknown norm_type {norm_type}")
521
+
522
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
523
+ emb = emb.unsqueeze(-2)
524
+ emb = self.linear(self.silu(emb))
525
+ scale, shift = emb.chunk(2, dim=-1)
526
+ x = self.norm(x) * (1 + scale) + shift
527
+ return x
528
+
529
+
530
+ class HunyuanVideoSingleTransformerBlock(nn.Module):
531
+ def __init__(
532
+ self,
533
+ num_attention_heads: int,
534
+ attention_head_dim: int,
535
+ mlp_ratio: float = 4.0,
536
+ qk_norm: str = "rms_norm",
537
+ ) -> None:
538
+ super().__init__()
539
+
540
+ hidden_size = num_attention_heads * attention_head_dim
541
+ mlp_dim = int(hidden_size * mlp_ratio)
542
+
543
+ self.attn = Attention(
544
+ query_dim=hidden_size,
545
+ cross_attention_dim=None,
546
+ dim_head=attention_head_dim,
547
+ heads=num_attention_heads,
548
+ out_dim=hidden_size,
549
+ bias=True,
550
+ processor=HunyuanAttnProcessorFlashAttnSingle(),
551
+ qk_norm=qk_norm,
552
+ eps=1e-6,
553
+ pre_only=True,
554
+ )
555
+
556
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
557
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
558
+ self.act_mlp = nn.GELU(approximate="tanh")
559
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
560
+
561
+ def forward(
562
+ self,
563
+ hidden_states: torch.Tensor,
564
+ encoder_hidden_states: torch.Tensor,
565
+ temb: torch.Tensor,
566
+ attention_mask: Optional[torch.Tensor] = None,
567
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
568
+ ) -> torch.Tensor:
569
+ text_seq_length = encoder_hidden_states.shape[1]
570
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
571
+
572
+ residual = hidden_states
573
+
574
+ # 1. Input normalization
575
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
576
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
577
+
578
+ norm_hidden_states, norm_encoder_hidden_states = (
579
+ norm_hidden_states[:, :-text_seq_length, :],
580
+ norm_hidden_states[:, -text_seq_length:, :],
581
+ )
582
+
583
+ # 2. Attention
584
+ attn_output, context_attn_output = self.attn(
585
+ hidden_states=norm_hidden_states,
586
+ encoder_hidden_states=norm_encoder_hidden_states,
587
+ attention_mask=attention_mask,
588
+ image_rotary_emb=image_rotary_emb,
589
+ )
590
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
591
+
592
+ # 3. Modulation and residual connection
593
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
594
+ hidden_states = gate * self.proj_out(hidden_states)
595
+ hidden_states = hidden_states + residual
596
+
597
+ hidden_states, encoder_hidden_states = (
598
+ hidden_states[:, :-text_seq_length, :],
599
+ hidden_states[:, -text_seq_length:, :],
600
+ )
601
+ return hidden_states, encoder_hidden_states
602
+
603
+
604
+ class HunyuanVideoTransformerBlock(nn.Module):
605
+ def __init__(
606
+ self,
607
+ num_attention_heads: int,
608
+ attention_head_dim: int,
609
+ mlp_ratio: float,
610
+ qk_norm: str = "rms_norm",
611
+ ) -> None:
612
+ super().__init__()
613
+
614
+ hidden_size = num_attention_heads * attention_head_dim
615
+
616
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
617
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
618
+
619
+ self.attn = Attention(
620
+ query_dim=hidden_size,
621
+ cross_attention_dim=None,
622
+ added_kv_proj_dim=hidden_size,
623
+ dim_head=attention_head_dim,
624
+ heads=num_attention_heads,
625
+ out_dim=hidden_size,
626
+ context_pre_only=False,
627
+ bias=True,
628
+ processor=HunyuanAttnProcessorFlashAttnDouble(),
629
+ qk_norm=qk_norm,
630
+ eps=1e-6,
631
+ )
632
+
633
+ self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
634
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
635
+
636
+ self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
637
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
638
+
639
+ def forward(
640
+ self,
641
+ hidden_states: torch.Tensor,
642
+ encoder_hidden_states: torch.Tensor,
643
+ temb: torch.Tensor,
644
+ attention_mask: Optional[torch.Tensor] = None,
645
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
646
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
647
+ # 1. Input normalization
648
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
649
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
650
+
651
+ # 2. Joint attention
652
+ attn_output, context_attn_output = self.attn(
653
+ hidden_states=norm_hidden_states,
654
+ encoder_hidden_states=norm_encoder_hidden_states,
655
+ attention_mask=attention_mask,
656
+ image_rotary_emb=freqs_cis,
657
+ )
658
+
659
+ # 3. Modulation and residual connection
660
+ hidden_states = hidden_states + attn_output * gate_msa
661
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
662
+
663
+ norm_hidden_states = self.norm2(hidden_states)
664
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
665
+
666
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
667
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
668
+
669
+ # 4. Feed-forward
670
+ ff_output = self.ff(norm_hidden_states)
671
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
672
+
673
+ hidden_states = hidden_states + gate_mlp * ff_output
674
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
675
+
676
+ return hidden_states, encoder_hidden_states
677
+
678
+
679
+ class ClipVisionProjection(nn.Module):
680
+ def __init__(self, in_channels, out_channels):
681
+ super().__init__()
682
+ self.up = nn.Linear(in_channels, out_channels * 3)
683
+ self.down = nn.Linear(out_channels * 3, out_channels)
684
+
685
+ def forward(self, x):
686
+ projected_x = self.down(nn.functional.silu(self.up(x)))
687
+ return projected_x
688
+
689
+
690
+ class HunyuanVideoPatchEmbed(nn.Module):
691
+ def __init__(self, patch_size, in_chans, embed_dim):
692
+ super().__init__()
693
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
694
+
695
+
696
+ class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
697
+ def __init__(self, inner_dim):
698
+ super().__init__()
699
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
700
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
701
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
702
+
703
+ @torch.no_grad()
704
+ def initialize_weight_from_another_conv3d(self, another_layer):
705
+ weight = another_layer.weight.detach().clone()
706
+ bias = another_layer.bias.detach().clone()
707
+
708
+ sd = {
709
+ 'proj.weight': weight.clone(),
710
+ 'proj.bias': bias.clone(),
711
+ 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0,
712
+ 'proj_2x.bias': bias.clone(),
713
+ 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0,
714
+ 'proj_4x.bias': bias.clone(),
715
+ }
716
+
717
+ sd = {k: v.clone() for k, v in sd.items()}
718
+
719
+ self.load_state_dict(sd)
720
+ return
721
+
722
+
723
+ class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
724
+ @register_to_config
725
+ def __init__(
726
+ self,
727
+ in_channels: int = 16,
728
+ out_channels: int = 16,
729
+ num_attention_heads: int = 24,
730
+ attention_head_dim: int = 128,
731
+ num_layers: int = 20,
732
+ num_single_layers: int = 40,
733
+ num_refiner_layers: int = 2,
734
+ mlp_ratio: float = 4.0,
735
+ patch_size: int = 2,
736
+ patch_size_t: int = 1,
737
+ qk_norm: str = "rms_norm",
738
+ guidance_embeds: bool = True,
739
+ text_embed_dim: int = 4096,
740
+ pooled_projection_dim: int = 768,
741
+ rope_theta: float = 256.0,
742
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
743
+ has_image_proj=False,
744
+ image_proj_dim=1152,
745
+ has_clean_x_embedder=False,
746
+ ) -> None:
747
+ super().__init__()
748
+
749
+ inner_dim = num_attention_heads * attention_head_dim
750
+ out_channels = out_channels or in_channels
751
+
752
+ # 1. Latent and condition embedders
753
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
754
+ self.context_embedder = HunyuanVideoTokenRefiner(
755
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
756
+ )
757
+ self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
758
+
759
+ self.clean_x_embedder = None
760
+ self.image_projection = None
761
+
762
+ # 2. RoPE
763
+ self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
764
+
765
+ # 3. Dual stream transformer blocks
766
+ self.transformer_blocks = nn.ModuleList(
767
+ [
768
+ HunyuanVideoTransformerBlock(
769
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
770
+ )
771
+ for _ in range(num_layers)
772
+ ]
773
+ )
774
+
775
+ # 4. Single stream transformer blocks
776
+ self.single_transformer_blocks = nn.ModuleList(
777
+ [
778
+ HunyuanVideoSingleTransformerBlock(
779
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
780
+ )
781
+ for _ in range(num_single_layers)
782
+ ]
783
+ )
784
+
785
+ # 5. Output projection
786
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
787
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
788
+
789
+ self.inner_dim = inner_dim
790
+ self.use_gradient_checkpointing = False
791
+ self.enable_teacache = False
792
+
793
+ if has_image_proj:
794
+ self.install_image_projection(image_proj_dim)
795
+
796
+ if has_clean_x_embedder:
797
+ self.install_clean_x_embedder()
798
+
799
+ self.high_quality_fp32_output_for_inference = False
800
+
801
+ def install_image_projection(self, in_channels):
802
+ self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim)
803
+ self.config['has_image_proj'] = True
804
+ self.config['image_proj_dim'] = in_channels
805
+
806
+ def install_clean_x_embedder(self):
807
+ self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
808
+ self.config['has_clean_x_embedder'] = True
809
+
810
+ def enable_gradient_checkpointing(self):
811
+ self.use_gradient_checkpointing = True
812
+ print('self.use_gradient_checkpointing = True')
813
+
814
+ def disable_gradient_checkpointing(self):
815
+ self.use_gradient_checkpointing = False
816
+ print('self.use_gradient_checkpointing = False')
817
+
818
+ def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
819
+ self.enable_teacache = enable_teacache
820
+ self.cnt = 0
821
+ self.num_steps = num_steps
822
+ self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
823
+ self.accumulated_rel_l1_distance = 0
824
+ self.previous_modulated_input = None
825
+ self.previous_residual = None
826
+ self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02])
827
+
828
+ def gradient_checkpointing_method(self, block, *args):
829
+ if self.use_gradient_checkpointing:
830
+ result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
831
+ else:
832
+ result = block(*args)
833
+ return result
834
+
835
+ def process_input_hidden_states(
836
+ self,
837
+ latents, latent_indices=None,
838
+ clean_latents=None, clean_latent_indices=None,
839
+ clean_latents_2x=None, clean_latent_2x_indices=None,
840
+ clean_latents_4x=None, clean_latent_4x_indices=None
841
+ ):
842
+ hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
843
+ B, C, T, H, W = hidden_states.shape
844
+
845
+ if latent_indices is None:
846
+ latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
847
+
848
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
849
+
850
+ rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
851
+ rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
852
+
853
+ if clean_latents is not None and clean_latent_indices is not None:
854
+ clean_latents = clean_latents.to(hidden_states)
855
+ clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
856
+ clean_latents = clean_latents.flatten(2).transpose(1, 2)
857
+
858
+ clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
859
+ clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
860
+
861
+ hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
862
+ rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
863
+
864
+ if clean_latents_2x is not None and clean_latent_2x_indices is not None:
865
+ clean_latents_2x = clean_latents_2x.to(hidden_states)
866
+ clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
867
+ clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
868
+ clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
869
+
870
+ clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device)
871
+ clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
872
+ clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
873
+ clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
874
+
875
+ hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
876
+ rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
877
+
878
+ if clean_latents_4x is not None and clean_latent_4x_indices is not None:
879
+ clean_latents_4x = clean_latents_4x.to(hidden_states)
880
+ clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
881
+ clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
882
+ clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
883
+
884
+ clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device)
885
+ clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
886
+ clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
887
+ clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
888
+
889
+ hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
890
+ rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
891
+
892
+ return hidden_states, rope_freqs
893
+
894
+ def forward(
895
+ self,
896
+ hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance,
897
+ latent_indices=None,
898
+ clean_latents=None, clean_latent_indices=None,
899
+ clean_latents_2x=None, clean_latent_2x_indices=None,
900
+ clean_latents_4x=None, clean_latent_4x_indices=None,
901
+ image_embeddings=None,
902
+ attention_kwargs=None, return_dict=True
903
+ ):
904
+
905
+ if attention_kwargs is None:
906
+ attention_kwargs = {}
907
+
908
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
909
+ p, p_t = self.config['patch_size'], self.config['patch_size_t']
910
+ post_patch_num_frames = num_frames // p_t
911
+ post_patch_height = height // p
912
+ post_patch_width = width // p
913
+ original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
914
+
915
+ hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices)
916
+
917
+ temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
918
+ encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask)
919
+
920
+ if self.image_projection is not None:
921
+ assert image_embeddings is not None, 'You must use image embeddings!'
922
+ extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
923
+ extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device)
924
+
925
+ # must cat before (not after) encoder_hidden_states, due to attn masking
926
+ encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
927
+ encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
928
+
929
+ with torch.no_grad():
930
+ if batch_size == 1:
931
+ # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
932
+ # If they are not same, then their impls are wrong. Ours are always the correct one.
933
+ text_len = encoder_attention_mask.sum().item()
934
+ encoder_hidden_states = encoder_hidden_states[:, :text_len]
935
+ attention_mask = None, None, None, None
936
+ else:
937
+ img_seq_len = hidden_states.shape[1]
938
+ txt_seq_len = encoder_hidden_states.shape[1]
939
+
940
+ cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
941
+ cu_seqlens_kv = cu_seqlens_q
942
+ max_seqlen_q = img_seq_len + txt_seq_len
943
+ max_seqlen_kv = max_seqlen_q
944
+
945
+ attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
946
+
947
+ if self.enable_teacache:
948
+ modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
949
+
950
+ if self.cnt == 0 or self.cnt == self.num_steps-1:
951
+ should_calc = True
952
+ self.accumulated_rel_l1_distance = 0
953
+ else:
954
+ curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
955
+ self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
956
+ should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
957
+
958
+ if should_calc:
959
+ self.accumulated_rel_l1_distance = 0
960
+
961
+ self.previous_modulated_input = modulated_inp
962
+ self.cnt += 1
963
+
964
+ if self.cnt == self.num_steps:
965
+ self.cnt = 0
966
+
967
+ if not should_calc:
968
+ hidden_states = hidden_states + self.previous_residual
969
+ else:
970
+ ori_hidden_states = hidden_states.clone()
971
+
972
+ for block_id, block in enumerate(self.transformer_blocks):
973
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
974
+ block,
975
+ hidden_states,
976
+ encoder_hidden_states,
977
+ temb,
978
+ attention_mask,
979
+ rope_freqs
980
+ )
981
+
982
+ for block_id, block in enumerate(self.single_transformer_blocks):
983
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
984
+ block,
985
+ hidden_states,
986
+ encoder_hidden_states,
987
+ temb,
988
+ attention_mask,
989
+ rope_freqs
990
+ )
991
+
992
+ self.previous_residual = hidden_states - ori_hidden_states
993
+ else:
994
+ for block_id, block in enumerate(self.transformer_blocks):
995
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
996
+ block,
997
+ hidden_states,
998
+ encoder_hidden_states,
999
+ temb,
1000
+ attention_mask,
1001
+ rope_freqs
1002
+ )
1003
+
1004
+ for block_id, block in enumerate(self.single_transformer_blocks):
1005
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1006
+ block,
1007
+ hidden_states,
1008
+ encoder_hidden_states,
1009
+ temb,
1010
+ attention_mask,
1011
+ rope_freqs
1012
+ )
1013
+
1014
+ hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
1015
+
1016
+ hidden_states = hidden_states[:, -original_context_length:, :]
1017
+
1018
+ if self.high_quality_fp32_output_for_inference:
1019
+ hidden_states = hidden_states.to(dtype=torch.float32)
1020
+ if self.proj_out.weight.dtype != torch.float32:
1021
+ self.proj_out.to(dtype=torch.float32)
1022
+
1023
+ hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
1024
+
1025
+ hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)',
1026
+ t=post_patch_num_frames, h=post_patch_height, w=post_patch_width,
1027
+ pt=p_t, ph=p, pw=p)
1028
+
1029
+ if return_dict:
1030
+ return Transformer2DModelOutput(sample=hidden_states)
1031
+
1032
+ return hidden_states,
diffusers_helper/pipelines/k_diffusion_hunyuan.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc
5
+ from diffusers_helper.k_diffusion.wrapper import fm_wrapper
6
+ from diffusers_helper.utils import repeat_to_batch_size
7
+
8
+
9
+ def flux_time_shift(t, mu=1.15, sigma=1.0):
10
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
11
+
12
+
13
+ def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0):
14
+ k = (y2 - y1) / (x2 - x1)
15
+ b = y1 - k * x1
16
+ mu = k * context_length + b
17
+ mu = min(mu, math.log(exp_max))
18
+ return mu
19
+
20
+
21
+ def get_flux_sigmas_from_mu(n, mu):
22
+ sigmas = torch.linspace(1, 0, steps=n + 1)
23
+ sigmas = flux_time_shift(sigmas, mu=mu)
24
+ return sigmas
25
+
26
+
27
+ @torch.inference_mode()
28
+ def sample_hunyuan(
29
+ transformer,
30
+ sampler='unipc',
31
+ initial_latent=None,
32
+ concat_latent=None,
33
+ strength=1.0,
34
+ width=512,
35
+ height=512,
36
+ frames=16,
37
+ real_guidance_scale=1.0,
38
+ distilled_guidance_scale=6.0,
39
+ guidance_rescale=0.0,
40
+ shift=None,
41
+ num_inference_steps=25,
42
+ batch_size=None,
43
+ generator=None,
44
+ prompt_embeds=None,
45
+ prompt_embeds_mask=None,
46
+ prompt_poolers=None,
47
+ negative_prompt_embeds=None,
48
+ negative_prompt_embeds_mask=None,
49
+ negative_prompt_poolers=None,
50
+ dtype=torch.bfloat16,
51
+ device=None,
52
+ negative_kwargs=None,
53
+ callback=None,
54
+ **kwargs,
55
+ ):
56
+ device = device or transformer.device
57
+
58
+ if batch_size is None:
59
+ batch_size = int(prompt_embeds.shape[0])
60
+
61
+ latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32)
62
+
63
+ B, C, T, H, W = latents.shape
64
+ seq_length = T * H * W // 4
65
+
66
+ if shift is None:
67
+ mu = calculate_flux_mu(seq_length, exp_max=7.0)
68
+ else:
69
+ mu = math.log(shift)
70
+
71
+ sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device)
72
+
73
+ k_model = fm_wrapper(transformer)
74
+
75
+ if initial_latent is not None:
76
+ sigmas = sigmas * strength
77
+ first_sigma = sigmas[0].to(device=device, dtype=torch.float32)
78
+ initial_latent = initial_latent.to(device=device, dtype=torch.float32)
79
+ latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma
80
+
81
+ if concat_latent is not None:
82
+ concat_latent = concat_latent.to(latents)
83
+
84
+ distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype)
85
+
86
+ prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size)
87
+ prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size)
88
+ prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size)
89
+ negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size)
90
+ negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size)
91
+ negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size)
92
+ concat_latent = repeat_to_batch_size(concat_latent, batch_size)
93
+
94
+ sampler_kwargs = dict(
95
+ dtype=dtype,
96
+ cfg_scale=real_guidance_scale,
97
+ cfg_rescale=guidance_rescale,
98
+ concat_latent=concat_latent,
99
+ positive=dict(
100
+ pooled_projections=prompt_poolers,
101
+ encoder_hidden_states=prompt_embeds,
102
+ encoder_attention_mask=prompt_embeds_mask,
103
+ guidance=distilled_guidance,
104
+ **kwargs,
105
+ ),
106
+ negative=dict(
107
+ pooled_projections=negative_prompt_poolers,
108
+ encoder_hidden_states=negative_prompt_embeds,
109
+ encoder_attention_mask=negative_prompt_embeds_mask,
110
+ guidance=distilled_guidance,
111
+ **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}),
112
+ )
113
+ )
114
+
115
+ if sampler == 'unipc':
116
+ results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback)
117
+ else:
118
+ raise NotImplementedError(f'Sampler {sampler} is not supported.')
119
+
120
+ return results
diffusers_helper/thread_utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ from threading import Thread, Lock
4
+
5
+
6
+ class Listener:
7
+ task_queue = []
8
+ lock = Lock()
9
+ thread = None
10
+
11
+ @classmethod
12
+ def _process_tasks(cls):
13
+ while True:
14
+ task = None
15
+ with cls.lock:
16
+ if cls.task_queue:
17
+ task = cls.task_queue.pop(0)
18
+
19
+ if task is None:
20
+ time.sleep(0.001)
21
+ continue
22
+
23
+ func, args, kwargs = task
24
+ try:
25
+ func(*args, **kwargs)
26
+ except Exception as e:
27
+ print(f"Error in listener thread: {e}")
28
+
29
+ @classmethod
30
+ def add_task(cls, func, *args, **kwargs):
31
+ with cls.lock:
32
+ cls.task_queue.append((func, args, kwargs))
33
+
34
+ if cls.thread is None:
35
+ cls.thread = Thread(target=cls._process_tasks, daemon=True)
36
+ cls.thread.start()
37
+
38
+
39
+ def async_run(func, *args, **kwargs):
40
+ Listener.add_task(func, *args, **kwargs)
41
+
42
+
43
+ class FIFOQueue:
44
+ def __init__(self):
45
+ self.queue = []
46
+ self.lock = Lock()
47
+
48
+ def push(self, item):
49
+ with self.lock:
50
+ self.queue.append(item)
51
+
52
+ def pop(self):
53
+ with self.lock:
54
+ if self.queue:
55
+ return self.queue.pop(0)
56
+ return None
57
+
58
+ def top(self):
59
+ with self.lock:
60
+ if self.queue:
61
+ return self.queue[0]
62
+ return None
63
+
64
+ def next(self):
65
+ while True:
66
+ with self.lock:
67
+ if self.queue:
68
+ return self.queue.pop(0)
69
+
70
+ time.sleep(0.001)
71
+
72
+
73
+ class AsyncStream:
74
+ def __init__(self):
75
+ self.input_queue = FIFOQueue()
76
+ self.output_queue = FIFOQueue()
diffusers_helper/utils.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import random
5
+ import glob
6
+ import torch
7
+ import einops
8
+ import numpy as np
9
+ import datetime
10
+ import torchvision
11
+
12
+ import safetensors.torch as sf
13
+ from PIL import Image
14
+
15
+
16
+ def min_resize(x, m):
17
+ if x.shape[0] < x.shape[1]:
18
+ s0 = m
19
+ s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
20
+ else:
21
+ s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
22
+ s1 = m
23
+ new_max = max(s1, s0)
24
+ raw_max = max(x.shape[0], x.shape[1])
25
+ if new_max < raw_max:
26
+ interpolation = cv2.INTER_AREA
27
+ else:
28
+ interpolation = cv2.INTER_LANCZOS4
29
+ y = cv2.resize(x, (s1, s0), interpolation=interpolation)
30
+ return y
31
+
32
+
33
+ def d_resize(x, y):
34
+ H, W, C = y.shape
35
+ new_min = min(H, W)
36
+ raw_min = min(x.shape[0], x.shape[1])
37
+ if new_min < raw_min:
38
+ interpolation = cv2.INTER_AREA
39
+ else:
40
+ interpolation = cv2.INTER_LANCZOS4
41
+ y = cv2.resize(x, (W, H), interpolation=interpolation)
42
+ return y
43
+
44
+
45
+ def resize_and_center_crop(image, target_width, target_height):
46
+ if target_height == image.shape[0] and target_width == image.shape[1]:
47
+ return image
48
+
49
+ pil_image = Image.fromarray(image)
50
+ original_width, original_height = pil_image.size
51
+ scale_factor = max(target_width / original_width, target_height / original_height)
52
+ resized_width = int(round(original_width * scale_factor))
53
+ resized_height = int(round(original_height * scale_factor))
54
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
55
+ left = (resized_width - target_width) / 2
56
+ top = (resized_height - target_height) / 2
57
+ right = (resized_width + target_width) / 2
58
+ bottom = (resized_height + target_height) / 2
59
+ cropped_image = resized_image.crop((left, top, right, bottom))
60
+ return np.array(cropped_image)
61
+
62
+
63
+ def resize_and_center_crop_pytorch(image, target_width, target_height):
64
+ B, C, H, W = image.shape
65
+
66
+ if H == target_height and W == target_width:
67
+ return image
68
+
69
+ scale_factor = max(target_width / W, target_height / H)
70
+ resized_width = int(round(W * scale_factor))
71
+ resized_height = int(round(H * scale_factor))
72
+
73
+ resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
74
+
75
+ top = (resized_height - target_height) // 2
76
+ left = (resized_width - target_width) // 2
77
+ cropped = resized[:, :, top:top + target_height, left:left + target_width]
78
+
79
+ return cropped
80
+
81
+
82
+ def resize_without_crop(image, target_width, target_height):
83
+ if target_height == image.shape[0] and target_width == image.shape[1]:
84
+ return image
85
+
86
+ pil_image = Image.fromarray(image)
87
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
88
+ return np.array(resized_image)
89
+
90
+
91
+ def just_crop(image, w, h):
92
+ if h == image.shape[0] and w == image.shape[1]:
93
+ return image
94
+
95
+ original_height, original_width = image.shape[:2]
96
+ k = min(original_height / h, original_width / w)
97
+ new_width = int(round(w * k))
98
+ new_height = int(round(h * k))
99
+ x_start = (original_width - new_width) // 2
100
+ y_start = (original_height - new_height) // 2
101
+ cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
102
+ return cropped_image
103
+
104
+
105
+ def write_to_json(data, file_path):
106
+ temp_file_path = file_path + ".tmp"
107
+ with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
108
+ json.dump(data, temp_file, indent=4)
109
+ os.replace(temp_file_path, file_path)
110
+ return
111
+
112
+
113
+ def read_from_json(file_path):
114
+ with open(file_path, 'rt', encoding='utf-8') as file:
115
+ data = json.load(file)
116
+ return data
117
+
118
+
119
+ def get_active_parameters(m):
120
+ return {k: v for k, v in m.named_parameters() if v.requires_grad}
121
+
122
+
123
+ def cast_training_params(m, dtype=torch.float32):
124
+ result = {}
125
+ for n, param in m.named_parameters():
126
+ if param.requires_grad:
127
+ param.data = param.to(dtype)
128
+ result[n] = param
129
+ return result
130
+
131
+
132
+ def separate_lora_AB(parameters, B_patterns=None):
133
+ parameters_normal = {}
134
+ parameters_B = {}
135
+
136
+ if B_patterns is None:
137
+ B_patterns = ['.lora_B.', '__zero__']
138
+
139
+ for k, v in parameters.items():
140
+ if any(B_pattern in k for B_pattern in B_patterns):
141
+ parameters_B[k] = v
142
+ else:
143
+ parameters_normal[k] = v
144
+
145
+ return parameters_normal, parameters_B
146
+
147
+
148
+ def set_attr_recursive(obj, attr, value):
149
+ attrs = attr.split(".")
150
+ for name in attrs[:-1]:
151
+ obj = getattr(obj, name)
152
+ setattr(obj, attrs[-1], value)
153
+ return
154
+
155
+
156
+ def print_tensor_list_size(tensors):
157
+ total_size = 0
158
+ total_elements = 0
159
+
160
+ if isinstance(tensors, dict):
161
+ tensors = tensors.values()
162
+
163
+ for tensor in tensors:
164
+ total_size += tensor.nelement() * tensor.element_size()
165
+ total_elements += tensor.nelement()
166
+
167
+ total_size_MB = total_size / (1024 ** 2)
168
+ total_elements_B = total_elements / 1e9
169
+
170
+ print(f"Total number of tensors: {len(tensors)}")
171
+ print(f"Total size of tensors: {total_size_MB:.2f} MB")
172
+ print(f"Total number of parameters: {total_elements_B:.3f} billion")
173
+ return
174
+
175
+
176
+ @torch.no_grad()
177
+ def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
178
+ batch_size = a.size(0)
179
+
180
+ if b is None:
181
+ b = torch.zeros_like(a)
182
+
183
+ if mask_a is None:
184
+ mask_a = torch.rand(batch_size) < probability_a
185
+
186
+ mask_a = mask_a.to(a.device)
187
+ mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
188
+ result = torch.where(mask_a, a, b)
189
+ return result
190
+
191
+
192
+ @torch.no_grad()
193
+ def zero_module(module):
194
+ for p in module.parameters():
195
+ p.detach().zero_()
196
+ return module
197
+
198
+
199
+ @torch.no_grad()
200
+ def supress_lower_channels(m, k, alpha=0.01):
201
+ data = m.weight.data.clone()
202
+
203
+ assert int(data.shape[1]) >= k
204
+
205
+ data[:, :k] = data[:, :k] * alpha
206
+ m.weight.data = data.contiguous().clone()
207
+ return m
208
+
209
+
210
+ def freeze_module(m):
211
+ if not hasattr(m, '_forward_inside_frozen_module'):
212
+ m._forward_inside_frozen_module = m.forward
213
+ m.requires_grad_(False)
214
+ m.forward = torch.no_grad()(m.forward)
215
+ return m
216
+
217
+
218
+ def get_latest_safetensors(folder_path):
219
+ safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
220
+
221
+ if not safetensors_files:
222
+ raise ValueError('No file to resume!')
223
+
224
+ latest_file = max(safetensors_files, key=os.path.getmtime)
225
+ latest_file = os.path.abspath(os.path.realpath(latest_file))
226
+ return latest_file
227
+
228
+
229
+ def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
230
+ tags = tags_str.split(', ')
231
+ tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
232
+ prompt = ', '.join(tags)
233
+ return prompt
234
+
235
+
236
+ def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
237
+ numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
238
+ if round_to_int:
239
+ numbers = np.round(numbers).astype(int)
240
+ return numbers.tolist()
241
+
242
+
243
+ def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
244
+ edges = np.linspace(0, 1, n + 1)
245
+ points = np.random.uniform(edges[:-1], edges[1:])
246
+ numbers = inclusive + (exclusive - inclusive) * points
247
+ if round_to_int:
248
+ numbers = np.round(numbers).astype(int)
249
+ return numbers.tolist()
250
+
251
+
252
+ def soft_append_bcthw(history, current, overlap=0):
253
+ if overlap <= 0:
254
+ return torch.cat([history, current], dim=2)
255
+
256
+ assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
257
+ assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
258
+
259
+ weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
260
+ blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
261
+ output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
262
+
263
+ return output.to(history)
264
+
265
+
266
+ def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
267
+ b, c, t, h, w = x.shape
268
+
269
+ per_row = b
270
+ for p in [6, 5, 4, 3, 2]:
271
+ if b % p == 0:
272
+ per_row = p
273
+ break
274
+
275
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
276
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
277
+ x = x.detach().cpu().to(torch.uint8)
278
+ x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
279
+ torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
280
+ return x
281
+
282
+
283
+ def save_bcthw_as_png(x, output_filename):
284
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
285
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
286
+ x = x.detach().cpu().to(torch.uint8)
287
+ x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
288
+ torchvision.io.write_png(x, output_filename)
289
+ return output_filename
290
+
291
+
292
+ def save_bchw_as_png(x, output_filename):
293
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
294
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
295
+ x = x.detach().cpu().to(torch.uint8)
296
+ x = einops.rearrange(x, 'b c h w -> c h (b w)')
297
+ torchvision.io.write_png(x, output_filename)
298
+ return output_filename
299
+
300
+
301
+ def add_tensors_with_padding(tensor1, tensor2):
302
+ if tensor1.shape == tensor2.shape:
303
+ return tensor1 + tensor2
304
+
305
+ shape1 = tensor1.shape
306
+ shape2 = tensor2.shape
307
+
308
+ new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
309
+
310
+ padded_tensor1 = torch.zeros(new_shape)
311
+ padded_tensor2 = torch.zeros(new_shape)
312
+
313
+ padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
314
+ padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
315
+
316
+ result = padded_tensor1 + padded_tensor2
317
+ return result
318
+
319
+
320
+ def print_free_mem():
321
+ torch.cuda.empty_cache()
322
+ free_mem, total_mem = torch.cuda.mem_get_info(0)
323
+ free_mem_mb = free_mem / (1024 ** 2)
324
+ total_mem_mb = total_mem / (1024 ** 2)
325
+ print(f"Free memory: {free_mem_mb:.2f} MB")
326
+ print(f"Total memory: {total_mem_mb:.2f} MB")
327
+ return
328
+
329
+
330
+ def print_gpu_parameters(device, state_dict, log_count=1):
331
+ summary = {"device": device, "keys_count": len(state_dict)}
332
+
333
+ logged_params = {}
334
+ for i, (key, tensor) in enumerate(state_dict.items()):
335
+ if i >= log_count:
336
+ break
337
+ logged_params[key] = tensor.flatten()[:3].tolist()
338
+
339
+ summary["params"] = logged_params
340
+
341
+ print(str(summary))
342
+ return
343
+
344
+
345
+ def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
346
+ from PIL import Image, ImageDraw, ImageFont
347
+
348
+ txt = Image.new("RGB", (width, height), color="white")
349
+ draw = ImageDraw.Draw(txt)
350
+ font = ImageFont.truetype(font_path, size=size)
351
+
352
+ if text == '':
353
+ return np.array(txt)
354
+
355
+ # Split text into lines that fit within the image width
356
+ lines = []
357
+ words = text.split()
358
+ current_line = words[0]
359
+
360
+ for word in words[1:]:
361
+ line_with_word = f"{current_line} {word}"
362
+ if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
363
+ current_line = line_with_word
364
+ else:
365
+ lines.append(current_line)
366
+ current_line = word
367
+
368
+ lines.append(current_line)
369
+
370
+ # Draw the text line by line
371
+ y = 0
372
+ line_height = draw.textbbox((0, 0), "A", font=font)[3]
373
+
374
+ for line in lines:
375
+ if y + line_height > height:
376
+ break # stop drawing if the next line will be outside the image
377
+ draw.text((0, y), line, fill="black", font=font)
378
+ y += line_height
379
+
380
+ return np.array(txt)
381
+
382
+
383
+ def blue_mark(x):
384
+ x = x.copy()
385
+ c = x[:, :, 2]
386
+ b = cv2.blur(c, (9, 9))
387
+ x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
388
+ return x
389
+
390
+
391
+ def green_mark(x):
392
+ x = x.copy()
393
+ x[:, :, 2] = -1
394
+ x[:, :, 0] = -1
395
+ return x
396
+
397
+
398
+ def frame_mark(x):
399
+ x = x.copy()
400
+ x[:64] = -1
401
+ x[-64:] = -1
402
+ x[:, :8] = 1
403
+ x[:, -8:] = 1
404
+ return x
405
+
406
+
407
+ @torch.inference_mode()
408
+ def pytorch2numpy(imgs):
409
+ results = []
410
+ for x in imgs:
411
+ y = x.movedim(0, -1)
412
+ y = y * 127.5 + 127.5
413
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
414
+ results.append(y)
415
+ return results
416
+
417
+
418
+ @torch.inference_mode()
419
+ def numpy2pytorch(imgs):
420
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
421
+ h = h.movedim(-1, 1)
422
+ return h
423
+
424
+
425
+ @torch.no_grad()
426
+ def duplicate_prefix_to_suffix(x, count, zero_out=False):
427
+ if zero_out:
428
+ return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
429
+ else:
430
+ return torch.cat([x, x[:count]], dim=0)
431
+
432
+
433
+ def weighted_mse(a, b, weight):
434
+ return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
435
+
436
+
437
+ def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
438
+ x = (x - x_min) / (x_max - x_min)
439
+ x = max(0.0, min(x, 1.0))
440
+ x = x ** sigma
441
+ return y_min + x * (y_max - y_min)
442
+
443
+
444
+ def expand_to_dims(x, target_dims):
445
+ return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
446
+
447
+
448
+ def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
449
+ if tensor is None:
450
+ return None
451
+
452
+ first_dim = tensor.shape[0]
453
+
454
+ if first_dim == batch_size:
455
+ return tensor
456
+
457
+ if batch_size % first_dim != 0:
458
+ raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
459
+
460
+ repeat_times = batch_size // first_dim
461
+
462
+ return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
463
+
464
+
465
+ def dim5(x):
466
+ return expand_to_dims(x, 5)
467
+
468
+
469
+ def dim4(x):
470
+ return expand_to_dims(x, 4)
471
+
472
+
473
+ def dim3(x):
474
+ return expand_to_dims(x, 3)
475
+
476
+
477
+ def crop_or_pad_yield_mask(x, length):
478
+ B, F, C = x.shape
479
+ device = x.device
480
+ dtype = x.dtype
481
+
482
+ if F < length:
483
+ y = torch.zeros((B, length, C), dtype=dtype, device=device)
484
+ mask = torch.zeros((B, length), dtype=torch.bool, device=device)
485
+ y[:, :F, :] = x
486
+ mask[:, :F] = True
487
+ return y, mask
488
+
489
+ return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
490
+
491
+
492
+ def extend_dim(x, dim, minimal_length, zero_pad=False):
493
+ original_length = int(x.shape[dim])
494
+
495
+ if original_length >= minimal_length:
496
+ return x
497
+
498
+ if zero_pad:
499
+ padding_shape = list(x.shape)
500
+ padding_shape[dim] = minimal_length - original_length
501
+ padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
502
+ else:
503
+ idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
504
+ last_element = x[idx]
505
+ padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
506
+
507
+ return torch.cat([x, padding], dim=dim)
508
+
509
+
510
+ def lazy_positional_encoding(t, repeats=None):
511
+ if not isinstance(t, list):
512
+ t = [t]
513
+
514
+ from diffusers.models.embeddings import get_timestep_embedding
515
+
516
+ te = torch.tensor(t)
517
+ te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
518
+
519
+ if repeats is None:
520
+ return te
521
+
522
+ te = te[:, None, :].expand(-1, repeats, -1)
523
+
524
+ return te
525
+
526
+
527
+ def state_dict_offset_merge(A, B, C=None):
528
+ result = {}
529
+ keys = A.keys()
530
+
531
+ for key in keys:
532
+ A_value = A[key]
533
+ B_value = B[key].to(A_value)
534
+
535
+ if C is None:
536
+ result[key] = A_value + B_value
537
+ else:
538
+ C_value = C[key].to(A_value)
539
+ result[key] = A_value + B_value - C_value
540
+
541
+ return result
542
+
543
+
544
+ def state_dict_weighted_merge(state_dicts, weights):
545
+ if len(state_dicts) != len(weights):
546
+ raise ValueError("Number of state dictionaries must match number of weights")
547
+
548
+ if not state_dicts:
549
+ return {}
550
+
551
+ total_weight = sum(weights)
552
+
553
+ if total_weight == 0:
554
+ raise ValueError("Sum of weights cannot be zero")
555
+
556
+ normalized_weights = [w / total_weight for w in weights]
557
+
558
+ keys = state_dicts[0].keys()
559
+ result = {}
560
+
561
+ for key in keys:
562
+ result[key] = state_dicts[0][key] * normalized_weights[0]
563
+
564
+ for i in range(1, len(state_dicts)):
565
+ state_dict_value = state_dicts[i][key].to(result[key])
566
+ result[key] += state_dict_value * normalized_weights[i]
567
+
568
+ return result
569
+
570
+
571
+ def group_files_by_folder(all_files):
572
+ grouped_files = {}
573
+
574
+ for file in all_files:
575
+ folder_name = os.path.basename(os.path.dirname(file))
576
+ if folder_name not in grouped_files:
577
+ grouped_files[folder_name] = []
578
+ grouped_files[folder_name].append(file)
579
+
580
+ list_of_lists = list(grouped_files.values())
581
+ return list_of_lists
582
+
583
+
584
+ def generate_timestamp():
585
+ now = datetime.datetime.now()
586
+ timestamp = now.strftime('%y%m%d_%H%M%S')
587
+ milliseconds = f"{int(now.microsecond / 1000):03d}"
588
+ random_number = random.randint(0, 9999)
589
+ return f"{timestamp}_{milliseconds}_{random_number}"
590
+
591
+
592
+ def write_PIL_image_with_png_info(image, metadata, path):
593
+ from PIL.PngImagePlugin import PngInfo
594
+
595
+ png_info = PngInfo()
596
+ for key, value in metadata.items():
597
+ png_info.add_text(key, value)
598
+
599
+ image.save(path, "PNG", pnginfo=png_info)
600
+ return image
601
+
602
+
603
+ def torch_safe_save(content, path):
604
+ torch.save(content, path + '_tmp')
605
+ os.replace(path + '_tmp', path)
606
+ return path
607
+
608
+
609
+ def move_optimizer_to_device(optimizer, device):
610
+ for state in optimizer.state.values():
611
+ for k, v in state.items():
612
+ if isinstance(v, torch.Tensor):
613
+ state[k] = v.to(device)
modules/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # modules/__init__.py
2
+
modules/generators/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .original_generator import OriginalModelGenerator
2
+ from .f1_generator import F1ModelGenerator
3
+
4
+ def create_model_generator(model_type, **kwargs):
5
+ """
6
+ Create a model generator based on the model type.
7
+
8
+ Args:
9
+ model_type: The type of model to create ("Original" or "F1")
10
+ **kwargs: Additional arguments to pass to the model generator constructor
11
+
12
+ Returns:
13
+ A model generator instance
14
+
15
+ Raises:
16
+ ValueError: If the model type is not supported
17
+ """
18
+ if model_type == "Original":
19
+ return OriginalModelGenerator(**kwargs)
20
+ elif model_type == "F1":
21
+ return F1ModelGenerator(**kwargs)
22
+ else:
23
+ raise ValueError(f"Unsupported model type: {model_type}")
modules/generators/base_generator.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from abc import ABC, abstractmethod
3
+ from diffusers_helper import lora_utils
4
+
5
+ class BaseModelGenerator(ABC):
6
+ """
7
+ Base class for model generators.
8
+ This defines the common interface that all model generators must implement.
9
+ """
10
+
11
+ def __init__(self,
12
+ text_encoder,
13
+ text_encoder_2,
14
+ tokenizer,
15
+ tokenizer_2,
16
+ vae,
17
+ image_encoder,
18
+ feature_extractor,
19
+ high_vram=False,
20
+ prompt_embedding_cache=None,
21
+ settings=None):
22
+ """
23
+ Initialize the base model generator.
24
+
25
+ Args:
26
+ text_encoder: The text encoder model
27
+ text_encoder_2: The second text encoder model
28
+ tokenizer: The tokenizer for the first text encoder
29
+ tokenizer_2: The tokenizer for the second text encoder
30
+ vae: The VAE model
31
+ image_encoder: The image encoder model
32
+ feature_extractor: The feature extractor
33
+ high_vram: Whether high VRAM mode is enabled
34
+ prompt_embedding_cache: Cache for prompt embeddings
35
+ settings: Application settings
36
+ """
37
+ self.text_encoder = text_encoder
38
+ self.text_encoder_2 = text_encoder_2
39
+ self.tokenizer = tokenizer
40
+ self.tokenizer_2 = tokenizer_2
41
+ self.vae = vae
42
+ self.image_encoder = image_encoder
43
+ self.feature_extractor = feature_extractor
44
+ self.high_vram = high_vram
45
+ self.prompt_embedding_cache = prompt_embedding_cache or {}
46
+ self.settings = settings
47
+ self.transformer = None
48
+ self.gpu = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ self.cpu = torch.device("cpu")
50
+
51
+ @abstractmethod
52
+ def load_model(self):
53
+ """
54
+ Load the transformer model.
55
+ This method should be implemented by each specific model generator.
56
+ """
57
+ pass
58
+
59
+ @abstractmethod
60
+ def get_model_name(self):
61
+ """
62
+ Get the name of the model.
63
+ This method should be implemented by each specific model generator.
64
+ """
65
+ pass
66
+
67
+ def unload_loras(self):
68
+ """
69
+ Unload all LoRAs from the transformer model.
70
+ """
71
+ if self.transformer is not None:
72
+ print(f"Unloading all LoRAs from {self.get_model_name()} model")
73
+ self.transformer = lora_utils.unload_all_loras(self.transformer)
74
+ self.verify_lora_state("After unloading LoRAs")
75
+ import gc
76
+ gc.collect()
77
+ if torch.cuda.is_available():
78
+ torch.cuda.empty_cache()
79
+
80
+ def verify_lora_state(self, label=""):
81
+ """
82
+ Debug function to verify the state of LoRAs in the transformer model.
83
+ """
84
+ if self.transformer is None:
85
+ print(f"[{label}] Transformer is None, cannot verify LoRA state")
86
+ return
87
+
88
+ has_loras = False
89
+ if hasattr(self.transformer, 'peft_config'):
90
+ adapter_names = list(self.transformer.peft_config.keys()) if self.transformer.peft_config else []
91
+ if adapter_names:
92
+ has_loras = True
93
+ print(f"[{label}] Transformer has LoRAs: {', '.join(adapter_names)}")
94
+ else:
95
+ print(f"[{label}] Transformer has no LoRAs in peft_config")
96
+ else:
97
+ print(f"[{label}] Transformer has no peft_config attribute")
98
+
99
+ # Check for any LoRA modules
100
+ for name, module in self.transformer.named_modules():
101
+ if hasattr(module, 'lora_A') and module.lora_A:
102
+ has_loras = True
103
+ # print(f"[{label}] Found lora_A in module {name}")
104
+ if hasattr(module, 'lora_B') and module.lora_B:
105
+ has_loras = True
106
+ # print(f"[{label}] Found lora_B in module {name}")
107
+
108
+ if not has_loras:
109
+ print(f"[{label}] No LoRA components found in transformer")
110
+
111
+ def move_lora_adapters_to_device(self, target_device):
112
+ """
113
+ Move all LoRA adapters in the transformer model to the specified device.
114
+ This handles the PEFT implementation of LoRA.
115
+ """
116
+ if self.transformer is None:
117
+ return
118
+
119
+ print(f"Moving all LoRA adapters to {target_device}")
120
+
121
+ # First, find all modules with LoRA adapters
122
+ lora_modules = []
123
+ for name, module in self.transformer.named_modules():
124
+ if hasattr(module, 'active_adapter') and hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
125
+ lora_modules.append((name, module))
126
+
127
+ # Now move all LoRA components to the target device
128
+ for name, module in lora_modules:
129
+ # Get the active adapter name
130
+ active_adapter = module.active_adapter
131
+
132
+ # Move the LoRA layers to the target device
133
+ if active_adapter is not None:
134
+ if isinstance(module.lora_A, torch.nn.ModuleDict):
135
+ # Handle ModuleDict case (PEFT implementation)
136
+ for adapter_name in list(module.lora_A.keys()):
137
+ # Move lora_A
138
+ if adapter_name in module.lora_A:
139
+ module.lora_A[adapter_name] = module.lora_A[adapter_name].to(target_device)
140
+
141
+ # Move lora_B
142
+ if adapter_name in module.lora_B:
143
+ module.lora_B[adapter_name] = module.lora_B[adapter_name].to(target_device)
144
+
145
+ # Move scaling
146
+ if hasattr(module, 'scaling') and isinstance(module.scaling, dict) and adapter_name in module.scaling:
147
+ if isinstance(module.scaling[adapter_name], torch.Tensor):
148
+ module.scaling[adapter_name] = module.scaling[adapter_name].to(target_device)
149
+ else:
150
+ # Handle direct attribute case
151
+ if hasattr(module, 'lora_A') and module.lora_A is not None:
152
+ module.lora_A = module.lora_A.to(target_device)
153
+ if hasattr(module, 'lora_B') and module.lora_B is not None:
154
+ module.lora_B = module.lora_B.to(target_device)
155
+ if hasattr(module, 'scaling') and module.scaling is not None:
156
+ if isinstance(module.scaling, torch.Tensor):
157
+ module.scaling = module.scaling.to(target_device)
158
+
159
+ print(f"Moved all LoRA adapters to {target_device}")
160
+
161
+ def load_loras(self, selected_loras, lora_folder, lora_loaded_names, lora_values=None):
162
+ """
163
+ Load LoRAs into the transformer model.
164
+
165
+ Args:
166
+ selected_loras: List of LoRA names to load
167
+ lora_folder: Folder containing the LoRA files
168
+ lora_loaded_names: List of loaded LoRA names
169
+ lora_values: Optional list of LoRA strength values
170
+ """
171
+ if self.transformer is None:
172
+ print("Cannot load LoRAs: Transformer model is not loaded")
173
+ return
174
+
175
+ import os
176
+
177
+ # Ensure all LoRAs are unloaded first
178
+ self.unload_loras()
179
+
180
+ # Load each selected LoRA
181
+ for lora_name in selected_loras:
182
+ try:
183
+ idx = lora_loaded_names.index(lora_name)
184
+ lora_file = None
185
+ for ext in [".safetensors", ".pt"]:
186
+ # Find any file that starts with the lora_name and ends with the extension
187
+ matching_files = [f for f in os.listdir(lora_folder)
188
+ if f.startswith(lora_name) and f.endswith(ext)]
189
+ if matching_files:
190
+ lora_file = matching_files[0] # Use the first matching file
191
+ break
192
+
193
+ if lora_file:
194
+ print(f"Loading LoRA {lora_file} to {self.get_model_name()} model")
195
+ self.transformer = lora_utils.load_lora(self.transformer, lora_folder, lora_file)
196
+
197
+ # Set LoRA strength if provided
198
+ if lora_values and idx < len(lora_values):
199
+ lora_strength = float(lora_values[idx])
200
+ print(f"Setting LoRA {lora_name} strength to {lora_strength}")
201
+
202
+ # Set scaling for this LoRA by iterating through modules
203
+ for name, module in self.transformer.named_modules():
204
+ if hasattr(module, 'scaling'):
205
+ if isinstance(module.scaling, dict):
206
+ # Handle ModuleDict case (PEFT implementation)
207
+ if lora_name in module.scaling:
208
+ if isinstance(module.scaling[lora_name], torch.Tensor):
209
+ module.scaling[lora_name] = torch.tensor(
210
+ lora_strength, device=module.scaling[lora_name].device
211
+ )
212
+ else:
213
+ module.scaling[lora_name] = lora_strength
214
+ else:
215
+ # Handle direct attribute case for scaling if needed
216
+ if isinstance(module.scaling, torch.Tensor):
217
+ module.scaling = torch.tensor(
218
+ lora_strength, device=module.scaling.device
219
+ )
220
+ else:
221
+ module.scaling = lora_strength
222
+ else:
223
+ print(f"LoRA file for {lora_name} not found!")
224
+ except Exception as e:
225
+ print(f"Error loading LoRA {lora_name}: {e}")
226
+
227
+ # Verify LoRA state after loading
228
+ self.verify_lora_state("After loading LoRAs")
modules/generators/f1_generator.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
3
+ from diffusers_helper.memory import DynamicSwapInstaller
4
+ from .base_generator import BaseModelGenerator
5
+
6
+ class F1ModelGenerator(BaseModelGenerator):
7
+ """
8
+ Model generator for the F1 HunyuanVideo model.
9
+ """
10
+
11
+ def __init__(self, **kwargs):
12
+ """
13
+ Initialize the F1 model generator.
14
+ """
15
+ super().__init__(**kwargs)
16
+ self.model_name = "F1"
17
+ self.model_path = 'lllyasviel/FramePack_F1_I2V_HY_20250503'
18
+
19
+ def get_model_name(self):
20
+ """
21
+ Get the name of the model.
22
+ """
23
+ return self.model_name
24
+
25
+ def load_model(self):
26
+ """
27
+ Load the F1 transformer model.
28
+ """
29
+ print(f"Loading {self.model_name} Transformer...")
30
+
31
+ # Create the transformer model
32
+ self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
33
+ self.model_path,
34
+ torch_dtype=torch.bfloat16
35
+ ).cpu()
36
+
37
+ # Configure the model
38
+ self.transformer.eval()
39
+ self.transformer.to(dtype=torch.bfloat16)
40
+ self.transformer.requires_grad_(False)
41
+
42
+ # Set up dynamic swap if not in high VRAM mode
43
+ if not self.high_vram:
44
+ DynamicSwapInstaller.install_model(self.transformer, device=self.gpu)
45
+
46
+ print(f"{self.model_name} Transformer Loaded.")
47
+ return self.transformer
48
+
49
+ def prepare_history_latents(self, height, width):
50
+ """
51
+ Prepare the history latents tensor for the F1 model.
52
+
53
+ Args:
54
+ height: The height of the image
55
+ width: The width of the image
56
+
57
+ Returns:
58
+ The initialized history latents tensor
59
+ """
60
+ return torch.zeros(
61
+ size=(1, 16, 16 + 2 + 1, height // 8, width // 8),
62
+ dtype=torch.float32
63
+ ).cpu()
64
+
65
+ def initialize_with_start_latent(self, history_latents, start_latent):
66
+ """
67
+ Initialize the history latents with the start latent for the F1 model.
68
+
69
+ Args:
70
+ history_latents: The history latents
71
+ start_latent: The start latent
72
+
73
+ Returns:
74
+ The initialized history latents
75
+ """
76
+ # Add the start frame to history_latents
77
+ return torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
78
+
79
+ def get_latent_paddings(self, total_latent_sections):
80
+ """
81
+ Get the latent paddings for the F1 model.
82
+
83
+ Args:
84
+ total_latent_sections: The total number of latent sections
85
+
86
+ Returns:
87
+ A list of latent paddings
88
+ """
89
+ # F1 model uses a fixed approach with just 0 for last section and 1 for others
90
+ return [1] * (total_latent_sections - 1) + [0]
91
+
92
+ def prepare_indices(self, latent_padding_size, latent_window_size):
93
+ """
94
+ Prepare the indices for the F1 model.
95
+
96
+ Args:
97
+ latent_padding_size: The size of the latent padding
98
+ latent_window_size: The size of the latent window
99
+
100
+ Returns:
101
+ A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices)
102
+ """
103
+ # F1 model uses a different indices approach
104
+ # latent_window_sizeが4.5の場合は特別に5を使用
105
+ effective_window_size = 5 if latent_window_size == 4.5 else int(latent_window_size)
106
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
107
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
108
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
109
+
110
+ return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices
111
+
112
+ def prepare_clean_latents(self, start_latent, history_latents):
113
+ """
114
+ Prepare the clean latents for the F1 model.
115
+
116
+ Args:
117
+ start_latent: The start latent
118
+ history_latents: The history latents
119
+
120
+ Returns:
121
+ A tuple of (clean_latents, clean_latents_2x, clean_latents_4x)
122
+ """
123
+ # For F1, we take the last frames for clean latents
124
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2)
125
+ # For F1, we prepend the start latent to clean_latents_1x
126
+ clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
127
+
128
+ return clean_latents, clean_latents_2x, clean_latents_4x
129
+
130
+ def update_history_latents(self, history_latents, generated_latents):
131
+ """
132
+ Update the history latents with the generated latents for the F1 model.
133
+
134
+ Args:
135
+ history_latents: The history latents
136
+ generated_latents: The generated latents
137
+
138
+ Returns:
139
+ The updated history latents
140
+ """
141
+ # For F1, we append new frames to the end
142
+ return torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
143
+
144
+ def get_real_history_latents(self, history_latents, total_generated_latent_frames):
145
+ """
146
+ Get the real history latents for the F1 model.
147
+
148
+ Args:
149
+ history_latents: The history latents
150
+ total_generated_latent_frames: The total number of generated latent frames
151
+
152
+ Returns:
153
+ The real history latents
154
+ """
155
+ # For F1, we take frames from the end
156
+ return history_latents[:, :, -total_generated_latent_frames:, :, :]
157
+
158
+ def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames):
159
+ """
160
+ Update the history pixels with the current pixels for the F1 model.
161
+
162
+ Args:
163
+ history_pixels: The history pixels
164
+ current_pixels: The current pixels
165
+ overlapped_frames: The number of overlapped frames
166
+
167
+ Returns:
168
+ The updated history pixels
169
+ """
170
+ from diffusers_helper.utils import soft_append_bcthw
171
+ # For F1 model, history_pixels is first, current_pixels is second
172
+ return soft_append_bcthw(history_pixels, current_pixels, overlapped_frames)
173
+
174
+ def get_section_latent_frames(self, latent_window_size, is_last_section):
175
+ """
176
+ Get the number of section latent frames for the F1 model.
177
+
178
+ Args:
179
+ latent_window_size: The size of the latent window
180
+ is_last_section: Whether this is the last section
181
+
182
+ Returns:
183
+ The number of section latent frames
184
+ """
185
+ return latent_window_size * 2
186
+
187
+ def get_current_pixels(self, real_history_latents, section_latent_frames, vae):
188
+ """
189
+ Get the current pixels for the F1 model.
190
+
191
+ Args:
192
+ real_history_latents: The real history latents
193
+ section_latent_frames: The number of section latent frames
194
+ vae: The VAE model
195
+
196
+ Returns:
197
+ The current pixels
198
+ """
199
+ from diffusers_helper.hunyuan import vae_decode
200
+ # For F1, we take frames from the end
201
+ return vae_decode(real_history_latents[:, :, -section_latent_frames:], vae).cpu()
202
+
203
+ def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt):
204
+ """
205
+ Format the position description for the F1 model.
206
+
207
+ Args:
208
+ total_generated_latent_frames: The total number of generated latent frames
209
+ current_pos: The current position in seconds
210
+ original_pos: The original position in seconds
211
+ current_prompt: The current prompt
212
+
213
+ Returns:
214
+ The formatted position description
215
+ """
216
+ return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, '
217
+ f'Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). '
218
+ f'Current position: {current_pos:.2f}s. '
219
+ f'using prompt: {current_prompt[:256]}...')
modules/generators/original_generator.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
3
+ from diffusers_helper.memory import DynamicSwapInstaller
4
+ from .base_generator import BaseModelGenerator
5
+
6
+ class OriginalModelGenerator(BaseModelGenerator):
7
+ """
8
+ Model generator for the Original HunyuanVideo model.
9
+ """
10
+
11
+ def __init__(self, **kwargs):
12
+ """
13
+ Initialize the Original model generator.
14
+ """
15
+ super().__init__(**kwargs)
16
+ self.model_name = "Original"
17
+ self.model_path = 'lllyasviel/FramePackI2V_HY'
18
+
19
+ def get_model_name(self):
20
+ """
21
+ Get the name of the model.
22
+ """
23
+ return self.model_name
24
+
25
+ def load_model(self):
26
+ """
27
+ Load the Original transformer model.
28
+ """
29
+ print(f"Loading {self.model_name} Transformer...")
30
+
31
+ # Create the transformer model
32
+ self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
33
+ self.model_path,
34
+ torch_dtype=torch.bfloat16
35
+ ).cpu()
36
+
37
+ # Configure the model
38
+ self.transformer.eval()
39
+ self.transformer.to(dtype=torch.bfloat16)
40
+ self.transformer.requires_grad_(False)
41
+
42
+ # Set up dynamic swap if not in high VRAM mode
43
+ if not self.high_vram:
44
+ DynamicSwapInstaller.install_model(self.transformer, device=self.gpu)
45
+
46
+ print(f"{self.model_name} Transformer Loaded.")
47
+ return self.transformer
48
+
49
+ def prepare_history_latents(self, height, width):
50
+ """
51
+ Prepare the history latents tensor for the Original model.
52
+
53
+ Args:
54
+ height: The height of the image
55
+ width: The width of the image
56
+
57
+ Returns:
58
+ The initialized history latents tensor
59
+ """
60
+ return torch.zeros(
61
+ size=(1, 16, 1 + 2 + 16, height // 8, width // 8),
62
+ dtype=torch.float32
63
+ ).cpu()
64
+
65
+ def get_latent_paddings(self, total_latent_sections):
66
+ """
67
+ Get the latent paddings for the Original model.
68
+
69
+ Args:
70
+ total_latent_sections: The total number of latent sections
71
+
72
+ Returns:
73
+ A list of latent paddings
74
+ """
75
+ # Original model uses reversed latent paddings
76
+ if total_latent_sections > 4:
77
+ return [3] + [2] * (total_latent_sections - 3) + [1, 0]
78
+ else:
79
+ return list(reversed(range(total_latent_sections)))
80
+
81
+ def prepare_indices(self, latent_padding_size, latent_window_size):
82
+ """
83
+ Prepare the indices for the Original model.
84
+
85
+ Args:
86
+ latent_padding_size: The size of the latent padding
87
+ latent_window_size: The size of the latent window
88
+
89
+ Returns:
90
+ A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices)
91
+ """
92
+ indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
93
+ clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
94
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
95
+
96
+ return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices
97
+
98
+ def prepare_clean_latents(self, start_latent, history_latents):
99
+ """
100
+ Prepare the clean latents for the Original model.
101
+
102
+ Args:
103
+ start_latent: The start latent
104
+ history_latents: The history latents
105
+
106
+ Returns:
107
+ A tuple of (clean_latents, clean_latents_2x, clean_latents_4x)
108
+ """
109
+ clean_latents_pre = start_latent.to(history_latents)
110
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
111
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
112
+
113
+ return clean_latents, clean_latents_2x, clean_latents_4x
114
+
115
+ def update_history_latents(self, history_latents, generated_latents):
116
+ """
117
+ Update the history latents with the generated latents for the Original model.
118
+
119
+ Args:
120
+ history_latents: The history latents
121
+ generated_latents: The generated latents
122
+
123
+ Returns:
124
+ The updated history latents
125
+ """
126
+ # For Original model, we prepend the generated latents
127
+ return torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
128
+
129
+ def get_real_history_latents(self, history_latents, total_generated_latent_frames):
130
+ """
131
+ Get the real history latents for the Original model.
132
+
133
+ Args:
134
+ history_latents: The history latents
135
+ total_generated_latent_frames: The total number of generated latent frames
136
+
137
+ Returns:
138
+ The real history latents
139
+ """
140
+ return history_latents[:, :, :total_generated_latent_frames, :, :]
141
+
142
+ def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames):
143
+ """
144
+ Update the history pixels with the current pixels for the Original model.
145
+
146
+ Args:
147
+ history_pixels: The history pixels
148
+ current_pixels: The current pixels
149
+ overlapped_frames: The number of overlapped frames
150
+
151
+ Returns:
152
+ The updated history pixels
153
+ """
154
+ from diffusers_helper.utils import soft_append_bcthw
155
+ # For Original model, current_pixels is first, history_pixels is second
156
+ return soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
157
+
158
+ def get_section_latent_frames(self, latent_window_size, is_last_section):
159
+ """
160
+ Get the number of section latent frames for the Original model.
161
+
162
+ Args:
163
+ latent_window_size: The size of the latent window
164
+ is_last_section: Whether this is the last section
165
+
166
+ Returns:
167
+ The number of section latent frames
168
+ """
169
+ return latent_window_size * 2
170
+
171
+ def get_current_pixels(self, real_history_latents, section_latent_frames, vae):
172
+ """
173
+ Get the current pixels for the Original model.
174
+
175
+ Args:
176
+ real_history_latents: The real history latents
177
+ section_latent_frames: The number of section latent frames
178
+ vae: The VAE model
179
+
180
+ Returns:
181
+ The current pixels
182
+ """
183
+ from diffusers_helper.hunyuan import vae_decode
184
+ return vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
185
+
186
+ def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt):
187
+ """
188
+ Format the position description for the Original model.
189
+
190
+ Args:
191
+ total_generated_latent_frames: The total number of generated latent frames
192
+ current_pos: The current position in seconds
193
+ original_pos: The original position in seconds
194
+ current_prompt: The current prompt
195
+
196
+ Returns:
197
+ The formatted position description
198
+ """
199
+ return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, '
200
+ f'Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). '
201
+ f'Current position: {current_pos:.2f}s (original: {original_pos:.2f}s). '
202
+ f'using prompt: {current_prompt[:256]}...')
modules/interface.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ import datetime
4
+ import random
5
+ import json
6
+ import os
7
+ import shutil
8
+ from typing import List, Dict, Any, Optional
9
+ from PIL import Image
10
+ import numpy as np
11
+ import base64
12
+ import io
13
+
14
+ from modules.video_queue import JobStatus, Job
15
+ from modules.prompt_handler import get_section_boundaries, get_quick_prompts, parse_timestamped_prompt
16
+ from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html
17
+ from diffusers_helper.bucket_tools import find_nearest_bucket
18
+
19
+ def create_interface(
20
+ process_fn,
21
+ monitor_fn,
22
+ end_process_fn,
23
+ update_queue_status_fn,
24
+ load_lora_file_fn,
25
+ job_queue,
26
+ settings,
27
+ default_prompt: str = '[1s: The person waves hello] [3s: The person jumps up and down] [5s: The person does a dance]',
28
+ lora_names: list = [],
29
+ lora_values: list = []
30
+ ):
31
+ """
32
+ Create the Gradio interface for the video generation application
33
+
34
+ Args:
35
+ process_fn: Function to process a new job
36
+ monitor_fn: Function to monitor an existing job
37
+ end_process_fn: Function to cancel the current job
38
+ update_queue_status_fn: Function to update the queue status display
39
+ default_prompt: Default prompt text
40
+ lora_names: List of loaded LoRA names
41
+
42
+ Returns:
43
+ Gradio Blocks interface
44
+ """
45
+ # Get section boundaries and quick prompts
46
+ section_boundaries = get_section_boundaries()
47
+ quick_prompts = get_quick_prompts()
48
+
49
+ # Create the interface
50
+ css = make_progress_bar_css()
51
+ css += """
52
+ /* Image container styling - more aggressive approach */
53
+ .contain-image, .contain-image > div, .contain-image > div > img {
54
+ object-fit: contain !important;
55
+ }
56
+
57
+ /* Target all images in the contain-image class and its children */
58
+ .contain-image img,
59
+ .contain-image > div > img,
60
+ .contain-image * img {
61
+ object-fit: contain !important;
62
+ width: 100% !important;
63
+ height: 100% !important;
64
+ max-height: 100% !important;
65
+ max-width: 100% !important;
66
+ }
67
+
68
+ /* Additional selectors to override Gradio defaults */
69
+ .gradio-container img,
70
+ .gradio-container .svelte-1b5oq5x,
71
+ .gradio-container [data-testid="image"] img {
72
+ object-fit: contain !important;
73
+ }
74
+
75
+ /* Toolbar styling */
76
+ #fixed-toolbar {
77
+ position: fixed;
78
+ top: 0;
79
+ left: 0;
80
+ width: 100vw;
81
+ z-index: 1000;
82
+ background: rgb(11, 15, 25);
83
+ color: #fff;
84
+ padding: 10px 20px;
85
+ display: flex;
86
+ align-items: center;
87
+ gap: 16px;
88
+ box-shadow: 0 2px 8px rgba(0,0,0,0.1);
89
+ border-bottom: 1px solid #4f46e5;
90
+ }
91
+
92
+ /* Button styling */
93
+ #toolbar-add-to-queue-btn button {
94
+ font-size: 14px !important;
95
+ padding: 4px 16px !important;
96
+ height: 32px !important;
97
+ min-width: 80px !important;
98
+ }
99
+ .narrow-button {
100
+ min-width: 40px !important;
101
+ width: 40px !important;
102
+ padding: 0 !important;
103
+ margin: 0 !important;
104
+ }
105
+ .gr-button-primary {
106
+ color: white;
107
+ }
108
+
109
+ /* Layout adjustments */
110
+ body, .gradio-container {
111
+ padding-top: 40px !important;
112
+ }
113
+ """
114
+
115
+ # Get the theme from settings
116
+ current_theme = settings.get("gradio_theme", "default") # Use default if not found
117
+ block = gr.Blocks(css=css, title="FramePack Studio", theme=current_theme).queue()
118
+
119
+ with block:
120
+
121
+ with gr.Row(elem_id="fixed-toolbar"):
122
+ gr.Markdown("<h1 style='margin:0;color:white;'>FramePack Studio</h1>")
123
+ # with gr.Column(scale=1):
124
+ # queue_stats_display = gr.Markdown("<p style='margin:0;color:white;'>Queue: 0 | Completed: 0</p>")
125
+ with gr.Column(scale=0):
126
+ refresh_stats_btn = gr.Button("⟳", elem_id="refresh-stats-btn")
127
+
128
+
129
+
130
+
131
+ with gr.Tabs():
132
+ with gr.Tab("Generate", id="generate_tab"):
133
+ with gr.Row():
134
+ with gr.Column(scale=2):
135
+ model_type = gr.Radio(
136
+ choices=["Original", "F1"],
137
+ value="Original",
138
+ label="Model",
139
+ info="Select which model to use for generation"
140
+ )
141
+ input_image = gr.Image(
142
+ sources='upload',
143
+ type="numpy",
144
+ label="Image (optional)",
145
+ height=420,
146
+ elem_classes="contain-image",
147
+ image_mode="RGB",
148
+ show_download_button=False,
149
+ show_label=True,
150
+ container=True
151
+ )
152
+
153
+ with gr.Accordion("Latent Image Options", open=False):
154
+ latent_type = gr.Dropdown(
155
+ ["Black", "White", "Noise", "Green Screen"], label="Latent Image", value="Black", info="Used as a starting point if no image is provided"
156
+ )
157
+
158
+ prompt = gr.Textbox(label="Prompt", value=default_prompt)
159
+
160
+ with gr.Accordion("Prompt Parameters", open=False):
161
+ n_prompt = gr.Textbox(label="Negative Prompt", value="", visible=True) # Make visible for both models
162
+
163
+ blend_sections = gr.Slider(
164
+ minimum=0, maximum=10, value=4, step=1,
165
+ label="Number of sections to blend between prompts"
166
+ )
167
+ with gr.Accordion("Generation Parameters", open=True):
168
+ with gr.Row():
169
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1)
170
+ total_second_length = gr.Slider(label="Video Length (Seconds)", minimum=1, maximum=120, value=6, step=0.1)
171
+ with gr.Group():
172
+ with gr.Row("Resolution"):
173
+ resolutionW = gr.Slider(
174
+ label="Width", minimum=128, maximum=768, value=640, step=32,
175
+ info="Nearest valid width will be used."
176
+ )
177
+ resolutionH = gr.Slider(
178
+ label="Height", minimum=128, maximum=768, value=640, step=32,
179
+ info="Nearest valid height will be used."
180
+ )
181
+ resolution_text = gr.Markdown(value="<div style='text-align:right; padding:5px 15px 5px 5px;'>Selected bucket for resolution: 640 x 640</div>", label="", show_label=False)
182
+ def on_input_image_change(img):
183
+ if img is not None:
184
+ return gr.update(info="Nearest valid bucket size will be used. Height will be adjusted automatically."), gr.update(visible=False)
185
+ else:
186
+ return gr.update(info="Nearest valid width will be used."), gr.update(visible=True)
187
+ input_image.change(fn=on_input_image_change, inputs=[input_image], outputs=[resolutionW, resolutionH])
188
+ def on_resolution_change(img, resolutionW, resolutionH):
189
+ out_bucket_resH, out_bucket_resW = [640, 640]
190
+ if img is not None:
191
+ H, W, _ = img.shape
192
+ out_bucket_resH, out_bucket_resW = find_nearest_bucket(H, W, resolution=resolutionW)
193
+ else:
194
+ out_bucket_resH, out_bucket_resW = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2) # if resolutionW > resolutionH else resolutionH
195
+ return gr.update(value=f"<div style='text-align:right; padding:5px 15px 5px 5px;'>Selected bucket for resolution: {out_bucket_resW} x {out_bucket_resH}</div>")
196
+ resolutionW.change(fn=on_resolution_change, inputs=[input_image, resolutionW, resolutionH], outputs=[resolution_text], show_progress="hidden")
197
+ resolutionH.change(fn=on_resolution_change, inputs=[input_image, resolutionW, resolutionH], outputs=[resolution_text], show_progress="hidden")
198
+ with gr.Row("LoRAs"):
199
+ lora_selector = gr.Dropdown(
200
+ choices=lora_names,
201
+ label="Select LoRAs to Load",
202
+ multiselect=True,
203
+ value=[],
204
+ info="Select one or more LoRAs to use for this job"
205
+ )
206
+ lora_names_states = gr.State(lora_names)
207
+ lora_sliders = {}
208
+ for lora in lora_names:
209
+ lora_sliders[lora] = gr.Slider(
210
+ minimum=0.0, maximum=2.0, value=1.0, step=0.01,
211
+ label=f"{lora} Weight", visible=False, interactive=True
212
+ )
213
+
214
+ with gr.Row("Metadata"):
215
+ json_upload = gr.File(
216
+ label="Upload Metadata JSON (optional)",
217
+ file_types=[".json"],
218
+ type="filepath",
219
+ height=100,
220
+ )
221
+ with gr.Row("TeaCache"):
222
+ use_teacache = gr.Checkbox(label='Use TeaCache', value=True, info='Faster speed, but often makes hands and fingers slightly worse.')
223
+
224
+ with gr.Row():
225
+ seed = gr.Number(label="Seed", value=31337, precision=0)
226
+ randomize_seed = gr.Checkbox(label="Randomize", value=False, info="Generate a new random seed for each job")
227
+
228
+ with gr.Accordion("Advanced Parameters", open=False):
229
+ latent_window_size = gr.Slider(label="Latent Window Size", minimum=1, maximum=33, value=9, step=1, visible=True, info='Change at your own risk, very experimental') # Should not change
230
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=1.0, step=0.01, visible=False) # Should not change
231
+ gs = gr.Slider(label="Distilled CFG Scale", minimum=1.0, maximum=32.0, value=10.0, step=0.01)
232
+ rs = gr.Slider(label="CFG Re-Scale", minimum=0.0, maximum=1.0, value=0.0, step=0.01, visible=False) # Should not change
233
+
234
+ with gr.Column():
235
+ preview_image = gr.Image(
236
+ label="Next Latents",
237
+ height=150,
238
+ visible=True,
239
+ type="numpy",
240
+ interactive=False,
241
+ elem_classes="contain-image",
242
+ image_mode="RGB"
243
+ )
244
+ result_video = gr.Video(label="Finished Frames", autoplay=True, show_share_button=False, height=256, loop=True)
245
+ progress_desc = gr.Markdown('', elem_classes='no-generating-animation')
246
+ progress_bar = gr.HTML('', elem_classes='no-generating-animation')
247
+
248
+ with gr.Row():
249
+ current_job_id = gr.Textbox(label="Current Job ID", visible=True, interactive=True)
250
+ end_button = gr.Button(value="Cancel Current Job", interactive=True)
251
+ start_button = gr.Button(value="Add to Queue", elem_id="toolbar-add-to-queue-btn")
252
+
253
+ with gr.Tab("Queue"):
254
+ with gr.Row():
255
+ with gr.Column():
256
+ # Create a container for the queue status
257
+ with gr.Row():
258
+ queue_status = gr.DataFrame(
259
+ headers=["Job ID", "Type", "Status", "Created", "Started", "Completed", "Elapsed"], # Removed Preview header
260
+ datatype=["str", "str", "str", "str", "str", "str", "str"], # Removed image datatype
261
+ label="Job Queue"
262
+ )
263
+ with gr.Row():
264
+ refresh_button = gr.Button("Refresh Queue")
265
+ # Connect the refresh button (Moved inside 'with block')
266
+ refresh_button.click(
267
+ fn=update_queue_status_fn, # Use the function passed in
268
+ inputs=[],
269
+ outputs=[queue_status]
270
+ )
271
+ # Create a container for thumbnails (kept for potential future use, though not displayed in DataFrame)
272
+ with gr.Row():
273
+ thumbnail_container = gr.Column()
274
+ thumbnail_container.elem_classes = ["thumbnail-container"]
275
+
276
+ # Add CSS for thumbnails
277
+ with gr.TabItem("Outputs"):
278
+ outputDirectory_video = settings.get("output_dir", settings.default_settings['output_dir'])
279
+ outputDirectory_metadata = settings.get("metadata_dir", settings.default_settings['metadata_dir'])
280
+ def get_gallery_items():
281
+ items = []
282
+ for f in os.listdir(outputDirectory_metadata):
283
+ if f.endswith(".png"):
284
+ prefix = os.path.splitext(f)[0]
285
+ latest_video = get_latest_video_version(prefix)
286
+ if latest_video:
287
+ video_path = os.path.join(outputDirectory_video, latest_video)
288
+ mtime = os.path.getmtime(video_path)
289
+ preview_path = os.path.join(outputDirectory_metadata, f)
290
+ items.append((preview_path, prefix, mtime))
291
+ items.sort(key=lambda x: x[2], reverse=True)
292
+ return [(i[0], i[1]) for i in items]
293
+ def get_latest_video_version(prefix):
294
+ max_number = -1
295
+ selected_file = None
296
+ for f in os.listdir(outputDirectory_video):
297
+ if f.startswith(prefix + "_") and f.endswith(".mp4"):
298
+ num = int(f.replace(prefix + "_", '').replace(".mp4", ''))
299
+ if num > max_number:
300
+ max_number = num
301
+ selected_file = f
302
+ return selected_file
303
+ def load_video_and_info_from_prefix(prefix):
304
+ video_file = get_latest_video_version(prefix)
305
+ if not video_file:
306
+ return None, "JSON not found."
307
+ video_path = os.path.join(outputDirectory_video, video_file)
308
+ json_path = os.path.join(outputDirectory_metadata, prefix) + ".json"
309
+ info = {"description": "no info"}
310
+ if os.path.exists(json_path):
311
+ with open(json_path, "r", encoding="utf-8") as f:
312
+ info = json.load(f)
313
+ return video_path, json.dumps(info, indent=2, ensure_ascii=False)
314
+ gallery_items_state = gr.State(get_gallery_items())
315
+ with gr.Row():
316
+ with gr.Column(scale=2):
317
+ thumbs = gr.Gallery(
318
+ # value=[i[0] for i in get_gallery_items()],
319
+ columns=[4],
320
+ allow_preview=False,
321
+ object_fit="cover",
322
+ height="auto"
323
+ )
324
+ refresh_button = gr.Button("Update")
325
+ with gr.Column(scale=5):
326
+ video_out = gr.Video(sources=[], autoplay=True, loop=True, visible=False)
327
+ with gr.Column(scale=1):
328
+ info_out = gr.Textbox(label="Generation info", visible=False)
329
+ def refresh_gallery():
330
+ new_items = get_gallery_items()
331
+ return gr.update(value=[i[0] for i in new_items]), new_items
332
+ refresh_button.click(fn=refresh_gallery, outputs=[thumbs, gallery_items_state])
333
+ def on_select(evt: gr.SelectData, gallery_items):
334
+ prefix = gallery_items[evt.index][1]
335
+ video, info = load_video_and_info_from_prefix(prefix)
336
+ return gr.update(value=video, visible=True), gr.update(value=info, visible=True)
337
+ thumbs.select(fn=on_select, inputs=[gallery_items_state], outputs=[video_out, info_out])
338
+ with gr.Tab("Settings"):
339
+ with gr.Row():
340
+ with gr.Column():
341
+ save_metadata = gr.Checkbox(
342
+ label="Save Metadata",
343
+ info="Save to JSON file",
344
+ value=settings.get("save_metadata", 6),
345
+ )
346
+ gpu_memory_preservation = gr.Slider(
347
+ label="GPU Inference Preserved Memory (GB) (larger means slower)",
348
+ minimum=1,
349
+ maximum=128,
350
+ step=0.1,
351
+ value=settings.get("gpu_memory_preservation", 6),
352
+ info="Set this number to a larger value if you encounter OOM. Larger value causes slower speed."
353
+ )
354
+ mp4_crf = gr.Slider(
355
+ label="MP4 Compression",
356
+ minimum=0,
357
+ maximum=100,
358
+ step=1,
359
+ value=settings.get("mp4_crf", 16),
360
+ info="Lower means better quality. 0 is uncompressed. Change to 16 if you get black outputs."
361
+ )
362
+ clean_up_videos = gr.Checkbox(
363
+ label="Clean up video files",
364
+ value=settings.get("clean_up_videos", True),
365
+ info="If checked, only the final video will be kept after generation."
366
+ )
367
+ output_dir = gr.Textbox(
368
+ label="Output Directory",
369
+ value=settings.get("output_dir"),
370
+ placeholder="Path to save generated videos"
371
+ )
372
+ metadata_dir = gr.Textbox(
373
+ label="Metadata Directory",
374
+ value=settings.get("metadata_dir"),
375
+ placeholder="Path to save metadata files"
376
+ )
377
+ lora_dir = gr.Textbox(
378
+ label="LoRA Directory",
379
+ value=settings.get("lora_dir"),
380
+ placeholder="Path to LoRA models"
381
+ )
382
+ gradio_temp_dir = gr.Textbox(label="Gradio Temporary Directory", value=settings.get("gradio_temp_dir"))
383
+ auto_save = gr.Checkbox(
384
+ label="Auto-save settings",
385
+ value=settings.get("auto_save_settings", True)
386
+ )
387
+ # Add Gradio Theme Dropdown
388
+ gradio_themes = ["default", "base", "soft", "glass", "mono", "huggingface"]
389
+ theme_dropdown = gr.Dropdown(
390
+ label="Theme",
391
+ choices=gradio_themes,
392
+ value=settings.get("gradio_theme", "soft"),
393
+ info="Select the Gradio UI theme. Requires restart."
394
+ )
395
+ save_btn = gr.Button("Save Settings")
396
+ cleanup_btn = gr.Button("Clean Up Temporary Files")
397
+ status = gr.HTML("")
398
+ cleanup_output = gr.Textbox(label="Cleanup Status", interactive=False)
399
+
400
+ def save_settings(save_metadata, gpu_memory_preservation, mp4_crf, clean_up_videos, output_dir, metadata_dir, lora_dir, gradio_temp_dir, auto_save, selected_theme):
401
+ try:
402
+ settings.save_settings(
403
+ save_metadata=save_metadata,
404
+ gpu_memory_preservation=gpu_memory_preservation,
405
+ mp4_crf=mp4_crf,
406
+ clean_up_videos=clean_up_videos,
407
+ output_dir=output_dir,
408
+ metadata_dir=metadata_dir,
409
+ lora_dir=lora_dir,
410
+ gradio_temp_dir=gradio_temp_dir,
411
+ auto_save_settings=auto_save,
412
+ gradio_theme=selected_theme
413
+ )
414
+ return "<p style='color:green;'>Settings saved successfully! Restart required for theme change.</p>"
415
+ except Exception as e:
416
+ return f"<p style='color:red;'>Error saving settings: {str(e)}</p>"
417
+
418
+ save_btn.click(
419
+ fn=save_settings,
420
+ inputs=[save_metadata, gpu_memory_preservation, mp4_crf, clean_up_videos, output_dir, metadata_dir, lora_dir, gradio_temp_dir, auto_save, theme_dropdown],
421
+ outputs=[status]
422
+ )
423
+
424
+ def cleanup_temp_files():
425
+ """Clean up temporary files and folders in the Gradio temp directory"""
426
+ temp_dir = settings.get("gradio_temp_dir")
427
+ if not temp_dir or not os.path.exists(temp_dir):
428
+ return "No temporary directory found or directory does not exist."
429
+
430
+ try:
431
+ # Get all items in the temp directory
432
+ items = os.listdir(temp_dir)
433
+ removed_count = 0
434
+ print(f"Finding items in {temp_dir}")
435
+ for item in items:
436
+ item_path = os.path.join(temp_dir, item)
437
+ try:
438
+ if os.path.isfile(item_path) or os.path.islink(item_path):
439
+ print(f"Removing {item_path}")
440
+ os.remove(item_path)
441
+ removed_count += 1
442
+ elif os.path.isdir(item_path):
443
+ print(f"Removing directory {item_path}")
444
+ shutil.rmtree(item_path)
445
+ removed_count += 1
446
+ except Exception as e:
447
+ print(f"Error removing {item_path}: {e}")
448
+
449
+ return f"Cleaned up {removed_count} temporary files/folders."
450
+ except Exception as e:
451
+ return f"Error cleaning up temporary files: {str(e)}"
452
+
453
+ # --- Event Handlers and Connections (Now correctly indented) ---
454
+
455
+ # Connect the main process function (wrapper for adding to queue)
456
+ def process_with_queue_update(model_type, *args):
457
+ # Extract all arguments (ensure order matches inputs lists)
458
+ input_image, prompt_text, n_prompt, seed_value, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, randomize_seed_checked, save_metadata_checked, blend_sections, latent_type, clean_up_videos, selected_loras, resolutionW, resolutionH, *lora_args = args
459
+
460
+ # DO NOT parse the prompt here. Parsing happens once in the worker.
461
+
462
+ # Use the current seed value as is for this job
463
+ # Call the process function with all arguments
464
+ # Pass the model_type and the ORIGINAL prompt_text string to the backend process function
465
+ result = process_fn(model_type, input_image, prompt_text, n_prompt, seed_value, total_second_length, # Pass original prompt_text string
466
+ latent_window_size, steps, cfg, gs, rs,
467
+ use_teacache, blend_sections, latent_type, clean_up_videos, selected_loras, resolutionW, resolutionH, *lora_args)
468
+
469
+ # If randomize_seed is checked, generate a new random seed for the next job
470
+ new_seed_value = None
471
+ if randomize_seed_checked:
472
+ new_seed_value = random.randint(0, 21474)
473
+ print(f"Generated new seed for next job: {new_seed_value}")
474
+
475
+ # If a job ID was created, automatically start monitoring it and update queue
476
+ if result and result[1]: # Check if job_id exists in results
477
+ job_id = result[1]
478
+ queue_status_data = update_queue_status_fn()
479
+
480
+ # Add the new seed value to the results if randomize is checked
481
+ if new_seed_value is not None:
482
+ return [result[0], job_id, result[2], result[3], result[4], result[5], result[6], queue_status_data, new_seed_value]
483
+ else:
484
+ return [result[0], job_id, result[2], result[3], result[4], result[5], result[6], queue_status_data, gr.update()]
485
+
486
+ # If no job ID was created, still return the new seed if randomize is checked
487
+ if new_seed_value is not None:
488
+ return result + [update_queue_status_fn(), new_seed_value]
489
+ else:
490
+ return result + [update_queue_status_fn(), gr.update()]
491
+
492
+ # Custom end process function that ensures the queue is updated
493
+ def end_process_with_update():
494
+ queue_status_data = end_process_fn()
495
+ # Make sure to return the queue status data
496
+ return queue_status_data
497
+
498
+ # --- Inputs Lists ---
499
+ # --- Inputs for Original Model ---
500
+ ips = [
501
+ input_image,
502
+ prompt,
503
+ n_prompt,
504
+ seed,
505
+ total_second_length,
506
+ latent_window_size,
507
+ steps,
508
+ cfg,
509
+ gs,
510
+ rs,
511
+ gpu_memory_preservation,
512
+ use_teacache,
513
+ mp4_crf,
514
+ randomize_seed,
515
+ save_metadata,
516
+ blend_sections,
517
+ latent_type,
518
+ clean_up_videos,
519
+ lora_selector,
520
+ resolutionW,
521
+ resolutionH,
522
+ lora_names_states
523
+ ]
524
+ # Add LoRA sliders to the input list
525
+ ips.extend([lora_sliders[lora] for lora in lora_names])
526
+
527
+
528
+ # --- Connect Buttons ---
529
+ start_button.click(
530
+ # Pass the selected model type from the radio buttons
531
+ fn=lambda selected_model, *args: process_with_queue_update(selected_model, *args),
532
+ inputs=[model_type] + ips,
533
+ outputs=[result_video, current_job_id, preview_image, progress_desc, progress_bar, start_button, end_button, queue_status, seed]
534
+ )
535
+
536
+ # Connect the end button to cancel the current job and update the queue
537
+ end_button.click(
538
+ fn=end_process_with_update,
539
+ outputs=[queue_status]
540
+ )
541
+
542
+ # --- Connect Monitoring ---
543
+ # Auto-monitor the current job when job_id changes
544
+ # Monitor original tab
545
+ current_job_id.change(
546
+ fn=monitor_fn,
547
+ inputs=[current_job_id],
548
+ outputs=[result_video, current_job_id, preview_image, progress_desc, progress_bar, start_button, end_button]
549
+ )
550
+
551
+ cleanup_btn.click(
552
+ fn=cleanup_temp_files,
553
+ outputs=[cleanup_output]
554
+ )
555
+
556
+
557
+ # --- Connect Queue Refresh ---
558
+ refresh_stats_btn.click(
559
+ fn=lambda: update_queue_status_fn(), # Use update_queue_status_fn passed in
560
+ inputs=None,
561
+ outputs=[queue_status] # Removed queue_stats_display from outputs
562
+ )
563
+
564
+ # Set up auto-refresh for queue status (using a timer)
565
+ refresh_timer = gr.Number(value=0, visible=False)
566
+ def refresh_timer_fn():
567
+ """Updates the timer value periodically to trigger queue refresh"""
568
+ return int(time.time())
569
+ # This timer seems unused, maybe intended for block.load()? Keeping definition for now.
570
+ # refresh_timer.change(
571
+ # fn=update_queue_status_fn, # Use the function passed in
572
+ # outputs=[queue_status] # Update shared queue status display
573
+ # )
574
+
575
+ # --- Connect LoRA UI ---
576
+ # Function to update slider visibility based on selection
577
+ def update_lora_sliders(selected_loras):
578
+ updates = []
579
+ # Need to handle potential missing keys if lora_names changes dynamically
580
+ # For now, assume lora_names passed to create_interface is static
581
+ for lora in lora_names:
582
+ updates.append(gr.update(visible=(lora in selected_loras)))
583
+ # Ensure the output list matches the number of sliders defined
584
+ num_sliders = len(lora_sliders)
585
+ return updates[:num_sliders] # Return only updates for existing sliders
586
+
587
+ # Connect the dropdown to the sliders
588
+ lora_selector.change(
589
+ fn=update_lora_sliders,
590
+ inputs=[lora_selector],
591
+ outputs=[lora_sliders[lora] for lora in lora_names] # Assumes lora_sliders keys match lora_names
592
+ )
593
+
594
+
595
+ # --- Connect Metadata Loading ---
596
+ # Function to load metadata from JSON file
597
+ def load_metadata_from_json(json_path):
598
+ if not json_path:
599
+ # Return updates for all potentially affected components
600
+ num_orig_sliders = len(lora_sliders)
601
+ return [gr.update()] * (2 + num_orig_sliders)
602
+
603
+ try:
604
+ with open(json_path, 'r') as f:
605
+ metadata = json.load(f)
606
+
607
+ prompt_val = metadata.get('prompt')
608
+ seed_val = metadata.get('seed')
609
+
610
+ # Check for LoRA values in metadata
611
+ lora_weights = metadata.get('loras', {}) # Changed key to 'loras' based on studio.py worker
612
+
613
+ print(f"Loaded metadata from JSON: {json_path}")
614
+ print(f"Prompt: {prompt_val}, Seed: {seed_val}")
615
+
616
+ # Update the UI components
617
+ updates = [
618
+ gr.update(value=prompt_val) if prompt_val else gr.update(),
619
+ gr.update(value=seed_val) if seed_val is not None else gr.update()
620
+ ]
621
+
622
+ # Update LoRA sliders if they exist in metadata
623
+ for lora in lora_names:
624
+ if lora in lora_weights:
625
+ updates.append(gr.update(value=lora_weights[lora]))
626
+ else:
627
+ updates.append(gr.update()) # No change if LoRA not in metadata
628
+
629
+ # Ensure the number of updates matches the number of outputs
630
+ num_orig_sliders = len(lora_sliders)
631
+ return updates[:2 + num_orig_sliders] # Return updates for prompt, seed, and sliders
632
+
633
+ except Exception as e:
634
+ print(f"Error loading metadata: {e}")
635
+ num_orig_sliders = len(lora_sliders)
636
+ return [gr.update()] * (2 + num_orig_sliders)
637
+
638
+
639
+ # Connect JSON metadata loader for Original tab
640
+ json_upload.change(
641
+ fn=load_metadata_from_json,
642
+ inputs=[json_upload],
643
+ outputs=[prompt, seed] + [lora_sliders[lora] for lora in lora_names]
644
+ )
645
+
646
+
647
+ # --- Helper Functions (defined within create_interface scope if needed by handlers) ---
648
+ # Function to get queue statistics
649
+ def get_queue_stats():
650
+ try:
651
+ # Get all jobs from the queue
652
+ jobs = job_queue.get_all_jobs()
653
+
654
+ # Count jobs by status
655
+ status_counts = {
656
+ "QUEUED": 0,
657
+ "RUNNING": 0,
658
+ "COMPLETED": 0,
659
+ "FAILED": 0,
660
+ "CANCELLED": 0
661
+ }
662
+
663
+ for job in jobs:
664
+ if hasattr(job, 'status'):
665
+ status = str(job.status) # Use str() for safety
666
+ if status in status_counts:
667
+ status_counts[status] += 1
668
+
669
+ # Format the display text
670
+ stats_text = f"Queue: {status_counts['QUEUED']} | Running: {status_counts['RUNNING']} | Completed: {status_counts['COMPLETED']} | Failed: {status_counts['FAILED']} | Cancelled: {status_counts['CANCELLED']}"
671
+
672
+ return f"<p style='margin:0;color:white;'>{stats_text}</p>"
673
+
674
+ except Exception as e:
675
+ print(f"Error getting queue stats: {e}")
676
+ return "<p style='margin:0;color:white;'>Error loading queue stats</p>"
677
+
678
+ # Add footer with social links
679
+ with gr.Row(elem_id="footer"):
680
+ with gr.Column(scale=1):
681
+ gr.HTML("""
682
+ <div style="text-align: center; padding: 20px; color: #666;">
683
+ <div style="margin-top: 10px;">
684
+ <a href="https://patreon.com/Colinu" target="_blank" style="margin: 0 10px; color: #666; text-decoration: none;">
685
+ <i class="fab fa-patreon"></i>Support on Patreon
686
+ </a>
687
+ <a href="https://discord.gg/MtuM7gFJ3V" target="_blank" style="margin: 0 10px; color: #666; text-decoration: none;">
688
+ <i class="fab fa-discord"></i> Discord
689
+ </a>
690
+ <a href="https://github.com/colinurbs/FramePack-Studio" target="_blank" style="margin: 0 10px; color: #666; text-decoration: none;">
691
+ <i class="fab fa-github"></i> GitHub
692
+ </a>
693
+ </div>
694
+ </div>
695
+ """)
696
+
697
+ # Add CSS for footer
698
+
699
+ return block
700
+
701
+
702
+ # --- Top-level Helper Functions (Used by Gradio callbacks, must be defined outside create_interface) ---
703
+
704
+ def format_queue_status(jobs):
705
+ """Format job data for display in the queue status table"""
706
+ rows = []
707
+ for job in jobs:
708
+ created = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(job.created_at)) if job.created_at else ""
709
+ started = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(job.started_at)) if job.started_at else ""
710
+ completed = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(job.completed_at)) if job.completed_at else ""
711
+
712
+ # Calculate elapsed time
713
+ elapsed_time = ""
714
+ if job.started_at:
715
+ if job.completed_at:
716
+ start_datetime = datetime.datetime.fromtimestamp(job.started_at)
717
+ complete_datetime = datetime.datetime.fromtimestamp(job.completed_at)
718
+ elapsed_seconds = (complete_datetime - start_datetime).total_seconds()
719
+ elapsed_time = f"{elapsed_seconds:.2f}s"
720
+ else:
721
+ # For running jobs, calculate elapsed time from now
722
+ start_datetime = datetime.datetime.fromtimestamp(job.started_at)
723
+ current_datetime = datetime.datetime.now()
724
+ elapsed_seconds = (current_datetime - start_datetime).total_seconds()
725
+ elapsed_time = f"{elapsed_seconds:.2f}s (running)"
726
+
727
+ # Get generation type from job data
728
+ generation_type = getattr(job, 'generation_type', 'Original')
729
+
730
+ # Removed thumbnail processing
731
+
732
+ rows.append([
733
+ job.id[:6] + '...',
734
+ generation_type,
735
+ job.status.value,
736
+ created,
737
+ started,
738
+ completed,
739
+ elapsed_time
740
+ # Removed thumbnail from row data
741
+ ])
742
+ return rows
743
+
744
+ # Create the queue status update function (wrapper around format_queue_status)
745
+ def update_queue_status_with_thumbnails(): # Function name is now slightly misleading, but keep for now to avoid breaking clicks
746
+ # This function is likely called by the refresh button and potentially the timer
747
+ # It needs access to the job_queue object
748
+ # Assuming job_queue is accessible globally or passed appropriately
749
+ # For now, let's assume it's globally accessible as defined in studio.py
750
+ # If not, this needs adjustment based on how job_queue is managed.
751
+ try:
752
+ # Need access to the global job_queue instance from studio.py
753
+ # This might require restructuring or passing job_queue differently.
754
+ # For now, assuming it's accessible (this might fail if run standalone)
755
+ from __main__ import job_queue # Attempt to import from main script scope
756
+
757
+ jobs = job_queue.get_all_jobs()
758
+ for job in jobs:
759
+ if job.status == JobStatus.PENDING:
760
+ job.queue_position = job_queue.get_queue_position(job.id)
761
+
762
+ if job_queue.current_job:
763
+ job_queue.current_job.status = JobStatus.RUNNING
764
+
765
+ return format_queue_status(jobs)
766
+ except ImportError:
767
+ print("Error: Could not import job_queue. Queue status update might fail.")
768
+ return [] # Return empty list on error
769
+ except Exception as e:
770
+ print(f"Error updating queue status: {e}")
771
+ return []
modules/prompt_handler.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional
4
+
5
+
6
+ @dataclass
7
+ class PromptSection:
8
+ """Represents a section of the prompt with specific timing information"""
9
+ prompt: str
10
+ start_time: float = 0 # in seconds
11
+ end_time: Optional[float] = None # in seconds, None means until the end
12
+
13
+
14
+ def snap_to_section_boundaries(prompt_sections: List[PromptSection], latent_window_size: int, fps: int = 30) -> List[PromptSection]:
15
+ """
16
+ Adjust timestamps to align with model's internal section boundaries
17
+
18
+ Args:
19
+ prompt_sections: List of PromptSection objects
20
+ latent_window_size: Size of the latent window used in the model
21
+ fps: Frames per second (default: 30)
22
+
23
+ Returns:
24
+ List of PromptSection objects with aligned timestamps
25
+ """
26
+ section_duration = (latent_window_size * 4 - 3) / fps # Duration of one section in seconds
27
+
28
+ aligned_sections = []
29
+ for section in prompt_sections:
30
+ # Snap start time to nearest section boundary
31
+ aligned_start = round(section.start_time / section_duration) * section_duration
32
+
33
+ # Snap end time to nearest section boundary
34
+ aligned_end = None
35
+ if section.end_time is not None:
36
+ aligned_end = round(section.end_time / section_duration) * section_duration
37
+
38
+ # Ensure minimum section length
39
+ if aligned_end is not None and aligned_end <= aligned_start:
40
+ aligned_end = aligned_start + section_duration
41
+
42
+ aligned_sections.append(PromptSection(
43
+ prompt=section.prompt,
44
+ start_time=aligned_start,
45
+ end_time=aligned_end
46
+ ))
47
+
48
+ return aligned_sections
49
+
50
+
51
+ def parse_timestamped_prompt(prompt_text: str, total_duration: float, latent_window_size: int = 9, generation_type: str = "Original") -> List[PromptSection]:
52
+ """
53
+ Parse a prompt with timestamps in the format [0s-2s: text] or [3s: text]
54
+
55
+ Args:
56
+ prompt_text: The input prompt text with optional timestamp sections
57
+ total_duration: Total duration of the video in seconds
58
+ latent_window_size: Size of the latent window used in the model
59
+ generation_type: Type of generation ("Original" or "F1")
60
+
61
+ Returns:
62
+ List of PromptSection objects with timestamps aligned to section boundaries
63
+ and reversed to account for reverse generation (only for Original type)
64
+ """
65
+ # Default prompt for the entire duration if no timestamps are found
66
+ if "[" not in prompt_text or "]" not in prompt_text:
67
+ return [PromptSection(prompt=prompt_text.strip())]
68
+
69
+ sections = []
70
+ # Find all timestamp sections [time: text]
71
+ timestamp_pattern = r'\[(\d+(?:\.\d+)?s)(?:-(\d+(?:\.\d+)?s))?\s*:\s*(.*?)\]'
72
+ regular_text = prompt_text
73
+
74
+ for match in re.finditer(timestamp_pattern, prompt_text):
75
+ start_time_str = match.group(1)
76
+ end_time_str = match.group(2)
77
+ section_text = match.group(3).strip()
78
+
79
+ # Convert time strings to seconds
80
+ start_time = float(start_time_str.rstrip('s'))
81
+ end_time = float(end_time_str.rstrip('s')) if end_time_str else None
82
+
83
+ sections.append(PromptSection(
84
+ prompt=section_text,
85
+ start_time=start_time,
86
+ end_time=end_time
87
+ ))
88
+
89
+ # Remove the processed section from regular_text
90
+ regular_text = regular_text.replace(match.group(0), "")
91
+
92
+ # If there's any text outside of timestamp sections, use it as a default for the entire duration
93
+ regular_text = regular_text.strip()
94
+ if regular_text:
95
+ sections.append(PromptSection(
96
+ prompt=regular_text,
97
+ start_time=0,
98
+ end_time=None
99
+ ))
100
+
101
+ # Sort sections by start time
102
+ sections.sort(key=lambda x: x.start_time)
103
+
104
+ # Fill in end times if not specified
105
+ for i in range(len(sections) - 1):
106
+ if sections[i].end_time is None:
107
+ sections[i].end_time = sections[i+1].start_time
108
+
109
+ # Set the last section's end time to the total duration if not specified
110
+ if sections and sections[-1].end_time is None:
111
+ sections[-1].end_time = total_duration
112
+
113
+ # Snap timestamps to section boundaries
114
+ sections = snap_to_section_boundaries(sections, latent_window_size)
115
+
116
+ # Only reverse timestamps for Original generation type
117
+ if generation_type == "Original":
118
+ # Now reverse the timestamps to account for reverse generation
119
+ reversed_sections = []
120
+ for section in sections:
121
+ reversed_start = total_duration - section.end_time if section.end_time is not None else 0
122
+ reversed_end = total_duration - section.start_time
123
+ reversed_sections.append(PromptSection(
124
+ prompt=section.prompt,
125
+ start_time=reversed_start,
126
+ end_time=reversed_end
127
+ ))
128
+
129
+ # Sort the reversed sections by start time
130
+ reversed_sections.sort(key=lambda x: x.start_time)
131
+ return reversed_sections
132
+
133
+ return sections
134
+
135
+
136
+ def get_section_boundaries(latent_window_size: int = 9, count: int = 10) -> str:
137
+ """
138
+ Calculate and format section boundaries for UI display
139
+
140
+ Args:
141
+ latent_window_size: Size of the latent window used in the model
142
+ count: Number of boundaries to display
143
+
144
+ Returns:
145
+ Formatted string of section boundaries
146
+ """
147
+ section_duration = (latent_window_size * 4 - 3) / 30
148
+ return ", ".join([f"{i*section_duration:.1f}s" for i in range(count)])
149
+
150
+
151
+ def get_quick_prompts() -> List[List[str]]:
152
+ """
153
+ Get a list of example timestamped prompts
154
+
155
+ Returns:
156
+ List of example prompts formatted for Gradio Dataset
157
+ """
158
+ prompts = [
159
+ '[0s: The person waves hello] [2s: The person jumps up and down] [4s: The person does a spin]',
160
+ '[0s: The person raises both arms slowly] [2s: The person claps hands enthusiastically]',
161
+ '[0s: Person gives thumbs up] [1.1s: Person smiles and winks] [2.2s: Person shows two thumbs down]',
162
+ '[0s: Person looks surprised] [1.1s: Person raises arms above head] [2.2s-3.3s: Person puts hands on hips]'
163
+ ]
164
+ return [[x] for x in prompts]
modules/settings.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Dict, Any, Optional
4
+ import os
5
+
6
+ class Settings:
7
+ def __init__(self):
8
+ # Get the project root directory (where settings.py is located)
9
+ project_root = Path(__file__).parent.parent
10
+
11
+ self.settings_file = project_root / ".framepack" / "settings.json"
12
+ self.settings_file.parent.mkdir(parents=True, exist_ok=True)
13
+
14
+ # Set default paths relative to project root
15
+ self.default_settings = {
16
+ "save_metadata": True,
17
+ "gpu_memory_preservation": 6,
18
+ "output_dir": str(project_root / "outputs"),
19
+ "metadata_dir": str(project_root / "outputs"),
20
+ "lora_dir": str(project_root / "loras"),
21
+ "gradio_temp_dir": str(project_root / "temp"),
22
+ "auto_save_settings": True,
23
+ "gradio_theme": "base",
24
+ "mp4_crf": 16,
25
+ "clean_up_videos": True
26
+ }
27
+ self.settings = self.load_settings()
28
+
29
+ def load_settings(self) -> Dict[str, Any]:
30
+ """Load settings from file or return defaults"""
31
+ if self.settings_file.exists():
32
+ try:
33
+ with open(self.settings_file, 'r') as f:
34
+ loaded_settings = json.load(f)
35
+ # Merge with defaults to ensure all settings exist
36
+ settings = self.default_settings.copy()
37
+ settings.update(loaded_settings)
38
+ return settings
39
+ except Exception as e:
40
+ print(f"Error loading settings: {e}")
41
+ return self.default_settings.copy()
42
+ return self.default_settings.copy()
43
+
44
+ def save_settings(self, **kwargs):
45
+ """Save settings to file. Accepts keyword arguments for any settings to update."""
46
+ # Update self.settings with any provided keyword arguments
47
+ self.settings.update(kwargs)
48
+ # Ensure all default fields are present
49
+ for k, v in self.default_settings.items():
50
+ self.settings.setdefault(k, v)
51
+
52
+ # Ensure directories exist for relevant fields
53
+ for dir_key in ["output_dir", "metadata_dir", "lora_dir", "gradio_temp_dir"]:
54
+ dir_path = self.settings.get(dir_key)
55
+ if dir_path:
56
+ os.makedirs(dir_path, exist_ok=True)
57
+
58
+ # Save to file
59
+ with open(self.settings_file, 'w') as f:
60
+ json.dump(self.settings, f, indent=4)
61
+
62
+ def get(self, key: str, default: Any = None) -> Any:
63
+ """Get a setting value"""
64
+ return self.settings.get(key, default)
65
+
66
+ def set(self, key: str, value: Any) -> None:
67
+ """Set a setting value"""
68
+ self.settings[key] = value
69
+ if self.settings.get("auto_save_settings", True):
70
+ self.save_settings()
71
+
72
+ def update(self, settings: Dict[str, Any]) -> None:
73
+ """Update multiple settings at once"""
74
+ self.settings.update(settings)
75
+ if self.settings.get("auto_save_settings", True):
76
+ self.save_settings()
modules/video_queue.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ import uuid
4
+ from dataclasses import dataclass, field
5
+ from enum import Enum
6
+ from typing import Dict, Any, Optional, List
7
+ import queue as queue_module # Renamed to avoid conflicts
8
+ import io
9
+ import base64
10
+ from PIL import Image
11
+ import numpy as np
12
+
13
+ from diffusers_helper.thread_utils import AsyncStream
14
+
15
+
16
+ # Simple LIFO queue implementation to avoid dependency on queue.LifoQueue
17
+ class SimpleLifoQueue:
18
+ def __init__(self):
19
+ self._queue = []
20
+ self._mutex = threading.Lock()
21
+ self._not_empty = threading.Condition(self._mutex)
22
+
23
+ def put(self, item):
24
+ with self._mutex:
25
+ self._queue.append(item)
26
+ self._not_empty.notify()
27
+
28
+ def get(self):
29
+ with self._not_empty:
30
+ while not self._queue:
31
+ self._not_empty.wait()
32
+ return self._queue.pop()
33
+
34
+ def task_done(self):
35
+ pass # For compatibility with queue.Queue
36
+
37
+
38
+ class JobStatus(Enum):
39
+ PENDING = "pending"
40
+ RUNNING = "running"
41
+ COMPLETED = "completed"
42
+ FAILED = "failed"
43
+ CANCELLED = "cancelled"
44
+
45
+
46
+ @dataclass
47
+ class Job:
48
+ id: str
49
+ params: Dict[str, Any]
50
+ status: JobStatus = JobStatus.PENDING
51
+ created_at: float = field(default_factory=time.time)
52
+ started_at: Optional[float] = None
53
+ completed_at: Optional[float] = None
54
+ error: Optional[str] = None
55
+ result: Optional[str] = None
56
+ progress_data: Optional[Dict] = None
57
+ queue_position: Optional[int] = None
58
+ stream: Optional[Any] = None
59
+ input_image: Optional[np.ndarray] = None
60
+ latent_type: Optional[str] = None
61
+ thumbnail: Optional[str] = None
62
+ generation_type: Optional[str] = None # Added generation_type
63
+
64
+ def __post_init__(self):
65
+ # Store generation type
66
+ self.generation_type = self.params.get('model_type', 'Original') # Initialize generation_type
67
+
68
+ # Store input image or latent type
69
+ if 'input_image' in self.params and self.params['input_image'] is not None:
70
+ self.input_image = self.params['input_image']
71
+ # Create thumbnail
72
+ img = Image.fromarray(self.input_image)
73
+ img.thumbnail((100, 100))
74
+ buffered = io.BytesIO()
75
+ img.save(buffered, format="PNG")
76
+ self.thumbnail = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}"
77
+ elif 'latent_type' in self.params:
78
+ self.latent_type = self.params['latent_type']
79
+ # Create a colored square based on latent type
80
+ color_map = {
81
+ "Black": (0, 0, 0),
82
+ "White": (255, 255, 255),
83
+ "Noise": (128, 128, 128),
84
+ "Green Screen": (0, 177, 64)
85
+ }
86
+ color = color_map.get(self.latent_type, (0, 0, 0))
87
+ img = Image.new('RGB', (100, 100), color)
88
+ buffered = io.BytesIO()
89
+ img.save(buffered, format="PNG")
90
+ self.thumbnail = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}"
91
+
92
+
93
+ class VideoJobQueue:
94
+ def __init__(self):
95
+ self.queue = queue_module.Queue() # Using standard Queue instead of LifoQueue
96
+ self.jobs = {}
97
+ self.current_job = None
98
+ self.lock = threading.Lock()
99
+ self.worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
100
+ self.worker_thread.start()
101
+ self.worker_function = None # Will be set from outside
102
+ self.is_processing = False # Flag to track if we're currently processing a job
103
+
104
+ def set_worker_function(self, worker_function):
105
+ """Set the worker function to use for processing jobs"""
106
+ self.worker_function = worker_function
107
+
108
+ def add_job(self, params):
109
+ """Add a job to the queue and return its ID"""
110
+ job_id = str(uuid.uuid4())
111
+ job = Job(
112
+ id=job_id,
113
+ params=params,
114
+ status=JobStatus.PENDING,
115
+ created_at=time.time(),
116
+ progress_data={},
117
+ stream=AsyncStream()
118
+ )
119
+
120
+ with self.lock:
121
+ print(f"Adding job {job_id} to queue, current job is {self.current_job.id if self.current_job else 'None'}")
122
+ self.jobs[job_id] = job
123
+ self.queue.put(job_id)
124
+
125
+ return job_id
126
+
127
+ def get_job(self, job_id):
128
+ """Get job by ID"""
129
+ with self.lock:
130
+ return self.jobs.get(job_id)
131
+
132
+ def get_all_jobs(self):
133
+ """Get all jobs"""
134
+ with self.lock:
135
+ return list(self.jobs.values())
136
+
137
+ def cancel_job(self, job_id):
138
+ """Cancel a pending job"""
139
+ with self.lock:
140
+ job = self.jobs.get(job_id)
141
+ if job and job.status == JobStatus.PENDING:
142
+ job.status = JobStatus.CANCELLED
143
+ job.completed_at = time.time() # Mark completion time
144
+ return True
145
+ elif job and job.status == JobStatus.RUNNING:
146
+ # Send cancel signal to the job's stream
147
+ job.stream.input_queue.push('end')
148
+ # Mark job as cancelled (this will be confirmed when the worker processes the end signal)
149
+ job.status = JobStatus.CANCELLED
150
+ job.completed_at = time.time() # Mark completion time
151
+ return True
152
+ return False
153
+
154
+ def get_queue_position(self, job_id):
155
+ """Get position in queue (0 = currently running)"""
156
+ with self.lock:
157
+ job = self.jobs.get(job_id)
158
+ if not job:
159
+ return None
160
+
161
+ if job.status == JobStatus.RUNNING:
162
+ return 0
163
+
164
+ if job.status != JobStatus.PENDING:
165
+ return None
166
+
167
+ # Count pending jobs ahead in queue
168
+ position = 1 # Start at 1 because 0 means running
169
+ for j in self.jobs.values():
170
+ if (j.status == JobStatus.PENDING and
171
+ j.created_at < job.created_at):
172
+ position += 1
173
+ return position
174
+
175
+ def update_job_progress(self, job_id, progress_data):
176
+ """Update job progress data"""
177
+ with self.lock:
178
+ job = self.jobs.get(job_id)
179
+ if job:
180
+ job.progress_data = progress_data
181
+
182
+ def _worker_loop(self):
183
+ """Worker thread that processes jobs from the queue"""
184
+ while True:
185
+ try:
186
+ # Get the next job ID from the queue
187
+ try:
188
+ job_id = self.queue.get(block=True, timeout=1.0) # Use timeout to allow periodic checks
189
+ except queue_module.Empty:
190
+ # No jobs in queue, just continue the loop
191
+ continue
192
+
193
+ with self.lock:
194
+ job = self.jobs.get(job_id)
195
+ if not job:
196
+ self.queue.task_done()
197
+ continue
198
+
199
+ # Skip cancelled jobs
200
+ if job.status == JobStatus.CANCELLED:
201
+ self.queue.task_done()
202
+ continue
203
+
204
+ # If we're already processing a job, wait for it to complete
205
+ if self.is_processing:
206
+ # Put the job back in the queue
207
+ self.queue.put(job_id)
208
+ self.queue.task_done()
209
+ time.sleep(0.1) # Small delay to prevent busy waiting
210
+ continue
211
+
212
+ print(f"Starting job {job_id}, current job was {self.current_job.id if self.current_job else 'None'}")
213
+ job.status = JobStatus.RUNNING
214
+ job.started_at = time.time()
215
+ self.current_job = job
216
+ self.is_processing = True
217
+
218
+ job_completed = False
219
+
220
+ try:
221
+ if self.worker_function is None:
222
+ raise ValueError("Worker function not set. Call set_worker_function() first.")
223
+
224
+ # Start the worker function with the job parameters
225
+ from diffusers_helper.thread_utils import async_run
226
+ async_run(
227
+ self.worker_function,
228
+ **job.params,
229
+ job_stream=job.stream
230
+ )
231
+
232
+ # Process the results from the stream
233
+ output_filename = None
234
+
235
+ # Set a maximum time to wait for the job to complete
236
+ max_wait_time = 3600 # 1 hour in seconds
237
+ start_time = time.time()
238
+ last_activity_time = time.time()
239
+
240
+ while True:
241
+ # Check if job has been cancelled before processing next output
242
+ with self.lock:
243
+ if job.status == JobStatus.CANCELLED:
244
+ print(f"Job {job_id} was cancelled, breaking out of processing loop")
245
+ job_completed = True
246
+ break
247
+
248
+ # Check if we've been waiting too long without any activity
249
+ current_time = time.time()
250
+ if current_time - start_time > max_wait_time:
251
+ print(f"Job {job_id} timed out after {max_wait_time} seconds")
252
+ with self.lock:
253
+ job.status = JobStatus.FAILED
254
+ job.error = "Job timed out"
255
+ job.completed_at = time.time()
256
+ job_completed = True
257
+ break
258
+
259
+ # Check for inactivity (no output for a while)
260
+ if current_time - last_activity_time > 60: # 1 minute of inactivity
261
+ print(f"Checking if job {job_id} is still active...")
262
+ # Just a periodic check, don't break yet
263
+
264
+ try:
265
+ # Try to get data from the queue with a non-blocking approach
266
+ flag, data = job.stream.output_queue.next()
267
+
268
+ # Update activity time since we got some data
269
+ last_activity_time = time.time()
270
+
271
+ if flag == 'file':
272
+ output_filename = data
273
+ with self.lock:
274
+ job.result = output_filename
275
+
276
+ elif flag == 'progress':
277
+ preview, desc, html = data
278
+ with self.lock:
279
+ job.progress_data = {
280
+ 'preview': preview,
281
+ 'desc': desc,
282
+ 'html': html
283
+ }
284
+
285
+ elif flag == 'end':
286
+ print(f"Received end signal for job {job_id}")
287
+ job_completed = True
288
+ break
289
+
290
+ except IndexError:
291
+ # Queue is empty, wait a bit and try again
292
+ time.sleep(0.1)
293
+ continue
294
+ except Exception as e:
295
+ print(f"Error processing job output: {e}")
296
+ # Wait a bit before trying again
297
+ time.sleep(0.1)
298
+ continue
299
+ except Exception as e:
300
+ import traceback
301
+ traceback.print_exc()
302
+ print(f"Error processing job {job_id}: {e}")
303
+ with self.lock:
304
+ job.status = JobStatus.FAILED
305
+ job.error = str(e)
306
+ job.completed_at = time.time()
307
+ job_completed = True
308
+
309
+ finally:
310
+ with self.lock:
311
+ # Make sure we properly clean up the job state
312
+ if job.status == JobStatus.RUNNING:
313
+ if job_completed:
314
+ job.status = JobStatus.COMPLETED
315
+ else:
316
+ # Something went wrong but we didn't mark it as completed
317
+ job.status = JobStatus.FAILED
318
+ job.error = "Job processing was interrupted"
319
+
320
+ job.completed_at = time.time()
321
+
322
+ print(f"Finishing job {job_id} with status {job.status}")
323
+ self.is_processing = False
324
+ self.current_job = None
325
+ self.queue.task_done()
326
+
327
+ except Exception as e:
328
+ import traceback
329
+ traceback.print_exc()
330
+ print(f"Error in worker loop: {e}")
331
+
332
+ # Make sure we reset processing state if there was an error
333
+ with self.lock:
334
+ self.is_processing = False
335
+ if self.current_job:
336
+ self.current_job.status = JobStatus.FAILED
337
+ self.current_job.error = f"Worker loop error: {str(e)}"
338
+ self.current_job.completed_at = time.time()
339
+ self.current_job = None
340
+
341
+ time.sleep(0.5) # Prevent tight loop on error
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.6.0
2
+ diffusers==0.33.1
3
+ transformers==4.46.2
4
+ gradio==5.25.2
5
+ sentencepiece==0.2.0
6
+ pillow==11.1.0
7
+ av==12.1.0
8
+ numpy==1.26.2
9
+ scipy==1.12.0
10
+ requests==2.31.0
11
+ torchsde==0.2.6
12
+ jinja2>=3.1.2
13
+ torchvision
14
+ einops
15
+ opencv-contrib-python
16
+ safetensors
17
+ peft
studio.py ADDED
@@ -0,0 +1,1012 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers_helper.hf_login import login
2
+
3
+ import json
4
+ import os
5
+ from pathlib import PurePath
6
+ import time
7
+ import argparse
8
+ import traceback
9
+ import einops
10
+ import numpy as np
11
+ import torch
12
+ import datetime
13
+
14
+ os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download')))
15
+
16
+ import gradio as gr
17
+ from PIL import Image
18
+ from PIL.PngImagePlugin import PngInfo
19
+ from diffusers import AutoencoderKLHunyuanVideo
20
+ from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer
21
+ from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode, vae_decode_fake
22
+ from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, generate_timestamp
23
+ from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
24
+ from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
25
+ from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete
26
+ from diffusers_helper.thread_utils import AsyncStream
27
+ from diffusers_helper.gradio.progress_bar import make_progress_bar_html
28
+ from transformers import SiglipImageProcessor, SiglipVisionModel
29
+ from diffusers_helper.clip_vision import hf_clip_vision_encode
30
+ from diffusers_helper.bucket_tools import find_nearest_bucket
31
+ from diffusers_helper import lora_utils
32
+ from diffusers_helper.lora_utils import load_lora, unload_all_loras
33
+
34
+ # Import model generators
35
+ from modules.generators import create_model_generator
36
+
37
+ # Global cache for prompt embeddings
38
+ prompt_embedding_cache = {}
39
+ # Import from modules
40
+ from modules.video_queue import VideoJobQueue, JobStatus
41
+ from modules.prompt_handler import parse_timestamped_prompt
42
+ from modules.interface import create_interface, format_queue_status
43
+ from modules.settings import Settings
44
+
45
+ # ADDED: Debug function to verify LoRA state
46
+ def verify_lora_state(transformer, label=""):
47
+ """Debug function to verify the state of LoRAs in a transformer model"""
48
+ if transformer is None:
49
+ print(f"[{label}] Transformer is None, cannot verify LoRA state")
50
+ return
51
+
52
+ has_loras = False
53
+ if hasattr(transformer, 'peft_config'):
54
+ adapter_names = list(transformer.peft_config.keys()) if transformer.peft_config else []
55
+ if adapter_names:
56
+ has_loras = True
57
+ print(f"[{label}] Transformer has LoRAs: {', '.join(adapter_names)}")
58
+ else:
59
+ print(f"[{label}] Transformer has no LoRAs in peft_config")
60
+ else:
61
+ print(f"[{label}] Transformer has no peft_config attribute")
62
+
63
+ # Check for any LoRA modules
64
+ for name, module in transformer.named_modules():
65
+ if hasattr(module, 'lora_A') and module.lora_A:
66
+ has_loras = True
67
+ # print(f"[{label}] Found lora_A in module {name}")
68
+ if hasattr(module, 'lora_B') and module.lora_B:
69
+ has_loras = True
70
+ # print(f"[{label}] Found lora_B in module {name}")
71
+
72
+ if not has_loras:
73
+ print(f"[{label}] No LoRA components found in transformer")
74
+
75
+
76
+ parser = argparse.ArgumentParser()
77
+ parser.add_argument('--share', action='store_true')
78
+ parser.add_argument("--server", type=str, default='0.0.0.0')
79
+ parser.add_argument("--port", type=int, required=False)
80
+ parser.add_argument("--inbrowser", action='store_true')
81
+ parser.add_argument("--lora", type=str, default=None, help="Lora path (comma separated for multiple)")
82
+ args = parser.parse_args()
83
+
84
+ print(args)
85
+
86
+ free_mem_gb = get_cuda_free_memory_gb(gpu)
87
+ high_vram = free_mem_gb > 60
88
+
89
+ print(f'Free VRAM {free_mem_gb} GB')
90
+ print(f'High-VRAM Mode: {high_vram}')
91
+
92
+ # Load models
93
+ text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu()
94
+ text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
95
+ tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
96
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2')
97
+ vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu()
98
+
99
+ feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor')
100
+ image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu()
101
+
102
+ # Initialize model generator placeholder
103
+ current_generator = None # Will hold the currently active model generator
104
+
105
+ # Load models based on VRAM availability later
106
+
107
+ # Configure models
108
+ vae.eval()
109
+ text_encoder.eval()
110
+ text_encoder_2.eval()
111
+ image_encoder.eval()
112
+
113
+ if not high_vram:
114
+ vae.enable_slicing()
115
+ vae.enable_tiling()
116
+
117
+
118
+ vae.to(dtype=torch.float16)
119
+ image_encoder.to(dtype=torch.float16)
120
+ text_encoder.to(dtype=torch.float16)
121
+ text_encoder_2.to(dtype=torch.float16)
122
+
123
+ vae.requires_grad_(False)
124
+ text_encoder.requires_grad_(False)
125
+ text_encoder_2.requires_grad_(False)
126
+ image_encoder.requires_grad_(False)
127
+
128
+ # Create lora directory if it doesn't exist
129
+ lora_dir = os.path.join(os.path.dirname(__file__), 'loras')
130
+ os.makedirs(lora_dir, exist_ok=True)
131
+
132
+ # Initialize LoRA support - moved scanning after settings load
133
+ lora_names = []
134
+ lora_values = [] # This seems unused for population, might be related to weights later
135
+
136
+ script_dir = os.path.dirname(os.path.abspath(__file__))
137
+
138
+ # Define default LoRA folder path relative to the script directory (used if setting is missing)
139
+ default_lora_folder = os.path.join(script_dir, "loras")
140
+ os.makedirs(default_lora_folder, exist_ok=True) # Ensure default exists
141
+
142
+ if not high_vram:
143
+ # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
144
+ DynamicSwapInstaller.install_model(text_encoder, device=gpu)
145
+ else:
146
+ text_encoder.to(gpu)
147
+ text_encoder_2.to(gpu)
148
+ image_encoder.to(gpu)
149
+ vae.to(gpu)
150
+
151
+ stream = AsyncStream()
152
+
153
+ outputs_folder = './outputs/'
154
+ os.makedirs(outputs_folder, exist_ok=True)
155
+
156
+ # Initialize settings
157
+ settings = Settings()
158
+
159
+ # --- Populate LoRA names AFTER settings are loaded ---
160
+ lora_folder_from_settings: str = settings.get("lora_dir", default_lora_folder) # Use setting, fallback to default
161
+ print(f"Scanning for LoRAs in: {lora_folder_from_settings}")
162
+ if os.path.isdir(lora_folder_from_settings):
163
+ try:
164
+ lora_files = [f for f in os.listdir(lora_folder_from_settings)
165
+ if f.endswith('.safetensors') or f.endswith('.pt')]
166
+ for lora_file in lora_files:
167
+ lora_name = PurePath(lora_file).stem
168
+ lora_names.append(lora_name) # Get name without extension
169
+ print(f"Found LoRAs: {lora_names}")
170
+ except Exception as e:
171
+ print(f"Error scanning LoRA directory '{lora_folder_from_settings}': {e}")
172
+ else:
173
+ print(f"LoRA directory not found: {lora_folder_from_settings}")
174
+ # --- End LoRA population ---
175
+
176
+
177
+ # Create job queue
178
+ job_queue = VideoJobQueue()
179
+
180
+
181
+
182
+ # Function to load a LoRA file
183
+ def load_lora_file(lora_file: str | PurePath):
184
+ if not lora_file:
185
+ return None, "No file selected"
186
+
187
+ try:
188
+ # Get the filename from the path
189
+ lora_path = PurePath(lora_file)
190
+ lora_name = lora_path.name
191
+
192
+ # Copy the file to the lora directory
193
+ lora_dest = PurePath(lora_dir, lora_path)
194
+ import shutil
195
+ shutil.copy(lora_file, lora_dest)
196
+
197
+ # Load the LoRA
198
+ global current_generator, lora_names
199
+ if current_generator is None:
200
+ return None, "Error: No model loaded to apply LoRA to. Generate something first."
201
+
202
+ # Unload any existing LoRAs first
203
+ current_generator.unload_loras()
204
+
205
+ # Load the single LoRA
206
+ selected_loras = [lora_path.stem]
207
+ current_generator.load_loras(selected_loras, lora_dir, selected_loras)
208
+
209
+ # Add to lora_names if not already there
210
+ lora_base_name = lora_path.stem
211
+ if lora_base_name not in lora_names:
212
+ lora_names.append(lora_base_name)
213
+
214
+ # Get the current device of the transformer
215
+ device = next(current_generator.transformer.parameters()).device
216
+
217
+ # Move all LoRA adapters to the same device as the base model
218
+ current_generator.move_lora_adapters_to_device(device)
219
+
220
+ print(f"Loaded LoRA: {lora_name} to {current_generator.get_model_name()} model")
221
+
222
+ return gr.update(choices=lora_names), f"Successfully loaded LoRA: {lora_name}"
223
+ except Exception as e:
224
+ print(f"Error loading LoRA: {e}")
225
+ return None, f"Error loading LoRA: {e}"
226
+
227
+ @torch.no_grad()
228
+ def get_cached_or_encode_prompt(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, target_device):
229
+ """
230
+ Retrieves prompt embeddings from cache or encodes them if not found.
231
+ Stores encoded embeddings (on CPU) in the cache.
232
+ Returns embeddings moved to the target_device.
233
+ """
234
+ if prompt in prompt_embedding_cache:
235
+ print(f"Cache hit for prompt: {prompt[:60]}...")
236
+ llama_vec_cpu, llama_mask_cpu, clip_l_pooler_cpu = prompt_embedding_cache[prompt]
237
+ # Move cached embeddings (from CPU) to the target device
238
+ llama_vec = llama_vec_cpu.to(target_device)
239
+ llama_attention_mask = llama_mask_cpu.to(target_device) if llama_mask_cpu is not None else None
240
+ clip_l_pooler = clip_l_pooler_cpu.to(target_device)
241
+ return llama_vec, llama_attention_mask, clip_l_pooler
242
+ else:
243
+ print(f"Cache miss for prompt: {prompt[:60]}...")
244
+ llama_vec, clip_l_pooler = encode_prompt_conds(
245
+ prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2
246
+ )
247
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
248
+ # Store CPU copies in cache
249
+ prompt_embedding_cache[prompt] = (llama_vec.cpu(), llama_attention_mask.cpu() if llama_attention_mask is not None else None, clip_l_pooler.cpu())
250
+ # Return embeddings already on the target device (as encode_prompt_conds uses the model's device)
251
+ return llama_vec, llama_attention_mask, clip_l_pooler
252
+
253
+ @torch.no_grad()
254
+ def worker(
255
+ model_type,
256
+ input_image,
257
+ prompt_text,
258
+ n_prompt,
259
+ seed,
260
+ total_second_length,
261
+ latent_window_size,
262
+ steps,
263
+ cfg,
264
+ gs,
265
+ rs,
266
+ use_teacache,
267
+ blend_sections,
268
+ latent_type,
269
+ selected_loras,
270
+ has_input_image,
271
+ lora_values=None,
272
+ job_stream=None,
273
+ output_dir=None,
274
+ metadata_dir=None,
275
+ resolutionW=640, # Add resolution parameter with default value
276
+ resolutionH=640,
277
+ lora_loaded_names=[]
278
+ ):
279
+ global high_vram, current_generator
280
+
281
+ # Ensure any existing LoRAs are unloaded from the current generator
282
+ if current_generator is not None:
283
+ print("Unloading any existing LoRAs before starting new job")
284
+ current_generator.unload_loras()
285
+ import gc
286
+ gc.collect()
287
+ if torch.cuda.is_available():
288
+ torch.cuda.empty_cache()
289
+
290
+ stream_to_use = job_stream if job_stream is not None else stream
291
+
292
+ total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
293
+ total_latent_sections = int(max(round(total_latent_sections), 1))
294
+
295
+ # --- Total progress tracking ---
296
+ total_steps = total_latent_sections * steps # Total diffusion steps over all segments
297
+ step_durations = [] # Rolling history of recent step durations for ETA
298
+ last_step_time = time.time()
299
+
300
+ # Parse the timestamped prompt with boundary snapping and reversing
301
+ # prompt_text should now be the original string from the job queue
302
+ prompt_sections = parse_timestamped_prompt(prompt_text, total_second_length, latent_window_size, model_type)
303
+ job_id = generate_timestamp()
304
+
305
+ stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
306
+
307
+ try:
308
+ if not high_vram:
309
+ # Unload everything *except* the potentially active transformer
310
+ unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae)
311
+ if current_generator is not None and current_generator.transformer is not None:
312
+ offload_model_from_device_for_memory_preservation(current_generator.transformer, target_device=gpu, preserved_memory_gb=8)
313
+
314
+ # --- Model Loading / Switching ---
315
+ print(f"Worker starting for model type: {model_type}")
316
+
317
+ # Create the appropriate model generator
318
+ new_generator = create_model_generator(
319
+ model_type,
320
+ text_encoder=text_encoder,
321
+ text_encoder_2=text_encoder_2,
322
+ tokenizer=tokenizer,
323
+ tokenizer_2=tokenizer_2,
324
+ vae=vae,
325
+ image_encoder=image_encoder,
326
+ feature_extractor=feature_extractor,
327
+ high_vram=high_vram,
328
+ prompt_embedding_cache=prompt_embedding_cache,
329
+ settings=settings
330
+ )
331
+
332
+ # Update the global generator
333
+ current_generator = new_generator
334
+
335
+ # Load the transformer model
336
+ current_generator.load_model()
337
+
338
+ # Ensure the model has no LoRAs loaded
339
+ print(f"Ensuring {model_type} model has no LoRAs loaded")
340
+ current_generator.unload_loras()
341
+
342
+ # Pre-encode all prompts
343
+ stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding all prompts...'))))
344
+
345
+ if not high_vram:
346
+ fake_diffusers_current_device(text_encoder, gpu)
347
+ load_model_as_complete(text_encoder_2, target_device=gpu)
348
+
349
+ # PROMPT BLENDING: Pre-encode all prompts and store in a list in order
350
+ unique_prompts = []
351
+ for section in prompt_sections:
352
+ if section.prompt not in unique_prompts:
353
+ unique_prompts.append(section.prompt)
354
+
355
+ encoded_prompts = {}
356
+ for prompt in unique_prompts:
357
+ # Use the helper function for caching and encoding
358
+ llama_vec, llama_attention_mask, clip_l_pooler = get_cached_or_encode_prompt(
359
+ prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, gpu
360
+ )
361
+ encoded_prompts[prompt] = (llama_vec, llama_attention_mask, clip_l_pooler)
362
+
363
+ # PROMPT BLENDING: Build a list of (start_section_idx, prompt) for each prompt
364
+ prompt_change_indices = []
365
+ last_prompt = None
366
+ for idx, section in enumerate(prompt_sections):
367
+ if section.prompt != last_prompt:
368
+ prompt_change_indices.append((idx, section.prompt))
369
+ last_prompt = section.prompt
370
+
371
+ # Encode negative prompt
372
+ if cfg == 1:
373
+ llama_vec_n, llama_attention_mask_n, clip_l_pooler_n = (
374
+ torch.zeros_like(encoded_prompts[prompt_sections[0].prompt][0]),
375
+ torch.zeros_like(encoded_prompts[prompt_sections[0].prompt][1]),
376
+ torch.zeros_like(encoded_prompts[prompt_sections[0].prompt][2])
377
+ )
378
+ else:
379
+ # Use the helper function for caching and encoding negative prompt
380
+ llama_vec_n, llama_attention_mask_n, clip_l_pooler_n = get_cached_or_encode_prompt(
381
+ n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, gpu
382
+ )
383
+
384
+ # Processing input image
385
+ stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Image processing ...'))))
386
+
387
+ H, W, _ = input_image.shape
388
+ height, width = find_nearest_bucket(H, W, resolution=resolutionW if has_input_image else (resolutionH+resolutionW)/2)
389
+ input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
390
+
391
+ if settings.get("save_metadata"):
392
+ metadata = PngInfo()
393
+ # prompt_text should be a string here now
394
+ metadata.add_text("prompt", prompt_text)
395
+ metadata.add_text("seed", str(seed))
396
+ Image.fromarray(input_image_np).save(os.path.join(metadata_dir, f'{job_id}.png'), pnginfo=metadata)
397
+
398
+ metadata_dict = {
399
+ "prompt": prompt_text, # Use the original string
400
+ "seed": seed,
401
+ "total_second_length": total_second_length,
402
+ "steps": steps,
403
+ "cfg": cfg,
404
+ "gs": gs,
405
+ "rs": rs,
406
+ "latent_type" : latent_type,
407
+ "blend_sections": blend_sections,
408
+ "latent_window_size": latent_window_size,
409
+ "mp4_crf": settings.get("mp4_crf"),
410
+ "timestamp": time.time(),
411
+ "resolutionW": resolutionW, # Add resolution to metadata
412
+ "resolutionH": resolutionH,
413
+ "model_type": model_type # Add model type to metadata
414
+ }
415
+ # Add LoRA information to metadata if LoRAs are used
416
+ def ensure_list(x):
417
+ if isinstance(x, list):
418
+ return x
419
+ elif x is None:
420
+ return []
421
+ else:
422
+ return [x]
423
+
424
+ selected_loras = ensure_list(selected_loras)
425
+ lora_values = ensure_list(lora_values)
426
+
427
+ if selected_loras and len(selected_loras) > 0:
428
+ lora_data = {}
429
+ for lora_name in selected_loras:
430
+ try:
431
+ idx = lora_loaded_names.index(lora_name)
432
+ weight = lora_values[idx] if lora_values and idx < len(lora_values) else 1.0
433
+ if isinstance(weight, list):
434
+ weight_value = weight[0] if weight and len(weight) > 0 else 1.0
435
+ else:
436
+ weight_value = weight
437
+ lora_data[lora_name] = float(weight_value)
438
+ except ValueError:
439
+ lora_data[lora_name] = 1.0
440
+ metadata_dict["loras"] = lora_data
441
+
442
+ with open(os.path.join(metadata_dir, f'{job_id}.json'), 'w') as f:
443
+ json.dump(metadata_dict, f, indent=2)
444
+ else:
445
+ Image.fromarray(input_image_np).save(os.path.join(metadata_dir, f'{job_id}.png'))
446
+
447
+ input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
448
+ input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
449
+
450
+ # VAE encoding
451
+ stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
452
+
453
+ if not high_vram:
454
+ load_model_as_complete(vae, target_device=gpu)
455
+
456
+ start_latent = vae_encode(input_image_pt, vae)
457
+
458
+ # CLIP Vision
459
+ stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
460
+
461
+ if not high_vram:
462
+ load_model_as_complete(image_encoder, target_device=gpu)
463
+
464
+ image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
465
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
466
+
467
+ # Dtype
468
+ for prompt_key in encoded_prompts:
469
+ llama_vec, llama_attention_mask, clip_l_pooler = encoded_prompts[prompt_key]
470
+ llama_vec = llama_vec.to(current_generator.transformer.dtype)
471
+ clip_l_pooler = clip_l_pooler.to(current_generator.transformer.dtype)
472
+ encoded_prompts[prompt_key] = (llama_vec, llama_attention_mask, clip_l_pooler)
473
+
474
+ llama_vec_n = llama_vec_n.to(current_generator.transformer.dtype)
475
+ clip_l_pooler_n = clip_l_pooler_n.to(current_generator.transformer.dtype)
476
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(current_generator.transformer.dtype)
477
+
478
+ # Sampling
479
+ stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
480
+
481
+ rnd = torch.Generator("cpu").manual_seed(seed)
482
+ num_frames = latent_window_size * 4 - 3
483
+
484
+ # Initialize history latents based on model type
485
+ history_latents = current_generator.prepare_history_latents(height, width)
486
+
487
+ # For F1 model, initialize with start latent
488
+ if model_type == "F1":
489
+ history_latents = current_generator.initialize_with_start_latent(history_latents, start_latent)
490
+ total_generated_latent_frames = 1 # Start with 1 for F1 model since it includes the first frame
491
+
492
+ history_pixels = None
493
+ if model_type == "Original":
494
+ total_generated_latent_frames = 0
495
+
496
+ # Get latent paddings from the generator
497
+ latent_paddings = current_generator.get_latent_paddings(total_latent_sections)
498
+
499
+ # PROMPT BLENDING: Track section index
500
+ section_idx = 0
501
+
502
+ # Load LoRAs if selected
503
+ if selected_loras:
504
+ current_generator.load_loras(selected_loras, lora_folder_from_settings, lora_loaded_names, lora_values)
505
+
506
+ # --- Callback for progress ---
507
+ def callback(d):
508
+ nonlocal last_step_time, step_durations
509
+ now_time = time.time()
510
+ # Record duration between diffusion steps (skip first where duration may include setup)
511
+ if last_step_time is not None:
512
+ step_delta = now_time - last_step_time
513
+ if step_delta > 0:
514
+ step_durations.append(step_delta)
515
+ if len(step_durations) > 30: # Keep only recent 30 steps
516
+ step_durations.pop(0)
517
+ last_step_time = now_time
518
+ avg_step = sum(step_durations) / len(step_durations) if step_durations else 0.0
519
+
520
+ preview = d['denoised']
521
+ preview = vae_decode_fake(preview)
522
+ preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
523
+ preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
524
+
525
+ # --- Progress & ETA logic ---
526
+ # Current segment progress
527
+ current_step = d['i'] + 1
528
+ percentage = int(100.0 * current_step / steps)
529
+
530
+ # Total progress
531
+ total_steps_done = section_idx * steps + current_step
532
+ total_percentage = int(100.0 * total_steps_done / total_steps)
533
+
534
+ # ETA calculations
535
+ def fmt_eta(sec):
536
+ try:
537
+ return str(datetime.timedelta(seconds=int(sec)))
538
+ except Exception:
539
+ return "--:--"
540
+
541
+ segment_eta = (steps - current_step) * avg_step if avg_step else 0
542
+ total_eta = (total_steps - total_steps_done) * avg_step if avg_step else 0
543
+
544
+ segment_hint = f'Sampling {current_step}/{steps} ETA {fmt_eta(segment_eta)}'
545
+ total_hint = f'Total {total_steps_done}/{total_steps} ETA {fmt_eta(total_eta)}'
546
+
547
+ current_pos = (total_generated_latent_frames * 4 - 3) / 30
548
+ original_pos = total_second_length - current_pos
549
+ if current_pos < 0: current_pos = 0
550
+ if original_pos < 0: original_pos = 0
551
+
552
+ hint = segment_hint # deprecated variable kept to minimise other code changes
553
+ desc = current_generator.format_position_description(
554
+ total_generated_latent_frames,
555
+ current_pos,
556
+ original_pos,
557
+ current_prompt
558
+ )
559
+
560
+ progress_data = {
561
+ 'preview': preview,
562
+ 'desc': desc,
563
+ 'html': make_progress_bar_html(percentage, segment_hint) + make_progress_bar_html(total_percentage, total_hint)
564
+ }
565
+ if job_stream is not None:
566
+ job = job_queue.get_job(job_id)
567
+ if job:
568
+ job.progress_data = progress_data
569
+
570
+ stream_to_use.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, segment_hint) + make_progress_bar_html(total_percentage, total_hint))))
571
+
572
+ # --- Main generation loop ---
573
+ for latent_padding in latent_paddings:
574
+ is_last_section = latent_padding == 0
575
+ latent_padding_size = latent_padding * latent_window_size
576
+
577
+ if stream_to_use.input_queue.top() == 'end':
578
+ stream_to_use.output_queue.push(('end', None))
579
+ return
580
+
581
+ current_time_position = (total_generated_latent_frames * 4 - 3) / 30 # in seconds
582
+ if current_time_position < 0:
583
+ current_time_position = 0.01
584
+
585
+ # Find the appropriate prompt for this section
586
+ current_prompt = prompt_sections[0].prompt # Default to first prompt
587
+ for section in prompt_sections:
588
+ if section.start_time <= current_time_position and (section.end_time is None or current_time_position < section.end_time):
589
+ current_prompt = section.prompt
590
+ break
591
+
592
+ # PROMPT BLENDING: Find if we're in a blend window
593
+ blend_alpha = None
594
+ prev_prompt = current_prompt
595
+ next_prompt = current_prompt
596
+
597
+ # Only try to blend if we have prompt change indices and multiple sections
598
+ if prompt_change_indices and len(prompt_sections) > 1:
599
+ for i, (change_idx, prompt) in enumerate(prompt_change_indices):
600
+ if section_idx < change_idx:
601
+ prev_prompt = prompt_change_indices[i - 1][1] if i > 0 else prompt
602
+ next_prompt = prompt
603
+ blend_start = change_idx
604
+ blend_end = change_idx + blend_sections
605
+ if section_idx >= change_idx and section_idx < blend_end:
606
+ blend_alpha = (section_idx - change_idx + 1) / blend_sections
607
+ break
608
+ elif section_idx == change_idx:
609
+ # At the exact change, start blending
610
+ if i > 0:
611
+ prev_prompt = prompt_change_indices[i - 1][1]
612
+ next_prompt = prompt
613
+ blend_alpha = 1.0 / blend_sections
614
+ else:
615
+ prev_prompt = prompt
616
+ next_prompt = prompt
617
+ blend_alpha = None
618
+ break
619
+ else:
620
+ # After last change, no blending
621
+ prev_prompt = current_prompt
622
+ next_prompt = current_prompt
623
+ blend_alpha = None
624
+
625
+ # Get the encoded prompt for this section
626
+ if blend_alpha is not None and prev_prompt != next_prompt:
627
+ # Blend embeddings
628
+ prev_llama_vec, prev_llama_attention_mask, prev_clip_l_pooler = encoded_prompts[prev_prompt]
629
+ next_llama_vec, next_llama_attention_mask, next_clip_l_pooler = encoded_prompts[next_prompt]
630
+ llama_vec = (1 - blend_alpha) * prev_llama_vec + blend_alpha * next_llama_vec
631
+ llama_attention_mask = prev_llama_attention_mask # usually same
632
+ clip_l_pooler = (1 - blend_alpha) * prev_clip_l_pooler + blend_alpha * next_clip_l_pooler
633
+ print(f"Blending prompts: '{prev_prompt[:30]}...' -> '{next_prompt[:30]}...', alpha={blend_alpha:.2f}")
634
+ else:
635
+ llama_vec, llama_attention_mask, clip_l_pooler = encoded_prompts[current_prompt]
636
+
637
+ original_time_position = total_second_length - current_time_position
638
+ if original_time_position < 0:
639
+ original_time_position = 0
640
+
641
+ print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}, '
642
+ f'time position: {current_time_position:.2f}s (original: {original_time_position:.2f}s), '
643
+ f'using prompt: {current_prompt[:60]}...')
644
+
645
+ # Prepare indices using the generator
646
+ clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices = current_generator.prepare_indices(latent_padding_size, latent_window_size)
647
+
648
+ # Prepare clean latents using the generator
649
+ clean_latents, clean_latents_2x, clean_latents_4x = current_generator.prepare_clean_latents(start_latent, history_latents)
650
+
651
+ # Print debug info
652
+ print(f"{model_type} model section {section_idx+1}/{total_latent_sections}, latent_padding={latent_padding}")
653
+
654
+ if not high_vram:
655
+ # Unload VAE etc. before loading transformer
656
+ unload_complete_models(vae, text_encoder, text_encoder_2, image_encoder)
657
+ move_model_to_device_with_memory_preservation(current_generator.transformer, target_device=gpu, preserved_memory_gb=settings.get("gpu_memory_preservation"))
658
+ if selected_loras:
659
+ current_generator.move_lora_adapters_to_device(gpu)
660
+
661
+ if use_teacache:
662
+ current_generator.transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
663
+ else:
664
+ current_generator.transformer.initialize_teacache(enable_teacache=False)
665
+
666
+ generated_latents = sample_hunyuan(
667
+ transformer=current_generator.transformer,
668
+ sampler='unipc',
669
+ width=width,
670
+ height=height,
671
+ frames=num_frames,
672
+ real_guidance_scale=cfg,
673
+ distilled_guidance_scale=gs,
674
+ guidance_rescale=rs,
675
+ num_inference_steps=steps,
676
+ generator=rnd,
677
+ prompt_embeds=llama_vec,
678
+ prompt_embeds_mask=llama_attention_mask,
679
+ prompt_poolers=clip_l_pooler,
680
+ negative_prompt_embeds=llama_vec_n,
681
+ negative_prompt_embeds_mask=llama_attention_mask_n,
682
+ negative_prompt_poolers=clip_l_pooler_n,
683
+ device=gpu,
684
+ dtype=torch.bfloat16,
685
+ image_embeddings=image_encoder_last_hidden_state,
686
+ latent_indices=latent_indices,
687
+ clean_latents=clean_latents,
688
+ clean_latent_indices=clean_latent_indices,
689
+ clean_latents_2x=clean_latents_2x,
690
+ clean_latent_2x_indices=clean_latent_2x_indices,
691
+ clean_latents_4x=clean_latents_4x,
692
+ clean_latent_4x_indices=clean_latent_4x_indices,
693
+ callback=callback,
694
+ )
695
+
696
+ total_generated_latent_frames += int(generated_latents.shape[2])
697
+ # Update history latents using the generator
698
+ history_latents = current_generator.update_history_latents(history_latents, generated_latents)
699
+
700
+ if not high_vram:
701
+ if selected_loras:
702
+ current_generator.move_lora_adapters_to_device(cpu)
703
+ offload_model_from_device_for_memory_preservation(current_generator.transformer, target_device=gpu, preserved_memory_gb=8)
704
+ load_model_as_complete(vae, target_device=gpu)
705
+
706
+ # Get real history latents using the generator
707
+ real_history_latents = current_generator.get_real_history_latents(history_latents, total_generated_latent_frames)
708
+
709
+ if history_pixels is None:
710
+ history_pixels = vae_decode(real_history_latents, vae).cpu()
711
+ else:
712
+ section_latent_frames = current_generator.get_section_latent_frames(latent_window_size, is_last_section)
713
+ overlapped_frames = latent_window_size * 4 - 3
714
+
715
+ # Get current pixels using the generator
716
+ current_pixels = current_generator.get_current_pixels(real_history_latents, section_latent_frames, vae)
717
+
718
+ # Update history pixels using the generator
719
+ history_pixels = current_generator.update_history_pixels(history_pixels, current_pixels, overlapped_frames)
720
+
721
+ print(f"{model_type} model section {section_idx+1}/{total_latent_sections}, history_pixels shape: {history_pixels.shape}")
722
+
723
+ if not high_vram:
724
+ unload_complete_models()
725
+
726
+ output_filename = os.path.join(output_dir, f'{job_id}_{total_generated_latent_frames}.mp4')
727
+ save_bcthw_as_mp4(history_pixels, output_filename, fps=30, crf=settings.get("mp4_crf"))
728
+ print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
729
+ stream_to_use.output_queue.push(('file', output_filename))
730
+
731
+ if is_last_section:
732
+ break
733
+
734
+ section_idx += 1 # PROMPT BLENDING: increment section index
735
+
736
+ # Unload all LoRAs after generation completed
737
+ if selected_loras:
738
+ print("Unloading all LoRAs after generation completed")
739
+ current_generator.unload_loras()
740
+ import gc
741
+ gc.collect()
742
+ if torch.cuda.is_available():
743
+ torch.cuda.empty_cache()
744
+
745
+ except:
746
+ traceback.print_exc()
747
+ # Unload all LoRAs after error
748
+ if current_generator is not None and selected_loras:
749
+ print("Unloading all LoRAs after error")
750
+ current_generator.unload_loras()
751
+ import gc
752
+ gc.collect()
753
+ if torch.cuda.is_available():
754
+ torch.cuda.empty_cache()
755
+
756
+ stream_to_use.output_queue.push(('error', f"Error during generation: {traceback.format_exc()}"))
757
+ if not high_vram:
758
+ # Ensure all models including the potentially active transformer are unloaded on error
759
+ unload_complete_models(
760
+ text_encoder, text_encoder_2, image_encoder, vae,
761
+ current_generator.transformer if current_generator else None
762
+ )
763
+
764
+ if settings.get("clean_up_videos"):
765
+ try:
766
+ video_files = [
767
+ f for f in os.listdir(output_dir)
768
+ if f.startswith(f"{job_id}_") and f.endswith(".mp4")
769
+ ]
770
+ print(f"Video files found for cleanup: {video_files}")
771
+ if video_files:
772
+ def get_frame_count(filename):
773
+ try:
774
+ # Handles filenames like jobid_123.mp4
775
+ return int(filename.replace(f"{job_id}_", "").replace(".mp4", ""))
776
+ except Exception:
777
+ return -1
778
+ video_files_sorted = sorted(video_files, key=get_frame_count)
779
+ print(f"Sorted video files: {video_files_sorted}")
780
+ final_video = video_files_sorted[-1]
781
+ for vf in video_files_sorted[:-1]:
782
+ full_path = os.path.join(output_dir, vf)
783
+ try:
784
+ os.remove(full_path)
785
+ print(f"Deleted intermediate video: {full_path}")
786
+ except Exception as e:
787
+ print(f"Failed to delete {full_path}: {e}")
788
+ except Exception as e:
789
+ print(f"Error during video cleanup: {e}")
790
+
791
+ # Final verification of LoRA state
792
+ if current_generator and current_generator.transformer:
793
+ verify_lora_state(current_generator.transformer, "Worker end")
794
+
795
+ stream_to_use.output_queue.push(('end', None))
796
+ return
797
+
798
+
799
+
800
+ # Set the worker function for the job queue
801
+ job_queue.set_worker_function(worker)
802
+
803
+
804
+ def process(
805
+ model_type,
806
+ input_image,
807
+ prompt_text,
808
+ n_prompt,
809
+ seed,
810
+ total_second_length,
811
+ latent_window_size,
812
+ steps,
813
+ cfg,
814
+ gs,
815
+ rs,
816
+ use_teacache,
817
+ blend_sections,
818
+ latent_type,
819
+ clean_up_videos,
820
+ selected_loras,
821
+ resolutionW,
822
+ resolutionH,
823
+ lora_loaded_names,
824
+ *lora_values
825
+ ):
826
+
827
+ # Create a blank black image if no
828
+ # Create a default image based on the selected latent_type
829
+ has_input_image = True
830
+ if input_image is None:
831
+ has_input_image = False
832
+ default_height, default_width = resolutionH, resolutionW
833
+ if latent_type == "White":
834
+ # Create a white image
835
+ input_image = np.ones((default_height, default_width, 3), dtype=np.uint8) * 255
836
+ print("No input image provided. Using a blank white image.")
837
+
838
+ elif latent_type == "Noise":
839
+ # Create a noise image
840
+ input_image = np.random.randint(0, 256, (default_height, default_width, 3), dtype=np.uint8)
841
+ print("No input image provided. Using a random noise image.")
842
+
843
+ elif latent_type == "Green Screen":
844
+ # Create a green screen image with standard chroma key green (0, 177, 64)
845
+ input_image = np.zeros((default_height, default_width, 3), dtype=np.uint8)
846
+ input_image[:, :, 1] = 177 # Green channel
847
+ input_image[:, :, 2] = 64 # Blue channel
848
+ # Red channel remains 0
849
+ print("No input image provided. Using a standard chroma key green screen.")
850
+
851
+ else: # Default to "Black" or any other value
852
+ # Create a black image
853
+ input_image = np.zeros((default_height, default_width, 3), dtype=np.uint8)
854
+ print(f"No input image provided. Using a blank black image (latent_type: {latent_type}).")
855
+
856
+
857
+ # Create job parameters
858
+ job_params = {
859
+ 'model_type': model_type,
860
+ 'input_image': input_image.copy(), # Make a copy to avoid reference issues
861
+ 'prompt_text': prompt_text,
862
+ 'n_prompt': n_prompt,
863
+ 'seed': seed,
864
+ 'total_second_length': total_second_length,
865
+ 'latent_window_size': latent_window_size,
866
+ 'latent_type': latent_type,
867
+ 'steps': steps,
868
+ 'cfg': cfg,
869
+ 'gs': gs,
870
+ 'rs': rs,
871
+ 'blend_sections': blend_sections,
872
+ 'use_teacache': use_teacache,
873
+ 'selected_loras': selected_loras,
874
+ 'has_input_image': has_input_image,
875
+ 'output_dir': settings.get("output_dir"),
876
+ 'metadata_dir': settings.get("metadata_dir"),
877
+ 'resolutionW': resolutionW, # Add resolution parameter
878
+ 'resolutionH': resolutionH,
879
+ 'lora_loaded_names': lora_loaded_names
880
+ }
881
+
882
+ # Add LoRA values if provided - extract them from the tuple
883
+ if lora_values:
884
+ # Convert tuple to list
885
+ lora_values_list = list(lora_values)
886
+ job_params['lora_values'] = lora_values_list
887
+
888
+ # Add job to queue
889
+ job_id = job_queue.add_job(job_params)
890
+
891
+ # Set the generation_type attribute on the job object directly
892
+ job = job_queue.get_job(job_id)
893
+ if job:
894
+ job.generation_type = model_type # Set generation_type to model_type for display in queue
895
+ print(f"Added job {job_id} to queue")
896
+
897
+ queue_status = update_queue_status()
898
+ # Return immediately after adding to queue
899
+ return None, job_id, None, '', f'Job added to queue. Job ID: {job_id}', gr.update(interactive=True), gr.update(interactive=True)
900
+
901
+
902
+
903
+ def end_process():
904
+ """Cancel the current running job and update the queue status"""
905
+ print("Cancelling current job")
906
+ with job_queue.lock:
907
+ if job_queue.current_job:
908
+ job_id = job_queue.current_job.id
909
+ print(f"Cancelling job {job_id}")
910
+
911
+ # Send the end signal to the job's stream
912
+ if job_queue.current_job.stream:
913
+ job_queue.current_job.stream.input_queue.push('end')
914
+
915
+ # Mark the job as cancelled
916
+ job_queue.current_job.status = JobStatus.CANCELLED
917
+ job_queue.current_job.completed_at = time.time() # Set completion time
918
+
919
+ # Force an update to the queue status
920
+ return update_queue_status()
921
+
922
+
923
+ def update_queue_status():
924
+ """Update queue status and refresh job positions"""
925
+ jobs = job_queue.get_all_jobs()
926
+ for job in jobs:
927
+ if job.status == JobStatus.PENDING:
928
+ job.queue_position = job_queue.get_queue_position(job.id)
929
+
930
+ # Make sure to update current running job info
931
+ if job_queue.current_job:
932
+ # Make sure the running job is showing status = RUNNING
933
+ job_queue.current_job.status = JobStatus.RUNNING
934
+
935
+ return format_queue_status(jobs)
936
+
937
+
938
+ def monitor_job(job_id):
939
+ """
940
+ Monitor a specific job and update the UI with the latest video segment as soon as it's available.
941
+ """
942
+ if not job_id:
943
+ yield None, None, None, '', 'No job ID provided', gr.update(interactive=True), gr.update(interactive=True)
944
+ return
945
+
946
+ last_video = None # Track the last video file shown
947
+
948
+ while True:
949
+ job = job_queue.get_job(job_id)
950
+ if not job:
951
+ yield None, job_id, None, '', 'Job not found', gr.update(interactive=True), gr.update(interactive=True)
952
+ return
953
+
954
+ # If a new video file is available, yield it immediately
955
+ if job.result and job.result != last_video:
956
+ last_video = job.result
957
+ # You can also update preview/progress here if desired
958
+ yield last_video, job_id, gr.update(visible=True), '', '', gr.update(interactive=True), gr.update(interactive=True)
959
+
960
+ # Handle job status and progress
961
+ if job.status == JobStatus.PENDING:
962
+ position = job_queue.get_queue_position(job_id)
963
+ yield last_video, job_id, gr.update(visible=True), '', f'Waiting in queue. Position: {position}', gr.update(interactive=True), gr.update(interactive=True)
964
+
965
+ elif job.status == JobStatus.RUNNING:
966
+ if job.progress_data and 'preview' in job.progress_data:
967
+ preview = job.progress_data.get('preview')
968
+ desc = job.progress_data.get('desc', '')
969
+ html = job.progress_data.get('html', '')
970
+ yield last_video, job_id, gr.update(visible=True, value=preview), desc, html, gr.update(interactive=True), gr.update(interactive=True)
971
+ else:
972
+ yield last_video, job_id, gr.update(visible=True), '', 'Processing...', gr.update(interactive=True), gr.update(interactive=True)
973
+
974
+ elif job.status == JobStatus.COMPLETED:
975
+ # Show the final video
976
+ yield last_video, job_id, gr.update(visible=True), '', '', gr.update(interactive=True), gr.update(interactive=True)
977
+ break
978
+
979
+ elif job.status == JobStatus.FAILED:
980
+ yield last_video, job_id, gr.update(visible=True), '', f'Error: {job.error}', gr.update(interactive=True), gr.update(interactive=True)
981
+ break
982
+
983
+ elif job.status == JobStatus.CANCELLED:
984
+ yield last_video, job_id, gr.update(visible=True), '', 'Job cancelled', gr.update(interactive=True), gr.update(interactive=True)
985
+ break
986
+
987
+ # Wait a bit before checking again
988
+ time.sleep(0.5)
989
+
990
+
991
+ # Set Gradio temporary directory from settings
992
+ os.environ["GRADIO_TEMP_DIR"] = settings.get("gradio_temp_dir")
993
+
994
+ # Create the interface
995
+ interface = create_interface(
996
+ process_fn=process,
997
+ monitor_fn=monitor_job,
998
+ end_process_fn=end_process,
999
+ update_queue_status_fn=update_queue_status,
1000
+ load_lora_file_fn=load_lora_file,
1001
+ job_queue=job_queue,
1002
+ settings=settings,
1003
+ lora_names=lora_names # Explicitly pass the found LoRA names
1004
+ )
1005
+
1006
+ # Launch the interface
1007
+ interface.launch(
1008
+ server_name=args.server,
1009
+ server_port=args.port,
1010
+ share=args.share,
1011
+ inbrowser=args.inbrowser
1012
+ )