Aduc-sdr commited on
Commit
0c097db
·
verified ·
1 Parent(s): 2d69166

Update managers/seedvr_manager.py

Browse files
Files changed (1) hide show
  1. managers/seedvr_manager.py +38 -36
managers/seedvr_manager.py CHANGED
@@ -2,11 +2,11 @@
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
- # Version: 2.3.4
6
  #
7
- # This version is optimized for Hugging Face Spaces environments. It now clones
8
- # the dependency directly from the official SeedVR HF Space, which is faster,
9
- # lighter, and more reliable than cloning from GitHub.
10
 
11
  import torch
12
  import torch.distributed as dist
@@ -22,17 +22,15 @@ import gradio as gr
22
  import mediapy
23
  from einops import rearrange
24
 
25
- # Internalized utility for color correction, ensuring stability.
26
  from tools.tensor_utils import wavelet_reconstruction
27
 
28
  logger = logging.getLogger(__name__)
29
 
30
  # --- Dependency Management ---
31
  DEPS_DIR = Path("./deps")
32
- # Renamed to reflect the new source
33
  SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
34
- # NEW: Cloning from the HF Space directly is much more efficient
35
  SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
 
36
 
37
  def setup_seedvr_dependencies():
38
  """
@@ -42,7 +40,6 @@ def setup_seedvr_dependencies():
42
  logger.info(f"SeedVR Space not found at '{SEEDVR_SPACE_DIR}'. Cloning from Hugging Face...")
43
  try:
44
  DEPS_DIR.mkdir(exist_ok=True)
45
- # We clone the entire space repo to get its file structure
46
  subprocess.run(
47
  ["git", "clone", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
48
  check=True, capture_output=True, text=True
@@ -60,8 +57,8 @@ def setup_seedvr_dependencies():
60
 
61
  setup_seedvr_dependencies()
62
 
63
- # The imports from a Space are often directly from the root
64
- from infer import VideoDiffusionInfer
65
  from common.config import load_config
66
  from common.seed import set_seed
67
  from data.image.transforms.divisible_crop import DivisibleCrop
@@ -71,6 +68,7 @@ from torchvision.transforms import Compose, Lambda, Normalize
71
  from torchvision.io.video import read_video
72
  from omegaconf import OmegaConf
73
 
 
74
  def _load_file_from_url(url, model_dir='./', file_name=None):
75
  os.makedirs(model_dir, exist_ok=True)
76
  filename = file_name or os.path.basename(urlparse(url).path)
@@ -90,11 +88,14 @@ class SeedVrManager:
90
  self._original_barrier = None
91
  logger.info("SeedVrManager initialized. Model will be loaded on demand.")
92
 
93
- def _download_models(self):
94
- """Downloads the necessary checkpoints for SeedVR2."""
95
- logger.info("Verifying and downloading SeedVR2 models...")
96
- ckpt_dir = SEEDVR_SPACE_DIR / 'ckpt' # Note: Path in Space repo might be different
 
97
  ckpt_dir.mkdir(exist_ok=True)
 
 
98
  pretrain_model_urls = {
99
  'vae_ckpt': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
100
  'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
@@ -104,41 +105,42 @@ class SeedVrManager:
104
  }
105
  for key, url in pretrain_model_urls.items():
106
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
107
- logger.info("SeedVR2 models downloaded successfully.")
108
 
109
  def _initialize_runner(self, model_version: str):
110
- """Loads and configures the SeedVR model."""
111
  if self.runner is not None: return
112
- self._download_models()
113
-
114
  if dist.is_available() and not dist.is_initialized():
115
  logger.info("Applying patch to disable torch.distributed.barrier for single-GPU inference.")
116
  self._original_barrier = dist.barrier
117
  dist.barrier = lambda *args, **kwargs: None
118
 
119
- logger.info(f"Initializing SeedVR2 {model_version} runner from Space repo...")
120
  if model_version == '3B':
121
- config_path = SEEDVR_SPACE_DIR / 'configs' / 'generate.yaml' # Typical path in a Space
122
- checkpoint_path = SEEDVR_SPACE_DIR / 'ckpt' / 'VINCIE-3B' / 'dit.pth'
123
  elif model_version == '7B':
124
- # Assuming a similar structure for a 7B space if it existed
125
- config_path = SEEDVR_SPACE_DIR / 'configs' / 'generate_7b.yaml'
126
- checkpoint_path = SEEDVR_SPACE_DIR / 'ckpt' / 'VINCIE-7B' / 'dit.pth'
127
  else:
128
  raise ValueError(f"Unsupported SeedVR model version: {model_version}")
129
 
130
- config = load_config(str(config_path))
131
-
 
 
 
 
 
 
 
 
132
  self.runner = VideoDiffusionInfer(config)
133
  OmegaConf.set_readonly(self.runner.config, False)
134
- # Manually set the correct checkpoint paths since the config inside the space might be relative
135
- self.runner.config.dit.checkpoint = str(checkpoint_path)
136
- self.runner.config.vae.checkpoint = str(SEEDVR_SPACE_DIR / 'ckpt' / 'VINCIE-3B' / 'vae.pth')
137
- self.runner.config.text.models[0].path = str(SEEDVR_SPACE_DIR / 'ckpt' / 'VINCIE-3B' / 'llm14b')
138
-
139
- self.runner.configure_dit_model(device=self.device, checkpoint=self.runner.config.dit.checkpoint)
140
  self.runner.configure_vae_model()
141
-
142
  if hasattr(self.runner.vae, "set_memory_limit"):
143
  self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
144
  self.is_initialized = True
@@ -181,8 +183,8 @@ class SeedVrManager:
181
  cond_latents = self.runner.vae_encode(cond_latents)
182
  self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
183
  self.runner.dit.to(self.device)
184
- pos_emb_path = SEEDVR_SPACE_DIR / 'ckpt' / 'pos_emb.pt'
185
- neg_emb_path = SEEDVR_SPACE_DIR / 'ckpt' / 'neg_emb.pt'
186
  text_pos_embeds = torch.load(pos_emb_path).to(self.device)
187
  text_neg_embeds = torch.load(neg_emb_path).to(self.device)
188
  text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
@@ -208,4 +210,4 @@ class SeedVrManager:
208
  self._unload_runner()
209
 
210
  # --- Singleton Instance ---
211
- seedvr_manager_singleton = SeedVrManager()
 
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
+ # Version: 2.3.5
6
  #
7
+ # This version uses the optimal strategy of cloning the self-contained Hugging Face
8
+ # Space repository and uses the full, correct import paths to resolve all
9
+ # ModuleNotFoundErrors, while retaining necessary runtime patches.
10
 
11
  import torch
12
  import torch.distributed as dist
 
22
  import mediapy
23
  from einops import rearrange
24
 
 
25
  from tools.tensor_utils import wavelet_reconstruction
26
 
27
  logger = logging.getLogger(__name__)
28
 
29
  # --- Dependency Management ---
30
  DEPS_DIR = Path("./deps")
 
31
  SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
 
32
  SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
33
+ VAE_CONFIG_URL = "https://raw.githubusercontent.com/ByteDance-Seed/SeedVR/main/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml"
34
 
35
  def setup_seedvr_dependencies():
36
  """
 
40
  logger.info(f"SeedVR Space not found at '{SEEDVR_SPACE_DIR}'. Cloning from Hugging Face...")
41
  try:
42
  DEPS_DIR.mkdir(exist_ok=True)
 
43
  subprocess.run(
44
  ["git", "clone", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
45
  check=True, capture_output=True, text=True
 
57
 
58
  setup_seedvr_dependencies()
59
 
60
+ # Use full import paths relative to the root of the cloned repository
61
+ from projects.video_diffusion_sr.infer import VideoDiffusionInfer
62
  from common.config import load_config
63
  from common.seed import set_seed
64
  from data.image.transforms.divisible_crop import DivisibleCrop
 
68
  from torchvision.io.video import read_video
69
  from omegaconf import OmegaConf
70
 
71
+
72
  def _load_file_from_url(url, model_dir='./', file_name=None):
73
  os.makedirs(model_dir, exist_ok=True)
74
  filename = file_name or os.path.basename(urlparse(url).path)
 
88
  self._original_barrier = None
89
  logger.info("SeedVrManager initialized. Model will be loaded on demand.")
90
 
91
+ def _download_models_and_configs(self):
92
+ """Downloads the necessary checkpoints AND the missing VAE config file."""
93
+ logger.info("Verifying and downloading SeedVR2 models and configs...")
94
+ ckpt_dir = SEEDVR_SPACE_DIR / 'ckpts'
95
+ config_dir = SEEDVR_SPACE_DIR / 'configs' / 'vae'
96
  ckpt_dir.mkdir(exist_ok=True)
97
+ config_dir.mkdir(parents=True, exist_ok=True)
98
+ _load_file_from_url(url=VAE_CONFIG_URL, model_dir=str(config_dir))
99
  pretrain_model_urls = {
100
  'vae_ckpt': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
101
  'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
 
105
  }
106
  for key, url in pretrain_model_urls.items():
107
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
108
+ logger.info("SeedVR2 models and configs downloaded successfully.")
109
 
110
  def _initialize_runner(self, model_version: str):
111
+ """Loads and configures the SeedVR model, with patches for single-GPU inference."""
112
  if self.runner is not None: return
113
+ self._download_models_and_configs()
114
+
115
  if dist.is_available() and not dist.is_initialized():
116
  logger.info("Applying patch to disable torch.distributed.barrier for single-GPU inference.")
117
  self._original_barrier = dist.barrier
118
  dist.barrier = lambda *args, **kwargs: None
119
 
120
+ logger.info(f"Initializing SeedVR2 {model_version} runner...")
121
  if model_version == '3B':
122
+ config_path = SEEDVR_SPACE_DIR / 'configs_3b' / 'main.yaml'
123
+ checkpoint_path = SEEDVR_SPACE_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
124
  elif model_version == '7B':
125
+ config_path = SEEDVR_SPACE_DIR / 'configs_7b' / 'main.yaml'
126
+ checkpoint_path = SEEDVR_SPACE_DIR / 'ckpts' / 'seedvr2_ema_7b.pth'
 
127
  else:
128
  raise ValueError(f"Unsupported SeedVR model version: {model_version}")
129
 
130
+ try:
131
+ config = load_config(str(config_path))
132
+ except FileNotFoundError:
133
+ logger.warning("Caught expected FileNotFoundError. Loading config manually.")
134
+ config = OmegaConf.load(str(config_path))
135
+ correct_vae_config_path = SEEDVR_SPACE_DIR / 'configs' / 'vae' / 's8_c16_t4_inflation_sd3.yaml'
136
+ vae_config = OmegaConf.load(str(correct_vae_config_path))
137
+ config.vae = vae_config
138
+ logger.info("Configuration loaded and patched manually.")
139
+
140
  self.runner = VideoDiffusionInfer(config)
141
  OmegaConf.set_readonly(self.runner.config, False)
142
+ self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
 
 
 
 
 
143
  self.runner.configure_vae_model()
 
144
  if hasattr(self.runner.vae, "set_memory_limit"):
145
  self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
146
  self.is_initialized = True
 
183
  cond_latents = self.runner.vae_encode(cond_latents)
184
  self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
185
  self.runner.dit.to(self.device)
186
+ pos_emb_path = SEEDVR_SPACE_DIR / 'ckpts' / 'pos_emb.pt'
187
+ neg_emb_path = SEEDVR_SPACE_DIR / 'ckpts' / 'neg_emb.pt'
188
  text_pos_embeds = torch.load(pos_emb_path).to(self.device)
189
  text_neg_embeds = torch.load(neg_emb_path).to(self.device)
190
  text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
 
210
  self._unload_runner()
211
 
212
  # --- Singleton Instance ---
213
+ seedvr_manager_singleton = SeedVrManager()```