Spaces:
Paused
Paused
import os | |
from typing import List, Callable | |
import numpy as np | |
import torch | |
from torch.nn import functional as F | |
np_dtype_to_torch_dtype = { | |
np.float16: torch.float16, | |
np.float32: torch.float32, | |
np.uint8: torch.uint8, | |
np.int8: torch.int8, | |
np.int32: torch.int32, | |
np.int64: torch.int64, | |
bool: torch.bool, | |
} | |
class IndexDiff: | |
def __init__(self, off_elements: torch.Tensor=None, off_positions: torch.Tensor=None, on_positions: torch.Tensor=None): | |
self.off_elements = off_elements | |
self.off_positions = off_positions | |
self.on_positions = on_positions | |
def batch_copy(sources: List[torch.Tensor], copy_stream, indices=None, device='cuda'): | |
with torch.cuda.stream(copy_stream): | |
out = () | |
for src in sources: | |
indexed = src[indices] if indices is not None else src | |
dst = torch.empty(indexed.shape, device=device, dtype=src.dtype) | |
dst.copy_(indexed, non_blocking=True) | |
out += (dst,) | |
return out | |
def mmap_to_tensor(torch_wrapped_mmap, pin_memory=False) -> torch.Tensor: | |
out = torch.empty(torch_wrapped_mmap.shape, dtype=torch_wrapped_mmap.dtype, device='cpu', pin_memory=pin_memory) | |
out.copy_(torch_wrapped_mmap) | |
return out | |
# Assuming that each entry of cached_indices is a step down the memory hierarchy, | |
# compute the diff at each level of the hierarchy. | |
# e.g. the first loop computes the indices that the GPU does not have, | |
# and the second loop computes the indices *of that diff* that the CPU does not have. | |
def compute_index_diffs(new_indices: torch.Tensor, cached_indices_list: List[torch.Tensor], pin_memory=True): | |
diffs = [] | |
current_diff = new_indices | |
for cached_indices in cached_indices_list: | |
if current_diff.size(0) == 0: | |
# No need to go further down the hierarchy | |
break | |
# Compute elements of new indices not contained current indices | |
off_elements = torch.tensor( | |
list(set(current_diff.tolist()).difference(set(cached_indices.tolist()))), | |
device='cpu', | |
dtype=torch.int32, | |
pin_memory=pin_memory | |
) | |
# Compute mask of current indices where new indices does not contain the element | |
on_position_mask = torch.isin(cached_indices, current_diff, assume_unique=True) | |
on_positions = torch.nonzero(on_position_mask).flatten() | |
off_positions = torch.nonzero(~on_position_mask).flatten()[:off_elements.size(0)] | |
diffs.append(IndexDiff(off_elements, off_positions, on_positions)) | |
current_diff = off_elements | |
return diffs | |
def topk_and_threshold(x, k, threshold=1): | |
vals, indices = torch.topk(x, k, sorted=True) | |
return indices[vals > threshold].int() | |
def load_mlp_sparsity_predictor(weight_path_prefix: str, layer_num: int, dtype: torch.dtype, device: str = 'cuda') -> Callable: | |
path_prefix = f'{weight_path_prefix}decoder.layers.{layer_num}.attn.mlp-sparsity-predictor.' | |
return load_predictor(path_prefix, dtype, device=device) | |
def load_predictor(path_prefix: str, dtype: torch.dtype, device: str='cuda') -> Callable: | |
path = lambda i: os.path.expanduser(f'{path_prefix}{i}.weight') | |
if os.path.exists(path(1)): | |
l1 = torch.load(path(1)).to(device).to(dtype) | |
l2 = torch.load(path(2)).to(device).to(dtype) | |
return lambda x: F.linear(F.linear(x, l1), l2) | |
else: | |
print(f'could not find predictor at {path(1)}') | |
return None |