weiyuchoumou526 commited on
Commit
080c0c2
·
1 Parent(s): d3fea23

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .gitignore +8 -0
  3. README.md +68 -7
  4. app.py +679 -0
  5. assets/rose_logo.png +3 -0
  6. assets/step1.png +3 -0
  7. assets/step2.png +3 -0
  8. assets/step3.png +3 -0
  9. configs/wan2.1/wan_civitai.yaml +39 -0
  10. inpainter/base_inpainter.py +374 -0
  11. requirements.txt +17 -0
  12. rose/__init__.py +0 -0
  13. rose/data/bucket_sampler.py +379 -0
  14. rose/data/dataset_image.py +76 -0
  15. rose/data/dataset_image_video.py +589 -0
  16. rose/data/dataset_video.py +262 -0
  17. rose/dist/__init__.py +43 -0
  18. rose/dist/fsdp.py +43 -0
  19. rose/dist/fuser.py +54 -0
  20. rose/dist/wan_xfuser.py +111 -0
  21. rose/models/__init__.py +6 -0
  22. rose/models/cache_utils.py +74 -0
  23. rose/models/diff_mask_predictor.py +42 -0
  24. rose/models/wan_image_encoder.py +553 -0
  25. rose/models/wan_text_encoder.py +376 -0
  26. rose/models/wan_transformer3d.py +1203 -0
  27. rose/models/wan_vae.py +705 -0
  28. rose/models/wan_xlm_roberta.py +170 -0
  29. rose/pipeline/__init__.py +6 -0
  30. rose/pipeline/pipeline_wan_fun.py +558 -0
  31. rose/pipeline/pipeline_wan_fun_control.py +723 -0
  32. rose/pipeline/pipeline_wan_fun_inpaint.py +729 -0
  33. rose/utils/__init__.py +0 -0
  34. rose/utils/discrete_sampler.py +46 -0
  35. rose/utils/fp8_optimization.py +56 -0
  36. rose/utils/lora_utils.py +516 -0
  37. rose/utils/utils.py +318 -0
  38. test_sample/test-sample0.mp4 +3 -0
  39. test_sample/test-sample1.mp4 +3 -0
  40. test_sample/test-sample2.mp4 +3 -0
  41. test_sample/test-sample3.mp4 +3 -0
  42. test_sample/test-sample4.mp4 +3 -0
  43. tools/__init__.py +0 -0
  44. tools/base_segmenter.py +129 -0
  45. tools/interact_tools.py +99 -0
  46. tools/mask_painter.py +288 -0
  47. tools/painter.py +215 -0
  48. track_anything.py +40 -0
  49. tracker/base_tracker.py +103 -0
  50. tracker/config/__init__.py +1 -0
.gitattributes CHANGED
@@ -31,5 +31,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.gif filter=lfs diff=lfs merge=lfs -text
35
  *.zst filter=lfs diff=lfs merge=lfs -text
36
  *tfevents* filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .vscode/
3
+ docs/
4
+ debug_images/
5
+ images/
6
+ result/
7
+ vots/
8
+ vots.py
README.md CHANGED
@@ -1,13 +1,74 @@
 
 
 
 
1
  ---
2
- title: ROSE
3
- emoji: 👁
4
- colorFrom: green
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.34.2
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
  ---
6
+ title: ROSE Awesome Space
7
+ emoji: 🚀
8
+ colorFrom: blue
9
+ colorTo: pink
10
  sdk: gradio
11
+ sdk_version: 4.15.0
12
  app_file: app.py
13
  pinned: false
 
14
  ---
15
 
16
+ ## Get Started
17
+
18
+ 1. Install ProPainter Dependencies
19
+
20
+ You can follow the [Dependencies and Installation](https://github.com/Luo-Yihang/ProPainter-pr/tree/dev_yihang#dependencies-and-installation).
21
+
22
+ 3. Install Demo Dependencies
23
+ ```shell
24
+ cd web-demos/hugging_face
25
+
26
+ # install python dependencies
27
+ pip3 install -r requirements.txt
28
+
29
+ # Run the demo
30
+ python app.py
31
+ ```
32
+
33
+ ## Usage Guidance
34
+ * Step 1: Upload your video and click the `Get video info` button.
35
+ ![Step 1](./assets/step1.png)
36
+
37
+ * Step 2:
38
+ 1. *[Optional]* Specify the tracking period for the currently added mask by dragging the `Track start frame` or `Track end frame`.
39
+ 2. Click the image on the left to select the mask area.
40
+ 3. - Click `Add mask` if you are satisfied with the mask, or
41
+ - *[Optional]* Click `Clear clicks` if you want to reselect the mask area, or
42
+ - *[Optional]* Click `Remove mask` to remove all masks.
43
+ 4. *[Optional]* Go back to step 2.1 to add another mask.
44
+ ![Step 2](./assets/step2.png)
45
+
46
+ * Step 3:
47
+ 1. Click the `Tracking` button to track the masks for the whole video.
48
+ 2. *[Optional]* Select the ProPainter parameters if the `ProPainter Parameters` dropdown.
49
+ 2. Then click `Inpainting` to get the inpainting results.
50
+ ![Step 3](./assets/step3.png)
51
+
52
+ *You can always refer to the `Highlighted Text` box on the page for guidance on the next step!*
53
+
54
+
55
+ ## Citation
56
+ If you find our repo useful for your research, please consider citing our paper:
57
+ ```bibtex
58
+ @inproceedings{zhou2023propainter,
59
+ title={{ProPainter}: Improving Propagation and Transformer for Video Inpainting},
60
+ author={Zhou, Shangchen and Li, Chongyi and Chan, Kelvin C.K and Loy, Chen Change},
61
+ booktitle={Proceedings of IEEE International Conference on Computer Vision (ICCV)},
62
+ year={2023}
63
+ }
64
+ ```
65
+
66
+
67
+ ## License
68
+
69
+ This project is licensed under <a rel="license" href="./LICENSE">NTU S-Lab License 1.0</a>. Redistribution and use should follow this license.
70
+
71
+
72
+ ## Acknowledgements
73
+
74
+ The project harnesses the capabilities from [Track Anything](https://github.com/gaomingqi/Track-Anything), [Segment Anything](https://github.com/facebookresearch/segment-anything) and [Cutie](https://github.com/hkchengrex/Cutie). Thanks for their awesome works.
app.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./")
3
+
4
+ import os
5
+ import json
6
+ import time
7
+ import psutil
8
+ import argparse
9
+
10
+ import cv2
11
+ import torch
12
+ import torchvision
13
+ import numpy as np
14
+ import gradio as gr
15
+
16
+ from tools.painter import mask_painter
17
+ from track_anything import TrackingAnything
18
+
19
+ from utils.misc import get_device
20
+ from utils.download_util import load_file_from_url
21
+ from transformers import AutoTokenizer, AutoModel
22
+ from omegaconf import OmegaConf
23
+ from torchvision.transforms import functional as TF
24
+ from torchvision.utils import save_image
25
+ from einops import rearrange
26
+ from PIL import Image
27
+
28
+ from rose.models import AutoencoderKLWan, CLIPModel, WanT5EncoderModel, WanTransformer3DModel
29
+ from rose.pipeline import WanFunInpaintPipeline
30
+ from diffusers import FlowMatchEulerDiscreteScheduler
31
+
32
+ def filter_kwargs(cls, kwargs):
33
+ import inspect
34
+ sig = inspect.signature(cls.__init__)
35
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
36
+ return {k: v for k, v in kwargs.items() if k in valid_params}
37
+
38
+ # pretrained_model_path = "./models/Diffusion_Transformer/Wan2.1-Fun-1.3B-InP"
39
+ pretrained_model_path = "alibaba-pai/Wan2.1-Fun-1.3B-InP"
40
+ transformer_path = "Kunbyte/ROSE"
41
+ # config_path = "configs/wan2.1/wan_civitai.yaml"
42
+ config_path = "./configs/wan2.1/wan_civitai.yaml"
43
+ config = OmegaConf.load(config_path)
44
+
45
+ tokenizer_subpath = config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')
46
+ tokenizer = AutoTokenizer.from_pretrained(f"{pretrained_model_path}/{tokenizer_subpath}")
47
+
48
+ text_encoder_subpath = config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')
49
+ text_encoder = WanT5EncoderModel.from_pretrained(
50
+ f"{pretrained_model_path}/{text_encoder_subpath}",
51
+ additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
52
+ low_cpu_mem_usage=True,
53
+ )
54
+
55
+ transformer_subpath = config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')
56
+ transformer3d = WanTransformer3DModel.from_pretrained(
57
+ f"{transformer_path}/{transformer_subpath}",
58
+ transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
59
+ )
60
+
61
+ image_encoder_subpath = config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')
62
+ clip_image_encoder = CLIPModel.from_pretrained(f"{pretrained_model_path}/{image_encoder_subpath}")
63
+
64
+ vae_subpath = config['vae_kwargs'].get('vae_subpath', 'vae')
65
+ vae = AutoencoderKLWan.from_pretrained(
66
+ f"{pretrained_model_path}/{vae_subpath}",
67
+ additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
68
+ )
69
+
70
+ noise_scheduler = FlowMatchEulerDiscreteScheduler(
71
+ **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs']))
72
+ )
73
+
74
+ # tokenizer = AutoTokenizer.from_pretrained(
75
+ # os.path.join(pretrained_model_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
76
+ # )
77
+ # text_encoder = WanT5EncoderModel.from_pretrained(
78
+ # os.path.join(pretrained_model_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
79
+ # additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
80
+ # low_cpu_mem_usage=True,
81
+ # )
82
+ # clip_image_encoder = CLIPModel.from_pretrained(
83
+ # os.path.join(pretrained_model_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
84
+ # )
85
+ # vae = AutoencoderKLWan.from_pretrained(
86
+ # os.path.join(pretrained_model_path, config['vae_kwargs'].get('vae_subpath', 'vae')),
87
+ # additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
88
+ # )
89
+ # transformer3d = WanTransformer3DModel.from_pretrained(
90
+ # os.path.join(transformer_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
91
+ # transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
92
+ # )
93
+ # noise_scheduler = FlowMatchEulerDiscreteScheduler(
94
+ # **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs']))
95
+ # )
96
+
97
+ pipeline = WanFunInpaintPipeline(
98
+ vae=vae,
99
+ text_encoder=text_encoder,
100
+ tokenizer=tokenizer,
101
+ transformer=transformer3d,
102
+ scheduler=noise_scheduler,
103
+ clip_image_encoder=clip_image_encoder
104
+ ).to("cuda", torch.float16)
105
+
106
+
107
+ def parse_augment():
108
+ parser = argparse.ArgumentParser()
109
+ parser.add_argument('--device', type=str, default=None)
110
+ parser.add_argument('--sam_model_type', type=str, default="vit_h")
111
+ parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications")
112
+ parser.add_argument('--mask_save', default=False)
113
+ args = parser.parse_args()
114
+
115
+ if not args.device:
116
+ args.device = str(get_device())
117
+
118
+ return args
119
+
120
+ # convert points input to prompt state
121
+ def get_prompt(click_state, click_input):
122
+ inputs = json.loads(click_input)
123
+ points = click_state[0]
124
+ labels = click_state[1]
125
+ for input in inputs:
126
+ points.append(input[:2])
127
+ labels.append(input[2])
128
+ click_state[0] = points
129
+ click_state[1] = labels
130
+ prompt = {
131
+ "prompt_type":["click"],
132
+ "input_point":click_state[0],
133
+ "input_label":click_state[1],
134
+ "multimask_output":"True",
135
+ }
136
+ return prompt
137
+
138
+ # extract frames from upload video
139
+ def get_frames_from_video(video_input, video_state):
140
+ """
141
+ Args:
142
+ video_path:str
143
+ timestamp:float64
144
+ Return
145
+ [[0:nearest_frame], [nearest_frame:], nearest_frame]
146
+ """
147
+ video_path = video_input
148
+ frames = []
149
+ user_name = time.time()
150
+ operation_log = [("[Must Do]", "Click image"), (": Video uploaded! Try to click the image shown in step2 to add masks.\n", None)]
151
+ try:
152
+ cap = cv2.VideoCapture(video_path)
153
+ fps = cap.get(cv2.CAP_PROP_FPS)
154
+ while cap.isOpened():
155
+ ret, frame = cap.read()
156
+ if ret == True:
157
+ current_memory_usage = psutil.virtual_memory().percent
158
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
159
+ if current_memory_usage > 90:
160
+ operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")]
161
+ print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.")
162
+ break
163
+ else:
164
+ break
165
+ except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
166
+ print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
167
+ image_size = (frames[0].shape[0],frames[0].shape[1])
168
+ # initialize video_state
169
+ video_state = {
170
+ "user_name": user_name,
171
+ "video_name": os.path.split(video_path)[-1],
172
+ "origin_images": frames,
173
+ "painted_images": frames.copy(),
174
+ "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames),
175
+ "logits": [None]*len(frames),
176
+ "select_frame_number": 0,
177
+ "fps": fps
178
+ }
179
+ video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size)
180
+ model.samcontroler.sam_controler.reset_image()
181
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
182
+ return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
183
+ gr.update(visible=True), gr.update(visible=True), \
184
+ gr.update(visible=True), gr.update(visible=True),\
185
+ gr.update(visible=True), gr.update(visible=True), \
186
+ gr.update(visible=True), gr.update(visible=True), \
187
+ gr.update(visible=True), gr.update(visible=True), \
188
+ gr.update(visible=True), gr.update(visible=True, choices=[], value=[]), \
189
+ gr.update(visible=True, value=operation_log), gr.update(visible=True, value=operation_log)
190
+
191
+ # get the select frame from gradio slider
192
+ def select_template(image_selection_slider, video_state, interactive_state, mask_dropdown):
193
+
194
+ # images = video_state[1]
195
+ image_selection_slider -= 1
196
+ video_state["select_frame_number"] = image_selection_slider
197
+
198
+ # once select a new template frame, set the image in sam
199
+
200
+ model.samcontroler.sam_controler.reset_image()
201
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
202
+
203
+ operation_log = [("",""), ("Select tracking start frame {}. Try to click the image to add masks for tracking.".format(image_selection_slider),"Normal")]
204
+
205
+ return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log, operation_log
206
+
207
+ # set the tracking end frame
208
+ def get_end_number(track_pause_number_slider, video_state, interactive_state):
209
+ interactive_state["track_end_number"] = track_pause_number_slider
210
+ operation_log = [("",""),("Select tracking finish frame {}.Try to click the image to add masks for tracking.".format(track_pause_number_slider),"Normal")]
211
+
212
+ return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log
213
+
214
+ # use sam to get the mask
215
+ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
216
+ """
217
+ Args:
218
+ template_frame: PIL.Image
219
+ point_prompt: flag for positive or negative button click
220
+ click_state: [[points], [labels]]
221
+ """
222
+ if point_prompt == "Positive":
223
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
224
+ interactive_state["positive_click_times"] += 1
225
+ else:
226
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
227
+ interactive_state["negative_click_times"] += 1
228
+
229
+ # prompt for sam model
230
+ model.samcontroler.sam_controler.reset_image()
231
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
232
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
233
+
234
+ mask, logit, painted_image = model.first_frame_click(
235
+ image=video_state["origin_images"][video_state["select_frame_number"]],
236
+ points=np.array(prompt["input_point"]),
237
+ labels=np.array(prompt["input_label"]),
238
+ multimask=prompt["multimask_output"],
239
+ )
240
+
241
+ video_state["masks"][video_state["select_frame_number"]] = mask
242
+ video_state["logits"][video_state["select_frame_number"]] = logit
243
+ video_state["painted_images"][video_state["select_frame_number"]] = painted_image
244
+
245
+ operation_log = [("[Must Do]", "Add mask"), (": add the current displayed mask for video segmentation.\n", None),
246
+ ("[Optional]", "Remove mask"), (": remove all added masks.\n", None),
247
+ ("[Optional]", "Clear clicks"), (": clear current displayed mask.\n", None),
248
+ ("[Optional]", "Click image"), (": Try to click the image shown in step2 if you want to generate more masks.\n", None)]
249
+ return painted_image, video_state, interactive_state, operation_log, operation_log
250
+
251
+ def add_multi_mask(video_state, interactive_state, mask_dropdown):
252
+ try:
253
+ mask = video_state["masks"][video_state["select_frame_number"]]
254
+ interactive_state["multi_mask"]["masks"].append(mask)
255
+ interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
256
+ mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
257
+ select_frame, _, _ = show_mask(video_state, interactive_state, mask_dropdown)
258
+ operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")]
259
+ except:
260
+ operation_log = [("Please click the image in step2 to generate masks.", "Error"), ("","")]
261
+ return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log, operation_log
262
+
263
+ def clear_click(video_state, click_state):
264
+ click_state = [[],[]]
265
+ template_frame = video_state["origin_images"][video_state["select_frame_number"]]
266
+ operation_log = [("",""), ("Cleared points history and refresh the image.","Normal")]
267
+ return template_frame, click_state, operation_log, operation_log
268
+
269
+ def remove_multi_mask(interactive_state, mask_dropdown):
270
+ interactive_state["multi_mask"]["mask_names"]= []
271
+ interactive_state["multi_mask"]["masks"] = []
272
+
273
+ operation_log = [("",""), ("Remove all masks. Try to add new masks","Normal")]
274
+ return interactive_state, gr.update(choices=[],value=[]), operation_log, operation_log
275
+
276
+ def show_mask(video_state, interactive_state, mask_dropdown):
277
+ mask_dropdown.sort()
278
+ select_frame = video_state["origin_images"][video_state["select_frame_number"]]
279
+ for i in range(len(mask_dropdown)):
280
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
281
+ mask = interactive_state["multi_mask"]["masks"][mask_number]
282
+ select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
283
+
284
+ operation_log = [("",""), ("Added masks {}. If you want to do the inpainting with current masks, please go to step3, and click the Tracking button first and then Inpainting button.".format(mask_dropdown),"Normal")]
285
+ return select_frame, operation_log, operation_log
286
+
287
+ # tracking vos
288
+ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
289
+ operation_log = [("",""), ("Tracking finished! Try to click the Inpainting button to get the inpainting result.","Normal")]
290
+ model.cutie.clear_memory()
291
+ if interactive_state["track_end_number"]:
292
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
293
+ else:
294
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
295
+
296
+ if interactive_state["multi_mask"]["masks"]:
297
+ if len(mask_dropdown) == 0:
298
+ mask_dropdown = ["mask_001"]
299
+ mask_dropdown.sort()
300
+ template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
301
+ for i in range(1,len(mask_dropdown)):
302
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
303
+ template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
304
+ video_state["masks"][video_state["select_frame_number"]]= template_mask
305
+ else:
306
+ template_mask = video_state["masks"][video_state["select_frame_number"]]
307
+
308
+ fps = float(video_state["fps"])
309
+ # operation error
310
+ if len(np.unique(template_mask))==1:
311
+ template_mask[0][0]=1
312
+ operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")]
313
+ # return video_output, video_state, interactive_state, operation_error
314
+ masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
315
+ # clear GPU memory
316
+ model.cutie.clear_memory()
317
+
318
+ if interactive_state["track_end_number"]:
319
+ video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
320
+ video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
321
+ video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images
322
+ else:
323
+ video_state["masks"][video_state["select_frame_number"]:] = masks
324
+ video_state["logits"][video_state["select_frame_number"]:] = logits
325
+ video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
326
+
327
+ video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
328
+ interactive_state["inference_times"] += 1
329
+
330
+ print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"],
331
+ interactive_state["positive_click_times"]+interactive_state["negative_click_times"],
332
+ interactive_state["positive_click_times"],
333
+ interactive_state["negative_click_times"]))
334
+
335
+ #### shanggao code for mask save
336
+ if interactive_state["mask_save"]:
337
+ if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])):
338
+ os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0]))
339
+ i = 0
340
+ print("save mask")
341
+ for mask in video_state["masks"]:
342
+ np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask)
343
+ i+=1
344
+ # save_mask(video_state["masks"], video_state["video_name"])
345
+ #### shanggao code for mask save
346
+ return video_output, video_state, interactive_state, operation_log, operation_log
347
+
348
+ def inpaint_video(video_state, *_):
349
+ operation_log = [("", ""), ("Inpainting finished!", "Normal")]
350
+
351
+ # import pdb;pdb.set_trace()
352
+ frames = video_state["origin_images"]
353
+ masks = video_state["masks"]
354
+ # masks = masks * 255
355
+ fps = int(video_state["fps"])
356
+
357
+ total_frames = len(frames)
358
+ target_frame_count = (total_frames - 1) // 16 * 16 + 1
359
+ frames = frames[:target_frame_count]
360
+ masks = masks[:target_frame_count]
361
+
362
+ frames_resized = [cv2.resize(frame, (720, 480), interpolation=cv2.INTER_CUBIC) for frame in frames]
363
+ masks_resized = [cv2.resize(mask, (720, 480), interpolation=cv2.INTER_CUBIC) for mask in masks]
364
+
365
+ with torch.no_grad():
366
+ video_tensor = torch.stack([TF.to_tensor(Image.fromarray(f)) for f in frames_resized], dim=1).unsqueeze(0).to("cuda", torch.float16)
367
+ mask_tensor = torch.stack([TF.to_tensor(Image.fromarray(m*255)) for m in masks_resized], dim=1).unsqueeze(0).to("cuda", torch.float16)
368
+ #video_tensor = torch.stack([torch.from_numpy(f).float() for f in frames_resized], dim=1).unsqueeze(0).to("cuda", torch.bfloat16)
369
+ #mask_tensor = torch.stack([torch.from_numpy(m).float() for m in masks_resized], dim=1).unsqueeze(0).to("cuda", torch.bfloat16)
370
+
371
+ output = pipeline(
372
+ prompt="",
373
+ video=video_tensor,
374
+ mask_video=mask_tensor,
375
+ num_frames=video_tensor.shape[2],
376
+ num_inference_steps=50
377
+ ).videos
378
+
379
+ output = output.clamp(0, 1).cpu()
380
+ output_np = (output[0].permute(1, 2, 3, 0).numpy() * 255).astype(np.uint8)
381
+
382
+ output_path = f"./result/inpaint/{video_state['video_name']}"
383
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
384
+
385
+ torchvision.io.write_video(output_path, torch.from_numpy(output_np), fps=fps, video_codec="libx264")
386
+
387
+ return output_path, operation_log, operation_log
388
+
389
+
390
+ # generate video after vos inference
391
+ def generate_video_from_frames(frames, output_path, fps=30):
392
+ """
393
+ Generates a video from a list of frames.
394
+
395
+ Args:
396
+ frames (list of numpy arrays): The frames to include in the video.
397
+ output_path (str): The path to save the generated video.
398
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
399
+ """
400
+ frames = torch.from_numpy(np.asarray(frames))
401
+ if not os.path.exists(os.path.dirname(output_path)):
402
+ os.makedirs(os.path.dirname(output_path))
403
+ fps = int(fps)
404
+ # import pdb;pdb.set_trace()
405
+ torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
406
+ return output_path
407
+
408
+ def restart():
409
+ operation_log = [("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")]
410
+ return {
411
+ "user_name": "",
412
+ "video_name": "",
413
+ "origin_images": None,
414
+ "painted_images": None,
415
+ "masks": None,
416
+ "inpaint_masks": None,
417
+ "logits": None,
418
+ "select_frame_number": 0,
419
+ "fps": 30
420
+ }, {
421
+ "inference_times": 0,
422
+ "negative_click_times" : 0,
423
+ "positive_click_times": 0,
424
+ "mask_save": args.mask_save,
425
+ "multi_mask": {
426
+ "mask_names": [],
427
+ "masks": []
428
+ },
429
+ "track_end_number": None,
430
+ }, [[],[]], None, None, None, \
431
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\
432
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
433
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
434
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", \
435
+ gr.update(visible=True, value=operation_log), gr.update(visible=False, value=operation_log)
436
+
437
+
438
+ # args, defined in track_anything.py
439
+ args = parse_augment()
440
+ pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/'
441
+ sam_checkpoint_url_dict = {
442
+ 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
443
+ 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
444
+ 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
445
+ }
446
+ checkpoint_fodler = os.path.join('.', 'weights')
447
+
448
+ sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_fodler)
449
+ cutie_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'cutie-base-mega.pth'), checkpoint_fodler)
450
+ # propainter_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'ProPainter.pth'), checkpoint_fodler)
451
+ # raft_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'raft-things.pth'), checkpoint_fodler)
452
+ # flow_completion_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'), checkpoint_fodler)
453
+
454
+ # initialize sam, cutie, propainter models
455
+ model = TrackingAnything(sam_checkpoint, cutie_checkpoint, args)
456
+
457
+
458
+ title = r"""<h1 align="center">ROSE: Remove Objects with Side Effects in Videos</h1>"""
459
+
460
+ description = r"""
461
+ <center></center>
462
+ <b>Official Gradio demo</b> for <a href='https://github.com/sczhou/ProPainter' target='_blank'><b>Remove Objects with Side Effects in Videos</b></a>.<br>
463
+ 🔥 ROSE is a robust inpainting algorithm.<br>
464
+ 🤗 Try to drop your video, add the masks and get the the inpainting results!<br>
465
+ """
466
+
467
+ css = """
468
+ .gradio-container {width: 85% !important; margin: 0 auto !important;}
469
+ .gr-monochrome-group {border-radius: 5px !important; border: revert-layer !important; border-width: 2px !important; color: black !important}
470
+ button {border-radius: 8px !important;}
471
+ .add_button {background-color: #4CAF50 !important;}
472
+ .remove_button {background-color: #f44336 !important;}
473
+ .mask_button_group {gap: 10px !important;}
474
+ .video {height: 300px !important;}
475
+ .image {height: 300px !important;}
476
+ .video .wrap.svelte-lcpz3o {display: flex !important; align-items: center !important; justify-content: center !important;}
477
+ .video .wrap.svelte-lcpz3o > :first-child {height: 100% !important;}
478
+ .margin_center {width: 50% !important; margin: auto !important;}
479
+ .jc_center {justify-content: center !important;}
480
+ body {
481
+ display: flex;
482
+ justify-content: center;
483
+ align-items: center;
484
+ min-height: 100vh;
485
+ margin: 0;
486
+ }
487
+ """
488
+
489
+ with gr.Blocks(theme=gr.themes.Monochrome(), css=css) as iface:
490
+ click_state = gr.State([[],[]])
491
+
492
+ interactive_state = gr.State({
493
+ "inference_times": 0,
494
+ "negative_click_times" : 0,
495
+ "positive_click_times": 0,
496
+ "mask_save": args.mask_save,
497
+ "multi_mask": {
498
+ "mask_names": [],
499
+ "masks": []
500
+ },
501
+ "track_end_number": None,
502
+ }
503
+ )
504
+
505
+ video_state = gr.State(
506
+ {
507
+ "user_name": "",
508
+ "video_name": "",
509
+ "origin_images": None,
510
+ "painted_images": None,
511
+ "masks": None,
512
+ "inpaint_masks": None,
513
+ "logits": None,
514
+ "select_frame_number": 0,
515
+ "fps": 30
516
+ }
517
+ )
518
+
519
+ gr.Markdown(title)
520
+ gr.Markdown(description)
521
+
522
+ with gr.Column():
523
+ # input video
524
+ gr.Markdown("## Step1: Upload video")
525
+ with gr.Row(equal_height=True):
526
+ with gr.Column(scale=2):
527
+ video_input = gr.Video(elem_classes="video")
528
+ extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
529
+ with gr.Column(scale=2):
530
+ run_status = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")],
531
+ color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"})
532
+ video_info = gr.Textbox(label="Video Info")
533
+
534
+
535
+ # add masks
536
+ step2_title = gr.Markdown("---\n## Step2: Add masks", visible=False)
537
+ with gr.Row(equal_height=True):
538
+ with gr.Column(scale=2):
539
+ template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
540
+ image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False)
541
+ track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
542
+ with gr.Column(scale=2, elem_classes="jc_center"):
543
+ run_status2 = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")],
544
+ color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"})
545
+ with gr.Row():
546
+ with gr.Column(scale=2, elem_classes="mask_button_group"):
547
+ clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False)
548
+ remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False, elem_classes="remove_button")
549
+ Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False, elem_classes="add_button")
550
+ point_prompt = gr.Radio(
551
+ choices=["Positive", "Negative"],
552
+ value="Positive",
553
+ label="Point prompt",
554
+ interactive=True,
555
+ visible=False,
556
+ min_width=100,
557
+ scale=1)
558
+ mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
559
+
560
+ # output video
561
+ step3_title = gr.Markdown("---\n## Step3: Track masks and get the inpainting result", visible=False)
562
+ with gr.Row(equal_height=True):
563
+ with gr.Column(scale=2):
564
+ tracking_video_output = gr.Video(visible=False, elem_classes="video")
565
+ tracking_video_predict_button = gr.Button(value="1. Tracking", visible=False, elem_classes="margin_center")
566
+ with gr.Column(scale=2):
567
+ inpaiting_video_output = gr.Video(visible=False, elem_classes="video")
568
+ inpaint_video_predict_button = gr.Button(value="2. Inpainting", visible=False, elem_classes="margin_center")
569
+
570
+ # first step: get the video information
571
+ extract_frames_button.click(
572
+ fn=get_frames_from_video,
573
+ inputs=[
574
+ video_input, video_state
575
+ ],
576
+ outputs=[video_state, video_info, template_frame,
577
+ image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button, template_frame,
578
+ tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button, inpaint_video_predict_button, step2_title, step3_title,mask_dropdown, run_status, run_status2]
579
+ )
580
+
581
+ # second step: select images from slider
582
+ image_selection_slider.release(fn=select_template,
583
+ inputs=[image_selection_slider, video_state, interactive_state],
584
+ outputs=[template_frame, video_state, interactive_state, run_status, run_status2], api_name="select_image")
585
+ track_pause_number_slider.release(fn=get_end_number,
586
+ inputs=[track_pause_number_slider, video_state, interactive_state],
587
+ outputs=[template_frame, interactive_state, run_status, run_status2], api_name="end_image")
588
+
589
+ # click select image to get mask using sam
590
+ template_frame.select(
591
+ fn=sam_refine,
592
+ inputs=[video_state, point_prompt, click_state, interactive_state],
593
+ outputs=[template_frame, video_state, interactive_state, run_status, run_status2]
594
+ )
595
+
596
+ # add different mask
597
+ Add_mask_button.click(
598
+ fn=add_multi_mask,
599
+ inputs=[video_state, interactive_state, mask_dropdown],
600
+ outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status, run_status2]
601
+ )
602
+
603
+ remove_mask_button.click(
604
+ fn=remove_multi_mask,
605
+ inputs=[interactive_state, mask_dropdown],
606
+ outputs=[interactive_state, mask_dropdown, run_status, run_status2]
607
+ )
608
+
609
+ # tracking video from select image and mask
610
+ tracking_video_predict_button.click(
611
+ fn=vos_tracking_video,
612
+ inputs=[video_state, interactive_state, mask_dropdown],
613
+ outputs=[tracking_video_output, video_state, interactive_state, run_status, run_status2]
614
+ )
615
+
616
+ # inpaint video from select image and mask
617
+ inpaint_video_predict_button.click(
618
+ fn=inpaint_video,
619
+ #inputs=[video_state, resize_ratio_number, dilate_radius_number, raft_iter_number, subvideo_length_number, neighbor_length_number, ref_stride_number, mask_dropdown],
620
+ inputs=[video_state, mask_dropdown],
621
+ outputs=[inpaiting_video_output, run_status, run_status2]
622
+ )
623
+
624
+ # click to get mask
625
+ mask_dropdown.change(
626
+ fn=show_mask,
627
+ inputs=[video_state, interactive_state, mask_dropdown],
628
+ outputs=[template_frame, run_status, run_status2]
629
+ )
630
+
631
+ # clear input
632
+ video_input.change(
633
+ fn=restart,
634
+ inputs=[],
635
+ outputs=[
636
+ video_state,
637
+ interactive_state,
638
+ click_state,
639
+ tracking_video_output, inpaiting_video_output,
640
+ template_frame,
641
+ tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
642
+ Add_mask_button, template_frame, tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button,inpaint_video_predict_button, step2_title, step3_title, mask_dropdown, video_info, run_status, run_status2
643
+ ],
644
+ queue=False,
645
+ show_progress=False)
646
+
647
+ video_input.clear(
648
+ fn=restart,
649
+ inputs=[],
650
+ outputs=[
651
+ video_state,
652
+ interactive_state,
653
+ click_state,
654
+ tracking_video_output, inpaiting_video_output,
655
+ template_frame,
656
+ tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
657
+ Add_mask_button, template_frame, tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button,inpaint_video_predict_button, step2_title, step3_title, mask_dropdown, video_info, run_status, run_status2
658
+ ],
659
+ queue=False,
660
+ show_progress=False)
661
+
662
+ # points clear
663
+ clear_button_click.click(
664
+ fn = clear_click,
665
+ inputs = [video_state, click_state,],
666
+ outputs = [template_frame,click_state, run_status, run_status2],
667
+ )
668
+
669
+ # set example
670
+ gr.Markdown("## Examples")
671
+ gr.Examples(
672
+ examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample0.mp4", "test-sample1.mp4", "test-sample2.mp4", "test-sample3.mp4", "test-sample4.mp4"]],
673
+ inputs=[video_input],
674
+ )
675
+ # gr.Markdown(article)
676
+
677
+ # iface.queue(concurrency_count=1)
678
+ iface.queue()
679
+ iface.launch(debug=True)
assets/rose_logo.png ADDED

Git LFS Details

  • SHA256: cf454f99eaabcece90cb664c39a45d17b58f8201eba8b220fa155ac22a014c4a
  • Pointer size: 130 Bytes
  • Size of remote file: 71.8 kB
assets/step1.png ADDED

Git LFS Details

  • SHA256: c93010fa938c75ae671e3aa362205f3f2692783930f67a6623e0a438479e7326
  • Pointer size: 131 Bytes
  • Size of remote file: 309 kB
assets/step2.png ADDED

Git LFS Details

  • SHA256: 48bdd827f2581da65df1e731325163bd7e5095511ed6e753346f74b71156dc8d
  • Pointer size: 131 Bytes
  • Size of remote file: 458 kB
assets/step3.png ADDED

Git LFS Details

  • SHA256: d8273104c2e558cb0d1edfe91f5e4ca27483815f5e31aad153f199a351c87b12
  • Pointer size: 132 Bytes
  • Size of remote file: 1.14 MB
configs/wan2.1/wan_civitai.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ format: civitai
2
+ pipeline: Wan
3
+ transformer_additional_kwargs:
4
+ transformer_subpath: ./
5
+ dict_mapping:
6
+ in_dim: in_channels
7
+ dim: hidden_size
8
+
9
+ vae_kwargs:
10
+ vae_subpath: Wan2.1_VAE.pth
11
+ temporal_compression_ratio: 4
12
+ spatial_compression_ratio: 8
13
+
14
+ text_encoder_kwargs:
15
+ text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
16
+ tokenizer_subpath: google/umt5-xxl
17
+ text_length: 512
18
+ vocab: 256384
19
+ dim: 4096
20
+ dim_attn: 4096
21
+ dim_ffn: 10240
22
+ num_heads: 64
23
+ num_layers: 24
24
+ num_buckets: 32
25
+ shared_pos: False
26
+ dropout: 0.0
27
+
28
+ scheduler_kwargs:
29
+ scheduler_subpath: null
30
+ num_train_timesteps: 1000
31
+ shift: 5.0
32
+ use_dynamic_shifting: false
33
+ base_shift: 0.5
34
+ max_shift: 1.15
35
+ base_image_seq_len: 256
36
+ max_image_seq_len: 4096
37
+
38
+ image_encoder_kwargs:
39
+ image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
inpainter/base_inpainter.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import sys
4
+ import cv2
5
+ import numpy as np
6
+ import scipy.ndimage
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+
10
+ import torch
11
+ import torchvision
12
+
13
+ from model.modules.flow_comp_raft import RAFT_bi
14
+ from model.recurrent_flow_completion import RecurrentFlowCompleteNet
15
+ from model.propainter import InpaintGenerator
16
+ from core.utils import to_tensors
17
+
18
+ import warnings
19
+ warnings.filterwarnings("ignore")
20
+
21
+
22
+ def imwrite(img, file_path, params=None, auto_mkdir=True):
23
+ if auto_mkdir:
24
+ dir_name = os.path.abspath(os.path.dirname(file_path))
25
+ os.makedirs(dir_name, exist_ok=True)
26
+ return cv2.imwrite(file_path, img, params)
27
+
28
+
29
+ def resize_frames(frames, size=None):
30
+ if size is not None:
31
+ out_size = size
32
+ process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
33
+ frames = [f.resize(process_size) for f in frames]
34
+ else:
35
+ out_size = frames[0].size
36
+ process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
37
+ if not out_size == process_size:
38
+ frames = [f.resize(process_size) for f in frames]
39
+
40
+ return frames, process_size, out_size
41
+
42
+
43
+ def read_frame_from_videos(frame_root):
44
+ if frame_root.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
45
+ video_name = os.path.basename(frame_root)[:-4]
46
+ vframes, aframes, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec') # RGB
47
+ frames = list(vframes.numpy())
48
+ frames = [Image.fromarray(f) for f in frames]
49
+ fps = info['video_fps']
50
+ else:
51
+ video_name = os.path.basename(frame_root)
52
+ frames = []
53
+ fr_lst = sorted(os.listdir(frame_root))
54
+ for fr in fr_lst:
55
+ frame = cv2.imread(os.path.join(frame_root, fr))
56
+ frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
57
+ frames.append(frame)
58
+ fps = None
59
+ size = frames[0].size
60
+
61
+ return frames, fps, size, video_name
62
+
63
+
64
+ def binary_mask(mask, th=0.1):
65
+ mask[mask>th] = 1
66
+ mask[mask<=th] = 0
67
+ return mask
68
+
69
+
70
+ def extrapolation(video_ori, scale):
71
+ """Prepares the data for video outpainting.
72
+ """
73
+ nFrame = len(video_ori)
74
+ imgW, imgH = video_ori[0].size
75
+
76
+ # Defines new FOV.
77
+ imgH_extr = int(scale[0] * imgH)
78
+ imgW_extr = int(scale[1] * imgW)
79
+ imgH_extr = imgH_extr - imgH_extr % 8
80
+ imgW_extr = imgW_extr - imgW_extr % 8
81
+ H_start = int((imgH_extr - imgH) / 2)
82
+ W_start = int((imgW_extr - imgW) / 2)
83
+
84
+ # Extrapolates the FOV for video.
85
+ frames = []
86
+ for v in video_ori:
87
+ frame = np.zeros(((imgH_extr, imgW_extr, 3)), dtype=np.uint8)
88
+ frame[H_start: H_start + imgH, W_start: W_start + imgW, :] = v
89
+ frames.append(Image.fromarray(frame))
90
+
91
+ # Generates the mask for missing region.
92
+ masks_dilated = []
93
+ flow_masks = []
94
+
95
+ dilate_h = 4 if H_start > 10 else 0
96
+ dilate_w = 4 if W_start > 10 else 0
97
+ mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.uint8)
98
+
99
+ mask[H_start+dilate_h: H_start+imgH-dilate_h,
100
+ W_start+dilate_w: W_start+imgW-dilate_w] = 0
101
+ flow_masks.append(Image.fromarray(mask * 255))
102
+
103
+ mask[H_start: H_start+imgH, W_start: W_start+imgW] = 0
104
+ masks_dilated.append(Image.fromarray(mask * 255))
105
+
106
+ flow_masks = flow_masks * nFrame
107
+ masks_dilated = masks_dilated * nFrame
108
+
109
+ return frames, flow_masks, masks_dilated, (imgW_extr, imgH_extr)
110
+
111
+
112
+ def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1):
113
+ ref_index = []
114
+ if ref_num == -1:
115
+ for i in range(0, length, ref_stride):
116
+ if i not in neighbor_ids:
117
+ ref_index.append(i)
118
+ else:
119
+ start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2))
120
+ end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2))
121
+ for i in range(start_idx, end_idx, ref_stride):
122
+ if i not in neighbor_ids:
123
+ if len(ref_index) > ref_num:
124
+ break
125
+ ref_index.append(i)
126
+ return ref_index
127
+
128
+
129
+ def read_mask_demo(masks, length, size, flow_mask_dilates=8, mask_dilates=5):
130
+ masks_img = []
131
+ masks_dilated = []
132
+ flow_masks = []
133
+
134
+ for mp in masks:
135
+ masks_img.append(Image.fromarray(mp.astype('uint8')))
136
+
137
+ for mask_img in masks_img:
138
+ if size is not None:
139
+ mask_img = mask_img.resize(size, Image.NEAREST)
140
+ mask_img = np.array(mask_img.convert('L'))
141
+
142
+ # Dilate 8 pixel so that all known pixel is trustworthy
143
+ if flow_mask_dilates > 0:
144
+ flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8)
145
+ else:
146
+ flow_mask_img = binary_mask(mask_img).astype(np.uint8)
147
+
148
+ flow_masks.append(Image.fromarray(flow_mask_img * 255))
149
+
150
+ if mask_dilates > 0:
151
+ mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8)
152
+ else:
153
+ mask_img = binary_mask(mask_img).astype(np.uint8)
154
+ masks_dilated.append(Image.fromarray(mask_img * 255))
155
+
156
+ if len(masks_img) == 1:
157
+ flow_masks = flow_masks * length
158
+ masks_dilated = masks_dilated * length
159
+
160
+ return flow_masks, masks_dilated
161
+
162
+
163
+ class ProInpainter:
164
+ def __init__(self, propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, device="cuda:0", use_half=True):
165
+ self.device = device
166
+ self.use_half = use_half
167
+ if self.device == torch.device('cpu'):
168
+ self.use_half = False
169
+
170
+ ##############################################
171
+ # set up RAFT and flow competition model
172
+ ##############################################
173
+ self.fix_raft = RAFT_bi(raft_checkpoint, self.device)
174
+
175
+ self.fix_flow_complete = RecurrentFlowCompleteNet(flow_completion_checkpoint)
176
+ for p in self.fix_flow_complete.parameters():
177
+ p.requires_grad = False
178
+ self.fix_flow_complete.to(self.device)
179
+ self.fix_flow_complete.eval()
180
+
181
+ ##############################################
182
+ # set up ProPainter model
183
+ ##############################################
184
+ self.model = InpaintGenerator(model_path=propainter_checkpoint).to(self.device)
185
+ self.model.eval()
186
+
187
+ if self.use_half:
188
+ self.fix_flow_complete = self.fix_flow_complete.half()
189
+ self.model = self.model.half()
190
+
191
+ def inpaint(self, npframes, masks, ratio=1.0, dilate_radius=4, raft_iter=20, subvideo_length=80, neighbor_length=10, ref_stride=10):
192
+ """
193
+ Perform Inpainting for video subsets
194
+
195
+ Output:
196
+ inpainted_frames: numpy array, T, H, W, 3
197
+ """
198
+
199
+ frames = []
200
+ for i in range(len(npframes)):
201
+ frames.append(Image.fromarray(npframes[i].astype('uint8'), mode="RGB"))
202
+ del npframes
203
+
204
+ size = frames[0].size
205
+ # The ouput size should be divided by 2 so that it can encoded by libx264
206
+ size = (int(ratio*size[0])//2*2, int(ratio*size[1])//2*2)
207
+
208
+ frames_len = len(frames)
209
+ frames, size, out_size = resize_frames(frames, size)
210
+ flow_masks, masks_dilated = read_mask_demo(masks, frames_len, size, dilate_radius, dilate_radius)
211
+ w, h = size
212
+
213
+ frames_inp = [np.array(f).astype(np.uint8) for f in frames]
214
+ frames = to_tensors()(frames).unsqueeze(0) * 2 - 1
215
+ flow_masks = to_tensors()(flow_masks).unsqueeze(0)
216
+ masks_dilated = to_tensors()(masks_dilated).unsqueeze(0)
217
+ frames, flow_masks, masks_dilated = frames.to(self.device), flow_masks.to(self.device), masks_dilated.to(self.device)
218
+
219
+ ##############################################
220
+ # ProPainter inference
221
+ ##############################################
222
+ video_length = frames.size(1)
223
+ with torch.no_grad():
224
+ # ---- compute flow ----
225
+ if frames.size(-1) <= 640:
226
+ short_clip_len = 12
227
+ elif frames.size(-1) <= 720:
228
+ short_clip_len = 8
229
+ elif frames.size(-1) <= 1280:
230
+ short_clip_len = 4
231
+ else:
232
+ short_clip_len = 2
233
+
234
+ # use fp32 for RAFT
235
+ if frames.size(1) > short_clip_len:
236
+ gt_flows_f_list, gt_flows_b_list = [], []
237
+ for f in range(0, video_length, short_clip_len):
238
+ end_f = min(video_length, f + short_clip_len)
239
+ if f == 0:
240
+ flows_f, flows_b = self.fix_raft(frames[:,f:end_f], iters=raft_iter)
241
+ else:
242
+ flows_f, flows_b = self.fix_raft(frames[:,f-1:end_f], iters=raft_iter)
243
+
244
+ gt_flows_f_list.append(flows_f)
245
+ gt_flows_b_list.append(flows_b)
246
+ torch.cuda.empty_cache()
247
+
248
+ gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
249
+ gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
250
+ gt_flows_bi = (gt_flows_f, gt_flows_b)
251
+ else:
252
+ gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
253
+ torch.cuda.empty_cache()
254
+
255
+ if self.use_half:
256
+ frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
257
+ gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())
258
+
259
+ # ---- complete flow ----
260
+ flow_length = gt_flows_bi[0].size(1)
261
+ if flow_length > subvideo_length:
262
+ pred_flows_f, pred_flows_b = [], []
263
+ pad_len = 5
264
+ for f in range(0, flow_length, subvideo_length):
265
+ s_f = max(0, f - pad_len)
266
+ e_f = min(flow_length, f + subvideo_length + pad_len)
267
+ pad_len_s = max(0, f) - s_f
268
+ pad_len_e = e_f - min(flow_length, f + subvideo_length)
269
+ pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow(
270
+ (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
271
+ flow_masks[:, s_f:e_f+1])
272
+ pred_flows_bi_sub = self.fix_flow_complete.combine_flow(
273
+ (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
274
+ pred_flows_bi_sub,
275
+ flow_masks[:, s_f:e_f+1])
276
+
277
+ pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
278
+ pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
279
+ torch.cuda.empty_cache()
280
+
281
+ pred_flows_f = torch.cat(pred_flows_f, dim=1)
282
+ pred_flows_b = torch.cat(pred_flows_b, dim=1)
283
+ pred_flows_bi = (pred_flows_f, pred_flows_b)
284
+ else:
285
+ pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
286
+ pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
287
+ torch.cuda.empty_cache()
288
+
289
+ # ---- image propagation ----
290
+ masked_frames = frames * (1 - masks_dilated)
291
+ subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation
292
+ if video_length > subvideo_length_img_prop:
293
+ updated_frames, updated_masks = [], []
294
+ pad_len = 10
295
+ for f in range(0, video_length, subvideo_length_img_prop):
296
+ s_f = max(0, f - pad_len)
297
+ e_f = min(video_length, f + subvideo_length_img_prop + pad_len)
298
+ pad_len_s = max(0, f) - s_f
299
+ pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop)
300
+
301
+ b, t, _, _, _ = masks_dilated[:, s_f:e_f].size()
302
+ pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1])
303
+ prop_imgs_sub, updated_local_masks_sub = self.model.img_propagation(masked_frames[:, s_f:e_f],
304
+ pred_flows_bi_sub,
305
+ masks_dilated[:, s_f:e_f],
306
+ 'nearest')
307
+ updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \
308
+ prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f]
309
+ updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w)
310
+
311
+ updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
312
+ updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
313
+ torch.cuda.empty_cache()
314
+
315
+ updated_frames = torch.cat(updated_frames, dim=1)
316
+ updated_masks = torch.cat(updated_masks, dim=1)
317
+ else:
318
+ b, t, _, _, _ = masks_dilated.size()
319
+ prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
320
+ updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
321
+ updated_masks = updated_local_masks.view(b, t, 1, h, w)
322
+ torch.cuda.empty_cache()
323
+
324
+ ori_frames = frames_inp
325
+ comp_frames = [None] * video_length
326
+
327
+ neighbor_stride = neighbor_length // 2
328
+ if video_length > subvideo_length:
329
+ ref_num = subvideo_length // ref_stride
330
+ else:
331
+ ref_num = -1
332
+
333
+ # ---- feature propagation + transformer ----
334
+ for f in tqdm(range(0, video_length, neighbor_stride)):
335
+ neighbor_ids = [
336
+ i for i in range(max(0, f - neighbor_stride),
337
+ min(video_length, f + neighbor_stride + 1))
338
+ ]
339
+ ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_stride, ref_num)
340
+ selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :]
341
+ selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :]
342
+ selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :]
343
+ selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :])
344
+
345
+ with torch.no_grad():
346
+ # 1.0 indicates mask
347
+ l_t = len(neighbor_ids)
348
+
349
+ # pred_img = selected_imgs # results of image propagation
350
+ pred_img = self.model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t)
351
+
352
+ pred_img = pred_img.view(-1, 3, h, w)
353
+
354
+ pred_img = (pred_img + 1) / 2
355
+ pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
356
+ binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute(
357
+ 0, 2, 3, 1).numpy().astype(np.uint8)
358
+ for i in range(len(neighbor_ids)):
359
+ idx = neighbor_ids[i]
360
+ img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
361
+ + ori_frames[idx] * (1 - binary_masks[i])
362
+ if comp_frames[idx] is None:
363
+ comp_frames[idx] = img
364
+ else:
365
+ comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5
366
+
367
+ comp_frames[idx] = comp_frames[idx].astype(np.uint8)
368
+
369
+ torch.cuda.empty_cache()
370
+
371
+ # need to return numpy array, T, H, W, 3
372
+ comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
373
+
374
+ return comp_frames
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ progressbar2
2
+ gdown
3
+ gitpython
4
+ git+https://github.com/cheind/py-thin-plate-spline
5
+ hickle
6
+ tensorboard
7
+ numpy
8
+ git+https://github.com/facebookresearch/segment-anything.git
9
+ gradio
10
+ opencv-python
11
+ matplotlib
12
+ pyyaml
13
+ av
14
+ openmim
15
+ tqdm
16
+ psutil
17
+ omegaconf
rose/__init__.py ADDED
File without changes
rose/data/bucket_sampler.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
4
+ Sized, TypeVar, Union)
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from torch.utils.data import BatchSampler, Dataset, Sampler
11
+
12
+ ASPECT_RATIO_512 = {
13
+ '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
14
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
15
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
16
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
17
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
18
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
19
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
20
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
21
+ '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
22
+ '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
23
+ }
24
+ ASPECT_RATIO_RANDOM_CROP_512 = {
25
+ '0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
26
+ '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
27
+ '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
28
+ '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
29
+ '2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
30
+ }
31
+ ASPECT_RATIO_RANDOM_CROP_PROB = [
32
+ 1, 2,
33
+ 4, 4, 4, 4,
34
+ 8, 8, 8,
35
+ 4, 4, 4, 4,
36
+ 2, 1
37
+ ]
38
+ ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
39
+
40
+ def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
41
+ aspect_ratio = height / width
42
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
43
+ return ratios[closest_ratio], float(closest_ratio)
44
+
45
+ def get_image_size_without_loading(path):
46
+ with Image.open(path) as img:
47
+ return img.size # (width, height)
48
+
49
+ class RandomSampler(Sampler[int]):
50
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
51
+
52
+ If with replacement, then user can specify :attr:`num_samples` to draw.
53
+
54
+ Args:
55
+ data_source (Dataset): dataset to sample from
56
+ replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
57
+ num_samples (int): number of samples to draw, default=`len(dataset)`.
58
+ generator (Generator): Generator used in sampling.
59
+ """
60
+
61
+ data_source: Sized
62
+ replacement: bool
63
+
64
+ def __init__(self, data_source: Sized, replacement: bool = False,
65
+ num_samples: Optional[int] = None, generator=None) -> None:
66
+ self.data_source = data_source
67
+ self.replacement = replacement
68
+ self._num_samples = num_samples
69
+ self.generator = generator
70
+ self._pos_start = 0
71
+
72
+ if not isinstance(self.replacement, bool):
73
+ raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
74
+
75
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
76
+ raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
77
+
78
+ @property
79
+ def num_samples(self) -> int:
80
+ # dataset size might change at runtime
81
+ if self._num_samples is None:
82
+ return len(self.data_source)
83
+ return self._num_samples
84
+
85
+ def __iter__(self) -> Iterator[int]:
86
+ n = len(self.data_source)
87
+ if self.generator is None:
88
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
89
+ generator = torch.Generator()
90
+ generator.manual_seed(seed)
91
+ else:
92
+ generator = self.generator
93
+
94
+ if self.replacement:
95
+ for _ in range(self.num_samples // 32):
96
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
97
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
98
+ else:
99
+ for _ in range(self.num_samples // n):
100
+ xx = torch.randperm(n, generator=generator).tolist()
101
+ if self._pos_start >= n:
102
+ self._pos_start = 0
103
+ print("xx top 10", xx[:10], self._pos_start)
104
+ for idx in range(self._pos_start, n):
105
+ yield xx[idx]
106
+ self._pos_start = (self._pos_start + 1) % n
107
+ self._pos_start = 0
108
+ yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
109
+
110
+ def __len__(self) -> int:
111
+ return self.num_samples
112
+
113
+ class AspectRatioBatchImageSampler(BatchSampler):
114
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
115
+
116
+ Args:
117
+ sampler (Sampler): Base sampler.
118
+ dataset (Dataset): Dataset providing data information.
119
+ batch_size (int): Size of mini-batch.
120
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
121
+ its size would be less than ``batch_size``.
122
+ aspect_ratios (dict): The predefined aspect ratios.
123
+ """
124
+ def __init__(
125
+ self,
126
+ sampler: Sampler,
127
+ dataset: Dataset,
128
+ batch_size: int,
129
+ train_folder: str = None,
130
+ aspect_ratios: dict = ASPECT_RATIO_512,
131
+ drop_last: bool = False,
132
+ config=None,
133
+ **kwargs
134
+ ) -> None:
135
+ if not isinstance(sampler, Sampler):
136
+ raise TypeError('sampler should be an instance of ``Sampler``, '
137
+ f'but got {sampler}')
138
+ if not isinstance(batch_size, int) or batch_size <= 0:
139
+ raise ValueError('batch_size should be a positive integer value, '
140
+ f'but got batch_size={batch_size}')
141
+ self.sampler = sampler
142
+ self.dataset = dataset
143
+ self.train_folder = train_folder
144
+ self.batch_size = batch_size
145
+ self.aspect_ratios = aspect_ratios
146
+ self.drop_last = drop_last
147
+ self.config = config
148
+ # buckets for each aspect ratio
149
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
150
+ # [str(k) for k, v in aspect_ratios]
151
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
152
+
153
+ def __iter__(self):
154
+ for idx in self.sampler:
155
+ try:
156
+ image_dict = self.dataset[idx]
157
+
158
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
159
+ if width is None or height is None:
160
+ image_id, name = image_dict['file_path'], image_dict['text']
161
+ if self.train_folder is None:
162
+ image_dir = image_id
163
+ else:
164
+ image_dir = os.path.join(self.train_folder, image_id)
165
+
166
+ width, height = get_image_size_without_loading(image_dir)
167
+
168
+ ratio = height / width # self.dataset[idx]
169
+ else:
170
+ height = int(height)
171
+ width = int(width)
172
+ ratio = height / width # self.dataset[idx]
173
+ except Exception as e:
174
+ print(e)
175
+ continue
176
+ # find the closest aspect ratio
177
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
178
+ if closest_ratio not in self.current_available_bucket_keys:
179
+ continue
180
+ bucket = self._aspect_ratio_buckets[closest_ratio]
181
+ bucket.append(idx)
182
+ # yield a batch of indices in the same aspect ratio group
183
+ if len(bucket) == self.batch_size:
184
+ yield bucket[:]
185
+ del bucket[:]
186
+
187
+ class AspectRatioBatchSampler(BatchSampler):
188
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
189
+
190
+ Args:
191
+ sampler (Sampler): Base sampler.
192
+ dataset (Dataset): Dataset providing data information.
193
+ batch_size (int): Size of mini-batch.
194
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
195
+ its size would be less than ``batch_size``.
196
+ aspect_ratios (dict): The predefined aspect ratios.
197
+ """
198
+ def __init__(
199
+ self,
200
+ sampler: Sampler,
201
+ dataset: Dataset,
202
+ batch_size: int,
203
+ video_folder: str = None,
204
+ train_data_format: str = "webvid",
205
+ aspect_ratios: dict = ASPECT_RATIO_512,
206
+ drop_last: bool = False,
207
+ config=None,
208
+ **kwargs
209
+ ) -> None:
210
+ if not isinstance(sampler, Sampler):
211
+ raise TypeError('sampler should be an instance of ``Sampler``, '
212
+ f'but got {sampler}')
213
+ if not isinstance(batch_size, int) or batch_size <= 0:
214
+ raise ValueError('batch_size should be a positive integer value, '
215
+ f'but got batch_size={batch_size}')
216
+ self.sampler = sampler
217
+ self.dataset = dataset
218
+ self.video_folder = video_folder
219
+ self.train_data_format = train_data_format
220
+ self.batch_size = batch_size
221
+ self.aspect_ratios = aspect_ratios
222
+ self.drop_last = drop_last
223
+ self.config = config
224
+ # buckets for each aspect ratio
225
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
226
+ # [str(k) for k, v in aspect_ratios]
227
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
228
+
229
+ def __iter__(self):
230
+ for idx in self.sampler:
231
+ try:
232
+ video_dict = self.dataset[idx]
233
+ width, more = video_dict.get("width", None), video_dict.get("height", None)
234
+
235
+ if width is None or height is None:
236
+ if self.train_data_format == "normal":
237
+ video_id, name = video_dict['file_path'], video_dict['text']
238
+ if self.video_folder is None:
239
+ video_dir = video_id
240
+ else:
241
+ video_dir = os.path.join(self.video_folder, video_id)
242
+ else:
243
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
244
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
245
+ cap = cv2.VideoCapture(video_dir)
246
+
247
+ # 获取视频尺寸
248
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
249
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
250
+
251
+ ratio = height / width # self.dataset[idx]
252
+ else:
253
+ height = int(height)
254
+ width = int(width)
255
+ ratio = height / width # self.dataset[idx]
256
+ except Exception as e:
257
+ print(e, self.dataset[idx], "This item is error, please check it.")
258
+ continue
259
+ # find the closest aspect ratio
260
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
261
+ if closest_ratio not in self.current_available_bucket_keys:
262
+ continue
263
+ bucket = self._aspect_ratio_buckets[closest_ratio]
264
+ bucket.append(idx)
265
+ # yield a batch of indices in the same aspect ratio group
266
+ if len(bucket) == self.batch_size:
267
+ yield bucket[:]
268
+ del bucket[:]
269
+
270
+ class AspectRatioBatchImageVideoSampler(BatchSampler):
271
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
272
+
273
+ Args:
274
+ sampler (Sampler): Base sampler.
275
+ dataset (Dataset): Dataset providing data information.
276
+ batch_size (int): Size of mini-batch.
277
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
278
+ its size would be less than ``batch_size``.
279
+ aspect_ratios (dict): The predefined aspect ratios.
280
+ """
281
+
282
+ def __init__(self,
283
+ sampler: Sampler,
284
+ dataset: Dataset,
285
+ batch_size: int,
286
+ train_folder: str = None,
287
+ aspect_ratios: dict = ASPECT_RATIO_512,
288
+ drop_last: bool = False
289
+ ) -> None:
290
+ if not isinstance(sampler, Sampler):
291
+ raise TypeError('sampler should be an instance of ``Sampler``, '
292
+ f'but got {sampler}')
293
+ if not isinstance(batch_size, int) or batch_size <= 0:
294
+ raise ValueError('batch_size should be a positive integer value, '
295
+ f'but got batch_size={batch_size}')
296
+ self.sampler = sampler
297
+ self.dataset = dataset
298
+ self.train_folder = train_folder
299
+ self.batch_size = batch_size
300
+ self.aspect_ratios = aspect_ratios
301
+ self.drop_last = drop_last
302
+
303
+ # buckets for each aspect ratio
304
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
305
+ self.bucket = {
306
+ 'image':{ratio: [] for ratio in aspect_ratios},
307
+ 'video':{ratio: [] for ratio in aspect_ratios}
308
+ }
309
+
310
+ def __iter__(self):
311
+ for idx in self.sampler:
312
+ content_type = self.dataset[idx].get('type', 'image')
313
+ if content_type == 'image':
314
+ try:
315
+ image_dict = self.dataset[idx]
316
+
317
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
318
+ if width is None or height is None:
319
+ image_id, name = image_dict['file_path'], image_dict['text']
320
+ if self.train_folder is None:
321
+ image_dir = image_id
322
+ else:
323
+ image_dir = os.path.join(self.train_folder, image_id)
324
+
325
+ width, height = get_image_size_without_loading(image_dir)
326
+
327
+ ratio = height / width # self.dataset[idx]
328
+ else:
329
+ height = int(height)
330
+ width = int(width)
331
+ ratio = height / width # self.dataset[idx]
332
+ except Exception as e:
333
+ print(e, self.dataset[idx], "This item is error, please check it.")
334
+ continue
335
+ # find the closest aspect ratio
336
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
337
+ if closest_ratio not in self.current_available_bucket_keys:
338
+ continue
339
+ bucket = self.bucket['image'][closest_ratio]
340
+ bucket.append(idx)
341
+ # yield a batch of indices in the same aspect ratio group
342
+ if len(bucket) == self.batch_size:
343
+ yield bucket[:]
344
+ del bucket[:]
345
+ else:
346
+ try:
347
+ video_dict = self.dataset[idx]
348
+ width, height = video_dict.get("width", None), video_dict.get("height", None)
349
+
350
+ if width is None or height is None:
351
+ video_id, name = video_dict['file_path'], video_dict['text']
352
+ if self.train_folder is None:
353
+ video_dir = video_id
354
+ else:
355
+ video_dir = os.path.join(self.train_folder, video_id)
356
+ cap = cv2.VideoCapture(video_dir)
357
+
358
+ # 获取视频尺寸
359
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
360
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
361
+
362
+ ratio = height / width # self.dataset[idx]
363
+ else:
364
+ height = int(height)
365
+ width = int(width)
366
+ ratio = height / width # self.dataset[idx]
367
+ except Exception as e:
368
+ print(e, self.dataset[idx], "This item is error, please check it.")
369
+ continue
370
+ # find the closest aspect ratio
371
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
372
+ if closest_ratio not in self.current_available_bucket_keys:
373
+ continue
374
+ bucket = self.bucket['video'][closest_ratio]
375
+ bucket.append(idx)
376
+ # yield a batch of indices in the same aspect ratio group
377
+ if len(bucket) == self.batch_size:
378
+ yield bucket[:]
379
+ del bucket[:]
rose/data/dataset_image.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+
12
+ class CC15M(Dataset):
13
+ def __init__(
14
+ self,
15
+ json_path,
16
+ video_folder=None,
17
+ resolution=512,
18
+ enable_bucket=False,
19
+ ):
20
+ print(f"loading annotations from {json_path} ...")
21
+ self.dataset = json.load(open(json_path, 'r'))
22
+ self.length = len(self.dataset)
23
+ print(f"data scale: {self.length}")
24
+
25
+ self.enable_bucket = enable_bucket
26
+ self.video_folder = video_folder
27
+
28
+ resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
29
+ self.pixel_transforms = transforms.Compose([
30
+ transforms.Resize(resolution[0]),
31
+ transforms.CenterCrop(resolution),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
34
+ ])
35
+
36
+ def get_batch(self, idx):
37
+ video_dict = self.dataset[idx]
38
+ video_id, name = video_dict['file_path'], video_dict['text']
39
+
40
+ if self.video_folder is None:
41
+ video_dir = video_id
42
+ else:
43
+ video_dir = os.path.join(self.video_folder, video_id)
44
+
45
+ pixel_values = Image.open(video_dir).convert("RGB")
46
+ return pixel_values, name
47
+
48
+ def __len__(self):
49
+ return self.length
50
+
51
+ def __getitem__(self, idx):
52
+ while True:
53
+ try:
54
+ pixel_values, name = self.get_batch(idx)
55
+ break
56
+ except Exception as e:
57
+ print(e)
58
+ idx = random.randint(0, self.length-1)
59
+
60
+ if not self.enable_bucket:
61
+ pixel_values = self.pixel_transforms(pixel_values)
62
+ else:
63
+ pixel_values = np.array(pixel_values)
64
+
65
+ sample = dict(pixel_values=pixel_values, text=name)
66
+ return sample
67
+
68
+ if __name__ == "__main__":
69
+ dataset = CC15M(
70
+ csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
71
+ resolution=512,
72
+ )
73
+
74
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
75
+ for idx, batch in enumerate(dataloader):
76
+ print(batch["pixel_values"].shape, len(batch["text"]))
rose/data/dataset_image_video.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import io
3
+ import json
4
+ import math
5
+ import os
6
+ import random
7
+ from threading import Thread
8
+
9
+ import albumentations
10
+ import cv2
11
+ import gc
12
+ import numpy as np
13
+ import torch
14
+ import torchvision.transforms as transforms
15
+
16
+ from func_timeout import func_timeout, FunctionTimedOut
17
+ from decord import VideoReader
18
+ from PIL import Image
19
+ from torch.utils.data import BatchSampler, Sampler
20
+ from torch.utils.data.dataset import Dataset
21
+ from contextlib import contextmanager
22
+
23
+ VIDEO_READER_TIMEOUT = 20
24
+
25
+ def get_random_mask(shape, image_start_only=False):
26
+ f, c, h, w = shape
27
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
28
+
29
+ if not image_start_only:
30
+ if f != 1:
31
+ mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05])
32
+ else:
33
+ mask_index = np.random.choice([0, 1], p = [0.2, 0.8])
34
+ if mask_index == 0:
35
+ center_x = torch.randint(0, w, (1,)).item()
36
+ center_y = torch.randint(0, h, (1,)).item()
37
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
38
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
39
+
40
+ start_x = max(center_x - block_size_x // 2, 0)
41
+ end_x = min(center_x + block_size_x // 2, w)
42
+ start_y = max(center_y - block_size_y // 2, 0)
43
+ end_y = min(center_y + block_size_y // 2, h)
44
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
45
+ elif mask_index == 1:
46
+ mask[:, :, :, :] = 1
47
+ elif mask_index == 2:
48
+ mask_frame_index = np.random.randint(1, 5)
49
+ mask[mask_frame_index:, :, :, :] = 1
50
+ elif mask_index == 3:
51
+ mask_frame_index = np.random.randint(1, 5)
52
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
53
+ elif mask_index == 4:
54
+ center_x = torch.randint(0, w, (1,)).item()
55
+ center_y = torch.randint(0, h, (1,)).item()
56
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
57
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
58
+
59
+ start_x = max(center_x - block_size_x // 2, 0)
60
+ end_x = min(center_x + block_size_x // 2, w)
61
+ start_y = max(center_y - block_size_y // 2, 0)
62
+ end_y = min(center_y + block_size_y // 2, h)
63
+
64
+ mask_frame_before = np.random.randint(0, f // 2)
65
+ mask_frame_after = np.random.randint(f // 2, f)
66
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
67
+ elif mask_index == 5:
68
+ mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
69
+ elif mask_index == 6:
70
+ num_frames_to_mask = random.randint(1, max(f // 2, 1))
71
+ frames_to_mask = random.sample(range(f), num_frames_to_mask)
72
+
73
+ for i in frames_to_mask:
74
+ block_height = random.randint(1, h // 4)
75
+ block_width = random.randint(1, w // 4)
76
+ top_left_y = random.randint(0, h - block_height)
77
+ top_left_x = random.randint(0, w - block_width)
78
+ mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
79
+ elif mask_index == 7:
80
+ center_x = torch.randint(0, w, (1,)).item()
81
+ center_y = torch.randint(0, h, (1,)).item()
82
+ a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴
83
+ b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴
84
+
85
+ for i in range(h):
86
+ for j in range(w):
87
+ if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
88
+ mask[:, :, i, j] = 1
89
+ elif mask_index == 8:
90
+ center_x = torch.randint(0, w, (1,)).item()
91
+ center_y = torch.randint(0, h, (1,)).item()
92
+ radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
93
+ for i in range(h):
94
+ for j in range(w):
95
+ if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
96
+ mask[:, :, i, j] = 1
97
+ elif mask_index == 9:
98
+ for idx in range(f):
99
+ if np.random.rand() > 0.5:
100
+ mask[idx, :, :, :] = 1
101
+ else:
102
+ raise ValueError(f"The mask_index {mask_index} is not define")
103
+ else:
104
+ if f != 1:
105
+ mask[1:, :, :, :] = 1
106
+ else:
107
+ mask[:, :, :, :] = 1
108
+ return mask
109
+
110
+ class ImageVideoSampler(BatchSampler):
111
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
112
+
113
+ Args:
114
+ sampler (Sampler): Base sampler.
115
+ dataset (Dataset): Dataset providing data information.
116
+ batch_size (int): Size of mini-batch.
117
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
118
+ its size would be less than ``batch_size``.
119
+ aspect_ratios (dict): The predefined aspect ratios.
120
+ """
121
+
122
+ def __init__(self,
123
+ sampler: Sampler,
124
+ dataset: Dataset,
125
+ batch_size: int,
126
+ drop_last: bool = False
127
+ ) -> None:
128
+ if not isinstance(sampler, Sampler):
129
+ raise TypeError('sampler should be an instance of ``Sampler``, '
130
+ f'but got {sampler}')
131
+ if not isinstance(batch_size, int) or batch_size <= 0:
132
+ raise ValueError('batch_size should be a positive integer value, '
133
+ f'but got batch_size={batch_size}')
134
+ self.sampler = sampler
135
+ self.dataset = dataset
136
+ self.batch_size = batch_size
137
+ self.drop_last = drop_last
138
+
139
+ # buckets for each aspect ratio
140
+ self.bucket = {'image':[], 'video':[]}
141
+
142
+ def __iter__(self):
143
+ for idx in self.sampler:
144
+ content_type = self.dataset.dataset[idx].get('type', 'image')
145
+ self.bucket[content_type].append(idx)
146
+
147
+ # yield a batch of indices in the same aspect ratio group
148
+ if len(self.bucket['video']) == self.batch_size:
149
+ bucket = self.bucket['video']
150
+ yield bucket[:]
151
+ del bucket[:]
152
+ elif len(self.bucket['image']) == self.batch_size:
153
+ bucket = self.bucket['image']
154
+ yield bucket[:]
155
+ del bucket[:]
156
+
157
+ @contextmanager
158
+ def VideoReader_contextmanager(*args, **kwargs):
159
+ vr = VideoReader(*args, **kwargs)
160
+ try:
161
+ yield vr
162
+ finally:
163
+ del vr
164
+ gc.collect()
165
+
166
+ def get_video_reader_batch(video_reader, batch_index):
167
+ frames = video_reader.get_batch(batch_index).asnumpy()
168
+ return frames
169
+
170
+ def resize_frame(frame, target_short_side):
171
+ h, w, _ = frame.shape
172
+ if h < w:
173
+ if target_short_side > h:
174
+ return frame
175
+ new_h = target_short_side
176
+ new_w = int(target_short_side * w / h)
177
+ else:
178
+ if target_short_side > w:
179
+ return frame
180
+ new_w = target_short_side
181
+ new_h = int(target_short_side * h / w)
182
+
183
+ resized_frame = cv2.resize(frame, (new_w, new_h))
184
+ return resized_frame
185
+
186
+ class ImageVideoDataset(Dataset):
187
+ def __init__(
188
+ self,
189
+ ann_path, data_root=None,
190
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
191
+ image_sample_size=512,
192
+ video_repeat=0,
193
+ text_drop_ratio=0.1,
194
+ enable_bucket=False,
195
+ video_length_drop_start=0.0,
196
+ video_length_drop_end=1.0,
197
+ enable_inpaint=False,
198
+ ):
199
+ # Loading annotations from files
200
+ print(f"loading annotations from {ann_path} ...")
201
+ if ann_path.endswith('.csv'):
202
+ with open(ann_path, 'r') as csvfile:
203
+ dataset = list(csv.DictReader(csvfile))
204
+ elif ann_path.endswith('.json'):
205
+ dataset = json.load(open(ann_path))
206
+
207
+ self.data_root = data_root
208
+
209
+ # It's used to balance num of images and videos.
210
+ self.dataset = []
211
+ for data in dataset:
212
+ if data.get('type', 'image') != 'video':
213
+ self.dataset.append(data)
214
+ if video_repeat > 0:
215
+ for _ in range(video_repeat):
216
+ for data in dataset:
217
+ if data.get('type', 'image') == 'video':
218
+ self.dataset.append(data)
219
+ del dataset
220
+
221
+ self.length = len(self.dataset)
222
+ print(f"data scale: {self.length}")
223
+ # TODO: enable bucket training
224
+ self.enable_bucket = enable_bucket
225
+ self.text_drop_ratio = text_drop_ratio
226
+ self.enable_inpaint = enable_inpaint
227
+
228
+ self.video_length_drop_start = video_length_drop_start
229
+ self.video_length_drop_end = video_length_drop_end
230
+
231
+ # Video params
232
+ self.video_sample_stride = video_sample_stride
233
+ self.video_sample_n_frames = video_sample_n_frames
234
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
235
+ self.video_transforms = transforms.Compose(
236
+ [
237
+ transforms.Resize(min(self.video_sample_size)),
238
+ transforms.CenterCrop(self.video_sample_size),
239
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
240
+ ]
241
+ )
242
+
243
+ # Image params
244
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
245
+ self.image_transforms = transforms.Compose([
246
+ transforms.Resize(min(self.image_sample_size)),
247
+ transforms.CenterCrop(self.image_sample_size),
248
+ transforms.ToTensor(),
249
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
250
+ ])
251
+
252
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
253
+
254
+ def get_batch(self, idx):
255
+ data_info = self.dataset[idx % len(self.dataset)]
256
+
257
+ if data_info.get('type', 'image')=='video':
258
+ video_id, text = data_info['file_path'], data_info['text']
259
+
260
+ if self.data_root is None:
261
+ video_dir = video_id
262
+ else:
263
+ video_dir = os.path.join(self.data_root, video_id)
264
+
265
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
266
+ min_sample_n_frames = min(
267
+ self.video_sample_n_frames,
268
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
269
+ )
270
+ if min_sample_n_frames == 0:
271
+ raise ValueError(f"No Frames in video.")
272
+
273
+ video_length = int(self.video_length_drop_end * len(video_reader))
274
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
275
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
276
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
277
+
278
+ try:
279
+ sample_args = (video_reader, batch_index)
280
+ pixel_values = func_timeout(
281
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
282
+ )
283
+ resized_frames = []
284
+ for i in range(len(pixel_values)):
285
+ frame = pixel_values[i]
286
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
287
+ resized_frames.append(resized_frame)
288
+ pixel_values = np.array(resized_frames)
289
+ except FunctionTimedOut:
290
+ raise ValueError(f"Read {idx} timeout.")
291
+ except Exception as e:
292
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
293
+
294
+ if not self.enable_bucket:
295
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
296
+ pixel_values = pixel_values / 255.
297
+ del video_reader
298
+ else:
299
+ pixel_values = pixel_values
300
+
301
+ if not self.enable_bucket:
302
+ pixel_values = self.video_transforms(pixel_values)
303
+
304
+ # Random use no text generation
305
+ if random.random() < self.text_drop_ratio:
306
+ text = ''
307
+ return pixel_values, text, 'video'
308
+ else:
309
+ image_path, text = data_info['file_path'], data_info['text']
310
+ if self.data_root is not None:
311
+ image_path = os.path.join(self.data_root, image_path)
312
+ image = Image.open(image_path).convert('RGB')
313
+ if not self.enable_bucket:
314
+ image = self.image_transforms(image).unsqueeze(0)
315
+ else:
316
+ image = np.expand_dims(np.array(image), 0)
317
+ if random.random() < self.text_drop_ratio:
318
+ text = ''
319
+ return image, text, 'image'
320
+
321
+ def __len__(self):
322
+ return self.length
323
+
324
+ def __getitem__(self, idx):
325
+ data_info = self.dataset[idx % len(self.dataset)]
326
+ data_type = data_info.get('type', 'image')
327
+ while True:
328
+ sample = {}
329
+ try:
330
+ data_info_local = self.dataset[idx % len(self.dataset)]
331
+ data_type_local = data_info_local.get('type', 'image')
332
+ if data_type_local != data_type:
333
+ raise ValueError("data_type_local != data_type")
334
+
335
+ pixel_values, name, data_type = self.get_batch(idx)
336
+ sample["pixel_values"] = pixel_values
337
+ sample["text"] = name
338
+ sample["data_type"] = data_type
339
+ sample["idx"] = idx
340
+
341
+ if len(sample) > 0:
342
+ break
343
+ except Exception as e:
344
+ print(e, self.dataset[idx % len(self.dataset)])
345
+ idx = random.randint(0, self.length-1)
346
+
347
+ if self.enable_inpaint and not self.enable_bucket:
348
+ mask = get_random_mask(pixel_values.size())
349
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
350
+ sample["mask_pixel_values"] = mask_pixel_values
351
+ sample["mask"] = mask
352
+
353
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
354
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
355
+ sample["clip_pixel_values"] = clip_pixel_values
356
+
357
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
358
+ if (mask == 1).all():
359
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
360
+ sample["ref_pixel_values"] = ref_pixel_values
361
+
362
+ return sample
363
+
364
+
365
+ class ImageVideoControlDataset(Dataset):
366
+ def __init__(
367
+ self,
368
+ ann_path, data_root=None,
369
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
370
+ image_sample_size=512,
371
+ video_repeat=0,
372
+ text_drop_ratio=0.1,
373
+ enable_bucket=False,
374
+ video_length_drop_start=0.0,
375
+ video_length_drop_end=1.0,
376
+ enable_inpaint=False,
377
+ ):
378
+ # Loading annotations from files
379
+ print(f"loading annotations from {ann_path} ...")
380
+ if ann_path.endswith('.csv'):
381
+ with open(ann_path, 'r') as csvfile:
382
+ dataset = list(csv.DictReader(csvfile))
383
+ elif ann_path.endswith('.json'):
384
+ dataset = json.load(open(ann_path))
385
+
386
+ self.data_root = data_root
387
+
388
+ # It's used to balance num of images and videos.
389
+ self.dataset = []
390
+ for data in dataset:
391
+ if data.get('type', 'image') != 'video':
392
+ self.dataset.append(data)
393
+ if video_repeat > 0:
394
+ for _ in range(video_repeat):
395
+ for data in dataset:
396
+ if data.get('type', 'image') == 'video':
397
+ self.dataset.append(data)
398
+ del dataset
399
+
400
+ self.length = len(self.dataset)
401
+ print(f"data scale: {self.length}")
402
+ # TODO: enable bucket training
403
+ self.enable_bucket = enable_bucket
404
+ self.text_drop_ratio = text_drop_ratio
405
+ self.enable_inpaint = enable_inpaint
406
+
407
+ self.video_length_drop_start = video_length_drop_start
408
+ self.video_length_drop_end = video_length_drop_end
409
+
410
+ # Video params
411
+ self.video_sample_stride = video_sample_stride
412
+ self.video_sample_n_frames = video_sample_n_frames
413
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
414
+ self.video_transforms = transforms.Compose(
415
+ [
416
+ transforms.Resize(min(self.video_sample_size)),
417
+ transforms.CenterCrop(self.video_sample_size),
418
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
419
+ ]
420
+ )
421
+
422
+ # Image params
423
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
424
+ self.image_transforms = transforms.Compose([
425
+ transforms.Resize(min(self.image_sample_size)),
426
+ transforms.CenterCrop(self.image_sample_size),
427
+ transforms.ToTensor(),
428
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
429
+ ])
430
+
431
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
432
+
433
+ def get_batch(self, idx):
434
+ data_info = self.dataset[idx % len(self.dataset)]
435
+ video_id, text = data_info['file_path'], data_info['text']
436
+
437
+ if data_info.get('type', 'image')=='video':
438
+ if self.data_root is None:
439
+ video_dir = video_id
440
+ else:
441
+ video_dir = os.path.join(self.data_root, video_id)
442
+
443
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
444
+ min_sample_n_frames = min(
445
+ self.video_sample_n_frames,
446
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
447
+ )
448
+ if min_sample_n_frames == 0:
449
+ raise ValueError(f"No Frames in video.")
450
+
451
+ video_length = int(self.video_length_drop_end * len(video_reader))
452
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
453
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
454
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
455
+
456
+ try:
457
+ sample_args = (video_reader, batch_index)
458
+ pixel_values = func_timeout(
459
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
460
+ )
461
+ resized_frames = []
462
+ for i in range(len(pixel_values)):
463
+ frame = pixel_values[i]
464
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
465
+ resized_frames.append(resized_frame)
466
+ pixel_values = np.array(resized_frames)
467
+ except FunctionTimedOut:
468
+ raise ValueError(f"Read {idx} timeout.")
469
+ except Exception as e:
470
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
471
+
472
+ if not self.enable_bucket:
473
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
474
+ pixel_values = pixel_values / 255.
475
+ del video_reader
476
+ else:
477
+ pixel_values = pixel_values
478
+
479
+ if not self.enable_bucket:
480
+ pixel_values = self.video_transforms(pixel_values)
481
+
482
+ # Random use no text generation
483
+ if random.random() < self.text_drop_ratio:
484
+ text = ''
485
+
486
+ control_video_id = data_info['control_file_path']
487
+
488
+ if self.data_root is None:
489
+ control_video_id = control_video_id
490
+ else:
491
+ control_video_id = os.path.join(self.data_root, control_video_id)
492
+
493
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
494
+ try:
495
+ sample_args = (control_video_reader, batch_index)
496
+ control_pixel_values = func_timeout(
497
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
498
+ )
499
+ resized_frames = []
500
+ for i in range(len(control_pixel_values)):
501
+ frame = control_pixel_values[i]
502
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
503
+ resized_frames.append(resized_frame)
504
+ control_pixel_values = np.array(resized_frames)
505
+ except FunctionTimedOut:
506
+ raise ValueError(f"Read {idx} timeout.")
507
+ except Exception as e:
508
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
509
+
510
+ if not self.enable_bucket:
511
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
512
+ control_pixel_values = control_pixel_values / 255.
513
+ del control_video_reader
514
+ else:
515
+ control_pixel_values = control_pixel_values
516
+
517
+ if not self.enable_bucket:
518
+ control_pixel_values = self.video_transforms(control_pixel_values)
519
+ return pixel_values, control_pixel_values, text, "video"
520
+ else:
521
+ image_path, text = data_info['file_path'], data_info['text']
522
+ if self.data_root is not None:
523
+ image_path = os.path.join(self.data_root, image_path)
524
+ image = Image.open(image_path).convert('RGB')
525
+ if not self.enable_bucket:
526
+ image = self.image_transforms(image).unsqueeze(0)
527
+ else:
528
+ image = np.expand_dims(np.array(image), 0)
529
+
530
+ if random.random() < self.text_drop_ratio:
531
+ text = ''
532
+
533
+ control_image_id = data_info['control_file_path']
534
+
535
+ if self.data_root is None:
536
+ control_image_id = control_image_id
537
+ else:
538
+ control_image_id = os.path.join(self.data_root, control_image_id)
539
+
540
+ control_image = Image.open(control_image_id).convert('RGB')
541
+ if not self.enable_bucket:
542
+ control_image = self.image_transforms(control_image).unsqueeze(0)
543
+ else:
544
+ control_image = np.expand_dims(np.array(control_image), 0)
545
+ return image, control_image, text, 'image'
546
+
547
+ def __len__(self):
548
+ return self.length
549
+
550
+ def __getitem__(self, idx):
551
+ data_info = self.dataset[idx % len(self.dataset)]
552
+ data_type = data_info.get('type', 'image')
553
+ while True:
554
+ sample = {}
555
+ try:
556
+ data_info_local = self.dataset[idx % len(self.dataset)]
557
+ data_type_local = data_info_local.get('type', 'image')
558
+ if data_type_local != data_type:
559
+ raise ValueError("data_type_local != data_type")
560
+
561
+ pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
562
+ sample["pixel_values"] = pixel_values
563
+ sample["control_pixel_values"] = control_pixel_values
564
+ sample["text"] = name
565
+ sample["data_type"] = data_type
566
+ sample["idx"] = idx
567
+
568
+ if len(sample) > 0:
569
+ break
570
+ except Exception as e:
571
+ print(e, self.dataset[idx % len(self.dataset)])
572
+ idx = random.randint(0, self.length-1)
573
+
574
+ if self.enable_inpaint and not self.enable_bucket:
575
+ mask = get_random_mask(pixel_values.size())
576
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
577
+ sample["mask_pixel_values"] = mask_pixel_values
578
+ sample["mask"] = mask
579
+
580
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
581
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
582
+ sample["clip_pixel_values"] = clip_pixel_values
583
+
584
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
585
+ if (mask == 1).all():
586
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
587
+ sample["ref_pixel_values"] = ref_pixel_values
588
+
589
+ return sample
rose/data/dataset_video.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ from contextlib import contextmanager
9
+ from threading import Thread
10
+
11
+ import albumentations
12
+ import cv2
13
+ import numpy as np
14
+ import torch
15
+ import torchvision.transforms as transforms
16
+ from decord import VideoReader
17
+ from einops import rearrange
18
+ from func_timeout import FunctionTimedOut, func_timeout
19
+ from PIL import Image
20
+ from torch.utils.data import BatchSampler, Sampler
21
+ from torch.utils.data.dataset import Dataset
22
+
23
+ VIDEO_READER_TIMEOUT = 20
24
+
25
+ def get_random_mask(shape):
26
+ f, c, h, w = shape
27
+
28
+ mask_index = np.random.randint(0, 4)
29
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
30
+ if mask_index == 0:
31
+ mask[1:, :, :, :] = 1
32
+ elif mask_index == 1:
33
+ mask_frame_index = 1
34
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
35
+ elif mask_index == 2:
36
+ center_x = torch.randint(0, w, (1,)).item()
37
+ center_y = torch.randint(0, h, (1,)).item()
38
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
39
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
40
+
41
+ start_x = max(center_x - block_size_x // 2, 0)
42
+ end_x = min(center_x + block_size_x // 2, w)
43
+ start_y = max(center_y - block_size_y // 2, 0)
44
+ end_y = min(center_y + block_size_y // 2, h)
45
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
46
+ elif mask_index == 3:
47
+ center_x = torch.randint(0, w, (1,)).item()
48
+ center_y = torch.randint(0, h, (1,)).item()
49
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
50
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
51
+
52
+ start_x = max(center_x - block_size_x // 2, 0)
53
+ end_x = min(center_x + block_size_x // 2, w)
54
+ start_y = max(center_y - block_size_y // 2, 0)
55
+ end_y = min(center_y + block_size_y // 2, h)
56
+
57
+ mask_frame_before = np.random.randint(0, f // 2)
58
+ mask_frame_after = np.random.randint(f // 2, f)
59
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
60
+ else:
61
+ raise ValueError(f"The mask_index {mask_index} is not define")
62
+ return mask
63
+
64
+
65
+ @contextmanager
66
+ def VideoReader_contextmanager(*args, **kwargs):
67
+ vr = VideoReader(*args, **kwargs)
68
+ try:
69
+ yield vr
70
+ finally:
71
+ del vr
72
+ gc.collect()
73
+
74
+
75
+ def get_video_reader_batch(video_reader, batch_index):
76
+ frames = video_reader.get_batch(batch_index).asnumpy()
77
+ return frames
78
+
79
+
80
+ class WebVid10M(Dataset):
81
+ def __init__(
82
+ self,
83
+ csv_path, video_folder,
84
+ sample_size=256, sample_stride=4, sample_n_frames=16,
85
+ enable_bucket=False, enable_inpaint=False, is_image=False,
86
+ ):
87
+ print(f"loading annotations from {csv_path} ...")
88
+ with open(csv_path, 'r') as csvfile:
89
+ self.dataset = list(csv.DictReader(csvfile))
90
+ self.length = len(self.dataset)
91
+ print(f"data scale: {self.length}")
92
+
93
+ self.video_folder = video_folder
94
+ self.sample_stride = sample_stride
95
+ self.sample_n_frames = sample_n_frames
96
+ self.enable_bucket = enable_bucket
97
+ self.enable_inpaint = enable_inpaint
98
+ self.is_image = is_image
99
+
100
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
101
+ self.pixel_transforms = transforms.Compose([
102
+ transforms.Resize(sample_size[0]),
103
+ transforms.CenterCrop(sample_size),
104
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
105
+ ])
106
+
107
+ def get_batch(self, idx):
108
+ video_dict = self.dataset[idx]
109
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
110
+
111
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
112
+ video_reader = VideoReader(video_dir)
113
+ video_length = len(video_reader)
114
+
115
+ if not self.is_image:
116
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
117
+ start_idx = random.randint(0, video_length - clip_length)
118
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
119
+ else:
120
+ batch_index = [random.randint(0, video_length - 1)]
121
+
122
+ if not self.enable_bucket:
123
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
124
+ pixel_values = pixel_values / 255.
125
+ del video_reader
126
+ else:
127
+ pixel_values = video_reader.get_batch(batch_index).asnumpy()
128
+
129
+ if self.is_image:
130
+ pixel_values = pixel_values[0]
131
+ return pixel_values, name
132
+
133
+ def __len__(self):
134
+ return self.length
135
+
136
+ def __getitem__(self, idx):
137
+ while True:
138
+ try:
139
+ pixel_values, name = self.get_batch(idx)
140
+ break
141
+
142
+ except Exception as e:
143
+ print("Error info:", e)
144
+ idx = random.randint(0, self.length-1)
145
+
146
+ if not self.enable_bucket:
147
+ pixel_values = self.pixel_transforms(pixel_values)
148
+ if self.enable_inpaint:
149
+ mask = get_random_mask(pixel_values.size())
150
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
151
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
152
+ else:
153
+ sample = dict(pixel_values=pixel_values, text=name)
154
+ return sample
155
+
156
+
157
+ class VideoDataset(Dataset):
158
+ def __init__(
159
+ self,
160
+ json_path, video_folder=None,
161
+ sample_size=256, sample_stride=4, sample_n_frames=16,
162
+ enable_bucket=False, enable_inpaint=False
163
+ ):
164
+ print(f"loading annotations from {json_path} ...")
165
+ self.dataset = json.load(open(json_path, 'r'))
166
+ self.length = len(self.dataset)
167
+ print(f"data scale: {self.length}")
168
+
169
+ self.video_folder = video_folder
170
+ self.sample_stride = sample_stride
171
+ self.sample_n_frames = sample_n_frames
172
+ self.enable_bucket = enable_bucket
173
+ self.enable_inpaint = enable_inpaint
174
+
175
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
176
+ self.pixel_transforms = transforms.Compose(
177
+ [
178
+ transforms.Resize(sample_size[0]),
179
+ transforms.CenterCrop(sample_size),
180
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
181
+ ]
182
+ )
183
+
184
+ def get_batch(self, idx):
185
+ video_dict = self.dataset[idx]
186
+ video_id, name = video_dict['file_path'], video_dict['text']
187
+
188
+ if self.video_folder is None:
189
+ video_dir = video_id
190
+ else:
191
+ video_dir = os.path.join(self.video_folder, video_id)
192
+
193
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
194
+ video_length = len(video_reader)
195
+
196
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
197
+ start_idx = random.randint(0, video_length - clip_length)
198
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
199
+
200
+ try:
201
+ sample_args = (video_reader, batch_index)
202
+ pixel_values = func_timeout(
203
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
204
+ )
205
+ except FunctionTimedOut:
206
+ raise ValueError(f"Read {idx} timeout.")
207
+ except Exception as e:
208
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
209
+
210
+ if not self.enable_bucket:
211
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
212
+ pixel_values = pixel_values / 255.
213
+ del video_reader
214
+ else:
215
+ pixel_values = pixel_values
216
+
217
+ return pixel_values, name
218
+
219
+ def __len__(self):
220
+ return self.length
221
+
222
+ def __getitem__(self, idx):
223
+ while True:
224
+ try:
225
+ pixel_values, name = self.get_batch(idx)
226
+ break
227
+
228
+ except Exception as e:
229
+ print("Error info:", e)
230
+ idx = random.randint(0, self.length-1)
231
+
232
+ if not self.enable_bucket:
233
+ pixel_values = self.pixel_transforms(pixel_values)
234
+ if self.enable_inpaint:
235
+ mask = get_random_mask(pixel_values.size())
236
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
237
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
238
+ else:
239
+ sample = dict(pixel_values=pixel_values, text=name)
240
+ return sample
241
+
242
+
243
+ if __name__ == "__main__":
244
+ if 1:
245
+ dataset = VideoDataset(
246
+ json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json",
247
+ sample_size=256,
248
+ sample_stride=4, sample_n_frames=16,
249
+ )
250
+
251
+ if 0:
252
+ dataset = WebVid10M(
253
+ csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
254
+ video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
255
+ sample_size=256,
256
+ sample_stride=4, sample_n_frames=16,
257
+ is_image=False,
258
+ )
259
+
260
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
261
+ for idx, batch in enumerate(dataloader):
262
+ print(batch["pixel_values"].shape, len(batch["text"]))
rose/dist/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ from .fsdp import shard_model
4
+ from .fuser import (get_sequence_parallel_rank,
5
+ get_sequence_parallel_world_size, get_sp_group,
6
+ get_world_group, init_distributed_environment,
7
+ initialize_model_parallel, set_multi_gpus_devices,
8
+ xFuserLongContextAttention)
9
+ from .wan_xfuser import usp_attn_forward
10
+
11
+ # The pai_fuser is an internally developed acceleration package, which can be used on PAI.
12
+ if importlib.util.find_spec("pai_fuser") is not None:
13
+ from pai_fuser.core import parallel_magvit_vae
14
+ from pai_fuser.core.attention import wan_usp_sparse_attention_wrapper
15
+ from . import wan_xfuser
16
+
17
+ # The simple_wrapper is used to solve the problem about conflicts between cython and torch.compile
18
+ def simple_wrapper(func):
19
+ def inner(*args, **kwargs):
20
+ return func(*args, **kwargs)
21
+ return inner
22
+
23
+ wan_xfuser.usp_attn_forward = simple_wrapper(wan_usp_sparse_attention_wrapper()(wan_xfuser.usp_attn_forward))
24
+ usp_attn_forward = simple_wrapper(wan_xfuser.usp_attn_forward)
25
+ print("Import PAI VAE Turbo and Sparse Attention")
26
+
27
+ from pai_fuser.core.rope import ENABLE_KERNEL, usp_fast_rope_apply_qk
28
+
29
+ if ENABLE_KERNEL:
30
+ import torch
31
+ from .wan_xfuser import rope_apply
32
+
33
+ def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
34
+ if torch.is_grad_enabled():
35
+ q = rope_apply(q, grid_sizes, freqs)
36
+ k = rope_apply(k, grid_sizes, freqs)
37
+ return q, k
38
+ else:
39
+ return usp_fast_rope_apply_qk(q, k, grid_sizes, freqs)
40
+
41
+ wan_xfuser.rope_apply_qk = adaptive_fast_usp_rope_apply_qk
42
+ rope_apply_qk = adaptive_fast_usp_rope_apply_qk
43
+ print("Import PAI Fast rope")
rose/dist/fsdp.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyied from https://github.com/Wan-Video/Wan2.1/blob/main/wan/distributed/fsdp.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import gc
4
+ from functools import partial
5
+
6
+ import torch
7
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
9
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
10
+ from torch.distributed.utils import _free_storage
11
+
12
+
13
+ def shard_model(
14
+ model,
15
+ device_id,
16
+ param_dtype=torch.bfloat16,
17
+ reduce_dtype=torch.float32,
18
+ buffer_dtype=torch.float32,
19
+ process_group=None,
20
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
21
+ sync_module_states=True,
22
+ ):
23
+ model = FSDP(
24
+ module=model,
25
+ process_group=process_group,
26
+ sharding_strategy=sharding_strategy,
27
+ auto_wrap_policy=partial(
28
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
29
+ mixed_precision=MixedPrecision(
30
+ param_dtype=param_dtype,
31
+ reduce_dtype=reduce_dtype,
32
+ buffer_dtype=buffer_dtype),
33
+ device_id=device_id,
34
+ sync_module_states=sync_module_states)
35
+ return model
36
+
37
+ def free_model(model):
38
+ for m in model.modules():
39
+ if isinstance(m, FSDP):
40
+ _free_storage(m._handle.flat_param.data)
41
+ del model
42
+ gc.collect()
43
+ torch.cuda.empty_cache()
rose/dist/fuser.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+ try:
7
+ # The pai_fuser is an internally developed acceleration package, which can be used on PAI.
8
+ if importlib.util.find_spec("pai_fuser") is not None:
9
+ import pai_fuser
10
+ from pai_fuser.core.distributed import (
11
+ get_sequence_parallel_rank, get_sequence_parallel_world_size,
12
+ get_sp_group, get_world_group, init_distributed_environment,
13
+ initialize_model_parallel)
14
+ from pai_fuser.core.long_ctx_attention import \
15
+ xFuserLongContextAttention
16
+ print("Import PAI DiT Turbo")
17
+ else:
18
+ import xfuser
19
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
20
+ get_sequence_parallel_world_size,
21
+ get_sp_group, get_world_group,
22
+ init_distributed_environment,
23
+ initialize_model_parallel)
24
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
25
+ print("Xfuser import sucessful")
26
+ except Exception as ex:
27
+ get_sequence_parallel_world_size = None
28
+ get_sequence_parallel_rank = None
29
+ xFuserLongContextAttention = None
30
+ get_sp_group = None
31
+ get_world_group = None
32
+ init_distributed_environment = None
33
+ initialize_model_parallel = None
34
+
35
+ def set_multi_gpus_devices(ulysses_degree, ring_degree):
36
+ if ulysses_degree > 1 or ring_degree > 1:
37
+ if get_sp_group is None:
38
+ raise RuntimeError("xfuser is not installed.")
39
+ dist.init_process_group("nccl")
40
+ print('parallel inference enabled: ulysses_degree=%d ring_degree=%d rank=%d world_size=%d' % (
41
+ ulysses_degree, ring_degree, dist.get_rank(),
42
+ dist.get_world_size()))
43
+ assert dist.get_world_size() == ring_degree * ulysses_degree, \
44
+ "number of GPUs(%d) should be equal to ring_degree * ulysses_degree." % dist.get_world_size()
45
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
46
+ initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(),
47
+ ring_degree=ring_degree,
48
+ ulysses_degree=ulysses_degree)
49
+ # device = torch.device("cuda:%d" % dist.get_rank())
50
+ device = torch.device(f"cuda:{get_world_group().local_rank}")
51
+ print('rank=%d device=%s' % (get_world_group().rank, str(device)))
52
+ else:
53
+ device = "cuda"
54
+ return device
rose/dist/wan_xfuser.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.cuda.amp as amp
3
+
4
+ from .fuser import (get_sequence_parallel_rank,
5
+ get_sequence_parallel_world_size, get_sp_group,
6
+ init_distributed_environment, initialize_model_parallel,
7
+ xFuserLongContextAttention)
8
+
9
+
10
+ def pad_freqs(original_tensor, target_len):
11
+ seq_len, s1, s2 = original_tensor.shape
12
+ pad_size = target_len - seq_len
13
+ padding_tensor = torch.ones(
14
+ pad_size,
15
+ s1,
16
+ s2,
17
+ dtype=original_tensor.dtype,
18
+ device=original_tensor.device)
19
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
20
+ return padded_tensor
21
+
22
+ @amp.autocast(enabled=False)
23
+ @torch.compiler.disable()
24
+ def rope_apply(x, grid_sizes, freqs):
25
+ """
26
+ x: [B, L, N, C].
27
+ grid_sizes: [B, 3].
28
+ freqs: [M, C // 2].
29
+ """
30
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
31
+ # split freqs
32
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
33
+
34
+ # loop over samples
35
+ output = []
36
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
37
+ seq_len = f * h * w
38
+
39
+ # precompute multipliers
40
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape(
41
+ s, n, -1, 2))
42
+ freqs_i = torch.cat([
43
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
44
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
45
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
46
+ ],
47
+ dim=-1).reshape(seq_len, 1, -1)
48
+
49
+ # apply rotary embedding
50
+ sp_size = get_sequence_parallel_world_size()
51
+ sp_rank = get_sequence_parallel_rank()
52
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
53
+ s_per_rank = s
54
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
55
+ s_per_rank), :, :]
56
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
57
+ x_i = torch.cat([x_i, x[i, s:]])
58
+
59
+ # append to collection
60
+ output.append(x_i)
61
+ return torch.stack(output)
62
+
63
+ def rope_apply_qk(q, k, grid_sizes, freqs):
64
+ q = rope_apply(q, grid_sizes, freqs)
65
+ k = rope_apply(k, grid_sizes, freqs)
66
+ return q, k
67
+
68
+ def usp_attn_forward(self,
69
+ x,
70
+ seq_lens,
71
+ grid_sizes,
72
+ freqs,
73
+ dtype=torch.bfloat16,
74
+ t=0):
75
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
76
+ half_dtypes = (torch.float16, torch.bfloat16)
77
+
78
+ def half(x):
79
+ return x if x.dtype in half_dtypes else x.to(dtype)
80
+
81
+ # query, key, value function
82
+ def qkv_fn(x):
83
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
84
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
85
+ v = self.v(x).view(b, s, n, d)
86
+ return q, k, v
87
+
88
+ q, k, v = qkv_fn(x)
89
+ q, k = rope_apply_qk(q, k, grid_sizes, freqs)
90
+
91
+ # TODO: We should use unpaded q,k,v for attention.
92
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
93
+ # if k_lens is not None:
94
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
95
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
96
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
97
+
98
+ x = xFuserLongContextAttention()(
99
+ None,
100
+ query=half(q),
101
+ key=half(k),
102
+ value=half(v),
103
+ window_size=self.window_size)
104
+
105
+ # TODO: padding after attention.
106
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
107
+
108
+ # output
109
+ x = x.flatten(2)
110
+ x = self.o(x)
111
+ return x
rose/models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
2
+
3
+ from .wan_image_encoder import CLIPModel
4
+ from .wan_text_encoder import WanT5EncoderModel
5
+ from .wan_transformer3d import WanTransformer3DModel
6
+ from .wan_vae import AutoencoderKLWan
rose/models/cache_utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def get_teacache_coefficients(model_name):
6
+ if "wan2.1-t2v-1.3b" or "wan2.1-fun-1.3b" in model_name.lower():
7
+ return [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
8
+ elif "wan2.1-t2v-14b" in model_name.lower():
9
+ return [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
10
+ elif "wan2.1-i2v-14b-480p" in model_name.lower():
11
+ return [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
12
+ elif "wan2.1-i2v-14b-720p" or "wan2.1-fun-14b" in model_name.lower():
13
+ return [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
14
+ else:
15
+ print(f"The model {model_name} is not supported by TeaCache.")
16
+ return None
17
+
18
+
19
+ class TeaCache():
20
+ """
21
+ Timestep Embedding Aware Cache, a training-free caching approach that estimates and leverages
22
+ the fluctuating differences among model outputs across timesteps, thereby accelerating the inference.
23
+ Please refer to:
24
+ 1. https://github.com/ali-vilab/TeaCache.
25
+ 2. Liu, Feng, et al. "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model." arXiv preprint arXiv:2411.19108 (2024).
26
+ """
27
+ def __init__(
28
+ self,
29
+ coefficients: list[float],
30
+ num_steps: int,
31
+ rel_l1_thresh: float = 0.0,
32
+ num_skip_start_steps: int = 0,
33
+ offload: bool = True,
34
+ ):
35
+ if num_steps < 1:
36
+ raise ValueError(f"`num_steps` must be greater than 0 but is {num_steps}.")
37
+ if rel_l1_thresh < 0:
38
+ raise ValueError(f"`rel_l1_thresh` must be greater than or equal to 0 but is {rel_l1_thresh}.")
39
+ if num_skip_start_steps < 0 or num_skip_start_steps > num_steps:
40
+ raise ValueError(
41
+ "`num_skip_start_steps` must be great than or equal to 0 and "
42
+ f"less than or equal to `num_steps={num_steps}` but is {num_skip_start_steps}."
43
+ )
44
+ self.coefficients = coefficients
45
+ self.num_steps = num_steps
46
+ self.rel_l1_thresh = rel_l1_thresh
47
+ self.num_skip_start_steps = num_skip_start_steps
48
+ self.offload = offload
49
+ self.rescale_func = np.poly1d(self.coefficients)
50
+
51
+ self.cnt = 0
52
+ self.should_calc = True
53
+ self.accumulated_rel_l1_distance = 0
54
+ self.previous_modulated_input = None
55
+ # Some pipelines concatenate the unconditional and text guide in forward.
56
+ self.previous_residual = None
57
+ # Some pipelines perform forward propagation separately on the unconditional and text guide.
58
+ self.previous_residual_cond = None
59
+ self.previous_residual_uncond = None
60
+
61
+ @staticmethod
62
+ def compute_rel_l1_distance(prev: torch.Tensor, cur: torch.Tensor) -> torch.Tensor:
63
+ rel_l1_distance = (torch.abs(cur - prev).mean()) / torch.abs(prev).mean()
64
+
65
+ return rel_l1_distance.cpu().item()
66
+
67
+ def reset(self):
68
+ self.cnt = 0
69
+ self.should_calc = True
70
+ self.accumulated_rel_l1_distance = 0
71
+ self.previous_modulated_input = None
72
+ self.previous_residual = None
73
+ self.previous_residual_cond = None
74
+ self.previous_residual_uncond = None
rose/models/diff_mask_predictor.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
5
+
6
+ class DiffMaskPredictor(nn.Module):
7
+ def __init__(self, input_dim=4608, patch_grid=(10, 15, 189), output_grid=(81, 480, 720), hidden_dim=256):
8
+ """
9
+ Args:
10
+ input_dim (int): concatenated feature dimension, e.g. 1536 * num_selected_layers
11
+ patch_grid (tuple): (F_p, H_p, W_p) - patch token grid shape (e.g., from transformer block)
12
+ output_grid (tuple): (F, H, W) - final full resolution shape for mask
13
+ hidden_dim (int): intermediate conv/linear hidden dim
14
+ """
15
+ super().__init__()
16
+ self.F_p, self.H_p, self.W_p = patch_grid
17
+ self.F, self.H, self.W = output_grid
18
+
19
+ self.project = nn.Sequential(
20
+ nn.Linear(input_dim, hidden_dim),
21
+ nn.GELU(),
22
+ nn.Linear(hidden_dim, 1)
23
+ )
24
+
25
+ def forward(self, x):
26
+ """
27
+ Args:
28
+ x (Tensor): shape [B, L, D_total], L = F_p H_p W_p
29
+ Returns:
30
+ Tensor: predicted diff mask, shape [B, 1, F, H, W]
31
+ """
32
+ B, L, D = x.shape
33
+ assert L == self.F_p * self.H_p * self.W_p, \
34
+ f"Input token length {L} doesn't match patch grid ({self.F_p}, {self.H_p}, {self.W_p})"
35
+
36
+ x = self.project(x) # [B, L, 1]
37
+ x = x.view(B, 1, self.F_p, self.H_p, self.W_p) # [B, 1, F_p, H_p, W_p]
38
+ x = F.interpolate(
39
+ x, size=(self.F, self.H, self.W),
40
+ mode="trilinear", align_corners=False # upsample to match ground truth resolution
41
+ )
42
+ return x # [B, 1, F, H, W]
rose/models/wan_image_encoder.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision.transforms as T
9
+
10
+ from .wan_transformer3d import attention
11
+ from .wan_xlm_roberta import XLMRoberta
12
+ from diffusers.configuration_utils import ConfigMixin
13
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+
16
+
17
+ __all__ = [
18
+ 'XLMRobertaCLIP',
19
+ 'clip_xlm_roberta_vit_h_14',
20
+ 'CLIPModel',
21
+ ]
22
+
23
+
24
+ def pos_interpolate(pos, seq_len):
25
+ if pos.size(1) == seq_len:
26
+ return pos
27
+ else:
28
+ src_grid = int(math.sqrt(pos.size(1)))
29
+ tar_grid = int(math.sqrt(seq_len))
30
+ n = pos.size(1) - src_grid * src_grid
31
+ return torch.cat([
32
+ pos[:, :n],
33
+ F.interpolate(
34
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
35
+ 0, 3, 1, 2),
36
+ size=(tar_grid, tar_grid),
37
+ mode='bicubic',
38
+ align_corners=False).flatten(2).transpose(1, 2)
39
+ ],
40
+ dim=1)
41
+
42
+
43
+ class QuickGELU(nn.Module):
44
+
45
+ def forward(self, x):
46
+ return x * torch.sigmoid(1.702 * x)
47
+
48
+
49
+ class LayerNorm(nn.LayerNorm):
50
+
51
+ def forward(self, x):
52
+ return super().forward(x.float()).type_as(x)
53
+
54
+
55
+ class SelfAttention(nn.Module):
56
+
57
+ def __init__(self,
58
+ dim,
59
+ num_heads,
60
+ causal=False,
61
+ attn_dropout=0.0,
62
+ proj_dropout=0.0):
63
+ assert dim % num_heads == 0
64
+ super().__init__()
65
+ self.dim = dim
66
+ self.num_heads = num_heads
67
+ self.head_dim = dim // num_heads
68
+ self.causal = causal
69
+ self.attn_dropout = attn_dropout
70
+ self.proj_dropout = proj_dropout
71
+
72
+ # layers
73
+ self.to_qkv = nn.Linear(dim, dim * 3)
74
+ self.proj = nn.Linear(dim, dim)
75
+
76
+ def forward(self, x):
77
+ """
78
+ x: [B, L, C].
79
+ """
80
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
81
+
82
+ # compute query, key, value
83
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
84
+
85
+ # compute attention
86
+ p = self.attn_dropout if self.training else 0.0
87
+ x = attention(q, k, v, dropout_p=p, causal=self.causal)
88
+ x = x.reshape(b, s, c)
89
+
90
+ # output
91
+ x = self.proj(x)
92
+ x = F.dropout(x, self.proj_dropout, self.training)
93
+ return x
94
+
95
+
96
+ class SwiGLU(nn.Module):
97
+
98
+ def __init__(self, dim, mid_dim):
99
+ super().__init__()
100
+ self.dim = dim
101
+ self.mid_dim = mid_dim
102
+
103
+ # layers
104
+ self.fc1 = nn.Linear(dim, mid_dim)
105
+ self.fc2 = nn.Linear(dim, mid_dim)
106
+ self.fc3 = nn.Linear(mid_dim, dim)
107
+
108
+ def forward(self, x):
109
+ x = F.silu(self.fc1(x)) * self.fc2(x)
110
+ x = self.fc3(x)
111
+ return x
112
+
113
+
114
+ class AttentionBlock(nn.Module):
115
+
116
+ def __init__(self,
117
+ dim,
118
+ mlp_ratio,
119
+ num_heads,
120
+ post_norm=False,
121
+ causal=False,
122
+ activation='quick_gelu',
123
+ attn_dropout=0.0,
124
+ proj_dropout=0.0,
125
+ norm_eps=1e-5):
126
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
127
+ super().__init__()
128
+ self.dim = dim
129
+ self.mlp_ratio = mlp_ratio
130
+ self.num_heads = num_heads
131
+ self.post_norm = post_norm
132
+ self.causal = causal
133
+ self.norm_eps = norm_eps
134
+
135
+ # layers
136
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
137
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
138
+ proj_dropout)
139
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
140
+ if activation == 'swi_glu':
141
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
142
+ else:
143
+ self.mlp = nn.Sequential(
144
+ nn.Linear(dim, int(dim * mlp_ratio)),
145
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
146
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
147
+
148
+ def forward(self, x):
149
+ if self.post_norm:
150
+ x = x + self.norm1(self.attn(x))
151
+ x = x + self.norm2(self.mlp(x))
152
+ else:
153
+ x = x + self.attn(self.norm1(x))
154
+ x = x + self.mlp(self.norm2(x))
155
+ return x
156
+
157
+
158
+ class AttentionPool(nn.Module):
159
+
160
+ def __init__(self,
161
+ dim,
162
+ mlp_ratio,
163
+ num_heads,
164
+ activation='gelu',
165
+ proj_dropout=0.0,
166
+ norm_eps=1e-5):
167
+ assert dim % num_heads == 0
168
+ super().__init__()
169
+ self.dim = dim
170
+ self.mlp_ratio = mlp_ratio
171
+ self.num_heads = num_heads
172
+ self.head_dim = dim // num_heads
173
+ self.proj_dropout = proj_dropout
174
+ self.norm_eps = norm_eps
175
+
176
+ # layers
177
+ gain = 1.0 / math.sqrt(dim)
178
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
179
+ self.to_q = nn.Linear(dim, dim)
180
+ self.to_kv = nn.Linear(dim, dim * 2)
181
+ self.proj = nn.Linear(dim, dim)
182
+ self.norm = LayerNorm(dim, eps=norm_eps)
183
+ self.mlp = nn.Sequential(
184
+ nn.Linear(dim, int(dim * mlp_ratio)),
185
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
186
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
187
+
188
+ def forward(self, x):
189
+ """
190
+ x: [B, L, C].
191
+ """
192
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
193
+
194
+ # compute query, key, value
195
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
196
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
197
+
198
+ # compute attention
199
+ x = flash_attention(q, k, v, version=2)
200
+ x = x.reshape(b, 1, c)
201
+
202
+ # output
203
+ x = self.proj(x)
204
+ x = F.dropout(x, self.proj_dropout, self.training)
205
+
206
+ # mlp
207
+ x = x + self.mlp(self.norm(x))
208
+ return x[:, 0]
209
+
210
+
211
+ class VisionTransformer(nn.Module):
212
+
213
+ def __init__(self,
214
+ image_size=224,
215
+ patch_size=16,
216
+ dim=768,
217
+ mlp_ratio=4,
218
+ out_dim=512,
219
+ num_heads=12,
220
+ num_layers=12,
221
+ pool_type='token',
222
+ pre_norm=True,
223
+ post_norm=False,
224
+ activation='quick_gelu',
225
+ attn_dropout=0.0,
226
+ proj_dropout=0.0,
227
+ embedding_dropout=0.0,
228
+ norm_eps=1e-5):
229
+ if image_size % patch_size != 0:
230
+ print(
231
+ '[WARNING] image_size is not divisible by patch_size',
232
+ flush=True)
233
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
234
+ out_dim = out_dim or dim
235
+ super().__init__()
236
+ self.image_size = image_size
237
+ self.patch_size = patch_size
238
+ self.num_patches = (image_size // patch_size)**2
239
+ self.dim = dim
240
+ self.mlp_ratio = mlp_ratio
241
+ self.out_dim = out_dim
242
+ self.num_heads = num_heads
243
+ self.num_layers = num_layers
244
+ self.pool_type = pool_type
245
+ self.post_norm = post_norm
246
+ self.norm_eps = norm_eps
247
+
248
+ # embeddings
249
+ gain = 1.0 / math.sqrt(dim)
250
+ self.patch_embedding = nn.Conv2d(
251
+ 3,
252
+ dim,
253
+ kernel_size=patch_size,
254
+ stride=patch_size,
255
+ bias=not pre_norm)
256
+ if pool_type in ('token', 'token_fc'):
257
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
258
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
259
+ 1, self.num_patches +
260
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
261
+ self.dropout = nn.Dropout(embedding_dropout)
262
+
263
+ # transformer
264
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
265
+ self.transformer = nn.Sequential(*[
266
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
267
+ activation, attn_dropout, proj_dropout, norm_eps)
268
+ for _ in range(num_layers)
269
+ ])
270
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
271
+
272
+ # head
273
+ if pool_type == 'token':
274
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
275
+ elif pool_type == 'token_fc':
276
+ self.head = nn.Linear(dim, out_dim)
277
+ elif pool_type == 'attn_pool':
278
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
279
+ proj_dropout, norm_eps)
280
+
281
+ def forward(self, x, interpolation=False, use_31_block=False):
282
+ b = x.size(0)
283
+
284
+ # embeddings
285
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
286
+ if self.pool_type in ('token', 'token_fc'):
287
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
288
+ if interpolation:
289
+ e = pos_interpolate(self.pos_embedding, x.size(1))
290
+ else:
291
+ e = self.pos_embedding
292
+ x = self.dropout(x + e)
293
+ if self.pre_norm is not None:
294
+ x = self.pre_norm(x)
295
+
296
+ # transformer
297
+ if use_31_block:
298
+ x = self.transformer[:-1](x)
299
+ return x
300
+ else:
301
+ x = self.transformer(x)
302
+ return x
303
+
304
+
305
+ class XLMRobertaWithHead(XLMRoberta):
306
+
307
+ def __init__(self, **kwargs):
308
+ self.out_dim = kwargs.pop('out_dim')
309
+ super().__init__(**kwargs)
310
+
311
+ # head
312
+ mid_dim = (self.dim + self.out_dim) // 2
313
+ self.head = nn.Sequential(
314
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
315
+ nn.Linear(mid_dim, self.out_dim, bias=False))
316
+
317
+ def forward(self, ids):
318
+ # xlm-roberta
319
+ x = super().forward(ids)
320
+
321
+ # average pooling
322
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
323
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
324
+
325
+ # head
326
+ x = self.head(x)
327
+ return x
328
+
329
+
330
+ class XLMRobertaCLIP(nn.Module):
331
+
332
+ def __init__(self,
333
+ embed_dim=1024,
334
+ image_size=224,
335
+ patch_size=14,
336
+ vision_dim=1280,
337
+ vision_mlp_ratio=4,
338
+ vision_heads=16,
339
+ vision_layers=32,
340
+ vision_pool='token',
341
+ vision_pre_norm=True,
342
+ vision_post_norm=False,
343
+ activation='gelu',
344
+ vocab_size=250002,
345
+ max_text_len=514,
346
+ type_size=1,
347
+ pad_id=1,
348
+ text_dim=1024,
349
+ text_heads=16,
350
+ text_layers=24,
351
+ text_post_norm=True,
352
+ text_dropout=0.1,
353
+ attn_dropout=0.0,
354
+ proj_dropout=0.0,
355
+ embedding_dropout=0.0,
356
+ norm_eps=1e-5):
357
+ super().__init__()
358
+ self.embed_dim = embed_dim
359
+ self.image_size = image_size
360
+ self.patch_size = patch_size
361
+ self.vision_dim = vision_dim
362
+ self.vision_mlp_ratio = vision_mlp_ratio
363
+ self.vision_heads = vision_heads
364
+ self.vision_layers = vision_layers
365
+ self.vision_pre_norm = vision_pre_norm
366
+ self.vision_post_norm = vision_post_norm
367
+ self.activation = activation
368
+ self.vocab_size = vocab_size
369
+ self.max_text_len = max_text_len
370
+ self.type_size = type_size
371
+ self.pad_id = pad_id
372
+ self.text_dim = text_dim
373
+ self.text_heads = text_heads
374
+ self.text_layers = text_layers
375
+ self.text_post_norm = text_post_norm
376
+ self.norm_eps = norm_eps
377
+
378
+ # models
379
+ self.visual = VisionTransformer(
380
+ image_size=image_size,
381
+ patch_size=patch_size,
382
+ dim=vision_dim,
383
+ mlp_ratio=vision_mlp_ratio,
384
+ out_dim=embed_dim,
385
+ num_heads=vision_heads,
386
+ num_layers=vision_layers,
387
+ pool_type=vision_pool,
388
+ pre_norm=vision_pre_norm,
389
+ post_norm=vision_post_norm,
390
+ activation=activation,
391
+ attn_dropout=attn_dropout,
392
+ proj_dropout=proj_dropout,
393
+ embedding_dropout=embedding_dropout,
394
+ norm_eps=norm_eps)
395
+ self.textual = XLMRobertaWithHead(
396
+ vocab_size=vocab_size,
397
+ max_seq_len=max_text_len,
398
+ type_size=type_size,
399
+ pad_id=pad_id,
400
+ dim=text_dim,
401
+ out_dim=embed_dim,
402
+ num_heads=text_heads,
403
+ num_layers=text_layers,
404
+ post_norm=text_post_norm,
405
+ dropout=text_dropout)
406
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
407
+
408
+ def forward(self, imgs, txt_ids):
409
+ """
410
+ imgs: [B, 3, H, W] of torch.float32.
411
+ - mean: [0.48145466, 0.4578275, 0.40821073]
412
+ - std: [0.26862954, 0.26130258, 0.27577711]
413
+ txt_ids: [B, L] of torch.long.
414
+ Encoded by data.CLIPTokenizer.
415
+ """
416
+ xi = self.visual(imgs)
417
+ xt = self.textual(txt_ids)
418
+ return xi, xt
419
+
420
+ def param_groups(self):
421
+ groups = [{
422
+ 'params': [
423
+ p for n, p in self.named_parameters()
424
+ if 'norm' in n or n.endswith('bias')
425
+ ],
426
+ 'weight_decay': 0.0
427
+ }, {
428
+ 'params': [
429
+ p for n, p in self.named_parameters()
430
+ if not ('norm' in n or n.endswith('bias'))
431
+ ]
432
+ }]
433
+ return groups
434
+
435
+
436
+ def _clip(pretrained=False,
437
+ pretrained_name=None,
438
+ model_cls=XLMRobertaCLIP,
439
+ return_transforms=False,
440
+ return_tokenizer=False,
441
+ tokenizer_padding='eos',
442
+ dtype=torch.float32,
443
+ device='cpu',
444
+ **kwargs):
445
+ # init a model on device
446
+ with torch.device(device):
447
+ model = model_cls(**kwargs)
448
+
449
+ # set device
450
+ model = model.to(dtype=dtype, device=device)
451
+ output = (model,)
452
+
453
+ # init transforms
454
+ if return_transforms:
455
+ # mean and std
456
+ if 'siglip' in pretrained_name.lower():
457
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
458
+ else:
459
+ mean = [0.48145466, 0.4578275, 0.40821073]
460
+ std = [0.26862954, 0.26130258, 0.27577711]
461
+
462
+ # transforms
463
+ transforms = T.Compose([
464
+ T.Resize((model.image_size, model.image_size),
465
+ interpolation=T.InterpolationMode.BICUBIC),
466
+ T.ToTensor(),
467
+ T.Normalize(mean=mean, std=std)
468
+ ])
469
+ output += (transforms,)
470
+ return output[0] if len(output) == 1 else output
471
+
472
+
473
+ def clip_xlm_roberta_vit_h_14(
474
+ pretrained=False,
475
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
476
+ **kwargs):
477
+ cfg = dict(
478
+ embed_dim=1024,
479
+ image_size=224,
480
+ patch_size=14,
481
+ vision_dim=1280,
482
+ vision_mlp_ratio=4,
483
+ vision_heads=16,
484
+ vision_layers=32,
485
+ vision_pool='token',
486
+ activation='gelu',
487
+ vocab_size=250002,
488
+ max_text_len=514,
489
+ type_size=1,
490
+ pad_id=1,
491
+ text_dim=1024,
492
+ text_heads=16,
493
+ text_layers=24,
494
+ text_post_norm=True,
495
+ text_dropout=0.1,
496
+ attn_dropout=0.0,
497
+ proj_dropout=0.0,
498
+ embedding_dropout=0.0)
499
+ cfg.update(**kwargs)
500
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
501
+
502
+
503
+ class CLIPModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
504
+
505
+ def __init__(self):
506
+ super(CLIPModel, self).__init__()
507
+ # init model
508
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
509
+ pretrained=False,
510
+ return_transforms=True,
511
+ return_tokenizer=False)
512
+
513
+ def forward(self, videos):
514
+ # preprocess
515
+ size = (self.model.image_size,) * 2
516
+ videos = torch.cat([
517
+ F.interpolate(
518
+ u.transpose(0, 1),
519
+ size=size,
520
+ mode='bicubic',
521
+ align_corners=False) for u in videos
522
+ ])
523
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
524
+
525
+ # forward
526
+ with torch.cuda.amp.autocast(dtype=self.dtype):
527
+ out = self.model.visual(videos, use_31_block=True)
528
+ return out
529
+
530
+ @classmethod
531
+ def from_pretrained(cls, pretrained_model_path, transformer_additional_kwargs={}):
532
+ def filter_kwargs(cls, kwargs):
533
+ import inspect
534
+ sig = inspect.signature(cls.__init__)
535
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
536
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
537
+ return filtered_kwargs
538
+
539
+ model = cls(**filter_kwargs(cls, transformer_additional_kwargs))
540
+ if pretrained_model_path.endswith(".safetensors"):
541
+ from safetensors.torch import load_file, safe_open
542
+ state_dict = load_file(pretrained_model_path)
543
+ else:
544
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
545
+ tmp_state_dict = {}
546
+ for key in state_dict:
547
+ tmp_state_dict["model." + key] = state_dict[key]
548
+ state_dict = tmp_state_dict
549
+ m, u = model.load_state_dict(state_dict)
550
+
551
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
552
+ print(m, u)
553
+ return model
rose/models/wan_text_encoder.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/t5.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from diffusers.configuration_utils import ConfigMixin
10
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+
13
+
14
+ def fp16_clamp(x):
15
+ if x.dtype == torch.float16 and torch.isinf(x).any():
16
+ clamp = torch.finfo(x.dtype).max - 1000
17
+ x = torch.clamp(x, min=-clamp, max=clamp)
18
+ return x
19
+
20
+
21
+ def init_weights(m):
22
+ if isinstance(m, T5LayerNorm):
23
+ nn.init.ones_(m.weight)
24
+ elif isinstance(m, T5FeedForward):
25
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
26
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
27
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
28
+ elif isinstance(m, T5Attention):
29
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
30
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
31
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
32
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
33
+ elif isinstance(m, T5RelativeEmbedding):
34
+ nn.init.normal_(
35
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
36
+
37
+
38
+ class GELU(nn.Module):
39
+ def forward(self, x):
40
+ return 0.5 * x * (1.0 + torch.tanh(
41
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
42
+
43
+
44
+ class T5LayerNorm(nn.Module):
45
+ def __init__(self, dim, eps=1e-6):
46
+ super(T5LayerNorm, self).__init__()
47
+ self.dim = dim
48
+ self.eps = eps
49
+ self.weight = nn.Parameter(torch.ones(dim))
50
+
51
+ def forward(self, x):
52
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
53
+ self.eps)
54
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
55
+ x = x.type_as(self.weight)
56
+ return self.weight * x
57
+
58
+
59
+ class T5Attention(nn.Module):
60
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
61
+ assert dim_attn % num_heads == 0
62
+ super(T5Attention, self).__init__()
63
+ self.dim = dim
64
+ self.dim_attn = dim_attn
65
+ self.num_heads = num_heads
66
+ self.head_dim = dim_attn // num_heads
67
+
68
+ # layers
69
+ self.q = nn.Linear(dim, dim_attn, bias=False)
70
+ self.k = nn.Linear(dim, dim_attn, bias=False)
71
+ self.v = nn.Linear(dim, dim_attn, bias=False)
72
+ self.o = nn.Linear(dim_attn, dim, bias=False)
73
+ self.dropout = nn.Dropout(dropout)
74
+
75
+ def forward(self, x, context=None, mask=None, pos_bias=None):
76
+ """
77
+ x: [B, L1, C].
78
+ context: [B, L2, C] or None.
79
+ mask: [B, L2] or [B, L1, L2] or None.
80
+ """
81
+ # check inputs
82
+ context = x if context is None else context
83
+ b, n, c = x.size(0), self.num_heads, self.head_dim
84
+
85
+ # compute query, key, value
86
+ q = self.q(x).view(b, -1, n, c)
87
+ k = self.k(context).view(b, -1, n, c)
88
+ v = self.v(context).view(b, -1, n, c)
89
+
90
+ # attention bias
91
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
92
+ if pos_bias is not None:
93
+ attn_bias += pos_bias
94
+ if mask is not None:
95
+ assert mask.ndim in [2, 3]
96
+ mask = mask.view(b, 1, 1,
97
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
98
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
99
+
100
+ # compute attention (T5 does not use scaling)
101
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
102
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
103
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
104
+
105
+ # output
106
+ x = x.reshape(b, -1, n * c)
107
+ x = self.o(x)
108
+ x = self.dropout(x)
109
+ return x
110
+
111
+
112
+ class T5FeedForward(nn.Module):
113
+
114
+ def __init__(self, dim, dim_ffn, dropout=0.1):
115
+ super(T5FeedForward, self).__init__()
116
+ self.dim = dim
117
+ self.dim_ffn = dim_ffn
118
+
119
+ # layers
120
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
121
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
122
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
123
+ self.dropout = nn.Dropout(dropout)
124
+
125
+ def forward(self, x):
126
+ x = self.fc1(x) * self.gate(x)
127
+ x = self.dropout(x)
128
+ x = self.fc2(x)
129
+ x = self.dropout(x)
130
+ return x
131
+
132
+
133
+ class T5SelfAttention(nn.Module):
134
+ def __init__(self,
135
+ dim,
136
+ dim_attn,
137
+ dim_ffn,
138
+ num_heads,
139
+ num_buckets,
140
+ shared_pos=True,
141
+ dropout=0.1):
142
+ super(T5SelfAttention, self).__init__()
143
+ self.dim = dim
144
+ self.dim_attn = dim_attn
145
+ self.dim_ffn = dim_ffn
146
+ self.num_heads = num_heads
147
+ self.num_buckets = num_buckets
148
+ self.shared_pos = shared_pos
149
+
150
+ # layers
151
+ self.norm1 = T5LayerNorm(dim)
152
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
153
+ self.norm2 = T5LayerNorm(dim)
154
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
155
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
156
+ num_buckets, num_heads, bidirectional=True)
157
+
158
+ def forward(self, x, mask=None, pos_bias=None):
159
+ e = pos_bias if self.shared_pos else self.pos_embedding(
160
+ x.size(1), x.size(1))
161
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
162
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
163
+ return x
164
+
165
+
166
+ class T5CrossAttention(nn.Module):
167
+ def __init__(self,
168
+ dim,
169
+ dim_attn,
170
+ dim_ffn,
171
+ num_heads,
172
+ num_buckets,
173
+ shared_pos=True,
174
+ dropout=0.1):
175
+ super(T5CrossAttention, self).__init__()
176
+ self.dim = dim
177
+ self.dim_attn = dim_attn
178
+ self.dim_ffn = dim_ffn
179
+ self.num_heads = num_heads
180
+ self.num_buckets = num_buckets
181
+ self.shared_pos = shared_pos
182
+
183
+ # layers
184
+ self.norm1 = T5LayerNorm(dim)
185
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
186
+ self.norm2 = T5LayerNorm(dim)
187
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
188
+ self.norm3 = T5LayerNorm(dim)
189
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
190
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
191
+ num_buckets, num_heads, bidirectional=False)
192
+
193
+ def forward(self,
194
+ x,
195
+ mask=None,
196
+ encoder_states=None,
197
+ encoder_mask=None,
198
+ pos_bias=None):
199
+ e = pos_bias if self.shared_pos else self.pos_embedding(
200
+ x.size(1), x.size(1))
201
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
202
+ x = fp16_clamp(x + self.cross_attn(
203
+ self.norm2(x), context=encoder_states, mask=encoder_mask))
204
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
205
+ return x
206
+
207
+
208
+ class T5RelativeEmbedding(nn.Module):
209
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
210
+ super(T5RelativeEmbedding, self).__init__()
211
+ self.num_buckets = num_buckets
212
+ self.num_heads = num_heads
213
+ self.bidirectional = bidirectional
214
+ self.max_dist = max_dist
215
+
216
+ # layers
217
+ self.embedding = nn.Embedding(num_buckets, num_heads)
218
+
219
+ def forward(self, lq, lk):
220
+ device = self.embedding.weight.device
221
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
222
+ # torch.arange(lq).unsqueeze(1).to(device)
223
+ if torch.device(type="meta") != device:
224
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
225
+ torch.arange(lq, device=device).unsqueeze(1)
226
+ else:
227
+ rel_pos = torch.arange(lk).unsqueeze(0) - \
228
+ torch.arange(lq).unsqueeze(1)
229
+ rel_pos = self._relative_position_bucket(rel_pos)
230
+ rel_pos_embeds = self.embedding(rel_pos)
231
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
232
+ 0) # [1, N, Lq, Lk]
233
+ return rel_pos_embeds.contiguous()
234
+
235
+ def _relative_position_bucket(self, rel_pos):
236
+ # preprocess
237
+ if self.bidirectional:
238
+ num_buckets = self.num_buckets // 2
239
+ rel_buckets = (rel_pos > 0).long() * num_buckets
240
+ rel_pos = torch.abs(rel_pos)
241
+ else:
242
+ num_buckets = self.num_buckets
243
+ rel_buckets = 0
244
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
245
+
246
+ # embeddings for small and large positions
247
+ max_exact = num_buckets // 2
248
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
249
+ math.log(self.max_dist / max_exact) *
250
+ (num_buckets - max_exact)).long()
251
+ rel_pos_large = torch.min(
252
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
253
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
254
+ return rel_buckets
255
+
256
+ class WanT5EncoderModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
257
+ def __init__(self,
258
+ vocab,
259
+ dim,
260
+ dim_attn,
261
+ dim_ffn,
262
+ num_heads,
263
+ num_layers,
264
+ num_buckets,
265
+ shared_pos=True,
266
+ dropout=0.1):
267
+ super(WanT5EncoderModel, self).__init__()
268
+ self.dim = dim
269
+ self.dim_attn = dim_attn
270
+ self.dim_ffn = dim_ffn
271
+ self.num_heads = num_heads
272
+ self.num_layers = num_layers
273
+ self.num_buckets = num_buckets
274
+ self.shared_pos = shared_pos
275
+
276
+ # layers
277
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
278
+ else nn.Embedding(vocab, dim)
279
+ self.pos_embedding = T5RelativeEmbedding(
280
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
281
+ self.dropout = nn.Dropout(dropout)
282
+ self.blocks = nn.ModuleList([
283
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
284
+ shared_pos, dropout) for _ in range(num_layers)
285
+ ])
286
+ self.norm = T5LayerNorm(dim)
287
+
288
+ # initialize weights
289
+ self.apply(init_weights)
290
+
291
+ def forward(
292
+ self,
293
+ input_ids: Optional[torch.LongTensor] = None,
294
+ attention_mask: Optional[torch.FloatTensor] = None,
295
+ ):
296
+ x = self.token_embedding(input_ids)
297
+ x = self.dropout(x)
298
+ e = self.pos_embedding(x.size(1),
299
+ x.size(1)) if self.shared_pos else None
300
+ for block in self.blocks:
301
+ x = block(x, attention_mask, pos_bias=e)
302
+ x = self.norm(x)
303
+ x = self.dropout(x)
304
+ return (x, )
305
+
306
+ @classmethod
307
+ def from_pretrained(cls, pretrained_model_path, additional_kwargs={}, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16):
308
+ def filter_kwargs(cls, kwargs):
309
+ import inspect
310
+ sig = inspect.signature(cls.__init__)
311
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
312
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
313
+ return filtered_kwargs
314
+
315
+ if low_cpu_mem_usage:
316
+ try:
317
+ import re
318
+
319
+ from diffusers.models.modeling_utils import \
320
+ load_model_dict_into_meta
321
+ from diffusers.utils import is_accelerate_available
322
+ if is_accelerate_available():
323
+ import accelerate
324
+
325
+ # Instantiate model with empty weights
326
+ with accelerate.init_empty_weights():
327
+ model = cls(**filter_kwargs(cls, additional_kwargs))
328
+
329
+ param_device = "cpu"
330
+ if pretrained_model_path.endswith(".safetensors"):
331
+ from safetensors.torch import load_file
332
+ state_dict = load_file(pretrained_model_path)
333
+ else:
334
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
335
+ # move the params from meta device to cpu
336
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
337
+ if len(missing_keys) > 0:
338
+ raise ValueError(
339
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
340
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
341
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
342
+ " those weights or else make sure your checkpoint file is correct."
343
+ )
344
+
345
+ unexpected_keys = load_model_dict_into_meta(
346
+ model,
347
+ state_dict,
348
+ device=param_device,
349
+ dtype=torch_dtype,
350
+ model_name_or_path=pretrained_model_path,
351
+ )
352
+
353
+ if cls._keys_to_ignore_on_load_unexpected is not None:
354
+ for pat in cls._keys_to_ignore_on_load_unexpected:
355
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
356
+
357
+ if len(unexpected_keys) > 0:
358
+ print(
359
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
360
+ )
361
+ return model
362
+ except Exception as e:
363
+ print(
364
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
365
+ )
366
+
367
+ model = cls(**filter_kwargs(cls, additional_kwargs))
368
+ if pretrained_model_path.endswith(".safetensors"):
369
+ from safetensors.torch import load_file, safe_open
370
+ state_dict = load_file(pretrained_model_path)
371
+ else:
372
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
373
+ m, u = model.load_state_dict(state_dict, strict=False)
374
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
375
+ print(m, u)
376
+ return model
rose/models/wan_transformer3d.py ADDED
@@ -0,0 +1,1203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+
4
+ import glob
5
+ import json
6
+ import math
7
+ import os
8
+ import types
9
+ import warnings
10
+ from typing import Any, Dict, Optional, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.cuda.amp as amp
15
+ import torch.nn as nn
16
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
17
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
18
+ from diffusers.loaders import PeftAdapterMixin
19
+ from diffusers.models.modeling_utils import ModelMixin
20
+ from diffusers.utils import is_torch_version, logging
21
+ from torch import nn
22
+
23
+ from ..dist import (get_sequence_parallel_rank,
24
+ get_sequence_parallel_world_size, get_sp_group,
25
+ xFuserLongContextAttention)
26
+ from ..dist.wan_xfuser import usp_attn_forward
27
+ from .cache_utils import TeaCache
28
+
29
+ try:
30
+ import flash_attn_interface
31
+ FLASH_ATTN_3_AVAILABLE = True
32
+ except ModuleNotFoundError:
33
+ FLASH_ATTN_3_AVAILABLE = False
34
+
35
+ try:
36
+ import flash_attn
37
+ FLASH_ATTN_2_AVAILABLE = True
38
+ except ModuleNotFoundError:
39
+ FLASH_ATTN_2_AVAILABLE = False
40
+
41
+
42
+ def flash_attention(
43
+ q,
44
+ k,
45
+ v,
46
+ q_lens=None,
47
+ k_lens=None,
48
+ dropout_p=0.,
49
+ softmax_scale=None,
50
+ q_scale=None,
51
+ causal=False,
52
+ window_size=(-1, -1),
53
+ deterministic=False,
54
+ dtype=torch.bfloat16,
55
+ version=None,
56
+ ):
57
+ """
58
+ q: [B, Lq, Nq, C1].
59
+ k: [B, Lk, Nk, C1].
60
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
61
+ q_lens: [B].
62
+ k_lens: [B].
63
+ dropout_p: float. Dropout probability.
64
+ softmax_scale: float. The scaling of QK^T before applying softmax.
65
+ causal: bool. Whether to apply causal attention mask.
66
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
67
+ deterministic: bool. If True, slightly slower and uses more memory.
68
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
69
+ """
70
+ half_dtypes = (torch.float16, torch.bfloat16)
71
+ assert dtype in half_dtypes
72
+ assert q.device.type == 'cuda' and q.size(-1) <= 256
73
+
74
+ # params
75
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
76
+
77
+ def half(x):
78
+ return x if x.dtype in half_dtypes else x.to(dtype)
79
+
80
+ # preprocess query
81
+ if q_lens is None:
82
+ q = half(q.flatten(0, 1))
83
+ q_lens = torch.tensor(
84
+ [lq] * b, dtype=torch.int32).to(
85
+ device=q.device, non_blocking=True)
86
+ else:
87
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
88
+
89
+ # preprocess key, value
90
+ if k_lens is None:
91
+ k = half(k.flatten(0, 1))
92
+ v = half(v.flatten(0, 1))
93
+ k_lens = torch.tensor(
94
+ [lk] * b, dtype=torch.int32).to(
95
+ device=k.device, non_blocking=True)
96
+ else:
97
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
98
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
99
+
100
+ q = q.to(v.dtype)
101
+ k = k.to(v.dtype)
102
+
103
+ if q_scale is not None:
104
+ q = q * q_scale
105
+
106
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
107
+ warnings.warn(
108
+ 'Flash attention 3 is not available, use flash attention 2 instead.'
109
+ )
110
+
111
+ # apply attention
112
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
113
+ # Note: dropout_p, window_size are not supported in FA3 now.
114
+ x = flash_attn_interface.flash_attn_varlen_func(
115
+ q=q,
116
+ k=k,
117
+ v=v,
118
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
119
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
120
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
121
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
122
+ seqused_q=None,
123
+ seqused_k=None,
124
+ max_seqlen_q=lq,
125
+ max_seqlen_k=lk,
126
+ softmax_scale=softmax_scale,
127
+ causal=causal,
128
+ deterministic=deterministic)[0].unflatten(0, (b, lq))
129
+ else:
130
+ assert FLASH_ATTN_2_AVAILABLE
131
+ x = flash_attn.flash_attn_varlen_func(
132
+ q=q,
133
+ k=k,
134
+ v=v,
135
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
136
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
137
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
138
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
139
+ max_seqlen_q=lq,
140
+ max_seqlen_k=lk,
141
+ dropout_p=dropout_p,
142
+ softmax_scale=softmax_scale,
143
+ causal=causal,
144
+ window_size=window_size,
145
+ deterministic=deterministic).unflatten(0, (b, lq))
146
+
147
+ # output
148
+ return x.type(out_dtype)
149
+
150
+
151
+ def attention(
152
+ q,
153
+ k,
154
+ v,
155
+ q_lens=None,
156
+ k_lens=None,
157
+ dropout_p=0.,
158
+ softmax_scale=None,
159
+ q_scale=None,
160
+ causal=False,
161
+ window_size=(-1, -1),
162
+ deterministic=False,
163
+ dtype=torch.bfloat16,
164
+ fa_version=None,
165
+ ):
166
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
167
+ return flash_attention(
168
+ q=q,
169
+ k=k,
170
+ v=v,
171
+ q_lens=q_lens,
172
+ k_lens=k_lens,
173
+ dropout_p=dropout_p,
174
+ softmax_scale=softmax_scale,
175
+ q_scale=q_scale,
176
+ causal=causal,
177
+ window_size=window_size,
178
+ deterministic=deterministic,
179
+ dtype=dtype,
180
+ version=fa_version,
181
+ )
182
+ else:
183
+ if q_lens is not None or k_lens is not None:
184
+ warnings.warn(
185
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
186
+ )
187
+ attn_mask = None
188
+
189
+ q = q.transpose(1, 2)
190
+ k = k.transpose(1, 2)
191
+ v = v.transpose(1, 2)
192
+
193
+ out = torch.nn.functional.scaled_dot_product_attention(
194
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
195
+
196
+ out = out.transpose(1, 2).contiguous()
197
+ return out
198
+
199
+
200
+ def sinusoidal_embedding_1d(dim, position):
201
+ # preprocess
202
+ assert dim % 2 == 0
203
+ half = dim // 2
204
+ position = position.type(torch.float64)
205
+
206
+ # calculation
207
+ sinusoid = torch.outer(
208
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
209
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
210
+ return x
211
+
212
+
213
+ @amp.autocast(enabled=False)
214
+ def rope_params(max_seq_len, dim, theta=10000):
215
+ assert dim % 2 == 0
216
+ freqs = torch.outer(
217
+ torch.arange(max_seq_len),
218
+ 1.0 / torch.pow(theta,
219
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
220
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
221
+ return freqs
222
+
223
+ # modified from https://github.com/thu-ml/RIFLEx/blob/main/riflex_utils.py
224
+ @amp.autocast(enabled=False)
225
+ def get_1d_rotary_pos_embed_riflex(
226
+ pos: Union[np.ndarray, int],
227
+ dim: int,
228
+ theta: float = 10000.0,
229
+ use_real=False,
230
+ k: Optional[int] = None,
231
+ L_test: Optional[int] = None,
232
+ L_test_scale: Optional[int] = None,
233
+ ):
234
+ """
235
+ RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
236
+
237
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
238
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
239
+ data type.
240
+
241
+ Args:
242
+ dim (`int`): Dimension of the frequency tensor.
243
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
244
+ theta (`float`, *optional*, defaults to 10000.0):
245
+ Scaling factor for frequency computation. Defaults to 10000.0.
246
+ use_real (`bool`, *optional*):
247
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
248
+ k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
249
+ L_test (`int`, *optional*, defaults to None): the number of frames for inference
250
+ Returns:
251
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
252
+ """
253
+ assert dim % 2 == 0
254
+
255
+ if isinstance(pos, int):
256
+ pos = torch.arange(pos)
257
+ if isinstance(pos, np.ndarray):
258
+ pos = torch.from_numpy(pos) # type: ignore # [S]
259
+
260
+ freqs = 1.0 / torch.pow(theta,
261
+ torch.arange(0, dim, 2).to(torch.float64).div(dim))
262
+
263
+ # === Riflex modification start ===
264
+ # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
265
+ # Empirical observations show that a few videos may exhibit repetition in the tail frames.
266
+ # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
267
+ if k is not None:
268
+ freqs[k-1] = 0.9 * 2 * torch.pi / L_test
269
+ # === Riflex modification end ===
270
+ if L_test_scale is not None:
271
+ freqs[k-1] = freqs[k-1] / L_test_scale
272
+
273
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
274
+ if use_real:
275
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
276
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
277
+ return freqs_cos, freqs_sin
278
+ else:
279
+ # lumina
280
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
281
+ return freqs_cis
282
+
283
+ @amp.autocast(enabled=False)
284
+ def rope_apply(x, grid_sizes, freqs):
285
+ n, c = x.size(2), x.size(3) // 2
286
+
287
+ # split freqs
288
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
289
+
290
+ # loop over samples
291
+ output = []
292
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
293
+ seq_len = f * h * w
294
+
295
+ # precompute multipliers
296
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
297
+ seq_len, n, -1, 2))
298
+ freqs_i = torch.cat([
299
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
300
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
301
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
302
+ ],
303
+ dim=-1).reshape(seq_len, 1, -1)
304
+
305
+ # apply rotary embedding
306
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
307
+ x_i = torch.cat([x_i, x[i, seq_len:]])
308
+
309
+ # append to collection
310
+ output.append(x_i)
311
+ return torch.stack(output).float()
312
+
313
+
314
+ class WanRMSNorm(nn.Module):
315
+
316
+ def __init__(self, dim, eps=1e-5):
317
+ super().__init__()
318
+ self.dim = dim
319
+ self.eps = eps
320
+ self.weight = nn.Parameter(torch.ones(dim))
321
+
322
+ def forward(self, x):
323
+ r"""
324
+ Args:
325
+ x(Tensor): Shape [B, L, C]
326
+ """
327
+ return self._norm(x.float()).type_as(x) * self.weight
328
+
329
+ def _norm(self, x):
330
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
331
+
332
+
333
+ class WanLayerNorm(nn.LayerNorm):
334
+
335
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
336
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
337
+
338
+ def forward(self, x):
339
+ r"""
340
+ Args:
341
+ x(Tensor): Shape [B, L, C]
342
+ """
343
+ return super().forward(x.float()).type_as(x)
344
+
345
+
346
+ class WanSelfAttention(nn.Module):
347
+
348
+ def __init__(self,
349
+ dim,
350
+ num_heads,
351
+ window_size=(-1, -1),
352
+ qk_norm=True,
353
+ eps=1e-6):
354
+ assert dim % num_heads == 0
355
+ super().__init__()
356
+ self.dim = dim
357
+ self.num_heads = num_heads
358
+ self.head_dim = dim // num_heads
359
+ self.window_size = window_size
360
+ self.qk_norm = qk_norm
361
+ self.eps = eps
362
+
363
+ # layers
364
+ self.q = nn.Linear(dim, dim)
365
+ self.k = nn.Linear(dim, dim)
366
+ self.v = nn.Linear(dim, dim)
367
+ self.o = nn.Linear(dim, dim)
368
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
369
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
370
+
371
+ def forward(self, x, seq_lens, grid_sizes, freqs, dtype):
372
+ r"""
373
+ Args:
374
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
375
+ seq_lens(Tensor): Shape [B]
376
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
377
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
378
+ """
379
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
380
+
381
+ # query, key, value function
382
+ def qkv_fn(x):
383
+ q = self.norm_q(self.q(x.to(dtype))).view(b, s, n, d)
384
+ k = self.norm_k(self.k(x.to(dtype))).view(b, s, n, d)
385
+ v = self.v(x.to(dtype)).view(b, s, n, d)
386
+ return q, k, v
387
+
388
+ q, k, v = qkv_fn(x)
389
+
390
+ x = attention(
391
+ q=rope_apply(q, grid_sizes, freqs).to(dtype),
392
+ k=rope_apply(k, grid_sizes, freqs).to(dtype),
393
+ v=v.to(dtype),
394
+ k_lens=seq_lens,
395
+ window_size=self.window_size)
396
+ x = x.to(dtype)
397
+
398
+ # output
399
+ x = x.flatten(2)
400
+ x = self.o(x)
401
+ return x
402
+
403
+
404
+ class WanT2VCrossAttention(WanSelfAttention):
405
+
406
+ def forward(self, x, context, context_lens, dtype):
407
+ r"""
408
+ Args:
409
+ x(Tensor): Shape [B, L1, C]
410
+ context(Tensor): Shape [B, L2, C]
411
+ context_lens(Tensor): Shape [B]
412
+ """
413
+ b, n, d = x.size(0), self.num_heads, self.head_dim
414
+
415
+ # compute query, key, value
416
+ q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
417
+ k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
418
+ v = self.v(context.to(dtype)).view(b, -1, n, d)
419
+
420
+ # compute attention
421
+ x = attention(
422
+ q.to(dtype),
423
+ k.to(dtype),
424
+ v.to(dtype),
425
+ k_lens=context_lens
426
+ )
427
+ x = x.to(dtype)
428
+
429
+ # output
430
+ x = x.flatten(2)
431
+ x = self.o(x)
432
+ return x
433
+
434
+
435
+ class WanI2VCrossAttention(WanSelfAttention):
436
+
437
+ def __init__(self,
438
+ dim,
439
+ num_heads,
440
+ window_size=(-1, -1),
441
+ qk_norm=True,
442
+ eps=1e-6):
443
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
444
+
445
+ self.k_img = nn.Linear(dim, dim)
446
+ self.v_img = nn.Linear(dim, dim)
447
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
448
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
449
+
450
+ def forward(self, x, context, context_lens, dtype):
451
+ r"""
452
+ Args:
453
+ x(Tensor): Shape [B, L1, C]
454
+ context(Tensor): Shape [B, L2, C]
455
+ context_lens(Tensor): Shape [B]
456
+ """
457
+ context_img = context[:, :257]
458
+ context = context[:, 257:]
459
+ b, n, d = x.size(0), self.num_heads, self.head_dim
460
+
461
+ # compute query, key, value
462
+ q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
463
+ k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
464
+ v = self.v(context.to(dtype)).view(b, -1, n, d)
465
+ k_img = self.norm_k_img(self.k_img(context_img.to(dtype))).view(b, -1, n, d)
466
+ v_img = self.v_img(context_img.to(dtype)).view(b, -1, n, d)
467
+
468
+ img_x = attention(
469
+ q.to(dtype),
470
+ k_img.to(dtype),
471
+ v_img.to(dtype),
472
+ k_lens=None
473
+ )
474
+ img_x = img_x.to(dtype)
475
+ # compute attention
476
+ x = attention(
477
+ q.to(dtype),
478
+ k.to(dtype),
479
+ v.to(dtype),
480
+ k_lens=context_lens
481
+ )
482
+ x = x.to(dtype)
483
+
484
+ # output
485
+ x = x.flatten(2)
486
+ img_x = img_x.flatten(2)
487
+ x = x + img_x
488
+ x = self.o(x)
489
+ return x
490
+
491
+
492
+ WAN_CROSSATTENTION_CLASSES = {
493
+ 't2v_cross_attn': WanT2VCrossAttention,
494
+ 'i2v_cross_attn': WanI2VCrossAttention,
495
+ }
496
+
497
+
498
+ class WanAttentionBlock(nn.Module):
499
+
500
+ def __init__(self,
501
+ cross_attn_type,
502
+ dim,
503
+ ffn_dim,
504
+ num_heads,
505
+ window_size=(-1, -1),
506
+ qk_norm=True,
507
+ cross_attn_norm=False,
508
+ eps=1e-6):
509
+ super().__init__()
510
+ self.dim = dim
511
+ self.ffn_dim = ffn_dim
512
+ self.num_heads = num_heads
513
+ self.window_size = window_size
514
+ self.qk_norm = qk_norm
515
+ self.cross_attn_norm = cross_attn_norm
516
+ self.eps = eps
517
+
518
+ # layers
519
+ self.norm1 = WanLayerNorm(dim, eps)
520
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
521
+ eps)
522
+ self.norm3 = WanLayerNorm(
523
+ dim, eps,
524
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
525
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
526
+ num_heads,
527
+ (-1, -1),
528
+ qk_norm,
529
+ eps)
530
+ self.norm2 = WanLayerNorm(dim, eps)
531
+ self.ffn = nn.Sequential(
532
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
533
+ nn.Linear(ffn_dim, dim))
534
+
535
+ # modulation
536
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
537
+
538
+ def forward(
539
+ self,
540
+ x,
541
+ e,
542
+ seq_lens,
543
+ grid_sizes,
544
+ freqs,
545
+ context,
546
+ context_lens,
547
+ dtype=torch.float32
548
+ ):
549
+ r"""
550
+ Args:
551
+ x(Tensor): Shape [B, L, C]
552
+ e(Tensor): Shape [B, 6, C]
553
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
554
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
555
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
556
+ """
557
+ e = (self.modulation + e).chunk(6, dim=1)
558
+
559
+ # self-attention
560
+ temp_x = self.norm1(x) * (1 + e[1]) + e[0]
561
+ temp_x = temp_x.to(dtype)
562
+
563
+ y = self.self_attn(temp_x, seq_lens, grid_sizes, freqs, dtype)
564
+ x = x + y * e[2]
565
+
566
+ # cross-attention & ffn function
567
+ def cross_attn_ffn(x, context, context_lens, e):
568
+ # cross-attention
569
+ x = x + self.cross_attn(self.norm3(x), context, context_lens, dtype)
570
+
571
+ # ffn function
572
+ temp_x = self.norm2(x) * (1 + e[4]) + e[3]
573
+ temp_x = temp_x.to(dtype)
574
+
575
+ y = self.ffn(temp_x)
576
+ x = x + y * e[5]
577
+ return x
578
+
579
+ x = cross_attn_ffn(x, context, context_lens, e)
580
+ return x
581
+
582
+
583
+ class Head(nn.Module):
584
+
585
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
586
+ super().__init__()
587
+ self.dim = dim
588
+ self.out_dim = out_dim
589
+ self.patch_size = patch_size
590
+ self.eps = eps
591
+
592
+ # layers
593
+ out_dim = math.prod(patch_size) * out_dim
594
+ self.norm = WanLayerNorm(dim, eps)
595
+ self.head = nn.Linear(dim, out_dim)
596
+
597
+ # modulation
598
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
599
+
600
+ def forward(self, x, e):
601
+ r"""
602
+ Args:
603
+ x(Tensor): Shape [B, L1, C]
604
+ e(Tensor): Shape [B, C]
605
+ """
606
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
607
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
608
+ return x
609
+
610
+
611
+ class MLPProj(torch.nn.Module):
612
+
613
+ def __init__(self, in_dim, out_dim):
614
+ super().__init__()
615
+
616
+ self.proj = torch.nn.Sequential(
617
+ torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
618
+ torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
619
+ torch.nn.LayerNorm(out_dim))
620
+
621
+ def forward(self, image_embeds):
622
+ clip_extra_context_tokens = self.proj(image_embeds)
623
+ return clip_extra_context_tokens
624
+
625
+
626
+
627
+ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
628
+ r"""
629
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
630
+ """
631
+
632
+ # ignore_for_config = [
633
+ # 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
634
+ # ]
635
+ # _no_split_modules = ['WanAttentionBlock']
636
+ _supports_gradient_checkpointing = True
637
+
638
+ @register_to_config
639
+ def __init__(
640
+ self,
641
+ model_type='t2v',
642
+ patch_size=(1, 2, 2),
643
+ text_len=512,
644
+ in_dim=16,
645
+ dim=2048,
646
+ ffn_dim=8192,
647
+ freq_dim=256,
648
+ text_dim=4096,
649
+ out_dim=16,
650
+ num_heads=16,
651
+ num_layers=32,
652
+ window_size=(-1, -1),
653
+ qk_norm=True,
654
+ cross_attn_norm=True,
655
+ eps=1e-6,
656
+ in_channels=16,
657
+ hidden_size=2048,
658
+ ):
659
+ r"""
660
+ Initialize the diffusion model backbone.
661
+
662
+ Args:
663
+ model_type (`str`, *optional*, defaults to 't2v'):
664
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
665
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
666
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
667
+ text_len (`int`, *optional*, defaults to 512):
668
+ Fixed length for text embeddings
669
+ in_dim (`int`, *optional*, defaults to 16):
670
+ Input video channels (C_in)
671
+ dim (`int`, *optional*, defaults to 2048):
672
+ Hidden dimension of the transformer
673
+ ffn_dim (`int`, *optional*, defaults to 8192):
674
+ Intermediate dimension in feed-forward network
675
+ freq_dim (`int`, *optional*, defaults to 256):
676
+ Dimension for sinusoidal time embeddings
677
+ text_dim (`int`, *optional*, defaults to 4096):
678
+ Input dimension for text embeddings
679
+ out_dim (`int`, *optional*, defaults to 16):
680
+ Output video channels (C_out)
681
+ num_heads (`int`, *optional*, defaults to 16):
682
+ Number of attention heads
683
+ num_layers (`int`, *optional*, defaults to 32):
684
+ Number of transformer blocks
685
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
686
+ Window size for local attention (-1 indicates global attention)
687
+ qk_norm (`bool`, *optional*, defaults to True):
688
+ Enable query/key normalization
689
+ cross_attn_norm (`bool`, *optional*, defaults to False):
690
+ Enable cross-attention normalization
691
+ eps (`float`, *optional*, defaults to 1e-6):
692
+ Epsilon value for normalization layers
693
+ """
694
+
695
+ super().__init__()
696
+
697
+ assert model_type in ['t2v', 'i2v']
698
+ self.model_type = model_type
699
+
700
+ self.patch_size = patch_size
701
+ self.text_len = text_len
702
+ self.in_dim = in_dim
703
+ self.dim = dim
704
+ self.ffn_dim = ffn_dim
705
+ self.freq_dim = freq_dim
706
+ self.text_dim = text_dim
707
+ self.out_dim = out_dim
708
+ self.num_heads = num_heads
709
+ self.num_layers = num_layers
710
+ self.window_size = window_size
711
+ self.qk_norm = qk_norm
712
+ self.cross_attn_norm = cross_attn_norm
713
+ self.eps = eps
714
+
715
+ # embeddings
716
+ self.patch_embedding = nn.Conv3d(
717
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
718
+ self.text_embedding = nn.Sequential(
719
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
720
+ nn.Linear(dim, dim))
721
+
722
+ self.time_embedding = nn.Sequential(
723
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
724
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
725
+
726
+ # blocks
727
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
728
+ self.blocks = nn.ModuleList([
729
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
730
+ window_size, qk_norm, cross_attn_norm, eps)
731
+ for _ in range(num_layers)
732
+ ])
733
+
734
+ # head
735
+ self.head = Head(dim, out_dim, patch_size, eps)
736
+
737
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
738
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
739
+ d = dim // num_heads
740
+ self.d = d
741
+ self.freqs = torch.cat(
742
+ [
743
+ rope_params(1024, d - 4 * (d // 6)),
744
+ rope_params(1024, 2 * (d // 6)),
745
+ rope_params(1024, 2 * (d // 6))
746
+ ],
747
+ dim=1
748
+ )
749
+
750
+ if model_type == 'i2v':
751
+ self.img_emb = MLPProj(1280, dim)
752
+
753
+ self.teacache = None
754
+ self.gradient_checkpointing = False
755
+ self.sp_world_size = 1
756
+ self.sp_world_rank = 0
757
+
758
+ def enable_teacache(
759
+ self,
760
+ coefficients,
761
+ num_steps: int,
762
+ rel_l1_thresh: float,
763
+ num_skip_start_steps: int = 0,
764
+ offload: bool = True
765
+ ):
766
+ self.teacache = TeaCache(
767
+ coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload
768
+ )
769
+
770
+ def disable_teacache(self):
771
+ self.teacache = None
772
+
773
+ def enable_riflex(
774
+ self,
775
+ k = 6,
776
+ L_test = 66,
777
+ L_test_scale = 4.886,
778
+ ):
779
+ device = self.freqs.device
780
+ self.freqs = torch.cat(
781
+ [
782
+ get_1d_rotary_pos_embed_riflex(1024, self.d - 4 * (self.d // 6), use_real=False, k=k, L_test=L_test, L_test_scale=L_test_scale),
783
+ rope_params(1024, 2 * (self.d // 6)),
784
+ rope_params(1024, 2 * (self.d // 6))
785
+ ],
786
+ dim=1
787
+ ).to(device)
788
+
789
+ def disable_riflex(self):
790
+ device = self.freqs.device
791
+ self.freqs = torch.cat(
792
+ [
793
+ rope_params(1024, self.d - 4 * (self.d // 6)),
794
+ rope_params(1024, 2 * (self.d // 6)),
795
+ rope_params(1024, 2 * (self.d // 6))
796
+ ],
797
+ dim=1
798
+ ).to(device)
799
+
800
+ def enable_multi_gpus_inference(self,):
801
+ self.sp_world_size = get_sequence_parallel_world_size()
802
+ self.sp_world_rank = get_sequence_parallel_rank()
803
+ for block in self.blocks:
804
+ block.self_attn.forward = types.MethodType(
805
+ usp_attn_forward, block.self_attn)
806
+
807
+ def _set_gradient_checkpointing(self, module, value=False):
808
+ self.gradient_checkpointing = value
809
+
810
+ def forward(
811
+ self,
812
+ x,
813
+ t,
814
+ context,
815
+ seq_len,
816
+ clip_fea=None,
817
+ y=None,
818
+ cond_flag=True,
819
+ return_intermediate=False, # whether return mid layers' output
820
+ selected_layers=(5, 15, 25) # layer idx to output
821
+ ):
822
+ r"""
823
+ Forward pass through the diffusion model
824
+
825
+ Args:
826
+ x (List[Tensor]):
827
+ List of input video tensors, each with shape [C_in, F, H, W]
828
+ t (Tensor):
829
+ Diffusion timesteps tensor of shape [B]
830
+ context (List[Tensor]):
831
+ List of text embeddings each with shape [L, C]
832
+ seq_len (`int`):
833
+ Maximum sequence length for positional encoding
834
+ clip_fea (Tensor, *optional*):
835
+ CLIP image features for image-to-video mode
836
+ y (List[Tensor], *optional*):
837
+ Conditional video inputs for image-to-video mode, same shape as x
838
+ cond_flag (`bool`, *optional*, defaults to True):
839
+ Flag to indicate whether to forward the condition input
840
+
841
+ Returns:
842
+ List[Tensor]:
843
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
844
+ """
845
+ if self.model_type == 'i2v':
846
+ assert clip_fea is not None and y is not None
847
+ # params
848
+ device = self.patch_embedding.weight.device
849
+ dtype = x.dtype
850
+ if self.freqs.device != device and torch.device(type="meta") != device:
851
+ self.freqs = self.freqs.to(device)
852
+
853
+ if y is not None:
854
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
855
+
856
+ # embeddings
857
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
858
+ grid_sizes = torch.stack(
859
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
860
+ x = [u.flatten(2).transpose(1, 2) for u in x]
861
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
862
+ if self.sp_world_size > 1:
863
+ seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
864
+ assert seq_lens.max() <= seq_len
865
+ x = torch.cat([
866
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
867
+ dim=1) for u in x
868
+ ])
869
+
870
+ # time embeddings
871
+ with amp.autocast(dtype=torch.float32):
872
+ e = self.time_embedding(
873
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
874
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
875
+ # to bfloat16 for saving memeory
876
+ # assert e.dtype == torch.float32 and e0.dtype == torch.float32
877
+ e0 = e0.to(dtype)
878
+ e = e.to(dtype)
879
+
880
+ # context
881
+ context_lens = None
882
+ context = self.text_embedding(
883
+ torch.stack([
884
+ torch.cat(
885
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
886
+ for u in context
887
+ ]))
888
+
889
+ if clip_fea is not None:
890
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
891
+ context = torch.concat([context_clip, context], dim=1)
892
+
893
+ # Context Parallel
894
+ if self.sp_world_size > 1:
895
+ x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
896
+
897
+ # TeaCache
898
+ if self.teacache is not None:
899
+ if cond_flag:
900
+ modulated_inp = e0
901
+ skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
902
+ if self.teacache.cnt == 0 or self.teacache.cnt == self.teacache.num_steps - 1 or skip_flag:
903
+ should_calc = True
904
+ self.teacache.accumulated_rel_l1_distance = 0
905
+ else:
906
+ if cond_flag:
907
+ rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
908
+ self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
909
+ if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
910
+ should_calc = False
911
+ else:
912
+ should_calc = True
913
+ self.teacache.accumulated_rel_l1_distance = 0
914
+ self.teacache.previous_modulated_input = modulated_inp
915
+ self.teacache.cnt += 1
916
+ if self.teacache.cnt == self.teacache.num_steps:
917
+ self.teacache.reset()
918
+ self.teacache.should_calc = should_calc
919
+ else:
920
+ should_calc = self.teacache.should_calc
921
+
922
+ intermediate_features = []
923
+
924
+ # TeaCache
925
+ if self.teacache is not None:
926
+ if not should_calc:
927
+ previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
928
+ x = x + previous_residual.to(x.device)
929
+ else:
930
+ ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
931
+
932
+ for idx, block in enumerate(self.blocks):
933
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
934
+
935
+ def create_custom_forward(module):
936
+ def custom_forward(*inputs):
937
+ return module(*inputs)
938
+
939
+ return custom_forward
940
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
941
+ x = torch.utils.checkpoint.checkpoint(
942
+ create_custom_forward(block),
943
+ x,
944
+ e0,
945
+ seq_lens,
946
+ grid_sizes,
947
+ self.freqs,
948
+ context,
949
+ context_lens,
950
+ dtype,
951
+ **ckpt_kwargs,
952
+ )
953
+ else:
954
+ # arguments
955
+ kwargs = dict(
956
+ e=e0,
957
+ seq_lens=seq_lens,
958
+ grid_sizes=grid_sizes,
959
+ freqs=self.freqs,
960
+ context=context,
961
+ context_lens=context_lens,
962
+ dtype=dtype
963
+ )
964
+ x = block(x, **kwargs)
965
+
966
+ if return_intermediate and idx in selected_layers:
967
+ intermediate_features.append(x.clone())
968
+
969
+ if cond_flag:
970
+ self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
971
+ else:
972
+ self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
973
+ else:
974
+ for idx, block in enumerate(self.blocks):
975
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
976
+
977
+ def create_custom_forward(module):
978
+ def custom_forward(*inputs):
979
+ return module(*inputs)
980
+
981
+ return custom_forward
982
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
983
+ x = torch.utils.checkpoint.checkpoint(
984
+ create_custom_forward(block),
985
+ x,
986
+ e0,
987
+ seq_lens,
988
+ grid_sizes,
989
+ self.freqs,
990
+ context,
991
+ context_lens,
992
+ dtype,
993
+ **ckpt_kwargs,
994
+ )
995
+ else:
996
+ # arguments
997
+ kwargs = dict(
998
+ e=e0,
999
+ seq_lens=seq_lens,
1000
+ grid_sizes=grid_sizes,
1001
+ freqs=self.freqs,
1002
+ context=context,
1003
+ context_lens=context_lens,
1004
+ dtype=dtype
1005
+ )
1006
+ x = block(x, **kwargs)
1007
+
1008
+ if return_intermediate and idx in selected_layers:
1009
+ intermediate_features.append(x.clone())
1010
+
1011
+
1012
+ if self.sp_world_size > 1:
1013
+ x = get_sp_group().all_gather(x, dim=1)
1014
+
1015
+ # head
1016
+ x = self.head(x, e)
1017
+
1018
+ # unpatchify
1019
+ x = self.unpatchify(x, grid_sizes)
1020
+ x = torch.stack(x)
1021
+
1022
+ if return_intermediate:
1023
+ mid_feat = intermediate_features
1024
+ return x, mid_feat
1025
+ else:
1026
+ return x
1027
+
1028
+
1029
+ def unpatchify(self, x, grid_sizes):
1030
+ r"""
1031
+ Reconstruct video tensors from patch embeddings.
1032
+
1033
+ Args:
1034
+ x (List[Tensor]):
1035
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
1036
+ grid_sizes (Tensor):
1037
+ Original spatial-temporal grid dimensions before patching,
1038
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
1039
+
1040
+ Returns:
1041
+ List[Tensor]:
1042
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
1043
+ """
1044
+
1045
+ c = self.out_dim
1046
+ out = []
1047
+ for u, v in zip(x, grid_sizes.tolist()):
1048
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
1049
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
1050
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
1051
+ out.append(u)
1052
+ return out
1053
+
1054
+ def init_weights(self):
1055
+ r"""
1056
+ Initialize model parameters using Xavier initialization.
1057
+ """
1058
+
1059
+ # basic init
1060
+ for m in self.modules():
1061
+ if isinstance(m, nn.Linear):
1062
+ nn.init.xavier_uniform_(m.weight)
1063
+ if m.bias is not None:
1064
+ nn.init.zeros_(m.bias)
1065
+
1066
+ # init embeddings
1067
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
1068
+ for m in self.text_embedding.modules():
1069
+ if isinstance(m, nn.Linear):
1070
+ nn.init.normal_(m.weight, std=.02)
1071
+ for m in self.time_embedding.modules():
1072
+ if isinstance(m, nn.Linear):
1073
+ nn.init.normal_(m.weight, std=.02)
1074
+
1075
+ # init output layer
1076
+ nn.init.zeros_(self.head.head.weight)
1077
+
1078
+ @classmethod
1079
+ def from_pretrained(
1080
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
1081
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
1082
+ ):
1083
+ if subfolder is not None:
1084
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1085
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
1086
+
1087
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1088
+ if not os.path.isfile(config_file):
1089
+ raise RuntimeError(f"{config_file} does not exist")
1090
+ with open(config_file, "r") as f:
1091
+ config = json.load(f)
1092
+
1093
+ from diffusers.utils import WEIGHTS_NAME
1094
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1095
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
1096
+
1097
+ if "dict_mapping" in transformer_additional_kwargs.keys():
1098
+ for key in transformer_additional_kwargs["dict_mapping"]:
1099
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
1100
+
1101
+ if low_cpu_mem_usage:
1102
+ try:
1103
+ import re
1104
+
1105
+ from diffusers.models.modeling_utils import \
1106
+ load_model_dict_into_meta
1107
+ from diffusers.utils import is_accelerate_available
1108
+ if is_accelerate_available():
1109
+ import accelerate
1110
+
1111
+ # Instantiate model with empty weights
1112
+ with accelerate.init_empty_weights():
1113
+ model = cls.from_config(config, **transformer_additional_kwargs)
1114
+
1115
+ param_device = "cpu"
1116
+ if os.path.exists(model_file):
1117
+ state_dict = torch.load(model_file, map_location="cpu")
1118
+ elif os.path.exists(model_file_safetensors):
1119
+ from safetensors.torch import load_file, safe_open
1120
+ state_dict = load_file(model_file_safetensors)
1121
+ else:
1122
+ from safetensors.torch import load_file, safe_open
1123
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1124
+ state_dict = {}
1125
+ print(model_files_safetensors)
1126
+ for _model_file_safetensors in model_files_safetensors:
1127
+ _state_dict = load_file(_model_file_safetensors)
1128
+ for key in _state_dict:
1129
+ state_dict[key] = _state_dict[key]
1130
+ model._convert_deprecated_attention_blocks(state_dict)
1131
+ # move the params from meta device to cpu
1132
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
1133
+ if len(missing_keys) > 0:
1134
+ raise ValueError(
1135
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
1136
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
1137
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
1138
+ " those weights or else make sure your checkpoint file is correct."
1139
+ )
1140
+
1141
+ unexpected_keys = load_model_dict_into_meta(
1142
+ model,
1143
+ state_dict,
1144
+ device=param_device,
1145
+ dtype=torch_dtype,
1146
+ model_name_or_path=pretrained_model_path,
1147
+ )
1148
+
1149
+ if cls._keys_to_ignore_on_load_unexpected is not None:
1150
+ for pat in cls._keys_to_ignore_on_load_unexpected:
1151
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1152
+
1153
+ if len(unexpected_keys) > 0:
1154
+ print(
1155
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1156
+ )
1157
+ return model
1158
+ except Exception as e:
1159
+ print(
1160
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
1161
+ )
1162
+
1163
+ model = cls.from_config(config, **transformer_additional_kwargs)
1164
+ if os.path.exists(model_file):
1165
+ state_dict = torch.load(model_file, map_location="cpu")
1166
+ elif os.path.exists(model_file_safetensors):
1167
+ from safetensors.torch import load_file, safe_open
1168
+ state_dict = load_file(model_file_safetensors)
1169
+ else:
1170
+ from safetensors.torch import load_file, safe_open
1171
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1172
+ state_dict = {}
1173
+ for _model_file_safetensors in model_files_safetensors:
1174
+ _state_dict = load_file(_model_file_safetensors)
1175
+ for key in _state_dict:
1176
+ state_dict[key] = _state_dict[key]
1177
+
1178
+ if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size():
1179
+ model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight']
1180
+ model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0
1181
+ state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight']
1182
+
1183
+ tmp_state_dict = {}
1184
+ for key in state_dict:
1185
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
1186
+ tmp_state_dict[key] = state_dict[key]
1187
+ else:
1188
+ print(key, "Size don't match, skip")
1189
+
1190
+ state_dict = tmp_state_dict
1191
+
1192
+ m, u = model.load_state_dict(state_dict, strict=False)
1193
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1194
+ print(m)
1195
+
1196
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
1197
+ print(f"### All Parameters: {sum(params) / 1e6} M")
1198
+
1199
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1200
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1201
+
1202
+ model = model.to(torch_dtype)
1203
+ return model
rose/models/wan_vae.py ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ from typing import Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
10
+ from diffusers.models.autoencoders.vae import (DecoderOutput,
11
+ DiagonalGaussianDistribution)
12
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+ from diffusers.utils.accelerate_utils import apply_forward_hook
15
+ from einops import rearrange
16
+
17
+ CACHE_T = 2
18
+
19
+
20
+ class CausalConv3d(nn.Conv3d):
21
+ """
22
+ Causal 3d convolusion.
23
+ """
24
+
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__(*args, **kwargs)
27
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
28
+ self.padding[1], 2 * self.padding[0], 0)
29
+ self.padding = (0, 0, 0)
30
+
31
+ def forward(self, x, cache_x=None):
32
+ padding = list(self._padding)
33
+ if cache_x is not None and self._padding[4] > 0:
34
+ cache_x = cache_x.to(x.device)
35
+ x = torch.cat([cache_x, x], dim=2)
36
+ padding[4] -= cache_x.shape[2]
37
+ x = F.pad(x, padding)
38
+
39
+ return super().forward(x)
40
+
41
+
42
+ class RMS_norm(nn.Module):
43
+
44
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
45
+ super().__init__()
46
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
47
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
48
+
49
+ self.channel_first = channel_first
50
+ self.scale = dim**0.5
51
+ self.gamma = nn.Parameter(torch.ones(shape))
52
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
53
+
54
+ def forward(self, x):
55
+ return F.normalize(
56
+ x, dim=(1 if self.channel_first else
57
+ -1)) * self.scale * self.gamma + self.bias
58
+
59
+
60
+ class Upsample(nn.Upsample):
61
+
62
+ def forward(self, x):
63
+ """
64
+ Fix bfloat16 support for nearest neighbor interpolation.
65
+ """
66
+ return super().forward(x.float()).type_as(x)
67
+
68
+
69
+ class Resample(nn.Module):
70
+
71
+ def __init__(self, dim, mode):
72
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
73
+ 'downsample3d')
74
+ super().__init__()
75
+ self.dim = dim
76
+ self.mode = mode
77
+
78
+ # layers
79
+ if mode == 'upsample2d':
80
+ self.resample = nn.Sequential(
81
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
82
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
83
+ elif mode == 'upsample3d':
84
+ self.resample = nn.Sequential(
85
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
86
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
87
+ self.time_conv = CausalConv3d(
88
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
89
+
90
+ elif mode == 'downsample2d':
91
+ self.resample = nn.Sequential(
92
+ nn.ZeroPad2d((0, 1, 0, 1)),
93
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
94
+ elif mode == 'downsample3d':
95
+ self.resample = nn.Sequential(
96
+ nn.ZeroPad2d((0, 1, 0, 1)),
97
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
98
+ self.time_conv = CausalConv3d(
99
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
100
+
101
+ else:
102
+ self.resample = nn.Identity()
103
+
104
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
105
+ b, c, t, h, w = x.size()
106
+ if self.mode == 'upsample3d':
107
+ if feat_cache is not None:
108
+ idx = feat_idx[0]
109
+ if feat_cache[idx] is None:
110
+ feat_cache[idx] = 'Rep'
111
+ feat_idx[0] += 1
112
+ else:
113
+
114
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
115
+ if cache_x.shape[2] < 2 and feat_cache[
116
+ idx] is not None and feat_cache[idx] != 'Rep':
117
+ # cache last frame of last two chunk
118
+ cache_x = torch.cat([
119
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
120
+ cache_x.device), cache_x
121
+ ],
122
+ dim=2)
123
+ if cache_x.shape[2] < 2 and feat_cache[
124
+ idx] is not None and feat_cache[idx] == 'Rep':
125
+ cache_x = torch.cat([
126
+ torch.zeros_like(cache_x).to(cache_x.device),
127
+ cache_x
128
+ ],
129
+ dim=2)
130
+ if feat_cache[idx] == 'Rep':
131
+ x = self.time_conv(x)
132
+ else:
133
+ x = self.time_conv(x, feat_cache[idx])
134
+ feat_cache[idx] = cache_x
135
+ feat_idx[0] += 1
136
+
137
+ x = x.reshape(b, 2, c, t, h, w)
138
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
139
+ 3)
140
+ x = x.reshape(b, c, t * 2, h, w)
141
+ t = x.shape[2]
142
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
143
+ x = self.resample(x)
144
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
145
+
146
+ if self.mode == 'downsample3d':
147
+ if feat_cache is not None:
148
+ idx = feat_idx[0]
149
+ if feat_cache[idx] is None:
150
+ feat_cache[idx] = x.clone()
151
+ feat_idx[0] += 1
152
+ else:
153
+
154
+ cache_x = x[:, :, -1:, :, :].clone()
155
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
156
+ # # cache last frame of last two chunk
157
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
158
+
159
+ x = self.time_conv(
160
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
161
+ feat_cache[idx] = cache_x
162
+ feat_idx[0] += 1
163
+ return x
164
+
165
+ def init_weight(self, conv):
166
+ conv_weight = conv.weight
167
+ nn.init.zeros_(conv_weight)
168
+ c1, c2, t, h, w = conv_weight.size()
169
+ one_matrix = torch.eye(c1, c2)
170
+ init_matrix = one_matrix
171
+ nn.init.zeros_(conv_weight)
172
+ #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
173
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
174
+ conv.weight.data.copy_(conv_weight)
175
+ nn.init.zeros_(conv.bias.data)
176
+
177
+ def init_weight2(self, conv):
178
+ conv_weight = conv.weight.data
179
+ nn.init.zeros_(conv_weight)
180
+ c1, c2, t, h, w = conv_weight.size()
181
+ init_matrix = torch.eye(c1 // 2, c2)
182
+ #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
183
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
184
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
185
+ conv.weight.data.copy_(conv_weight)
186
+ nn.init.zeros_(conv.bias.data)
187
+
188
+
189
+ class ResidualBlock(nn.Module):
190
+
191
+ def __init__(self, in_dim, out_dim, dropout=0.0):
192
+ super().__init__()
193
+ self.in_dim = in_dim
194
+ self.out_dim = out_dim
195
+
196
+ # layers
197
+ self.residual = nn.Sequential(
198
+ RMS_norm(in_dim, images=False), nn.SiLU(),
199
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
200
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
201
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
202
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
203
+ if in_dim != out_dim else nn.Identity()
204
+
205
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
206
+ h = self.shortcut(x)
207
+ for layer in self.residual:
208
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
209
+ idx = feat_idx[0]
210
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
211
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
212
+ # cache last frame of last two chunk
213
+ cache_x = torch.cat([
214
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
215
+ cache_x.device), cache_x
216
+ ],
217
+ dim=2)
218
+ x = layer(x, feat_cache[idx])
219
+ feat_cache[idx] = cache_x
220
+ feat_idx[0] += 1
221
+ else:
222
+ x = layer(x)
223
+ return x + h
224
+
225
+
226
+ class AttentionBlock(nn.Module):
227
+ """
228
+ Causal self-attention with a single head.
229
+ """
230
+
231
+ def __init__(self, dim):
232
+ super().__init__()
233
+ self.dim = dim
234
+
235
+ # layers
236
+ self.norm = RMS_norm(dim)
237
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
238
+ self.proj = nn.Conv2d(dim, dim, 1)
239
+
240
+ # zero out the last layer params
241
+ nn.init.zeros_(self.proj.weight)
242
+
243
+ def forward(self, x):
244
+ identity = x
245
+ b, c, t, h, w = x.size()
246
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
247
+ x = self.norm(x)
248
+ # compute query, key, value
249
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
250
+ -1).permute(0, 1, 3,
251
+ 2).contiguous().chunk(
252
+ 3, dim=-1)
253
+
254
+ # apply attention
255
+ x = F.scaled_dot_product_attention(
256
+ q,
257
+ k,
258
+ v,
259
+ )
260
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
261
+
262
+ # output
263
+ x = self.proj(x)
264
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
265
+ return x + identity
266
+
267
+
268
+ class Encoder3d(nn.Module):
269
+
270
+ def __init__(self,
271
+ dim=128,
272
+ z_dim=4,
273
+ dim_mult=[1, 2, 4, 4],
274
+ num_res_blocks=2,
275
+ attn_scales=[],
276
+ temperal_downsample=[True, True, False],
277
+ dropout=0.0):
278
+ super().__init__()
279
+ self.dim = dim
280
+ self.z_dim = z_dim
281
+ self.dim_mult = dim_mult
282
+ self.num_res_blocks = num_res_blocks
283
+ self.attn_scales = attn_scales
284
+ self.temperal_downsample = temperal_downsample
285
+
286
+ # dimensions
287
+ dims = [dim * u for u in [1] + dim_mult]
288
+ scale = 1.0
289
+
290
+ # init block
291
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
292
+
293
+ # downsample blocks
294
+ downsamples = []
295
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
296
+ # residual (+attention) blocks
297
+ for _ in range(num_res_blocks):
298
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
299
+ if scale in attn_scales:
300
+ downsamples.append(AttentionBlock(out_dim))
301
+ in_dim = out_dim
302
+
303
+ # downsample block
304
+ if i != len(dim_mult) - 1:
305
+ mode = 'downsample3d' if temperal_downsample[
306
+ i] else 'downsample2d'
307
+ downsamples.append(Resample(out_dim, mode=mode))
308
+ scale /= 2.0
309
+ self.downsamples = nn.Sequential(*downsamples)
310
+
311
+ # middle blocks
312
+ self.middle = nn.Sequential(
313
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
314
+ ResidualBlock(out_dim, out_dim, dropout))
315
+
316
+ # output blocks
317
+ self.head = nn.Sequential(
318
+ RMS_norm(out_dim, images=False), nn.SiLU(),
319
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
320
+
321
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
322
+ if feat_cache is not None:
323
+ idx = feat_idx[0]
324
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
325
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
326
+ # cache last frame of last two chunk
327
+ cache_x = torch.cat([
328
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
329
+ cache_x.device), cache_x
330
+ ],
331
+ dim=2)
332
+ x = self.conv1(x, feat_cache[idx])
333
+ feat_cache[idx] = cache_x
334
+ feat_idx[0] += 1
335
+ else:
336
+ x = self.conv1(x)
337
+
338
+ ## downsamples
339
+ for layer in self.downsamples:
340
+ if feat_cache is not None:
341
+ x = layer(x, feat_cache, feat_idx)
342
+ else:
343
+ x = layer(x)
344
+
345
+ ## middle
346
+ for layer in self.middle:
347
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
348
+ x = layer(x, feat_cache, feat_idx)
349
+ else:
350
+ x = layer(x)
351
+
352
+ ## head
353
+ for layer in self.head:
354
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
355
+ idx = feat_idx[0]
356
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
357
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
358
+ # cache last frame of last two chunk
359
+ cache_x = torch.cat([
360
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
361
+ cache_x.device), cache_x
362
+ ],
363
+ dim=2)
364
+ x = layer(x, feat_cache[idx])
365
+ feat_cache[idx] = cache_x
366
+ feat_idx[0] += 1
367
+ else:
368
+ x = layer(x)
369
+ return x
370
+
371
+
372
+ class Decoder3d(nn.Module):
373
+
374
+ def __init__(self,
375
+ dim=128,
376
+ z_dim=4,
377
+ dim_mult=[1, 2, 4, 4],
378
+ num_res_blocks=2,
379
+ attn_scales=[],
380
+ temperal_upsample=[False, True, True],
381
+ dropout=0.0):
382
+ super().__init__()
383
+ self.dim = dim
384
+ self.z_dim = z_dim
385
+ self.dim_mult = dim_mult
386
+ self.num_res_blocks = num_res_blocks
387
+ self.attn_scales = attn_scales
388
+ self.temperal_upsample = temperal_upsample
389
+
390
+ # dimensions
391
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
392
+ scale = 1.0 / 2**(len(dim_mult) - 2)
393
+
394
+ # init block
395
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
396
+
397
+ # middle blocks
398
+ self.middle = nn.Sequential(
399
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
400
+ ResidualBlock(dims[0], dims[0], dropout))
401
+
402
+ # upsample blocks
403
+ upsamples = []
404
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
405
+ # residual (+attention) blocks
406
+ if i == 1 or i == 2 or i == 3:
407
+ in_dim = in_dim // 2
408
+ for _ in range(num_res_blocks + 1):
409
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
410
+ if scale in attn_scales:
411
+ upsamples.append(AttentionBlock(out_dim))
412
+ in_dim = out_dim
413
+
414
+ # upsample block
415
+ if i != len(dim_mult) - 1:
416
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
417
+ upsamples.append(Resample(out_dim, mode=mode))
418
+ scale *= 2.0
419
+ self.upsamples = nn.Sequential(*upsamples)
420
+
421
+ # output blocks
422
+ self.head = nn.Sequential(
423
+ RMS_norm(out_dim, images=False), nn.SiLU(),
424
+ CausalConv3d(out_dim, 3, 3, padding=1))
425
+
426
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
427
+ ## conv1
428
+ if feat_cache is not None:
429
+ idx = feat_idx[0]
430
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
431
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
432
+ # cache last frame of last two chunk
433
+ cache_x = torch.cat([
434
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
435
+ cache_x.device), cache_x
436
+ ],
437
+ dim=2)
438
+ x = self.conv1(x, feat_cache[idx])
439
+ feat_cache[idx] = cache_x
440
+ feat_idx[0] += 1
441
+ else:
442
+ x = self.conv1(x)
443
+
444
+ ## middle
445
+ for layer in self.middle:
446
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
447
+ x = layer(x, feat_cache, feat_idx)
448
+ else:
449
+ x = layer(x)
450
+
451
+ ## upsamples
452
+ for layer in self.upsamples:
453
+ if feat_cache is not None:
454
+ x = layer(x, feat_cache, feat_idx)
455
+ else:
456
+ x = layer(x)
457
+
458
+ ## head
459
+ for layer in self.head:
460
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
461
+ idx = feat_idx[0]
462
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
463
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
464
+ # cache last frame of last two chunk
465
+ cache_x = torch.cat([
466
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
467
+ cache_x.device), cache_x
468
+ ],
469
+ dim=2)
470
+ x = layer(x, feat_cache[idx])
471
+ feat_cache[idx] = cache_x
472
+ feat_idx[0] += 1
473
+ else:
474
+ x = layer(x)
475
+ return x
476
+
477
+
478
+ def count_conv3d(model):
479
+ count = 0
480
+ for m in model.modules():
481
+ if isinstance(m, CausalConv3d):
482
+ count += 1
483
+ return count
484
+
485
+
486
+ class AutoencoderKLWan_(nn.Module):
487
+
488
+ def __init__(self,
489
+ dim=128,
490
+ z_dim=4,
491
+ dim_mult=[1, 2, 4, 4],
492
+ num_res_blocks=2,
493
+ attn_scales=[],
494
+ temperal_downsample=[True, True, False],
495
+ dropout=0.0):
496
+ super().__init__()
497
+ self.dim = dim
498
+ self.z_dim = z_dim
499
+ self.dim_mult = dim_mult
500
+ self.num_res_blocks = num_res_blocks
501
+ self.attn_scales = attn_scales
502
+ self.temperal_downsample = temperal_downsample
503
+ self.temperal_upsample = temperal_downsample[::-1]
504
+
505
+ # modules
506
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
507
+ attn_scales, self.temperal_downsample, dropout)
508
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
509
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
510
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
511
+ attn_scales, self.temperal_upsample, dropout)
512
+
513
+ def forward(self, x):
514
+ mu, log_var = self.encode(x)
515
+ z = self.reparameterize(mu, log_var)
516
+ x_recon = self.decode(z)
517
+ return x_recon, mu, log_var
518
+
519
+ def encode(self, x, scale):
520
+ self.clear_cache()
521
+ ## cache
522
+ t = x.shape[2]
523
+ iter_ = 1 + (t - 1) // 4
524
+ scale = [item.to(x.device, x.dtype) for item in scale]
525
+ ## 对encode输入的x,按时间拆分为1、4、4、4....
526
+ for i in range(iter_):
527
+ self._enc_conv_idx = [0]
528
+ if i == 0:
529
+ out = self.encoder(
530
+ x[:, :, :1, :, :],
531
+ feat_cache=self._enc_feat_map,
532
+ feat_idx=self._enc_conv_idx)
533
+ else:
534
+ out_ = self.encoder(
535
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
536
+ feat_cache=self._enc_feat_map,
537
+ feat_idx=self._enc_conv_idx)
538
+ out = torch.cat([out, out_], 2)
539
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
540
+ if isinstance(scale[0], torch.Tensor):
541
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
542
+ 1, self.z_dim, 1, 1, 1)
543
+ else:
544
+ mu = (mu - scale[0]) * scale[1]
545
+ x = torch.cat([mu, log_var], dim = 1)
546
+ self.clear_cache()
547
+ return x
548
+
549
+ def decode(self, z, scale):
550
+ self.clear_cache()
551
+ # z: [b,c,t,h,w]
552
+ scale = [item.to(z.device, z.dtype) for item in scale]
553
+ if isinstance(scale[0], torch.Tensor):
554
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
555
+ 1, self.z_dim, 1, 1, 1)
556
+ else:
557
+ z = z / scale[1] + scale[0]
558
+ iter_ = z.shape[2]
559
+ x = self.conv2(z)
560
+ for i in range(iter_):
561
+ self._conv_idx = [0]
562
+ if i == 0:
563
+ out = self.decoder(
564
+ x[:, :, i:i + 1, :, :],
565
+ feat_cache=self._feat_map,
566
+ feat_idx=self._conv_idx)
567
+ else:
568
+ out_ = self.decoder(
569
+ x[:, :, i:i + 1, :, :],
570
+ feat_cache=self._feat_map,
571
+ feat_idx=self._conv_idx)
572
+ out = torch.cat([out, out_], 2)
573
+ self.clear_cache()
574
+ return out
575
+
576
+ def reparameterize(self, mu, log_var):
577
+ std = torch.exp(0.5 * log_var)
578
+ eps = torch.randn_like(std)
579
+ return eps * std + mu
580
+
581
+ def sample(self, imgs, deterministic=False):
582
+ mu, log_var = self.encode(imgs)
583
+ if deterministic:
584
+ return mu
585
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
586
+ return mu + std * torch.randn_like(std)
587
+
588
+ def clear_cache(self):
589
+ self._conv_num = count_conv3d(self.decoder)
590
+ self._conv_idx = [0]
591
+ self._feat_map = [None] * self._conv_num
592
+ #cache encode
593
+ self._enc_conv_num = count_conv3d(self.encoder)
594
+ self._enc_conv_idx = [0]
595
+ self._enc_feat_map = [None] * self._enc_conv_num
596
+
597
+
598
+ def _video_vae(z_dim=None, **kwargs):
599
+ """
600
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
601
+ """
602
+ # params
603
+ cfg = dict(
604
+ dim=96,
605
+ z_dim=z_dim,
606
+ dim_mult=[1, 2, 4, 4],
607
+ num_res_blocks=2,
608
+ attn_scales=[],
609
+ temperal_downsample=[False, True, True],
610
+ dropout=0.0)
611
+ cfg.update(**kwargs)
612
+
613
+ # init model
614
+ model = AutoencoderKLWan_(**cfg)
615
+
616
+ return model
617
+
618
+
619
+ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
620
+
621
+ @register_to_config
622
+ def __init__(
623
+ self,
624
+ latent_channels=16,
625
+ temporal_compression_ratio=4,
626
+ spacial_compression_ratio=8
627
+ ):
628
+ super().__init__()
629
+ mean = [
630
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
631
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
632
+ ]
633
+ std = [
634
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
635
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
636
+ ]
637
+ self.mean = torch.tensor(mean, dtype=torch.float32)
638
+ self.std = torch.tensor(std, dtype=torch.float32)
639
+ self.scale = [self.mean, 1.0 / self.std]
640
+
641
+ # init model
642
+ self.model = _video_vae(
643
+ z_dim=latent_channels,
644
+ )
645
+
646
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
647
+ x = [
648
+ self.model.encode(u.unsqueeze(0), self.scale).squeeze(0)
649
+ for u in x
650
+ ]
651
+ x = torch.stack(x)
652
+ return x
653
+
654
+ @apply_forward_hook
655
+ def encode(
656
+ self, x: torch.Tensor, return_dict: bool = True
657
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
658
+ h = self._encode(x)
659
+
660
+ posterior = DiagonalGaussianDistribution(h)
661
+
662
+ if not return_dict:
663
+ return (posterior,)
664
+ return AutoencoderKLOutput(latent_dist=posterior)
665
+
666
+ def _decode(self, zs):
667
+ dec = [
668
+ self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
669
+ for u in zs
670
+ ]
671
+ dec = torch.stack(dec)
672
+
673
+ return DecoderOutput(sample=dec)
674
+
675
+ @apply_forward_hook
676
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
677
+ decoded = self._decode(z).sample
678
+
679
+ if not return_dict:
680
+ return (decoded,)
681
+ return DecoderOutput(sample=decoded)
682
+
683
+ @classmethod
684
+ def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
685
+ def filter_kwargs(cls, kwargs):
686
+ import inspect
687
+ sig = inspect.signature(cls.__init__)
688
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
689
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
690
+ return filtered_kwargs
691
+
692
+ model = cls(**filter_kwargs(cls, additional_kwargs))
693
+ if pretrained_model_path.endswith(".safetensors"):
694
+ from safetensors.torch import load_file, safe_open
695
+ state_dict = load_file(pretrained_model_path)
696
+ else:
697
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
698
+ tmp_state_dict = {}
699
+ for key in state_dict:
700
+ tmp_state_dict["model." + key] = state_dict[key]
701
+ state_dict = tmp_state_dict
702
+ m, u = model.load_state_dict(state_dict, strict=False)
703
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
704
+ print(m, u)
705
+ return model
rose/models/wan_xlm_roberta.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ __all__ = ['XLMRoberta', 'xlm_roberta_large']
8
+
9
+
10
+ class SelfAttention(nn.Module):
11
+
12
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
13
+ assert dim % num_heads == 0
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.num_heads = num_heads
17
+ self.head_dim = dim // num_heads
18
+ self.eps = eps
19
+
20
+ # layers
21
+ self.q = nn.Linear(dim, dim)
22
+ self.k = nn.Linear(dim, dim)
23
+ self.v = nn.Linear(dim, dim)
24
+ self.o = nn.Linear(dim, dim)
25
+ self.dropout = nn.Dropout(dropout)
26
+
27
+ def forward(self, x, mask):
28
+ """
29
+ x: [B, L, C].
30
+ """
31
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
32
+
33
+ # compute query, key, value
34
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
35
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
36
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
37
+
38
+ # compute attention
39
+ p = self.dropout.p if self.training else 0.0
40
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
41
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
42
+
43
+ # output
44
+ x = self.o(x)
45
+ x = self.dropout(x)
46
+ return x
47
+
48
+
49
+ class AttentionBlock(nn.Module):
50
+
51
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
52
+ super().__init__()
53
+ self.dim = dim
54
+ self.num_heads = num_heads
55
+ self.post_norm = post_norm
56
+ self.eps = eps
57
+
58
+ # layers
59
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
60
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
61
+ self.ffn = nn.Sequential(
62
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
63
+ nn.Dropout(dropout))
64
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
65
+
66
+ def forward(self, x, mask):
67
+ if self.post_norm:
68
+ x = self.norm1(x + self.attn(x, mask))
69
+ x = self.norm2(x + self.ffn(x))
70
+ else:
71
+ x = x + self.attn(self.norm1(x), mask)
72
+ x = x + self.ffn(self.norm2(x))
73
+ return x
74
+
75
+
76
+ class XLMRoberta(nn.Module):
77
+ """
78
+ XLMRobertaModel with no pooler and no LM head.
79
+ """
80
+
81
+ def __init__(self,
82
+ vocab_size=250002,
83
+ max_seq_len=514,
84
+ type_size=1,
85
+ pad_id=1,
86
+ dim=1024,
87
+ num_heads=16,
88
+ num_layers=24,
89
+ post_norm=True,
90
+ dropout=0.1,
91
+ eps=1e-5):
92
+ super().__init__()
93
+ self.vocab_size = vocab_size
94
+ self.max_seq_len = max_seq_len
95
+ self.type_size = type_size
96
+ self.pad_id = pad_id
97
+ self.dim = dim
98
+ self.num_heads = num_heads
99
+ self.num_layers = num_layers
100
+ self.post_norm = post_norm
101
+ self.eps = eps
102
+
103
+ # embeddings
104
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
105
+ self.type_embedding = nn.Embedding(type_size, dim)
106
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
107
+ self.dropout = nn.Dropout(dropout)
108
+
109
+ # blocks
110
+ self.blocks = nn.ModuleList([
111
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
112
+ for _ in range(num_layers)
113
+ ])
114
+
115
+ # norm layer
116
+ self.norm = nn.LayerNorm(dim, eps=eps)
117
+
118
+ def forward(self, ids):
119
+ """
120
+ ids: [B, L] of torch.LongTensor.
121
+ """
122
+ b, s = ids.shape
123
+ mask = ids.ne(self.pad_id).long()
124
+
125
+ # embeddings
126
+ x = self.token_embedding(ids) + \
127
+ self.type_embedding(torch.zeros_like(ids)) + \
128
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
129
+ if self.post_norm:
130
+ x = self.norm(x)
131
+ x = self.dropout(x)
132
+
133
+ # blocks
134
+ mask = torch.where(
135
+ mask.view(b, 1, 1, s).gt(0), 0.0,
136
+ torch.finfo(x.dtype).min)
137
+ for block in self.blocks:
138
+ x = block(x, mask)
139
+
140
+ # output
141
+ if not self.post_norm:
142
+ x = self.norm(x)
143
+ return x
144
+
145
+
146
+ def xlm_roberta_large(pretrained=False,
147
+ return_tokenizer=False,
148
+ device='cpu',
149
+ **kwargs):
150
+ """
151
+ XLMRobertaLarge adapted from Huggingface.
152
+ """
153
+ # params
154
+ cfg = dict(
155
+ vocab_size=250002,
156
+ max_seq_len=514,
157
+ type_size=1,
158
+ pad_id=1,
159
+ dim=1024,
160
+ num_heads=16,
161
+ num_layers=24,
162
+ post_norm=True,
163
+ dropout=0.1,
164
+ eps=1e-5)
165
+ cfg.update(**kwargs)
166
+
167
+ # init a model on device
168
+ with torch.device(device):
169
+ model = XLMRoberta(**cfg)
170
+ return model
rose/pipeline/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .pipeline_wan_fun import WanFunPipeline
2
+ from .pipeline_wan_fun_inpaint import WanFunInpaintPipeline
3
+ from .pipeline_wan_fun_control import WanFunControlPipeline
4
+
5
+ WanPipeline = WanFunPipeline
6
+ WanI2VPipeline = WanFunInpaintPipeline
rose/pipeline/pipeline_wan_fun.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import FlowMatchEulerDiscreteScheduler
9
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
10
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
12
+ from diffusers.utils.torch_utils import randn_tensor
13
+ from diffusers.video_processor import VideoProcessor
14
+
15
+ from ..models import (AutoencoderKLWan, AutoTokenizer,
16
+ WanT5EncoderModel, WanTransformer3DModel)
17
+
18
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
+
20
+
21
+ EXAMPLE_DOC_STRING = """
22
+ Examples:
23
+ ```python
24
+ pass
25
+ ```
26
+ """
27
+
28
+
29
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
30
+ def retrieve_timesteps(
31
+ scheduler,
32
+ num_inference_steps: Optional[int] = None,
33
+ device: Optional[Union[str, torch.device]] = None,
34
+ timesteps: Optional[List[int]] = None,
35
+ sigmas: Optional[List[float]] = None,
36
+ **kwargs,
37
+ ):
38
+ """
39
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
40
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
41
+
42
+ Args:
43
+ scheduler (`SchedulerMixin`):
44
+ The scheduler to get timesteps from.
45
+ num_inference_steps (`int`):
46
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
47
+ must be `None`.
48
+ device (`str` or `torch.device`, *optional*):
49
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
50
+ timesteps (`List[int]`, *optional*):
51
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
52
+ `num_inference_steps` and `sigmas` must be `None`.
53
+ sigmas (`List[float]`, *optional*):
54
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
55
+ `num_inference_steps` and `timesteps` must be `None`.
56
+
57
+ Returns:
58
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
59
+ second element is the number of inference steps.
60
+ """
61
+ if timesteps is not None and sigmas is not None:
62
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
63
+ if timesteps is not None:
64
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
65
+ if not accepts_timesteps:
66
+ raise ValueError(
67
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
68
+ f" timestep schedules. Please check whether you are using the correct scheduler."
69
+ )
70
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
71
+ timesteps = scheduler.timesteps
72
+ num_inference_steps = len(timesteps)
73
+ elif sigmas is not None:
74
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
75
+ if not accept_sigmas:
76
+ raise ValueError(
77
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
78
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
79
+ )
80
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
81
+ timesteps = scheduler.timesteps
82
+ num_inference_steps = len(timesteps)
83
+ else:
84
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
85
+ timesteps = scheduler.timesteps
86
+ return timesteps, num_inference_steps
87
+
88
+
89
+ @dataclass
90
+ class WanPipelineOutput(BaseOutput):
91
+ r"""
92
+ Output class for CogVideo pipelines.
93
+
94
+ Args:
95
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
96
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
97
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
98
+ `(batch_size, num_frames, channels, height, width)`.
99
+ """
100
+
101
+ videos: torch.Tensor
102
+
103
+
104
+ class WanFunPipeline(DiffusionPipeline):
105
+ r"""
106
+ Pipeline for text-to-video generation using Wan.
107
+
108
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
109
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
110
+ """
111
+
112
+ _optional_components = []
113
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
114
+
115
+ _callback_tensor_inputs = [
116
+ "latents",
117
+ "prompt_embeds",
118
+ "negative_prompt_embeds",
119
+ ]
120
+
121
+ def __init__(
122
+ self,
123
+ tokenizer: AutoTokenizer,
124
+ text_encoder: WanT5EncoderModel,
125
+ vae: AutoencoderKLWan,
126
+ transformer: WanTransformer3DModel,
127
+ scheduler: FlowMatchEulerDiscreteScheduler,
128
+ ):
129
+ super().__init__()
130
+
131
+ self.register_modules(
132
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
133
+ )
134
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
135
+
136
+ def _get_t5_prompt_embeds(
137
+ self,
138
+ prompt: Union[str, List[str]] = None,
139
+ num_videos_per_prompt: int = 1,
140
+ max_sequence_length: int = 512,
141
+ device: Optional[torch.device] = None,
142
+ dtype: Optional[torch.dtype] = None,
143
+ ):
144
+ device = device or self._execution_device
145
+ dtype = dtype or self.text_encoder.dtype
146
+
147
+ prompt = [prompt] if isinstance(prompt, str) else prompt
148
+ batch_size = len(prompt)
149
+
150
+ text_inputs = self.tokenizer(
151
+ prompt,
152
+ padding="max_length",
153
+ max_length=max_sequence_length,
154
+ truncation=True,
155
+ add_special_tokens=True,
156
+ return_tensors="pt",
157
+ )
158
+ text_input_ids = text_inputs.input_ids
159
+ prompt_attention_mask = text_inputs.attention_mask
160
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
161
+
162
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
163
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
164
+ logger.warning(
165
+ "The following part of your input was truncated because `max_sequence_length` is set to "
166
+ f" {max_sequence_length} tokens: {removed_text}"
167
+ )
168
+
169
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
170
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
171
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
172
+
173
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
174
+ _, seq_len, _ = prompt_embeds.shape
175
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
176
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
177
+
178
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
179
+
180
+ def encode_prompt(
181
+ self,
182
+ prompt: Union[str, List[str]],
183
+ negative_prompt: Optional[Union[str, List[str]]] = None,
184
+ do_classifier_free_guidance: bool = True,
185
+ num_videos_per_prompt: int = 1,
186
+ prompt_embeds: Optional[torch.Tensor] = None,
187
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
188
+ max_sequence_length: int = 512,
189
+ device: Optional[torch.device] = None,
190
+ dtype: Optional[torch.dtype] = None,
191
+ ):
192
+ r"""
193
+ Encodes the prompt into text encoder hidden states.
194
+
195
+ Args:
196
+ prompt (`str` or `List[str]`, *optional*):
197
+ prompt to be encoded
198
+ negative_prompt (`str` or `List[str]`, *optional*):
199
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
200
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
201
+ less than `1`).
202
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
203
+ Whether to use classifier free guidance or not.
204
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
205
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
206
+ prompt_embeds (`torch.Tensor`, *optional*):
207
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
208
+ provided, text embeddings will be generated from `prompt` input argument.
209
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
210
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
211
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
212
+ argument.
213
+ device: (`torch.device`, *optional*):
214
+ torch device
215
+ dtype: (`torch.dtype`, *optional*):
216
+ torch dtype
217
+ """
218
+ device = device or self._execution_device
219
+
220
+ prompt = [prompt] if isinstance(prompt, str) else prompt
221
+ if prompt is not None:
222
+ batch_size = len(prompt)
223
+ else:
224
+ batch_size = prompt_embeds.shape[0]
225
+
226
+ if prompt_embeds is None:
227
+ prompt_embeds = self._get_t5_prompt_embeds(
228
+ prompt=prompt,
229
+ num_videos_per_prompt=num_videos_per_prompt,
230
+ max_sequence_length=max_sequence_length,
231
+ device=device,
232
+ dtype=dtype,
233
+ )
234
+
235
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
236
+ negative_prompt = negative_prompt or ""
237
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
238
+
239
+ if prompt is not None and type(prompt) is not type(negative_prompt):
240
+ raise TypeError(
241
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
242
+ f" {type(prompt)}."
243
+ )
244
+ elif batch_size != len(negative_prompt):
245
+ raise ValueError(
246
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
247
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
248
+ " the batch size of `prompt`."
249
+ )
250
+
251
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
252
+ prompt=negative_prompt,
253
+ num_videos_per_prompt=num_videos_per_prompt,
254
+ max_sequence_length=max_sequence_length,
255
+ device=device,
256
+ dtype=dtype,
257
+ )
258
+
259
+ return prompt_embeds, negative_prompt_embeds
260
+
261
+ def prepare_latents(
262
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
263
+ ):
264
+ if isinstance(generator, list) and len(generator) != batch_size:
265
+ raise ValueError(
266
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
267
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
268
+ )
269
+
270
+ shape = (
271
+ batch_size,
272
+ num_channels_latents,
273
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
274
+ height // self.vae.spacial_compression_ratio,
275
+ width // self.vae.spacial_compression_ratio,
276
+ )
277
+
278
+ if latents is None:
279
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
280
+ else:
281
+ latents = latents.to(device)
282
+
283
+ # scale the initial noise by the standard deviation required by the scheduler
284
+ if hasattr(self.scheduler, "init_noise_sigma"):
285
+ latents = latents * self.scheduler.init_noise_sigma
286
+ return latents
287
+
288
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
289
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
290
+ frames = (frames / 2 + 0.5).clamp(0, 1)
291
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
292
+ frames = frames.cpu().float().numpy()
293
+ return frames
294
+
295
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
296
+ def prepare_extra_step_kwargs(self, generator, eta):
297
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
298
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
299
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
300
+ # and should be between [0, 1]
301
+
302
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
303
+ extra_step_kwargs = {}
304
+ if accepts_eta:
305
+ extra_step_kwargs["eta"] = eta
306
+
307
+ # check if the scheduler accepts generator
308
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
309
+ if accepts_generator:
310
+ extra_step_kwargs["generator"] = generator
311
+ return extra_step_kwargs
312
+
313
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
314
+ def check_inputs(
315
+ self,
316
+ prompt,
317
+ height,
318
+ width,
319
+ negative_prompt,
320
+ callback_on_step_end_tensor_inputs,
321
+ prompt_embeds=None,
322
+ negative_prompt_embeds=None,
323
+ ):
324
+ if height % 8 != 0 or width % 8 != 0:
325
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
326
+
327
+ if callback_on_step_end_tensor_inputs is not None and not all(
328
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
329
+ ):
330
+ raise ValueError(
331
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
332
+ )
333
+ if prompt is not None and prompt_embeds is not None:
334
+ raise ValueError(
335
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
336
+ " only forward one of the two."
337
+ )
338
+ elif prompt is None and prompt_embeds is None:
339
+ raise ValueError(
340
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
341
+ )
342
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
343
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
344
+
345
+ if prompt is not None and negative_prompt_embeds is not None:
346
+ raise ValueError(
347
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
348
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
349
+ )
350
+
351
+ if negative_prompt is not None and negative_prompt_embeds is not None:
352
+ raise ValueError(
353
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
354
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
355
+ )
356
+
357
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
358
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
359
+ raise ValueError(
360
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
361
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
362
+ f" {negative_prompt_embeds.shape}."
363
+ )
364
+
365
+ @property
366
+ def guidance_scale(self):
367
+ return self._guidance_scale
368
+
369
+ @property
370
+ def num_timesteps(self):
371
+ return self._num_timesteps
372
+
373
+ @property
374
+ def attention_kwargs(self):
375
+ return self._attention_kwargs
376
+
377
+ @property
378
+ def interrupt(self):
379
+ return self._interrupt
380
+
381
+ @torch.no_grad()
382
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
383
+ def __call__(
384
+ self,
385
+ prompt: Optional[Union[str, List[str]]] = None,
386
+ negative_prompt: Optional[Union[str, List[str]]] = None,
387
+ height: int = 480,
388
+ width: int = 720,
389
+ num_frames: int = 49,
390
+ num_inference_steps: int = 50,
391
+ timesteps: Optional[List[int]] = None,
392
+ guidance_scale: float = 6,
393
+ num_videos_per_prompt: int = 1,
394
+ eta: float = 0.0,
395
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
396
+ latents: Optional[torch.FloatTensor] = None,
397
+ prompt_embeds: Optional[torch.FloatTensor] = None,
398
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
399
+ output_type: str = "numpy",
400
+ return_dict: bool = False,
401
+ callback_on_step_end: Optional[
402
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
403
+ ] = None,
404
+ attention_kwargs: Optional[Dict[str, Any]] = None,
405
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
406
+ max_sequence_length: int = 512,
407
+ comfyui_progressbar: bool = False,
408
+ ) -> Union[WanPipelineOutput, Tuple]:
409
+ """
410
+ Function invoked when calling the pipeline for generation.
411
+ Args:
412
+
413
+ Examples:
414
+
415
+ Returns:
416
+
417
+ """
418
+
419
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
420
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
421
+ num_videos_per_prompt = 1
422
+
423
+ # 1. Check inputs. Raise error if not correct
424
+ self.check_inputs(
425
+ prompt,
426
+ height,
427
+ width,
428
+ negative_prompt,
429
+ callback_on_step_end_tensor_inputs,
430
+ prompt_embeds,
431
+ negative_prompt_embeds,
432
+ )
433
+ self._guidance_scale = guidance_scale
434
+ self._attention_kwargs = attention_kwargs
435
+ self._interrupt = False
436
+
437
+ # 2. Default call parameters
438
+ if prompt is not None and isinstance(prompt, str):
439
+ batch_size = 1
440
+ elif prompt is not None and isinstance(prompt, list):
441
+ batch_size = len(prompt)
442
+ else:
443
+ batch_size = prompt_embeds.shape[0]
444
+
445
+ device = self._execution_device
446
+ weight_dtype = self.text_encoder.dtype
447
+
448
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
449
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
450
+ # corresponds to doing no classifier free guidance.
451
+ do_classifier_free_guidance = guidance_scale > 1.0
452
+
453
+ # 3. Encode input prompt
454
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
455
+ prompt,
456
+ negative_prompt,
457
+ do_classifier_free_guidance,
458
+ num_videos_per_prompt=num_videos_per_prompt,
459
+ prompt_embeds=prompt_embeds,
460
+ negative_prompt_embeds=negative_prompt_embeds,
461
+ max_sequence_length=max_sequence_length,
462
+ device=device,
463
+ )
464
+ if do_classifier_free_guidance:
465
+ prompt_embeds = negative_prompt_embeds + prompt_embeds
466
+
467
+ # 4. Prepare timesteps
468
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
469
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
470
+ else:
471
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
472
+ self._num_timesteps = len(timesteps)
473
+ if comfyui_progressbar:
474
+ from comfy.utils import ProgressBar
475
+ pbar = ProgressBar(num_inference_steps + 1)
476
+
477
+ # 5. Prepare latents
478
+ latent_channels = self.transformer.config.in_channels
479
+ latents = self.prepare_latents(
480
+ batch_size * num_videos_per_prompt,
481
+ latent_channels,
482
+ num_frames,
483
+ height,
484
+ width,
485
+ weight_dtype,
486
+ device,
487
+ generator,
488
+ latents,
489
+ )
490
+ if comfyui_progressbar:
491
+ pbar.update(1)
492
+
493
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
494
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
495
+
496
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spacial_compression_ratio, height // self.vae.spacial_compression_ratio)
497
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
498
+ # 7. Denoising loop
499
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
500
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
501
+ for i, t in enumerate(timesteps):
502
+ if self.interrupt:
503
+ continue
504
+
505
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
506
+ if hasattr(self.scheduler, "scale_model_input"):
507
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
508
+
509
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
510
+ timestep = t.expand(latent_model_input.shape[0])
511
+
512
+ # predict noise model_output
513
+ with torch.cuda.amp.autocast(dtype=weight_dtype):
514
+ noise_pred = self.transformer(
515
+ x=latent_model_input,
516
+ context=prompt_embeds,
517
+ t=timestep,
518
+ seq_len=seq_len,
519
+ )
520
+
521
+ # perform guidance
522
+ if do_classifier_free_guidance:
523
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
524
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
525
+
526
+ # compute the previous noisy sample x_t -> x_t-1
527
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
528
+
529
+ if callback_on_step_end is not None:
530
+ callback_kwargs = {}
531
+ for k in callback_on_step_end_tensor_inputs:
532
+ callback_kwargs[k] = locals()[k]
533
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
534
+
535
+ latents = callback_outputs.pop("latents", latents)
536
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
537
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
538
+
539
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
540
+ progress_bar.update()
541
+ if comfyui_progressbar:
542
+ pbar.update(1)
543
+
544
+ if output_type == "numpy":
545
+ video = self.decode_latents(latents)
546
+ elif not output_type == "latent":
547
+ video = self.decode_latents(latents)
548
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
549
+ else:
550
+ video = latents
551
+
552
+ # Offload all models
553
+ self.maybe_free_model_hooks()
554
+
555
+ if not return_dict:
556
+ video = torch.from_numpy(video)
557
+
558
+ return WanPipelineOutput(videos=video)
rose/pipeline/pipeline_wan_fun_control.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.functional as TF
10
+ from diffusers import FlowMatchEulerDiscreteScheduler
11
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
16
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.video_processor import VideoProcessor
19
+ from einops import rearrange
20
+ from PIL import Image
21
+ from transformers import T5Tokenizer
22
+
23
+ from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
24
+ WanT5EncoderModel, WanTransformer3DModel)
25
+
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+
29
+ EXAMPLE_DOC_STRING = """
30
+ Examples:
31
+ ```python
32
+ pass
33
+ ```
34
+ """
35
+
36
+
37
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
38
+ def retrieve_timesteps(
39
+ scheduler,
40
+ num_inference_steps: Optional[int] = None,
41
+ device: Optional[Union[str, torch.device]] = None,
42
+ timesteps: Optional[List[int]] = None,
43
+ sigmas: Optional[List[float]] = None,
44
+ **kwargs,
45
+ ):
46
+ """
47
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
48
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
49
+
50
+ Args:
51
+ scheduler (`SchedulerMixin`):
52
+ The scheduler to get timesteps from.
53
+ num_inference_steps (`int`):
54
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
55
+ must be `None`.
56
+ device (`str` or `torch.device`, *optional*):
57
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
58
+ timesteps (`List[int]`, *optional*):
59
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
60
+ `num_inference_steps` and `sigmas` must be `None`.
61
+ sigmas (`List[float]`, *optional*):
62
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
63
+ `num_inference_steps` and `timesteps` must be `None`.
64
+
65
+ Returns:
66
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
67
+ second element is the number of inference steps.
68
+ """
69
+ if timesteps is not None and sigmas is not None:
70
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
71
+ if timesteps is not None:
72
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
73
+ if not accepts_timesteps:
74
+ raise ValueError(
75
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
76
+ f" timestep schedules. Please check whether you are using the correct scheduler."
77
+ )
78
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
79
+ timesteps = scheduler.timesteps
80
+ num_inference_steps = len(timesteps)
81
+ elif sigmas is not None:
82
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
83
+ if not accept_sigmas:
84
+ raise ValueError(
85
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
86
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
87
+ )
88
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
89
+ timesteps = scheduler.timesteps
90
+ num_inference_steps = len(timesteps)
91
+ else:
92
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
93
+ timesteps = scheduler.timesteps
94
+ return timesteps, num_inference_steps
95
+
96
+
97
+ def resize_mask(mask, latent, process_first_frame_only=True):
98
+ latent_size = latent.size()
99
+ batch_size, channels, num_frames, height, width = mask.shape
100
+
101
+ if process_first_frame_only:
102
+ target_size = list(latent_size[2:])
103
+ target_size[0] = 1
104
+ first_frame_resized = F.interpolate(
105
+ mask[:, :, 0:1, :, :],
106
+ size=target_size,
107
+ mode='trilinear',
108
+ align_corners=False
109
+ )
110
+
111
+ target_size = list(latent_size[2:])
112
+ target_size[0] = target_size[0] - 1
113
+ if target_size[0] != 0:
114
+ remaining_frames_resized = F.interpolate(
115
+ mask[:, :, 1:, :, :],
116
+ size=target_size,
117
+ mode='trilinear',
118
+ align_corners=False
119
+ )
120
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
121
+ else:
122
+ resized_mask = first_frame_resized
123
+ else:
124
+ target_size = list(latent_size[2:])
125
+ resized_mask = F.interpolate(
126
+ mask,
127
+ size=target_size,
128
+ mode='trilinear',
129
+ align_corners=False
130
+ )
131
+ return resized_mask
132
+
133
+
134
+ @dataclass
135
+ class WanPipelineOutput(BaseOutput):
136
+ r"""
137
+ Output class for CogVideo pipelines.
138
+
139
+ Args:
140
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
141
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
142
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
143
+ `(batch_size, num_frames, channels, height, width)`.
144
+ """
145
+
146
+ videos: torch.Tensor
147
+
148
+
149
+ class WanFunControlPipeline(DiffusionPipeline):
150
+ r"""
151
+ Pipeline for text-to-video generation using Wan.
152
+
153
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
154
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
155
+ """
156
+
157
+ _optional_components = []
158
+ model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae"
159
+
160
+ _callback_tensor_inputs = [
161
+ "latents",
162
+ "prompt_embeds",
163
+ "negative_prompt_embeds",
164
+ ]
165
+
166
+ def __init__(
167
+ self,
168
+ tokenizer: AutoTokenizer,
169
+ text_encoder: WanT5EncoderModel,
170
+ vae: AutoencoderKLWan,
171
+ transformer: WanTransformer3DModel,
172
+ clip_image_encoder: CLIPModel,
173
+ scheduler: FlowMatchEulerDiscreteScheduler,
174
+ ):
175
+ super().__init__()
176
+
177
+ self.register_modules(
178
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler
179
+ )
180
+
181
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
182
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
183
+ self.mask_processor = VaeImageProcessor(
184
+ vae_scale_factor=self.vae.spacial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
185
+ )
186
+
187
+ def _get_t5_prompt_embeds(
188
+ self,
189
+ prompt: Union[str, List[str]] = None,
190
+ num_videos_per_prompt: int = 1,
191
+ max_sequence_length: int = 512,
192
+ device: Optional[torch.device] = None,
193
+ dtype: Optional[torch.dtype] = None,
194
+ ):
195
+ device = device or self._execution_device
196
+ dtype = dtype or self.text_encoder.dtype
197
+
198
+ prompt = [prompt] if isinstance(prompt, str) else prompt
199
+ batch_size = len(prompt)
200
+
201
+ text_inputs = self.tokenizer(
202
+ prompt,
203
+ padding="max_length",
204
+ max_length=max_sequence_length,
205
+ truncation=True,
206
+ add_special_tokens=True,
207
+ return_tensors="pt",
208
+ )
209
+ text_input_ids = text_inputs.input_ids
210
+ prompt_attention_mask = text_inputs.attention_mask
211
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
212
+
213
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
214
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
215
+ logger.warning(
216
+ "The following part of your input was truncated because `max_sequence_length` is set to "
217
+ f" {max_sequence_length} tokens: {removed_text}"
218
+ )
219
+
220
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
221
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
222
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
223
+
224
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
225
+ _, seq_len, _ = prompt_embeds.shape
226
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
227
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
228
+
229
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
230
+
231
+ def encode_prompt(
232
+ self,
233
+ prompt: Union[str, List[str]],
234
+ negative_prompt: Optional[Union[str, List[str]]] = None,
235
+ do_classifier_free_guidance: bool = True,
236
+ num_videos_per_prompt: int = 1,
237
+ prompt_embeds: Optional[torch.Tensor] = None,
238
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
239
+ max_sequence_length: int = 512,
240
+ device: Optional[torch.device] = None,
241
+ dtype: Optional[torch.dtype] = None,
242
+ ):
243
+ r"""
244
+ Encodes the prompt into text encoder hidden states.
245
+
246
+ Args:
247
+ prompt (`str` or `List[str]`, *optional*):
248
+ prompt to be encoded
249
+ negative_prompt (`str` or `List[str]`, *optional*):
250
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
251
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
252
+ less than `1`).
253
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
254
+ Whether to use classifier free guidance or not.
255
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
256
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
257
+ prompt_embeds (`torch.Tensor`, *optional*):
258
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
259
+ provided, text embeddings will be generated from `prompt` input argument.
260
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
261
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
262
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
263
+ argument.
264
+ device: (`torch.device`, *optional*):
265
+ torch device
266
+ dtype: (`torch.dtype`, *optional*):
267
+ torch dtype
268
+ """
269
+ device = device or self._execution_device
270
+
271
+ prompt = [prompt] if isinstance(prompt, str) else prompt
272
+ if prompt is not None:
273
+ batch_size = len(prompt)
274
+ else:
275
+ batch_size = prompt_embeds.shape[0]
276
+
277
+ if prompt_embeds is None:
278
+ prompt_embeds = self._get_t5_prompt_embeds(
279
+ prompt=prompt,
280
+ num_videos_per_prompt=num_videos_per_prompt,
281
+ max_sequence_length=max_sequence_length,
282
+ device=device,
283
+ dtype=dtype,
284
+ )
285
+
286
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
287
+ negative_prompt = negative_prompt or ""
288
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
289
+
290
+ if prompt is not None and type(prompt) is not type(negative_prompt):
291
+ raise TypeError(
292
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
293
+ f" {type(prompt)}."
294
+ )
295
+ elif batch_size != len(negative_prompt):
296
+ raise ValueError(
297
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
298
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
299
+ " the batch size of `prompt`."
300
+ )
301
+
302
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
303
+ prompt=negative_prompt,
304
+ num_videos_per_prompt=num_videos_per_prompt,
305
+ max_sequence_length=max_sequence_length,
306
+ device=device,
307
+ dtype=dtype,
308
+ )
309
+
310
+ return prompt_embeds, negative_prompt_embeds
311
+
312
+ def prepare_latents(
313
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
314
+ ):
315
+ if isinstance(generator, list) and len(generator) != batch_size:
316
+ raise ValueError(
317
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
318
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
319
+ )
320
+
321
+ shape = (
322
+ batch_size,
323
+ num_channels_latents,
324
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
325
+ height // self.vae.spacial_compression_ratio,
326
+ width // self.vae.spacial_compression_ratio,
327
+ )
328
+
329
+ if latents is None:
330
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
331
+ else:
332
+ latents = latents.to(device)
333
+
334
+ # scale the initial noise by the standard deviation required by the scheduler
335
+ if hasattr(self.scheduler, "init_noise_sigma"):
336
+ latents = latents * self.scheduler.init_noise_sigma
337
+ return latents
338
+
339
+ def prepare_control_latents(
340
+ self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
341
+ ):
342
+ # resize the control to latents shape as we concatenate the control to the latents
343
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
344
+ # and half precision
345
+
346
+ if control is not None:
347
+ control = control.to(device=device, dtype=dtype)
348
+ bs = 1
349
+ new_control = []
350
+ for i in range(0, control.shape[0], bs):
351
+ control_bs = control[i : i + bs]
352
+ control_bs = self.vae.encode(control_bs)[0]
353
+ control_bs = control_bs.mode()
354
+ new_control.append(control_bs)
355
+ control = torch.cat(new_control, dim = 0)
356
+
357
+ if control_image is not None:
358
+ control_image = control_image.to(device=device, dtype=dtype)
359
+ bs = 1
360
+ new_control_pixel_values = []
361
+ for i in range(0, control_image.shape[0], bs):
362
+ control_pixel_values_bs = control_image[i : i + bs]
363
+ control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
364
+ control_pixel_values_bs = control_pixel_values_bs.mode()
365
+ new_control_pixel_values.append(control_pixel_values_bs)
366
+ control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
367
+ else:
368
+ control_image_latents = None
369
+
370
+ return control, control_image_latents
371
+
372
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
373
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
374
+ frames = (frames / 2 + 0.5).clamp(0, 1)
375
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
376
+ frames = frames.cpu().float().numpy()
377
+ return frames
378
+
379
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
380
+ def prepare_extra_step_kwargs(self, generator, eta):
381
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
382
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
383
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
384
+ # and should be between [0, 1]
385
+
386
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
387
+ extra_step_kwargs = {}
388
+ if accepts_eta:
389
+ extra_step_kwargs["eta"] = eta
390
+
391
+ # check if the scheduler accepts generator
392
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
393
+ if accepts_generator:
394
+ extra_step_kwargs["generator"] = generator
395
+ return extra_step_kwargs
396
+
397
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
398
+ def check_inputs(
399
+ self,
400
+ prompt,
401
+ height,
402
+ width,
403
+ negative_prompt,
404
+ callback_on_step_end_tensor_inputs,
405
+ prompt_embeds=None,
406
+ negative_prompt_embeds=None,
407
+ ):
408
+ if height % 8 != 0 or width % 8 != 0:
409
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
410
+
411
+ if callback_on_step_end_tensor_inputs is not None and not all(
412
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
413
+ ):
414
+ raise ValueError(
415
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
416
+ )
417
+ if prompt is not None and prompt_embeds is not None:
418
+ raise ValueError(
419
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
420
+ " only forward one of the two."
421
+ )
422
+ elif prompt is None and prompt_embeds is None:
423
+ raise ValueError(
424
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
425
+ )
426
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
427
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
428
+
429
+ if prompt is not None and negative_prompt_embeds is not None:
430
+ raise ValueError(
431
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
432
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
433
+ )
434
+
435
+ if negative_prompt is not None and negative_prompt_embeds is not None:
436
+ raise ValueError(
437
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
438
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
439
+ )
440
+
441
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
442
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
443
+ raise ValueError(
444
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
445
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
446
+ f" {negative_prompt_embeds.shape}."
447
+ )
448
+
449
+ @property
450
+ def guidance_scale(self):
451
+ return self._guidance_scale
452
+
453
+ @property
454
+ def num_timesteps(self):
455
+ return self._num_timesteps
456
+
457
+ @property
458
+ def attention_kwargs(self):
459
+ return self._attention_kwargs
460
+
461
+ @property
462
+ def interrupt(self):
463
+ return self._interrupt
464
+
465
+ @torch.no_grad()
466
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
467
+ def __call__(
468
+ self,
469
+ prompt: Optional[Union[str, List[str]]] = None,
470
+ negative_prompt: Optional[Union[str, List[str]]] = None,
471
+ height: int = 480,
472
+ width: int = 720,
473
+ control_video: Union[torch.FloatTensor] = None,
474
+ ref_image: Union[torch.FloatTensor] = None,
475
+ num_frames: int = 49,
476
+ num_inference_steps: int = 50,
477
+ timesteps: Optional[List[int]] = None,
478
+ guidance_scale: float = 6,
479
+ num_videos_per_prompt: int = 1,
480
+ eta: float = 0.0,
481
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
482
+ latents: Optional[torch.FloatTensor] = None,
483
+ prompt_embeds: Optional[torch.FloatTensor] = None,
484
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
485
+ output_type: str = "numpy",
486
+ return_dict: bool = False,
487
+ callback_on_step_end: Optional[
488
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
489
+ ] = None,
490
+ attention_kwargs: Optional[Dict[str, Any]] = None,
491
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
492
+ clip_image: Image = None,
493
+ max_sequence_length: int = 512,
494
+ comfyui_progressbar: bool = False,
495
+ ) -> Union[WanPipelineOutput, Tuple]:
496
+ """
497
+ Function invoked when calling the pipeline for generation.
498
+ Args:
499
+
500
+ Examples:
501
+
502
+ Returns:
503
+
504
+ """
505
+
506
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
507
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
508
+ num_videos_per_prompt = 1
509
+
510
+ # 1. Check inputs. Raise error if not correct
511
+ self.check_inputs(
512
+ prompt,
513
+ height,
514
+ width,
515
+ negative_prompt,
516
+ callback_on_step_end_tensor_inputs,
517
+ prompt_embeds,
518
+ negative_prompt_embeds,
519
+ )
520
+ self._guidance_scale = guidance_scale
521
+ self._attention_kwargs = attention_kwargs
522
+ self._interrupt = False
523
+
524
+ # 2. Default call parameters
525
+ if prompt is not None and isinstance(prompt, str):
526
+ batch_size = 1
527
+ elif prompt is not None and isinstance(prompt, list):
528
+ batch_size = len(prompt)
529
+ else:
530
+ batch_size = prompt_embeds.shape[0]
531
+
532
+ device = self._execution_device
533
+ weight_dtype = self.text_encoder.dtype
534
+
535
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
536
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
537
+ # corresponds to doing no classifier free guidance.
538
+ do_classifier_free_guidance = guidance_scale > 1.0
539
+
540
+ # 3. Encode input prompt
541
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
542
+ prompt,
543
+ negative_prompt,
544
+ do_classifier_free_guidance,
545
+ num_videos_per_prompt=num_videos_per_prompt,
546
+ prompt_embeds=prompt_embeds,
547
+ negative_prompt_embeds=negative_prompt_embeds,
548
+ max_sequence_length=max_sequence_length,
549
+ device=device,
550
+ )
551
+ if do_classifier_free_guidance:
552
+ prompt_embeds = negative_prompt_embeds + prompt_embeds
553
+
554
+ # 4. Prepare timesteps
555
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
556
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
557
+ else:
558
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
559
+ self._num_timesteps = len(timesteps)
560
+ if comfyui_progressbar:
561
+ from comfy.utils import ProgressBar
562
+ pbar = ProgressBar(num_inference_steps + 2)
563
+
564
+ # 5. Prepare latents.
565
+ latent_channels = self.vae.config.latent_channels
566
+ latents = self.prepare_latents(
567
+ batch_size * num_videos_per_prompt,
568
+ latent_channels,
569
+ num_frames,
570
+ height,
571
+ width,
572
+ weight_dtype,
573
+ device,
574
+ generator,
575
+ latents,
576
+ )
577
+ if comfyui_progressbar:
578
+ pbar.update(1)
579
+
580
+ # Prepare mask latent variables
581
+ if control_video is not None:
582
+ video_length = control_video.shape[2]
583
+ control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
584
+ control_video = control_video.to(dtype=torch.float32)
585
+ control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
586
+ control_video_latents = self.prepare_control_latents(
587
+ None,
588
+ control_video,
589
+ batch_size,
590
+ height,
591
+ width,
592
+ weight_dtype,
593
+ device,
594
+ generator,
595
+ do_classifier_free_guidance
596
+ )[1]
597
+ control_latents = (
598
+ torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
599
+ ).to(device, weight_dtype)
600
+ else:
601
+ control_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
602
+ control_latents = (
603
+ torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
604
+ ).to(device, weight_dtype)
605
+
606
+ if ref_image is not None:
607
+ video_length = ref_image.shape[2]
608
+ ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width)
609
+ ref_image = ref_image.to(dtype=torch.float32)
610
+ ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length)
611
+
612
+ ref_image_latentes = self.prepare_control_latents(
613
+ None,
614
+ ref_image,
615
+ batch_size,
616
+ height,
617
+ width,
618
+ weight_dtype,
619
+ device,
620
+ generator,
621
+ do_classifier_free_guidance
622
+ )[1]
623
+
624
+ ref_image_latentes_conv_in = torch.zeros_like(latents)
625
+ if latents.size()[2] != 1:
626
+ ref_image_latentes_conv_in[:, :, :1] = ref_image_latentes
627
+ ref_image_latentes_conv_in = (
628
+ torch.cat([ref_image_latentes_conv_in] * 2) if do_classifier_free_guidance else ref_image_latentes_conv_in
629
+ ).to(device, weight_dtype)
630
+ control_latents = torch.cat([control_latents, ref_image_latentes_conv_in], dim = 1)
631
+ else:
632
+ ref_image_latentes_conv_in = torch.zeros_like(latents)
633
+ ref_image_latentes_conv_in = (
634
+ torch.cat([ref_image_latentes_conv_in] * 2) if do_classifier_free_guidance else ref_image_latentes_conv_in
635
+ ).to(device, weight_dtype)
636
+ control_latents = torch.cat([control_latents, ref_image_latentes_conv_in], dim = 1)
637
+
638
+ # Prepare clip latent variables
639
+ if clip_image is not None:
640
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
641
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
642
+ clip_context = (
643
+ torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
644
+ )
645
+ else:
646
+ clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0))
647
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
648
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
649
+ clip_context = (
650
+ torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
651
+ )
652
+ clip_context = torch.zeros_like(clip_context)
653
+ if comfyui_progressbar:
654
+ pbar.update(1)
655
+
656
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
657
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
658
+
659
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spacial_compression_ratio, height // self.vae.spacial_compression_ratio)
660
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
661
+ # 7. Denoising loop
662
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
663
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
664
+ for i, t in enumerate(timesteps):
665
+ if self.interrupt:
666
+ continue
667
+
668
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
669
+ if hasattr(self.scheduler, "scale_model_input"):
670
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
671
+
672
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
673
+ timestep = t.expand(latent_model_input.shape[0])
674
+
675
+ # predict noise model_output
676
+ with torch.cuda.amp.autocast(dtype=weight_dtype):
677
+ noise_pred = self.transformer(
678
+ x=latent_model_input,
679
+ context=prompt_embeds,
680
+ t=timestep,
681
+ seq_len=seq_len,
682
+ y=control_latents,
683
+ clip_fea=clip_context,
684
+ )
685
+
686
+ # perform guidance
687
+ if do_classifier_free_guidance:
688
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
689
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
690
+
691
+ # compute the previous noisy sample x_t -> x_t-1
692
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
693
+
694
+ if callback_on_step_end is not None:
695
+ callback_kwargs = {}
696
+ for k in callback_on_step_end_tensor_inputs:
697
+ callback_kwargs[k] = locals()[k]
698
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
699
+
700
+ latents = callback_outputs.pop("latents", latents)
701
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
702
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
703
+
704
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
705
+ progress_bar.update()
706
+ if comfyui_progressbar:
707
+ pbar.update(1)
708
+
709
+ if output_type == "numpy":
710
+ video = self.decode_latents(latents)
711
+ elif not output_type == "latent":
712
+ video = self.decode_latents(latents)
713
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
714
+ else:
715
+ video = latents
716
+
717
+ # Offload all models
718
+ self.maybe_free_model_hooks()
719
+
720
+ if not return_dict:
721
+ video = torch.from_numpy(video)
722
+
723
+ return WanPipelineOutput(videos=video)
rose/pipeline/pipeline_wan_fun_inpaint.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.functional as TF
10
+ from diffusers import FlowMatchEulerDiscreteScheduler
11
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
16
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.video_processor import VideoProcessor
19
+ from einops import rearrange
20
+ from PIL import Image
21
+ from transformers import T5Tokenizer
22
+ from torchvision.utils import save_image
23
+
24
+ from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
25
+ WanT5EncoderModel, WanTransformer3DModel)
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ EXAMPLE_DOC_STRING = """
31
+ Examples:
32
+ ```python
33
+ pass
34
+ ```
35
+ """
36
+
37
+
38
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
39
+ def retrieve_timesteps(
40
+ scheduler,
41
+ num_inference_steps: Optional[int] = None,
42
+ device: Optional[Union[str, torch.device]] = None,
43
+ timesteps: Optional[List[int]] = None,
44
+ sigmas: Optional[List[float]] = None,
45
+ **kwargs,
46
+ ):
47
+ """
48
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
49
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
50
+
51
+ Args:
52
+ scheduler (`SchedulerMixin`):
53
+ The scheduler to get timesteps from.
54
+ num_inference_steps (`int`):
55
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
56
+ must be `None`.
57
+ device (`str` or `torch.device`, *optional*):
58
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
59
+ timesteps (`List[int]`, *optional*):
60
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
61
+ `num_inference_steps` and `sigmas` must be `None`.
62
+ sigmas (`List[float]`, *optional*):
63
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
64
+ `num_inference_steps` and `timesteps` must be `None`.
65
+
66
+ Returns:
67
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
68
+ second element is the number of inference steps.
69
+ """
70
+ if timesteps is not None and sigmas is not None:
71
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
72
+ if timesteps is not None:
73
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
74
+ if not accepts_timesteps:
75
+ raise ValueError(
76
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
77
+ f" timestep schedules. Please check whether you are using the correct scheduler."
78
+ )
79
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
80
+ timesteps = scheduler.timesteps
81
+ num_inference_steps = len(timesteps)
82
+ elif sigmas is not None:
83
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
84
+ if not accept_sigmas:
85
+ raise ValueError(
86
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
87
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
88
+ )
89
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
90
+ timesteps = scheduler.timesteps
91
+ num_inference_steps = len(timesteps)
92
+ else:
93
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
94
+ timesteps = scheduler.timesteps
95
+ return timesteps, num_inference_steps
96
+
97
+
98
+ def resize_mask(mask, latent, process_first_frame_only=True):
99
+ latent_size = latent.size()
100
+ batch_size, channels, num_frames, height, width = mask.shape
101
+
102
+ if process_first_frame_only:
103
+ target_size = list(latent_size[2:])
104
+ target_size[0] = 1
105
+ first_frame_resized = F.interpolate(
106
+ mask[:, :, 0:1, :, :],
107
+ size=target_size,
108
+ mode='trilinear',
109
+ align_corners=False
110
+ )
111
+
112
+ target_size = list(latent_size[2:])
113
+ target_size[0] = target_size[0] - 1
114
+ if target_size[0] != 0:
115
+ remaining_frames_resized = F.interpolate(
116
+ mask[:, :, 1:, :, :],
117
+ size=target_size,
118
+ mode='trilinear',
119
+ align_corners=False
120
+ )
121
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
122
+ else:
123
+ resized_mask = first_frame_resized
124
+ else:
125
+ target_size = list(latent_size[2:])
126
+ resized_mask = F.interpolate(
127
+ mask,
128
+ size=target_size,
129
+ mode='trilinear',
130
+ align_corners=False
131
+ )
132
+ return resized_mask
133
+
134
+
135
+ @dataclass
136
+ class WanPipelineOutput(BaseOutput):
137
+ r"""
138
+ Output class for CogVideo pipelines.
139
+
140
+ Args:
141
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
142
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
143
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
144
+ `(batch_size, num_frames, channels, height, width)`.
145
+ """
146
+
147
+ videos: torch.Tensor
148
+
149
+
150
+ class WanFunInpaintPipeline(DiffusionPipeline):
151
+ r"""
152
+ Pipeline for text-to-video generation using Wan.
153
+
154
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
155
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
156
+ """
157
+
158
+ _optional_components = []
159
+ model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae"
160
+
161
+ _callback_tensor_inputs = [
162
+ "latents",
163
+ "prompt_embeds",
164
+ "negative_prompt_embeds",
165
+ ]
166
+
167
+ def __init__(
168
+ self,
169
+ tokenizer: AutoTokenizer,
170
+ text_encoder: WanT5EncoderModel,
171
+ vae: AutoencoderKLWan,
172
+ transformer: WanTransformer3DModel,
173
+ clip_image_encoder: CLIPModel,
174
+ scheduler: FlowMatchEulerDiscreteScheduler,
175
+ ):
176
+ super().__init__()
177
+
178
+ self.register_modules(
179
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler
180
+ )
181
+
182
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
183
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
184
+ self.mask_processor = VaeImageProcessor(
185
+ vae_scale_factor=self.vae.spacial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
186
+ )
187
+
188
+ def _get_t5_prompt_embeds(
189
+ self,
190
+ prompt: Union[str, List[str]] = None,
191
+ num_videos_per_prompt: int = 1,
192
+ max_sequence_length: int = 512,
193
+ device: Optional[torch.device] = None,
194
+ dtype: Optional[torch.dtype] = None,
195
+ ):
196
+ device = device or self._execution_device
197
+ dtype = dtype or self.text_encoder.dtype
198
+
199
+ prompt = [prompt] if isinstance(prompt, str) else prompt
200
+ batch_size = len(prompt)
201
+
202
+ text_inputs = self.tokenizer(
203
+ prompt,
204
+ padding="max_length",
205
+ max_length=max_sequence_length,
206
+ truncation=True,
207
+ add_special_tokens=True,
208
+ return_tensors="pt",
209
+ )
210
+ text_input_ids = text_inputs.input_ids
211
+ prompt_attention_mask = text_inputs.attention_mask
212
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
213
+
214
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
215
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
216
+ logger.warning(
217
+ "The following part of your input was truncated because `max_sequence_length` is set to "
218
+ f" {max_sequence_length} tokens: {removed_text}"
219
+ )
220
+
221
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
222
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
223
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
224
+
225
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
226
+ _, seq_len, _ = prompt_embeds.shape
227
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
228
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
229
+
230
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
231
+
232
+ def encode_prompt(
233
+ self,
234
+ prompt: Union[str, List[str]],
235
+ negative_prompt: Optional[Union[str, List[str]]] = None,
236
+ do_classifier_free_guidance: bool = True,
237
+ num_videos_per_prompt: int = 1,
238
+ prompt_embeds: Optional[torch.Tensor] = None,
239
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
240
+ max_sequence_length: int = 512,
241
+ device: Optional[torch.device] = None,
242
+ dtype: Optional[torch.dtype] = None,
243
+ ):
244
+ r"""
245
+ Encodes the prompt into text encoder hidden states.
246
+
247
+ Args:
248
+ prompt (`str` or `List[str]`, *optional*):
249
+ prompt to be encoded
250
+ negative_prompt (`str` or `List[str]`, *optional*):
251
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
252
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
253
+ less than `1`).
254
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
255
+ Whether to use classifier free guidance or not.
256
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
257
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
258
+ prompt_embeds (`torch.Tensor`, *optional*):
259
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
260
+ provided, text embeddings will be generated from `prompt` input argument.
261
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
262
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
263
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
264
+ argument.
265
+ device: (`torch.device`, *optional*):
266
+ torch device
267
+ dtype: (`torch.dtype`, *optional*):
268
+ torch dtype
269
+ """
270
+ device = device or self._execution_device
271
+
272
+ prompt = [prompt] if isinstance(prompt, str) else prompt
273
+ if prompt is not None:
274
+ batch_size = len(prompt)
275
+ else:
276
+ batch_size = prompt_embeds.shape[0]
277
+
278
+ if prompt_embeds is None:
279
+ prompt_embeds = self._get_t5_prompt_embeds(
280
+ prompt=prompt,
281
+ num_videos_per_prompt=num_videos_per_prompt,
282
+ max_sequence_length=max_sequence_length,
283
+ device=device,
284
+ dtype=dtype,
285
+ )
286
+
287
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
288
+ negative_prompt = negative_prompt or ""
289
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
290
+
291
+ if prompt is not None and type(prompt) is not type(negative_prompt):
292
+ raise TypeError(
293
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
294
+ f" {type(prompt)}."
295
+ )
296
+ elif batch_size != len(negative_prompt):
297
+ raise ValueError(
298
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
299
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
300
+ " the batch size of `prompt`."
301
+ )
302
+
303
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
304
+ prompt=negative_prompt,
305
+ num_videos_per_prompt=num_videos_per_prompt,
306
+ max_sequence_length=max_sequence_length,
307
+ device=device,
308
+ dtype=dtype,
309
+ )
310
+
311
+ return prompt_embeds, negative_prompt_embeds
312
+
313
+ def prepare_latents(
314
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
315
+ ):
316
+ if isinstance(generator, list) and len(generator) != batch_size:
317
+ raise ValueError(
318
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
319
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
320
+ )
321
+
322
+ shape = (
323
+ batch_size,
324
+ num_channels_latents,
325
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
326
+ height // self.vae.spacial_compression_ratio,
327
+ width // self.vae.spacial_compression_ratio,
328
+ )
329
+
330
+ if latents is None:
331
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
332
+ else:
333
+ latents = latents.to(device)
334
+
335
+ # scale the initial noise by the standard deviation required by the scheduler
336
+ if hasattr(self.scheduler, "init_noise_sigma"):
337
+ latents = latents * self.scheduler.init_noise_sigma
338
+ return latents
339
+
340
+ def prepare_mask_latents(
341
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
342
+ ):
343
+ # resize the mask to latents shape as we concatenate the mask to the latents
344
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
345
+ # and half precision
346
+
347
+ if mask is not None:
348
+ mask = mask.to(device=device, dtype=self.vae.dtype)
349
+ bs = 1
350
+ new_mask = []
351
+ for i in range(0, mask.shape[0], bs):
352
+ mask_bs = mask[i : i + bs]
353
+ mask_bs = self.vae.encode(mask_bs)[0]
354
+ mask_bs = mask_bs.mode()
355
+ new_mask.append(mask_bs)
356
+ mask = torch.cat(new_mask, dim = 0)
357
+ # mask = mask * self.vae.config.scaling_factor
358
+
359
+ if masked_image is not None:
360
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
361
+ bs = 1
362
+ new_mask_pixel_values = []
363
+ for i in range(0, masked_image.shape[0], bs):
364
+ mask_pixel_values_bs = masked_image[i : i + bs]
365
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
366
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
367
+ new_mask_pixel_values.append(mask_pixel_values_bs)
368
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
369
+ # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
370
+ else:
371
+ masked_image_latents = None
372
+
373
+ return mask, masked_image_latents
374
+
375
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
376
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
377
+ frames = (frames / 2 + 0.5).clamp(0, 1)
378
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
379
+ frames = frames.cpu().float().numpy()
380
+ return frames
381
+
382
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
383
+ def prepare_extra_step_kwargs(self, generator, eta):
384
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
385
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
386
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
387
+ # and should be between [0, 1]
388
+
389
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
390
+ extra_step_kwargs = {}
391
+ if accepts_eta:
392
+ extra_step_kwargs["eta"] = eta
393
+
394
+ # check if the scheduler accepts generator
395
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
396
+ if accepts_generator:
397
+ extra_step_kwargs["generator"] = generator
398
+ return extra_step_kwargs
399
+
400
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
401
+ def check_inputs(
402
+ self,
403
+ prompt,
404
+ height,
405
+ width,
406
+ negative_prompt,
407
+ callback_on_step_end_tensor_inputs,
408
+ prompt_embeds=None,
409
+ negative_prompt_embeds=None,
410
+ ):
411
+ if height % 8 != 0 or width % 8 != 0:
412
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
413
+
414
+ if callback_on_step_end_tensor_inputs is not None and not all(
415
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
416
+ ):
417
+ raise ValueError(
418
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
419
+ )
420
+ if prompt is not None and prompt_embeds is not None:
421
+ raise ValueError(
422
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
423
+ " only forward one of the two."
424
+ )
425
+ elif prompt is None and prompt_embeds is None:
426
+ raise ValueError(
427
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
428
+ )
429
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
430
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
431
+
432
+ if prompt is not None and negative_prompt_embeds is not None:
433
+ raise ValueError(
434
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
435
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
436
+ )
437
+
438
+ if negative_prompt is not None and negative_prompt_embeds is not None:
439
+ raise ValueError(
440
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
441
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
442
+ )
443
+
444
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
445
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
446
+ raise ValueError(
447
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
448
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
449
+ f" {negative_prompt_embeds.shape}."
450
+ )
451
+
452
+ @property
453
+ def guidance_scale(self):
454
+ return self._guidance_scale
455
+
456
+ @property
457
+ def num_timesteps(self):
458
+ return self._num_timesteps
459
+
460
+ @property
461
+ def attention_kwargs(self):
462
+ return self._attention_kwargs
463
+
464
+ @property
465
+ def interrupt(self):
466
+ return self._interrupt
467
+
468
+ @torch.no_grad()
469
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
470
+ def __call__(
471
+ self,
472
+ prompt: Optional[Union[str, List[str]]] = None,
473
+ negative_prompt: Optional[Union[str, List[str]]] = None,
474
+ height: int = 480,
475
+ width: int = 720,
476
+ video: Union[torch.FloatTensor] = None,
477
+ mask_video: Union[torch.FloatTensor] = None,
478
+ num_frames: int = 49,
479
+ num_inference_steps: int = 50,
480
+ timesteps: Optional[List[int]] = None,
481
+ guidance_scale: float = 6,
482
+ num_videos_per_prompt: int = 1,
483
+ eta: float = 0.0,
484
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
485
+ latents: Optional[torch.FloatTensor] = None,
486
+ prompt_embeds: Optional[torch.FloatTensor] = None,
487
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
488
+ output_type: str = "numpy",
489
+ return_dict: bool = False,
490
+ callback_on_step_end: Optional[
491
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
492
+ ] = None,
493
+ attention_kwargs: Optional[Dict[str, Any]] = None,
494
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
495
+ clip_image: Image = None,
496
+ max_sequence_length: int = 512,
497
+ comfyui_progressbar: bool = False,
498
+ ) -> Union[WanPipelineOutput, Tuple]:
499
+ """
500
+ Function invoked when calling the pipeline for generation.
501
+ Args:
502
+
503
+ Examples:
504
+
505
+ Returns:
506
+
507
+ """
508
+
509
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
510
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
511
+ num_videos_per_prompt = 1
512
+
513
+ # 1. Check inputs. Raise error if not correct
514
+ self.check_inputs(
515
+ prompt,
516
+ height,
517
+ width,
518
+ negative_prompt,
519
+ callback_on_step_end_tensor_inputs,
520
+ prompt_embeds,
521
+ negative_prompt_embeds,
522
+ )
523
+ self._guidance_scale = guidance_scale
524
+ self._attention_kwargs = attention_kwargs
525
+ self._interrupt = False
526
+
527
+ # 2. Default call parameters
528
+ if prompt is not None and isinstance(prompt, str):
529
+ batch_size = 1
530
+ elif prompt is not None and isinstance(prompt, list):
531
+ batch_size = len(prompt)
532
+ else:
533
+ batch_size = prompt_embeds.shape[0]
534
+
535
+ device = self._execution_device
536
+ weight_dtype = self.text_encoder.dtype
537
+
538
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
539
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
540
+ # corresponds to doing no classifier free guidance.
541
+ do_classifier_free_guidance = guidance_scale > 1.0
542
+
543
+ # 3. Encode input prompt
544
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
545
+ prompt,
546
+ negative_prompt,
547
+ do_classifier_free_guidance,
548
+ num_videos_per_prompt=num_videos_per_prompt,
549
+ prompt_embeds=prompt_embeds,
550
+ negative_prompt_embeds=negative_prompt_embeds,
551
+ max_sequence_length=max_sequence_length,
552
+ device=device,
553
+ )
554
+ if do_classifier_free_guidance:
555
+ prompt_embeds = negative_prompt_embeds + prompt_embeds
556
+
557
+ # 4. Prepare timesteps
558
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
559
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
560
+ else:
561
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
562
+ self._num_timesteps = len(timesteps)
563
+ if comfyui_progressbar:
564
+ from comfy.utils import ProgressBar
565
+ pbar = ProgressBar(num_inference_steps + 2)
566
+
567
+ # 5. Prepare latents.
568
+ if video is not None:
569
+ video_length = video.shape[2]
570
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
571
+ init_video = init_video.to(dtype=torch.float32)
572
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
573
+ else:
574
+ init_video = None
575
+
576
+ # import pdb; pdb.set_trace()
577
+ latent_channels = self.vae.config.latent_channels
578
+ latents = self.prepare_latents(
579
+ batch_size * num_videos_per_prompt,
580
+ latent_channels,
581
+ num_frames,
582
+ height,
583
+ width,
584
+ weight_dtype,
585
+ device,
586
+ generator,
587
+ latents,
588
+ )
589
+ if comfyui_progressbar:
590
+ pbar.update(1)
591
+
592
+ # Prepare mask latent variables
593
+ if init_video is not None:
594
+ if (mask_video == 255).all():
595
+ mask_latents = torch.tile(
596
+ torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1]
597
+ )
598
+ masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
599
+
600
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
601
+ masked_video_latents_input = (
602
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
603
+ )
604
+ y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype)
605
+ else:
606
+ bs, _, video_length, height, width = video.size()
607
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
608
+ mask_condition = mask_condition.to(dtype=torch.float32)
609
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
610
+
611
+ # masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5)
612
+ masked_video = init_video
613
+
614
+ _, masked_video_latents = self.prepare_mask_latents(
615
+ None,
616
+ masked_video,
617
+ batch_size,
618
+ height,
619
+ width,
620
+ weight_dtype,
621
+ device,
622
+ generator,
623
+ do_classifier_free_guidance,
624
+ noise_aug_strength=None,
625
+ )
626
+
627
+ mask_condition = torch.concat(
628
+ [
629
+ torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
630
+ mask_condition[:, :, 1:]
631
+ ], dim=2
632
+ )
633
+ mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width)
634
+ mask_condition = mask_condition.transpose(1, 2)
635
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype)
636
+
637
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
638
+ masked_video_latents_input = (
639
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
640
+ )
641
+
642
+ y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype)
643
+
644
+ # Prepare clip latent variables
645
+ if clip_image is not None:
646
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
647
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
648
+ clip_context = (
649
+ torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
650
+ )
651
+ else:
652
+ clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0))
653
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
654
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
655
+ clip_context = (
656
+ torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
657
+ )
658
+ clip_context = torch.zeros_like(clip_context)
659
+ if comfyui_progressbar:
660
+ pbar.update(1)
661
+
662
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
663
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
664
+
665
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spacial_compression_ratio, height // self.vae.spacial_compression_ratio)
666
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
667
+ # 7. Denoising loop
668
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
669
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
670
+ for i, t in enumerate(timesteps):
671
+ if self.interrupt:
672
+ continue
673
+
674
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
675
+ if hasattr(self.scheduler, "scale_model_input"):
676
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
677
+
678
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
679
+ timestep = t.expand(latent_model_input.shape[0])
680
+
681
+ # predict noise model_output
682
+ with torch.cuda.amp.autocast(dtype=weight_dtype):
683
+ noise_pred = self.transformer(
684
+ x=latent_model_input,
685
+ context=prompt_embeds,
686
+ t=timestep,
687
+ seq_len=seq_len,
688
+ y=y,
689
+ clip_fea=clip_context,
690
+ )
691
+
692
+ # perform guidance
693
+ if do_classifier_free_guidance:
694
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
695
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
696
+
697
+ # compute the previous noisy sample x_t -> x_t-1
698
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
699
+
700
+ if callback_on_step_end is not None:
701
+ callback_kwargs = {}
702
+ for k in callback_on_step_end_tensor_inputs:
703
+ callback_kwargs[k] = locals()[k]
704
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
705
+
706
+ latents = callback_outputs.pop("latents", latents)
707
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
708
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
709
+
710
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
711
+ progress_bar.update()
712
+ if comfyui_progressbar:
713
+ pbar.update(1)
714
+
715
+ if output_type == "numpy":
716
+ video = self.decode_latents(latents)
717
+ elif not output_type == "latent":
718
+ video = self.decode_latents(latents)
719
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
720
+ else:
721
+ video = latents
722
+
723
+ # Offload all models
724
+ self.maybe_free_model_hooks()
725
+
726
+ if not return_dict:
727
+ video = torch.from_numpy(video)
728
+
729
+ return WanPipelineOutput(videos=video)
rose/utils/__init__.py ADDED
File without changes
rose/utils/discrete_sampler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/THUDM/CogVideo/blob/3710a612d8760f5cdb1741befeebb65b9e0f2fe0/sat/sgm/modules/diffusionmodules/sigma_sampling.py
2
+ """
3
+ import torch
4
+
5
+ class DiscreteSampling:
6
+ def __init__(self, num_idx, uniform_sampling=False):
7
+ self.num_idx = num_idx
8
+ self.uniform_sampling = uniform_sampling
9
+ self.is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
10
+
11
+ if self.is_distributed and self.uniform_sampling:
12
+ world_size = torch.distributed.get_world_size()
13
+ self.rank = torch.distributed.get_rank()
14
+
15
+ i = 1
16
+ while True:
17
+ if world_size % i != 0 or num_idx % (world_size // i) != 0:
18
+ i += 1
19
+ else:
20
+ self.group_num = world_size // i
21
+ break
22
+ assert self.group_num > 0
23
+ assert world_size % self.group_num == 0
24
+ # the number of rank in one group
25
+ self.group_width = world_size // self.group_num
26
+ self.sigma_interval = self.num_idx // self.group_num
27
+ print('rank=%d world_size=%d group_num=%d group_width=%d sigma_interval=%s' % (
28
+ self.rank, world_size, self.group_num,
29
+ self.group_width, self.sigma_interval))
30
+
31
+ def __call__(self, n_samples, generator=None, device=None):
32
+ if self.is_distributed and self.uniform_sampling:
33
+ group_index = self.rank // self.group_width
34
+ idx = torch.randint(
35
+ group_index * self.sigma_interval,
36
+ (group_index + 1) * self.sigma_interval,
37
+ (n_samples,),
38
+ generator=generator, device=device,
39
+ )
40
+ print('proc[%d] idx=%s' % (self.rank, idx))
41
+ else:
42
+ idx = torch.randint(
43
+ 0, self.num_idx, (n_samples,),
44
+ generator=generator, device=device,
45
+ )
46
+ return idx
rose/utils/fp8_optimization.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/kijai/ComfyUI-MochiWrapper
2
+ """
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs):
7
+ weight_dtype = cls.weight.dtype
8
+ cls.to(origin_dtype)
9
+
10
+ # Convert all inputs to the original dtype
11
+ inputs = [input.to(origin_dtype) for input in inputs]
12
+ out = cls.original_forward(*inputs, **kwargs)
13
+
14
+ cls.to(weight_dtype)
15
+ return out
16
+
17
+ def replace_parameters_by_name(module, name_keywords, device):
18
+ from torch import nn
19
+ for name, param in list(module.named_parameters(recurse=False)):
20
+ if any(keyword in name for keyword in name_keywords):
21
+ if isinstance(param, nn.Parameter):
22
+ tensor = param.data
23
+ delattr(module, name)
24
+ setattr(module, name, tensor.to(device=device))
25
+ for child_name, child_module in module.named_children():
26
+ replace_parameters_by_name(child_module, name_keywords, device)
27
+
28
+ def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens']):
29
+ for name, module in model.named_modules():
30
+ flag = False
31
+ for _exclude_module_name in exclude_module_name:
32
+ if _exclude_module_name in name:
33
+ flag = True
34
+ if flag:
35
+ continue
36
+ for param_name, param in module.named_parameters():
37
+ flag = False
38
+ for _exclude_module_name in exclude_module_name:
39
+ if _exclude_module_name in param_name:
40
+ flag = True
41
+ if flag:
42
+ continue
43
+ param.data = param.data.to(torch.float8_e4m3fn)
44
+
45
+ def convert_weight_dtype_wrapper(module, origin_dtype):
46
+ for name, module in module.named_modules():
47
+ if name == "" or "embed_tokens" in name:
48
+ continue
49
+ original_forward = module.forward
50
+ if hasattr(module, "weight") and module.weight is not None:
51
+ setattr(module, "original_forward", original_forward)
52
+ setattr(
53
+ module,
54
+ "forward",
55
+ lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs)
56
+ )
rose/utils/lora_utils.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+ # https://github.com/bmaltais/kohya_ss
6
+
7
+ import hashlib
8
+ import math
9
+ import os
10
+ from collections import defaultdict
11
+ from io import BytesIO
12
+ from typing import List, Optional, Type, Union
13
+
14
+ import safetensors.torch
15
+ import torch
16
+ import torch.utils.checkpoint
17
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
18
+ from safetensors.torch import load_file
19
+ from transformers import T5EncoderModel
20
+
21
+
22
+ class LoRAModule(torch.nn.Module):
23
+ """
24
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ lora_name,
30
+ org_module: torch.nn.Module,
31
+ multiplier=1.0,
32
+ lora_dim=4,
33
+ alpha=1,
34
+ dropout=None,
35
+ rank_dropout=None,
36
+ module_dropout=None,
37
+ ):
38
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
39
+ super().__init__()
40
+ self.lora_name = lora_name
41
+
42
+ if org_module.__class__.__name__ == "Conv2d":
43
+ in_dim = org_module.in_channels
44
+ out_dim = org_module.out_channels
45
+ else:
46
+ in_dim = org_module.in_features
47
+ out_dim = org_module.out_features
48
+
49
+ self.lora_dim = lora_dim
50
+ if org_module.__class__.__name__ == "Conv2d":
51
+ kernel_size = org_module.kernel_size
52
+ stride = org_module.stride
53
+ padding = org_module.padding
54
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
55
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
56
+ else:
57
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
58
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
59
+
60
+ if type(alpha) == torch.Tensor:
61
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
62
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
63
+ self.scale = alpha / self.lora_dim
64
+ self.register_buffer("alpha", torch.tensor(alpha))
65
+
66
+ # same as microsoft's
67
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
68
+ torch.nn.init.zeros_(self.lora_up.weight)
69
+
70
+ self.multiplier = multiplier
71
+ self.org_module = org_module # remove in applying
72
+ self.dropout = dropout
73
+ self.rank_dropout = rank_dropout
74
+ self.module_dropout = module_dropout
75
+
76
+ def apply_to(self):
77
+ self.org_forward = self.org_module.forward
78
+ self.org_module.forward = self.forward
79
+ del self.org_module
80
+
81
+ def forward(self, x, *args, **kwargs):
82
+ weight_dtype = x.dtype
83
+ org_forwarded = self.org_forward(x)
84
+
85
+ # module dropout
86
+ if self.module_dropout is not None and self.training:
87
+ if torch.rand(1) < self.module_dropout:
88
+ return org_forwarded
89
+
90
+ lx = self.lora_down(x.to(self.lora_down.weight.dtype))
91
+
92
+ # normal dropout
93
+ if self.dropout is not None and self.training:
94
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
95
+
96
+ # rank dropout
97
+ if self.rank_dropout is not None and self.training:
98
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
99
+ if len(lx.size()) == 3:
100
+ mask = mask.unsqueeze(1) # for Text Encoder
101
+ elif len(lx.size()) == 4:
102
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
103
+ lx = lx * mask
104
+
105
+ # scaling for rank dropout: treat as if the rank is changed
106
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
107
+ else:
108
+ scale = self.scale
109
+
110
+ lx = self.lora_up(lx)
111
+
112
+ return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale
113
+
114
+
115
+ def addnet_hash_legacy(b):
116
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
117
+ m = hashlib.sha256()
118
+
119
+ b.seek(0x100000)
120
+ m.update(b.read(0x10000))
121
+ return m.hexdigest()[0:8]
122
+
123
+
124
+ def addnet_hash_safetensors(b):
125
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
126
+ hash_sha256 = hashlib.sha256()
127
+ blksize = 1024 * 1024
128
+
129
+ b.seek(0)
130
+ header = b.read(8)
131
+ n = int.from_bytes(header, "little")
132
+
133
+ offset = n + 8
134
+ b.seek(offset)
135
+ for chunk in iter(lambda: b.read(blksize), b""):
136
+ hash_sha256.update(chunk)
137
+
138
+ return hash_sha256.hexdigest()
139
+
140
+
141
+ def precalculate_safetensors_hashes(tensors, metadata):
142
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
143
+ save time on indexing the model later."""
144
+
145
+ # Because writing user metadata to the file can change the result of
146
+ # sd_models.model_hash(), only retain the training metadata for purposes of
147
+ # calculating the hash, as they are meant to be immutable
148
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
149
+
150
+ bytes = safetensors.torch.save(tensors, metadata)
151
+ b = BytesIO(bytes)
152
+
153
+ model_hash = addnet_hash_safetensors(b)
154
+ legacy_hash = addnet_hash_legacy(b)
155
+ return model_hash, legacy_hash
156
+
157
+
158
+ class LoRANetwork(torch.nn.Module):
159
+ TRANSFORMER_TARGET_REPLACE_MODULE = ["CogVideoXTransformer3DModel", "WanTransformer3DModel"]
160
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder", "T5SelfAttention", "T5CrossAttention"]
161
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
162
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
163
+ def __init__(
164
+ self,
165
+ text_encoder: Union[List[T5EncoderModel], T5EncoderModel],
166
+ unet,
167
+ multiplier: float = 1.0,
168
+ lora_dim: int = 4,
169
+ alpha: float = 1,
170
+ dropout: Optional[float] = None,
171
+ module_class: Type[object] = LoRAModule,
172
+ skip_name: str = None,
173
+ varbose: Optional[bool] = False,
174
+ ) -> None:
175
+ super().__init__()
176
+ self.multiplier = multiplier
177
+
178
+ self.lora_dim = lora_dim
179
+ self.alpha = alpha
180
+ self.dropout = dropout
181
+
182
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
183
+ print(f"neuron dropout: p={self.dropout}")
184
+
185
+ # create module instances
186
+ def create_modules(
187
+ is_unet: bool,
188
+ root_module: torch.nn.Module,
189
+ target_replace_modules: List[torch.nn.Module],
190
+ ) -> List[LoRAModule]:
191
+ prefix = (
192
+ self.LORA_PREFIX_TRANSFORMER
193
+ if is_unet
194
+ else self.LORA_PREFIX_TEXT_ENCODER
195
+ )
196
+ loras = []
197
+ skipped = []
198
+ for name, module in root_module.named_modules():
199
+ if module.__class__.__name__ in target_replace_modules:
200
+ for child_name, child_module in module.named_modules():
201
+ is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
202
+ is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
203
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
204
+
205
+ if skip_name is not None and skip_name in child_name:
206
+ continue
207
+
208
+ if is_linear or is_conv2d:
209
+ lora_name = prefix + "." + name + "." + child_name
210
+ lora_name = lora_name.replace(".", "_")
211
+
212
+ dim = None
213
+ alpha = None
214
+
215
+ if is_linear or is_conv2d_1x1:
216
+ dim = self.lora_dim
217
+ alpha = self.alpha
218
+
219
+ if dim is None or dim == 0:
220
+ if is_linear or is_conv2d_1x1:
221
+ skipped.append(lora_name)
222
+ continue
223
+
224
+ lora = module_class(
225
+ lora_name,
226
+ child_module,
227
+ self.multiplier,
228
+ dim,
229
+ alpha,
230
+ dropout=dropout,
231
+ )
232
+ loras.append(lora)
233
+ return loras, skipped
234
+
235
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
236
+
237
+ self.text_encoder_loras = []
238
+ skipped_te = []
239
+ for i, text_encoder in enumerate(text_encoders):
240
+ if text_encoder is not None:
241
+ text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
242
+ self.text_encoder_loras.extend(text_encoder_loras)
243
+ skipped_te += skipped
244
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
245
+
246
+ self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
247
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
248
+
249
+ # assertion
250
+ names = set()
251
+ for lora in self.text_encoder_loras + self.unet_loras:
252
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
253
+ names.add(lora.lora_name)
254
+
255
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
256
+ if apply_text_encoder:
257
+ print("enable LoRA for text encoder")
258
+ else:
259
+ self.text_encoder_loras = []
260
+
261
+ if apply_unet:
262
+ print("enable LoRA for U-Net")
263
+ else:
264
+ self.unet_loras = []
265
+
266
+ for lora in self.text_encoder_loras + self.unet_loras:
267
+ lora.apply_to()
268
+ self.add_module(lora.lora_name, lora)
269
+
270
+ def set_multiplier(self, multiplier):
271
+ self.multiplier = multiplier
272
+ for lora in self.text_encoder_loras + self.unet_loras:
273
+ lora.multiplier = self.multiplier
274
+
275
+ def load_weights(self, file):
276
+ if os.path.splitext(file)[1] == ".safetensors":
277
+ from safetensors.torch import load_file
278
+
279
+ weights_sd = load_file(file)
280
+ else:
281
+ weights_sd = torch.load(file, map_location="cpu")
282
+ info = self.load_state_dict(weights_sd, False)
283
+ return info
284
+
285
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
286
+ self.requires_grad_(True)
287
+ all_params = []
288
+
289
+ def enumerate_params(loras):
290
+ params = []
291
+ for lora in loras:
292
+ params.extend(lora.parameters())
293
+ return params
294
+
295
+ if self.text_encoder_loras:
296
+ param_data = {"params": enumerate_params(self.text_encoder_loras)}
297
+ if text_encoder_lr is not None:
298
+ param_data["lr"] = text_encoder_lr
299
+ all_params.append(param_data)
300
+
301
+ if self.unet_loras:
302
+ param_data = {"params": enumerate_params(self.unet_loras)}
303
+ if unet_lr is not None:
304
+ param_data["lr"] = unet_lr
305
+ all_params.append(param_data)
306
+
307
+ return all_params
308
+
309
+ def enable_gradient_checkpointing(self):
310
+ pass
311
+
312
+ def get_trainable_params(self):
313
+ return self.parameters()
314
+
315
+ def save_weights(self, file, dtype, metadata):
316
+ if metadata is not None and len(metadata) == 0:
317
+ metadata = None
318
+
319
+ state_dict = self.state_dict()
320
+
321
+ if dtype is not None:
322
+ for key in list(state_dict.keys()):
323
+ v = state_dict[key]
324
+ v = v.detach().clone().to("cpu").to(dtype)
325
+ state_dict[key] = v
326
+
327
+ if os.path.splitext(file)[1] == ".safetensors":
328
+ from safetensors.torch import save_file
329
+
330
+ # Precalculate model hashes to save time on indexing
331
+ if metadata is None:
332
+ metadata = {}
333
+ model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
334
+ metadata["sshs_model_hash"] = model_hash
335
+ metadata["sshs_legacy_hash"] = legacy_hash
336
+
337
+ save_file(state_dict, file, metadata)
338
+ else:
339
+ torch.save(state_dict, file)
340
+
341
+ def create_network(
342
+ multiplier: float,
343
+ network_dim: Optional[int],
344
+ network_alpha: Optional[float],
345
+ text_encoder: Union[T5EncoderModel, List[T5EncoderModel]],
346
+ transformer,
347
+ neuron_dropout: Optional[float] = None,
348
+ skip_name: str = None,
349
+ **kwargs,
350
+ ):
351
+ if network_dim is None:
352
+ network_dim = 4 # default
353
+ if network_alpha is None:
354
+ network_alpha = 1.0
355
+
356
+ network = LoRANetwork(
357
+ text_encoder,
358
+ transformer,
359
+ multiplier=multiplier,
360
+ lora_dim=network_dim,
361
+ alpha=network_alpha,
362
+ dropout=neuron_dropout,
363
+ skip_name=skip_name,
364
+ varbose=True,
365
+ )
366
+ return network
367
+
368
+ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
369
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
370
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
371
+ if state_dict is None:
372
+ state_dict = load_file(lora_path, device=device)
373
+ else:
374
+ state_dict = state_dict
375
+ updates = defaultdict(dict)
376
+ for key, value in state_dict.items():
377
+ layer, elem = key.split('.', 1)
378
+ updates[layer][elem] = value
379
+
380
+ sequential_cpu_offload_flag = False
381
+ if pipeline.transformer.device == torch.device(type="meta"):
382
+ pipeline.remove_all_hooks()
383
+ sequential_cpu_offload_flag = True
384
+ offload_device = pipeline._offload_device
385
+
386
+ for layer, elems in updates.items():
387
+
388
+ if "lora_te" in layer:
389
+ if transformer_only:
390
+ continue
391
+ else:
392
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
393
+ curr_layer = pipeline.text_encoder
394
+ else:
395
+ layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
396
+ curr_layer = pipeline.transformer
397
+
398
+ try:
399
+ curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
400
+ except Exception:
401
+ temp_name = layer_infos.pop(0)
402
+ while len(layer_infos) > -1:
403
+ try:
404
+ curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos))
405
+ break
406
+ except Exception:
407
+ try:
408
+ curr_layer = curr_layer.__getattr__(temp_name)
409
+ if len(layer_infos) > 0:
410
+ temp_name = layer_infos.pop(0)
411
+ elif len(layer_infos) == 0:
412
+ break
413
+ except Exception:
414
+ if len(layer_infos) == 0:
415
+ print('Error loading layer')
416
+ if len(temp_name) > 0:
417
+ temp_name += "_" + layer_infos.pop(0)
418
+ else:
419
+ temp_name = layer_infos.pop(0)
420
+
421
+ origin_dtype = curr_layer.weight.data.dtype
422
+ origin_device = curr_layer.weight.data.device
423
+
424
+ curr_layer = curr_layer.to(device, dtype)
425
+ weight_up = elems['lora_up.weight'].to(device, dtype)
426
+ weight_down = elems['lora_down.weight'].to(device, dtype)
427
+
428
+ if 'alpha' in elems.keys():
429
+ alpha = elems['alpha'].item() / weight_up.shape[1]
430
+ else:
431
+ alpha = 1.0
432
+
433
+ if len(weight_up.shape) == 4:
434
+ curr_layer.weight.data += multiplier * alpha * torch.mm(
435
+ weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
436
+ ).unsqueeze(2).unsqueeze(3)
437
+ else:
438
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
439
+ curr_layer = curr_layer.to(origin_device, origin_dtype)
440
+
441
+ if sequential_cpu_offload_flag:
442
+ pipeline.enable_sequential_cpu_offload(device=offload_device)
443
+ return pipeline
444
+
445
+ # TODO: Refactor with merge_lora.
446
+ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
447
+ """Unmerge state_dict in LoRANetwork from the pipeline in diffusers."""
448
+ LORA_PREFIX_UNET = "lora_unet"
449
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
450
+ state_dict = load_file(lora_path, device=device)
451
+
452
+ updates = defaultdict(dict)
453
+ for key, value in state_dict.items():
454
+ layer, elem = key.split('.', 1)
455
+ updates[layer][elem] = value
456
+
457
+ sequential_cpu_offload_flag = False
458
+ if pipeline.transformer.device == torch.device(type="meta"):
459
+ pipeline.remove_all_hooks()
460
+ sequential_cpu_offload_flag = True
461
+
462
+ for layer, elems in updates.items():
463
+
464
+ if "lora_te" in layer:
465
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
466
+ curr_layer = pipeline.text_encoder
467
+ else:
468
+ layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
469
+ curr_layer = pipeline.transformer
470
+
471
+ try:
472
+ curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
473
+ except Exception:
474
+ temp_name = layer_infos.pop(0)
475
+ while len(layer_infos) > -1:
476
+ try:
477
+ curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos))
478
+ break
479
+ except Exception:
480
+ try:
481
+ curr_layer = curr_layer.__getattr__(temp_name)
482
+ if len(layer_infos) > 0:
483
+ temp_name = layer_infos.pop(0)
484
+ elif len(layer_infos) == 0:
485
+ break
486
+ except Exception:
487
+ if len(layer_infos) == 0:
488
+ print('Error loading layer')
489
+ if len(temp_name) > 0:
490
+ temp_name += "_" + layer_infos.pop(0)
491
+ else:
492
+ temp_name = layer_infos.pop(0)
493
+
494
+ origin_dtype = curr_layer.weight.data.dtype
495
+ origin_device = curr_layer.weight.data.device
496
+
497
+ curr_layer = curr_layer.to(device, dtype)
498
+ weight_up = elems['lora_up.weight'].to(device, dtype)
499
+ weight_down = elems['lora_down.weight'].to(device, dtype)
500
+
501
+ if 'alpha' in elems.keys():
502
+ alpha = elems['alpha'].item() / weight_up.shape[1]
503
+ else:
504
+ alpha = 1.0
505
+
506
+ if len(weight_up.shape) == 4:
507
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(
508
+ weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
509
+ ).unsqueeze(2).unsqueeze(3)
510
+ else:
511
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
512
+ curr_layer = curr_layer.to(origin_device, origin_dtype)
513
+
514
+ if sequential_cpu_offload_flag:
515
+ pipeline.enable_sequential_cpu_offload(device=device)
516
+ return pipeline
rose/utils/utils.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import imageio
4
+ import inspect
5
+ import numpy as np
6
+ import torch
7
+ import torchvision
8
+ import cv2
9
+ from einops import rearrange
10
+ from PIL import Image
11
+
12
+ def filter_kwargs(cls, kwargs):
13
+ sig = inspect.signature(cls.__init__)
14
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
15
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
16
+ return filtered_kwargs
17
+
18
+ def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
19
+ target_pixels = int(base_resolution) * int(base_resolution)
20
+ original_width, original_height = Image.open(image).size
21
+ ratio = (target_pixels / (original_width * original_height)) ** 0.5
22
+ width_slider = round(original_width * ratio)
23
+ height_slider = round(original_height * ratio)
24
+ return height_slider, width_slider
25
+
26
+ def color_transfer(sc, dc):
27
+ """
28
+ Transfer color distribution from of sc, referred to dc.
29
+
30
+ Args:
31
+ sc (numpy.ndarray): input image to be transfered.
32
+ dc (numpy.ndarray): reference image
33
+
34
+ Returns:
35
+ numpy.ndarray: Transferred color distribution on the sc.
36
+ """
37
+
38
+ def get_mean_and_std(img):
39
+ x_mean, x_std = cv2.meanStdDev(img)
40
+ x_mean = np.hstack(np.around(x_mean, 2))
41
+ x_std = np.hstack(np.around(x_std, 2))
42
+ return x_mean, x_std
43
+
44
+ sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB)
45
+ s_mean, s_std = get_mean_and_std(sc)
46
+ dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB)
47
+ t_mean, t_std = get_mean_and_std(dc)
48
+ img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean
49
+ np.putmask(img_n, img_n > 255, 255)
50
+ np.putmask(img_n, img_n < 0, 0)
51
+ dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB)
52
+ return dst
53
+
54
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False):
55
+ videos = rearrange(videos, "b c t h w -> t b c h w")
56
+ outputs = []
57
+ for x in videos:
58
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
59
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
60
+ if rescale:
61
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
62
+ x = (x * 255).numpy().astype(np.uint8)
63
+ outputs.append(Image.fromarray(x))
64
+
65
+ if color_transfer_post_process:
66
+ for i in range(1, len(outputs)):
67
+ outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0])))
68
+
69
+ os.makedirs(os.path.dirname(path), exist_ok=True)
70
+ if imageio_backend:
71
+ if path.endswith("mp4"):
72
+ imageio.mimsave(path, outputs, fps=fps)
73
+ else:
74
+ imageio.mimsave(path, outputs, duration=(1000 * 1/fps))
75
+ else:
76
+ if path.endswith("mp4"):
77
+ path = path.replace('.mp4', '.gif')
78
+ outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
79
+
80
+ def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
81
+ if validation_image_start is not None and validation_image_end is not None:
82
+ if type(validation_image_start) is str and os.path.isfile(validation_image_start):
83
+ image_start = clip_image = Image.open(validation_image_start).convert("RGB")
84
+ image_start = image_start.resize([sample_size[1], sample_size[0]])
85
+ clip_image = clip_image.resize([sample_size[1], sample_size[0]])
86
+ else:
87
+ image_start = clip_image = validation_image_start
88
+ image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
89
+ clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
90
+
91
+ if type(validation_image_end) is str and os.path.isfile(validation_image_end):
92
+ image_end = Image.open(validation_image_end).convert("RGB")
93
+ image_end = image_end.resize([sample_size[1], sample_size[0]])
94
+ else:
95
+ image_end = validation_image_end
96
+ image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end]
97
+
98
+ if type(image_start) is list:
99
+ clip_image = clip_image[0]
100
+ start_video = torch.cat(
101
+ [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
102
+ dim=2
103
+ )
104
+ input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
105
+ input_video[:, :, :len(image_start)] = start_video
106
+
107
+ input_video_mask = torch.zeros_like(input_video[:, :1])
108
+ input_video_mask[:, :, len(image_start):] = 255
109
+ else:
110
+ input_video = torch.tile(
111
+ torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
112
+ [1, 1, video_length, 1, 1]
113
+ )
114
+ input_video_mask = torch.zeros_like(input_video[:, :1])
115
+ input_video_mask[:, :, 1:] = 255
116
+
117
+ if type(image_end) is list:
118
+ image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end]
119
+ end_video = torch.cat(
120
+ [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end],
121
+ dim=2
122
+ )
123
+ input_video[:, :, -len(end_video):] = end_video
124
+
125
+ input_video_mask[:, :, -len(image_end):] = 0
126
+ else:
127
+ image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size)
128
+ input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0)
129
+ input_video_mask[:, :, -1:] = 0
130
+
131
+ input_video = input_video / 255
132
+
133
+ elif validation_image_start is not None:
134
+ if type(validation_image_start) is str and os.path.isfile(validation_image_start):
135
+ image_start = clip_image = Image.open(validation_image_start).convert("RGB")
136
+ image_start = image_start.resize([sample_size[1], sample_size[0]])
137
+ clip_image = clip_image.resize([sample_size[1], sample_size[0]])
138
+ else:
139
+ image_start = clip_image = validation_image_start
140
+ image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
141
+ clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
142
+ image_end = None
143
+
144
+ if type(image_start) is list:
145
+ clip_image = clip_image[0]
146
+ start_video = torch.cat(
147
+ [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
148
+ dim=2
149
+ )
150
+ input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
151
+ input_video[:, :, :len(image_start)] = start_video
152
+ input_video = input_video / 255
153
+
154
+ input_video_mask = torch.zeros_like(input_video[:, :1])
155
+ input_video_mask[:, :, len(image_start):] = 255
156
+ else:
157
+ input_video = torch.tile(
158
+ torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
159
+ [1, 1, video_length, 1, 1]
160
+ ) / 255
161
+ input_video_mask = torch.zeros_like(input_video[:, :1])
162
+ input_video_mask[:, :, 1:, ] = 255
163
+ else:
164
+ image_start = None
165
+ image_end = None
166
+ input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
167
+ input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
168
+ clip_image = None
169
+
170
+ del image_start
171
+ del image_end
172
+ gc.collect()
173
+
174
+ return input_video, input_video_mask, clip_image
175
+
176
+ def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None):
177
+ if input_video_path is not None:
178
+ if isinstance(input_video_path, str):
179
+ cap = cv2.VideoCapture(input_video_path)
180
+ input_video = []
181
+
182
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
183
+ frame_skip = 1 if fps is None else int(original_fps // fps)
184
+
185
+ frame_count = 0
186
+
187
+ while True:
188
+ ret, frame = cap.read()
189
+ if not ret:
190
+ break
191
+
192
+ if frame_count % frame_skip == 0:
193
+ frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
194
+ input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
195
+
196
+ frame_count += 1
197
+
198
+ cap.release()
199
+ else:
200
+ input_video = input_video_path
201
+
202
+ input_video = torch.from_numpy(np.array(input_video))[:video_length]
203
+ input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
204
+
205
+ if validation_video_mask is not None:
206
+ validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
207
+ input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
208
+
209
+ input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
210
+ input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
211
+ input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
212
+ else:
213
+ input_video_mask = torch.zeros_like(input_video[:, :1])
214
+ input_video_mask[:, :, :] = 255
215
+ else:
216
+ input_video, input_video_mask = None, None
217
+
218
+ if ref_image is not None:
219
+ if isinstance(ref_image, str):
220
+ clip_image = Image.open(ref_image).convert("RGB")
221
+ else:
222
+ clip_image = Image.fromarray(np.array(ref_image, np.uint8))
223
+ else:
224
+ clip_image = None
225
+
226
+ if ref_image is not None:
227
+ if isinstance(ref_image, str):
228
+ ref_image = Image.open(ref_image).convert("RGB")
229
+ ref_image = ref_image.resize((sample_size[1], sample_size[0]))
230
+ ref_image = torch.from_numpy(np.array(ref_image))
231
+ ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
232
+ else:
233
+ ref_image = torch.from_numpy(np.array(ref_image))
234
+ ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
235
+ return input_video, input_video_mask, ref_image, clip_image
236
+
237
+
238
+ def get_video_and_mask(input_video_path, video_length, sample_size, fps=None, input_mask_path=None, ref_image=None):
239
+ if input_video_path is not None:
240
+ if isinstance(input_video_path, str):
241
+ cap = cv2.VideoCapture(input_video_path)
242
+ input_video = []
243
+
244
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
245
+ frame_skip = 1 if fps is None else int(original_fps // fps)
246
+
247
+ frame_count = 0
248
+
249
+ while True:
250
+ ret, frame = cap.read()
251
+ if not ret:
252
+ break
253
+
254
+ if frame_count % frame_skip == 0:
255
+ frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
256
+ input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
257
+
258
+ frame_count += 1
259
+
260
+ cap.release()
261
+ else:
262
+ input_video = input_video_path
263
+
264
+ input_video = torch.from_numpy(np.array(input_video))[:video_length]
265
+ input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255.0
266
+
267
+ else:
268
+ input_video = None
269
+
270
+ if input_mask_path is not None:
271
+ if isinstance(input_mask_path, str):
272
+ cap = cv2.VideoCapture(input_mask_path)
273
+ mask_frames = []
274
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
275
+ frame_skip = 1 if fps is None else int(original_fps // fps)
276
+ frame_count = 0
277
+
278
+ while True:
279
+ ret, frame = cap.read()
280
+ if not ret:
281
+ break
282
+ if frame_count % frame_skip == 0:
283
+ frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
284
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
285
+ mask_frames.append(gray)
286
+ frame_count += 1
287
+ cap.release()
288
+ else:
289
+ mask_frames = input_mask_path
290
+
291
+ mask_np = np.array(mask_frames)[:video_length] # (F, H, W), uint8
292
+ mask_bin = np.where(mask_np < 240, 0, 1).astype(np.uint8) # (F,H,W)
293
+ mask_tensor = torch.from_numpy(mask_bin)
294
+ mask_tensor = mask_tensor.unsqueeze(1)
295
+ mask_tensor = mask_tensor.unsqueeze(0)
296
+ input_mask = mask_tensor.permute(0,2,1,3,4)
297
+ input_mask = input_mask.float()
298
+ else:
299
+ input_mask = None
300
+
301
+ if ref_image is not None:
302
+ if isinstance(ref_image, str):
303
+ clip_image = Image.open(ref_image).convert("RGB")
304
+ else:
305
+ clip_image = Image.fromarray(np.array(ref_image, np.uint8))
306
+ else:
307
+ clip_image = None
308
+
309
+ if ref_image is not None:
310
+ if isinstance(ref_image, str):
311
+ ref_image = Image.open(ref_image).convert("RGB")
312
+ ref_image = ref_image.resize((sample_size[1], sample_size[0]))
313
+ ref_image = torch.from_numpy(np.array(ref_image))
314
+ ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
315
+ else:
316
+ ref_image = torch.from_numpy(np.array(ref_image))
317
+ ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
318
+ return input_video, input_mask, ref_image, clip_image
test_sample/test-sample0.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d63abe5ce28c76c83b0f575e7b8cd2707ffb58e3e478b7f500865700d4738a2
3
+ size 476512
test_sample/test-sample1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54369037f2fb7c42ada4165618894cedcd87d8368d40547571b43e5fd4ff0025
3
+ size 975899
test_sample/test-sample2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11e3e5b0aea5881bf8c5edf8a8196b51bef11fac019ee96192cc9b24146ae07b
3
+ size 286526
test_sample/test-sample3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29f38c57585dad9946aaa32f7ef9faae10bcb87913708462114da5fb164d4775
3
+ size 146705
test_sample/test-sample4.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0a42fa9abcfa72baee2a437b06b4e1981b5109c815795bf005d4f7d5cd47096
3
+ size 1465415
tools/__init__.py ADDED
File without changes
tools/base_segmenter.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import cv2
4
+ from PIL import Image, ImageDraw, ImageOps
5
+ import numpy as np
6
+ from typing import Union
7
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
+ import matplotlib.pyplot as plt
9
+ import PIL
10
+ from .mask_painter import mask_painter
11
+
12
+
13
+ class BaseSegmenter:
14
+ def __init__(self, SAM_checkpoint, model_type, device='cuda:0'):
15
+ """
16
+ device: model device
17
+ SAM_checkpoint: path of SAM checkpoint
18
+ model_type: vit_b, vit_l, vit_h
19
+ """
20
+ print(f"Initializing BaseSegmenter to {device}")
21
+ assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h'
22
+
23
+ self.device = device
24
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
25
+ self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint)
26
+ self.model.to(device=self.device)
27
+ self.predictor = SamPredictor(self.model)
28
+ self.embedded = False
29
+
30
+ @torch.no_grad()
31
+ def set_image(self, image: np.ndarray):
32
+ # PIL.open(image_path) 3channel: RGB
33
+ # image embedding: avoid encode the same image multiple times
34
+ self.orignal_image = image
35
+ if self.embedded:
36
+ print('repeat embedding, please reset_image.')
37
+ return
38
+ self.predictor.set_image(image)
39
+ self.embedded = True
40
+ return
41
+
42
+ @torch.no_grad()
43
+ def reset_image(self):
44
+ # reset image embeding
45
+ self.predictor.reset_image()
46
+ self.embedded = False
47
+
48
+ def predict(self, prompts, mode, multimask=True):
49
+ """
50
+ image: numpy array, h, w, 3
51
+ prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
52
+ prompts['point_coords']: numpy array [N,2]
53
+ prompts['point_labels']: numpy array [1,N]
54
+ prompts['mask_input']: numpy array [1,256,256]
55
+ mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
56
+ mask_outputs: True (return 3 masks), False (return 1 mask only)
57
+ whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
58
+ """
59
+ assert self.embedded, 'prediction is called before set_image (feature embedding).'
60
+ assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both'
61
+
62
+ if mode == 'point':
63
+ masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
64
+ point_labels=prompts['point_labels'],
65
+ multimask_output=multimask)
66
+ elif mode == 'mask':
67
+ masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'],
68
+ multimask_output=multimask)
69
+ elif mode == 'both': # both
70
+ masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
71
+ point_labels=prompts['point_labels'],
72
+ mask_input=prompts['mask_input'],
73
+ multimask_output=multimask)
74
+ else:
75
+ raise("Not implement now!")
76
+ # masks (n, h, w), scores (n,), logits (n, 256, 256)
77
+ return masks, scores, logits
78
+
79
+
80
+ if __name__ == "__main__":
81
+ # load and show an image
82
+ image = cv2.imread('/hhd3/gaoshang/truck.jpg')
83
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
84
+
85
+ # initialise BaseSegmenter
86
+ SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
87
+ model_type = 'vit_h'
88
+ device = "cuda:4"
89
+ base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)
90
+
91
+ # image embedding (once embedded, multiple prompts can be applied)
92
+ base_segmenter.set_image(image)
93
+
94
+ # examples
95
+ # point only ------------------------
96
+ mode = 'point'
97
+ prompts = {
98
+ 'point_coords': np.array([[500, 375], [1125, 625]]),
99
+ 'point_labels': np.array([1, 1]),
100
+ }
101
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
102
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
103
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
104
+ cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
105
+
106
+ # both ------------------------
107
+ mode = 'both'
108
+ mask_input = logits[np.argmax(scores), :, :]
109
+ prompts = {'mask_input': mask_input [None, :, :]}
110
+ prompts = {
111
+ 'point_coords': np.array([[500, 375], [1125, 625]]),
112
+ 'point_labels': np.array([1, 0]),
113
+ 'mask_input': mask_input[None, :, :]
114
+ }
115
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
116
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
117
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
118
+ cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
119
+
120
+ # mask only ------------------------
121
+ mode = 'mask'
122
+ mask_input = logits[np.argmax(scores), :, :]
123
+
124
+ prompts = {'mask_input': mask_input[None, :, :]}
125
+
126
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
127
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
128
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
129
+ cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)
tools/interact_tools.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import cv2
4
+ from PIL import Image, ImageDraw, ImageOps
5
+ import numpy as np
6
+ from typing import Union
7
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
+ import matplotlib.pyplot as plt
9
+ import PIL
10
+ from .mask_painter import mask_painter as mask_painter2
11
+ from .base_segmenter import BaseSegmenter
12
+ from .painter import mask_painter, point_painter
13
+ import os
14
+ import requests
15
+ import sys
16
+
17
+
18
+ mask_color = 3
19
+ mask_alpha = 0.7
20
+ contour_color = 1
21
+ contour_width = 5
22
+ point_color_ne = 8
23
+ point_color_ps = 50
24
+ point_alpha = 0.9
25
+ point_radius = 15
26
+ contour_color = 2
27
+ contour_width = 5
28
+
29
+
30
+ class SamControler():
31
+ def __init__(self, SAM_checkpoint, model_type, device):
32
+ '''
33
+ initialize sam controler
34
+ '''
35
+ self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
36
+
37
+
38
+ # def seg_again(self, image: np.ndarray):
39
+ # '''
40
+ # it is used when interact in video
41
+ # '''
42
+ # self.sam_controler.reset_image()
43
+ # self.sam_controler.set_image(image)
44
+ # return
45
+
46
+
47
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
48
+ '''
49
+ it is used in first frame in video
50
+ return: mask, logit, painted image(mask+point)
51
+ '''
52
+ # self.sam_controler.set_image(image)
53
+ origal_image = self.sam_controler.orignal_image
54
+ neg_flag = labels[-1]
55
+ if neg_flag==1:
56
+ #find neg
57
+ prompts = {
58
+ 'point_coords': points,
59
+ 'point_labels': labels,
60
+ }
61
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
62
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
63
+ prompts = {
64
+ 'point_coords': points,
65
+ 'point_labels': labels,
66
+ 'mask_input': logit[None, :, :]
67
+ }
68
+ masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
69
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
70
+ else:
71
+ #find positive
72
+ prompts = {
73
+ 'point_coords': points,
74
+ 'point_labels': labels,
75
+ }
76
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
77
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
78
+
79
+
80
+ assert len(points)==len(labels)
81
+
82
+ painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
83
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
84
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
85
+ painted_image = Image.fromarray(painted_image)
86
+
87
+ return mask, logit, painted_image
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
tools/mask_painter.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import copy
6
+ import time
7
+
8
+
9
+ def colormap(rgb=True):
10
+ color_list = np.array(
11
+ [
12
+ 0.000, 0.000, 0.000,
13
+ 1.000, 1.000, 1.000,
14
+ 1.000, 0.498, 0.313,
15
+ 0.392, 0.581, 0.929,
16
+ 0.000, 0.447, 0.741,
17
+ 0.850, 0.325, 0.098,
18
+ 0.929, 0.694, 0.125,
19
+ 0.494, 0.184, 0.556,
20
+ 0.466, 0.674, 0.188,
21
+ 0.301, 0.745, 0.933,
22
+ 0.635, 0.078, 0.184,
23
+ 0.300, 0.300, 0.300,
24
+ 0.600, 0.600, 0.600,
25
+ 1.000, 0.000, 0.000,
26
+ 1.000, 0.500, 0.000,
27
+ 0.749, 0.749, 0.000,
28
+ 0.000, 1.000, 0.000,
29
+ 0.000, 0.000, 1.000,
30
+ 0.667, 0.000, 1.000,
31
+ 0.333, 0.333, 0.000,
32
+ 0.333, 0.667, 0.000,
33
+ 0.333, 1.000, 0.000,
34
+ 0.667, 0.333, 0.000,
35
+ 0.667, 0.667, 0.000,
36
+ 0.667, 1.000, 0.000,
37
+ 1.000, 0.333, 0.000,
38
+ 1.000, 0.667, 0.000,
39
+ 1.000, 1.000, 0.000,
40
+ 0.000, 0.333, 0.500,
41
+ 0.000, 0.667, 0.500,
42
+ 0.000, 1.000, 0.500,
43
+ 0.333, 0.000, 0.500,
44
+ 0.333, 0.333, 0.500,
45
+ 0.333, 0.667, 0.500,
46
+ 0.333, 1.000, 0.500,
47
+ 0.667, 0.000, 0.500,
48
+ 0.667, 0.333, 0.500,
49
+ 0.667, 0.667, 0.500,
50
+ 0.667, 1.000, 0.500,
51
+ 1.000, 0.000, 0.500,
52
+ 1.000, 0.333, 0.500,
53
+ 1.000, 0.667, 0.500,
54
+ 1.000, 1.000, 0.500,
55
+ 0.000, 0.333, 1.000,
56
+ 0.000, 0.667, 1.000,
57
+ 0.000, 1.000, 1.000,
58
+ 0.333, 0.000, 1.000,
59
+ 0.333, 0.333, 1.000,
60
+ 0.333, 0.667, 1.000,
61
+ 0.333, 1.000, 1.000,
62
+ 0.667, 0.000, 1.000,
63
+ 0.667, 0.333, 1.000,
64
+ 0.667, 0.667, 1.000,
65
+ 0.667, 1.000, 1.000,
66
+ 1.000, 0.000, 1.000,
67
+ 1.000, 0.333, 1.000,
68
+ 1.000, 0.667, 1.000,
69
+ 0.167, 0.000, 0.000,
70
+ 0.333, 0.000, 0.000,
71
+ 0.500, 0.000, 0.000,
72
+ 0.667, 0.000, 0.000,
73
+ 0.833, 0.000, 0.000,
74
+ 1.000, 0.000, 0.000,
75
+ 0.000, 0.167, 0.000,
76
+ 0.000, 0.333, 0.000,
77
+ 0.000, 0.500, 0.000,
78
+ 0.000, 0.667, 0.000,
79
+ 0.000, 0.833, 0.000,
80
+ 0.000, 1.000, 0.000,
81
+ 0.000, 0.000, 0.167,
82
+ 0.000, 0.000, 0.333,
83
+ 0.000, 0.000, 0.500,
84
+ 0.000, 0.000, 0.667,
85
+ 0.000, 0.000, 0.833,
86
+ 0.000, 0.000, 1.000,
87
+ 0.143, 0.143, 0.143,
88
+ 0.286, 0.286, 0.286,
89
+ 0.429, 0.429, 0.429,
90
+ 0.571, 0.571, 0.571,
91
+ 0.714, 0.714, 0.714,
92
+ 0.857, 0.857, 0.857
93
+ ]
94
+ ).astype(np.float32)
95
+ color_list = color_list.reshape((-1, 3)) * 255
96
+ if not rgb:
97
+ color_list = color_list[:, ::-1]
98
+ return color_list
99
+
100
+
101
+ color_list = colormap()
102
+ color_list = color_list.astype('uint8').tolist()
103
+
104
+
105
+ def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
106
+ background_color = np.array(background_color)
107
+ contour_color = np.array(contour_color)
108
+
109
+ # background_mask = 1 - background_mask
110
+ # contour_mask = 1 - contour_mask
111
+
112
+ for i in range(3):
113
+ image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
114
+ + background_color[i] * (background_alpha-background_mask*background_alpha)
115
+
116
+ image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
117
+ + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
118
+
119
+ return image.astype('uint8')
120
+
121
+
122
+ def mask_generator_00(mask, background_radius, contour_radius):
123
+ # no background width when '00'
124
+ # distance map
125
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
126
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
127
+ dist_map = dist_transform_fore - dist_transform_back
128
+ # ...:::!!!:::...
129
+ contour_radius += 2
130
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
131
+ contour_mask = contour_mask / np.max(contour_mask)
132
+ contour_mask[contour_mask>0.5] = 1.
133
+
134
+ return mask, contour_mask
135
+
136
+
137
+ def mask_generator_01(mask, background_radius, contour_radius):
138
+ # no background width when '00'
139
+ # distance map
140
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
141
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
142
+ dist_map = dist_transform_fore - dist_transform_back
143
+ # ...:::!!!:::...
144
+ contour_radius += 2
145
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
146
+ contour_mask = contour_mask / np.max(contour_mask)
147
+ return mask, contour_mask
148
+
149
+
150
+ def mask_generator_10(mask, background_radius, contour_radius):
151
+ # distance map
152
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
153
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
154
+ dist_map = dist_transform_fore - dist_transform_back
155
+ # .....:::::!!!!!
156
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
157
+ background_mask = (background_mask - np.min(background_mask))
158
+ background_mask = background_mask / np.max(background_mask)
159
+ # ...:::!!!:::...
160
+ contour_radius += 2
161
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
162
+ contour_mask = contour_mask / np.max(contour_mask)
163
+ contour_mask[contour_mask>0.5] = 1.
164
+ return background_mask, contour_mask
165
+
166
+
167
+ def mask_generator_11(mask, background_radius, contour_radius):
168
+ # distance map
169
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
170
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
171
+ dist_map = dist_transform_fore - dist_transform_back
172
+ # .....:::::!!!!!
173
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
174
+ background_mask = (background_mask - np.min(background_mask))
175
+ background_mask = background_mask / np.max(background_mask)
176
+ # ...:::!!!:::...
177
+ contour_radius += 2
178
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
179
+ contour_mask = contour_mask / np.max(contour_mask)
180
+ return background_mask, contour_mask
181
+
182
+
183
+ def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
184
+ """
185
+ Input:
186
+ input_image: numpy array
187
+ input_mask: numpy array
188
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
189
+ background_blur_radius: radius of background blur, must be odd number
190
+ contour_width: width of mask contour, must be odd number
191
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
192
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
193
+ mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
194
+
195
+ Output:
196
+ painted_image: numpy array
197
+ """
198
+ assert input_image.shape[:2] == input_mask.shape, 'different shape'
199
+ assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
200
+ assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
201
+
202
+ # downsample input image and mask
203
+ width, height = input_image.shape[0], input_image.shape[1]
204
+ res = 1024
205
+ ratio = min(1.0 * res / max(width, height), 1.0)
206
+ input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
207
+ input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
208
+
209
+ # 0: background, 1: foreground
210
+ msk = np.clip(input_mask, 0, 1)
211
+
212
+ # generate masks for background and contour pixels
213
+ background_radius = (background_blur_radius - 1) // 2
214
+ contour_radius = (contour_width - 1) // 2
215
+ generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
216
+ background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
217
+
218
+ # paint
219
+ painted_image = vis_add_mask\
220
+ (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
221
+
222
+ return painted_image
223
+
224
+
225
+ if __name__ == '__main__':
226
+
227
+ background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
228
+ background_blur_radius = 31 # radius of background blur, must be odd number
229
+ contour_width = 11 # contour width, must be odd number
230
+ contour_color = 3 # id in color map, 0: black, 1: white, >1: others
231
+ contour_alpha = 1 # transparency of background, 0: no contour highlighted
232
+
233
+ # load input image and mask
234
+ input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
235
+ input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
236
+
237
+ # paint
238
+ overall_time_1 = 0
239
+ overall_time_2 = 0
240
+ overall_time_3 = 0
241
+ overall_time_4 = 0
242
+ overall_time_5 = 0
243
+
244
+ for i in range(50):
245
+ t2 = time.time()
246
+ painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
247
+ e2 = time.time()
248
+
249
+ t3 = time.time()
250
+ painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
251
+ e3 = time.time()
252
+
253
+ t1 = time.time()
254
+ painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
255
+ e1 = time.time()
256
+
257
+ t4 = time.time()
258
+ painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
259
+ e4 = time.time()
260
+
261
+ t5 = time.time()
262
+ painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
263
+ e5 = time.time()
264
+
265
+ overall_time_1 += (e1 - t1)
266
+ overall_time_2 += (e2 - t2)
267
+ overall_time_3 += (e3 - t3)
268
+ overall_time_4 += (e4 - t4)
269
+ overall_time_5 += (e5 - t5)
270
+
271
+ print(f'average time w gaussian: {overall_time_1/50}')
272
+ print(f'average time w/o gaussian00: {overall_time_2/50}')
273
+ print(f'average time w/o gaussian10: {overall_time_3/50}')
274
+ print(f'average time w/o gaussian01: {overall_time_4/50}')
275
+ print(f'average time w/o gaussian11: {overall_time_5/50}')
276
+
277
+ # save
278
+ painted_image_00 = Image.fromarray(painted_image_00)
279
+ painted_image_00.save('./test_img/painter_output_image_00.png')
280
+
281
+ painted_image_10 = Image.fromarray(painted_image_10)
282
+ painted_image_10.save('./test_img/painter_output_image_10.png')
283
+
284
+ painted_image_01 = Image.fromarray(painted_image_01)
285
+ painted_image_01.save('./test_img/painter_output_image_01.png')
286
+
287
+ painted_image_11 = Image.fromarray(painted_image_11)
288
+ painted_image_11.save('./test_img/painter_output_image_11.png')
tools/painter.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # paint masks, contours, or points on images, with specified colors
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import copy
7
+ import time
8
+
9
+
10
+ def colormap(rgb=True):
11
+ color_list = np.array(
12
+ [
13
+ 0.000, 0.000, 0.000,
14
+ 1.000, 1.000, 1.000,
15
+ 1.000, 0.498, 0.313,
16
+ 0.392, 0.581, 0.929,
17
+ 0.000, 0.447, 0.741,
18
+ 0.850, 0.325, 0.098,
19
+ 0.929, 0.694, 0.125,
20
+ 0.494, 0.184, 0.556,
21
+ 0.466, 0.674, 0.188,
22
+ 0.301, 0.745, 0.933,
23
+ 0.635, 0.078, 0.184,
24
+ 0.300, 0.300, 0.300,
25
+ 0.600, 0.600, 0.600,
26
+ 1.000, 0.000, 0.000,
27
+ 1.000, 0.500, 0.000,
28
+ 0.749, 0.749, 0.000,
29
+ 0.000, 1.000, 0.000,
30
+ 0.000, 0.000, 1.000,
31
+ 0.667, 0.000, 1.000,
32
+ 0.333, 0.333, 0.000,
33
+ 0.333, 0.667, 0.000,
34
+ 0.333, 1.000, 0.000,
35
+ 0.667, 0.333, 0.000,
36
+ 0.667, 0.667, 0.000,
37
+ 0.667, 1.000, 0.000,
38
+ 1.000, 0.333, 0.000,
39
+ 1.000, 0.667, 0.000,
40
+ 1.000, 1.000, 0.000,
41
+ 0.000, 0.333, 0.500,
42
+ 0.000, 0.667, 0.500,
43
+ 0.000, 1.000, 0.500,
44
+ 0.333, 0.000, 0.500,
45
+ 0.333, 0.333, 0.500,
46
+ 0.333, 0.667, 0.500,
47
+ 0.333, 1.000, 0.500,
48
+ 0.667, 0.000, 0.500,
49
+ 0.667, 0.333, 0.500,
50
+ 0.667, 0.667, 0.500,
51
+ 0.667, 1.000, 0.500,
52
+ 1.000, 0.000, 0.500,
53
+ 1.000, 0.333, 0.500,
54
+ 1.000, 0.667, 0.500,
55
+ 1.000, 1.000, 0.500,
56
+ 0.000, 0.333, 1.000,
57
+ 0.000, 0.667, 1.000,
58
+ 0.000, 1.000, 1.000,
59
+ 0.333, 0.000, 1.000,
60
+ 0.333, 0.333, 1.000,
61
+ 0.333, 0.667, 1.000,
62
+ 0.333, 1.000, 1.000,
63
+ 0.667, 0.000, 1.000,
64
+ 0.667, 0.333, 1.000,
65
+ 0.667, 0.667, 1.000,
66
+ 0.667, 1.000, 1.000,
67
+ 1.000, 0.000, 1.000,
68
+ 1.000, 0.333, 1.000,
69
+ 1.000, 0.667, 1.000,
70
+ 0.167, 0.000, 0.000,
71
+ 0.333, 0.000, 0.000,
72
+ 0.500, 0.000, 0.000,
73
+ 0.667, 0.000, 0.000,
74
+ 0.833, 0.000, 0.000,
75
+ 1.000, 0.000, 0.000,
76
+ 0.000, 0.167, 0.000,
77
+ 0.000, 0.333, 0.000,
78
+ 0.000, 0.500, 0.000,
79
+ 0.000, 0.667, 0.000,
80
+ 0.000, 0.833, 0.000,
81
+ 0.000, 1.000, 0.000,
82
+ 0.000, 0.000, 0.167,
83
+ 0.000, 0.000, 0.333,
84
+ 0.000, 0.000, 0.500,
85
+ 0.000, 0.000, 0.667,
86
+ 0.000, 0.000, 0.833,
87
+ 0.000, 0.000, 1.000,
88
+ 0.143, 0.143, 0.143,
89
+ 0.286, 0.286, 0.286,
90
+ 0.429, 0.429, 0.429,
91
+ 0.571, 0.571, 0.571,
92
+ 0.714, 0.714, 0.714,
93
+ 0.857, 0.857, 0.857
94
+ ]
95
+ ).astype(np.float32)
96
+ color_list = color_list.reshape((-1, 3)) * 255
97
+ if not rgb:
98
+ color_list = color_list[:, ::-1]
99
+ return color_list
100
+
101
+
102
+ color_list = colormap()
103
+ color_list = color_list.astype('uint8').tolist()
104
+
105
+
106
+ def vis_add_mask(image, mask, color, alpha):
107
+ color = np.array(color_list[color])
108
+ mask = mask > 0.5
109
+ image[mask] = image[mask] * (1-alpha) + color * alpha
110
+ return image.astype('uint8')
111
+
112
+ def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5):
113
+ h, w = input_image.shape[:2]
114
+ point_mask = np.zeros((h, w)).astype('uint8')
115
+ for point in input_points:
116
+ point_mask[point[1], point[0]] = 1
117
+
118
+ kernel = cv2.getStructuringElement(2, (point_radius, point_radius))
119
+ point_mask = cv2.dilate(point_mask, kernel)
120
+
121
+ contour_radius = (contour_width - 1) // 2
122
+ dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3)
123
+ dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3)
124
+ dist_map = dist_transform_fore - dist_transform_back
125
+ # ...:::!!!:::...
126
+ contour_radius += 2
127
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
128
+ contour_mask = contour_mask / np.max(contour_mask)
129
+ contour_mask[contour_mask>0.5] = 1.
130
+
131
+ # paint mask
132
+ painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha)
133
+ # paint contour
134
+ painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
135
+ return painted_image
136
+
137
+ def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3):
138
+ assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
139
+ # 0: background, 1: foreground
140
+ mask = np.clip(input_mask, 0, 1)
141
+ contour_radius = (contour_width - 1) // 2
142
+
143
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
144
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
145
+ dist_map = dist_transform_fore - dist_transform_back
146
+ # ...:::!!!:::...
147
+ contour_radius += 2
148
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
149
+ contour_mask = contour_mask / np.max(contour_mask)
150
+ contour_mask[contour_mask>0.5] = 1.
151
+
152
+ # paint mask
153
+ painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha)
154
+ # paint contour
155
+ painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
156
+
157
+ return painted_image
158
+
159
+ def background_remover(input_image, input_mask):
160
+ """
161
+ input_image: H, W, 3, np.array
162
+ input_mask: H, W, np.array
163
+
164
+ image_wo_background: PIL.Image
165
+ """
166
+ assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
167
+ # 0: background, 1: foreground
168
+ mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255
169
+ image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4
170
+ image_wo_background = Image.fromarray(image_wo_background).convert('RGBA')
171
+
172
+ return image_wo_background
173
+
174
+ if __name__ == '__main__':
175
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
176
+ input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P'))
177
+
178
+ # example of mask painter
179
+ mask_color = 3
180
+ mask_alpha = 0.7
181
+ contour_color = 1
182
+ contour_width = 5
183
+
184
+ # save
185
+ painted_image = Image.fromarray(input_image)
186
+ painted_image.save('images/original.png')
187
+
188
+ painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width)
189
+ # save
190
+ painted_image = Image.fromarray(input_image)
191
+ painted_image.save('images/original1.png')
192
+
193
+ # example of point painter
194
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
195
+ input_points = np.array([[500, 375], [70, 600]]) # x, y
196
+ point_color = 5
197
+ point_alpha = 0.9
198
+ point_radius = 15
199
+ contour_color = 2
200
+ contour_width = 5
201
+ painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width)
202
+ # save
203
+ painted_image = Image.fromarray(painted_image_1)
204
+ painted_image.save('images/point_painter_1.png')
205
+
206
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
207
+ painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29)
208
+ # save
209
+ painted_image = Image.fromarray(painted_image_2)
210
+ painted_image.save('images/point_painter_2.png')
211
+
212
+ # example of background remover
213
+ input_image = np.array(Image.open('images/original.png').convert('RGB'))
214
+ image_wo_background = background_remover(input_image, input_mask) # return PIL.Image
215
+ image_wo_background.save('images/image_wo_background.png')
track_anything.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+
4
+ from tools.interact_tools import SamControler
5
+ from tracker.base_tracker import BaseTracker
6
+ # from inpainter.base_inpainter import ProInpainter
7
+
8
+
9
+ class TrackingAnything():
10
+ def __init__(self, sam_checkpoint, cutie_checkpoint, args):
11
+ self.args = args
12
+ self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
13
+ self.cutie = BaseTracker(cutie_checkpoint, device=args.device)
14
+ # self.baseinpainter = ProInpainter(propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, args.device)
15
+
16
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
17
+ mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
18
+ return mask, logit, painted_image
19
+
20
+ def generator(self, images: list, template_mask:np.ndarray):
21
+ masks = []
22
+ logits = []
23
+ painted_images = []
24
+ for i in tqdm(range(len(images)), desc="Tracking image"):
25
+ if i==0:
26
+ mask, logit, painted_image = self.cutie.track(images[i], template_mask)
27
+ masks.append(mask)
28
+ logits.append(logit)
29
+ painted_images.append(painted_image)
30
+ else:
31
+ mask, logit, painted_image = self.cutie.track(images[i])
32
+ masks.append(mask)
33
+ logits.append(logit)
34
+ painted_images.append(painted_image)
35
+ return masks, logits, painted_images
36
+
37
+
38
+
39
+
40
+
tracker/base_tracker.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from omegaconf import OmegaConf
5
+
6
+ import sys
7
+ sys.path.append('../')
8
+
9
+ from tracker.config import CONFIG
10
+ from tracker.model.cutie import CUTIE
11
+ from tracker.inference.inference_core import InferenceCore
12
+ from tracker.utils.mask_mapper import MaskMapper
13
+
14
+ from tools.painter import mask_painter
15
+
16
+
17
+ class BaseTracker:
18
+ def __init__(self, cutie_checkpoint, device) -> None:
19
+ """
20
+ device: model device
21
+ cutie_checkpoint: checkpoint of XMem model
22
+ """
23
+ config = OmegaConf.create(CONFIG)
24
+
25
+ # initialise XMem
26
+ network = CUTIE(config).to(device).eval()
27
+ model_weights = torch.load(cutie_checkpoint, map_location=device)
28
+ network.load_weights(model_weights)
29
+
30
+ # initialise IncerenceCore
31
+ self.tracker = InferenceCore(network, config)
32
+ self.device = device
33
+
34
+ # changable properties
35
+ self.mapper = MaskMapper()
36
+ self.initialised = False
37
+
38
+ @torch.no_grad()
39
+ def resize_mask(self, mask):
40
+ # mask transform is applied AFTER mapper, so we need to post-process it in eval.py
41
+ h, w = mask.shape[-2:]
42
+ min_hw = min(h, w)
43
+ return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
44
+ mode='nearest')
45
+
46
+ @torch.no_grad()
47
+ def image_to_torch(self, frame: np.ndarray, device: str = 'cuda'):
48
+ # frame: H*W*3 numpy array
49
+ frame = frame.transpose(2, 0, 1)
50
+ frame = torch.from_numpy(frame).float().to(device, non_blocking=True) / 255
51
+ return frame
52
+
53
+ @torch.no_grad()
54
+ def track(self, frame, first_frame_annotation=None):
55
+ """
56
+ Input:
57
+ frames: numpy arrays (H, W, 3)
58
+ logit: numpy array (H, W), logit
59
+
60
+ Output:
61
+ mask: numpy arrays (H, W)
62
+ logit: numpy arrays, probability map (H, W)
63
+ painted_image: numpy array (H, W, 3)
64
+ """
65
+
66
+ if first_frame_annotation is not None: # first frame mask
67
+ # initialisation
68
+ mask, labels = self.mapper.convert_mask(first_frame_annotation)
69
+ mask = torch.Tensor(mask).to(self.device)
70
+ else:
71
+ mask = None
72
+ labels = None
73
+
74
+ # prepare inputs
75
+ frame_tensor = self.image_to_torch(frame, self.device)
76
+
77
+ # track one frame
78
+ probs = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
79
+
80
+ # convert to mask
81
+ out_mask = torch.argmax(probs, dim=0)
82
+ out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
83
+
84
+ final_mask = np.zeros_like(out_mask)
85
+
86
+ # map back
87
+ for k, v in self.mapper.remappings.items():
88
+ final_mask[out_mask == v] = k
89
+
90
+ num_objs = final_mask.max()
91
+ painted_image = frame
92
+ for obj in range(1, num_objs+1):
93
+ if np.max(final_mask==obj) == 0:
94
+ continue
95
+ painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
96
+
97
+ return final_mask, final_mask, painted_image
98
+
99
+ @torch.no_grad()
100
+ def clear_memory(self):
101
+ self.tracker.clear_memory()
102
+ self.mapper.clear_labels()
103
+ torch.cuda.empty_cache()
tracker/config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ CONFIG = {'exp_id': 'default', 'dataset': 'd17-val', 'amp': False, 'output_dir': None, 'flip_aug': False, 'max_internal_size': -1, 'image_directory': None, 'mask_directory': None, 'json_directory': None, 'size': None, 'save_all': None, 'use_all_masks': None, 'use_long_term': None, 'mem_every': 5, 'max_mem_frames': 5, 'long_term': {'count_usage': True, 'max_mem_frames': 10, 'min_mem_frames': 5, 'num_prototypes': 128, 'max_num_tokens': 10000, 'buffer_tokens': 2000}, 'top_k': 30, 'stagger_updates': 5, 'chunk_size': -1, 'save_scores': False, 'save_aux': False, 'visualize': False, 'model': {'pixel_mean': [0.485, 0.456, 0.406], 'pixel_std': [0.229, 0.224, 0.225], 'pixel_dim': 256, 'key_dim': 64, 'value_dim': 256, 'sensory_dim': 256, 'embed_dim': 256, 'pixel_encoder': {'type': 'resnet50', 'ms_dims': [1024, 512, 256]}, 'mask_encoder': {'type': 'resnet18', 'final_dim': 256}, 'pixel_pe_scale': 32, 'pixel_pe_temperature': 128, 'object_transformer': {'embed_dim': '${model.embed_dim}', 'ff_dim': 2048, 'num_heads': 8, 'num_blocks': 3, 'num_queries': 16, 'read_from_pixel': {'input_norm': False, 'input_add_pe': False, 'add_pe_to_qkv': [True, True, False]}, 'read_from_past': {'add_pe_to_qkv': [True, True, False]}, 'read_from_memory': {'add_pe_to_qkv': [True, True, False]}, 'read_from_query': {'add_pe_to_qkv': [True, True, False], 'output_norm': False}, 'query_self_attention': {'add_pe_to_qkv': [True, True, False]}, 'pixel_self_attention': {'add_pe_to_qkv': [True, True, False]}}, 'object_summarizer': {'embed_dim': '${model.object_transformer.embed_dim}', 'num_summaries': '${model.object_transformer.num_queries}', 'add_pe': True}, 'aux_loss': {'sensory': {'enabled': True, 'weight': 0.01}, 'query': {'enabled': True, 'weight': 0.01}}, 'mask_decoder': {'up_dims': [256, 128, 128]}}}