import torch class Masker: def __init__(self, tokenizer) -> None: self.tokenizer = tokenizer self.mask_token_id = self.tokenizer.mask_token_id def random_mask(self, input_ids, mask_prob=0.15): device = input_ids.device mask = (torch.rand(input_ids.shape) < mask_prob).to(device) mask = mask & (torch.logical_not(self.tokenizer.is_special_token(input_ids))) masked_input_ids = input_ids.clone() masked_input_ids[mask] = self.mask_token_id return masked_input_ids, mask def mask_ptm_tokens( self, input_ids, ): device = input_ids.device is_ptm_mask = self.tokenizer.is_ptm_token(input_ids).to(device) is_ptm_mask = is_ptm_mask & ( torch.logical_not(self.tokenizer.is_special_token(input_ids)) ) masked_input_ids = input_ids.clone() masked_input_ids[is_ptm_mask] = self.mask_token_id return masked_input_ids, is_ptm_mask def random_and_ptm_mask(self, input_ids, mask_prob=0.15): device = input_ids.device mask = (torch.rand(input_ids.shape) < mask_prob).to(device) mask = mask & (torch.logical_not(self.tokenizer.is_special_token(input_ids))) is_ptm_mask = self.tokenizer.is_ptm_token(input_ids).to(device) is_ptm_mask = is_ptm_mask & ( torch.logical_not(self.tokenizer.is_special_token(input_ids)) ) mask = mask | is_ptm_mask masked_input_ids = input_ids.clone() masked_input_ids[mask] = self.mask_token_id return masked_input_ids, mask def random_or_random_and_ptm_mask( self, input_ids, ranom_mask_prob=0.15, alternate_prob=0.2 ): """ alternate between [(1) random mask] and [(2) random mask & ptm mask] by probability alternate_prob """ p = torch.rand(1).item() if p < alternate_prob: return self.random_mask(input_ids, ranom_mask_prob) else: return self.random_and_ptm_mask(input_ids, ranom_mask_prob)