Spaces:
Running
on
Zero
Running
on
Zero
from copy import deepcopy | |
from typing import Optional, Tuple | |
import torch | |
from flash_attn import flash_attn_func | |
from transformers.modeling_outputs import CausalLMOutput | |
from ..ops.streaming_kernel import TritonMultiStageDotProductionAttention | |
class CudaCache: | |
def __init__(self, num_units, unit_size, dtype): | |
self.num_units = num_units | |
self.unit_size = unit_size | |
self.dtype = dtype | |
self.data = torch.empty((num_units, unit_size), device="cuda", dtype=dtype) | |
self.idle_set = set(list(range(num_units))) | |
def alloc(self): | |
assert len(self.idle_set) > 0 | |
idx = self.idle_set.pop() | |
return self.data[idx], idx | |
def delete(self, idx): | |
assert idx not in self.idle_set | |
self.idle_set.add(idx) | |
class MemoryUnit: | |
def __init__( | |
self, | |
kv: Tuple[torch.Tensor, torch.Tensor], | |
cache: CudaCache, | |
load_to_cache: bool = False, | |
pin_memory: bool = False, | |
): | |
self.cache = cache | |
if kv[0].is_cuda: | |
cpu_data = tuple(_t.contiguous().to("cpu", non_blocking=True) for _t in kv) | |
else: | |
cpu_data = tuple(_t.contiguous() for _t in kv) | |
if pin_memory: | |
cpu_data = tuple(_t.pin_memory() for _t in cpu_data) | |
if load_to_cache: | |
gpu_data, gpu_data_id = cache.alloc() | |
gpu_data = gpu_data.view((2,) + kv[0].shape) | |
gpu_data[0].copy_(kv[0], non_blocking=True) | |
gpu_data[1].copy_(kv[1], non_blocking=True) | |
event = torch.cuda.Event() | |
event.record(torch.cuda.current_stream()) | |
else: | |
gpu_data, gpu_data_id = None, None | |
event = None | |
self.cpu_data = cpu_data | |
self.gpu_data = gpu_data | |
self.gpu_data_id = gpu_data_id | |
self.event = event | |
def load(self, target: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> bool: | |
if self.gpu_data is not None: | |
if target is not None: | |
target[0].copy_(self.gpu_data[0], non_blocking=True) | |
target[1].copy_(self.gpu_data[1], non_blocking=True) | |
target_event = torch.cuda.Event() | |
target_event.record(torch.cuda.current_stream()) | |
else: | |
target_event = None | |
return False, target_event | |
gpu_data, gpu_data_id = self.cache.alloc() | |
gpu_data = gpu_data.view((2,) + self.cpu_data[0].shape) | |
if target is not None: | |
target[0].copy_(self.cpu_data[0], non_blocking=True) | |
target[1].copy_(self.cpu_data[1], non_blocking=True) | |
target_event = torch.cuda.Event() | |
target_event.record(torch.cuda.current_stream()) | |
gpu_data[0].copy_(target[0], non_blocking=True) | |
gpu_data[1].copy_(target[1], non_blocking=True) | |
else: | |
gpu_data[0].copy_(self.cpu_data[0], non_blocking=True) | |
gpu_data[1].copy_(self.cpu_data[1], non_blocking=True) | |
event = torch.cuda.Event() | |
event.record(torch.cuda.current_stream()) | |
self.event = event | |
self.gpu_data = gpu_data | |
self.gpu_data_id = gpu_data_id | |
return True, target_event | |
def get(self): | |
assert self.gpu_data is not None | |
self.event.wait() | |
return self.gpu_data | |
def offload(self): | |
assert self.gpu_data is not None | |
self.event.wait() | |
self.gpu_data = None | |
self.cache.delete(self.gpu_data_id) | |
self.gpu_data_id = None | |
class VectorTensor: | |
def __init__(self, hidden_size, element_dtype): | |
init_cached_size = 16 | |
self.data = torch.empty( | |
(init_cached_size, hidden_size), dtype=element_dtype, device="cuda" | |
) | |
self.length = 0 | |
self.cache_size = init_cached_size | |
self.hidden_size = hidden_size | |
def append_cache(self): | |
new_cache_size = self.cache_size * 2 | |
data_shape = self.data.shape | |
new_data = torch.empty( | |
(new_cache_size,) + data_shape[1:], device="cuda", dtype=self.data.dtype | |
) | |
new_data[: self.cache_size, ...].copy_(self.data) | |
self.data = new_data | |
self.cache_size = new_cache_size | |
def append(self, tensor: torch.Tensor): | |
assert tensor.dtype == self.data.dtype | |
assert tensor.size(1) == self.hidden_size | |
assert tensor.is_contiguous() | |
append_l = tensor.size(0) | |
while self.length + append_l > self.cache_size: | |
self.append_cache() | |
self.data[self.length : self.length + append_l, ...].copy_(tensor) | |
self.length += append_l | |
def get_data(self): | |
return self.data[: self.length, ...] | |
def get_topk(self, tensor: torch.Tensor, topk): # inner product | |
assert tensor.dim() == 1 and tensor.size(0) == self.hidden_size | |
logits = torch.matmul(self.data[: self.length], tensor[:, None]).squeeze(dim=-1) | |
assert logits.dim() == 1 and logits.size(0) == self.length | |
return logits.topk(topk, dim=0).indices.cpu().tolist() | |
def __len__(self): | |
return self.length | |
class Faiss: | |
def __init__(self, hidden_size, element_dtype): | |
import faiss | |
# We use the CPU index here because the GPU index requires a long initialization time | |
self.index = faiss.IndexFlatIP(hidden_size) | |
self.hidden_size = hidden_size | |
def append(self, tensor: torch.Tensor): | |
assert tensor.dim() == 2 and tensor.size(1) == self.hidden_size | |
self.index.add(tensor.cpu().float().numpy().astype("float32")) | |
def get_data(self): | |
raise ValueError | |
def get_topk(self, tensor: torch.Tensor, topk): | |
assert tensor.dim() == 1 and tensor.size(0) == self.hidden_size | |
xq = tensor[None, :].cpu().float().numpy().astype("float32") | |
topk_index = self.index.search(xq, topk)[1][0].tolist() | |
return topk_index | |
def __len__(self): | |
return self.index.ntotal | |
GLOBAL_STREAM = None | |
class ContextManager: | |
def __init__( | |
self, | |
position_embedding, | |
n_init, | |
n_local, | |
block_size, | |
max_cached_block, | |
topk, | |
exc_block_size, | |
score_decay: Optional[float] = None, | |
repr_topk: int = 1, | |
cache_strategy="lru", | |
chunk_topk_calc: Optional[int] = None, | |
async_global_stream: bool = False, | |
pin_memory: bool = False, | |
faiss: bool = False, | |
perhead: bool = False, | |
dense_decoding: bool = False, | |
): | |
self.length = 0 | |
self.position_embedding = position_embedding | |
self.n_init = n_init | |
self.n_local = n_local | |
self.block_size = block_size | |
self.max_cached_block = max_cached_block | |
self.exc_block_size = exc_block_size | |
self.score_decay = score_decay | |
assert exc_block_size <= n_local # no global token in input | |
self.topk = topk | |
self.Attn = TritonMultiStageDotProductionAttention | |
self.initialized = False | |
self.repr_topk = repr_topk | |
self.cache_strategy = cache_strategy | |
self.load_count = 0 | |
self.chunk_topk_calc = chunk_topk_calc | |
self.async_global_stream = async_global_stream | |
self.pin_memory = pin_memory | |
self.faiss = faiss | |
self.perhead = perhead | |
self.dense_decoding = dense_decoding | |
global GLOBAL_STREAM | |
if self.async_global_stream and GLOBAL_STREAM is None: | |
GLOBAL_STREAM = torch.cuda.Stream() | |
assert cache_strategy in ["lru", "lru-s"] | |
if cache_strategy == "lru-s": | |
self.calc_block_score = True | |
else: | |
self.calc_block_score = False | |
def remove_lru_blocks( | |
self, u, num_remove: Optional[int] = None, ignore_blocks=None | |
): | |
if num_remove is None: | |
num_remove = len(self.cached_blocks[u]) - self.max_cached_block | |
if num_remove <= 0: | |
return | |
lst = list(self.cached_blocks[u].items()) | |
lst.sort(key=lambda x: x[1]) | |
removed = 0 | |
for i in range(len(lst)): | |
idx = lst[i][0] | |
if ignore_blocks is None or (idx not in ignore_blocks): | |
self.global_blocks[u][idx].offload() | |
self.cached_blocks[u].pop(idx) | |
removed += 1 | |
if removed >= num_remove: | |
return | |
def get_block_k(self, k, score): | |
assert isinstance(score, torch.Tensor) | |
assert k.dim() >= 2 | |
k = self.from_group_kv(k) | |
assert k.shape[:-1] == score.shape | |
assert k.shape[-2] == self.block_size | |
score_topk = score.topk(self.repr_topk, dim=-1).indices | |
assert score_topk.shape == (self.num_units, self.unit_size, self.repr_topk) | |
ret = torch.gather( | |
k, | |
-2, | |
score_topk[:, :, :, None].expand( | |
self.num_units, self.unit_size, self.repr_topk, self.dim_head | |
), | |
) | |
return ret | |
def from_group_kv(self, tensor): | |
assert tensor.dim() == 4 | |
assert tensor.size(1) == self.num_heads_kv | |
if self.num_heads == self.num_heads_kv: | |
return tensor | |
_, _, length, dim_head = tensor.shape | |
num_group = self.num_heads // self.num_heads_kv | |
tensor = tensor.view((self.num_units, self.unit_size_kv, 1, length, dim_head)) | |
tensor = tensor.expand( | |
(self.num_units, self.unit_size_kv, num_group, length, dim_head) | |
).reshape((self.num_units, self.num_heads, length, dim_head)) | |
return tensor | |
def init(self, local_q, local_k, local_v, global_q, global_k, global_v): | |
assert local_q.dim() == 4 | |
batch_size, num_heads, len_q, dim_head = local_q.shape | |
num_heads_kv = local_k.size(1) | |
for _t in [local_q, local_k, local_v, global_q, global_k, global_v]: | |
assert _t.size(0) == batch_size | |
assert _t.size(1) == num_heads or _t.size(1) == num_heads_kv | |
assert _t.size(2) == len_q | |
assert _t.size(3) == dim_head | |
assert _t.is_cuda | |
self.batch_size = batch_size | |
self.num_heads = num_heads | |
self.num_heads_kv = num_heads_kv | |
self.dim_head = dim_head | |
self.num_units = batch_size | |
self.unit_size = num_heads | |
self.unit_size_kv = num_heads_kv | |
self.global_blocks = [[] for _ in range(self.num_units)] # [[memory_unit]] | |
self.cached_blocks = [ | |
{} for _ in range(self.num_units) | |
] # [[block_id: block_score] | |
self.num_global_block = 0 | |
if self.faiss: | |
self.block_k = [ | |
Faiss(dim_head * self.unit_size, global_k.dtype) | |
for _ in range(self.num_units) | |
] | |
else: | |
self.block_k = [ | |
VectorTensor(dim_head * self.unit_size, global_k.dtype) | |
for _ in range(self.num_units) | |
] | |
self.local_k = torch.empty( | |
(self.num_units, self.unit_size_kv, 0, dim_head), | |
dtype=local_k.dtype, | |
device=local_k.device, | |
) | |
self.local_v = torch.empty( | |
(self.num_units, self.unit_size_kv, 0, dim_head), | |
dtype=local_v.dtype, | |
device=local_v.device, | |
) | |
if self.dense_decoding: | |
self.dense_k = torch.empty( | |
(self.num_units, self.unit_size_kv, 0, dim_head), | |
dtype=local_k.dtype, | |
device=local_k.device, | |
) | |
self.dense_v = torch.empty( | |
(self.num_units, self.unit_size_kv, 0, dim_head), | |
dtype=local_v.dtype, | |
device=local_v.device, | |
) | |
self.global_remainder = ( | |
torch.empty( | |
(self.num_units, self.unit_size_kv, 0, dim_head), | |
dtype=global_k.dtype, | |
device=global_k.device, | |
), | |
torch.empty( | |
(self.num_units, self.unit_size_kv, 0, dim_head), | |
dtype=global_v.dtype, | |
device=global_v.device, | |
), | |
) | |
self.global_remainder_local_score = torch.empty( | |
(self.num_units, self.unit_size, 0), | |
dtype=global_k.dtype, | |
device=global_k.device, | |
) | |
self.init_k = torch.empty( | |
(self.num_units, self.unit_size_kv, 0, dim_head), | |
dtype=global_k.dtype, | |
device=global_k.device, | |
) | |
self.init_v = torch.empty( | |
(self.num_units, self.unit_size_kv, 0, dim_head), | |
dtype=global_k.dtype, | |
device=global_k.device, | |
) | |
self.init_exc = False | |
self.dtype = local_q.dtype | |
self.position_embedding._update_cos_sin_tables_len( | |
self.n_local + self.exc_block_size + 1, local_k.device, local_k.dim() | |
) | |
buffer_len = ( | |
self.topk * self.block_size | |
+ self.exc_block_size | |
+ self.block_size | |
+ self.n_init | |
) | |
self.global_buffer = torch.zeros( | |
(2, self.num_units, self.unit_size_kv, buffer_len, dim_head), | |
dtype=global_k.dtype, | |
device=global_k.device, | |
) | |
self.global_buffer_block_id_list = [ | |
[-1] * self.topk for _ in range(self.num_units) | |
] | |
self.global_buffer_init_st = 0 | |
self.global_buffer_init_ed = 0 | |
self.cuda_cache = CudaCache( | |
self.max_cached_block * self.num_units, | |
self.unit_size_kv * self.block_size * dim_head * 2, | |
local_k.dtype, | |
) | |
self.initialized = True | |
def calc_block_topk(self, global_h_q): | |
if not self._use_chunk_topk: | |
if self.num_global_block <= self.topk: | |
return [ | |
list(range(len(self.global_blocks[0]))) | |
for _ in range(self.num_units) | |
] | |
global_h_q = global_h_q.mean(dim=2, keepdim=False) | |
assert global_h_q.shape == (self.num_units, self.unit_size, self.dim_head) | |
global_h_q = global_h_q.reshape( | |
self.num_units, self.dim_head * self.unit_size | |
) | |
ret = [] | |
for u in range(self.num_units): | |
ret.append(self.block_k[u].get_topk(global_h_q[u], self.topk)) | |
else: | |
return self._cached_topk[self._topk_cur] | |
return ret | |
def get_global_hidden_and_mask(self, len_q, block_topk): | |
assert len(block_topk) == self.num_units | |
global_block_map = [[] for _ in range(self.num_units)] | |
global_remainder_len = max( | |
self._global_remainder_ed | |
- self._global_remainder_st | |
+ len_q | |
- self.n_local, | |
0, | |
) | |
init_len = self.init_k.size(-2) | |
sliding_window = None | |
global_h_k = self.global_buffer[0] | |
global_h_v = self.global_buffer[1] | |
block_num = len(block_topk[0]) | |
for u in range(self.num_units): | |
assert len(block_topk[u]) == block_num | |
block_topk[u].sort() | |
global_block_map[u] = deepcopy(self.global_buffer_block_id_list[u]) | |
for b_idx in block_topk[u]: | |
if b_idx in global_block_map[u]: | |
continue | |
st = -1 | |
ed = -1 | |
for j in range(self.topk): | |
if ( | |
global_block_map[u][j] == -1 | |
or global_block_map[u][j] not in block_topk[u] | |
): | |
st = j * self.block_size | |
ed = st + self.block_size | |
global_block_map[u][j] = b_idx | |
break | |
assert b_idx in self.cached_blocks[u] | |
self.global_blocks[u][b_idx].load( | |
(global_h_k[u, :, st:ed, :], global_h_v[u, :, st:ed, :]) | |
) | |
init_st = block_num * self.block_size | |
init_ed = init_st + init_len | |
if ( | |
self.global_buffer_init_st != init_st | |
or self.global_buffer_init_ed != init_ed | |
): | |
global_h_k[:, :, init_st:init_ed, :].copy_(self.init_k, non_blocking=True) | |
global_h_v[:, :, init_st:init_ed, :].copy_(self.init_v, non_blocking=True) | |
ed = init_ed | |
rmd_st = init_ed | |
rmd_ed = rmd_st + global_remainder_len | |
ed = rmd_ed | |
global_h_k[:, :, rmd_st:rmd_ed, :].copy_( | |
self.global_remainder[0][ | |
:, | |
:, | |
self._global_remainder_st : self._global_remainder_st | |
+ global_remainder_len, | |
:, | |
], | |
non_blocking=True, | |
) | |
global_h_v[:, :, rmd_st:rmd_ed, :].copy_( | |
self.global_remainder[1][ | |
:, | |
:, | |
self._global_remainder_st : self._global_remainder_st | |
+ global_remainder_len, | |
:, | |
], | |
non_blocking=True, | |
) | |
sliding_window = (self.global_remainder[0].size(-2) + rmd_st, self.n_local) | |
self.global_buffer_block_id_list = deepcopy(global_block_map) | |
self.global_buffer_init_st = init_st | |
self.global_buffer_init_ed = init_ed | |
for u in range(self.num_units): | |
assert max(global_block_map[u][block_num:] + [-1]) == -1 | |
assert min(global_block_map[u][:block_num] + [0]) > -1 | |
global_block_map[u] = list(global_block_map[u][:block_num]) | |
global_h_k = global_h_k[:, :, :ed, :] | |
global_h_v = global_h_v[:, :, :ed, :] | |
return global_h_k, global_h_v, sliding_window, global_block_map, block_num | |
def update_block_score( | |
self, global_score: torch.FloatTensor, global_block_map, global_block_num | |
): | |
if global_score is not None: | |
global_score = global_score[:, :, : global_block_num * self.block_size] | |
assert global_score.shape == ( | |
self.num_units, | |
self.unit_size, | |
global_block_num * self.block_size, | |
) | |
global_score = global_score.view( | |
self.num_units, self.unit_size, global_block_num, self.block_size | |
) | |
global_score = global_score.sum(dim=-1).sum(dim=1) | |
assert global_score.shape == (self.num_units, global_block_num) | |
global_score = global_score.to( | |
device="cpu", non_blocking=False | |
) # (num_units, global_block_num) | |
for u in range(self.num_units): | |
for k, v in self.cached_blocks[u].items(): | |
self.cached_blocks[u][k] = v * self.score_decay | |
score = global_score[u].tolist() | |
assert len(score) >= len(global_block_map[u]) | |
for s, i in zip(score, global_block_map[u]): | |
self.cached_blocks[u][i] += s | |
def _append(self, local_q, local_k, local_v, global_q): | |
# get local_h_q, local_h_k, local_h_v | |
local_h_q, local_h_k = self.position_embedding(local_q, local_k) | |
local_h_v = local_v | |
# calc local result first to overlap host-device communication | |
attn = self.Attn(local_h_q.shape, local_h_q.dtype, local_h_q.device) | |
attn.append( | |
local_h_q, local_h_k, local_h_v, get_score=True, sliding_window=self.n_local | |
) | |
# calc topk global repr k and load cache | |
with torch.cuda.stream(GLOBAL_STREAM): | |
block_topk = self.calc_block_topk(global_q) | |
for u in range(self.num_units): | |
num_remove = len(self.cached_blocks[u]) - self.max_cached_block | |
for bidx in block_topk[u]: | |
if bidx not in self.cached_blocks[u]: | |
num_remove += 1 | |
# update cache | |
self.remove_lru_blocks(u, num_remove, block_topk[u]) | |
if self.cache_strategy == "lru": | |
self.load_count += 1 | |
for u in range(self.num_units): | |
for bidx in block_topk[u]: | |
self.cached_blocks[u][bidx] = self.load_count | |
elif self.cache_strategy == "lru-s": | |
for u in range(self.num_units): | |
for bidx in block_topk[u]: | |
self.cached_blocks[u][bidx] = 0 | |
else: | |
raise ValueError | |
# get global_h_k, global_h_v, global_mask | |
# Beacuse exc_block_size <= n_local, no global_k, global_v used in global part | |
global_h_q = global_q | |
( | |
global_h_k, | |
global_h_v, | |
global_sliding_window, | |
global_block_map, | |
global_block_num, | |
) = self.get_global_hidden_and_mask(local_h_q.size(-2), block_topk) | |
if self.async_global_stream: | |
torch.cuda.current_stream().wait_stream(GLOBAL_STREAM) | |
# calc global result | |
attn.append( | |
global_h_q, | |
global_h_k, | |
global_h_v, | |
end=True, | |
get_score=self.calc_block_score, | |
sliding_window=global_sliding_window, | |
complement_sliding_window=True, | |
) | |
o, score_list = attn.get_result() | |
loc_score = score_list[0] | |
glb_score = score_list[1] | |
if self.async_global_stream: | |
GLOBAL_STREAM.wait_stream(torch.cuda.current_stream()) | |
# update global score | |
with torch.cuda.stream(GLOBAL_STREAM): | |
self.update_block_score(glb_score, global_block_map, global_block_num) | |
return o.view((self.batch_size, self.num_heads, -1, self.dim_head)), loc_score | |
def get_batched_topk(self, global_q): | |
length = global_q.shape[2] | |
exc_num = (length + self.exc_block_size - 1) // self.exc_block_size | |
exc_block_num = length // self.exc_block_size | |
ret = [] | |
if self.num_global_block <= self.topk: | |
for _ in range(exc_num): | |
ret.append( | |
[ | |
list(range(len(self.global_blocks[0]))) | |
for _ in range(self.num_units) | |
] | |
) | |
return ret | |
global_h_q = global_q | |
assert global_h_q.dim() == 4 | |
assert global_h_q.shape[:2] == (self.num_units, self.unit_size) | |
assert global_h_q.shape[3] == self.dim_head | |
block_k = torch.cat( | |
[self.block_k[u].get_data()[None, :, :] for u in range(self.num_units)], | |
dim=0, | |
) | |
assert block_k.shape == ( | |
self.num_units, | |
self.num_global_block, | |
self.dim_head * self.unit_size, | |
) | |
block_k = ( | |
block_k.reshape( | |
self.num_units, self.num_global_block, self.unit_size, self.dim_head | |
) | |
.permute(0, 2, 1, 3) | |
.contiguous() | |
) | |
if exc_block_num > 0: | |
tmp_global_h_q = ( | |
global_h_q[:, :, : exc_block_num * self.exc_block_size, :] | |
.reshape( | |
self.num_units, | |
self.unit_size, | |
exc_block_num, | |
self.exc_block_size, | |
self.dim_head, | |
) | |
.mean(dim=-2) | |
) | |
assert tmp_global_h_q.shape == ( | |
self.num_units, | |
self.unit_size, | |
exc_block_num, | |
self.dim_head, | |
) | |
block_score = torch.matmul(tmp_global_h_q, block_k.transpose(-1, -2)).mean( | |
dim=1 | |
) # (num_units, exc_block_num, num_global_block) | |
assert block_score.shape == ( | |
self.num_units, | |
exc_block_num, | |
self.num_global_block, | |
) | |
indices = block_score.topk(self.topk, dim=-1).indices.cpu() | |
for b in range(exc_block_num): | |
tmp = [] | |
for u in range(self.num_units): | |
tmp.append(indices[u, b].tolist()) | |
assert len(tmp[-1]) == self.topk | |
ret.append(tmp) | |
if exc_block_num != exc_num: | |
tmp_global_h_q = ( | |
global_h_q[:, :, exc_block_num * self.exc_block_size :, :] | |
.reshape( | |
self.num_units, | |
self.unit_size, | |
length - exc_block_num * self.exc_block_size, | |
self.dim_head, | |
) | |
.mean(dim=-2, keepdim=True) | |
) | |
assert tmp_global_h_q.shape == ( | |
self.num_units, | |
self.unit_size, | |
1, | |
self.dim_head, | |
) | |
block_score = torch.matmul(tmp_global_h_q, block_k.transpose(-1, -2)) | |
assert block_score.shape == ( | |
self.num_units, | |
self.unit_size, | |
1, | |
self.num_global_block, | |
) | |
block_score = block_score.squeeze(dim=2).mean(dim=1) | |
assert block_score.shape == (self.num_units, self.num_global_block) | |
indices = block_score.topk(self.topk, dim=-1).indices.cpu() | |
tmp = [] | |
for u in range(self.num_units): | |
tmp.append(indices[u].tolist()) | |
assert len(tmp[-1]) == self.topk | |
ret.append(tmp) | |
return ret | |
def append_global(self, exc_length, kv_length, local_score): | |
global_remainder_ed = self._global_remainder_ed + exc_length | |
global_remainder_st = self._global_remainder_st | |
global_remainder_len = global_remainder_ed - global_remainder_st | |
assert local_score.shape[:3] == (self.num_units, self.unit_size, kv_length) | |
local_score = local_score[:, :, -exc_length - self.n_local :] | |
self.global_remainder_local_score[ | |
:, :, global_remainder_ed - local_score.size(-1) : global_remainder_ed | |
].add_(local_score) | |
if not self.init_exc and global_remainder_len > self.n_local: | |
global_k = self.global_remainder[0] | |
global_v = self.global_remainder[1] | |
append_init_len = min( | |
self.n_init - self.init_k.size(-2), global_remainder_len - self.n_local | |
) | |
self.init_k = torch.cat( | |
( | |
self.init_k, | |
global_k[ | |
:, | |
:, | |
global_remainder_st : global_remainder_st + append_init_len, | |
:, | |
], | |
), | |
dim=-2, | |
) | |
self.init_v = torch.cat( | |
( | |
self.init_v, | |
global_v[ | |
:, | |
:, | |
global_remainder_st : global_remainder_st + append_init_len, | |
:, | |
], | |
), | |
dim=-2, | |
) | |
global_remainder_st += append_init_len | |
global_remainder_len -= append_init_len | |
if self.init_k.size(-2) == self.n_init: | |
self.init_exc = True | |
while global_remainder_len - self.block_size >= self.n_local: | |
global_remainder_len -= self.block_size | |
for u in range(self.num_units): | |
self.global_blocks[u].append( | |
( | |
MemoryUnit( | |
( | |
self.global_remainder[0][ | |
u, | |
:, | |
global_remainder_st : global_remainder_st | |
+ self.block_size, | |
:, | |
], | |
self.global_remainder[1][ | |
u, | |
:, | |
global_remainder_st : global_remainder_st | |
+ self.block_size, | |
:, | |
], | |
), | |
self.cuda_cache, | |
False, | |
self.pin_memory, | |
) | |
) | |
) | |
global_block_k = self.get_block_k( | |
self.global_remainder[0][ | |
:, :, global_remainder_st : global_remainder_st + self.block_size, : | |
], | |
self.global_remainder_local_score[ | |
:, :, global_remainder_st : global_remainder_st + self.block_size | |
], | |
) | |
assert global_block_k.shape == ( | |
self.num_units, | |
self.unit_size, | |
self.repr_topk, | |
self.dim_head, | |
) | |
global_block_k = global_block_k.mean(dim=-2, keepdim=False) | |
global_block_k = global_block_k.reshape( | |
self.num_units, self.unit_size * self.dim_head | |
) | |
global_block_k = global_block_k[:, None, :] | |
self.num_global_block += 1 | |
for u in range(self.num_units): | |
self.block_k[u].append(global_block_k[u]) | |
global_remainder_st += self.block_size | |
self._global_remainder_ed = global_remainder_ed | |
self._global_remainder_st = global_remainder_st | |
def append( | |
self, | |
local_q, | |
local_k, | |
local_v, | |
global_q, | |
global_k, | |
global_v, | |
): | |
batch_size = local_q.size(0) | |
input_length = local_q.size(-2) | |
if self.perhead: | |
num_heads = local_q.size(1) | |
num_heads_kv = local_v.size(1) | |
def repeat_kv(t): | |
t = t.view(batch_size, num_heads_kv, 1, input_length, -1) | |
t = t.expand( | |
batch_size, | |
num_heads_kv, | |
num_heads // num_heads_kv, | |
input_length, | |
-1, | |
) | |
t = t.reshape(batch_size * num_heads, 1, input_length, -1) | |
return t | |
local_q = local_q.view(batch_size * num_heads, 1, input_length, -1) | |
local_k = repeat_kv(local_k) | |
local_v = repeat_kv(local_v) | |
global_q = global_q.view(batch_size * num_heads, 1, input_length, -1) | |
global_k = repeat_kv(global_k) | |
global_v = repeat_kv(global_v) | |
if not self.initialized: | |
self.init(local_q, local_k, local_v, global_q, global_k, global_v) | |
input_length = local_q.size(-2) | |
if self.async_global_stream: | |
GLOBAL_STREAM.wait_stream(torch.cuda.current_stream()) | |
# append local and global tensor | |
self.local_k = torch.cat((self.local_k, local_k), dim=-2) | |
self.local_v = torch.cat((self.local_v, local_v), dim=-2) | |
kv_length = self.local_k.size(-2) | |
if self.dense_decoding: | |
self.dense_k = torch.cat((self.dense_k, local_k), dim=-2) | |
self.dense_v = torch.cat((self.dense_v, local_v), dim=-2) | |
# append global remainder | |
with torch.cuda.stream(GLOBAL_STREAM): | |
self._global_remainder_st = 0 | |
self._global_remainder_ed = self.global_remainder[0].size(-2) | |
self.global_remainder = ( | |
torch.cat((self.global_remainder[0], global_k), dim=-2), | |
torch.cat((self.global_remainder[1], global_v), dim=-2), | |
) | |
self.global_remainder_local_score = torch.cat( | |
( | |
self.global_remainder_local_score, | |
torch.zeros( | |
(self.num_units, self.unit_size, global_k.size(-2)), | |
dtype=global_k.dtype, | |
device=global_k.device, | |
), | |
), | |
dim=-1, | |
) | |
with torch.cuda.stream(GLOBAL_STREAM): | |
global_q = self.position_embedding.apply_rotary_pos_emb_one_angle( | |
global_q, self.n_local | |
) | |
use_chunk_topk = self.chunk_topk_calc is not None and input_length > 1 | |
self._use_chunk_topk = use_chunk_topk | |
if use_chunk_topk: | |
exc_block_num = input_length // self.exc_block_size | |
exc_block_per_topk_chunk = self.chunk_topk_calc // self.exc_block_size | |
calc_cur_list = [ | |
i * self.exc_block_size | |
for i in range(0, exc_block_num + 1, exc_block_per_topk_chunk) | |
] | |
if calc_cur_list[-1] < input_length: | |
calc_cur_list.append(input_length) | |
self._topk_cur = 0 | |
self._topk_calc_cur = -1 | |
o_list = [] | |
for st in range(0, input_length, self.exc_block_size): | |
ed = min(st + self.exc_block_size, input_length) | |
if use_chunk_topk and calc_cur_list[self._topk_calc_cur + 1] < ed: | |
# calculate topk and sync with host here | |
assert ed <= calc_cur_list[self._topk_calc_cur + 2] | |
self._topk_calc_cur += 1 | |
with torch.cuda.stream(GLOBAL_STREAM): | |
self._cached_topk = self.get_batched_topk( | |
global_q[ | |
:, | |
:, | |
calc_cur_list[self._topk_calc_cur] : calc_cur_list[ | |
self._topk_calc_cur + 1 | |
], | |
:, | |
] | |
) | |
self._topk_cur = 0 | |
kv_st = max(kv_length + st - input_length - self.n_local, 0) | |
kv_ed = kv_length + ed - input_length | |
chunk_o, local_score = self._append( | |
local_q[:, :, st:ed, :], | |
self.local_k[:, :, kv_st:kv_ed, :], | |
self.local_v[:, :, kv_st:kv_ed, :], | |
global_q[:, :, st:ed, :], | |
) | |
o_list.append(chunk_o) | |
# append global | |
with torch.cuda.stream(GLOBAL_STREAM): | |
self.append_global(ed - st, kv_ed - kv_st, local_score) | |
if self.async_global_stream: | |
torch.cuda.current_stream().wait_stream(GLOBAL_STREAM) | |
if use_chunk_topk: | |
self._topk_cur += 1 | |
self.length += input_length | |
# update local and global tensor | |
if self.local_k.size(-2) >= self.n_local: | |
self.local_k = self.local_k[:, :, -self.n_local :, :] | |
self.local_v = self.local_v[:, :, -self.n_local :, :] | |
assert self._global_remainder_ed == self.global_remainder[0].size(-2) | |
with torch.cuda.stream(GLOBAL_STREAM): | |
self.global_remainder = ( | |
self.global_remainder[0][:, :, self._global_remainder_st :, :], | |
self.global_remainder[1][:, :, self._global_remainder_st :, :], | |
) | |
self.global_remainder_local_score = self.global_remainder_local_score[ | |
:, :, self._global_remainder_st : | |
] | |
ret = torch.cat(o_list, dim=-2) | |
if self.perhead: | |
ret = ret.view(batch_size, num_heads, input_length, -1) | |
return ret | |
def size(self, *args, **kwargs): | |
return self.length | |
def inf_llm_forward( | |
n_local, | |
n_init, | |
topk, | |
block_size, | |
max_cached_block, | |
exc_block_size, | |
repr_topk: int = 1, | |
cache_strategy="lru", | |
score_decay=None, | |
chunk_topk_calc=None, | |
async_global_stream=True, | |
pin_memory=False, | |
faiss=False, | |
perhead=False, | |
dense_decoding=False, | |
*args, | |
**kwargs | |
): | |
def forward( | |
self, | |
query: torch.Tensor, | |
key_value: torch.Tensor, | |
position_bias: Optional[torch.Tensor], | |
use_cache: bool, | |
past_key_value, | |
project_q, | |
project_k, | |
project_v, | |
attention_out, | |
dim_head, | |
num_heads, | |
num_heads_kv, | |
): | |
batch_size = query.size(0) | |
len_q = query.size(1) | |
len_k = key_value.size(1) | |
# assert use_cache | |
h_q = project_q(query) # (batch, len_q, num_heads * dim_head) | |
h_k = project_k(key_value) # (batch, len_k, num_heads * dim_head) | |
h_v = project_v(key_value) # (batch, len_k, num_heads * dim_head) | |
h_q = ( | |
h_q.view(batch_size, len_q, num_heads, dim_head) | |
.permute(0, 2, 1, 3) | |
.contiguous() | |
) # (batch, num_heads, len_q, dim_head) | |
h_k = ( | |
h_k.view(batch_size, len_k, num_heads_kv, dim_head) | |
.permute(0, 2, 1, 3) | |
.contiguous() | |
) # (batch, num_heads_kv, len_k, dim_head) | |
h_v = ( | |
h_v.view(batch_size, len_k, num_heads_kv, dim_head) | |
.permute(0, 2, 1, 3) | |
.contiguous() | |
) # (batch, num_heads_kv, len_k, dim_head) | |
if len_q == 1 and dense_decoding: | |
past_k = past_key_value.dense_k | |
past_v = past_key_value.dense_v | |
h_k = torch.cat((past_k, h_k), dim=-2) | |
h_v = torch.cat((past_v, h_v), dim=-2) | |
past_key_value.dense_k = h_k | |
past_key_value.dense_v = h_v | |
h_q, h_k = position_bias(h_q, h_k) | |
# (batch_size, seqlen, nheads, headdim) | |
h_q = h_q.transpose(1, 2) | |
h_k = h_k.transpose(1, 2) | |
h_v = h_v.transpose(1, 2) | |
# (batch_size, seqlen, nheads, headdim) | |
o = flash_attn_func(h_q, h_k, h_v, causal=True) | |
o = o.reshape(batch_size, len_q, dim_head * num_heads) | |
o = attention_out(o) | |
if use_cache: | |
return o, past_key_value | |
else: | |
return o | |
if past_key_value is None: | |
past_key_value = ContextManager( | |
position_bias, | |
n_init, | |
n_local, | |
block_size, | |
max_cached_block, | |
topk, | |
exc_block_size, | |
score_decay, | |
repr_topk, | |
cache_strategy, | |
chunk_topk_calc, | |
async_global_stream, | |
pin_memory, | |
faiss, | |
perhead, | |
dense_decoding=dense_decoding, | |
) | |
local_q, local_k, local_v = h_q, h_k, h_v | |
global_q, global_k, global_v = h_q, h_k, h_v | |
o = past_key_value.append( | |
local_q, | |
local_k, | |
local_v, | |
global_q, | |
global_k, | |
global_v, | |
) | |
o = o.view(batch_size, num_heads, len_q, dim_head).permute(0, 2, 1, 3) | |
o = o.reshape(batch_size, len_q, dim_head * num_heads) | |
o = attention_out(o) | |
if use_cache: | |
return o, past_key_value | |
else: | |
return o | |
return forward | |
class GreedySearch: | |
def __init__(self, model, tokenizer): | |
model.eval() | |
self.device = model.device | |
self.model = model | |
self.tokenizer = tokenizer | |
self.past_kv = None | |
def clear(self): | |
self.past_kv = None | |
def _process_texts(self, input_text): | |
model_inputs = {} | |
input_ids = self.tokenizer.encode(input_text) | |
model_inputs["input_ids"] = input_ids | |
model_inputs["attention_mask"] = [1] * len(model_inputs["input_ids"]) | |
for key in model_inputs: | |
model_inputs[key] = ( | |
torch.tensor(model_inputs[key]).int().unsqueeze(0).cuda() | |
) | |
return model_inputs | |
def generate(self, text=None, input_ids=None, **kwargs): | |
if input_ids is None: | |
model_inputs = self._process_texts(text) | |
input_ids = model_inputs["input_ids"] | |
with torch.inference_mode(): | |
result = self._decode(input_ids, **kwargs) | |
self.clear() | |
return result | |
def _decode( | |
self, | |
input_ids, | |
max_length=100, | |
extra_end_token_ids=[], | |
chunk_size: int = 4096, | |
output=False, | |
): | |
if input_ids.dim() == 1: | |
input_ids = input_ids[None, :] | |
input_ids = input_ids.cuda() | |
attention_mask = torch.ones_like(input_ids) | |
assert input_ids.size(0) == 1 | |
length = input_ids.size(1) | |
end_token_ids = extra_end_token_ids + [self.tokenizer.eos_token_id] | |
logits = None | |
past_key_values = self.past_kv | |
if output: | |
output_text = "" | |
for i in range(max_length + 1): | |
if i == 0: | |
if chunk_size is None: | |
chunk_size = input_ids.size(1) | |
for st in range(0, input_ids.size(1) - 1, chunk_size): | |
ed = min(input_ids.size(1) - 1, st + chunk_size) | |
out = self.model( | |
input_ids=input_ids[:, st:ed], | |
attention_mask=attention_mask[:, :ed], | |
use_cache=True, | |
return_dict=True, | |
past_key_values=past_key_values, | |
) | |
logits, past_key_values = out.logits, out.past_key_values | |
out = self.model( | |
input_ids=input_ids[:, -1:], | |
attention_mask=attention_mask, | |
use_cache=True, | |
return_dict=True, | |
past_key_values=past_key_values, | |
) | |
logits, past_key_values = out.logits, out.past_key_values | |
else: | |
out = self.model( | |
input_ids=input_ids[:, -1:], | |
attention_mask=attention_mask, | |
past_key_values=past_key_values, | |
use_cache=True, | |
return_dict=True, | |
) | |
logits, past_key_values = out.logits, out.past_key_values | |
logits = logits[:, -1, :] | |
word = logits.argmax(dim=-1) | |
if word.item() in end_token_ids or i == max_length: | |
break | |
input_ids = torch.cat((input_ids, word.view(1, 1)), dim=-1) | |
attention_mask = torch.cat( | |
( | |
attention_mask, | |
torch.ones( | |
(attention_mask.size(0), 1), | |
dtype=torch.int, | |
device=attention_mask.device, | |
), | |
), | |
dim=-1, | |
) | |
if output: | |
tmp = self.tokenizer.decode(input_ids.squeeze(0)[length:]) | |
if len(tmp) > len(output_text): | |
import sys | |
sys.stdout.write(tmp[len(output_text) :]) | |
sys.stdout.flush() | |
output_text = tmp | |
self.past_kv = past_key_values | |
if output: | |
sys.stdout.write("\n") | |
sys.stdout.flush() | |
# return [self.tokenizer.decode(input_ids.squeeze(0)[length:])] | |
return input_ids | |
class InfLLMGenerator(GreedySearch): | |
def generate( | |
self, | |
input_ids=None, | |
generation_config=None, | |
pad_token_id=None, | |
max_new_tokens=None, | |
): | |
if max_new_tokens is not None: | |
max_new_tokens = max_new_tokens | |
else: | |
max_new_tokens = generation_config.max_new_tokens | |
return super().generate( | |
text=None, | |
input_ids=input_ids, | |
max_length=max_new_tokens, | |
chunk_size=8192, | |
extra_end_token_ids=[pad_token_id] if pad_token_id is not None else [], | |
) | |
def __call__(self, input_ids=None, *args, **kwargs): | |
# chunked forward | |
chunk_size = 8192 | |
all_logits = torch.empty(0, dtype=torch.bfloat16).to(input_ids.device) | |
for st in range(0, input_ids.size(1), chunk_size): | |
torch.cuda.empty_cache() | |
ed = min(input_ids.size(1), st + chunk_size) | |
out = self.model( | |
input_ids=input_ids[:, st:ed], | |
) | |
logits = out.logits.to(torch.bfloat16) | |
all_logits = torch.cat((all_logits, logits), dim=1) | |
return CausalLMOutput(logits=all_logits) | |