code-completion / archive-misc /test_new_attnmask.py
jblitzar's picture
Upload folder using huggingface_hub
a8639ac verified
import torch
def attnmask_new(sz):
return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)
def attnmask_old(sz):
return torch.log(torch.tril(torch.ones(sz, sz)))
def compare_masks(sz):
new_mask = attnmask_new(sz)
old_mask = attnmask_old(sz)
# print("New Mask:")
# print(new_mask)
# print("\nOld Mask:")
# print(old_mask)
if torch.equal(new_mask, old_mask):
#print("\nThe masks are equal.")
return True
else:
print("\nThe masks are NOT equal.")
raise ValueError("Masks differ, check implementation.")
if __name__ == "__main__":
for i in range(1, 100):
# print(f"Comparing masks of size {i}:")
compare_masks(i)
#print("\n" + "="*50 + "\n")
print("All masks are equal for sizes 1 to 99.")
# test time taken for size 256 of each implementation usint timeit
import timeit
size = 256
new_time = timeit.timeit(lambda: attnmask_new(size), number=1000)
old_time = timeit.timeit(lambda: attnmask_old(size), number=1000)
print(f"New mask time for size {size}: {new_time:.6f} seconds")
print(f"Old mask time for size {size}: {old_time:.6f} seconds")
if new_time < old_time:
print("New mask implementation is faster.")
else:
print("Old mask implementation is faster.")