WeichenFan commited on
Commit
4a40efc
·
1 Parent(s): 0962cb3
README.md CHANGED
@@ -1,13 +1,60 @@
1
  ---
2
  title: Vchitect 2.0
3
- emoji: 🏃
4
- colorFrom: red
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.43.0
8
  app_file: app.py
9
  pinned: false
10
- license: artistic-2.0
11
  ---
 
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Vchitect 2.0
3
+ emoji: 🐢
4
+ colorFrom: yellow
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.42.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
+ # Vchitect-XL
13
 
14
+ ## Installation
15
+
16
+ ### 1. Create a conda environment and install PyTorch
17
+
18
+ Note: You may want to adjust the CUDA version [according to your driver version](https://docs.nvidia.com/deploy/cuda-compatibility/#default-to-minor-version).
19
+
20
+ ```bash
21
+ conda create -n VchitectXL -y
22
+ conda activate VchitectXL
23
+ conda install python=3.11 pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y
24
+ ```
25
+
26
+ ### 2. Install dependencies
27
+
28
+ ```bash
29
+ pip install -r requirements.txt
30
+ ```
31
+
32
+ ### 3. Install ``flash-attn``
33
+
34
+ ```bash
35
+ pip install flash-attn --no-build-isolation
36
+ ```
37
+
38
+ ### 4. Install [nvidia apex](https://github.com/nvidia/apex)
39
+
40
+ ```bash
41
+ pip install ninja
42
+ git clone https://github.com/NVIDIA/apex
43
+ cd apex
44
+ # if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
45
+ pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
46
+ # otherwise
47
+ pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
48
+ ```
49
+
50
+ ## Inference
51
+
52
+ ~~~bash
53
+ #easy infer
54
+ test_file=$1
55
+ save_dir=$2
56
+ ckpt_path=$3
57
+
58
+ python inference.py --test_file "${test_file}" --save_dir "${save_dir}" --ckpt_path "${ckpt_path}"
59
+
60
+ ~~~
app.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+ import time
4
+
5
+ import gradio as gr
6
+ import torch
7
+ # from diffusers import CogVideoXPipeline
8
+ from models.pipeline import VchitectXLPipeline
9
+ from diffusers.utils import export_to_video
10
+ from datetime import datetime, timedelta
11
+ # from openai import OpenAI
12
+ import spaces
13
+ import moviepy.editor as mp
14
+
15
+ import os
16
+ from huggingface_hub import login
17
+ login(token=os.getenv('HF_TOKEN'))
18
+
19
+ dtype = torch.float16
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ pipe = VchitectXLPipeline("Vchitect-XL/Vchitect-XL-2B",device)
22
+ #VchitectXLPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=dtype).to(device)
23
+ #
24
+ os.makedirs("./output", exist_ok=True)
25
+ os.makedirs("./gradio_tmp", exist_ok=True)
26
+
27
+ sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
28
+ For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
29
+ There are a few rules to follow:
30
+ You will only ever output a single video description per user request.
31
+ When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
32
+ Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
33
+ Video descriptions must have the same num of words as examples below. Extra words will be ignored.
34
+ """
35
+
36
+
37
+ # def convert_prompt(prompt: str, retry_times: int = 3) -> str:
38
+ # if not os.environ.get("OPENAI_API_KEY"):
39
+ # return prompt
40
+ # client = OpenAI()
41
+ # text = prompt.strip()
42
+
43
+ # for i in range(retry_times):
44
+ # response = client.chat.completions.create(
45
+ # messages=[
46
+ # {"role": "system", "content": sys_prompt},
47
+ # {
48
+ # "role": "user",
49
+ # "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
50
+ # },
51
+ # {
52
+ # "role": "assistant",
53
+ # "content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
54
+ # },
55
+ # {
56
+ # "role": "user",
57
+ # "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
58
+ # },
59
+ # {
60
+ # "role": "assistant",
61
+ # "content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
62
+ # },
63
+ # {
64
+ # "role": "user",
65
+ # "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
66
+ # },
67
+ # {
68
+ # "role": "assistant",
69
+ # "content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
70
+ # },
71
+ # {
72
+ # "role": "user",
73
+ # "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
74
+ # },
75
+ # ],
76
+ # model="glm-4-0520",
77
+ # temperature=0.01,
78
+ # top_p=0.7,
79
+ # stream=False,
80
+ # max_tokens=250,
81
+ # )
82
+ # if response.choices:
83
+ # return response.choices[0].message.content
84
+ # return prompt
85
+
86
+ @spaces.GPU(duration=120)
87
+ def infer(prompt: str, num_inference_steps: int, guidance_scale: float, progress=gr.Progress(track_tqdm=True)):
88
+ torch.cuda.empty_cache()
89
+ # video = pipe(
90
+ # prompt=prompt,
91
+ # num_videos_per_prompt=1,
92
+ # num_inference_steps=num_inference_steps,
93
+ # num_frames=49,
94
+ # guidance_scale=guidance_scale,
95
+ # ).frames[0]
96
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
97
+ video = pipe(
98
+ prompt,
99
+ negative_prompt="",
100
+ num_inference_steps=num_inference_steps,
101
+ guidance_scale=guidance_scale,
102
+ width=432,
103
+ height=240, #480x288 624x352 432x240 768x432
104
+ frames=16
105
+ )
106
+
107
+ return video
108
+
109
+
110
+ def save_video(tensor):
111
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
112
+ video_path = f"./output/{timestamp}.mp4"
113
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
114
+ export_to_video(tensor, video_path)
115
+ return video_path
116
+
117
+
118
+ def convert_to_gif(video_path):
119
+ clip = mp.VideoFileClip(video_path)
120
+ clip = clip.set_fps(8)
121
+ clip = clip.resize(height=240)
122
+ gif_path = video_path.replace(".mp4", ".gif")
123
+ clip.write_gif(gif_path, fps=8)
124
+ return gif_path
125
+
126
+
127
+ def delete_old_files():
128
+ while True:
129
+ now = datetime.now()
130
+ cutoff = now - timedelta(minutes=10)
131
+ directories = ["./output", "./gradio_tmp"]
132
+
133
+ for directory in directories:
134
+ for filename in os.listdir(directory):
135
+ file_path = os.path.join(directory, filename)
136
+ if os.path.isfile(file_path):
137
+ file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
138
+ if file_mtime < cutoff:
139
+ os.remove(file_path)
140
+ time.sleep(600)
141
+
142
+
143
+ threading.Thread(target=delete_old_files, daemon=True).start()
144
+
145
+ with gr.Blocks() as demo:
146
+ gr.Markdown("""
147
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
148
+ Vchitect-XL 2B Huggingface Space🤗
149
+ </div>
150
+ <div style="text-align: center;">
151
+ <a href="https://huggingface.co/Vchitect-XL/Vchitect-XL-2B">🤗 2B Model Hub</a> |
152
+ <a href="https://vchitect.intern-ai.org.cn/">🌐 Website</a> |
153
+ </div>
154
+ <div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
155
+ ⚠️ This demo is for academic research and experiential use only.
156
+ Users should strictly adhere to local laws and ethics.
157
+ </div>
158
+ """)
159
+ with gr.Row():
160
+ with gr.Column():
161
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=5)
162
+
163
+ # with gr.Row():
164
+ # gr.Markdown(
165
+ # "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.")
166
+ # enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
167
+
168
+ with gr.Column():
169
+ # gr.Markdown("**Optional Parameters** (default values are recommended)<br>"
170
+ # "Increasing the number of inference steps will produce more detailed videos, but it will slow down the process.<br>"
171
+ # "50 steps are recommended for most cases.<br>"
172
+ # "For the 5B model, 50 steps will take approximately 350 seconds.")
173
+ with gr.Row():
174
+ num_inference_steps = gr.Number(label="Inference Steps", value=100)
175
+ guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
176
+ generate_button = gr.Button("🎬 Generate Video")
177
+
178
+ with gr.Column():
179
+ video_output = gr.Video(label="CogVideoX Generate Video", width=768, height=432)
180
+ with gr.Row():
181
+ download_video_button = gr.File(label="📥 Download Video", visible=False)
182
+ download_gif_button = gr.File(label="📥 Download GIF", visible=False)
183
+
184
+ # gr.Markdown("""
185
+ # <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
186
+ # <div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 20px;">
187
+ # Demo Videos with 50 Inference Steps and 6.0 Guidance Scale.
188
+ # </div>
189
+ # <tr>
190
+ # <td style="width: 25%; vertical-align: top; font-size: 0.8em;">
191
+ # <p>A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.</p>
192
+ # </td>
193
+ # <td style="width: 25%; vertical-align: top;">
194
+ # <video src="https://github.com/user-attachments/assets/ea3af39a-3160-4999-90ec-2f7863c5b0e9" width="100%" controls autoplay></video>
195
+ # </td>
196
+ # <td style="width: 25%; vertical-align: top; font-size: 0.8em;">
197
+ # <p>The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.</p>
198
+ # </td>
199
+ # <td style="width: 25%; vertical-align: top;">
200
+ # <video src="https://github.com/user-attachments/assets/9de41efd-d4d1-4095-aeda-246dd834e91d" width="100%" controls autoplay></video>
201
+ # </td>
202
+ # </tr>
203
+ # <tr>
204
+ # <td style="width: 25%; vertical-align: top; font-size: 0.8em;">
205
+ # <p>A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.</p>
206
+ # </td>
207
+ # <td style="width: 25%; vertical-align: top;">
208
+ # <video src="https://github.com/user-attachments/assets/941d6661-6a8d-4a1b-b912-59606f0b2841" width="100%" controls autoplay></video>
209
+ # </td>
210
+ # <td style="width: 25%; vertical-align: top; font-size: 0.8em;">
211
+ # <p>In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.</p>
212
+ # </td>
213
+ # <td style="width: 25%; vertical-align: top;">
214
+ # <video src="https://github.com/user-attachments/assets/938529c4-91ae-4f60-b96b-3c3947fa63cb" width="100%" controls autoplay></video>
215
+ # </td>
216
+ # </tr>
217
+ # </table>
218
+ # """)
219
+
220
+
221
+ def generate(prompt, num_inference_steps, guidance_scale, model_choice, progress=gr.Progress(track_tqdm=True)):
222
+ tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
223
+ video_path = save_video(tensor)
224
+ video_update = gr.update(visible=True, value=video_path)
225
+ gif_path = convert_to_gif(video_path)
226
+ gif_update = gr.update(visible=True, value=gif_path)
227
+
228
+ return video_path, video_update, gif_update
229
+
230
+
231
+ # def enhance_prompt_func(prompt):
232
+ # return convert_prompt(prompt, retry_times=1)
233
+
234
+
235
+ generate_button.click(
236
+ generate,
237
+ inputs=[prompt, num_inference_steps, guidance_scale],
238
+ outputs=[video_output, download_video_button, download_gif_button]
239
+ )
240
+
241
+ # enhance_button.click(
242
+ # enhance_prompt_func,
243
+ # inputs=[prompt],
244
+ # outputs=[prompt]
245
+ # )
246
+
247
+ if __name__ == "__main__":
248
+ demo.launch()
inference.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models.pipeline import VchitectXLPipeline
3
+ import random
4
+ import numpy as np
5
+ import os
6
+
7
+ def set_seed(seed):
8
+ random.seed(seed)
9
+ os.environ['PYTHONHASHSEED'] = str(seed)
10
+ np.random.seed(seed)
11
+ torch.manual_seed(seed)
12
+ torch.cuda.manual_seed(seed)
13
+
14
+ def infer(args):
15
+ pipe = VchitectXLPipeline(args.ckpt_path)
16
+ idx = 0
17
+
18
+ with open(args.test_file,'r') as f:
19
+ for lines in f.readlines():
20
+ for seed in range(5):
21
+ set_seed(seed)
22
+ prompt = lines.strip('\n')
23
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
24
+ video = pipe(
25
+ prompt,
26
+ negative_prompt="",
27
+ num_inference_steps=50,
28
+ guidance_scale=7.5,
29
+ width=768,
30
+ height=432, #480x288 624x352 432x240 768x432
31
+ frames=40
32
+ )
33
+
34
+ images = video
35
+
36
+ from utils import save_as_mp4
37
+ import sys,os
38
+ duration = 1000 / 8
39
+
40
+ save_dir = args.save_dir
41
+ os.makedirs(save_dir,exist_ok=True)
42
+
43
+ idx += 1
44
+
45
+ save_as_mp4(images, os.path.join(save_dir, f"sample_{idx}_seed{seed}")+'.mp4', duration=duration)
46
+
47
+ import sys,os
48
+ import argparse
49
+
50
+ def main():
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument("--test_file", type=str)
53
+ parser.add_argument("--save_dir", type=str)
54
+ parser.add_argument("--ckpt_path", type=str)
55
+ args = parser.parse_known_args()[0]
56
+ infer(args)
57
+
58
+ if __name__ == "__main__":
59
+ main()
models/VchitectXL.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Union, Tuple, List
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
23
+ from models.blocks import JointTransformerBlock
24
+ # from diffusers.models.attention_processor import Attention, AttentionProcessor
25
+ from models.attention import Attention, AttentionProcessor
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+ from diffusers.models.normalization import AdaLayerNormContinuous
28
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
29
+ from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
30
+ from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
31
+
32
+ from einops import rearrange
33
+ from torch.distributed._tensor import Shard, Replicate
34
+ from torch.distributed.tensor.parallel import (
35
+ parallelize_module,
36
+ PrepareModuleOutput
37
+ )
38
+
39
+ #from models.layers import ParallelTimestepEmbedder, TransformerBlock, ParallelFinalLayer, Identity
40
+
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+ class VchitectXLTransformerModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
46
+ """
47
+ The Transformer model introduced in Stable Diffusion 3.
48
+
49
+ Reference: https://arxiv.org/abs/2403.03206
50
+
51
+ Parameters:
52
+ sample_size (`int`): The width of the latent images. This is fixed during training since
53
+ it is used to learn a number of position embeddings.
54
+ patch_size (`int`): Patch size to turn the input data into small patches.
55
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
56
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
57
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
58
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
59
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
60
+ caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
61
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
62
+ out_channels (`int`, defaults to 16): Number of output channels.
63
+
64
+ """
65
+
66
+ _supports_gradient_checkpointing = True
67
+
68
+ @register_to_config
69
+ def __init__(
70
+ self,
71
+ sample_size: int = 128,
72
+ patch_size: int = 2,
73
+ in_channels: int = 16,
74
+ num_layers: int = 18,
75
+ attention_head_dim: int = 64,
76
+ num_attention_heads: int = 18,
77
+ joint_attention_dim: int = 4096,
78
+ caption_projection_dim: int = 1152,
79
+ pooled_projection_dim: int = 2048,
80
+ out_channels: int = 16,
81
+ pos_embed_max_size: int = 96,
82
+ tp_size: int = 1,
83
+ rope_scaling_factor: float = 1.,
84
+ ):
85
+ super().__init__()
86
+ default_out_channels = in_channels
87
+ self.out_channels = out_channels if out_channels is not None else default_out_channels
88
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
89
+
90
+ self.pos_embed = PatchEmbed(
91
+ height=self.config.sample_size,
92
+ width=self.config.sample_size,
93
+ patch_size=self.config.patch_size,
94
+ in_channels=self.config.in_channels,
95
+ embed_dim=self.inner_dim,
96
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
97
+ )
98
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
99
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
100
+ )
101
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
102
+ # `attention_head_dim` is doubled to account for the mixing.
103
+ # It needs to crafted when we get the actual checkpoints.
104
+ self.transformer_blocks = nn.ModuleList(
105
+ [
106
+ JointTransformerBlock(
107
+ dim=self.inner_dim,
108
+ num_attention_heads=self.config.num_attention_heads,
109
+ attention_head_dim=self.inner_dim,
110
+ context_pre_only=i == num_layers - 1,
111
+ tp_size = tp_size
112
+ )
113
+ for i in range(self.config.num_layers)
114
+ ]
115
+ )
116
+
117
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
118
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
119
+
120
+ self.gradient_checkpointing = False
121
+
122
+ # Video param
123
+ # self.scatter_dim_zero = Identity()
124
+ self.freqs_cis = VchitectXLTransformerModel.precompute_freqs_cis(
125
+ self.inner_dim // self.config.num_attention_heads, 1000000, theta=1e6, rope_scaling_factor=rope_scaling_factor # todo max pos embeds
126
+ )
127
+
128
+ #self.vid_token = nn.Parameter(torch.empty(self.inner_dim))
129
+
130
+ @staticmethod
131
+ def tp_parallelize(model, tp_mesh):
132
+ for layer_id, transformer_block in enumerate(model.transformer_blocks):
133
+ layer_tp_plan = {
134
+ # Attention layer
135
+ "attn.gather_seq_scatter_hidden": PrepareModuleOutput(
136
+ output_layouts=Replicate(),
137
+ desired_output_layouts=Shard(-2)
138
+ ),
139
+ "attn.gather_hidden_scatter_seq": PrepareModuleOutput(
140
+ output_layouts=Shard(-2),
141
+ desired_output_layouts=Replicate(),
142
+ )
143
+ }
144
+ parallelize_module(
145
+ module=transformer_block,
146
+ device_mesh=tp_mesh,
147
+ parallelize_plan=layer_tp_plan
148
+ )
149
+ return model
150
+
151
+ @staticmethod
152
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, rope_scaling_factor: float = 1.0):
153
+ freqs = 1.0 / (theta ** (
154
+ torch.arange(0, dim, 2)[: (dim // 2)].float() / dim
155
+ ))
156
+ t = torch.arange(end, device=freqs.device, dtype=torch.float)
157
+ t = t / rope_scaling_factor
158
+ freqs = torch.outer(t, freqs).float()
159
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
160
+ return freqs_cis
161
+
162
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
163
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
164
+ """
165
+ Sets the attention processor to use [feed forward
166
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
167
+
168
+ Parameters:
169
+ chunk_size (`int`, *optional*):
170
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
171
+ over each tensor of dim=`dim`.
172
+ dim (`int`, *optional*, defaults to `0`):
173
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
174
+ or dim=1 (sequence length).
175
+ """
176
+ if dim not in [0, 1]:
177
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
178
+
179
+ # By default chunk size is 1
180
+ chunk_size = chunk_size or 1
181
+
182
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
183
+ if hasattr(module, "set_chunk_feed_forward"):
184
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
185
+
186
+ for child in module.children():
187
+ fn_recursive_feed_forward(child, chunk_size, dim)
188
+
189
+ for module in self.children():
190
+ fn_recursive_feed_forward(module, chunk_size, dim)
191
+
192
+ @property
193
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
194
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
195
+ r"""
196
+ Returns:
197
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
198
+ indexed by its weight name.
199
+ """
200
+ # set recursively
201
+ processors = {}
202
+
203
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
204
+ if hasattr(module, "get_processor"):
205
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
206
+
207
+ for sub_name, child in module.named_children():
208
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
209
+
210
+ return processors
211
+
212
+ for name, module in self.named_children():
213
+ fn_recursive_add_processors(name, module, processors)
214
+
215
+ return processors
216
+
217
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
218
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
219
+ r"""
220
+ Sets the attention processor to use to compute attention.
221
+
222
+ Parameters:
223
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
224
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
225
+ for **all** `Attention` layers.
226
+
227
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
228
+ processor. This is strongly recommended when setting trainable attention processors.
229
+
230
+ """
231
+ count = len(self.attn_processors.keys())
232
+
233
+ if isinstance(processor, dict) and len(processor) != count:
234
+ raise ValueError(
235
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
236
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
237
+ )
238
+
239
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
240
+ if hasattr(module, "set_processor"):
241
+ if not isinstance(processor, dict):
242
+ module.set_processor(processor)
243
+ else:
244
+ module.set_processor(processor.pop(f"{name}.processor"))
245
+
246
+ for sub_name, child in module.named_children():
247
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
248
+
249
+ for name, module in self.named_children():
250
+ fn_recursive_attn_processor(name, module, processor)
251
+
252
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
253
+ def fuse_qkv_projections(self):
254
+ """
255
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
256
+ are fused. For cross-attention modules, key and value projection matrices are fused.
257
+
258
+ <Tip warning={true}>
259
+
260
+ This API is 🧪 experimental.
261
+
262
+ </Tip>
263
+ """
264
+ self.original_attn_processors = None
265
+
266
+ for _, attn_processor in self.attn_processors.items():
267
+ if "Added" in str(attn_processor.__class__.__name__):
268
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
269
+
270
+ self.original_attn_processors = self.attn_processors
271
+
272
+ for module in self.modules():
273
+ if isinstance(module, Attention):
274
+ module.fuse_projections(fuse=True)
275
+
276
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
277
+ def unfuse_qkv_projections(self):
278
+ """Disables the fused QKV projection if enabled.
279
+
280
+ <Tip warning={true}>
281
+
282
+ This API is 🧪 experimental.
283
+
284
+ </Tip>
285
+
286
+ """
287
+ if self.original_attn_processors is not None:
288
+ self.set_attn_processor(self.original_attn_processors)
289
+
290
+ def _set_gradient_checkpointing(self, module, value=False):
291
+ if hasattr(module, "gradient_checkpointing"):
292
+ module.gradient_checkpointing = value
293
+
294
+ def patchify_and_embed(self, x):
295
+ pH = pW = self.patch_size
296
+ B, F, C, H, W = x.size()
297
+ x = rearrange(x, "b f c h w -> (b f) c h w")
298
+ x = self.pos_embed(x) # [B L D]
299
+ # x = torch.cat([
300
+ # x,
301
+ # self.vid_token.view(1, 1, -1).expand(B*F, 1, -1),
302
+ # ], dim=1)
303
+
304
+ return x, F, [(H, W)] * B
305
+
306
+ def forward(
307
+ self,
308
+ hidden_states: torch.FloatTensor,
309
+ encoder_hidden_states: torch.FloatTensor = None,
310
+ pooled_projections: torch.FloatTensor = None,
311
+ timestep: torch.LongTensor = None,
312
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
313
+ return_dict: bool = True,
314
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
315
+ """
316
+ The [`VchitectXLTransformerModel`] forward method.
317
+
318
+ Args:
319
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
320
+ Input `hidden_states`.
321
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
322
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
323
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
324
+ from the embeddings of input conditions.
325
+ timestep ( `torch.LongTensor`):
326
+ Used to indicate denoising step.
327
+ joint_attention_kwargs (`dict`, *optional*):
328
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
329
+ `self.processor` in
330
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
331
+ return_dict (`bool`, *optional*, defaults to `True`):
332
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
333
+ tuple.
334
+
335
+ Returns:
336
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
337
+ `tuple` where the first element is the sample tensor.
338
+ """
339
+ if joint_attention_kwargs is not None:
340
+ joint_attention_kwargs = joint_attention_kwargs.copy()
341
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
342
+ else:
343
+ lora_scale = 1.0
344
+
345
+ # if USE_PEFT_BACKEND:
346
+ # # weight the lora layers by setting `lora_scale` for each PEFT layer
347
+ # scale_lora_layers(self, lora_scale)
348
+ # else:
349
+ # logger.warning(
350
+ # "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
351
+ # )
352
+
353
+ height, width = hidden_states.shape[-2:]
354
+
355
+ batch_size = hidden_states.shape[0]
356
+ hidden_states, F_num, _ = self.patchify_and_embed(hidden_states) # takes care of adding positional embeddings too.
357
+ full_seq = batch_size * F_num
358
+
359
+ self.freqs_cis = self.freqs_cis.to(hidden_states.device)
360
+ freqs_cis = self.freqs_cis
361
+ # seq_length = hidden_states.size(1)
362
+ # freqs_cis = self.freqs_cis[:hidden_states.size(1)*F_num]
363
+ temb = self.time_text_embed(timestep, pooled_projections)
364
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
365
+
366
+ # for block in self.transformer_blocks:
367
+ # if self.training and self.gradient_checkpointing:
368
+
369
+ # def create_custom_forward(module, return_dict=None):
370
+ # def custom_forward(*inputs):
371
+ # if return_dict is not None:
372
+ # return module(*inputs, return_dict=return_dict)
373
+ # else:
374
+ # return module(*inputs)
375
+
376
+ # return custom_forward
377
+
378
+ # ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
379
+ # hidden_states = torch.utils.checkpoint.checkpoint(
380
+ # create_custom_forward(block),
381
+ # hidden_states,
382
+ # encoder_hidden_states,
383
+ # temb,
384
+ # **ckpt_kwargs,
385
+ # )
386
+
387
+ # else:
388
+ # encoder_hidden_states, hidden_states = block(
389
+ # hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
390
+ # )
391
+
392
+ for block_idx, block in enumerate(self.transformer_blocks):
393
+ encoder_hidden_states, hidden_states = block(
394
+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb.repeat(F_num,1), freqs_cis=freqs_cis, full_seqlen=full_seq, Frame=F_num
395
+ )
396
+
397
+ hidden_states = self.norm_out(hidden_states, temb)
398
+ hidden_states = self.proj_out(hidden_states)
399
+
400
+ # unpatchify
401
+ # hidden_states = hidden_states[:, :-1] #Drop the video token
402
+
403
+ # unpatchify
404
+ patch_size = self.config.patch_size
405
+ height = height // patch_size
406
+ width = width // patch_size
407
+
408
+ hidden_states = hidden_states.reshape(
409
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
410
+ )
411
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
412
+ output = hidden_states.reshape(
413
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
414
+ )
415
+
416
+ if USE_PEFT_BACKEND:
417
+ # remove `lora_scale` from each PEFT layer
418
+ unscale_lora_layers(self, lora_scale)
419
+
420
+ if not return_dict:
421
+ return (output,)
422
+
423
+ return Transformer2DModelOutput(sample=output)
424
+
425
+ def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
426
+ return list(self.transformer_blocks)
427
+
428
+ @classmethod
429
+ def from_pretrained_temporal(cls, pretrained_model_path, torch_dtype, logger, subfolder=None, tp_size=1):
430
+
431
+ import os
432
+ import json
433
+
434
+ if subfolder is not None:
435
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
436
+
437
+ config_file = os.path.join(pretrained_model_path, 'config.json')
438
+
439
+ with open(config_file, "r") as f:
440
+ config = json.load(f)
441
+
442
+ config["tp_size"] = tp_size
443
+ from diffusers.utils import WEIGHTS_NAME
444
+ from safetensors.torch import load_file,load_model
445
+ model = cls.from_config(config)
446
+ # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
447
+
448
+ model_files = [
449
+ os.path.join(pretrained_model_path, 'diffusion_pytorch_model.bin'),
450
+ os.path.join(pretrained_model_path, 'diffusion_pytorch_model.safetensors')
451
+ ]
452
+
453
+ model_file = None
454
+
455
+ for fp in model_files:
456
+ if os.path.exists(fp):
457
+ model_file = fp
458
+
459
+ if not model_file:
460
+ raise RuntimeError(f"{model_file} does not exist")
461
+
462
+ if not os.path.isfile(model_file):
463
+ raise RuntimeError(f"{model_file} does not exist")
464
+
465
+
466
+ state_dict = load_file(model_file,device="cpu")
467
+ m, u = model.load_state_dict(state_dict, strict=False)
468
+ model = model.to(torch_dtype)
469
+
470
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
471
+ total_params = [p.numel() for n, p in model.named_parameters()]
472
+
473
+ if logger is not None:
474
+ logger.info(f"model_file: {model_file}")
475
+ logger.info(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
476
+ logger.info(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
477
+ logger.info(f"### Total Parameters: {sum(total_params) / 1e6} M")
478
+
479
+ return model
models/__init__.py ADDED
File without changes
models/__pycache__/VchitectXL.cpython-310.pyc ADDED
Binary file (15.6 kB). View file
 
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (161 Bytes). View file
 
models/__pycache__/attention.cpython-310.pyc ADDED
Binary file (72.1 kB). View file
 
models/__pycache__/autoencoder_kl_temporal_decoder.cpython-310.pyc ADDED
Binary file (20.5 kB). View file
 
models/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (19.3 kB). View file
 
models/__pycache__/layers.cpython-310.pyc ADDED
Binary file (17.2 kB). View file
 
models/__pycache__/modeling_t5.cpython-310.pyc ADDED
Binary file (65.1 kB). View file
 
models/__pycache__/models.cpython-310.pyc ADDED
Binary file (8.81 kB). View file
 
models/__pycache__/motion_module.cpython-310.pyc ADDED
Binary file (8.3 kB). View file
 
models/__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (31.1 kB). View file
 
models/__pycache__/scheduling_ddim_cogvideox.cpython-310.pyc ADDED
Binary file (15.3 kB). View file
 
models/__pycache__/sd3_attention.cpython-310.pyc ADDED
Binary file (19.4 kB). View file
 
models/__pycache__/sd3_models.cpython-310.pyc ADDED
Binary file (17.2 kB). View file
 
models/__pycache__/sd3_sparse.cpython-310.pyc ADDED
Binary file (15.5 kB). View file
 
models/__pycache__/sd3_sparse_ae_temporal_pipeline.cpython-310.pyc ADDED
Binary file (31.9 kB). View file
 
models/__pycache__/sd3_sparse_i2v_pipeline.cpython-310.pyc ADDED
Binary file (31.9 kB). View file
 
models/__pycache__/sd3_sparse_init_pipeline.cpython-310.pyc ADDED
Binary file (32.7 kB). View file
 
models/__pycache__/sd3_sparse_pipeline.cpython-310.pyc ADDED
Binary file (31.4 kB). View file
 
models/__pycache__/sparse_attention.cpython-310.pyc ADDED
Binary file (72.1 kB). View file
 
models/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.66 kB). View file
 
models/attention.py ADDED
The diff for this file is too large to render. See raw diff
 
models/blocks.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import deprecate, logging
21
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
22
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
23
+ # from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
24
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
26
+ from models.attention import Attention, VchitectAttnProcessor
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
32
+ # "feed_forward_chunk_size" can be used to save memory
33
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
+ raise ValueError(
35
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
+ )
37
+
38
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
+ ff_output = torch.cat(
40
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
41
+ dim=chunk_dim,
42
+ )
43
+ return ff_output
44
+
45
+
46
+ @maybe_allow_in_graph
47
+ class GatedSelfAttentionDense(nn.Module):
48
+ r"""
49
+ A gated self-attention dense layer that combines visual features and object features.
50
+
51
+ Parameters:
52
+ query_dim (`int`): The number of channels in the query.
53
+ context_dim (`int`): The number of channels in the context.
54
+ n_heads (`int`): The number of heads to use for attention.
55
+ d_head (`int`): The number of channels in each head.
56
+ """
57
+
58
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
59
+ super().__init__()
60
+
61
+ # we need a linear projection since we need cat visual feature and obj feature
62
+ self.linear = nn.Linear(context_dim, query_dim)
63
+
64
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
65
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
66
+
67
+ self.norm1 = nn.LayerNorm(query_dim)
68
+ self.norm2 = nn.LayerNorm(query_dim)
69
+
70
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
71
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
72
+
73
+ self.enabled = True
74
+
75
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
76
+ if not self.enabled:
77
+ return x
78
+
79
+ n_visual = x.shape[1]
80
+ objs = self.linear(objs)
81
+
82
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
83
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
84
+
85
+ return x
86
+
87
+
88
+ @maybe_allow_in_graph
89
+ class JointTransformerBlock(nn.Module):
90
+ r"""
91
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
92
+
93
+ Reference: https://arxiv.org/abs/2403.03206
94
+
95
+ Parameters:
96
+ dim (`int`): The number of channels in the input and output.
97
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
98
+ attention_head_dim (`int`): The number of channels in each head.
99
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
100
+ processing of `context` conditions.
101
+ """
102
+
103
+ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False, tp_size=1):
104
+ super().__init__()
105
+
106
+ self.context_pre_only = context_pre_only
107
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
108
+
109
+ self.norm1 = AdaLayerNormZero(dim)
110
+
111
+ if context_norm_type == "ada_norm_continous":
112
+ self.norm1_context = AdaLayerNormContinuous(
113
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
114
+ )
115
+ elif context_norm_type == "ada_norm_zero":
116
+ self.norm1_context = AdaLayerNormZero(dim)
117
+ else:
118
+ raise ValueError(
119
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
120
+ )
121
+ # if hasattr(F, "scaled_dot_product_attention"):
122
+ # processor = VchitectAttnProcessor()
123
+ # else:
124
+ # raise ValueError(
125
+ # "The current PyTorch version does not support the `scaled_dot_product_attention` function."
126
+ # )
127
+ processor = VchitectAttnProcessor()
128
+ self.attn = Attention(
129
+ query_dim=dim,
130
+ cross_attention_dim=None,
131
+ added_kv_proj_dim=dim,
132
+ dim_head=attention_head_dim // num_attention_heads,
133
+ heads=num_attention_heads,
134
+ out_dim=attention_head_dim,
135
+ context_pre_only=context_pre_only,
136
+ bias=True,
137
+ processor=processor,
138
+ tp_size = tp_size
139
+ )
140
+
141
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
142
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
143
+
144
+ if not context_pre_only:
145
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
146
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
147
+ else:
148
+ self.norm2_context = None
149
+ self.ff_context = None
150
+
151
+ # let chunk size default to None
152
+ self._chunk_size = None
153
+ self._chunk_dim = 0
154
+
155
+
156
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
157
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
158
+ # Sets chunk feed-forward
159
+ self._chunk_size = chunk_size
160
+ self._chunk_dim = dim
161
+
162
+ def forward(
163
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, freqs_cis: torch.Tensor, full_seqlen: int, Frame: int
164
+ ):
165
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
166
+ if self.context_pre_only:
167
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
168
+ else:
169
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
170
+ encoder_hidden_states, emb=temb
171
+ )
172
+
173
+ # Attention.
174
+ attn_output, context_attn_output = self.attn(
175
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
176
+ freqs_cis=freqs_cis,
177
+ full_seqlen=full_seqlen,
178
+ Frame=Frame,
179
+ )
180
+
181
+ # Process attention outputs for the `hidden_states`.
182
+ attn_output = gate_msa.unsqueeze(1) * attn_output
183
+ hidden_states = hidden_states + attn_output
184
+
185
+ norm_hidden_states = self.norm2(hidden_states)
186
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
187
+ if self._chunk_size is not None:
188
+ # "feed_forward_chunk_size" can be used to save memory
189
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
190
+ else:
191
+ ff_output = self.ff(norm_hidden_states)
192
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
193
+
194
+ hidden_states = hidden_states + ff_output
195
+
196
+ # Process attention outputs for the `encoder_hidden_states`.
197
+ if self.context_pre_only:
198
+ encoder_hidden_states = None
199
+ else:
200
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
201
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
202
+
203
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
204
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
205
+ if self._chunk_size is not None:
206
+ # "feed_forward_chunk_size" can be used to save memory
207
+ context_ff_output = _chunked_feed_forward(
208
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
209
+ )
210
+ else:
211
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
212
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
213
+
214
+ return encoder_hidden_states, hidden_states
215
+
216
+
217
+ @maybe_allow_in_graph
218
+ class BasicTransformerBlock(nn.Module):
219
+ r"""
220
+ A basic Transformer block.
221
+
222
+ Parameters:
223
+ dim (`int`): The number of channels in the input and output.
224
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
225
+ attention_head_dim (`int`): The number of channels in each head.
226
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
227
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
228
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
229
+ num_embeds_ada_norm (:
230
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
231
+ attention_bias (:
232
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
233
+ only_cross_attention (`bool`, *optional*):
234
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
235
+ double_self_attention (`bool`, *optional*):
236
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
237
+ upcast_attention (`bool`, *optional*):
238
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
239
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
240
+ Whether to use learnable elementwise affine parameters for normalization.
241
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
242
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
243
+ final_dropout (`bool` *optional*, defaults to False):
244
+ Whether to apply a final dropout after the last feed-forward layer.
245
+ attention_type (`str`, *optional*, defaults to `"default"`):
246
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
247
+ positional_embeddings (`str`, *optional*, defaults to `None`):
248
+ The type of positional embeddings to apply to.
249
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
250
+ The maximum number of positional embeddings to apply.
251
+ """
252
+
253
+ def __init__(
254
+ self,
255
+ dim: int,
256
+ num_attention_heads: int,
257
+ attention_head_dim: int,
258
+ dropout=0.0,
259
+ cross_attention_dim: Optional[int] = None,
260
+ activation_fn: str = "geglu",
261
+ num_embeds_ada_norm: Optional[int] = None,
262
+ attention_bias: bool = False,
263
+ only_cross_attention: bool = False,
264
+ double_self_attention: bool = False,
265
+ upcast_attention: bool = False,
266
+ norm_elementwise_affine: bool = True,
267
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
268
+ norm_eps: float = 1e-5,
269
+ final_dropout: bool = False,
270
+ attention_type: str = "default",
271
+ positional_embeddings: Optional[str] = None,
272
+ num_positional_embeddings: Optional[int] = None,
273
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
274
+ ada_norm_bias: Optional[int] = None,
275
+ ff_inner_dim: Optional[int] = None,
276
+ ff_bias: bool = True,
277
+ attention_out_bias: bool = True,
278
+ ):
279
+ super().__init__()
280
+ self.only_cross_attention = only_cross_attention
281
+
282
+ # We keep these boolean flags for backward-compatibility.
283
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
284
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
285
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
286
+ self.use_layer_norm = norm_type == "layer_norm"
287
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
288
+
289
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
290
+ raise ValueError(
291
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
292
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
293
+ )
294
+
295
+ self.norm_type = norm_type
296
+ self.num_embeds_ada_norm = num_embeds_ada_norm
297
+
298
+ if positional_embeddings and (num_positional_embeddings is None):
299
+ raise ValueError(
300
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
301
+ )
302
+
303
+ if positional_embeddings == "sinusoidal":
304
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
305
+ else:
306
+ self.pos_embed = None
307
+
308
+ # Define 3 blocks. Each block has its own normalization layer.
309
+ # 1. Self-Attn
310
+ if norm_type == "ada_norm":
311
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
312
+ elif norm_type == "ada_norm_zero":
313
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
314
+ elif norm_type == "ada_norm_continuous":
315
+ self.norm1 = AdaLayerNormContinuous(
316
+ dim,
317
+ ada_norm_continous_conditioning_embedding_dim,
318
+ norm_elementwise_affine,
319
+ norm_eps,
320
+ ada_norm_bias,
321
+ "rms_norm",
322
+ )
323
+ else:
324
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
325
+
326
+ self.attn1 = Attention(
327
+ query_dim=dim,
328
+ heads=num_attention_heads,
329
+ dim_head=attention_head_dim,
330
+ dropout=dropout,
331
+ bias=attention_bias,
332
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
333
+ upcast_attention=upcast_attention,
334
+ out_bias=attention_out_bias,
335
+ )
336
+
337
+ # 2. Cross-Attn
338
+ if cross_attention_dim is not None or double_self_attention:
339
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
340
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
341
+ # the second cross attention block.
342
+ if norm_type == "ada_norm":
343
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
344
+ elif norm_type == "ada_norm_continuous":
345
+ self.norm2 = AdaLayerNormContinuous(
346
+ dim,
347
+ ada_norm_continous_conditioning_embedding_dim,
348
+ norm_elementwise_affine,
349
+ norm_eps,
350
+ ada_norm_bias,
351
+ "rms_norm",
352
+ )
353
+ else:
354
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
355
+
356
+ self.attn2 = Attention(
357
+ query_dim=dim,
358
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
359
+ heads=num_attention_heads,
360
+ dim_head=attention_head_dim,
361
+ dropout=dropout,
362
+ bias=attention_bias,
363
+ upcast_attention=upcast_attention,
364
+ out_bias=attention_out_bias,
365
+ ) # is self-attn if encoder_hidden_states is none
366
+ else:
367
+ self.norm2 = None
368
+ self.attn2 = None
369
+
370
+ # 3. Feed-forward
371
+ if norm_type == "ada_norm_continuous":
372
+ self.norm3 = AdaLayerNormContinuous(
373
+ dim,
374
+ ada_norm_continous_conditioning_embedding_dim,
375
+ norm_elementwise_affine,
376
+ norm_eps,
377
+ ada_norm_bias,
378
+ "layer_norm",
379
+ )
380
+
381
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]:
382
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
383
+ elif norm_type == "layer_norm_i2vgen":
384
+ self.norm3 = None
385
+
386
+ self.ff = FeedForward(
387
+ dim,
388
+ dropout=dropout,
389
+ activation_fn=activation_fn,
390
+ final_dropout=final_dropout,
391
+ inner_dim=ff_inner_dim,
392
+ bias=ff_bias,
393
+ )
394
+
395
+ # 4. Fuser
396
+ if attention_type == "gated" or attention_type == "gated-text-image":
397
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
398
+
399
+ # 5. Scale-shift for PixArt-Alpha.
400
+ if norm_type == "ada_norm_single":
401
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
402
+
403
+ # let chunk size default to None
404
+ self._chunk_size = None
405
+ self._chunk_dim = 0
406
+
407
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
408
+ # Sets chunk feed-forward
409
+ self._chunk_size = chunk_size
410
+ self._chunk_dim = dim
411
+
412
+ def forward(
413
+ self,
414
+ hidden_states: torch.Tensor,
415
+ attention_mask: Optional[torch.Tensor] = None,
416
+ encoder_hidden_states: Optional[torch.Tensor] = None,
417
+ encoder_attention_mask: Optional[torch.Tensor] = None,
418
+ timestep: Optional[torch.LongTensor] = None,
419
+ cross_attention_kwargs: Dict[str, Any] = None,
420
+ class_labels: Optional[torch.LongTensor] = None,
421
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
422
+ ) -> torch.Tensor:
423
+ if cross_attention_kwargs is not None:
424
+ if cross_attention_kwargs.get("scale", None) is not None:
425
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
426
+
427
+ # Notice that normalization is always applied before the real computation in the following blocks.
428
+ # 0. Self-Attention
429
+ batch_size = hidden_states.shape[0]
430
+
431
+ if self.norm_type == "ada_norm":
432
+ norm_hidden_states = self.norm1(hidden_states, timestep)
433
+ elif self.norm_type == "ada_norm_zero":
434
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
435
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
436
+ )
437
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
438
+ norm_hidden_states = self.norm1(hidden_states)
439
+ elif self.norm_type == "ada_norm_continuous":
440
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
441
+ elif self.norm_type == "ada_norm_single":
442
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
443
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
444
+ ).chunk(6, dim=1)
445
+ norm_hidden_states = self.norm1(hidden_states)
446
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
447
+ norm_hidden_states = norm_hidden_states.squeeze(1)
448
+ else:
449
+ raise ValueError("Incorrect norm used")
450
+
451
+ if self.pos_embed is not None:
452
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
453
+
454
+ # 1. Prepare GLIGEN inputs
455
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
456
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
457
+
458
+ attn_output = self.attn1(
459
+ norm_hidden_states,
460
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
461
+ attention_mask=attention_mask,
462
+ **cross_attention_kwargs,
463
+ )
464
+ if self.norm_type == "ada_norm_zero":
465
+ attn_output = gate_msa.unsqueeze(1) * attn_output
466
+ elif self.norm_type == "ada_norm_single":
467
+ attn_output = gate_msa * attn_output
468
+
469
+ hidden_states = attn_output + hidden_states
470
+ if hidden_states.ndim == 4:
471
+ hidden_states = hidden_states.squeeze(1)
472
+
473
+ # 1.2 GLIGEN Control
474
+ if gligen_kwargs is not None:
475
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
476
+
477
+ # 3. Cross-Attention
478
+ if self.attn2 is not None:
479
+ if self.norm_type == "ada_norm":
480
+ norm_hidden_states = self.norm2(hidden_states, timestep)
481
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
482
+ norm_hidden_states = self.norm2(hidden_states)
483
+ elif self.norm_type == "ada_norm_single":
484
+ # For PixArt norm2 isn't applied here:
485
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
486
+ norm_hidden_states = hidden_states
487
+ elif self.norm_type == "ada_norm_continuous":
488
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
489
+ else:
490
+ raise ValueError("Incorrect norm")
491
+
492
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
493
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
494
+
495
+ attn_output = self.attn2(
496
+ norm_hidden_states,
497
+ encoder_hidden_states=encoder_hidden_states,
498
+ attention_mask=encoder_attention_mask,
499
+ **cross_attention_kwargs,
500
+ )
501
+ hidden_states = attn_output + hidden_states
502
+
503
+ # 4. Feed-forward
504
+ # i2vgen doesn't have this norm 🤷‍♂️
505
+ if self.norm_type == "ada_norm_continuous":
506
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
507
+ elif not self.norm_type == "ada_norm_single":
508
+ norm_hidden_states = self.norm3(hidden_states)
509
+
510
+ if self.norm_type == "ada_norm_zero":
511
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
512
+
513
+ if self.norm_type == "ada_norm_single":
514
+ norm_hidden_states = self.norm2(hidden_states)
515
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
516
+
517
+ if self._chunk_size is not None:
518
+ # "feed_forward_chunk_size" can be used to save memory
519
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
520
+ else:
521
+ ff_output = self.ff(norm_hidden_states)
522
+
523
+ if self.norm_type == "ada_norm_zero":
524
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
525
+ elif self.norm_type == "ada_norm_single":
526
+ ff_output = gate_mlp * ff_output
527
+
528
+ hidden_states = ff_output + hidden_states
529
+ if hidden_states.ndim == 4:
530
+ hidden_states = hidden_states.squeeze(1)
531
+
532
+ return hidden_states
533
+
534
+
535
+ @maybe_allow_in_graph
536
+ class TemporalBasicTransformerBlock(nn.Module):
537
+ r"""
538
+ A basic Transformer block for video like data.
539
+
540
+ Parameters:
541
+ dim (`int`): The number of channels in the input and output.
542
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
543
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
544
+ attention_head_dim (`int`): The number of channels in each head.
545
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
546
+ """
547
+
548
+ def __init__(
549
+ self,
550
+ dim: int,
551
+ time_mix_inner_dim: int,
552
+ num_attention_heads: int,
553
+ attention_head_dim: int,
554
+ cross_attention_dim: Optional[int] = None,
555
+ ):
556
+ super().__init__()
557
+ self.is_res = dim == time_mix_inner_dim
558
+
559
+ self.norm_in = nn.LayerNorm(dim)
560
+
561
+ # Define 3 blocks. Each block has its own normalization layer.
562
+ # 1. Self-Attn
563
+ self.ff_in = FeedForward(
564
+ dim,
565
+ dim_out=time_mix_inner_dim,
566
+ activation_fn="geglu",
567
+ )
568
+
569
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
570
+ self.attn1 = Attention(
571
+ query_dim=time_mix_inner_dim,
572
+ heads=num_attention_heads,
573
+ dim_head=attention_head_dim,
574
+ cross_attention_dim=None,
575
+ )
576
+
577
+ # 2. Cross-Attn
578
+ if cross_attention_dim is not None:
579
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
580
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
581
+ # the second cross attention block.
582
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
583
+ self.attn2 = Attention(
584
+ query_dim=time_mix_inner_dim,
585
+ cross_attention_dim=cross_attention_dim,
586
+ heads=num_attention_heads,
587
+ dim_head=attention_head_dim,
588
+ ) # is self-attn if encoder_hidden_states is none
589
+ else:
590
+ self.norm2 = None
591
+ self.attn2 = None
592
+
593
+ # 3. Feed-forward
594
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
595
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
596
+
597
+ # let chunk size default to None
598
+ self._chunk_size = None
599
+ self._chunk_dim = None
600
+
601
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
602
+ # Sets chunk feed-forward
603
+ self._chunk_size = chunk_size
604
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
605
+ self._chunk_dim = 1
606
+
607
+ def forward(
608
+ self,
609
+ hidden_states: torch.Tensor,
610
+ num_frames: int,
611
+ encoder_hidden_states: Optional[torch.Tensor] = None,
612
+ ) -> torch.Tensor:
613
+ # Notice that normalization is always applied before the real computation in the following blocks.
614
+ # 0. Self-Attention
615
+ batch_size = hidden_states.shape[0]
616
+
617
+ batch_frames, seq_length, channels = hidden_states.shape
618
+ batch_size = batch_frames // num_frames
619
+
620
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
621
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
622
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
623
+
624
+ residual = hidden_states
625
+ hidden_states = self.norm_in(hidden_states)
626
+
627
+ if self._chunk_size is not None:
628
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
629
+ else:
630
+ hidden_states = self.ff_in(hidden_states)
631
+
632
+ if self.is_res:
633
+ hidden_states = hidden_states + residual
634
+
635
+ norm_hidden_states = self.norm1(hidden_states)
636
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
637
+ hidden_states = attn_output + hidden_states
638
+
639
+ # 3. Cross-Attention
640
+ if self.attn2 is not None:
641
+ norm_hidden_states = self.norm2(hidden_states)
642
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
643
+ hidden_states = attn_output + hidden_states
644
+
645
+ # 4. Feed-forward
646
+ norm_hidden_states = self.norm3(hidden_states)
647
+
648
+ if self._chunk_size is not None:
649
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
650
+ else:
651
+ ff_output = self.ff(norm_hidden_states)
652
+
653
+ if self.is_res:
654
+ hidden_states = ff_output + hidden_states
655
+ else:
656
+ hidden_states = ff_output
657
+
658
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
659
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
660
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
661
+
662
+ return hidden_states
663
+
664
+
665
+ class SkipFFTransformerBlock(nn.Module):
666
+ def __init__(
667
+ self,
668
+ dim: int,
669
+ num_attention_heads: int,
670
+ attention_head_dim: int,
671
+ kv_input_dim: int,
672
+ kv_input_dim_proj_use_bias: bool,
673
+ dropout=0.0,
674
+ cross_attention_dim: Optional[int] = None,
675
+ attention_bias: bool = False,
676
+ attention_out_bias: bool = True,
677
+ ):
678
+ super().__init__()
679
+ if kv_input_dim != dim:
680
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
681
+ else:
682
+ self.kv_mapper = None
683
+
684
+ self.norm1 = RMSNorm(dim, 1e-06)
685
+
686
+ self.attn1 = Attention(
687
+ query_dim=dim,
688
+ heads=num_attention_heads,
689
+ dim_head=attention_head_dim,
690
+ dropout=dropout,
691
+ bias=attention_bias,
692
+ cross_attention_dim=cross_attention_dim,
693
+ out_bias=attention_out_bias,
694
+ )
695
+
696
+ self.norm2 = RMSNorm(dim, 1e-06)
697
+
698
+ self.attn2 = Attention(
699
+ query_dim=dim,
700
+ cross_attention_dim=cross_attention_dim,
701
+ heads=num_attention_heads,
702
+ dim_head=attention_head_dim,
703
+ dropout=dropout,
704
+ bias=attention_bias,
705
+ out_bias=attention_out_bias,
706
+ )
707
+
708
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
709
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
710
+
711
+ if self.kv_mapper is not None:
712
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
713
+
714
+ norm_hidden_states = self.norm1(hidden_states)
715
+
716
+ attn_output = self.attn1(
717
+ norm_hidden_states,
718
+ encoder_hidden_states=encoder_hidden_states,
719
+ **cross_attention_kwargs,
720
+ )
721
+
722
+ hidden_states = attn_output + hidden_states
723
+
724
+ norm_hidden_states = self.norm2(hidden_states)
725
+
726
+ attn_output = self.attn2(
727
+ norm_hidden_states,
728
+ encoder_hidden_states=encoder_hidden_states,
729
+ **cross_attention_kwargs,
730
+ )
731
+
732
+ hidden_states = attn_output + hidden_states
733
+
734
+ return hidden_states
735
+
736
+
737
+ class FeedForward(nn.Module):
738
+ r"""
739
+ A feed-forward layer.
740
+
741
+ Parameters:
742
+ dim (`int`): The number of channels in the input.
743
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
744
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
745
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
746
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
747
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
748
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
749
+ """
750
+
751
+ def __init__(
752
+ self,
753
+ dim: int,
754
+ dim_out: Optional[int] = None,
755
+ mult: int = 4,
756
+ dropout: float = 0.0,
757
+ activation_fn: str = "geglu",
758
+ final_dropout: bool = False,
759
+ inner_dim=None,
760
+ bias: bool = True,
761
+ ):
762
+ super().__init__()
763
+ if inner_dim is None:
764
+ inner_dim = int(dim * mult)
765
+ dim_out = dim_out if dim_out is not None else dim
766
+
767
+ if activation_fn == "gelu":
768
+ act_fn = GELU(dim, inner_dim, bias=bias)
769
+ if activation_fn == "gelu-approximate":
770
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
771
+ elif activation_fn == "geglu":
772
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
773
+ elif activation_fn == "geglu-approximate":
774
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
775
+
776
+ self.net = nn.ModuleList([])
777
+ # project in
778
+ self.net.append(act_fn)
779
+ # project dropout
780
+ self.net.append(nn.Dropout(dropout))
781
+ # project out
782
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
783
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
784
+ if final_dropout:
785
+ self.net.append(nn.Dropout(dropout))
786
+
787
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
788
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
789
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
790
+ deprecate("scale", "1.0.0", deprecation_message)
791
+ for module in self.net:
792
+ hidden_states = module(hidden_states)
793
+ return hidden_states
models/modeling_t5.py ADDED
The diff for this file is too large to render. See raw diff
 
models/pipeline.py ADDED
@@ -0,0 +1,963 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from transformers import (
20
+ CLIPTextModelWithProjection,
21
+ CLIPTokenizer,
22
+ T5TokenizerFast,
23
+ )
24
+ from models.modeling_t5 import T5EncoderModel
25
+ from models.VchitectXL import VchitectXLTransformerModel
26
+ from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel, CLIPTextModelWithProjection
27
+ from diffusers.image_processor import VaeImageProcessor
28
+ from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin
29
+ from diffusers.models.autoencoders import AutoencoderKL
30
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
31
+ from diffusers.utils import (
32
+ is_torch_xla_available,
33
+ logging,
34
+ replace_example_docstring,
35
+ )
36
+ from diffusers.utils.torch_utils import randn_tensor
37
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
38
+
39
+ from op_replace import replace_all_layernorms
40
+ if is_torch_xla_available():
41
+ import torch_xla.core.xla_model as xm
42
+ XLA_AVAILABLE = True
43
+ else:
44
+ XLA_AVAILABLE = False
45
+
46
+ import math
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+ EXAMPLE_DOC_STRING = """
51
+ Examples:
52
+ ```py
53
+ >>> import torch
54
+ >>> from diffusers import VchitectXLPipeline
55
+
56
+ >>> pipe = VchitectXLPipeline.from_pretrained(
57
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
58
+ ... )
59
+ >>> pipe.to("cuda")
60
+ >>> prompt = "A cat holding a sign that says hello world"
61
+ >>> image = pipe(prompt).images[0]
62
+ >>> image.save("sd3.png")
63
+ ```
64
+ """
65
+
66
+
67
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
68
+ def retrieve_timesteps(
69
+ scheduler,
70
+ num_inference_steps: Optional[int] = None,
71
+ device: Optional[Union[str, torch.device]] = None,
72
+ timesteps: Optional[List[int]] = None,
73
+ sigmas: Optional[List[float]] = None,
74
+ **kwargs,
75
+ ):
76
+ """
77
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
78
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
79
+
80
+ Args:
81
+ scheduler (`SchedulerMixin`):
82
+ The scheduler to get timesteps from.
83
+ num_inference_steps (`int`):
84
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
85
+ must be `None`.
86
+ device (`str` or `torch.device`, *optional*):
87
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
88
+ timesteps (`List[int]`, *optional*):
89
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
90
+ `num_inference_steps` and `sigmas` must be `None`.
91
+ sigmas (`List[float]`, *optional*):
92
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
93
+ `num_inference_steps` and `timesteps` must be `None`.
94
+
95
+ Returns:
96
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
97
+ second element is the number of inference steps.
98
+ """
99
+ if timesteps is not None and sigmas is not None:
100
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
101
+ if timesteps is not None:
102
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
103
+ if not accepts_timesteps:
104
+ raise ValueError(
105
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
106
+ f" timestep schedules. Please check whether you are using the correct scheduler."
107
+ )
108
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
109
+ timesteps = scheduler.timesteps
110
+ num_inference_steps = len(timesteps)
111
+ elif sigmas is not None:
112
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
113
+ if not accept_sigmas:
114
+ raise ValueError(
115
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
116
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
117
+ )
118
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
119
+ timesteps = scheduler.timesteps
120
+ num_inference_steps = len(timesteps)
121
+ else:
122
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
123
+ timesteps = scheduler.timesteps
124
+ return timesteps, num_inference_steps
125
+
126
+ def load_text_encoders(load_path, class_one, class_two, class_three, precision="fp16"):
127
+ text_encoder_one = class_one.from_pretrained(
128
+ load_path, subfolder="text_encoder", revision=None, variant=precision
129
+ )
130
+ text_encoder_two = class_two.from_pretrained(
131
+ load_path, subfolder="text_encoder_2", revision=None, variant=precision
132
+ )
133
+ text_encoder_three = class_three.from_pretrained(
134
+ load_path, subfolder="text_encoder_3", revision=None, variant=precision
135
+ )
136
+ return text_encoder_one, text_encoder_two, text_encoder_three
137
+
138
+
139
+ def import_model_class_from_model_name_or_path(
140
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
141
+ ):
142
+ text_encoder_config = PretrainedConfig.from_pretrained(
143
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
144
+ )
145
+ model_class = text_encoder_config.architectures[0]
146
+ if model_class == "CLIPTextModelWithProjection":
147
+ from transformers import CLIPTextModelWithProjection
148
+
149
+ return CLIPTextModelWithProjection
150
+ elif model_class == "T5EncoderModel":
151
+ from transformers import T5EncoderModel
152
+
153
+ return T5EncoderModel
154
+ else:
155
+ raise ValueError(f"{model_class} is not supported.")
156
+
157
+ class VchitectXLPipeline():
158
+ r"""
159
+ Args:
160
+ transformer ([`VchitectXLTransformerModel`]):
161
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
162
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
163
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
164
+ vae ([`AutoencoderKL`]):
165
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
166
+ text_encoder ([`CLIPTextModelWithProjection`]):
167
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
168
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
169
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
170
+ as its dimension.
171
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
172
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
173
+ specifically the
174
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
175
+ variant.
176
+ text_encoder_3 ([`T5EncoderModel`]):
177
+ Frozen text-encoder. Stable Diffusion 3 uses
178
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
179
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
180
+ tokenizer (`CLIPTokenizer`):
181
+ Tokenizer of class
182
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
183
+ tokenizer_2 (`CLIPTokenizer`):
184
+ Second Tokenizer of class
185
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
186
+ tokenizer_3 (`T5TokenizerFast`):
187
+ Tokenizer of class
188
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
189
+ """
190
+
191
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
192
+ _optional_components = []
193
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
194
+
195
+ def __init__(
196
+ self,
197
+ load_path = None,
198
+ device = None,
199
+ ):
200
+ super().__init__()
201
+
202
+ # Load the tokenizers
203
+ self.tokenizer = CLIPTokenizer.from_pretrained(
204
+ load_path,
205
+ subfolder="tokenizer",
206
+ revision=None,
207
+ )
208
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(
209
+ load_path,
210
+ subfolder="tokenizer_2",
211
+ revision=None,
212
+ )
213
+ self.tokenizer_3 = T5TokenizerFast.from_pretrained(
214
+ load_path,
215
+ subfolder="tokenizer_3",
216
+ revision=None,
217
+ )
218
+
219
+ # import correct text encoder classes
220
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
221
+ load_path, None
222
+ )
223
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
224
+ load_path, None, subfolder="text_encoder_2"
225
+ )
226
+ text_encoder_cls_three = import_model_class_from_model_name_or_path(
227
+ load_path, None, subfolder="text_encoder_3"
228
+ )
229
+ # Load scheduler and models
230
+ self.text_encoder, self.text_encoder_2, self.text_encoder_3 = load_text_encoders(
231
+ load_path, text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three, None
232
+ )
233
+ self.text_encoder, self.text_encoder_2, self.text_encoder_3 = self.text_encoder.to(device), self.text_encoder_2.to(device), self.text_encoder_3.to(device)
234
+
235
+ self.vae = AutoencoderKL.from_pretrained(
236
+ load_path,
237
+ subfolder="vae",
238
+ revision=None,
239
+ variant=None,
240
+ ).to(device)
241
+
242
+ # self.transformer = VchitectXLTransformerModel.from_pretrained_temporal(load_path,torch_dtype=torch.bfloat16,logger=None,subfolder="transformer").to(device)
243
+ self.transformer = VchitectXLTransformerModel.from_pretrained(load_path,torch_dtype=torch.bfloat16,subfolder="transformer").to(device)
244
+ self.transformer = replace_all_layernorms(self.transformer)
245
+ self.transformer.eval()
246
+
247
+ self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
248
+ load_path, subfolder="scheduler"
249
+ )
250
+
251
+ self.execution_device = "cuda"
252
+
253
+ self.vae_scale_factor = (
254
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
255
+ )
256
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
257
+ self.tokenizer_max_length = (
258
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
259
+ )
260
+ self.max_sequence_length_t5 = 256
261
+ self.default_sample_size = (
262
+ self.transformer.config.sample_size
263
+ if hasattr(self, "transformer") and self.transformer is not None
264
+ else 128
265
+ )
266
+
267
+ def _get_t5_prompt_embeds(
268
+ self,
269
+ prompt: Union[str, List[str]] = None,
270
+ num_images_per_prompt: int = 1,
271
+ device: Optional[torch.device] = None,
272
+ dtype: Optional[torch.dtype] = None,
273
+ ):
274
+ device = device or self.execution_device
275
+ dtype = dtype or self.text_encoder.dtype
276
+
277
+ prompt = [prompt] if isinstance(prompt, str) else prompt
278
+ batch_size = len(prompt)
279
+
280
+ if self.text_encoder_3 is None:
281
+ return torch.zeros(
282
+ (batch_size, self.max_sequence_length_t5, self.transformer.config.joint_attention_dim),
283
+ device=device,
284
+ dtype=dtype,
285
+ )
286
+
287
+ text_inputs = self.tokenizer_3(
288
+ prompt,
289
+ padding="max_length",
290
+ max_length=self.max_sequence_length_t5,
291
+ truncation=True,
292
+ add_special_tokens=True,
293
+ return_tensors="pt",
294
+ )
295
+ text_input_ids = text_inputs.input_ids
296
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
297
+
298
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
299
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.max_sequence_length_t5 - 1 : -1])
300
+ logger.warning(
301
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
302
+ f" {self.max_sequence_length_t5} tokens: {removed_text}"
303
+ )
304
+
305
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
306
+
307
+ dtype = self.text_encoder_3.dtype
308
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
309
+
310
+ _, seq_len, _ = prompt_embeds.shape
311
+
312
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
313
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
314
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
315
+
316
+ return prompt_embeds
317
+
318
+ def _get_clip_prompt_embeds(
319
+ self,
320
+ prompt: Union[str, List[str]],
321
+ num_images_per_prompt: int = 1,
322
+ device: Optional[torch.device] = None,
323
+ clip_skip: Optional[int] = None,
324
+ clip_model_index: int = 0,
325
+ ):
326
+ device = device or self.execution_device
327
+
328
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
329
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
330
+
331
+ tokenizer = clip_tokenizers[clip_model_index]
332
+ text_encoder = clip_text_encoders[clip_model_index]
333
+
334
+ prompt = [prompt] if isinstance(prompt, str) else prompt
335
+ batch_size = len(prompt)
336
+
337
+ text_inputs = tokenizer(
338
+ prompt,
339
+ padding="max_length",
340
+ max_length=self.tokenizer_max_length,
341
+ truncation=True,
342
+ return_tensors="pt",
343
+ )
344
+
345
+ text_input_ids = text_inputs.input_ids
346
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
347
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
348
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
349
+ logger.warning(
350
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
351
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
352
+ )
353
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
354
+ pooled_prompt_embeds = prompt_embeds[0]
355
+
356
+ if clip_skip is None:
357
+ prompt_embeds = prompt_embeds.hidden_states[-2]
358
+ else:
359
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
360
+
361
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
362
+
363
+ _, seq_len, _ = prompt_embeds.shape
364
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
365
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
366
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
367
+
368
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
369
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
370
+
371
+ return prompt_embeds, pooled_prompt_embeds
372
+
373
+ def encode_prompt(
374
+ self,
375
+ prompt: Union[str, List[str]],
376
+ prompt_2: Union[str, List[str]],
377
+ prompt_3: Union[str, List[str]],
378
+ device: Optional[torch.device] = None,
379
+ num_images_per_prompt: int = 1,
380
+ do_classifier_free_guidance: bool = True,
381
+ negative_prompt: Optional[Union[str, List[str]]] = None,
382
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
383
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
384
+ prompt_embeds: Optional[torch.FloatTensor] = None,
385
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
386
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
387
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
388
+ clip_skip: Optional[int] = None,
389
+ ):
390
+ r"""
391
+
392
+ Args:
393
+ prompt (`str` or `List[str]`, *optional*):
394
+ prompt to be encoded
395
+ prompt_2 (`str` or `List[str]`, *optional*):
396
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
397
+ used in all text-encoders
398
+ prompt_3 (`str` or `List[str]`, *optional*):
399
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
400
+ used in all text-encoders
401
+ device: (`torch.device`):
402
+ torch device
403
+ num_images_per_prompt (`int`):
404
+ number of images that should be generated per prompt
405
+ do_classifier_free_guidance (`bool`):
406
+ whether to use classifier free guidance or not
407
+ negative_prompt (`str` or `List[str]`, *optional*):
408
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
409
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
410
+ less than `1`).
411
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
412
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
413
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
414
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
415
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
416
+ `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
417
+ prompt_embeds (`torch.FloatTensor`, *optional*):
418
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
419
+ provided, text embeddings will be generated from `prompt` input argument.
420
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
421
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
422
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
423
+ argument.
424
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
425
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
426
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
427
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
428
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
429
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
430
+ input argument.
431
+ clip_skip (`int`, *optional*):
432
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
433
+ the output of the pre-final layer will be used for computing the prompt embeddings.
434
+ """
435
+ device = device or self.execution_device
436
+
437
+ prompt = [prompt] if isinstance(prompt, str) else prompt
438
+ if prompt is not None:
439
+ batch_size = len(prompt)
440
+ else:
441
+ batch_size = prompt_embeds.shape[0]
442
+
443
+ if prompt_embeds is None:
444
+ prompt_2 = prompt_2 or prompt
445
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
446
+
447
+ prompt_3 = prompt_3 or prompt
448
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
449
+
450
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
451
+ prompt=prompt,
452
+ device=device,
453
+ num_images_per_prompt=num_images_per_prompt,
454
+ clip_skip=clip_skip,
455
+ clip_model_index=0,
456
+ )
457
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
458
+ prompt=prompt_2,
459
+ device=device,
460
+ num_images_per_prompt=num_images_per_prompt,
461
+ clip_skip=clip_skip,
462
+ clip_model_index=1,
463
+ )
464
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
465
+
466
+ t5_prompt_embed = self._get_t5_prompt_embeds(
467
+ prompt=prompt_3,
468
+ num_images_per_prompt=num_images_per_prompt,
469
+ device=device,
470
+ )
471
+
472
+ clip_prompt_embeds = torch.nn.functional.pad(
473
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
474
+ )
475
+
476
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
477
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
478
+
479
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
480
+ negative_prompt = negative_prompt or ""
481
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
482
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
483
+
484
+ # normalize str to list
485
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
486
+ negative_prompt_2 = (
487
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
488
+ )
489
+ negative_prompt_3 = (
490
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
491
+ )
492
+
493
+ if prompt is not None and type(prompt) is not type(negative_prompt):
494
+ raise TypeError(
495
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
496
+ f" {type(prompt)}."
497
+ )
498
+ elif batch_size != len(negative_prompt):
499
+ raise ValueError(
500
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
501
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
502
+ " the batch size of `prompt`."
503
+ )
504
+
505
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
506
+ negative_prompt,
507
+ device=device,
508
+ num_images_per_prompt=num_images_per_prompt,
509
+ clip_skip=None,
510
+ clip_model_index=0,
511
+ )
512
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
513
+ negative_prompt_2,
514
+ device=device,
515
+ num_images_per_prompt=num_images_per_prompt,
516
+ clip_skip=None,
517
+ clip_model_index=1,
518
+ )
519
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
520
+
521
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
522
+ prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, device=device
523
+ )
524
+
525
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
526
+ negative_clip_prompt_embeds,
527
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
528
+ )
529
+
530
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
531
+ negative_pooled_prompt_embeds = torch.cat(
532
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
533
+ )
534
+
535
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
536
+
537
+ def check_inputs(
538
+ self,
539
+ prompt,
540
+ prompt_2,
541
+ prompt_3,
542
+ height,
543
+ width,
544
+ negative_prompt=None,
545
+ negative_prompt_2=None,
546
+ negative_prompt_3=None,
547
+ prompt_embeds=None,
548
+ negative_prompt_embeds=None,
549
+ pooled_prompt_embeds=None,
550
+ negative_pooled_prompt_embeds=None,
551
+ callback_on_step_end_tensor_inputs=None,
552
+ ):
553
+ if height % 8 != 0 or width % 8 != 0:
554
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
555
+
556
+ if callback_on_step_end_tensor_inputs is not None and not all(
557
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
558
+ ):
559
+ raise ValueError(
560
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
561
+ )
562
+
563
+ if prompt is not None and prompt_embeds is not None:
564
+ raise ValueError(
565
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
566
+ " only forward one of the two."
567
+ )
568
+ elif prompt_2 is not None and prompt_embeds is not None:
569
+ raise ValueError(
570
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
571
+ " only forward one of the two."
572
+ )
573
+ elif prompt_3 is not None and prompt_embeds is not None:
574
+ raise ValueError(
575
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
576
+ " only forward one of the two."
577
+ )
578
+ elif prompt is None and prompt_embeds is None:
579
+ raise ValueError(
580
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
581
+ )
582
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
583
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
584
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
585
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
586
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
587
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
588
+
589
+ if negative_prompt is not None and negative_prompt_embeds is not None:
590
+ raise ValueError(
591
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
592
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
593
+ )
594
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
595
+ raise ValueError(
596
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
597
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
598
+ )
599
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
600
+ raise ValueError(
601
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
602
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
603
+ )
604
+
605
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
606
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
607
+ raise ValueError(
608
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
609
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
610
+ f" {negative_prompt_embeds.shape}."
611
+ )
612
+
613
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
614
+ raise ValueError(
615
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
616
+ )
617
+
618
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
619
+ raise ValueError(
620
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
621
+ )
622
+
623
+ def prepare_latents(
624
+ self,
625
+ batch_size,
626
+ num_channels_latents,
627
+ height,
628
+ width,
629
+ frames,
630
+ dtype,
631
+ device,
632
+ generator,
633
+ latents=None,
634
+ ):
635
+ if latents is not None:
636
+ return latents.to(device=device, dtype=dtype)
637
+ #1, 60, 16, 32, 32
638
+ shape = (
639
+ batch_size,
640
+ frames,
641
+ num_channels_latents,
642
+ int(height) // self.vae_scale_factor,
643
+ int(width) // self.vae_scale_factor,
644
+ )
645
+
646
+ if isinstance(generator, list) and len(generator) != batch_size:
647
+ raise ValueError(
648
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
649
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
650
+ )
651
+
652
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
653
+
654
+ return latents
655
+
656
+ @property
657
+ def guidance_scale(self):
658
+ return self._guidance_scale
659
+
660
+ @property
661
+ def clip_skip(self):
662
+ return self._clip_skip
663
+
664
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
665
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
666
+ # corresponds to doing no classifier free guidance.
667
+ @property
668
+ def do_classifier_free_guidance(self):
669
+ return self._guidance_scale > 1
670
+
671
+ @property
672
+ def joint_attention_kwargs(self):
673
+ return self._joint_attention_kwargs
674
+
675
+ @property
676
+ def num_timesteps(self):
677
+ return self._num_timesteps
678
+
679
+ @property
680
+ def interrupt(self):
681
+ return self._interrupt
682
+
683
+ @torch.no_grad()
684
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
685
+ def __call__(
686
+ self,
687
+ prompt: Union[str, List[str]] = None,
688
+ prompt_2: Optional[Union[str, List[str]]] = None,
689
+ prompt_3: Optional[Union[str, List[str]]] = None,
690
+ height: Optional[int] = None,
691
+ width: Optional[int] = None,
692
+ frames: Optional[int] = None,
693
+ num_inference_steps: int = 28,
694
+ timesteps: List[int] = None,
695
+ guidance_scale: float = 7.0,
696
+ negative_prompt: Optional[Union[str, List[str]]] = None,
697
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
698
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
699
+ num_images_per_prompt: Optional[int] = 1,
700
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
701
+ latents: Optional[torch.FloatTensor] = None,
702
+ prompt_embeds: Optional[torch.FloatTensor] = None,
703
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
704
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
705
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
706
+ output_type: Optional[str] = "pil",
707
+ return_dict: bool = True,
708
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
709
+ clip_skip: Optional[int] = None,
710
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
711
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
712
+ ):
713
+ r"""
714
+ Function invoked when calling the pipeline for generation.
715
+
716
+ Args:
717
+ prompt (`str` or `List[str]`, *optional*):
718
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
719
+ instead.
720
+ prompt_2 (`str` or `List[str]`, *optional*):
721
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
722
+ will be used instead
723
+ prompt_3 (`str` or `List[str]`, *optional*):
724
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
725
+ will be used instead
726
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
727
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
728
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
729
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
730
+ num_inference_steps (`int`, *optional*, defaults to 50):
731
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
732
+ expense of slower inference.
733
+ timesteps (`List[int]`, *optional*):
734
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
735
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
736
+ passed will be used. Must be in descending order.
737
+ guidance_scale (`float`, *optional*, defaults to 5.0):
738
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
739
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
740
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
741
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
742
+ usually at the expense of lower image quality.
743
+ negative_prompt (`str` or `List[str]`, *optional*):
744
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
745
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
746
+ less than `1`).
747
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
748
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
749
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
750
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
751
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
752
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
753
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
754
+ The number of images to generate per prompt.
755
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
756
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
757
+ to make generation deterministic.
758
+ latents (`torch.FloatTensor`, *optional*):
759
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
760
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
761
+ tensor will ge generated by sampling using the supplied random `generator`.
762
+ prompt_embeds (`torch.FloatTensor`, *optional*):
763
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
764
+ provided, text embeddings will be generated from `prompt` input argument.
765
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
766
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
767
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
768
+ argument.
769
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
770
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
771
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
772
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
773
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
774
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
775
+ input argument.
776
+ output_type (`str`, *optional*, defaults to `"pil"`):
777
+ The output format of the generate image. Choose between
778
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
779
+ return_dict (`bool`, *optional*, defaults to `True`):
780
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
781
+ of a plain tuple.
782
+ joint_attention_kwargs (`dict`, *optional*):
783
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
784
+ `self.processor` in
785
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
786
+ callback_on_step_end (`Callable`, *optional*):
787
+ A function that calls at the end of each denoising steps during the inference. The function is called
788
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
789
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
790
+ `callback_on_step_end_tensor_inputs`.
791
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
792
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
793
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
794
+ `._callback_tensor_inputs` attribute of your pipeline class.
795
+
796
+ Examples:
797
+
798
+ Returns:
799
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
800
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
801
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
802
+ """
803
+
804
+ height = height or self.default_sample_size * self.vae_scale_factor
805
+ width = width or self.default_sample_size * self.vae_scale_factor
806
+ frames = frames or 24
807
+
808
+ # 1. Check inputs. Raise error if not correct
809
+ self.check_inputs(
810
+ prompt,
811
+ prompt_2,
812
+ prompt_3,
813
+ height,
814
+ width,
815
+ negative_prompt=negative_prompt,
816
+ negative_prompt_2=negative_prompt_2,
817
+ negative_prompt_3=negative_prompt_3,
818
+ prompt_embeds=prompt_embeds,
819
+ negative_prompt_embeds=negative_prompt_embeds,
820
+ pooled_prompt_embeds=pooled_prompt_embeds,
821
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
822
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
823
+ )
824
+
825
+ self._guidance_scale = guidance_scale
826
+ self._clip_skip = clip_skip
827
+ self._joint_attention_kwargs = joint_attention_kwargs
828
+ self._interrupt = False
829
+
830
+ # 2. Define call parameters
831
+ if prompt is not None and isinstance(prompt, str):
832
+ batch_size = 1
833
+ elif prompt is not None and isinstance(prompt, list):
834
+ batch_size = len(prompt)
835
+ else:
836
+ batch_size = prompt_embeds.shape[0]
837
+
838
+ device = self.execution_device
839
+
840
+
841
+ (
842
+ prompt_embeds,
843
+ negative_prompt_embeds,
844
+ pooled_prompt_embeds,
845
+ negative_pooled_prompt_embeds,
846
+ ) = self.encode_prompt(
847
+ prompt=prompt,
848
+ prompt_2=prompt_2,
849
+ prompt_3=prompt_3,
850
+ negative_prompt=negative_prompt,
851
+ negative_prompt_2=negative_prompt_2,
852
+ negative_prompt_3=negative_prompt_3,
853
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
854
+ prompt_embeds=prompt_embeds,
855
+ negative_prompt_embeds=negative_prompt_embeds,
856
+ pooled_prompt_embeds=pooled_prompt_embeds,
857
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
858
+ device=device,
859
+ clip_skip=self.clip_skip,
860
+ num_images_per_prompt=num_images_per_prompt,
861
+ )
862
+
863
+ if self.do_classifier_free_guidance:
864
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
865
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
866
+
867
+ # 4. Prepare timesteps
868
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
869
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
870
+ self._num_timesteps = len(timesteps)
871
+
872
+ # 5. Prepare latent variables
873
+ num_channels_latents = self.transformer.config.in_channels
874
+ latents = self.prepare_latents(
875
+ batch_size * num_images_per_prompt,
876
+ num_channels_latents,
877
+ height,
878
+ width,
879
+ frames,
880
+ prompt_embeds.dtype,
881
+ device,
882
+ generator,
883
+ latents,
884
+ )
885
+
886
+ # 6. Denoising loop
887
+ # with self.progress_bar(total=num_inference_steps) as progress_bar:
888
+ from tqdm import tqdm
889
+ for i, t in tqdm(enumerate(timesteps)):
890
+ if self.interrupt:
891
+ continue
892
+
893
+ # expand the latents if we are doing classifier free guidance
894
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
895
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
896
+ timestep = t.expand(latents.shape[0])
897
+ noise_pred_uncond = self.transformer(
898
+ hidden_states=latent_model_input[0,:].unsqueeze(0),
899
+ timestep=timestep,
900
+ encoder_hidden_states=prompt_embeds[0,:].unsqueeze(0),
901
+ pooled_projections=pooled_prompt_embeds[0,:].unsqueeze(0),
902
+ joint_attention_kwargs=self.joint_attention_kwargs,
903
+ return_dict=False,
904
+ )[0]
905
+
906
+ noise_pred_text = self.transformer(
907
+ hidden_states=latent_model_input[1,:].unsqueeze(0),
908
+ timestep=timestep,
909
+ encoder_hidden_states=prompt_embeds[1,:].unsqueeze(0),
910
+ pooled_projections=pooled_prompt_embeds[1,:].unsqueeze(0),
911
+ joint_attention_kwargs=self.joint_attention_kwargs,
912
+ return_dict=False,
913
+ )[0]
914
+ self._guidance_scale = 1 + guidance_scale * (
915
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
916
+ )
917
+ # perform guidance
918
+ if self.do_classifier_free_guidance:
919
+ # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
920
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
921
+
922
+ # compute the previous noisy sample x_t -> x_t-1
923
+ latents_dtype = latents.dtype
924
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
925
+
926
+ if latents.dtype != latents_dtype:
927
+ if torch.backends.mps.is_available():
928
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
929
+ latents = latents.to(latents_dtype)
930
+
931
+ if callback_on_step_end is not None:
932
+ callback_kwargs = {}
933
+ for k in callback_on_step_end_tensor_inputs:
934
+ callback_kwargs[k] = locals()[k]
935
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
936
+
937
+ latents = callback_outputs.pop("latents", latents)
938
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
939
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
940
+ negative_pooled_prompt_embeds = callback_outputs.pop(
941
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
942
+ )
943
+
944
+ # call the callback, if provided
945
+ # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
946
+ # progress_bar.update()
947
+
948
+ if XLA_AVAILABLE:
949
+ xm.mark_step()
950
+
951
+ # if output_type == "latent":
952
+ # image = latents
953
+
954
+ # else:
955
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
956
+ videos = []
957
+ for v_idx in range(latents.shape[1]):
958
+ image = self.vae.decode(latents[:,v_idx], return_dict=False)[0]
959
+ image = self.image_processor.postprocess(image, output_type=output_type)
960
+ videos.append(image[0])
961
+
962
+ return videos
963
+
models/utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class RMSNorm(nn.Module):
6
+ """
7
+ Initialize the RMSNorm normalization layer.
8
+
9
+ Args:
10
+ dim (int): The dimension of the input tensor.
11
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
12
+
13
+ Attributes:
14
+ eps (float): A small value added to the denominator for numerical stability.
15
+ weight (nn.Parameter): Learnable scaling parameter.
16
+
17
+ """
18
+
19
+ def __init__(self, dim: int, eps: float = 1e-6):
20
+ super().__init__()
21
+ self.eps = eps
22
+ self.weight = nn.Parameter(torch.ones(dim))
23
+
24
+ def _norm(self, x: torch.Tensor):
25
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
26
+
27
+ def forward(self, x: torch.Tensor):
28
+ output = self._norm(x.float()).type_as(x)
29
+ return output * self.weight
30
+
31
+ def reset_parameters(self):
32
+ torch.nn.init.ones_(self.weight) # type: ignore
op_replace.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ try:
3
+ from apex.normalization import FusedLayerNorm
4
+ except ImportError as e:
5
+ try:
6
+ from xformers.triton import FusedLayerNorm
7
+ except ImportError as e:
8
+ FusedLayerNorm = None
9
+
10
+
11
+ def replace_all_layernorms(model):
12
+ if FusedLayerNorm is None:
13
+ print("WARNING: apex.normalization & xformers.triton.FusedLayerNorm is not found, \
14
+ skip using FusedLayerNorm")
15
+ return model
16
+ for name, module in model.named_children():
17
+ if isinstance(module, torch.nn.LayerNorm):
18
+ setattr(model, name, FusedLayerNorm(
19
+ module.normalized_shape, module.eps, module.elementwise_affine))
20
+ else:
21
+ replace_all_layernorms(module)
22
+ return model
23
+
24
+
25
+ def replace_all_groupnorms(model):
26
+ try:
27
+ from apex.contrib.group_norm import GroupNorm
28
+ except ImportError as e:
29
+ print("WARNING: apex.contrib.group_norm is not found, skip using apex groupnorm")
30
+ return model
31
+ for name, module in model.named_children():
32
+ if isinstance(module, torch.nn.GroupNorm):
33
+ setattr(model, name, GroupNorm(
34
+ module.num_groups, module.num_channels,
35
+ eps=module.eps, affine=module.affine))
36
+ else:
37
+ replace_all_groupnorms(module)
38
+ return model
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ accelerate
3
+ transformers
4
+ gradio
5
+ torch
6
+ torchvision
7
+ torchdiffeq
8
+ click
9
+ einops
10
+ moviepy
11
+ sentencepiece
12
+ Pillow==9.5.0
utils.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio
2
+ import numpy as np
3
+ from typing import List
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ import subprocess
7
+ from time import sleep
8
+ import os
9
+ import torch
10
+ import torch.distributed as dist
11
+ from torch.distributed.fsdp import (
12
+ FullyShardedDataParallel as FSDP,
13
+ StateDictType, FullStateDictConfig,
14
+ )
15
+ from torch.distributed.checkpoint.state_dict import (
16
+ StateDictOptions,
17
+ get_model_state_dict,
18
+ get_optimizer_state_dict,
19
+ set_optimizer_state_dict
20
+ )
21
+
22
+ _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP = None, None
23
+ _LOCAL_RANK, _LOCAL_WORLD_SIZE = -1, -1
24
+
25
+ def images_to_gif_bytes(images: List, duration: int = 1000) -> bytes:
26
+ with BytesIO() as output_buffer:
27
+ # Save the first image
28
+ images[0].save(output_buffer,
29
+ format='GIF',
30
+ save_all=True,
31
+ append_images=images[1:],
32
+ duration=duration,
33
+ loop=0) # 0 means the GIF will loop indefinitely
34
+
35
+ # Get the byte array from the buffer
36
+ gif_bytes = output_buffer.getvalue()
37
+ return gif_bytes
38
+
39
+
40
+ def save_as_gif(images: List, file_path: str, duration: int = 1000):
41
+ with open(file_path, "wb") as f:
42
+ f.write(images_to_gif_bytes(images, duration))
43
+
44
+
45
+ def images_to_mp4_bytes(images: List[Image.Image], duration: float = 1000) -> bytes:
46
+ with BytesIO() as output_buffer:
47
+ with imageio.get_writer(output_buffer, format='mp4', fps=1 / (duration / 1000)) as writer:
48
+ for img in images:
49
+ writer.append_data(np.array(img))
50
+ mp4_bytes = output_buffer.getvalue()
51
+ return mp4_bytes
52
+
53
+
54
+ def save_as_mp4(images: List[Image.Image], file_path: str, duration: float = 1000):
55
+ with open(file_path, "wb") as f:
56
+ f.write(images_to_mp4_bytes(images, duration))
57
+
58
+
59
+
60
+ def get_local_rank() -> int:
61
+ return _LOCAL_RANK
62
+
63
+
64
+ def get_local_world_size() -> int:
65
+ return _LOCAL_WORLD_SIZE
66
+
67
+
68
+ def _setup_dist_env_from_slurm(args):
69
+ while not os.environ.get("MASTER_ADDR", ""):
70
+ try:
71
+ os.environ["MASTER_ADDR"] = subprocess.check_output(
72
+ "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" %
73
+ os.environ['SLURM_NODELIST'],
74
+ shell=True,
75
+ ).decode().strip()
76
+ except:
77
+ pass
78
+ sleep(1)
79
+ os.environ["MASTER_PORT"] = str(int(args.master_port)+1)
80
+ os.environ["RANK"] = os.environ["SLURM_PROCID"]
81
+ os.environ["WORLD_SIZE"] = os.environ["SLURM_NPROCS"]
82
+ os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
83
+ os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"]
84
+
85
+ def init_process_groups(args):
86
+ if any([
87
+ x not in os.environ
88
+ for x in ["RANK", "WORLD_SIZE", "MASTER_PORT", "MASTER_ADDR"]
89
+ ]):
90
+ _setup_dist_env_from_slurm(args)
91
+
92
+ dist.init_process_group("nccl")
93
+ torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
94
+
95
+ global _LOCAL_RANK, _LOCAL_WORLD_SIZE
96
+ _LOCAL_RANK = int(os.environ["LOCAL_RANK"])
97
+ _LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"])
98
+
99
+ global _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP
100
+ local_ranks, local_world_sizes = [torch.empty(
101
+ [dist.get_world_size()], dtype=torch.long, device="cuda"
102
+ ) for _ in (0, 1)]
103
+ dist.all_gather_into_tensor(local_ranks, torch.tensor(get_local_rank(), device="cuda"))
104
+ dist.all_gather_into_tensor(local_world_sizes, torch.tensor(get_local_world_size(), device="cuda"))
105
+ local_ranks, local_world_sizes = local_ranks.tolist(), local_world_sizes.tolist()
106
+
107
+ node_ranks = [[0]]
108
+ for i in range(1, dist.get_world_size()):
109
+ if len(node_ranks[-1]) == local_world_sizes[i - 1]:
110
+ node_ranks.append([])
111
+ else:
112
+ assert local_world_sizes[i] == local_world_sizes[i - 1]
113
+ node_ranks[-1].append(i)
114
+ for ranks in node_ranks:
115
+ group = dist.new_group(ranks)
116
+ if dist.get_rank() in ranks:
117
+ assert _INTRA_NODE_PROCESS_GROUP is None
118
+ _INTRA_NODE_PROCESS_GROUP = group
119
+ assert _INTRA_NODE_PROCESS_GROUP is not None
120
+
121
+ if min(local_world_sizes) == max(local_world_sizes):
122
+ for i in range(get_local_world_size()):
123
+ group = dist.new_group(list(range(i, dist.get_world_size(), get_local_world_size())))
124
+ if i == get_local_rank():
125
+ assert _INTER_NODE_PROCESS_GROUP is None
126
+ _INTER_NODE_PROCESS_GROUP = group
127
+ assert _INTER_NODE_PROCESS_GROUP is not None
128
+
129
+
130
+ def get_intra_node_process_group():
131
+ assert _INTRA_NODE_PROCESS_GROUP is not None, \
132
+ "Intra-node process group is not initialized."
133
+ return _INTRA_NODE_PROCESS_GROUP
134
+
135
+
136
+ def get_inter_node_process_group():
137
+ assert _INTRA_NODE_PROCESS_GROUP is not None, \
138
+ "Intra- and inter-node process groups are not initialized."
139
+ return _INTER_NODE_PROCESS_GROUP
140
+
141
+
142
+ def save_model_fsdp_only(rank, model, output_folder, filename):
143
+ with FSDP.state_dict_type(
144
+ model,
145
+ StateDictType.FULL_STATE_DICT,
146
+ FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
147
+ ):
148
+ consolidated_model_state_dict = model.state_dict()
149
+ if rank == 0:
150
+ torch.save(
151
+ consolidated_model_state_dict,
152
+ os.path.join(output_folder, filename),
153
+ )
154
+ del consolidated_model_state_dict
155
+ dist.barrier()
156
+
157
+
158
+ def save_model(rank, model, output_folder, filename):
159
+ state_dict = get_model_state_dict(
160
+ model,
161
+ options=StateDictOptions(
162
+ full_state_dict=True,
163
+ cpu_offload=True,
164
+ ),
165
+ )
166
+ if rank == 0:
167
+ torch.save(state_dict, os.path.join(output_folder, filename))
168
+ del state_dict
169
+ dist.barrier()
170
+
171
+
172
+ def load_model(rank, model, output_folder, filename, strict=True, logger=None):
173
+ if rank == 0:
174
+ missing_keys, unexpected_keys = model.load_state_dict(
175
+ torch.load(os.path.join(output_folder, filename), map_location="cpu"),
176
+ strict=strict
177
+ )
178
+ if logger is not None:
179
+ logger.info("Model initialization result:")
180
+ logger.info(f" Missing keys: {missing_keys}")
181
+ logger.info(f" Unexpected keys: {unexpected_keys}")
182
+ dist.barrier()
183
+
184
+
185
+ def save_optimizer_fsdp_only(model, optimizer, output_folder, filename):
186
+ with FSDP.state_dict_type(
187
+ model,
188
+ StateDictType.LOCAL_STATE_DICT,
189
+ ):
190
+ torch.save(optimizer.state_dict(), os.path.join(output_folder, filename))
191
+ dist.barrier()
192
+
193
+
194
+ def load_optimizer_fsdp_only(optimizer, output_folder, filename):
195
+ optimizer.load_state_dict(
196
+ torch.load(os.path.join(output_folder, filename), map_location="cpu")
197
+ )
198
+ dist.barrier()
199
+
200
+
201
+ def save_optimizer(model, optimizer, output_folder, filename):
202
+ state_dict = get_optimizer_state_dict(
203
+ model,
204
+ optimizer,
205
+ options=StateDictOptions(
206
+ full_state_dict=False,
207
+ cpu_offload=True,
208
+ ),
209
+ )
210
+ torch.save(state_dict, os.path.join(output_folder, filename))
211
+ dist.barrier()
212
+
213
+
214
+ def load_optimizer(model, optimizer, output_folder, filename):
215
+ state_dict = torch.load(os.path.join(output_folder, filename), map_location="cpu")
216
+ set_optimizer_state_dict(
217
+ model,
218
+ optimizer,
219
+ optim_state_dict=state_dict,
220
+ options=StateDictOptions(
221
+ full_state_dict=False,
222
+ strict=True
223
+ ),
224
+ )
225
+ dist.barrier()