tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
503 Bytes
"""Create mask for subsequent steps."""
def make_history_mask(xp, block):
"""Prepare the history mask.
Args:
block (ndarray): Block with dimensions: (B x S).
Returns:
ndarray, np.ndarray: History mask with dimensions (B, S, S).
"""
batch, length = block.shape
arange = xp.arange(length)
history_mask = (arange[None] <= arange[:, None])[
None,
]
history_mask = xp.broadcast_to(history_mask, (batch, length, length))
return history_mask