|
|
from typing import List, Dict, Any, Optional
|
|
|
import time
|
|
|
import json
|
|
|
import logging
|
|
|
import duckdb
|
|
|
from huggingface_hub import HfApi, HfFileSystem
|
|
|
from tensor_core import TensorCore
|
|
|
from config import get_hf_token_cached
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorOps:
|
|
|
"""Manages tensor operations with remote state tracking"""
|
|
|
DB_URL = "hf://datasets/Fred808/helium/storage.json"
|
|
|
|
|
|
def __init__(self, db_url: Optional[str] = None):
|
|
|
self.db_url = db_url or self.DB_URL
|
|
|
self.max_retries = 3
|
|
|
self._connect_with_retries()
|
|
|
self._setup_database()
|
|
|
|
|
|
def _connect_with_retries(self):
|
|
|
"""Establish database connection with retry logic"""
|
|
|
for attempt in range(self.max_retries):
|
|
|
try:
|
|
|
self.conn = self._init_db_connection()
|
|
|
return
|
|
|
except Exception as e:
|
|
|
if attempt == self.max_retries - 1:
|
|
|
raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
|
|
|
time.sleep(1)
|
|
|
|
|
|
def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
|
|
|
"""Initialize database connection with HuggingFace configuration"""
|
|
|
|
|
|
_, _, owner, dataset, db_file = self.db_url.split('/', 4)
|
|
|
db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
|
|
|
|
|
|
|
|
|
conn = duckdb.connect(db_path)
|
|
|
conn.execute("INSTALL httpfs;")
|
|
|
conn.execute("LOAD httpfs;")
|
|
|
conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';")
|
|
|
conn.execute("SET s3_use_ssl=true;")
|
|
|
conn.execute("SET s3_url_style='path';")
|
|
|
conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
|
|
|
conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
|
|
|
return conn
|
|
|
|
|
|
def _setup_database(self):
|
|
|
"""Initialize database tables"""
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS tensor_operations (
|
|
|
operation_id VARCHAR PRIMARY KEY,
|
|
|
operation_type VARCHAR,
|
|
|
inputs JSON,
|
|
|
output_shape VARCHAR,
|
|
|
chip_id INTEGER,
|
|
|
stream_id INTEGER,
|
|
|
warp_id VARCHAR,
|
|
|
status VARCHAR DEFAULT 'pending',
|
|
|
result_address BIGINT,
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
started_at TIMESTAMP,
|
|
|
completed_at TIMESTAMP,
|
|
|
error_message VARCHAR,
|
|
|
state_json JSON
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
def execute_tensor_op(self, operation: str, inputs: List[Dict[str, Any]],
|
|
|
output_shape: Optional[tuple] = None,
|
|
|
chip_id: Optional[int] = None,
|
|
|
stream_id: Optional[int] = None,
|
|
|
warp_id: Optional[str] = None) -> Optional[int]:
|
|
|
"""
|
|
|
Execute a tensor operation with enhanced tracking and coordination
|
|
|
Args:
|
|
|
operation: Operation type (matmul, conv2d, etc.)
|
|
|
inputs: List of input tensors with metadata
|
|
|
output_shape: Expected output shape (for pre-allocation)
|
|
|
chip_id: Target GPU chip (if None, will be automatically selected)
|
|
|
stream_id: Execution stream ID (if None, uses default stream)
|
|
|
warp_id: ID of warp to execute on (if None, automatically scheduled)
|
|
|
Returns:
|
|
|
Address of output tensor or None if operation fails
|
|
|
"""
|
|
|
operation_id = None
|
|
|
try:
|
|
|
|
|
|
operation_id = f"op_{time.time_ns()}"
|
|
|
|
|
|
|
|
|
if chip_id is None:
|
|
|
|
|
|
result = self.conn.execute("""
|
|
|
SELECT chip_id
|
|
|
FROM tensor_operations
|
|
|
WHERE status = 'running'
|
|
|
GROUP BY chip_id
|
|
|
ORDER BY COUNT(*) ASC
|
|
|
LIMIT 1
|
|
|
""").fetchall()
|
|
|
|
|
|
chip_id = result[0][0] if result else 0
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
INSERT INTO tensor_operations (
|
|
|
operation_id, operation_type, inputs, output_shape,
|
|
|
chip_id, stream_id, warp_id, status, state_json
|
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
""", [
|
|
|
operation_id,
|
|
|
operation,
|
|
|
inputs,
|
|
|
str(output_shape) if output_shape else None,
|
|
|
chip_id,
|
|
|
stream_id,
|
|
|
warp_id,
|
|
|
'pending',
|
|
|
{
|
|
|
"status": "initialized",
|
|
|
"timestamp": time.time_ns()
|
|
|
}
|
|
|
])
|
|
|
|
|
|
|
|
|
tensor_core = TensorCore()
|
|
|
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
UPDATE tensor_operations
|
|
|
SET status = 'running',
|
|
|
started_at = CURRENT_TIMESTAMP,
|
|
|
state_json = ?
|
|
|
WHERE operation_id = ?
|
|
|
""", [{"status": "running"}, operation_id])
|
|
|
|
|
|
|
|
|
result_address = None
|
|
|
if operation == 'matmul':
|
|
|
result_address = tensor_core.matmul(
|
|
|
inputs[0]['data'],
|
|
|
inputs[1]['data'],
|
|
|
warp_id=warp_id
|
|
|
)
|
|
|
elif operation == 'conv2d':
|
|
|
result_address = tensor_core.conv2d(
|
|
|
inputs[0]['data'],
|
|
|
inputs[1]['data'],
|
|
|
warp_id=warp_id
|
|
|
)
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
UPDATE tensor_operations
|
|
|
SET status = 'completed',
|
|
|
completed_at = CURRENT_TIMESTAMP,
|
|
|
result_address = ?,
|
|
|
state_json = ?
|
|
|
WHERE operation_id = ?
|
|
|
""", [
|
|
|
result_address,
|
|
|
{"status": "completed", "result": result_address},
|
|
|
operation_id
|
|
|
])
|
|
|
|
|
|
return result_address
|
|
|
|
|
|
except Exception as e:
|
|
|
if operation_id:
|
|
|
|
|
|
self.conn.execute("""
|
|
|
UPDATE tensor_operations
|
|
|
SET status = 'failed',
|
|
|
completed_at = CURRENT_TIMESTAMP,
|
|
|
error_message = ?,
|
|
|
state_json = ?
|
|
|
WHERE operation_id = ?
|
|
|
""", [
|
|
|
str(e),
|
|
|
{"status": "failed", "error": str(e)},
|
|
|
operation_id
|
|
|
])
|
|
|
logging.error(f"Tensor operation failed: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
def get_operation_status(self, operation_id: str) -> Dict[str, Any]:
|
|
|
"""Get the current status of a tensor operation"""
|
|
|
try:
|
|
|
result = self.conn.execute("""
|
|
|
SELECT status, result_address, error_message, state_json
|
|
|
FROM tensor_operations
|
|
|
WHERE operation_id = ?
|
|
|
""", [operation_id]).fetchall()
|
|
|
|
|
|
if not result:
|
|
|
return {"status": "not_found"}
|
|
|
|
|
|
row = result[0]
|
|
|
return {
|
|
|
"status": row[0],
|
|
|
"result_address": row[1],
|
|
|
"error_message": row[2],
|
|
|
"state": row[3]
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Failed to get operation status: {str(e)}")
|
|
|
return {"status": "error", "error": str(e)}
|
|
|
|
|
|
def wait_for_operation(self, operation_id: str, timeout: Optional[float] = None) -> Dict[str, Any]:
|
|
|
"""Wait for a tensor operation to complete"""
|
|
|
start_time = time.time()
|
|
|
while True:
|
|
|
status = self.get_operation_status(operation_id)
|
|
|
|
|
|
if status["status"] in ["completed", "failed"]:
|
|
|
return status
|
|
|
|
|
|
if timeout and (time.time() - start_time) > timeout:
|
|
|
return {"status": "timeout"}
|
|
|
|
|
|
time.sleep(0.001)
|
|
|
|
|
|
def synchronize_operations(self, operation_ids: List[str]) -> Dict[str, Any]:
|
|
|
"""Synchronize multiple tensor operations"""
|
|
|
try:
|
|
|
results = {}
|
|
|
for op_id in operation_ids:
|
|
|
results[op_id] = self.wait_for_operation(op_id)
|
|
|
|
|
|
return {
|
|
|
"status": "completed",
|
|
|
"operations": results
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Failed to synchronize tensor operations: {str(e)}")
|
|
|
return {
|
|
|
"status": "error",
|
|
|
"error": str(e)
|
|
|
}
|
|
|
|
|
|
|
|
|
if warp_id is None:
|
|
|
available_warps = [
|
|
|
w for w in self.warps[chip_id][target_sm_id]
|
|
|
if len(w.get_active_threads()) > 0
|
|
|
]
|
|
|
if not available_warps:
|
|
|
raise RuntimeError("No available warps")
|
|
|
warp = available_warps[0]
|
|
|
warp_id = str(warp.warp_id)
|
|
|
op_info["warp_id"] = warp_id
|
|
|
|
|
|
|
|
|
op_metadata = target_sm.matrix_op_scheduler.schedule_operation(
|
|
|
op_type=operation,
|
|
|
input_shapes=[inp.get("shape") for inp in inputs],
|
|
|
warp_id=warp_id
|
|
|
)
|
|
|
|
|
|
if op_metadata is None:
|
|
|
raise RuntimeError("Failed to schedule matrix operation")
|
|
|
|
|
|
try:
|
|
|
|
|
|
if not target_sm.matrix_op_lock.acquire_matrix_op(
|
|
|
op_metadata.op_id,
|
|
|
op_info
|
|
|
):
|
|
|
raise RuntimeError("Failed to acquire matrix operation lock")
|
|
|
|
|
|
|
|
|
result = None
|
|
|
if operation == "matmul":
|
|
|
A = self.memory_manager.read_tensor(inputs[0]["address"])
|
|
|
B = self.memory_manager.read_tensor(inputs[1]["address"])
|
|
|
result = target_sm.tensor_core_matmul(A, B, warp_id=warp_id)
|
|
|
|
|
|
elif operation == "conv2d":
|
|
|
input_tensor = self.memory_manager.read_tensor(inputs[0]["address"])
|
|
|
kernel = self.memory_manager.read_tensor(inputs[1]["address"])
|
|
|
result = target_sm.tensor_core_conv2d(input_tensor, kernel, warp_id=warp_id)
|
|
|
|
|
|
if result is None:
|
|
|
raise RuntimeError(f"Failed to execute {operation}")
|
|
|
|
|
|
|
|
|
output_addr = self.allocate_memory(
|
|
|
result.nbytes,
|
|
|
chip_id=chip_id,
|
|
|
tensor_shape=result.shape,
|
|
|
dtype=result.dtype
|
|
|
)
|
|
|
|
|
|
self.memory_manager.write_tensor(output_addr, result)
|
|
|
|
|
|
|
|
|
target_sm.matrix_op_scheduler.complete_operation(
|
|
|
op_metadata,
|
|
|
output_shape=result.shape,
|
|
|
success=True
|
|
|
)
|
|
|
|
|
|
|
|
|
target_sm.tensor_op_history.append({
|
|
|
**op_info,
|
|
|
"op_id": op_metadata.op_id,
|
|
|
"output_shape": result.shape,
|
|
|
"output_address": output_addr,
|
|
|
"end_time": time.time_ns(),
|
|
|
"status": "completed"
|
|
|
})
|
|
|
|
|
|
return output_addr
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
if op_metadata:
|
|
|
target_sm.matrix_op_scheduler.complete_operation(
|
|
|
op_metadata,
|
|
|
output_shape=None,
|
|
|
success=False,
|
|
|
error=str(e)
|
|
|
)
|
|
|
raise
|
|
|
|
|
|
finally:
|
|
|
|
|
|
if op_metadata:
|
|
|
target_sm.matrix_op_lock.release_matrix_op(op_metadata.op_id)
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Tensor operation failed: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
def get_tensor_op_status(self, chip_id: int, sm_id: int, op_id: str) -> Dict[str, Any]:
|
|
|
"""Get status and metadata for a tensor operation"""
|
|
|
try:
|
|
|
sm = self.streaming_multiprocessors[chip_id][sm_id]
|
|
|
active_ops = sm.matrix_op_scheduler.coordinator.get_active_operations()
|
|
|
|
|
|
|
|
|
for op in active_ops:
|
|
|
if op.op_id == op_id:
|
|
|
return {
|
|
|
"status": "running",
|
|
|
"metadata": op.__dict__
|
|
|
}
|
|
|
|
|
|
|
|
|
history = sm.matrix_op_scheduler.coordinator.get_operation_history()
|
|
|
for op in history:
|
|
|
if op.op_id == op_id:
|
|
|
return {
|
|
|
"status": op.status,
|
|
|
"metadata": op.__dict__
|
|
|
}
|
|
|
|
|
|
return {
|
|
|
"status": "not_found",
|
|
|
"metadata": None
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Failed to get operation status: {str(e)}")
|
|
|
return {
|
|
|
"status": "error",
|
|
|
"metadata": {"error": str(e)}
|
|
|
}
|
|
|
|
|
|
def sync_tensor_ops(self, chip_id: int, sm_id: int, warp_id: Optional[str] = None):
|
|
|
"""Synchronize pending tensor operations"""
|
|
|
try:
|
|
|
sm = self.streaming_multiprocessors[chip_id][sm_id]
|
|
|
|
|
|
|
|
|
if warp_id is not None:
|
|
|
active_ops = [
|
|
|
op for op in sm.matrix_op_scheduler.coordinator.get_active_operations()
|
|
|
if op.warp_id == warp_id
|
|
|
]
|
|
|
else:
|
|
|
active_ops = sm.matrix_op_scheduler.coordinator.get_active_operations()
|
|
|
|
|
|
|
|
|
for op in active_ops:
|
|
|
while True:
|
|
|
status = self.get_tensor_op_status(chip_id, sm_id, op.op_id)
|
|
|
if status["status"] not in ["running", "scheduled"]:
|
|
|
break
|
|
|
time.sleep(0.001)
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Failed to synchronize tensor operations: {str(e)}")
|
|
|
|
|
|
|