Spaces:
Sleeping
Sleeping
| from websocket_storage import WebSocketGPUStorage | |
| import numpy as np | |
| from typing import Dict, Any, Optional | |
| import time | |
| class VirtualVRAM: | |
| def __init__(self, size_gb: int = None, storage=None): | |
| """Initialize virtual VRAM with unlimited storage capability""" | |
| self.storage = storage | |
| if self.storage is None: | |
| from websocket_storage import WebSocketGPUStorage | |
| self.storage = WebSocketGPUStorage() | |
| if not self.storage.wait_for_connection(): | |
| raise RuntimeError("Could not connect to GPU storage server") | |
| # Initialize VRAM state with unlimited capacity | |
| self.vram_state = { | |
| "total_size": float('inf'), # Unlimited size | |
| "allocated": 0, | |
| "blocks": {}, | |
| "memory_map": {}, | |
| "is_unlimited": True | |
| } | |
| self.store_vram_state() | |
| def store_vram_state(self, max_retries=3): | |
| """Store VRAM state in WebSocket storage with retry logic""" | |
| for attempt in range(max_retries): | |
| try: | |
| # Wait for connection if needed | |
| if not self.storage.wait_for_connection(timeout=5): | |
| print(f"Waiting for WebSocket connection (attempt {attempt + 1}/{max_retries})") | |
| time.sleep(1) | |
| continue | |
| # Ensure state is JSON serializable | |
| safe_state = { | |
| "total_size": str(self.vram_state["total_size"]) if isinstance(self.vram_state["total_size"], float) and self.vram_state["total_size"] == float('inf') else self.vram_state["total_size"], | |
| "allocated": self.vram_state["allocated"], | |
| "blocks": self.vram_state["blocks"], | |
| "memory_map": self.vram_state["memory_map"], | |
| "is_unlimited": self.vram_state["is_unlimited"] | |
| } | |
| success = self.storage.store_state("vram", "state", safe_state) | |
| if success: | |
| return True | |
| print(f"Failed to store VRAM state (attempt {attempt + 1}/{max_retries})") | |
| time.sleep(1) | |
| except Exception as e: | |
| print(f"Error storing VRAM state (attempt {attempt + 1}/{max_retries}): {str(e)}") | |
| time.sleep(1) | |
| raise RuntimeError("Failed to store VRAM state after multiple attempts") | |
| def allocate_block(self, size: int, block_id: Optional[str] = None) -> str: | |
| """Allocate a block of VRAM""" | |
| if self.vram_state["allocated"] + size > self.vram_state["total_size"]: | |
| raise MemoryError("Not enough VRAM available") | |
| if block_id is None: | |
| block_id = f"block_{time.time_ns()}" | |
| self.vram_state["blocks"][block_id] = { | |
| "size": size, | |
| "allocated_at": time.time_ns(), | |
| "last_accessed": time.time_ns() | |
| } | |
| self.vram_state["allocated"] += size | |
| # Store updated state | |
| self.store_vram_state() | |
| return block_id | |
| def free_block(self, block_id: str): | |
| """Free a block of VRAM""" | |
| if block_id in self.vram_state["blocks"]: | |
| self.vram_state["allocated"] -= self.vram_state["blocks"][block_id]["size"] | |
| del self.vram_state["blocks"][block_id] | |
| self.store_vram_state() | |
| # Remove block data from storage | |
| self.storage.store_tensor(block_id, None) | |
| def write_block(self, block_id: str, data: np.ndarray): | |
| """Write data to a VRAM block""" | |
| if block_id not in self.vram_state["blocks"]: | |
| raise ValueError(f"Block {block_id} not allocated") | |
| self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns() | |
| self.store_vram_state() | |
| return self.storage.store_tensor(block_id, data) | |
| def read_block(self, block_id: str) -> Optional[np.ndarray]: | |
| """Read data from a VRAM block""" | |
| if block_id not in self.vram_state["blocks"]: | |
| raise ValueError(f"Block {block_id} not allocated") | |
| self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns() | |
| self.store_vram_state() | |
| return self.storage.load_tensor(block_id) | |
| def map_address(self, virtual_addr: str, block_id: str): | |
| """Map virtual address to VRAM block""" | |
| self.vram_state["memory_map"][virtual_addr] = block_id | |
| self.store_vram_state() | |
| def get_block_from_address(self, virtual_addr: str) -> Optional[str]: | |
| """Get block ID from virtual address""" | |
| return self.vram_state["memory_map"].get(virtual_addr) | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get VRAM statistics""" | |
| return { | |
| "total_gb": self.size_gb, | |
| "used_gb": self.vram_state["allocated"] / (1024 * 1024 * 1024), | |
| "free_gb": (self.vram_state["total_size"] - self.vram_state["allocated"]) / (1024 * 1024 * 1024), | |
| "num_blocks": len(self.vram_state["blocks"]), | |
| "mappings": len(self.vram_state["memory_map"]) | |
| } | |