Spaces:
Running on T4
Running on T4
| """ | |
| Model Manager for real-time motion generation (HF Space version) | |
| Loads model from Hugging Face Hub instead of local checkpoints. | |
| """ | |
| import threading | |
| import time | |
| from collections import deque | |
| import numpy as np | |
| import torch | |
| from motion_process import StreamJointRecovery263 | |
| class FrameBuffer: | |
| """ | |
| Thread-safe frame buffer that maintains a queue of generated frames | |
| """ | |
| def __init__(self, target_buffer_size=4): | |
| self.buffer = deque(maxlen=100) # Max 100 frames in buffer | |
| self.target_size = target_buffer_size | |
| self.lock = threading.Lock() | |
| def add_frame(self, joints): | |
| """Add a frame to the buffer""" | |
| with self.lock: | |
| self.buffer.append(joints) | |
| def get_frame(self): | |
| """Get the next frame from buffer""" | |
| with self.lock: | |
| if len(self.buffer) > 0: | |
| return self.buffer.popleft() | |
| return None | |
| def size(self): | |
| """Get current buffer size""" | |
| with self.lock: | |
| return len(self.buffer) | |
| def clear(self): | |
| """Clear the buffer""" | |
| with self.lock: | |
| self.buffer.clear() | |
| def needs_generation(self): | |
| """Check if buffer needs more frames""" | |
| return self.size() < self.target_size | |
| class ModelManager: | |
| """ | |
| Manages model loading from HF Hub and real-time frame generation | |
| """ | |
| def __init__(self, model_name): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| # Load models from HF Hub | |
| self.vae, self.model = self._load_models(model_name) | |
| # Build config dicts from model's individual attributes (HF model API) | |
| self._base_schedule_config = { | |
| "chunk_size": self.model.chunk_size, | |
| "steps": self.model.noise_steps, | |
| } | |
| self._base_cfg_config = { | |
| "cfg_scale": self.model.cfg_scale, | |
| } | |
| # Frame buffer (for active session) | |
| self.frame_buffer = FrameBuffer(target_buffer_size=16) | |
| # Broadcast buffer (for spectators) - append-only with frame IDs | |
| self.broadcast_frames = deque(maxlen=200) | |
| self.broadcast_id = 0 | |
| self.broadcast_lock = threading.Lock() | |
| # Stream joint recovery with smoothing | |
| self.smoothing_alpha = 0.5 # Default: medium smoothing | |
| self.stream_recovery = StreamJointRecovery263( | |
| joints_num=22, smoothing_alpha=self.smoothing_alpha | |
| ) | |
| # Generation state | |
| self.current_text = "" | |
| self.is_generating = False | |
| self.generation_thread = None | |
| self.should_stop = False | |
| # Model generation state | |
| self.first_chunk = True # For VAE stream_decode | |
| self._model_first_chunk = True # For model stream_generate_step | |
| self.history_length = 30 | |
| print("ModelManager initialized successfully") | |
| def _patch_attention_sdpa(self, model_name): | |
| """Patch flash_attention() to include SDPA fallback for GPUs without flash-attn (e.g., T4).""" | |
| import glob | |
| import os | |
| hf_cache = os.path.join(os.path.expanduser("~"), ".cache", "huggingface") | |
| patterns = [ | |
| os.path.join( | |
| hf_cache, "hub", "models--" + model_name.replace("/", "--"), | |
| "snapshots", "*", "ldf_models", "tools", "attention.py", | |
| ), | |
| os.path.join( | |
| hf_cache, "modules", "transformers_modules", model_name, | |
| "*", "ldf_models", "tools", "attention.py", | |
| ), | |
| ] | |
| # Use the assert + next line as target to ensure idempotent patching | |
| target = ( | |
| ' assert q.device.type == "cuda" and q.size(-1) <= 256\n' | |
| "\n" | |
| " # params\n" | |
| ) | |
| replacement = ( | |
| ' assert q.device.type == "cuda" and q.size(-1) <= 256\n' | |
| "\n" | |
| " # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n" | |
| " if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n" | |
| " out_dtype = q.dtype\n" | |
| " b, lq, nq, c = q.shape\n" | |
| " lk = k.size(1)\n" | |
| " q = q.transpose(1, 2).to(dtype)\n" | |
| " k = k.transpose(1, 2).to(dtype)\n" | |
| " v = v.transpose(1, 2).to(dtype)\n" | |
| " attn_mask = None\n" | |
| " is_causal_flag = causal\n" | |
| " if k_lens is not None:\n" | |
| " k_lens = k_lens.to(q.device)\n" | |
| " valid = torch.arange(lk, device=q.device).unsqueeze(0) < k_lens.unsqueeze(1)\n" | |
| " attn_mask = torch.where(valid[:, None, None, :], 0.0, float('-inf')).to(dtype=dtype)\n" | |
| " is_causal_flag = False\n" | |
| " if causal:\n" | |
| " cm = torch.triu(torch.ones(lq, lk, device=q.device, dtype=torch.bool), diagonal=1)\n" | |
| " attn_mask = attn_mask.masked_fill(cm[None, None, :, :], float('-inf'))\n" | |
| " out = torch.nn.functional.scaled_dot_product_attention(\n" | |
| " q, k, v, attn_mask=attn_mask, is_causal=is_causal_flag, dropout_p=dropout_p\n" | |
| " )\n" | |
| " return out.transpose(1, 2).contiguous().to(out_dtype)\n" | |
| "\n" | |
| " # params\n" | |
| ) | |
| for pattern in patterns: | |
| for filepath in glob.glob(pattern): | |
| with open(filepath, "r") as f: | |
| content = f.read() | |
| if "SDPA fallback" in content: | |
| print(f"Already patched: {filepath}") | |
| continue | |
| if target in content: | |
| content = content.replace(target, replacement, 1) | |
| with open(filepath, "w") as f: | |
| f.write(content) | |
| print(f"Patched with SDPA fallback: {filepath}") | |
| def _load_models(self, model_name): | |
| """Load VAE and diffusion models from HF Hub""" | |
| torch.set_float32_matmul_precision("high") | |
| # Pre-download model files to hub cache | |
| print(f"Downloading model from HF Hub: {model_name}") | |
| from huggingface_hub import snapshot_download | |
| snapshot_download(model_name) | |
| # Patch flash_attention with SDPA fallback for T4 (no flash-attn) | |
| self._patch_attention_sdpa(model_name) | |
| print("Loading model...") | |
| from transformers import AutoModel | |
| hf_model = AutoModel.from_pretrained(model_name, trust_remote_code=True) | |
| hf_model.to(self.device) | |
| # Trigger lazy loading / warmup | |
| print("Warming up model...") | |
| _ = hf_model("test", length=1) | |
| # Access underlying streaming components | |
| model = hf_model.ldf_model | |
| vae = hf_model.vae | |
| model.eval() | |
| vae.eval() | |
| print("Models loaded successfully") | |
| return vae, model | |
| def start_generation(self, text, history_length=None): | |
| """Start or update generation with new text""" | |
| self.current_text = text | |
| if history_length is not None: | |
| self.history_length = history_length | |
| if not self.is_generating: | |
| # Reset state before starting (only once at the beginning) | |
| self.frame_buffer.clear() | |
| self.stream_recovery.reset() | |
| self.vae.clear_cache() | |
| self.first_chunk = True | |
| self._model_first_chunk = True | |
| # Restore model params from base config | |
| self.model.chunk_size = self._base_schedule_config["chunk_size"] | |
| self.model.noise_steps = self._base_schedule_config["steps"] | |
| self.model.cfg_scale = self._base_cfg_config["cfg_scale"] | |
| self.model.init_generated(self.history_length, batch_size=1) | |
| print( | |
| f"Model initialized with history length: {self.history_length}" | |
| ) | |
| # Start generation thread | |
| self.should_stop = False | |
| self.generation_thread = threading.Thread(target=self._generation_loop) | |
| self.generation_thread.daemon = True | |
| self.generation_thread.start() | |
| self.is_generating = True | |
| def update_text(self, text): | |
| """Update text without resetting state (continuous generation with new text)""" | |
| if text != self.current_text: | |
| old_text = self.current_text | |
| self.current_text = text | |
| # Don't reset first_chunk, stream_recovery, or vae cache | |
| # This allows continuous generation with text changes | |
| print(f"Text updated: '{old_text}' -> '{text}' (continuous generation)") | |
| def pause_generation(self): | |
| """Pause generation (keeps all state)""" | |
| self.should_stop = True | |
| if self.generation_thread: | |
| self.generation_thread.join(timeout=2.0) | |
| self.is_generating = False | |
| print("Generation paused (state preserved)") | |
| def resume_generation(self): | |
| """Resume generation from paused state""" | |
| if self.is_generating: | |
| print("Already generating, ignoring resume") | |
| return | |
| # Restart generation thread with existing state | |
| self.should_stop = False | |
| self.generation_thread = threading.Thread(target=self._generation_loop) | |
| self.generation_thread.daemon = True | |
| self.generation_thread.start() | |
| self.is_generating = True | |
| print("Generation resumed") | |
| def reset(self, history_length=None, smoothing_alpha=None): | |
| """Reset generation state completely | |
| Args: | |
| history_length: History window length for the model | |
| smoothing_alpha: EMA smoothing factor (0.0 to 1.0) | |
| - 1.0 = no smoothing (default) | |
| - 0.0 = infinite smoothing | |
| - Recommended: 0.3-0.7 for visible smoothing | |
| """ | |
| # Stop if running | |
| if self.is_generating: | |
| self.pause_generation() | |
| # Clear everything | |
| self.frame_buffer.clear() | |
| self.vae.clear_cache() | |
| self.first_chunk = True | |
| if history_length is not None: | |
| self.history_length = history_length | |
| # Update smoothing alpha if provided and recreate stream recovery | |
| if smoothing_alpha is not None: | |
| self.smoothing_alpha = np.clip(smoothing_alpha, 0.0, 1.0) | |
| print(f"Smoothing alpha updated to: {self.smoothing_alpha}") | |
| # Recreate stream recovery with new smoothing alpha | |
| self.stream_recovery = StreamJointRecovery263( | |
| joints_num=22, smoothing_alpha=self.smoothing_alpha | |
| ) | |
| # Restore model params from base config | |
| self.model.chunk_size = self._base_schedule_config["chunk_size"] | |
| self.model.noise_steps = self._base_schedule_config["steps"] | |
| self.model.cfg_scale = self._base_cfg_config["cfg_scale"] | |
| self._model_first_chunk = True | |
| # Initialize model | |
| self.model.init_generated(self.history_length, batch_size=1) | |
| print( | |
| f"Model reset - history: {self.history_length}, smoothing: {self.smoothing_alpha}" | |
| ) | |
| def _generation_loop(self): | |
| """Main generation loop that runs in background thread""" | |
| print("Generation loop started") | |
| step_count = 0 | |
| total_gen_time = 0 | |
| with torch.no_grad(): | |
| while not self.should_stop: | |
| # Check if buffer needs more frames | |
| if self.frame_buffer.needs_generation(): | |
| try: | |
| step_start = time.time() | |
| # Generate one token (produces frames from VAE) | |
| x = {"text": [self.current_text]} | |
| # Generate from model (1 token) | |
| output = self.model.stream_generate_step( | |
| x, first_chunk=self._model_first_chunk | |
| ) | |
| self._model_first_chunk = False | |
| generated = output["generated"] | |
| # Skip if no frames committed yet | |
| if generated[0].shape[0] == 0: | |
| continue | |
| # Decode with VAE (1 token -> 4 frames) | |
| decoded = self.vae.stream_decode( | |
| generated[0][None, :], first_chunk=self.first_chunk | |
| )[0] | |
| self.first_chunk = False | |
| # Convert each frame to joints | |
| for i in range(decoded.shape[0]): | |
| frame_data = decoded[i].cpu().numpy() | |
| joints = self.stream_recovery.process_frame(frame_data) | |
| self.frame_buffer.add_frame(joints) | |
| # Also add to broadcast buffer for spectators | |
| with self.broadcast_lock: | |
| self.broadcast_id += 1 | |
| self.broadcast_frames.append( | |
| (self.broadcast_id, joints) | |
| ) | |
| step_time = time.time() - step_start | |
| total_gen_time += step_time | |
| step_count += 1 | |
| # Print performance stats every 10 steps | |
| if step_count % 10 == 0: | |
| avg_time = total_gen_time / step_count | |
| fps = decoded.shape[0] / avg_time | |
| print( | |
| f"[Generation] Step {step_count}: {step_time * 1000:.1f}ms, " | |
| f"Avg: {avg_time * 1000:.1f}ms, " | |
| f"FPS: {fps:.1f}, " | |
| f"Buffer: {self.frame_buffer.size()}" | |
| ) | |
| except Exception as e: | |
| print(f"Error in generation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| time.sleep(0.1) | |
| else: | |
| # Buffer is full, wait a bit | |
| time.sleep(0.01) | |
| print("Generation loop stopped") | |
| def get_next_frame(self): | |
| """Get the next frame from buffer""" | |
| return self.frame_buffer.get_frame() | |
| def get_broadcast_frames(self, after_id, count=8): | |
| """Get frames from broadcast buffer after the given ID (for spectators).""" | |
| with self.broadcast_lock: | |
| frames = [ | |
| (fid, joints) | |
| for fid, joints in self.broadcast_frames | |
| if fid > after_id | |
| ] | |
| return frames[:count] | |
| def get_buffer_status(self): | |
| """Get buffer status""" | |
| return { | |
| "buffer_size": self.frame_buffer.size(), | |
| "target_size": self.frame_buffer.target_size, | |
| "is_generating": self.is_generating, | |
| "current_text": self.current_text, | |
| "smoothing_alpha": self.smoothing_alpha, | |
| "history_length": self.history_length, | |
| "schedule_config": { | |
| "chunk_size": self.model.chunk_size, | |
| "steps": self.model.noise_steps, | |
| }, | |
| "cfg_config": { | |
| "cfg_scale": self.model.cfg_scale, | |
| }, | |
| } | |
| # Global model manager instance | |
| _model_manager = None | |
| def get_model_manager(model_name=None): | |
| """Get or create the global model manager instance""" | |
| global _model_manager | |
| if _model_manager is None: | |
| _model_manager = ModelManager(model_name) | |
| return _model_manager | |