zhangzhi's picture
init commit
a476bbf verified
raw
history blame
No virus
2.06 kB
import torch
class Masker:
def __init__(self, tokenizer) -> None:
self.tokenizer = tokenizer
self.mask_token_id = self.tokenizer.mask_token_id
def random_mask(self, input_ids, mask_prob=0.15):
device = input_ids.device
mask = (torch.rand(input_ids.shape) < mask_prob).to(device)
mask = mask & (torch.logical_not(self.tokenizer.is_special_token(input_ids)))
masked_input_ids = input_ids.clone()
masked_input_ids[mask] = self.mask_token_id
return masked_input_ids, mask
def mask_ptm_tokens(
self,
input_ids,
):
device = input_ids.device
is_ptm_mask = self.tokenizer.is_ptm_token(input_ids).to(device)
is_ptm_mask = is_ptm_mask & (
torch.logical_not(self.tokenizer.is_special_token(input_ids))
)
masked_input_ids = input_ids.clone()
masked_input_ids[is_ptm_mask] = self.mask_token_id
return masked_input_ids, is_ptm_mask
def random_and_ptm_mask(self, input_ids, mask_prob=0.15):
device = input_ids.device
mask = (torch.rand(input_ids.shape) < mask_prob).to(device)
mask = mask & (torch.logical_not(self.tokenizer.is_special_token(input_ids)))
is_ptm_mask = self.tokenizer.is_ptm_token(input_ids).to(device)
is_ptm_mask = is_ptm_mask & (
torch.logical_not(self.tokenizer.is_special_token(input_ids))
)
mask = mask | is_ptm_mask
masked_input_ids = input_ids.clone()
masked_input_ids[mask] = self.mask_token_id
return masked_input_ids, mask
def random_or_random_and_ptm_mask(
self, input_ids, ranom_mask_prob=0.15, alternate_prob=0.2
):
"""
alternate between [(1) random mask] and [(2) random mask & ptm mask] by probability alternate_prob
"""
p = torch.rand(1).item()
if p < alternate_prob:
return self.random_mask(input_ids, ranom_mask_prob)
else:
return self.random_and_ptm_mask(input_ids, ranom_mask_prob)