修改清单(前后对比)
范围
只做“能跑”的最小修复,尽量保留原本逻辑与结构。
1) multi-shot/multi_view/datasets/videodataset.py
补齐未定义变量,保持原返回结构
Before
return {
"global_caption": None,
"shot_num": 3,
"pre_shot_caption": ["xxx", "xxx", "xxx"],
# "single_caption": meta_prompt["single_prompt"],
"video": input_video,
"ref_num": ID_num * 3, ###TODO: 先跑通 ID_num = 1 的情况
"ID_num": ID_num,
"ref_images": [[Image0, Image1, Image2]],
"video_path": video_path
}
After
ID_num = 1
Image0, Image1, Image2 = ref_images[:3]
return {
"global_caption": None,
"shot_num": 3,
"pre_shot_caption": ["xxx", "xxx", "xxx"],
# "single_caption": meta_prompt["single_prompt"],
"video": input_video,
"ref_num": ID_num * 3, ###TODO: 先跑通 ID_num = 1 的情况
"ID_num": ID_num,
"ref_images": [[Image0, Image1, Image2]],
"video_path": video_path
}
2) multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/pipelines/wan_video_new.py
2.1 Prompt 编码(修复拼写/对象调用)
Before
prompt = pip.text_encoder.process_prompt(prompt, positive=positive)
output = pip.text_encoder.tokenizer(prompt, return_mask=True, add_special_tokens=True)
ids = output['input_ids'].to(device)
mask = output['attention_mask'].to(device)
prompt_emb = self.text_encoder(ids, mask)
...
prompt_shot_all = pip.text_encoder.process_prompt(prompt_shot_all, positive=positive)
...
for shot_index, shot_cut_end in enmurate(shot_cut_ends):
start_pos = shot_cut_starts[shot_index]
end_pos = shot_cut_end
shot_text = cleaned_prompt[start_pos: end_pos + 1].strip()
After
prompt = pipe.text_encoder.process_prompt(prompt, positive=positive)
output = pipe.text_encoder.tokenizer(prompt, return_mask=True, add_special_tokens=True)
ids = output['input_ids'].to(device)
mask = output['attention_mask'].to(device)
prompt_emb = pipe.text_encoder(ids, mask)
...
prompt_shot_all = pipe.text_encoder.process_prompt(prompt_shot_all, positive=positive)
cleaned_prompt = prompt_shot_all
...
for shot_index, shot_cut_end in enumerate(shot_cut_ends):
start_pos = shot_cut_starts[shot_index]
end_pos = shot_cut_end
shot_text = cleaned_prompt[start_pos: end_pos + 1].strip()
2.2 Shot mask 构造(修复未定义变量)
Before
S_shots = len(shot_text_ranges[0]) ###TODO: 当前batch size 是 1
...
for sid, (s0, s1) in enumerate(shot_ranges):
s0 = int(s0)
s1 = int(s1)
shot_table[sid, s0: s1 + 1] = True
...
allow_all = torch.cat([allow_shot, allow_ref_image], dim = 1)
assert allow_all.shape == x.shape[2] "The shape is something wrong"
After
shot_ranges = shot_text_ranges[0]
if isinstance(shot_ranges, dict):
shot_ranges = shot_ranges.get("shots", [])
S_shots = len(shot_ranges)
for sid, span in enumerate(shot_ranges):
if span is None:
continue
s0, s1 = span
s0 = int(s0)
s1 = int(s1)
shot_table[sid, s0: s1 + 1] = True
...
allow_all = torch.cat([allow_shot, allow_ref_image], dim = 1)
assert allow_all.shape[1] == S_q, "The shape is something wrong"
2.3 shot_rope 分支变量名冲突修复
Before
for shot_index, num_frames in enumerate(shots_nums):
f = num_frames
rope_s = freq_s[shot_index] \
.view(1, 1, 1, -1) \
.expand(f, h, w, -1)
...
freqs = freqs.reshape(f * h * w, 1, -1)
After
for shot_index, num_frames in enumerate(shots_nums):
f = num_frames
rope_s = freq_s[shot_index].view(1, 1, 1, -1).expand(f, h, w, -1)
...
freqs = freqs.reshape(f * h * w, 1, -1)
2.4 model_fn_wan_video 函数签名语法修复
Before
ID_2_shot: None ######每个shot 中对应包含的ID是那几个,是一个list[ batch0: [shot0: [0,1], shot1:[2]], batch1:[]]
**kwargs,
After
ID_2_shot=None, ######每个shot 中对应包含的ID是那几个,是一个list[ batch0: [shot0: [0,1], shot1:[2]], batch1:[]]
**kwargs,
2.5 WanVideoUnit_SpeedControl 缺失类补齐
Before
WanVideoUnit_SpeedControl(), # 在 units 列表中引用,但类未定义
After
class WanVideoUnit_SpeedControl(PipelineUnit):
def __init__(self):
super().__init__(input_params=("motion_bucket_id",))
def process(self, pipe: WanVideoPipeline, motion_bucket_id):
if motion_bucket_id is None:
return {}
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"motion_bucket_id": motion_bucket_id}
2.6 Prompt 处理使用 prompter(修复 process_prompt 缺失)
Before
prompt = pipe.text_encoder.process_prompt(prompt, positive=positive)
output = pipe.text_encoder.tokenizer(prompt, return_mask=True, add_special_tokens=True)
...
prompt_shot_all = pipe.text_encoder.process_prompt(prompt_shot_all, positive=positive)
...
enc_output = pipe.text_encoder(
text,
return_mask=True,
add_special_tokens=True,
return_tensors="pt"
)
After
prompt = pipe.prompter.process_prompt(prompt, positive=positive)
output = pipe.prompter.tokenizer(prompt, return_mask=True, add_special_tokens=True)
...
prompt_shot_all = pipe.prompter.process_prompt(prompt_shot_all, positive=positive)
...
enc_output = pipe.prompter.tokenizer(
text,
return_mask=True,
add_special_tokens=True,
return_tensors="pt"
)
2.7 兼容 tokenizer 返回 tuple / dict
Before
output = pipe.prompter.tokenizer(prompt, return_mask=True, add_special_tokens=True)
ids = output['input_ids'].to(device)
mask = output['attention_mask'].to(device)
...
enc_output = pipe.prompter.tokenizer(..., return_mask=True, ...)
ids = enc_output['input_ids'].to(device)
mask = enc_output['attention_mask'].to(device)
After
output = pipe.prompter.tokenizer(prompt, return_mask=True, add_special_tokens=True)
if isinstance(output, tuple):
ids, mask = output
else:
ids = output['input_ids']
mask = output['attention_mask']
ids = ids.to(device)
mask = mask.to(device)
...
enc_output = pipe.prompter.tokenizer(..., return_mask=True, ...)
if isinstance(enc_output, tuple):
ids, mask = enc_output
else:
ids = enc_output['input_ids']
mask = enc_output['attention_mask']
ids = ids.to(device)
mask = mask.to(device)
2.8 使用 prompter 的 text_len(修复属性缺失)
Before
pad_len = pipe.text_encoder.text_len - total_len
After
pad_len = pipe.prompter.text_len - total_len
3) multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/models/wan_video_dit.py
3.1 attention_per_batch_with_shots 中 ID token slice 修复
Before
ID_token_start = shot_token_all_num + id_idx * pre_ID_token_num
ID_token_end = start + pre_ID_token_num
assert end <= k.shape[2], (
f"ID token slice out of range: start={start}, end={end}, "
f"K_len={k.shape[2]}"
)
id_token_k = k[bi, :, start:end, :]
id_token_v = v[bi, :, start:end, :]
After
start = shot_token_all_num + id_idx * pre_id_token_num
if start >= k.shape[2]:
continue
end = min(start + pre_id_token_num, k.shape[2])
id_token_k = k[bi, :, start:end, :]
id_token_v = v[bi, :, start:end, :]
3.2 CrossAttention.forward 增加 attn_mask
Before
def forward(self, x: torch.Tensor, y: torch.Tensor):
...
x = self.attn(q, k, v)
After
def forward(self, x: torch.Tensor, y: torch.Tensor, attn_mask=None):
...
x = self.attn(q, k, v, attn_mask=attn_mask)
4) multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/trainers/utils.py
新增参数以匹配 pipeline
Before
# (no --shot_rope argument)
After
parser.add_argument("--shot_rope", type=bool, default=False, help="Whether apply shot rope for multi-shot video")
5) 新增文件
multi-shot/MULTI_SHOT_CORE_SUMMARY.md
- Before: 文件不存在
- After: 新增总结文档
multi-shot/MODIFICATION_LOG.md
- Before: 文件不存在
6) multi-shot/dry_run_train.py
强制将模型移动到 CUDA 以匹配输入设备
Before
device = "cuda" if torch.cuda.is_available() else "cpu"
model.pipe.device = device
model.pipe.torch_dtype = torch.bfloat16
After
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.pipe.device = device
model.pipe.torch_dtype = torch.bfloat16
- After: 新增修改清单(本文件)
验证
python -m py_compile multi-shot/multi_view/datasets/videodataset.py
python -m py_compile multi-shot/multi_view/train.py
python -m py_compile multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/pipelines/wan_video_new.py
python -m py_compile multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/models/wan_video_dit.py
python -m py_compile multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/trainers/utils.py