|
import torch |
|
|
|
|
|
def create_grid_mask(seq_length, trunck_length, fill_triangle): |
|
assert seq_length > 0 |
|
|
|
|
|
if fill_triangle: |
|
mask = 1 - torch.triu(torch.ones(seq_length, seq_length), diagonal=1) |
|
|
|
else: |
|
mask = torch.zeros(seq_length, seq_length) |
|
|
|
for i in range(seq_length): |
|
trunck_idx = i // trunck_length |
|
trunck_start = trunck_idx * trunck_length |
|
trunck_end = trunck_length + trunck_start |
|
mask[i][trunck_start:trunck_end] = 1 |
|
|
|
return mask |
|
|
|
|
|
if __name__ == "__main__": |
|
mask = create_grid_mask(seq_length=8, trunck_length=3, fill_triangle=True).int() |
|
print(mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|