multishot / MODIFICATION_LOG.md
PencilHu's picture
Upload folder using huggingface_hub
85752bc verified

修改清单(前后对比)

范围

只做“能跑”的最小修复,尽量保留原本逻辑与结构。

1) multi-shot/multi_view/datasets/videodataset.py

补齐未定义变量,保持原返回结构

Before

return {
    "global_caption": None,
    "shot_num": 3,
    "pre_shot_caption": ["xxx", "xxx", "xxx"],
    # "single_caption": meta_prompt["single_prompt"],
    "video": input_video,
    "ref_num": ID_num * 3, ###TODO: 先跑通 ID_num = 1 的情况
    "ID_num": ID_num,
    "ref_images": [[Image0, Image1, Image2]],
    "video_path": video_path
}

After

ID_num = 1
Image0, Image1, Image2 = ref_images[:3]
return {
    "global_caption": None,
    "shot_num": 3,
    "pre_shot_caption": ["xxx", "xxx", "xxx"],
    # "single_caption": meta_prompt["single_prompt"],
    "video": input_video,
    "ref_num": ID_num * 3, ###TODO: 先跑通 ID_num = 1 的情况
    "ID_num": ID_num,
    "ref_images": [[Image0, Image1, Image2]],
    "video_path": video_path
}

2) multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/pipelines/wan_video_new.py

2.1 Prompt 编码(修复拼写/对象调用)

Before

prompt =  pip.text_encoder.process_prompt(prompt, positive=positive)
output =  pip.text_encoder.tokenizer(prompt, return_mask=True, add_special_tokens=True)
ids = output['input_ids'].to(device)
mask = output['attention_mask'].to(device)
prompt_emb = self.text_encoder(ids, mask)
...
prompt_shot_all = pip.text_encoder.process_prompt(prompt_shot_all, positive=positive)
...
for shot_index, shot_cut_end in enmurate(shot_cut_ends):
    start_pos = shot_cut_starts[shot_index]
    end_pos = shot_cut_end
    shot_text = cleaned_prompt[start_pos: end_pos + 1].strip()

After

prompt = pipe.text_encoder.process_prompt(prompt, positive=positive)
output = pipe.text_encoder.tokenizer(prompt, return_mask=True, add_special_tokens=True)
ids = output['input_ids'].to(device)
mask = output['attention_mask'].to(device)
prompt_emb = pipe.text_encoder(ids, mask)
...
prompt_shot_all = pipe.text_encoder.process_prompt(prompt_shot_all, positive=positive)
cleaned_prompt = prompt_shot_all
...
for shot_index, shot_cut_end in enumerate(shot_cut_ends):
    start_pos = shot_cut_starts[shot_index]
    end_pos = shot_cut_end
    shot_text = cleaned_prompt[start_pos: end_pos + 1].strip()

2.2 Shot mask 构造(修复未定义变量)

Before

S_shots = len(shot_text_ranges[0]) ###TODO: 当前batch size 是 1
...
for sid, (s0, s1) in enumerate(shot_ranges):
    s0 = int(s0)
    s1 = int(s1)
    shot_table[sid, s0: s1 + 1] = True
...
allow_all = torch.cat([allow_shot, allow_ref_image], dim = 1)
assert allow_all.shape == x.shape[2] "The shape is something wrong"

After

shot_ranges = shot_text_ranges[0]
if isinstance(shot_ranges, dict):
    shot_ranges = shot_ranges.get("shots", [])
S_shots = len(shot_ranges)
for sid, span in enumerate(shot_ranges):
    if span is None:
        continue
    s0, s1 = span
    s0 = int(s0)
    s1 = int(s1)
    shot_table[sid, s0: s1 + 1] = True
...
allow_all = torch.cat([allow_shot, allow_ref_image], dim = 1)
assert allow_all.shape[1] == S_q, "The shape is something wrong"

2.3 shot_rope 分支变量名冲突修复

Before

for shot_index, num_frames in enumerate(shots_nums):
    f = num_frames
    rope_s = freq_s[shot_index] \
        .view(1, 1, 1, -1) \  
        .expand(f, h, w, -1)
    ...
    freqs = freqs.reshape(f * h * w, 1, -1)

After

for shot_index, num_frames in enumerate(shots_nums):
    f = num_frames
    rope_s = freq_s[shot_index].view(1, 1, 1, -1).expand(f, h, w, -1)
    ...
    freqs = freqs.reshape(f * h * w, 1, -1)

2.4 model_fn_wan_video 函数签名语法修复

Before

ID_2_shot: None ######每个shot 中对应包含的ID是那几个,是一个list[ batch0: [shot0: [0,1], shot1:[2]], batch1:[]]
**kwargs,

After

ID_2_shot=None, ######每个shot 中对应包含的ID是那几个,是一个list[ batch0: [shot0: [0,1], shot1:[2]], batch1:[]]
**kwargs,

2.5 WanVideoUnit_SpeedControl 缺失类补齐

Before

WanVideoUnit_SpeedControl(),  # 在 units 列表中引用,但类未定义

After

class WanVideoUnit_SpeedControl(PipelineUnit):
    def __init__(self):
        super().__init__(input_params=("motion_bucket_id",))

    def process(self, pipe: WanVideoPipeline, motion_bucket_id):
        if motion_bucket_id is None:
            return {}
        motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device)
        return {"motion_bucket_id": motion_bucket_id}

2.6 Prompt 处理使用 prompter(修复 process_prompt 缺失)

Before

prompt = pipe.text_encoder.process_prompt(prompt, positive=positive)
output = pipe.text_encoder.tokenizer(prompt, return_mask=True, add_special_tokens=True)
...
prompt_shot_all = pipe.text_encoder.process_prompt(prompt_shot_all, positive=positive)
...
enc_output = pipe.text_encoder(
    text,
    return_mask=True,
    add_special_tokens=True,
    return_tensors="pt"
)

After

prompt = pipe.prompter.process_prompt(prompt, positive=positive)
output = pipe.prompter.tokenizer(prompt, return_mask=True, add_special_tokens=True)
...
prompt_shot_all = pipe.prompter.process_prompt(prompt_shot_all, positive=positive)
...
enc_output = pipe.prompter.tokenizer(
    text,
    return_mask=True,
    add_special_tokens=True,
    return_tensors="pt"
)

2.7 兼容 tokenizer 返回 tuple / dict

Before

output = pipe.prompter.tokenizer(prompt, return_mask=True, add_special_tokens=True)
ids = output['input_ids'].to(device)
mask = output['attention_mask'].to(device)
...
enc_output = pipe.prompter.tokenizer(..., return_mask=True, ...)
ids = enc_output['input_ids'].to(device)
mask = enc_output['attention_mask'].to(device)

After

output = pipe.prompter.tokenizer(prompt, return_mask=True, add_special_tokens=True)
if isinstance(output, tuple):
    ids, mask = output
else:
    ids = output['input_ids']
    mask = output['attention_mask']
ids = ids.to(device)
mask = mask.to(device)
...
enc_output = pipe.prompter.tokenizer(..., return_mask=True, ...)
if isinstance(enc_output, tuple):
    ids, mask = enc_output
else:
    ids = enc_output['input_ids']
    mask = enc_output['attention_mask']
ids = ids.to(device)
mask = mask.to(device)

2.8 使用 prompter 的 text_len(修复属性缺失)

Before

pad_len = pipe.text_encoder.text_len - total_len

After

pad_len = pipe.prompter.text_len - total_len

3) multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/models/wan_video_dit.py

3.1 attention_per_batch_with_shots 中 ID token slice 修复

Before

ID_token_start = shot_token_all_num + id_idx * pre_ID_token_num
ID_token_end   = start + pre_ID_token_num
assert end <= k.shape[2], (
    f"ID token slice out of range: start={start}, end={end}, "
    f"K_len={k.shape[2]}"
)
id_token_k = k[bi, :, start:end, :] 
id_token_v = v[bi, :, start:end, :]

After

start = shot_token_all_num + id_idx * pre_id_token_num
if start >= k.shape[2]:
    continue
end = min(start + pre_id_token_num, k.shape[2])
id_token_k = k[bi, :, start:end, :]
id_token_v = v[bi, :, start:end, :]

3.2 CrossAttention.forward 增加 attn_mask

Before

def forward(self, x: torch.Tensor, y: torch.Tensor):
    ...
    x = self.attn(q, k, v)

After

def forward(self, x: torch.Tensor, y: torch.Tensor, attn_mask=None):
    ...
    x = self.attn(q, k, v, attn_mask=attn_mask)

4) multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/trainers/utils.py

新增参数以匹配 pipeline

Before

# (no --shot_rope argument)

After

parser.add_argument("--shot_rope", type=bool, default=False, help="Whether apply shot rope for multi-shot video")

5) 新增文件

multi-shot/MULTI_SHOT_CORE_SUMMARY.md

  • Before: 文件不存在
  • After: 新增总结文档

multi-shot/MODIFICATION_LOG.md

  • Before: 文件不存在

6) multi-shot/dry_run_train.py

强制将模型移动到 CUDA 以匹配输入设备

Before

device = "cuda" if torch.cuda.is_available() else "cpu"
model.pipe.device = device
model.pipe.torch_dtype = torch.bfloat16

After

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.pipe.device = device
model.pipe.torch_dtype = torch.bfloat16
  • After: 新增修改清单(本文件)

验证

python -m py_compile multi-shot/multi_view/datasets/videodataset.py
python -m py_compile multi-shot/multi_view/train.py
python -m py_compile multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/pipelines/wan_video_new.py
python -m py_compile multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/models/wan_video_dit.py
python -m py_compile multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/trainers/utils.py