tricksy / util.py
austinsilveria's picture
Add tricksy
75fa479
raw
history blame
3.49 kB
import os
from typing import List, Callable
import numpy as np
import torch
from torch.nn import functional as F
np_dtype_to_torch_dtype = {
np.float16: torch.float16,
np.float32: torch.float32,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int32: torch.int32,
np.int64: torch.int64,
bool: torch.bool,
}
class IndexDiff:
def __init__(self, off_elements: torch.Tensor=None, off_positions: torch.Tensor=None, on_positions: torch.Tensor=None):
self.off_elements = off_elements
self.off_positions = off_positions
self.on_positions = on_positions
def batch_copy(sources: List[torch.Tensor], copy_stream, indices=None, device='cuda'):
with torch.cuda.stream(copy_stream):
out = ()
for src in sources:
indexed = src[indices] if indices is not None else src
dst = torch.empty(indexed.shape, device=device, dtype=src.dtype)
dst.copy_(indexed, non_blocking=True)
out += (dst,)
return out
def mmap_to_tensor(torch_wrapped_mmap, pin_memory=False) -> torch.Tensor:
out = torch.empty(torch_wrapped_mmap.shape, dtype=torch_wrapped_mmap.dtype, device='cpu', pin_memory=pin_memory)
out.copy_(torch_wrapped_mmap)
return out
# Assuming that each entry of cached_indices is a step down the memory hierarchy,
# compute the diff at each level of the hierarchy.
# e.g. the first loop computes the indices that the GPU does not have,
# and the second loop computes the indices *of that diff* that the CPU does not have.
def compute_index_diffs(new_indices: torch.Tensor, cached_indices_list: List[torch.Tensor], pin_memory=True):
diffs = []
current_diff = new_indices
for cached_indices in cached_indices_list:
if current_diff.size(0) == 0:
# No need to go further down the hierarchy
break
# Compute elements of new indices not contained current indices
off_elements = torch.tensor(
list(set(current_diff.tolist()).difference(set(cached_indices.tolist()))),
device='cpu',
dtype=torch.int32,
pin_memory=pin_memory
)
# Compute mask of current indices where new indices does not contain the element
on_position_mask = torch.isin(cached_indices, current_diff, assume_unique=True)
on_positions = torch.nonzero(on_position_mask).flatten()
off_positions = torch.nonzero(~on_position_mask).flatten()[:off_elements.size(0)]
diffs.append(IndexDiff(off_elements, off_positions, on_positions))
current_diff = off_elements
return diffs
def topk_and_threshold(x, k, threshold=1):
vals, indices = torch.topk(x, k, sorted=True)
return indices[vals > threshold].int()
def load_mlp_sparsity_predictor(weight_path_prefix: str, layer_num: int, dtype: torch.dtype, device: str = 'cuda') -> Callable:
path_prefix = f'{weight_path_prefix}decoder.layers.{layer_num}.attn.mlp-sparsity-predictor.'
return load_predictor(path_prefix, dtype, device=device)
def load_predictor(path_prefix: str, dtype: torch.dtype, device: str='cuda') -> Callable:
path = lambda i: os.path.expanduser(f'{path_prefix}{i}.weight')
if os.path.exists(path(1)):
l1 = torch.load(path(1)).to(device).to(dtype)
l2 = torch.load(path(2)).to(device).to(dtype)
return lambda x: F.linear(F.linear(x, l1), l2)
else:
print(f'could not find predictor at {path(1)}')
return None