INTAI / virtual_vram.py
Factor Studios
Upload 2 files
d19c131 verified
from websocket_storage import WebSocketGPUStorage
import numpy as np
from typing import Dict, Any, Optional
import time
class VirtualVRAM:
def __init__(self, size_gb: int = None, storage=None):
"""Initialize virtual VRAM with unlimited storage capability"""
self.storage = storage
if self.storage is None:
from websocket_storage import WebSocketGPUStorage
self.storage = WebSocketGPUStorage()
if not self.storage.wait_for_connection():
raise RuntimeError("Could not connect to GPU storage server")
# Initialize VRAM state with unlimited capacity
self.vram_state = {
"total_size": float('inf'), # Unlimited size
"allocated": 0,
"blocks": {},
"memory_map": {},
"is_unlimited": True
}
self.store_vram_state()
def store_vram_state(self, max_retries=3):
"""Store VRAM state in WebSocket storage with retry logic"""
for attempt in range(max_retries):
try:
# Wait for connection if needed
if not self.storage.wait_for_connection(timeout=5):
print(f"Waiting for WebSocket connection (attempt {attempt + 1}/{max_retries})")
time.sleep(1)
continue
# Ensure state is JSON serializable
safe_state = {
"total_size": str(self.vram_state["total_size"]) if isinstance(self.vram_state["total_size"], float) and self.vram_state["total_size"] == float('inf') else self.vram_state["total_size"],
"allocated": self.vram_state["allocated"],
"blocks": self.vram_state["blocks"],
"memory_map": self.vram_state["memory_map"],
"is_unlimited": self.vram_state["is_unlimited"]
}
success = self.storage.store_state("vram", "state", safe_state)
if success:
return True
print(f"Failed to store VRAM state (attempt {attempt + 1}/{max_retries})")
time.sleep(1)
except Exception as e:
print(f"Error storing VRAM state (attempt {attempt + 1}/{max_retries}): {str(e)}")
time.sleep(1)
raise RuntimeError("Failed to store VRAM state after multiple attempts")
def allocate_block(self, size: int, block_id: Optional[str] = None) -> str:
"""Allocate a block of VRAM"""
if self.vram_state["allocated"] + size > self.vram_state["total_size"]:
raise MemoryError("Not enough VRAM available")
if block_id is None:
block_id = f"block_{time.time_ns()}"
self.vram_state["blocks"][block_id] = {
"size": size,
"allocated_at": time.time_ns(),
"last_accessed": time.time_ns()
}
self.vram_state["allocated"] += size
# Store updated state
self.store_vram_state()
return block_id
def free_block(self, block_id: str):
"""Free a block of VRAM"""
if block_id in self.vram_state["blocks"]:
self.vram_state["allocated"] -= self.vram_state["blocks"][block_id]["size"]
del self.vram_state["blocks"][block_id]
self.store_vram_state()
# Remove block data from storage
self.storage.store_tensor(block_id, None)
def write_block(self, block_id: str, data: np.ndarray):
"""Write data to a VRAM block"""
if block_id not in self.vram_state["blocks"]:
raise ValueError(f"Block {block_id} not allocated")
self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns()
self.store_vram_state()
return self.storage.store_tensor(block_id, data)
def read_block(self, block_id: str) -> Optional[np.ndarray]:
"""Read data from a VRAM block"""
if block_id not in self.vram_state["blocks"]:
raise ValueError(f"Block {block_id} not allocated")
self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns()
self.store_vram_state()
return self.storage.load_tensor(block_id)
def map_address(self, virtual_addr: str, block_id: str):
"""Map virtual address to VRAM block"""
self.vram_state["memory_map"][virtual_addr] = block_id
self.store_vram_state()
def get_block_from_address(self, virtual_addr: str) -> Optional[str]:
"""Get block ID from virtual address"""
return self.vram_state["memory_map"].get(virtual_addr)
def get_stats(self) -> Dict[str, Any]:
"""Get VRAM statistics"""
return {
"total_gb": self.size_gb,
"used_gb": self.vram_state["allocated"] / (1024 * 1024 * 1024),
"free_gb": (self.vram_state["total_size"] - self.vram_state["allocated"]) / (1024 * 1024 * 1024),
"num_blocks": len(self.vram_state["blocks"]),
"mappings": len(self.vram_state["memory_map"])
}