|
|
from dataclasses import dataclass
|
|
|
from typing import List, Dict, Any, Callable
|
|
|
import numpy as np
|
|
|
from threading import Lock
|
|
|
|
|
|
@dataclass
|
|
|
class KernelConfig:
|
|
|
"""Configuration for a CUDA-like kernel launch"""
|
|
|
block_dim: tuple[int, int, int]
|
|
|
grid_dim: tuple[int, int, int]
|
|
|
shared_memory_size: int = 0
|
|
|
|
|
|
class ThreadIdx:
|
|
|
"""Thread index within a block"""
|
|
|
def __init__(self, x: int, y: int, z: int):
|
|
|
self.x = x
|
|
|
self.y = y
|
|
|
self.z = z
|
|
|
|
|
|
class BlockIdx:
|
|
|
"""Block index within the grid"""
|
|
|
def __init__(self, x: int, y: int, z: int):
|
|
|
self.x = x
|
|
|
self.y = y
|
|
|
self.z = z
|
|
|
|
|
|
class Warp:
|
|
|
"""Represents a group of 32 threads that execute in lockstep"""
|
|
|
WARP_SIZE = 32
|
|
|
|
|
|
def __init__(self, warp_id: int, threads: List[ThreadIdx]):
|
|
|
self.warp_id = warp_id
|
|
|
self.threads = threads
|
|
|
self.active_mask = (1 << len(threads)) - 1
|
|
|
|
|
|
def synchronize(self):
|
|
|
"""Synchronize all threads in the warp"""
|
|
|
pass
|
|
|
|
|
|
def vote_all(self, predicate: bool) -> bool:
|
|
|
"""Returns true if predicate is true for all active threads"""
|
|
|
return all(predicate for _ in range(len(self.threads)))
|
|
|
|
|
|
def vote_any(self, predicate: bool) -> bool:
|
|
|
"""Returns true if predicate is true for any active thread"""
|
|
|
return any(predicate for _ in range(len(self.threads)))
|
|
|
|
|
|
class Block:
|
|
|
"""Represents a thread block with shared memory"""
|
|
|
def __init__(self, block_idx: BlockIdx, dim: tuple[int, int, int], shared_mem_size: int):
|
|
|
self.block_idx = block_idx
|
|
|
self.dim = dim
|
|
|
self.shared_memory = SharedMemory(shared_mem_size)
|
|
|
self.warps: List[Warp] = []
|
|
|
self._create_warps()
|
|
|
|
|
|
def _create_warps(self):
|
|
|
"""Organize threads into warps"""
|
|
|
threads = []
|
|
|
total_threads = self.dim[0] * self.dim[1] * self.dim[2]
|
|
|
|
|
|
for idx in range(total_threads):
|
|
|
|
|
|
z = idx // (self.dim[0] * self.dim[1])
|
|
|
y = (idx % (self.dim[0] * self.dim[1])) // self.dim[0]
|
|
|
x = idx % self.dim[0]
|
|
|
threads.append(ThreadIdx(x, y, z))
|
|
|
|
|
|
if len(threads) == Warp.WARP_SIZE or idx == total_threads - 1:
|
|
|
self.warps.append(Warp(len(self.warps), threads))
|
|
|
threads = []
|
|
|
|
|
|
def synchronize(self):
|
|
|
"""Synchronize all threads in the block"""
|
|
|
for warp in self.warps:
|
|
|
warp.synchronize()
|
|
|
|
|
|
class SharedMemory:
|
|
|
"""Represents shared memory accessible by all threads in a block"""
|
|
|
def __init__(self, size_bytes: int):
|
|
|
self.size = size_bytes
|
|
|
self.data = bytearray(size_bytes)
|
|
|
self.lock = Lock()
|
|
|
|
|
|
def read(self, offset: int, size: int) -> bytearray:
|
|
|
with self.lock:
|
|
|
return self.data[offset:offset + size]
|
|
|
|
|
|
def write(self, offset: int, data: bytearray):
|
|
|
with self.lock:
|
|
|
self.data[offset:offset + len(data)] = data
|
|
|
|
|
|
class KernelFunction:
|
|
|
"""Wrapper for a kernel function"""
|
|
|
def __init__(self, func: Callable):
|
|
|
self.func = func
|
|
|
self.shared_memory_size = 0
|
|
|
|
|
|
def configure(self, shared_memory_size: int = 0):
|
|
|
"""Configure kernel properties"""
|
|
|
self.shared_memory_size = shared_memory_size
|
|
|
return self
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
"""Execute the kernel function"""
|
|
|
return self.func(*args, **kwargs)
|
|
|
|
|
|
def launch_kernel(kernel_func: KernelFunction, config: KernelConfig, *args):
|
|
|
"""Launch a kernel with the specified configuration"""
|
|
|
total_blocks = config.grid_dim[0] * config.grid_dim[1] * config.grid_dim[2]
|
|
|
|
|
|
|
|
|
blocks = []
|
|
|
for block_idx in range(total_blocks):
|
|
|
|
|
|
bz = block_idx // (config.grid_dim[0] * config.grid_dim[1])
|
|
|
by = (block_idx % (config.grid_dim[0] * config.grid_dim[1])) // config.grid_dim[0]
|
|
|
bx = block_idx % config.grid_dim[0]
|
|
|
|
|
|
block = Block(
|
|
|
BlockIdx(bx, by, bz),
|
|
|
config.block_dim,
|
|
|
config.shared_memory_size
|
|
|
)
|
|
|
blocks.append(block)
|
|
|
|
|
|
|
|
|
for block in blocks:
|
|
|
for warp in block.warps:
|
|
|
for thread in warp.threads:
|
|
|
kernel_func(block, thread, *args)
|
|
|
|
|
|
def kernel(func: Callable) -> KernelFunction:
|
|
|
"""Decorator to mark a function as a kernel"""
|
|
|
return KernelFunction(func)
|
|
|
|