huaweilin's picture
update
14ce5a9
"""This file contains the definition of utility functions for masking."""
import math
from typing import Text, Tuple
import torch
def get_mask_tokens(
tokens: torch.Tensor,
mask_token: int,
mode: Text = "arccos",
min_masking_ratio: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the masked tokens.
Args:
tokens -> torch.Tensor: The input tokens.
mask_token -> int: The special `mask` token.
mode -> Text: The masking function to use (default: "arccos").
Returns:
masked_tokens -> torch.Tensor: The masked input tokens. Each masked token is set to mask_token.
mask -> torch.Tensor: A boolean tensor mask indicating which tokens are masked.
"""
r = torch.rand(tokens.size(0)) * (1 - min_masking_ratio)
if mode == "linear":
val_to_mask = 1 - r
elif mode == "square":
val_to_mask = 1 - (r**2)
elif mode == "cosine":
val_to_mask = torch.cos(r * math.pi * 0.5)
elif mode == "arccos":
val_to_mask = torch.acos(r) / (math.pi * 0.5)
else:
raise ValueError(
"Invalid mode. Choose between 'linear','square', 'cosine', 'arccos'."
)
masked_tokens = tokens.detach().clone()
mask = torch.rand(tokens.size()) < val_to_mask.view(-1, 1, 1)
masked_tokens[mask] = torch.full_like(masked_tokens[mask], mask_token)
return masked_tokens, mask
def get_masking_ratio(progress: float, mode: Text = "arccos") -> torch.Tensor:
"""Get masking ratio.
Args:
progress -> float: The percentage of iterations already done.
mode -> Text: The masking function to use (default: "arccos").
Returns:
val_to_mask -> torch.Tensor: The masking ratio.
"""
r = torch.tensor(progress)
if mode == "root":
val_to_mask = 1 - (r**0.5)
elif mode == "square":
val_to_mask = 1 - (r**2)
elif mode == "cosine":
val_to_mask = torch.cos(r * math.pi * 0.5)
elif mode == "arccos":
val_to_mask = torch.acos(r) / (math.pi * 0.5)
elif mode == "linear":
val_to_mask = 1 - r
else:
raise ValueError(
"Invalid mode. Choose between 'linear','square', 'cosine', 'arccos', 'root'."
)
val_to_mask = torch.clamp(val_to_mask, 1e-6, 1.0)
return val_to_mask