Factor Studios commited on
Commit
e6d9024
·
verified ·
1 Parent(s): 81a0028

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -20
  2. virtual_gpu_server_http.py +985 -0
requirements.txt CHANGED
@@ -1,20 +1 @@
1
- # Virtual GPU Server Core Dependencies
2
- numpy==1.24.3 # Array operations for GPU data handling
3
- websockets==11.0.3 # WebSocket server implementation
4
- aiohttp==3.8.5 # HTTP server and async web framework
5
- uvicorn==0.23.2 # ASGI server implementation
6
-
7
- # WebSocket/HTTP Server Dependencies
8
- aiosignal==1.3.1 # Async signals
9
- async-timeout==4.0.3 # Timeouts for async operations
10
- attrs==23.1.0 # Class builders
11
- charset-normalizer==3.2.0 # Unicode normalization
12
- frozenlist==1.4.0 # Immutable lists
13
- multidict==6.0.4 # Dict with multiple values per key
14
- yarl==1.9.2 # URL handling
15
-
16
- # Performance & Utilities
17
- ujson==5.8.0 # Fast JSON processing
18
- psutil==5.9.5 # System and process monitoring
19
- pathlib==1.0.1 # Path manipulation
20
- fastapi
 
1
+ fastapi>=0.104.0\nuvicorn>=0.24.0\nwebsockets>=12.0\npyjwt>=2.8.0\nrequests>=2.31.0\naiohttp>=3.9.0\nnumpy>=1.24.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtual_gpu_server_http.py ADDED
@@ -0,0 +1,985 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import websockets
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ import uuid
7
+ import time
8
+ import jwt
9
+ from typing import Dict, Any, Optional, List
10
+ import numpy as np
11
+ from fastapi import FastAPI, WebSocket, HTTPException, Depends, Request, Response
12
+ from fastapi.responses import HTMLResponse, JSONResponse
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
15
+ from datetime import datetime, timedelta
16
+ import hashlib
17
+ import gzip
18
+ import base64
19
+ from pydantic import BaseModel
20
+
21
+ # Create FastAPI instance with enhanced configuration
22
+ app = FastAPI(
23
+ title="Virtual GPU Server",
24
+ description="HTTP and WebSocket API for Virtual GPU v2",
25
+ version="2.0.0"
26
+ )
27
+
28
+ # Add CORS middleware for cross-origin requests
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"], # Allow all origins for development
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ # JWT configuration
38
+ JWT_SECRET = "virtual_gpu_secret_key_2025" # In production, use environment variable
39
+ JWT_ALGORITHM = "HS256"
40
+ JWT_EXPIRATION_HOURS = 24
41
+
42
+ # HTTP Bearer security scheme
43
+ security = HTTPBearer()
44
+
45
+ # Pydantic models for request/response validation
46
+ class SessionCreateRequest(BaseModel):
47
+ client_id: Optional[str] = None
48
+ resource_limits: Optional[Dict[str, Any]] = None
49
+
50
+ class SessionResponse(BaseModel):
51
+ session_token: str
52
+ session_id: str
53
+ expires_at: datetime
54
+
55
+ class VRAMWriteRequest(BaseModel):
56
+ data: List[Any]
57
+ metadata: Optional[Dict[str, Any]] = None
58
+ model_size: Optional[int] = None
59
+
60
+ class VRAMResponse(BaseModel):
61
+ status: str
62
+ message: Optional[str] = None
63
+ data: Optional[List[Any]] = None
64
+ metadata: Optional[Dict[str, Any]] = None
65
+ source: Optional[str] = None
66
+
67
+ class StateRequest(BaseModel):
68
+ data: Dict[str, Any]
69
+ timestamp: Optional[float] = None
70
+
71
+ class StateResponse(BaseModel):
72
+ status: str
73
+ message: Optional[str] = None
74
+ data: Optional[Dict[str, Any]] = None
75
+ source: Optional[str] = None
76
+
77
+ class CacheRequest(BaseModel):
78
+ data: Any
79
+ ttl: Optional[int] = None
80
+
81
+ class CacheResponse(BaseModel):
82
+ status: str
83
+ message: Optional[str] = None
84
+ data: Optional[Any] = None
85
+ source: Optional[str] = None
86
+
87
+ class ModelLoadRequest(BaseModel):
88
+ model_data: Optional[Dict[str, Any]] = None
89
+ model_path: Optional[str] = None
90
+ model_hash: Optional[str] = None
91
+
92
+ class ModelInferenceRequest(BaseModel):
93
+ input_data: List[Any]
94
+ batch_size: Optional[int] = None
95
+
96
+ class ErrorResponse(BaseModel):
97
+ status: str
98
+ error_code: str
99
+ message: str
100
+ details: Optional[Dict[str, Any]] = None
101
+ retry_after: Optional[int] = None
102
+ request_id: str
103
+
104
+ class VirtualGPUServer:
105
+ def __init__(self):
106
+ self.base_path = Path(__file__).parent / "storage"
107
+ self.vram_path = self.base_path / "vram_blocks"
108
+ self.state_path = self.base_path / "gpu_state"
109
+ self.cache_path = self.base_path / "cache"
110
+ self.models_path = self.base_path / "models"
111
+
112
+ # Ensure all storage directories exist
113
+ self.vram_path.mkdir(parents=True, exist_ok=True)
114
+ self.state_path.mkdir(parents=True, exist_ok=True)
115
+ self.cache_path.mkdir(parents=True, exist_ok=True)
116
+ self.models_path.mkdir(parents=True, exist_ok=True)
117
+
118
+ # In-memory caches for faster access
119
+ self.vram_cache: Dict[str, Any] = {}
120
+ self.state_cache: Dict[str, Any] = {}
121
+ self.memory_cache: Dict[str, Any] = {}
122
+ self.model_cache: Dict[str, Any] = {}
123
+
124
+ # Session management for HTTP API
125
+ self.http_sessions: Dict[str, Dict[str, Any]] = {}
126
+
127
+ # Active WebSocket connections and sessions (for backward compatibility)
128
+ self.active_connections: Dict[str, WebSocket] = {}
129
+ self.active_sessions: Dict[str, Dict[str, Any]] = {}
130
+ self.heartbeat_interval = 5 # seconds
131
+ self.connection_timeout = 30 # seconds
132
+
133
+ # Performance monitoring
134
+ self.ops_counter = 0
135
+ self.start_time = time.time()
136
+ self.request_counter = 0
137
+
138
+ def _make_json_serializable(self, obj):
139
+ """Convert non-JSON-serializable objects to serializable format"""
140
+ if isinstance(obj, dict):
141
+ return {k: self._make_json_serializable(v) for k, v in obj.items()}
142
+ elif isinstance(obj, list):
143
+ return [self._make_json_serializable(i) for i in obj]
144
+ elif isinstance(obj, tuple):
145
+ return list(obj)
146
+ elif isinstance(obj, (np.ndarray, np.generic)):
147
+ return obj.tolist()
148
+ elif isinstance(obj, (Path, uuid.UUID)):
149
+ return str(obj)
150
+ elif hasattr(obj, '__dict__'):
151
+ # Handle custom objects by converting their __dict__ to serializable format
152
+ return self._make_json_serializable(obj.__dict__)
153
+ elif isinstance(obj, (int, float, str, bool, type(None))):
154
+ return obj
155
+ else:
156
+ # Convert any other types to string representation
157
+ return str(obj)
158
+
159
+ def create_session_token(self, session_id: str, client_id: str = None, resource_limits: Dict[str, Any] = None) -> str:
160
+ """Create a JWT session token"""
161
+ payload = {
162
+ "session_id": session_id,
163
+ "client_id": client_id or "anonymous",
164
+ "resource_limits": resource_limits or {},
165
+ "created_at": time.time(),
166
+ "expires_at": time.time() + (JWT_EXPIRATION_HOURS * 3600)
167
+ }
168
+ return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
169
+
170
+ def verify_session_token(self, token: str) -> Dict[str, Any]:
171
+ """Verify and decode a JWT session token"""
172
+ try:
173
+ payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
174
+ if payload["expires_at"] < time.time():
175
+ raise HTTPException(status_code=401, detail="Session token expired")
176
+ return payload
177
+ except jwt.InvalidTokenError:
178
+ raise HTTPException(status_code=401, detail="Invalid session token")
179
+
180
+ def generate_request_id(self) -> str:
181
+ """Generate a unique request ID"""
182
+ self.request_counter += 1
183
+ return f"req_{int(time.time())}_{self.request_counter}"
184
+
185
+ def compress_data(self, data: bytes) -> bytes:
186
+ """Compress data using gzip"""
187
+ return gzip.compress(data)
188
+
189
+ def decompress_data(self, data: bytes) -> bytes:
190
+ """Decompress gzip data"""
191
+ return gzip.decompress(data)
192
+
193
+ async def handle_vram_operation(self, operation: dict) -> dict:
194
+ """Handle VRAM read/write operations (preserved from WebSocket implementation)"""
195
+ try:
196
+ op_type = operation.get('type')
197
+ if not op_type:
198
+ raise ValueError("Missing operation type")
199
+
200
+ block_id = operation.get('block_id')
201
+ if not block_id:
202
+ raise ValueError("Missing block_id")
203
+
204
+ data = operation.get('data')
205
+ if data and isinstance(data, (dict, list)):
206
+ data = self._make_json_serializable(data)
207
+
208
+ if op_type == 'write':
209
+ if data is None:
210
+ raise ValueError("Missing data for write operation")
211
+ file_path = self.vram_path / f"{block_id}.npy"
212
+ np.save(file_path, np.array(data))
213
+ self.vram_cache[block_id] = np.array(data)
214
+
215
+ # Store metadata
216
+ metadata = operation.get('metadata', {})
217
+ metadata_path = self.vram_path / f"{block_id}_metadata.json"
218
+ with open(metadata_path, 'w') as f:
219
+ json.dump(metadata, f)
220
+
221
+ return {'status': 'success', 'message': f'Block {block_id} written'}
222
+
223
+ if op_type == 'read':
224
+ if block_id in self.vram_cache:
225
+ # Load metadata
226
+ metadata_path = self.vram_path / f"{block_id}_metadata.json"
227
+ metadata = {}
228
+ if metadata_path.exists():
229
+ with open(metadata_path, 'r') as f:
230
+ metadata = json.load(f)
231
+
232
+ return {
233
+ 'status': 'success',
234
+ 'data': self.vram_cache[block_id] if isinstance(self.vram_cache[block_id], list) else self.vram_cache[block_id].tolist(),
235
+ 'metadata': metadata,
236
+ 'source': 'cache'
237
+ }
238
+
239
+ file_path = self.vram_path / f"{block_id}.npy"
240
+ if file_path.exists():
241
+ data = np.load(file_path)
242
+ self.vram_cache[block_id] = np.array(data)
243
+
244
+ # Load metadata
245
+ metadata_path = self.vram_path / f"{block_id}_metadata.json"
246
+ metadata = {}
247
+ if metadata_path.exists():
248
+ with open(metadata_path, 'r') as f:
249
+ metadata = json.load(f)
250
+
251
+ return {
252
+ 'status': 'success',
253
+ 'data': data.tolist(),
254
+ 'metadata': metadata,
255
+ 'source': 'disk'
256
+ }
257
+ return {'status': 'error', 'message': 'Block not found'}
258
+
259
+ return {'status': 'error', 'message': f'Unknown operation type: {op_type}'}
260
+
261
+ except ValueError as e:
262
+ return {'status': 'error', 'message': str(e)}
263
+ except Exception as e:
264
+ return {'status': 'error', 'message': f'Operation failed: {str(e)}'}
265
+
266
+ async def handle_state_operation(self, operation: dict) -> dict:
267
+ """Handle GPU state operations (preserved from WebSocket implementation)"""
268
+ op_type = operation.get('type')
269
+ component = operation.get('component')
270
+ state_id = operation.get('state_id')
271
+ state_data = operation.get('data')
272
+
273
+ file_path = self.state_path / component / f"{state_id}.json"
274
+
275
+ if op_type == 'save':
276
+ file_path.parent.mkdir(parents=True, exist_ok=True)
277
+ with open(file_path, 'w') as f:
278
+ json.dump(state_data, f)
279
+ self.state_cache[f"{component}:{state_id}"] = state_data
280
+ return {'status': 'success', 'message': f'State {state_id} saved'}
281
+
282
+ elif op_type == 'load':
283
+ cache_key = f"{component}:{state_id}"
284
+ if cache_key in self.state_cache:
285
+ return {
286
+ 'status': 'success',
287
+ 'data': self.state_cache[cache_key],
288
+ 'source': 'cache'
289
+ }
290
+
291
+ if file_path.exists():
292
+ with open(file_path) as f:
293
+ state_data = json.load(f)
294
+ self.state_cache[cache_key] = state_data
295
+ return {
296
+ 'status': 'success',
297
+ 'data': state_data,
298
+ 'source': 'disk'
299
+ }
300
+
301
+ return {'status': 'error', 'message': 'State not found'}
302
+
303
+ async def handle_cache_operation(self, operation: dict) -> dict:
304
+ """Handle cache operations (preserved from WebSocket implementation)"""
305
+ op_type = operation.get('type')
306
+ key = operation.get('key')
307
+ data = operation.get('data')
308
+
309
+ if op_type == 'set':
310
+ self.memory_cache[key] = data
311
+ # Also persist to disk for recovery
312
+ file_path = self.cache_path / f"{key}.json"
313
+ with open(file_path, 'w') as f:
314
+ json.dump(data, f)
315
+ return {'status': 'success', 'message': f'Cache key {key} set'}
316
+
317
+ elif op_type == 'get':
318
+ if key in self.memory_cache:
319
+ return {
320
+ 'status': 'success',
321
+ 'data': self.memory_cache[key],
322
+ 'source': 'memory'
323
+ }
324
+
325
+ file_path = self.cache_path / f"{key}.json"
326
+ if file_path.exists():
327
+ with open(file_path) as f:
328
+ data = json.load(f)
329
+ self.memory_cache[key] = data
330
+ return {
331
+ 'status': 'success',
332
+ 'data': data,
333
+ 'source': 'disk'
334
+ }
335
+
336
+ return {'status': 'error', 'message': 'Cache key not found'}
337
+
338
+ def get_stats(self) -> dict:
339
+ """Get server statistics"""
340
+ current_time = time.time()
341
+ uptime = current_time - self.start_time
342
+ ops_per_second = self.ops_counter / uptime if uptime > 0 else 0
343
+
344
+ return {
345
+ 'uptime': uptime,
346
+ 'total_operations': self.ops_counter,
347
+ 'ops_per_second': ops_per_second,
348
+ 'active_connections': len(self.active_connections),
349
+ 'active_http_sessions': len(self.http_sessions),
350
+ 'vram_cache_size': len(self.vram_cache),
351
+ 'state_cache_size': len(self.state_cache),
352
+ 'memory_cache_size': len(self.memory_cache),
353
+ 'model_cache_size': len(self.model_cache)
354
+ }
355
+
356
+ # Create server instance
357
+ server = VirtualGPUServer()
358
+
359
+ # Dependency to get current session from JWT token
360
+ def get_current_session(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Dict[str, Any]:
361
+ return server.verify_session_token(credentials.credentials)
362
+
363
+ # HTTP API Endpoints
364
+
365
+ @app.post("/api/v1/sessions", response_model=SessionResponse)
366
+ async def create_session(request: SessionCreateRequest):
367
+ """Create a new HTTP session"""
368
+ session_id = str(uuid.uuid4())
369
+ client_id = request.client_id or "anonymous"
370
+
371
+ # Create session token
372
+ token = server.create_session_token(session_id, client_id, request.resource_limits)
373
+
374
+ # Store session info
375
+ server.http_sessions[session_id] = {
376
+ 'session_id': session_id,
377
+ 'client_id': client_id,
378
+ 'created_at': time.time(),
379
+ 'resource_limits': request.resource_limits or {},
380
+ 'ops_count': 0
381
+ }
382
+
383
+ expires_at = datetime.fromtimestamp(time.time() + (JWT_EXPIRATION_HOURS * 3600))
384
+
385
+ return SessionResponse(
386
+ session_token=token,
387
+ session_id=session_id,
388
+ expires_at=expires_at
389
+ )
390
+
391
+ @app.post("/api/v1/vram/blocks/{block_id}", response_model=VRAMResponse)
392
+ async def write_vram_block(
393
+ block_id: str,
394
+ request: VRAMWriteRequest,
395
+ session: Dict[str, Any] = Depends(get_current_session)
396
+ ):
397
+ """Write tensor data to VRAM block"""
398
+ try:
399
+ operation = {
400
+ 'operation': 'vram',
401
+ 'type': 'write',
402
+ 'block_id': block_id,
403
+ 'data': request.data,
404
+ 'metadata': request.metadata or {},
405
+ 'model_size': request.model_size
406
+ }
407
+
408
+ result = await server.handle_vram_operation(operation)
409
+ server.ops_counter += 1
410
+
411
+ if result['status'] == 'success':
412
+ return VRAMResponse(
413
+ status=result['status'],
414
+ message=result['message']
415
+ )
416
+ else:
417
+ raise HTTPException(status_code=400, detail=result['message'])
418
+
419
+ except Exception as e:
420
+ request_id = server.generate_request_id()
421
+ raise HTTPException(
422
+ status_code=500,
423
+ detail=f"VRAM write operation failed: {str(e)}"
424
+ )
425
+
426
+ @app.get("/api/v1/vram/blocks/{block_id}", response_model=VRAMResponse)
427
+ async def read_vram_block(
428
+ block_id: str,
429
+ session: Dict[str, Any] = Depends(get_current_session)
430
+ ):
431
+ """Read tensor data from VRAM block"""
432
+ try:
433
+ operation = {
434
+ 'operation': 'vram',
435
+ 'type': 'read',
436
+ 'block_id': block_id
437
+ }
438
+
439
+ result = await server.handle_vram_operation(operation)
440
+ server.ops_counter += 1
441
+
442
+ if result['status'] == 'success':
443
+ return VRAMResponse(
444
+ status=result['status'],
445
+ data=result.get('data'),
446
+ metadata=result.get('metadata'),
447
+ source=result.get('source')
448
+ )
449
+ else:
450
+ raise HTTPException(status_code=404, detail=result['message'])
451
+
452
+ except HTTPException:
453
+ raise
454
+ except Exception as e:
455
+ request_id = server.generate_request_id()
456
+ raise HTTPException(
457
+ status_code=500,
458
+ detail=f"VRAM read operation failed: {str(e)}"
459
+ )
460
+
461
+ @app.delete("/api/v1/vram/blocks/{block_id}")
462
+ async def delete_vram_block(
463
+ block_id: str,
464
+ session: Dict[str, Any] = Depends(get_current_session)
465
+ ):
466
+ """Delete tensor data from VRAM block"""
467
+ try:
468
+ # Remove from cache
469
+ if block_id in server.vram_cache:
470
+ del server.vram_cache[block_id]
471
+
472
+ # Remove files
473
+ file_path = server.vram_path / f"{block_id}.npy"
474
+ metadata_path = server.vram_path / f"{block_id}_metadata.json"
475
+
476
+ if file_path.exists():
477
+ file_path.unlink()
478
+ if metadata_path.exists():
479
+ metadata_path.unlink()
480
+
481
+ server.ops_counter += 1
482
+ return {"status": "success", "message": f"Block {block_id} deleted"}
483
+
484
+ except Exception as e:
485
+ raise HTTPException(
486
+ status_code=500,
487
+ detail=f"VRAM delete operation failed: {str(e)}"
488
+ )
489
+
490
+ @app.post("/api/v1/state/{component}/{state_id}", response_model=StateResponse)
491
+ async def save_state(
492
+ component: str,
493
+ state_id: str,
494
+ request: StateRequest,
495
+ session: Dict[str, Any] = Depends(get_current_session)
496
+ ):
497
+ """Save component state"""
498
+ try:
499
+ operation = {
500
+ 'operation': 'state',
501
+ 'type': 'save',
502
+ 'component': component,
503
+ 'state_id': state_id,
504
+ 'data': request.data
505
+ }
506
+
507
+ result = await server.handle_state_operation(operation)
508
+ server.ops_counter += 1
509
+
510
+ if result['status'] == 'success':
511
+ return StateResponse(
512
+ status=result['status'],
513
+ message=result['message']
514
+ )
515
+ else:
516
+ raise HTTPException(status_code=400, detail=result['message'])
517
+
518
+ except Exception as e:
519
+ raise HTTPException(
520
+ status_code=500,
521
+ detail=f"State save operation failed: {str(e)}"
522
+ )
523
+
524
+ @app.get("/api/v1/state/{component}/{state_id}", response_model=StateResponse)
525
+ async def load_state(
526
+ component: str,
527
+ state_id: str,
528
+ session: Dict[str, Any] = Depends(get_current_session)
529
+ ):
530
+ """Load component state"""
531
+ try:
532
+ operation = {
533
+ 'operation': 'state',
534
+ 'type': 'load',
535
+ 'component': component,
536
+ 'state_id': state_id
537
+ }
538
+
539
+ result = await server.handle_state_operation(operation)
540
+ server.ops_counter += 1
541
+
542
+ if result['status'] == 'success':
543
+ return StateResponse(
544
+ status=result['status'],
545
+ data=result.get('data'),
546
+ source=result.get('source')
547
+ )
548
+ else:
549
+ raise HTTPException(status_code=404, detail=result['message'])
550
+
551
+ except HTTPException:
552
+ raise
553
+ except Exception as e:
554
+ raise HTTPException(
555
+ status_code=500,
556
+ detail=f"State load operation failed: {str(e)}"
557
+ )
558
+
559
+ @app.post("/api/v1/cache/{key}", response_model=CacheResponse)
560
+ async def set_cache(
561
+ key: str,
562
+ request: CacheRequest,
563
+ session: Dict[str, Any] = Depends(get_current_session)
564
+ ):
565
+ """Set cache value"""
566
+ try:
567
+ operation = {
568
+ 'operation': 'cache',
569
+ 'type': 'set',
570
+ 'key': key,
571
+ 'data': request.data
572
+ }
573
+
574
+ result = await server.handle_cache_operation(operation)
575
+ server.ops_counter += 1
576
+
577
+ if result['status'] == 'success':
578
+ return CacheResponse(
579
+ status=result['status'],
580
+ message=result['message']
581
+ )
582
+ else:
583
+ raise HTTPException(status_code=400, detail=result['message'])
584
+
585
+ except Exception as e:
586
+ raise HTTPException(
587
+ status_code=500,
588
+ detail=f"Cache set operation failed: {str(e)}"
589
+ )
590
+
591
+ @app.get("/api/v1/cache/{key}", response_model=CacheResponse)
592
+ async def get_cache(
593
+ key: str,
594
+ session: Dict[str, Any] = Depends(get_current_session)
595
+ ):
596
+ """Get cache value"""
597
+ try:
598
+ operation = {
599
+ 'operation': 'cache',
600
+ 'type': 'get',
601
+ 'key': key
602
+ }
603
+
604
+ result = await server.handle_cache_operation(operation)
605
+ server.ops_counter += 1
606
+
607
+ if result['status'] == 'success':
608
+ return CacheResponse(
609
+ status=result['status'],
610
+ data=result.get('data'),
611
+ source=result.get('source')
612
+ )
613
+ else:
614
+ raise HTTPException(status_code=404, detail=result['message'])
615
+
616
+ except HTTPException:
617
+ raise
618
+ except Exception as e:
619
+ raise HTTPException(
620
+ status_code=500,
621
+ detail=f"Cache get operation failed: {str(e)}"
622
+ )
623
+
624
+ @app.post("/api/v1/models/{model_name}/load")
625
+ async def load_model(
626
+ model_name: str,
627
+ request: ModelLoadRequest,
628
+ session: Dict[str, Any] = Depends(get_current_session)
629
+ ):
630
+ """Load AI model"""
631
+ try:
632
+ # Store model information
633
+ model_info = {
634
+ 'model_name': model_name,
635
+ 'model_data': request.model_data,
636
+ 'model_path': request.model_path,
637
+ 'model_hash': request.model_hash,
638
+ 'loaded_at': time.time(),
639
+ 'session_id': session['session_id']
640
+ }
641
+
642
+ server.model_cache[model_name] = model_info
643
+
644
+ # Store in persistent storage
645
+ model_file = server.models_path / f"{model_name}.json"
646
+ with open(model_file, 'w') as f:
647
+ json.dump(model_info, f)
648
+
649
+ server.ops_counter += 1
650
+ return {
651
+ "status": "success",
652
+ "message": f"Model {model_name} loaded successfully",
653
+ "model_info": {
654
+ "name": model_name,
655
+ "loaded_at": model_info['loaded_at']
656
+ }
657
+ }
658
+
659
+ except Exception as e:
660
+ raise HTTPException(
661
+ status_code=500,
662
+ detail=f"Model load operation failed: {str(e)}"
663
+ )
664
+
665
+ @app.post("/api/v1/models/{model_name}/inference")
666
+ async def run_inference(
667
+ model_name: str,
668
+ request: ModelInferenceRequest,
669
+ session: Dict[str, Any] = Depends(get_current_session)
670
+ ):
671
+ """Run model inference"""
672
+ try:
673
+ # Check if model is loaded
674
+ if model_name not in server.model_cache:
675
+ raise HTTPException(status_code=404, detail=f"Model {model_name} not loaded")
676
+
677
+ # Simulate inference processing
678
+ # In a real implementation, this would invoke the actual model
679
+ result = {
680
+ "status": "success",
681
+ "output": request.input_data, # Echo input for now
682
+ "metrics": {
683
+ "inference_time": 0.1,
684
+ "tokens_processed": len(request.input_data)
685
+ },
686
+ "model_info": server.model_cache[model_name]
687
+ }
688
+
689
+ server.ops_counter += 1
690
+ return result
691
+
692
+ except HTTPException:
693
+ raise
694
+ except Exception as e:
695
+ raise HTTPException(
696
+ status_code=500,
697
+ detail=f"Inference operation failed: {str(e)}"
698
+ )
699
+
700
+ @app.get("/api/v1/models/{model_name}/status")
701
+ async def get_model_status(
702
+ model_name: str,
703
+ session: Dict[str, Any] = Depends(get_current_session)
704
+ ):
705
+ """Get model status"""
706
+ try:
707
+ if model_name in server.model_cache:
708
+ return {
709
+ "status": "loaded",
710
+ "model_info": server.model_cache[model_name]
711
+ }
712
+ else:
713
+ return {
714
+ "status": "not_loaded",
715
+ "message": f"Model {model_name} is not loaded"
716
+ }
717
+
718
+ except Exception as e:
719
+ raise HTTPException(
720
+ status_code=500,
721
+ detail=f"Model status check failed: {str(e)}"
722
+ )
723
+
724
+ # Multi-chip coordination endpoints
725
+ @app.post("/api/v1/chips/{src_chip_id}/transfer/{dst_chip_id}")
726
+ async def transfer_between_chips(
727
+ src_chip_id: int,
728
+ dst_chip_id: int,
729
+ request: dict,
730
+ session: Dict[str, Any] = Depends(get_current_session)
731
+ ):
732
+ """Transfer data between GPU chips"""
733
+ try:
734
+ data_id = request.get('data_id')
735
+ if not data_id:
736
+ raise HTTPException(status_code=400, detail="Missing data_id")
737
+
738
+ # Load the source data
739
+ source_operation = {
740
+ 'operation': 'vram',
741
+ 'type': 'read',
742
+ 'block_id': data_id
743
+ }
744
+
745
+ source_result = await server.handle_vram_operation(source_operation)
746
+ if source_result.get('status') != 'success':
747
+ raise HTTPException(status_code=404, detail=f"Source data {data_id} not found")
748
+
749
+ # Create new data ID for destination
750
+ new_data_id = f"{data_id}_chip_{dst_chip_id}"
751
+
752
+ # Store the data with the new ID
753
+ dest_operation = {
754
+ 'operation': 'vram',
755
+ 'type': 'write',
756
+ 'block_id': new_data_id,
757
+ 'data': source_result.get('data'),
758
+ 'metadata': source_result.get('metadata', {})
759
+ }
760
+
761
+ dest_result = await server.handle_vram_operation(dest_operation)
762
+ if dest_result.get('status') != 'success':
763
+ raise HTTPException(status_code=500, detail="Failed to store transferred data")
764
+
765
+ # Simulate cross-chip transfer
766
+ transfer_id = f"transfer_{time.time_ns()}"
767
+
768
+ result = {
769
+ "status": "success",
770
+ "transfer_id": transfer_id,
771
+ "src_chip": src_chip_id,
772
+ "dst_chip": dst_chip_id,
773
+ "data_id": data_id,
774
+ "new_data_id": new_data_id
775
+ }
776
+
777
+ server.ops_counter += 1
778
+ return result
779
+
780
+ except HTTPException:
781
+ raise
782
+ except Exception as e:
783
+ raise HTTPException(
784
+ status_code=500,
785
+ detail=f"Chip transfer failed: {str(e)}"
786
+ )
787
+
788
+ @app.post("/api/v1/sync/barrier/{barrier_id}")
789
+ async def create_sync_barrier(
790
+ barrier_id: str,
791
+ request: dict,
792
+ session: Dict[str, Any] = Depends(get_current_session)
793
+ ):
794
+ """Create synchronization barrier"""
795
+ try:
796
+ num_participants = request.get('num_participants', 1)
797
+
798
+ # Store barrier info
799
+ barrier_info = {
800
+ 'barrier_id': barrier_id,
801
+ 'num_participants': num_participants,
802
+ 'arrived_participants': 0,
803
+ 'created_at': time.time()
804
+ }
805
+
806
+ server.memory_cache[f"barrier_{barrier_id}"] = barrier_info
807
+
808
+ return {
809
+ "status": "success",
810
+ "barrier_id": barrier_id,
811
+ "num_participants": num_participants
812
+ }
813
+
814
+ except Exception as e:
815
+ raise HTTPException(
816
+ status_code=500,
817
+ detail=f"Barrier creation failed: {str(e)}"
818
+ )
819
+
820
+ @app.put("/api/v1/sync/barrier/{barrier_id}/wait")
821
+ async def wait_sync_barrier(
822
+ barrier_id: str,
823
+ session: Dict[str, Any] = Depends(get_current_session)
824
+ ):
825
+ """Wait at synchronization barrier"""
826
+ try:
827
+ barrier_key = f"barrier_{barrier_id}"
828
+ if barrier_key not in server.memory_cache:
829
+ raise HTTPException(status_code=404, detail="Barrier not found")
830
+
831
+ barrier_info = server.memory_cache[barrier_key]
832
+ barrier_info['arrived_participants'] += 1
833
+
834
+ # Check if all participants have arrived
835
+ if barrier_info['arrived_participants'] >= barrier_info['num_participants']:
836
+ # All participants arrived, release barrier
837
+ del server.memory_cache[barrier_key]
838
+ return {
839
+ "status": "released",
840
+ "message": "All participants arrived, barrier released"
841
+ }
842
+ else:
843
+ return {
844
+ "status": "waiting",
845
+ "arrived": barrier_info['arrived_participants'],
846
+ "total": barrier_info['num_participants']
847
+ }
848
+
849
+ except HTTPException:
850
+ raise
851
+ except Exception as e:
852
+ raise HTTPException(
853
+ status_code=500,
854
+ detail=f"Barrier wait failed: {str(e)}"
855
+ )
856
+
857
+ # Preserved WebSocket endpoints for backward compatibility
858
+ @app.get("/", response_class=HTMLResponse)
859
+ async def handle_index():
860
+ """Handle HTTP index request"""
861
+ stats = server.get_stats()
862
+ html = f"""
863
+ <!DOCTYPE html>
864
+ <html>
865
+ <head>
866
+ <title>Virtual GPU Server v2.0</title>
867
+ <style>
868
+ body {{ font-family: Arial, sans-serif; margin: 40px; }}
869
+ table {{ border-collapse: collapse; width: 100%; margin-top: 20px; }}
870
+ th, td {{ padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }}
871
+ th {{ background-color: #f2f2f2; }}
872
+ .stats {{ background-color: #f9f9f9; padding: 20px; border-radius: 5px; }}
873
+ .api-info {{ background-color: #e8f4fd; padding: 20px; border-radius: 5px; margin-top: 20px; }}
874
+ </style>
875
+ </head>
876
+ <body>
877
+ <h1>Virtual GPU Server v2.0 Status</h1>
878
+ <div class="api-info">
879
+ <h2>API Information</h2>
880
+ <p><strong>HTTP REST API:</strong> Available at /api/v1/</p>
881
+ <p><strong>WebSocket API:</strong> Available at /ws (backward compatibility)</p>
882
+ <p><strong>API Documentation:</strong> <a href="/docs">/docs</a></p>
883
+ </div>
884
+ <div class="stats">
885
+ <h2>Server Statistics</h2>
886
+ <ul>
887
+ <li>Uptime: {stats['uptime']:.2f} seconds</li>
888
+ <li>Total Operations: {stats['total_operations']}</li>
889
+ <li>Operations per Second: {stats['ops_per_second']:.2f}</li>
890
+ <li>Active WebSocket Connections: {stats['active_connections']}</li>
891
+ <li>Active HTTP Sessions: {stats['active_http_sessions']}</li>
892
+ <li>VRAM Cache Size: {stats['vram_cache_size']}</li>
893
+ <li>State Cache Size: {stats['state_cache_size']}</li>
894
+ <li>Memory Cache Size: {stats['memory_cache_size']}</li>
895
+ <li>Model Cache Size: {stats['model_cache_size']}</li>
896
+ </ul>
897
+ </div>
898
+ <h2>Server Files</h2>
899
+ <iframe src="/files" style="width: 100%; height: 500px; border: none;"></iframe>
900
+ </body>
901
+ </html>
902
+ """
903
+ return HTMLResponse(content=html)
904
+
905
+ @app.get("/files", response_class=HTMLResponse)
906
+ async def handle_files():
907
+ """Handle HTTP files listing request"""
908
+ def format_size(size):
909
+ for unit in ['B', 'KB', 'MB', 'GB']:
910
+ if size < 1024:
911
+ return f"{size:.2f} {unit}"
912
+ size /= 1024
913
+ return f"{size:.2f} TB"
914
+
915
+ html = ['<!DOCTYPE html><html><head>',
916
+ '<style>',
917
+ 'body { font-family: Arial, sans-serif; margin: 20px; }',
918
+ 'table { border-collapse: collapse; width: 100%; }',
919
+ 'th, td { padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }',
920
+ 'th { background-color: #f2f2f2; }',
921
+ '</style></head><body>',
922
+ '<h2>Server Files</h2>',
923
+ '<table><tr><th>Path</th><th>Size</th><th>Last Modified</th></tr>']
924
+
925
+ for root, _, files in os.walk(server.base_path):
926
+ for file in files:
927
+ full_path = Path(root) / file
928
+ rel_path = full_path.relative_to(server.base_path)
929
+ size = format_size(os.path.getsize(full_path))
930
+ mtime = datetime.fromtimestamp(os.path.getmtime(full_path))
931
+ html.append(f'<tr><td>{rel_path}</td><td>{size}</td><td>{mtime}</td></tr>')
932
+
933
+ html.extend(['</table></body></html>'])
934
+ return HTMLResponse(content='\n'.join(html))
935
+
936
+ # WebSocket endpoint (preserved for backward compatibility)
937
+ @app.websocket("/ws")
938
+ async def websocket_endpoint(websocket: WebSocket):
939
+ await websocket.accept()
940
+ session_id = str(uuid.uuid4())
941
+ server.active_connections[session_id] = websocket
942
+ server.active_sessions[session_id] = {
943
+ 'start_time': time.time(),
944
+ 'ops_count': 0
945
+ }
946
+
947
+ try:
948
+ while True:
949
+ message = await websocket.receive_json()
950
+
951
+ # Route operation to appropriate handler
952
+ operation_type = message.get('operation')
953
+ if operation_type == 'vram':
954
+ response = await server.handle_vram_operation(message)
955
+ elif operation_type == 'state':
956
+ response = await server.handle_state_operation(message)
957
+ elif operation_type == 'cache':
958
+ response = await server.handle_cache_operation(message)
959
+ else:
960
+ response = {
961
+ 'status': 'error',
962
+ 'message': 'Unknown operation type'
963
+ }
964
+
965
+ # Update statistics
966
+ server.ops_counter += 1
967
+ server.active_sessions[session_id]['ops_count'] += 1
968
+
969
+ # Send response
970
+ await websocket.send_json(response)
971
+
972
+ except Exception as e:
973
+ print(f"WebSocket error: {e}")
974
+ finally:
975
+ # Cleanup on disconnect
976
+ if session_id in server.active_connections:
977
+ del server.active_connections[session_id]
978
+ if session_id in server.active_sessions:
979
+ del server.active_sessions[session_id]
980
+
981
+ # For running directly (development)
982
+ if __name__ == "__main__":
983
+ import uvicorn
984
+ uvicorn.run("virtual_gpu_server_http:app", host="0.0.0.0", port=7860, reload=True)
985
+