ViTeX-Bench commited on
Commit
baabe40
·
verified ·
1 Parent(s): 9a431ad

Use bundled code + base_model paths (v2)

Browse files
Files changed (1) hide show
  1. inference_example.py +39 -36
inference_example.py CHANGED
@@ -1,17 +1,9 @@
1
  """
2
- ViTeX-14B inference example.
3
 
4
- Loads:
5
- - Wan-AI/Wan2.1-VACE-14B (base model)
6
- - ViTeX-Bench/ViTeX-14B (this fine-tuned VACE module)
7
-
8
- Runs one or more video text-edit jobs, writing MP4 outputs.
9
-
10
- Requires:
11
- - The DiffSynth-Studio-TextVACE fork (provides GlyphEncoder + ConditionCrossAttention)
12
- - torch >= 2.7.0+cu128 (NCCL >= 2.25.1 recommended on H100)
13
- - One NVIDIA GPU with >= 80 GB VRAM (H100 / A100 80 GB)
14
- - imageio-ffmpeg, opencv-python
15
 
16
  Usage:
17
  python inference_example.py \
@@ -20,21 +12,32 @@ Usage:
20
  --glyph_video path/to/target_glyph.mp4 \
21
  --prompt "Change the sign to read 'HILTON'" \
22
  --output out.mp4
 
 
 
 
23
  """
24
 
25
  import os
 
26
  import argparse
27
  import glob
28
 
 
 
 
 
29
  import torch
30
  from PIL import Image
31
 
32
- from huggingface_hub import snapshot_download
33
-
34
  from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
35
  from diffsynth.core import load_state_dict
36
 
37
 
 
 
 
 
38
  HEIGHT = 720
39
  WIDTH = 1280
40
  NUM_FRAMES = 121
@@ -44,7 +47,8 @@ SEED = 42
44
 
45
 
46
  def load_video_frames(path, target_frames=NUM_FRAMES, resize=(HEIGHT, WIDTH)):
47
- """Load a video file into a list of PIL Images, optionally subsampling/padding."""
 
48
  import cv2
49
  cap = cv2.VideoCapture(path)
50
  frames = []
@@ -60,7 +64,6 @@ def load_video_frames(path, target_frames=NUM_FRAMES, resize=(HEIGHT, WIDTH)):
60
 
61
  if not frames:
62
  raise ValueError(f"empty video: {path}")
63
-
64
  if target_frames and len(frames) > target_frames:
65
  import numpy as np
66
  idx = np.linspace(0, len(frames) - 1, target_frames, dtype=int)
@@ -93,21 +96,31 @@ def save_video(frames, path, fps=24):
93
  proc.wait()
94
 
95
 
96
- def build_pipeline(base_dir, ckpt_path, device="cuda:0"):
97
- diffusion_shards = sorted(glob.glob(os.path.join(base_dir, "diffusion_pytorch_model-*.safetensors")))
 
 
 
 
 
 
 
 
 
98
  pipe = WanVideoPipeline.from_pretrained(
99
  torch_dtype=torch.bfloat16,
100
  device=device,
101
  model_configs=[
102
  ModelConfig(path=diffusion_shards),
103
- ModelConfig(path=os.path.join(base_dir, "models_t5_umt5-xxl-enc-bf16.pth")),
104
- ModelConfig(path=os.path.join(base_dir, "Wan2.1_VAE.pth")),
105
  ],
106
- tokenizer_config=ModelConfig(path=os.path.join(base_dir, "google/umt5-xxl")),
107
  redirect_common_files=False,
108
  )
109
- print(f"Loading ViTeX-14B weights from {ckpt_path}")
110
- state = load_state_dict(ckpt_path)
 
111
  res = pipe.vace.load_state_dict(state, strict=False)
112
  print(f" loaded {len(state)} keys (missing {len(res.missing_keys)}, unexpected {len(res.unexpected_keys)})")
113
  del state
@@ -116,9 +129,9 @@ def build_pipeline(base_dir, ckpt_path, device="cuda:0"):
116
 
117
  def main():
118
  p = argparse.ArgumentParser()
119
- p.add_argument("--vace_video", required=True, help="Source RGB video (the one to edit).")
120
  p.add_argument("--vace_mask", required=True, help="Per-frame binary mask: 1=replace, 0=keep.")
121
- p.add_argument("--glyph_video", required=True, help="Pre-rendered target glyphs placed in the mask region.")
122
  p.add_argument("--prompt", default="", help="Optional text prompt describing the edit.")
123
  p.add_argument("--output", default="output.mp4")
124
  p.add_argument("--height", type=int, default=HEIGHT)
@@ -130,23 +143,13 @@ def main():
130
  p.add_argument("--device", default="cuda:0")
131
  args = p.parse_args()
132
 
133
- # 1. Download base + this model
134
- print("Downloading Wan-AI/Wan2.1-VACE-14B (base, ~60 GB)...")
135
- base_dir = snapshot_download("Wan-AI/Wan2.1-VACE-14B")
136
- print("Downloading ViTeX-Bench/ViTeX-14B (this model, ~8 GB)...")
137
- vitex_dir = snapshot_download("ViTeX-Bench/ViTeX-14B")
138
- ckpt_path = os.path.join(vitex_dir, "vitex_14b.safetensors")
139
-
140
- # 2. Build pipeline
141
- pipe = build_pipeline(base_dir, ckpt_path, device=args.device)
142
 
143
- # 3. Load inputs
144
  target_size = (args.height, args.width)
145
  vace_video = load_video_frames(args.vace_video, args.num_frames, target_size)
146
  vace_mask = load_video_frames(args.vace_mask, args.num_frames, target_size)
147
  glyph = load_video_frames(args.glyph_video, args.num_frames, target_size)
148
 
149
- # 4. Run
150
  print(f"Running pipeline (seed={args.seed}, cfg={args.cfg_scale}, steps={args.num_inference_steps})...")
151
  out_frames = pipe(
152
  prompt=args.prompt,
 
1
  """
2
+ ViTeX-14B inference example (self-contained).
3
 
4
+ Assumes you cloned this HuggingFace repo and are running this script from the
5
+ repo root. The bundled `diffsynth/` library, `vitex_14b.safetensors` weights,
6
+ and the full `base_model/` directory are picked up automatically.
 
 
 
 
 
 
 
 
7
 
8
  Usage:
9
  python inference_example.py \
 
12
  --glyph_video path/to/target_glyph.mp4 \
13
  --prompt "Change the sign to read 'HILTON'" \
14
  --output out.mp4
15
+
16
+ Hardware:
17
+ - 1 × NVIDIA GPU with >= 80 GB VRAM (peak ~70 GB at 720 × 1280 × 121 frames)
18
+ - ~250 GB CPU RAM recommended (DiT loading + activation offload)
19
  """
20
 
21
  import os
22
+ import sys
23
  import argparse
24
  import glob
25
 
26
+ # Use the bundled diffsynth shipped alongside this script.
27
+ HERE = os.path.dirname(os.path.abspath(__file__))
28
+ sys.path.insert(0, HERE)
29
+
30
  import torch
31
  from PIL import Image
32
 
 
 
33
  from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
34
  from diffsynth.core import load_state_dict
35
 
36
 
37
+ BASE_DIR = os.path.join(HERE, "base_model")
38
+ ADAPTER_CKPT = os.path.join(HERE, "vitex_14b.safetensors")
39
+ TOKENIZER_DIR = os.path.join(BASE_DIR, "google", "umt5-xxl")
40
+
41
  HEIGHT = 720
42
  WIDTH = 1280
43
  NUM_FRAMES = 121
 
47
 
48
 
49
  def load_video_frames(path, target_frames=NUM_FRAMES, resize=(HEIGHT, WIDTH)):
50
+ """Load a video file into a list of PIL Images, sub-sampled or padded to
51
+ `target_frames`, optionally resized to `(H, W)`."""
52
  import cv2
53
  cap = cv2.VideoCapture(path)
54
  frames = []
 
64
 
65
  if not frames:
66
  raise ValueError(f"empty video: {path}")
 
67
  if target_frames and len(frames) > target_frames:
68
  import numpy as np
69
  idx = np.linspace(0, len(frames) - 1, target_frames, dtype=int)
 
96
  proc.wait()
97
 
98
 
99
+ def build_pipeline(device="cuda:0"):
100
+ diffusion_shards = sorted(glob.glob(os.path.join(BASE_DIR, "diffusion_pytorch_model-*.safetensors")))
101
+ if not diffusion_shards:
102
+ raise FileNotFoundError(
103
+ f"No diffusion_pytorch_model-*.safetensors found under {BASE_DIR}. "
104
+ "Make sure you downloaded the full repo via `git lfs clone` or "
105
+ "`huggingface-cli download ViTeX-Bench/ViTeX-14B`."
106
+ )
107
+ if not os.path.isfile(ADAPTER_CKPT):
108
+ raise FileNotFoundError(f"Missing trained adapter: {ADAPTER_CKPT}")
109
+
110
  pipe = WanVideoPipeline.from_pretrained(
111
  torch_dtype=torch.bfloat16,
112
  device=device,
113
  model_configs=[
114
  ModelConfig(path=diffusion_shards),
115
+ ModelConfig(path=os.path.join(BASE_DIR, "models_t5_umt5-xxl-enc-bf16.pth")),
116
+ ModelConfig(path=os.path.join(BASE_DIR, "Wan2.1_VAE.pth")),
117
  ],
118
+ tokenizer_config=ModelConfig(path=TOKENIZER_DIR),
119
  redirect_common_files=False,
120
  )
121
+
122
+ print(f"Loading ViTeX-14B trained weights from {ADAPTER_CKPT}")
123
+ state = load_state_dict(ADAPTER_CKPT)
124
  res = pipe.vace.load_state_dict(state, strict=False)
125
  print(f" loaded {len(state)} keys (missing {len(res.missing_keys)}, unexpected {len(res.unexpected_keys)})")
126
  del state
 
129
 
130
  def main():
131
  p = argparse.ArgumentParser()
132
+ p.add_argument("--vace_video", required=True, help="Source RGB video to edit.")
133
  p.add_argument("--vace_mask", required=True, help="Per-frame binary mask: 1=replace, 0=keep.")
134
+ p.add_argument("--glyph_video", required=True, help="Pre-rendered target glyphs in the mask region.")
135
  p.add_argument("--prompt", default="", help="Optional text prompt describing the edit.")
136
  p.add_argument("--output", default="output.mp4")
137
  p.add_argument("--height", type=int, default=HEIGHT)
 
143
  p.add_argument("--device", default="cuda:0")
144
  args = p.parse_args()
145
 
146
+ pipe = build_pipeline(device=args.device)
 
 
 
 
 
 
 
 
147
 
 
148
  target_size = (args.height, args.width)
149
  vace_video = load_video_frames(args.vace_video, args.num_frames, target_size)
150
  vace_mask = load_video_frames(args.vace_mask, args.num_frames, target_size)
151
  glyph = load_video_frames(args.glyph_video, args.num_frames, target_size)
152
 
 
153
  print(f"Running pipeline (seed={args.seed}, cfg={args.cfg_scale}, steps={args.num_inference_steps})...")
154
  out_frames = pipe(
155
  prompt=args.prompt,