INV / virtual_gpu_driver /src /tensor_ops.py
Fred808's picture
Upload 256 files
7a0c684 verified
from typing import List, Dict, Any, Optional
import time
import json
import logging
import duckdb
from huggingface_hub import HfApi, HfFileSystem
from tensor_core import TensorCore
from config import get_hf_token_cached
# Initialize token from .env
class TensorOps:
"""Manages tensor operations with remote state tracking"""
DB_URL = "hf://datasets/Fred808/helium/storage.json"
def __init__(self, db_url: Optional[str] = None):
self.db_url = db_url or self.DB_URL
self.max_retries = 3
self._connect_with_retries()
self._setup_database()
def _connect_with_retries(self):
"""Establish database connection with retry logic"""
for attempt in range(self.max_retries):
try:
self.conn = self._init_db_connection()
return
except Exception as e:
if attempt == self.max_retries - 1:
raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
time.sleep(1)
def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
"""Initialize database connection with HuggingFace configuration"""
# Convert HF URL to S3 path
_, _, owner, dataset, db_file = self.db_url.split('/', 4)
db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
# Connect to remote database
conn = duckdb.connect(db_path)
conn.execute("INSTALL httpfs;")
conn.execute("LOAD httpfs;")
conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';")
conn.execute("SET s3_use_ssl=true;")
conn.execute("SET s3_url_style='path';")
conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
return conn
def _setup_database(self):
"""Initialize database tables"""
# Tensor operations table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS tensor_operations (
operation_id VARCHAR PRIMARY KEY,
operation_type VARCHAR,
inputs JSON,
output_shape VARCHAR,
chip_id INTEGER,
stream_id INTEGER,
warp_id VARCHAR,
status VARCHAR DEFAULT 'pending',
result_address BIGINT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
error_message VARCHAR,
state_json JSON
)
""")
def execute_tensor_op(self, operation: str, inputs: List[Dict[str, Any]],
output_shape: Optional[tuple] = None,
chip_id: Optional[int] = None,
stream_id: Optional[int] = None,
warp_id: Optional[str] = None) -> Optional[int]:
"""
Execute a tensor operation with enhanced tracking and coordination
Args:
operation: Operation type (matmul, conv2d, etc.)
inputs: List of input tensors with metadata
output_shape: Expected output shape (for pre-allocation)
chip_id: Target GPU chip (if None, will be automatically selected)
stream_id: Execution stream ID (if None, uses default stream)
warp_id: ID of warp to execute on (if None, automatically scheduled)
Returns:
Address of output tensor or None if operation fails
"""
operation_id = None
try:
# Generate operation ID
operation_id = f"op_{time.time_ns()}"
# Choose optimal GPU if not specified
if chip_id is None:
# Query least loaded GPU
result = self.conn.execute("""
SELECT chip_id
FROM tensor_operations
WHERE status = 'running'
GROUP BY chip_id
ORDER BY COUNT(*) ASC
LIMIT 1
""").fetchall()
chip_id = result[0][0] if result else 0
# Create operation record
self.conn.execute("""
INSERT INTO tensor_operations (
operation_id, operation_type, inputs, output_shape,
chip_id, stream_id, warp_id, status, state_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", [
operation_id,
operation,
inputs,
str(output_shape) if output_shape else None,
chip_id,
stream_id,
warp_id,
'pending',
{
"status": "initialized",
"timestamp": time.time_ns()
}
])
# Initialize tensor core
tensor_core = TensorCore()
# Execute operation
# Update status to running
self.conn.execute("""
UPDATE tensor_operations
SET status = 'running',
started_at = CURRENT_TIMESTAMP,
state_json = ?
WHERE operation_id = ?
""", [{"status": "running"}, operation_id])
# Execute based on operation type
result_address = None
if operation == 'matmul':
result_address = tensor_core.matmul(
inputs[0]['data'],
inputs[1]['data'],
warp_id=warp_id
)
elif operation == 'conv2d':
result_address = tensor_core.conv2d(
inputs[0]['data'],
inputs[1]['data'],
warp_id=warp_id
)
# Update operation status to completed
self.conn.execute("""
UPDATE tensor_operations
SET status = 'completed',
completed_at = CURRENT_TIMESTAMP,
result_address = ?,
state_json = ?
WHERE operation_id = ?
""", [
result_address,
{"status": "completed", "result": result_address},
operation_id
])
return result_address
except Exception as e:
if operation_id:
# Update operation status to failed
self.conn.execute("""
UPDATE tensor_operations
SET status = 'failed',
completed_at = CURRENT_TIMESTAMP,
error_message = ?,
state_json = ?
WHERE operation_id = ?
""", [
str(e),
{"status": "failed", "error": str(e)},
operation_id
])
logging.error(f"Tensor operation failed: {str(e)}")
return None
def get_operation_status(self, operation_id: str) -> Dict[str, Any]:
"""Get the current status of a tensor operation"""
try:
result = self.conn.execute("""
SELECT status, result_address, error_message, state_json
FROM tensor_operations
WHERE operation_id = ?
""", [operation_id]).fetchall()
if not result:
return {"status": "not_found"}
row = result[0]
return {
"status": row[0],
"result_address": row[1],
"error_message": row[2],
"state": row[3]
}
except Exception as e:
logging.error(f"Failed to get operation status: {str(e)}")
return {"status": "error", "error": str(e)}
def wait_for_operation(self, operation_id: str, timeout: Optional[float] = None) -> Dict[str, Any]:
"""Wait for a tensor operation to complete"""
start_time = time.time()
while True:
status = self.get_operation_status(operation_id)
if status["status"] in ["completed", "failed"]:
return status
if timeout and (time.time() - start_time) > timeout:
return {"status": "timeout"}
time.sleep(0.001)
def synchronize_operations(self, operation_ids: List[str]) -> Dict[str, Any]:
"""Synchronize multiple tensor operations"""
try:
results = {}
for op_id in operation_ids:
results[op_id] = self.wait_for_operation(op_id)
return {
"status": "completed",
"operations": results
}
except Exception as e:
logging.error(f"Failed to synchronize tensor operations: {str(e)}")
return {
"status": "error",
"error": str(e)
}
# Get warp if not specified
if warp_id is None:
available_warps = [
w for w in self.warps[chip_id][target_sm_id]
if len(w.get_active_threads()) > 0
]
if not available_warps:
raise RuntimeError("No available warps")
warp = available_warps[0]
warp_id = str(warp.warp_id)
op_info["warp_id"] = warp_id
# Schedule operation
op_metadata = target_sm.matrix_op_scheduler.schedule_operation(
op_type=operation,
input_shapes=[inp.get("shape") for inp in inputs],
warp_id=warp_id
)
if op_metadata is None:
raise RuntimeError("Failed to schedule matrix operation")
try:
# Acquire matrix operation lock
if not target_sm.matrix_op_lock.acquire_matrix_op(
op_metadata.op_id,
op_info
):
raise RuntimeError("Failed to acquire matrix operation lock")
# Execute operation based on type
result = None
if operation == "matmul":
A = self.memory_manager.read_tensor(inputs[0]["address"])
B = self.memory_manager.read_tensor(inputs[1]["address"])
result = target_sm.tensor_core_matmul(A, B, warp_id=warp_id)
elif operation == "conv2d":
input_tensor = self.memory_manager.read_tensor(inputs[0]["address"])
kernel = self.memory_manager.read_tensor(inputs[1]["address"])
result = target_sm.tensor_core_conv2d(input_tensor, kernel, warp_id=warp_id)
if result is None:
raise RuntimeError(f"Failed to execute {operation}")
# Allocate output and store result
output_addr = self.allocate_memory(
result.nbytes,
chip_id=chip_id,
tensor_shape=result.shape,
dtype=result.dtype
)
self.memory_manager.write_tensor(output_addr, result)
# Complete operation successfully
target_sm.matrix_op_scheduler.complete_operation(
op_metadata,
output_shape=result.shape,
success=True
)
# Update operation history
target_sm.tensor_op_history.append({
**op_info,
"op_id": op_metadata.op_id,
"output_shape": result.shape,
"output_address": output_addr,
"end_time": time.time_ns(),
"status": "completed"
})
return output_addr
except Exception as e:
# Handle operation failure
if op_metadata:
target_sm.matrix_op_scheduler.complete_operation(
op_metadata,
output_shape=None,
success=False,
error=str(e)
)
raise
finally:
# Always release the matrix operation lock
if op_metadata:
target_sm.matrix_op_lock.release_matrix_op(op_metadata.op_id)
except Exception as e:
logging.error(f"Tensor operation failed: {str(e)}")
return None
def get_tensor_op_status(self, chip_id: int, sm_id: int, op_id: str) -> Dict[str, Any]:
"""Get status and metadata for a tensor operation"""
try:
sm = self.streaming_multiprocessors[chip_id][sm_id]
active_ops = sm.matrix_op_scheduler.coordinator.get_active_operations()
# Check active operations
for op in active_ops:
if op.op_id == op_id:
return {
"status": "running",
"metadata": op.__dict__
}
# Check operation history
history = sm.matrix_op_scheduler.coordinator.get_operation_history()
for op in history:
if op.op_id == op_id:
return {
"status": op.status,
"metadata": op.__dict__
}
return {
"status": "not_found",
"metadata": None
}
except Exception as e:
logging.error(f"Failed to get operation status: {str(e)}")
return {
"status": "error",
"metadata": {"error": str(e)}
}
def sync_tensor_ops(self, chip_id: int, sm_id: int, warp_id: Optional[str] = None):
"""Synchronize pending tensor operations"""
try:
sm = self.streaming_multiprocessors[chip_id][sm_id]
# Get relevant operations
if warp_id is not None:
active_ops = [
op for op in sm.matrix_op_scheduler.coordinator.get_active_operations()
if op.warp_id == warp_id
]
else:
active_ops = sm.matrix_op_scheduler.coordinator.get_active_operations()
# Wait for operations to complete
for op in active_ops:
while True:
status = self.get_tensor_op_status(chip_id, sm_id, op.op_id)
if status["status"] not in ["running", "scheduled"]:
break
time.sleep(0.001) # Small delay to prevent busy waiting
return True
except Exception as e:
logging.error(f"Failed to synchronize tensor operations: {str(e)}")