File size: 1,338 Bytes
a8639ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch


def attnmask_new(sz):
    return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

def attnmask_old(sz):
    return torch.log(torch.tril(torch.ones(sz, sz)))

def compare_masks(sz):
    new_mask = attnmask_new(sz)
    old_mask = attnmask_old(sz)
    
    # print("New Mask:")
    # print(new_mask)
    # print("\nOld Mask:")
    # print(old_mask)
    
    if torch.equal(new_mask, old_mask):
        #print("\nThe masks are equal.")
        return True
    else:
        print("\nThe masks are NOT equal.")
        raise ValueError("Masks differ, check implementation.")

if __name__ == "__main__":
    for i in range(1, 100):
        # print(f"Comparing masks of size {i}:")
        compare_masks(i)
        #print("\n" + "="*50 + "\n")
    print("All masks are equal for sizes 1 to 99.")

    # test time taken for size 256 of each implementation usint timeit
    import timeit
    size = 256
    new_time = timeit.timeit(lambda: attnmask_new(size), number=1000)
    old_time = timeit.timeit(lambda: attnmask_old(size), number=1000)
    print(f"New mask time for size {size}: {new_time:.6f} seconds")
    print(f"Old mask time for size {size}: {old_time:.6f} seconds")
    if new_time < old_time:
        print("New mask implementation is faster.")
    else:
        print("Old mask implementation is faster.")