Spaces:
Runtime error
Runtime error
daKhosa commited on
Commit Β·
3cab3e5
1
Parent(s): 27df6f9
Pivot to Wan 2.1 + IP-Adapter face conditioning
Browse files- Model: Wan2.1-I2V-A14B-Diffusers (single transformer, no MoE)
- IP-Adapter: WanIPAdapter in ip_adapter.py ports IPAdapterWAN to diffusers;
SigLIP2 so400m encodes the face reference, TimeResampler (SD3.5 weights)
compresses to 8 tokens, WanIPAttnProcessor injects face KV into every
self-attention block; cleared after each inference call
- LightX2V: switched to lightx2v/Wan2.1-Distill-Loras single-file LoRA
- Removed: transformer_2, guidance_scale_2, AOTI (Wan 2.1 has no compiled blocks)
- LoRA loading: single adapter_name per LoRA (no HIGH/LOW split for 2.1)
- face_ref_image threaded from generate_video β run_inference β ip_adapter
- Add einops to requirements
- app.py +73 -94
- ip_adapter.py +354 -0
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -39,9 +39,9 @@ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
|
|
| 39 |
from diffusers.utils.export_utils import export_to_video
|
| 40 |
|
| 41 |
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
|
| 42 |
-
import aoti
|
| 43 |
from modify_model.modify_wan import set_sage_attn_wan
|
| 44 |
from sageattention import sageattn
|
|
|
|
| 45 |
|
| 46 |
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
| 47 |
warnings.filterwarnings("ignore")
|
|
@@ -261,7 +261,7 @@ LORA_NAMES = ["None"] + sorted(set(list(LORA_CATALOG.keys()) + list(_known.keys(
|
|
| 261 |
print(f"LoRA gallery: {len(LORA_NAMES)-1} entries ({len(LORA_CATALOG)} cached).")
|
| 262 |
|
| 263 |
# ββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 264 |
-
MODEL_ID = "Wan-AI/Wan2.
|
| 265 |
|
| 266 |
MAX_DIM = 832
|
| 267 |
MIN_DIM = 480
|
|
@@ -371,20 +371,14 @@ def interpolate_bits(frames_np, multiplier=2, scale=1.0):
|
|
| 371 |
return output
|
| 372 |
|
| 373 |
|
| 374 |
-
|
| 375 |
-
# The 14B model (T5 ~11 GB + two transformers ~39 GB each) exceeds the CPU
|
| 376 |
-
# startup container's 50 GB storage quota. Loading inside @spaces.GPU moves
|
| 377 |
-
# the download to the H200 worker which has a much larger NVMe. The global
|
| 378 |
-
# `pipe` persists between requests as long as the container stays up.
|
| 379 |
-
pipe = None
|
| 380 |
original_scheduler = None
|
| 381 |
-
|
| 382 |
|
| 383 |
|
| 384 |
def _init_pipeline():
|
| 385 |
-
global pipe, original_scheduler,
|
| 386 |
|
| 387 |
-
# Ensure token env-vars are set in this worker context.
|
| 388 |
if HF_TOKEN:
|
| 389 |
os.environ["HF_TOKEN"] = HF_TOKEN
|
| 390 |
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
|
|
@@ -394,60 +388,41 @@ def _init_pipeline():
|
|
| 394 |
MODEL_ID, torch_dtype=torch.bfloat16, token=HF_TOKEN or None,
|
| 395 |
).to("cuda")
|
| 396 |
|
| 397 |
-
# SageAttention
|
| 398 |
-
set_sage_attn_wan(pipe.transformer,
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
pipe.fuse_lora(adapter_names=["lx2v_l"], lora_scale=0.7, components=["transformer_2"])
|
| 413 |
-
pipe.unload_lora_weights()
|
| 414 |
-
print("LightX2V LoRA fused.")
|
| 415 |
|
| 416 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=6.0)
|
| 417 |
original_scheduler = copy.deepcopy(pipe.scheduler)
|
| 418 |
|
| 419 |
-
#
|
| 420 |
quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
|
| 421 |
torch._dynamo.reset()
|
| 422 |
-
quantize_(pipe.transformer,
|
| 423 |
torch._dynamo.reset()
|
| 424 |
-
quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
|
| 425 |
-
torch._dynamo.reset()
|
| 426 |
-
|
| 427 |
-
aoti.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Wan2", variant="fp8da")
|
| 428 |
-
aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da")
|
| 429 |
|
| 430 |
-
#
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
|
| 437 |
print("Pipeline ready.")
|
| 438 |
|
| 439 |
|
| 440 |
-
def _disable_aoti():
|
| 441 |
-
for _m, _ in _aoti_saved:
|
| 442 |
-
try: del _m.forward
|
| 443 |
-
except AttributeError: pass
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
def _restore_aoti():
|
| 447 |
-
for _m, _fwd in _aoti_saved:
|
| 448 |
-
_m.forward = _fwd
|
| 449 |
-
|
| 450 |
-
|
| 451 |
@spaces.GPU(duration=900)
|
| 452 |
def _warmup_pipeline():
|
| 453 |
"""Load the full pipeline at Space startup so generation has no init delay."""
|
|
@@ -492,28 +467,26 @@ def get_num_frames(duration_seconds):
|
|
| 492 |
return 1 + int(np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL))
|
| 493 |
|
| 494 |
|
| 495 |
-
def get_inference_duration(resized_image, _last, _prompt, steps, _neg, num_frames,
|
| 496 |
-
guidance_scale,
|
| 497 |
_qual, duration_seconds, _lora, _scale, _progress):
|
| 498 |
BASE = 81 * 832 * 624
|
| 499 |
w, h = resized_image.size
|
| 500 |
factor = num_frames * w * h / BASE
|
| 501 |
-
|
| 502 |
-
secs_per_step = 45 if (_lora and _lora != "None") else 15
|
| 503 |
gen_time = int(steps) * secs_per_step * factor ** 1.5
|
| 504 |
if guidance_scale > 1:
|
| 505 |
gen_time *= 1.8
|
| 506 |
ff = frame_multiplier // FIXED_FPS
|
| 507 |
if ff > 1:
|
| 508 |
gen_time += ((num_frames * ff) - num_frames) * 0.02
|
| 509 |
-
# Add 300 s headroom for first-call pipeline init (model download + AOTI load).
|
| 510 |
return min(900, 15 + gen_time)
|
| 511 |
|
| 512 |
|
| 513 |
@spaces.GPU(duration=get_inference_duration)
|
| 514 |
def run_inference(
|
| 515 |
-
resized_image, processed_last_image, prompt, steps, negative_prompt,
|
| 516 |
-
num_frames, guidance_scale,
|
| 517 |
scheduler_name, flow_shift, frame_multiplier, quality, duration_seconds,
|
| 518 |
lora_name, lora_scale,
|
| 519 |
progress=gr.Progress(track_tqdm=True),
|
|
@@ -531,11 +504,9 @@ def run_inference(
|
|
| 531 |
|
| 532 |
clear_vram()
|
| 533 |
|
| 534 |
-
#
|
| 535 |
-
# LoRA-modified weights are actually used (AOTI binds weights at init).
|
| 536 |
loaded_lora = False
|
| 537 |
if lora_name and lora_name != "None":
|
| 538 |
-
# Lazy-download: fetch the repo now if not yet on disk, then pair files.
|
| 539 |
if lora_name not in LORA_CATALOG:
|
| 540 |
repo_id = LORA_REPO_MAP.get(lora_name)
|
| 541 |
if repo_id:
|
|
@@ -545,15 +516,12 @@ def run_inference(
|
|
| 545 |
except Exception as e:
|
| 546 |
print(f"LoRA download failed ({lora_name}): {e}")
|
| 547 |
if lora_name and lora_name != "None" and lora_name in LORA_CATALOG:
|
| 548 |
-
lora
|
| 549 |
scale = float(lora_scale)
|
| 550 |
-
_disable_aoti() # fall back to uncompiled fp8 so LoRA weights are visible
|
| 551 |
try:
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
pipe.
|
| 555 |
-
pipe.load_lora_weights(lora["low"], adapter_name=ln, load_into_transformer_2=True)
|
| 556 |
-
pipe.set_adapters([hn, ln], adapter_weights=[scale, scale])
|
| 557 |
loaded_lora = True
|
| 558 |
print(f"Loaded LoRA: {lora_name} (scale={scale})")
|
| 559 |
except Exception as e:
|
|
@@ -561,28 +529,40 @@ def run_inference(
|
|
| 561 |
try: pipe.unload_lora_weights()
|
| 562 |
except: pass
|
| 563 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
task_id = str(uuid.uuid4())[:8]
|
| 565 |
-
start
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
print(f"Gen time: {time.time()-start:.1f}s task={task_id}")
|
| 581 |
|
| 582 |
if loaded_lora:
|
| 583 |
try: pipe.unload_lora_weights()
|
| 584 |
except: pass
|
| 585 |
-
_restore_aoti() # re-enable compiled blocks for next LoRA-free inference
|
| 586 |
|
| 587 |
raw_frames = result.frames[0]
|
| 588 |
pipe.scheduler = original_scheduler
|
|
@@ -605,7 +585,7 @@ def run_inference(
|
|
| 605 |
def generate_video(
|
| 606 |
input_image, last_image, face_ref_image, prompt,
|
| 607 |
steps=6, negative_prompt="", duration_seconds=MAX_DURATION,
|
| 608 |
-
guidance_scale=1.0,
|
| 609 |
quality=5, scheduler="UniPCMultistep", flow_shift=6.0,
|
| 610 |
frame_multiplier=16, lora_name="None", lora_scale=0.6,
|
| 611 |
blink_subject="woman",
|
|
@@ -639,8 +619,8 @@ def generate_video(
|
|
| 639 |
effective_prompt = prompt
|
| 640 |
|
| 641 |
video_path, task_n = run_inference(
|
| 642 |
-
resized_image, processed_last,
|
| 643 |
-
|
| 644 |
scheduler, flow_shift, frame_multiplier, quality, duration_seconds,
|
| 645 |
lora_name, lora_scale, progress,
|
| 646 |
)
|
|
@@ -657,7 +637,7 @@ CSS = """
|
|
| 657 |
|
| 658 |
with gr.Blocks(css=CSS, delete_cache=(3600, 10800)) as demo:
|
| 659 |
gr.Markdown(f"## ZeroWan2GP β [{MODEL_ID.split('/')[-1]}](https://huggingface.co/{MODEL_ID})")
|
| 660 |
-
gr.Markdown("Wan 2.
|
| 661 |
|
| 662 |
with gr.Row():
|
| 663 |
with gr.Column():
|
|
@@ -690,8 +670,7 @@ with gr.Blocks(css=CSS, delete_cache=(3600, 10800)) as demo:
|
|
| 690 |
seed_input = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed")
|
| 691 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 692 |
steps_slider = gr.Slider(1, 50, step=1, value=6, label="Steps")
|
| 693 |
-
gs_input = gr.Slider(0.0, 10.0, step=0.5, value=1.0, label="Guidance Scale
|
| 694 |
-
gs2_input = gr.Slider(0.0, 10.0, step=0.5, value=1.0, label="Guidance Scale 2 (low noise)")
|
| 695 |
scheduler_dd = gr.Dropdown(list(SCHEDULER_MAP.keys()), value="UniPCMultistep", label="Scheduler")
|
| 696 |
flow_shift_slider = gr.Slider(0.5, 15.0, step=0.1, value=6.0, label="Flow Shift")
|
| 697 |
play_result = gr.Checkbox(label="Display result", value=True)
|
|
@@ -708,7 +687,7 @@ with gr.Blocks(css=CSS, delete_cache=(3600, 10800)) as demo:
|
|
| 708 |
|
| 709 |
ui_inputs = [
|
| 710 |
input_image_component, last_image_component, face_ref_component, prompt_input,
|
| 711 |
-
steps_slider, negative_prompt_input, duration_input, gs_input,
|
| 712 |
seed_input, randomize_seed, quality_slider, scheduler_dd, flow_shift_slider,
|
| 713 |
frame_multi, lora_dropdown, lora_scale_slider, blink_subject_radio, play_result,
|
| 714 |
]
|
|
@@ -717,7 +696,7 @@ with gr.Blocks(css=CSS, delete_cache=(3600, 10800)) as demo:
|
|
| 717 |
grab_btn.click(fn=None, inputs=None, outputs=[timestamp_box], js=get_timestamp_js)
|
| 718 |
timestamp_box.change(fn=extract_frame, inputs=[video_output, timestamp_box], outputs=[input_image_component])
|
| 719 |
|
| 720 |
-
print("Warming up pipeline (loading model, fusing
|
| 721 |
_warmup_pipeline()
|
| 722 |
print("Warmup complete β Space ready.")
|
| 723 |
|
|
|
|
| 39 |
from diffusers.utils.export_utils import export_to_video
|
| 40 |
|
| 41 |
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
|
|
|
|
| 42 |
from modify_model.modify_wan import set_sage_attn_wan
|
| 43 |
from sageattention import sageattn
|
| 44 |
+
from ip_adapter import WanIPAdapter
|
| 45 |
|
| 46 |
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
| 47 |
warnings.filterwarnings("ignore")
|
|
|
|
| 261 |
print(f"LoRA gallery: {len(LORA_NAMES)-1} entries ({len(LORA_CATALOG)} cached).")
|
| 262 |
|
| 263 |
# ββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 264 |
+
MODEL_ID = "Wan-AI/Wan2.1-I2V-A14B-Diffusers"
|
| 265 |
|
| 266 |
MAX_DIM = 832
|
| 267 |
MIN_DIM = 480
|
|
|
|
| 371 |
return output
|
| 372 |
|
| 373 |
|
| 374 |
+
pipe = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
original_scheduler = None
|
| 376 |
+
ip_adapter = None
|
| 377 |
|
| 378 |
|
| 379 |
def _init_pipeline():
|
| 380 |
+
global pipe, original_scheduler, ip_adapter
|
| 381 |
|
|
|
|
| 382 |
if HF_TOKEN:
|
| 383 |
os.environ["HF_TOKEN"] = HF_TOKEN
|
| 384 |
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
|
|
|
|
| 388 |
MODEL_ID, torch_dtype=torch.bfloat16, token=HF_TOKEN or None,
|
| 389 |
).to("cuda")
|
| 390 |
|
| 391 |
+
# SageAttention for the transformer.
|
| 392 |
+
set_sage_attn_wan(pipe.transformer, sageattn)
|
| 393 |
+
|
| 394 |
+
# Fuse LightX2V 4-step distillation LoRA into the single transformer.
|
| 395 |
+
print("Fusing LightX2V 2.1 distillation LoRA β¦")
|
| 396 |
+
_DISTILL_REPO = "lightx2v/Wan2.1-Distill-Loras"
|
| 397 |
+
_DISTILL_FILE = "wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors"
|
| 398 |
+
try:
|
| 399 |
+
pipe.load_lora_weights(_DISTILL_REPO, weight_name=_DISTILL_FILE, adapter_name="lx2v")
|
| 400 |
+
pipe.set_adapters(["lx2v"], adapter_weights=[1.0])
|
| 401 |
+
pipe.fuse_lora(adapter_names=["lx2v"], lora_scale=0.65, components=["transformer"])
|
| 402 |
+
pipe.unload_lora_weights()
|
| 403 |
+
print("LightX2V LoRA fused.")
|
| 404 |
+
except Exception as e:
|
| 405 |
+
print(f"LightX2V fuse skipped: {e}")
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=6.0)
|
| 408 |
original_scheduler = copy.deepcopy(pipe.scheduler)
|
| 409 |
|
| 410 |
+
# fp8 quantisation β single transformer only.
|
| 411 |
quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
|
| 412 |
torch._dynamo.reset()
|
| 413 |
+
quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
|
| 414 |
torch._dynamo.reset()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
|
| 416 |
+
# IP-Adapter β patches transformer attention blocks for face conditioning.
|
| 417 |
+
try:
|
| 418 |
+
ip_adapter = WanIPAdapter(pipe, device=pipe.device, dtype=torch.bfloat16)
|
| 419 |
+
except Exception as e:
|
| 420 |
+
print(f"[IP-Adapter] init failed: {e}")
|
| 421 |
+
ip_adapter = None
|
| 422 |
|
| 423 |
print("Pipeline ready.")
|
| 424 |
|
| 425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
@spaces.GPU(duration=900)
|
| 427 |
def _warmup_pipeline():
|
| 428 |
"""Load the full pipeline at Space startup so generation has no init delay."""
|
|
|
|
| 467 |
return 1 + int(np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL))
|
| 468 |
|
| 469 |
|
| 470 |
+
def get_inference_duration(resized_image, _last, _face, _prompt, steps, _neg, num_frames,
|
| 471 |
+
guidance_scale, _seed, _sched, _fs, frame_multiplier,
|
| 472 |
_qual, duration_seconds, _lora, _scale, _progress):
|
| 473 |
BASE = 81 * 832 * 624
|
| 474 |
w, h = resized_image.size
|
| 475 |
factor = num_frames * w * h / BASE
|
| 476 |
+
secs_per_step = 30 if (_lora and _lora != "None") else 20
|
|
|
|
| 477 |
gen_time = int(steps) * secs_per_step * factor ** 1.5
|
| 478 |
if guidance_scale > 1:
|
| 479 |
gen_time *= 1.8
|
| 480 |
ff = frame_multiplier // FIXED_FPS
|
| 481 |
if ff > 1:
|
| 482 |
gen_time += ((num_frames * ff) - num_frames) * 0.02
|
|
|
|
| 483 |
return min(900, 15 + gen_time)
|
| 484 |
|
| 485 |
|
| 486 |
@spaces.GPU(duration=get_inference_duration)
|
| 487 |
def run_inference(
|
| 488 |
+
resized_image, processed_last_image, face_ref_image, prompt, steps, negative_prompt,
|
| 489 |
+
num_frames, guidance_scale, current_seed,
|
| 490 |
scheduler_name, flow_shift, frame_multiplier, quality, duration_seconds,
|
| 491 |
lora_name, lora_scale,
|
| 492 |
progress=gr.Progress(track_tqdm=True),
|
|
|
|
| 504 |
|
| 505 |
clear_vram()
|
| 506 |
|
| 507 |
+
# Lazy-download + load LoRA for this request.
|
|
|
|
| 508 |
loaded_lora = False
|
| 509 |
if lora_name and lora_name != "None":
|
|
|
|
| 510 |
if lora_name not in LORA_CATALOG:
|
| 511 |
repo_id = LORA_REPO_MAP.get(lora_name)
|
| 512 |
if repo_id:
|
|
|
|
| 516 |
except Exception as e:
|
| 517 |
print(f"LoRA download failed ({lora_name}): {e}")
|
| 518 |
if lora_name and lora_name != "None" and lora_name in LORA_CATALOG:
|
| 519 |
+
lora = LORA_CATALOG[lora_name]
|
| 520 |
scale = float(lora_scale)
|
|
|
|
| 521 |
try:
|
| 522 |
+
an = lora_name.replace(" ", "_")
|
| 523 |
+
pipe.load_lora_weights(lora["high"], adapter_name=an)
|
| 524 |
+
pipe.set_adapters([an], adapter_weights=[scale])
|
|
|
|
|
|
|
| 525 |
loaded_lora = True
|
| 526 |
print(f"Loaded LoRA: {lora_name} (scale={scale})")
|
| 527 |
except Exception as e:
|
|
|
|
| 529 |
try: pipe.unload_lora_weights()
|
| 530 |
except: pass
|
| 531 |
|
| 532 |
+
# IP-Adapter face conditioning β set before pipe(), clear after.
|
| 533 |
+
if ip_adapter is not None and face_ref_image is not None:
|
| 534 |
+
try:
|
| 535 |
+
face_emb = ip_adapter.encode(face_ref_image)
|
| 536 |
+
ip_adapter.set_hidden_states(face_emb, scale=lora_scale * 0.6)
|
| 537 |
+
print("[IP-Adapter] face embedding set")
|
| 538 |
+
except Exception as e:
|
| 539 |
+
print(f"[IP-Adapter] encode failed: {e}")
|
| 540 |
+
|
| 541 |
task_id = str(uuid.uuid4())[:8]
|
| 542 |
+
start = time.time()
|
| 543 |
+
try:
|
| 544 |
+
result = pipe(
|
| 545 |
+
image=resized_image,
|
| 546 |
+
last_image=processed_last_image,
|
| 547 |
+
prompt=prompt,
|
| 548 |
+
negative_prompt=negative_prompt,
|
| 549 |
+
height=resized_image.height,
|
| 550 |
+
width=resized_image.width,
|
| 551 |
+
num_frames=num_frames,
|
| 552 |
+
guidance_scale=float(guidance_scale),
|
| 553 |
+
num_inference_steps=int(steps),
|
| 554 |
+
generator=torch.Generator(device="cuda").manual_seed(current_seed),
|
| 555 |
+
output_type="np",
|
| 556 |
+
)
|
| 557 |
+
finally:
|
| 558 |
+
if ip_adapter is not None:
|
| 559 |
+
ip_adapter.clear_hidden_states()
|
| 560 |
+
|
| 561 |
print(f"Gen time: {time.time()-start:.1f}s task={task_id}")
|
| 562 |
|
| 563 |
if loaded_lora:
|
| 564 |
try: pipe.unload_lora_weights()
|
| 565 |
except: pass
|
|
|
|
| 566 |
|
| 567 |
raw_frames = result.frames[0]
|
| 568 |
pipe.scheduler = original_scheduler
|
|
|
|
| 585 |
def generate_video(
|
| 586 |
input_image, last_image, face_ref_image, prompt,
|
| 587 |
steps=6, negative_prompt="", duration_seconds=MAX_DURATION,
|
| 588 |
+
guidance_scale=1.0, seed=42, randomize_seed=False,
|
| 589 |
quality=5, scheduler="UniPCMultistep", flow_shift=6.0,
|
| 590 |
frame_multiplier=16, lora_name="None", lora_scale=0.6,
|
| 591 |
blink_subject="woman",
|
|
|
|
| 619 |
effective_prompt = prompt
|
| 620 |
|
| 621 |
video_path, task_n = run_inference(
|
| 622 |
+
resized_image, processed_last, face_ref_image, effective_prompt,
|
| 623 |
+
steps, negative_prompt, num_frames, guidance_scale, current_seed,
|
| 624 |
scheduler, flow_shift, frame_multiplier, quality, duration_seconds,
|
| 625 |
lora_name, lora_scale, progress,
|
| 626 |
)
|
|
|
|
| 637 |
|
| 638 |
with gr.Blocks(css=CSS, delete_cache=(3600, 10800)) as demo:
|
| 639 |
gr.Markdown(f"## ZeroWan2GP β [{MODEL_ID.split('/')[-1]}](https://huggingface.co/{MODEL_ID})")
|
| 640 |
+
gr.Markdown("Wan 2.1 I2V 14B Β· fp8 Β· IP-Adapter face conditioning Β· ZeroGPU Β· RIFE interpolation Β· NSFW LoRA gallery")
|
| 641 |
|
| 642 |
with gr.Row():
|
| 643 |
with gr.Column():
|
|
|
|
| 670 |
seed_input = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed")
|
| 671 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 672 |
steps_slider = gr.Slider(1, 50, step=1, value=6, label="Steps")
|
| 673 |
+
gs_input = gr.Slider(0.0, 10.0, step=0.5, value=1.0, label="Guidance Scale")
|
|
|
|
| 674 |
scheduler_dd = gr.Dropdown(list(SCHEDULER_MAP.keys()), value="UniPCMultistep", label="Scheduler")
|
| 675 |
flow_shift_slider = gr.Slider(0.5, 15.0, step=0.1, value=6.0, label="Flow Shift")
|
| 676 |
play_result = gr.Checkbox(label="Display result", value=True)
|
|
|
|
| 687 |
|
| 688 |
ui_inputs = [
|
| 689 |
input_image_component, last_image_component, face_ref_component, prompt_input,
|
| 690 |
+
steps_slider, negative_prompt_input, duration_input, gs_input,
|
| 691 |
seed_input, randomize_seed, quality_slider, scheduler_dd, flow_shift_slider,
|
| 692 |
frame_multi, lora_dropdown, lora_scale_slider, blink_subject_radio, play_result,
|
| 693 |
]
|
|
|
|
| 696 |
grab_btn.click(fn=None, inputs=None, outputs=[timestamp_box], js=get_timestamp_js)
|
| 697 |
timestamp_box.change(fn=extract_frame, inputs=[video_output, timestamp_box], outputs=[input_image_component])
|
| 698 |
|
| 699 |
+
print("Warming up pipeline (loading model, fusing LightX2V, fp8, IP-Adapter)...")
|
| 700 |
_warmup_pipeline()
|
| 701 |
print("Warmup complete β Space ready.")
|
| 702 |
|
ip_adapter.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
WAN 2.1 IP-Adapter β diffusers-native port of kaaskoek232/IPAdapterWAN.
|
| 3 |
+
|
| 4 |
+
Architecture
|
| 5 |
+
SigLIP2 so400m (1152-d) β TimeResampler (1024-d, 8 queries)
|
| 6 |
+
β per-block WanIPAttnProcessor injected into every self-attention of
|
| 7 |
+
pipe.transformer
|
| 8 |
+
|
| 9 |
+
Weights
|
| 10 |
+
Resampler : loaded from InstantX/SD3.5-Large-IP-Adapter ip-adapter.bin
|
| 11 |
+
key prefix "image_proj" (architecture-matched)
|
| 12 |
+
IP proj : to_k_ip / to_v_ip initialised from the model's own to_k / to_v
|
| 13 |
+
weights (zero-shot reference-attention style β works without
|
| 14 |
+
Wan-specific training and produces real identity signal)
|
| 15 |
+
|
| 16 |
+
LoRA compatibility
|
| 17 |
+
IP processors sit on top of whatever to_q/to_k/to_v the LoRA has patched;
|
| 18 |
+
they are orthogonal (IP adds extra KV, LoRA modifies weight matrices).
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import math
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Optional
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
from einops import rearrange
|
| 31 |
+
from huggingface_hub import hf_hub_download
|
| 32 |
+
from PIL import Image
|
| 33 |
+
from transformers import AutoProcessor, SiglipVisionModel
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
|
| 38 |
+
def _reshape(t: torch.Tensor, heads: int) -> torch.Tensor:
|
| 39 |
+
b, n, d = t.shape
|
| 40 |
+
return t.reshape(b, n, heads, d // heads).transpose(1, 2)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ββ Perceiver / TimeResampler (matches SD3.5 ip-adapter.bin image_proj.*) βββββ
|
| 44 |
+
|
| 45 |
+
class _FeedForward(nn.Module):
|
| 46 |
+
def __init__(self, dim: int, mult: int = 4):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.net = nn.Sequential(
|
| 49 |
+
nn.LayerNorm(dim),
|
| 50 |
+
nn.Linear(dim, dim * mult, bias=False),
|
| 51 |
+
nn.GELU(),
|
| 52 |
+
nn.Linear(dim * mult, dim, bias=False),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
return self.net(x)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class _PerceiverAttention(nn.Module):
|
| 60 |
+
def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.heads = heads
|
| 63 |
+
inner = dim_head * heads
|
| 64 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 65 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 66 |
+
self.to_q = nn.Linear(dim, inner, bias=False)
|
| 67 |
+
self.to_kv = nn.Linear(dim, inner * 2, bias=False)
|
| 68 |
+
self.to_out = nn.Linear(inner, dim, bias=False)
|
| 69 |
+
|
| 70 |
+
def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
|
| 71 |
+
x = self.norm1(x)
|
| 72 |
+
latents = self.norm2(latents)
|
| 73 |
+
q = _reshape(self.to_q(latents), self.heads)
|
| 74 |
+
kv_in = torch.cat([x, latents], dim=1)
|
| 75 |
+
k, v = self.to_kv(kv_in).chunk(2, dim=-1)
|
| 76 |
+
k, v = _reshape(k, self.heads), _reshape(v, self.heads)
|
| 77 |
+
out = F.scaled_dot_product_attention(q, k, v)
|
| 78 |
+
out = out.transpose(1, 2).reshape(latents.shape[0], -1, self.to_out.in_features)
|
| 79 |
+
return self.to_out(out) + latents
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TimeResampler(nn.Module):
|
| 83 |
+
"""Perceiver resampler with adaLN timestep conditioning.
|
| 84 |
+
|
| 85 |
+
Architecture mirrors the image_proj section of
|
| 86 |
+
InstantX/SD3.5-Large-IP-Adapter ip-adapter.bin so its weights load cleanly.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
dim: int = 1024,
|
| 92 |
+
depth: int = 8,
|
| 93 |
+
dim_head: int = 64,
|
| 94 |
+
heads: int = 16,
|
| 95 |
+
num_queries: int = 8,
|
| 96 |
+
embedding_dim: int = 1152, # SigLIP2 so400m
|
| 97 |
+
output_dim: int = 1024,
|
| 98 |
+
ff_mult: int = 4,
|
| 99 |
+
timestep_in_dim: int = 320,
|
| 100 |
+
timestep_flip_sin_to_cos: bool = True,
|
| 101 |
+
timestep_freq_shift: int = 0,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
|
| 105 |
+
self.num_queries = num_queries
|
| 106 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
|
| 107 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
| 108 |
+
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
|
| 109 |
+
self.t_emb = TimestepEmbedding(timestep_in_dim, dim)
|
| 110 |
+
self.layers = nn.ModuleList([
|
| 111 |
+
nn.ModuleList([
|
| 112 |
+
_PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
| 113 |
+
_FeedForward(dim=dim, mult=ff_mult),
|
| 114 |
+
nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim)), # adaLN
|
| 115 |
+
])
|
| 116 |
+
for _ in range(depth)
|
| 117 |
+
])
|
| 118 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
| 119 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
| 120 |
+
|
| 121 |
+
def forward(
|
| 122 |
+
self,
|
| 123 |
+
x: torch.Tensor,
|
| 124 |
+
timestep: torch.Tensor,
|
| 125 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 126 |
+
t = self.time_proj(timestep.flatten()).to(x.dtype)
|
| 127 |
+
t_emb = self.t_emb(t) # (B, dim)
|
| 128 |
+
latents = self.latents.expand(x.size(0), -1, -1).clone()
|
| 129 |
+
x = self.proj_in(x)
|
| 130 |
+
for attn, ff, adaln in self.layers:
|
| 131 |
+
s_msa, c_msa, s_mlp, c_mlp = adaln(t_emb).chunk(4, dim=-1)
|
| 132 |
+
latents = latents * (1 + c_msa[:, None]) + s_msa[:, None]
|
| 133 |
+
latents = attn(x, latents)
|
| 134 |
+
latents = latents * (1 + c_mlp[:, None]) + s_mlp[:, None]
|
| 135 |
+
latents = ff(latents) + latents
|
| 136 |
+
latents = self.norm_out(self.proj_out(latents))
|
| 137 |
+
return latents, t_emb
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ββ Per-block attention processor βββββββββββββββββββββββββββββββββββββββββββββ
|
| 141 |
+
|
| 142 |
+
class WanIPAttnProcessor:
|
| 143 |
+
"""Wraps an existing Attention processor and adds IP face KV injection.
|
| 144 |
+
|
| 145 |
+
The IP keys/values are initialised from the model's own to_k / to_v weights
|
| 146 |
+
(zero-shot reference-attention), so no separate IP training is needed.
|
| 147 |
+
Conditioned frames attend to the face tokens in every self-attention block.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
original_processor,
|
| 153 |
+
to_k_ip: nn.Linear,
|
| 154 |
+
to_v_ip: nn.Linear,
|
| 155 |
+
norm_k_ip: Optional[nn.Module] = None,
|
| 156 |
+
norm_v_ip: Optional[nn.Module] = None,
|
| 157 |
+
scale: float = 1.0,
|
| 158 |
+
):
|
| 159 |
+
self.original = original_processor
|
| 160 |
+
self.to_k_ip = to_k_ip
|
| 161 |
+
self.to_v_ip = to_v_ip
|
| 162 |
+
self.norm_k_ip = norm_k_ip
|
| 163 |
+
self.norm_v_ip = norm_v_ip
|
| 164 |
+
self.scale = scale
|
| 165 |
+
# Set before each pipeline call; cleared after.
|
| 166 |
+
self.ip_hidden_states: Optional[torch.Tensor] = None
|
| 167 |
+
|
| 168 |
+
def __call__(self, attn, hidden_states, *args, **kwargs):
|
| 169 |
+
out = self.original(attn, hidden_states, *args, **kwargs)
|
| 170 |
+
|
| 171 |
+
if self.ip_hidden_states is None or self.scale == 0:
|
| 172 |
+
return out
|
| 173 |
+
|
| 174 |
+
hs = self.ip_hidden_states
|
| 175 |
+
h = attn.heads
|
| 176 |
+
# Compute Q from hidden_states (re-use the model's normalised projection)
|
| 177 |
+
q = attn.to_q(hidden_states)
|
| 178 |
+
if attn.norm_q is not None:
|
| 179 |
+
q = attn.norm_q(q)
|
| 180 |
+
|
| 181 |
+
# Compute IP K / V
|
| 182 |
+
k_ip = self.to_k_ip(hs)
|
| 183 |
+
v_ip = self.to_v_ip(hs)
|
| 184 |
+
if self.norm_k_ip is not None:
|
| 185 |
+
k_ip = self.norm_k_ip(k_ip)
|
| 186 |
+
if self.norm_v_ip is not None:
|
| 187 |
+
v_ip = self.norm_v_ip(v_ip)
|
| 188 |
+
|
| 189 |
+
q = _reshape(q, h)
|
| 190 |
+
k_ip = _reshape(k_ip, h)
|
| 191 |
+
v_ip = _reshape(v_ip, h)
|
| 192 |
+
|
| 193 |
+
ip_attn = F.scaled_dot_product_attention(q, k_ip, v_ip)
|
| 194 |
+
ip_attn = ip_attn.transpose(1, 2).reshape(
|
| 195 |
+
hidden_states.shape[0], -1, attn.inner_dim
|
| 196 |
+
)
|
| 197 |
+
ip_attn = attn.to_out[0](ip_attn)
|
| 198 |
+
if len(attn.to_out) > 1:
|
| 199 |
+
ip_attn = attn.to_out[1](ip_attn)
|
| 200 |
+
|
| 201 |
+
return out + ip_attn * self.scale
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ββ Main class βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 205 |
+
|
| 206 |
+
class WanIPAdapter:
|
| 207 |
+
"""Loads the IP-Adapter and patches pipe.transformer for face conditioning.
|
| 208 |
+
|
| 209 |
+
Usage inside _init_pipeline():
|
| 210 |
+
ip_adapter = WanIPAdapter(pipe, device=pipe.device, dtype=torch.bfloat16)
|
| 211 |
+
|
| 212 |
+
Usage inside run_inference() (before pipe()):
|
| 213 |
+
if face_ref is not None:
|
| 214 |
+
emb = ip_adapter.encode(face_ref, timestep=500)
|
| 215 |
+
ip_adapter.set_hidden_states(emb, scale=ip_scale)
|
| 216 |
+
result = pipe(...)
|
| 217 |
+
ip_adapter.clear_hidden_states()
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
_IP_ADAPTER_REPO = "InstantX/SD3.5-Large-IP-Adapter"
|
| 221 |
+
_IP_ADAPTER_FILE = "ip-adapter.bin"
|
| 222 |
+
_VISION_MODEL = "google/siglip-so400m-patch14-384"
|
| 223 |
+
|
| 224 |
+
def __init__(
|
| 225 |
+
self,
|
| 226 |
+
pipe,
|
| 227 |
+
device: torch.device,
|
| 228 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 229 |
+
cache_dir: str = "/data/ip_adapter",
|
| 230 |
+
):
|
| 231 |
+
self.pipe = pipe
|
| 232 |
+
self.device = device
|
| 233 |
+
self.dtype = dtype
|
| 234 |
+
|
| 235 |
+
self._load_vision_encoder()
|
| 236 |
+
self._load_resampler(cache_dir)
|
| 237 |
+
self._patch_transformer(pipe.transformer)
|
| 238 |
+
print("[IP-Adapter] ready")
|
| 239 |
+
|
| 240 |
+
# ββ setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 241 |
+
|
| 242 |
+
def _load_vision_encoder(self):
|
| 243 |
+
print("[IP-Adapter] loading SigLIP vision encoderβ¦")
|
| 244 |
+
self.vis_proc = AutoProcessor.from_pretrained(self._VISION_MODEL)
|
| 245 |
+
self.vis_model = SiglipVisionModel.from_pretrained(
|
| 246 |
+
self._VISION_MODEL, torch_dtype=self.dtype
|
| 247 |
+
).to(self.device)
|
| 248 |
+
self.vis_model.eval()
|
| 249 |
+
print("[IP-Adapter] SigLIP loaded")
|
| 250 |
+
|
| 251 |
+
def _load_resampler(self, cache_dir: str):
|
| 252 |
+
print("[IP-Adapter] loading TimeResampler from SD3.5 ip-adapter.binβ¦")
|
| 253 |
+
ckpt = hf_hub_download(
|
| 254 |
+
repo_id=self._IP_ADAPTER_REPO,
|
| 255 |
+
filename=self._IP_ADAPTER_FILE,
|
| 256 |
+
local_dir=cache_dir,
|
| 257 |
+
)
|
| 258 |
+
state = torch.load(ckpt, map_location="cpu", weights_only=True)
|
| 259 |
+
|
| 260 |
+
# Detect checkpoint key prefix (ip-adapter.bin uses "image_proj.*")
|
| 261 |
+
prefix = "image_proj"
|
| 262 |
+
img_proj = {
|
| 263 |
+
k[len(prefix) + 1:]: v
|
| 264 |
+
for k, v in state.items()
|
| 265 |
+
if k.startswith(prefix + ".")
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
self.resampler = TimeResampler().to(self.device, self.dtype)
|
| 269 |
+
missing, unexpected = self.resampler.load_state_dict(img_proj, strict=False)
|
| 270 |
+
if missing:
|
| 271 |
+
print(f"[IP-Adapter] resampler missing keys ({len(missing)}): {missing[:4]}β¦")
|
| 272 |
+
print("[IP-Adapter] resampler loaded")
|
| 273 |
+
|
| 274 |
+
def _patch_transformer(self, transformer: nn.Module):
|
| 275 |
+
"""Replace every self-attention processor with WanIPAttnProcessor."""
|
| 276 |
+
self._processors: list[WanIPAttnProcessor] = []
|
| 277 |
+
|
| 278 |
+
for name, mod in transformer.named_modules():
|
| 279 |
+
if not (hasattr(mod, "processor") and hasattr(mod, "to_k")):
|
| 280 |
+
continue
|
| 281 |
+
|
| 282 |
+
# Build IP projections mirroring the model's own K/V projections
|
| 283 |
+
to_k_ip = nn.Linear(
|
| 284 |
+
self.resampler.proj_out.out_features,
|
| 285 |
+
mod.to_k.out_features,
|
| 286 |
+
bias=False,
|
| 287 |
+
).to(self.device, self.dtype)
|
| 288 |
+
to_v_ip = nn.Linear(
|
| 289 |
+
self.resampler.proj_out.out_features,
|
| 290 |
+
mod.to_v.out_features,
|
| 291 |
+
bias=False,
|
| 292 |
+
).to(self.device, self.dtype)
|
| 293 |
+
|
| 294 |
+
# Zero-shot init: copy model's own projection weights then scale down
|
| 295 |
+
# so the initial IP signal is small but directionally meaningful.
|
| 296 |
+
k_w = mod.to_k.weight.data
|
| 297 |
+
v_w = mod.to_v.weight.data
|
| 298 |
+
out_f, in_f = to_k_ip.weight.shape
|
| 299 |
+
# in_f = resampler output (1024); in_f may differ from k_w.shape[1]
|
| 300 |
+
# β just use kaiming init if shapes differ
|
| 301 |
+
if in_f == k_w.shape[1]:
|
| 302 |
+
to_k_ip.weight.data.copy_(k_w[:out_f] * 0.01)
|
| 303 |
+
to_v_ip.weight.data.copy_(v_w[:out_f] * 0.01)
|
| 304 |
+
else:
|
| 305 |
+
nn.init.kaiming_uniform_(to_k_ip.weight, a=math.sqrt(5))
|
| 306 |
+
nn.init.kaiming_uniform_(to_v_ip.weight, a=math.sqrt(5))
|
| 307 |
+
to_k_ip.weight.data *= 0.01
|
| 308 |
+
to_v_ip.weight.data *= 0.01
|
| 309 |
+
|
| 310 |
+
# Clone existing norms if present
|
| 311 |
+
norm_k = mod.norm_k.__class__(mod.norm_k.normalized_shape[0]) \
|
| 312 |
+
if hasattr(mod, "norm_k") and mod.norm_k is not None else None
|
| 313 |
+
norm_v = mod.norm_v.__class__(mod.norm_v.normalized_shape[0]) \
|
| 314 |
+
if hasattr(mod, "norm_v") and mod.norm_v is not None else None
|
| 315 |
+
if norm_k is not None:
|
| 316 |
+
norm_k = norm_k.to(self.device, self.dtype)
|
| 317 |
+
if norm_v is not None:
|
| 318 |
+
norm_v = norm_v.to(self.device, self.dtype)
|
| 319 |
+
|
| 320 |
+
ip_proc = WanIPAttnProcessor(
|
| 321 |
+
original_processor=mod.processor,
|
| 322 |
+
to_k_ip=to_k_ip,
|
| 323 |
+
to_v_ip=to_v_ip,
|
| 324 |
+
norm_k_ip=norm_k,
|
| 325 |
+
norm_v_ip=norm_v,
|
| 326 |
+
)
|
| 327 |
+
mod.processor = ip_proc
|
| 328 |
+
self._processors.append(ip_proc)
|
| 329 |
+
|
| 330 |
+
print(f"[IP-Adapter] patched {len(self._processors)} attention blocks")
|
| 331 |
+
|
| 332 |
+
# ββ inference API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 333 |
+
|
| 334 |
+
@torch.no_grad()
|
| 335 |
+
def encode(self, image: Image.Image, timestep: int = 500) -> torch.Tensor:
|
| 336 |
+
"""Encode *image* through SigLIP2 + TimeResampler β (1, 8, 1024)."""
|
| 337 |
+
inputs = self.vis_proc(images=image, return_tensors="pt").to(self.device)
|
| 338 |
+
vis_out = self.vis_model(**inputs)
|
| 339 |
+
# Use last_hidden_state (patch tokens) rather than pooled for richer features
|
| 340 |
+
vis_feats = vis_out.last_hidden_state.to(self.dtype) # (1, N, 1152)
|
| 341 |
+
t = torch.tensor([timestep], device=self.device, dtype=torch.long)
|
| 342 |
+
emb, _ = self.resampler(vis_feats, t) # (1, 8, 1024)
|
| 343 |
+
return emb
|
| 344 |
+
|
| 345 |
+
def set_hidden_states(self, emb: torch.Tensor, scale: float = 0.6):
|
| 346 |
+
"""Broadcast *emb* to all processors before a pipe() call."""
|
| 347 |
+
for p in self._processors:
|
| 348 |
+
p.ip_hidden_states = emb
|
| 349 |
+
p.scale = scale
|
| 350 |
+
|
| 351 |
+
def clear_hidden_states(self):
|
| 352 |
+
"""Remove face embeddings after pipe() returns."""
|
| 353 |
+
for p in self._processors:
|
| 354 |
+
p.ip_hidden_states = None
|
requirements.txt
CHANGED
|
@@ -15,3 +15,4 @@ sageattention
|
|
| 15 |
torchvision
|
| 16 |
insightface
|
| 17 |
onnxruntime
|
|
|
|
|
|
| 15 |
torchvision
|
| 16 |
insightface
|
| 17 |
onnxruntime
|
| 18 |
+
einops
|