Spaces:
Running
Running
import torch | |
import act_mem | |
import layers | |
if __name__ == "__main__": | |
batch_size, seq_len, d_model, n_heads = 1, 128, 1024, 32 | |
print(f"Batch size: {batch_size}, sequence length: {seq_len}, d_model: {d_model}, n_heads: {n_heads}") | |
dtype = torch.bfloat16 | |
inputs = torch.randn( | |
batch_size, | |
seq_len, | |
d_model, | |
device="cuda", | |
requires_grad=True, | |
dtype=dtype, | |
) | |
attn = layers.Attention( | |
d_model=d_model, | |
n_heads=n_heads, | |
device="cuda", | |
dtype=dtype, | |
) | |
with act_mem.AllocatedMemContext() as mem, act_mem.SavedTensorContext( | |
ignored_tensors=attn.parameters() | |
) as saved: | |
out = attn(inputs) | |
stm = saved.saved_tensor_mem | |
print(f'{mem.delta["current"]=}') | |
print(f"{stm=}") | |
print(f"{stm/out.numel()=}") | |