File size: 2,059 Bytes
a476bbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import torch
class Masker:
def __init__(self, tokenizer) -> None:
self.tokenizer = tokenizer
self.mask_token_id = self.tokenizer.mask_token_id
def random_mask(self, input_ids, mask_prob=0.15):
device = input_ids.device
mask = (torch.rand(input_ids.shape) < mask_prob).to(device)
mask = mask & (torch.logical_not(self.tokenizer.is_special_token(input_ids)))
masked_input_ids = input_ids.clone()
masked_input_ids[mask] = self.mask_token_id
return masked_input_ids, mask
def mask_ptm_tokens(
self,
input_ids,
):
device = input_ids.device
is_ptm_mask = self.tokenizer.is_ptm_token(input_ids).to(device)
is_ptm_mask = is_ptm_mask & (
torch.logical_not(self.tokenizer.is_special_token(input_ids))
)
masked_input_ids = input_ids.clone()
masked_input_ids[is_ptm_mask] = self.mask_token_id
return masked_input_ids, is_ptm_mask
def random_and_ptm_mask(self, input_ids, mask_prob=0.15):
device = input_ids.device
mask = (torch.rand(input_ids.shape) < mask_prob).to(device)
mask = mask & (torch.logical_not(self.tokenizer.is_special_token(input_ids)))
is_ptm_mask = self.tokenizer.is_ptm_token(input_ids).to(device)
is_ptm_mask = is_ptm_mask & (
torch.logical_not(self.tokenizer.is_special_token(input_ids))
)
mask = mask | is_ptm_mask
masked_input_ids = input_ids.clone()
masked_input_ids[mask] = self.mask_token_id
return masked_input_ids, mask
def random_or_random_and_ptm_mask(
self, input_ids, ranom_mask_prob=0.15, alternate_prob=0.2
):
"""
alternate between [(1) random mask] and [(2) random mask & ptm mask] by probability alternate_prob
"""
p = torch.rand(1).item()
if p < alternate_prob:
return self.random_mask(input_ids, ranom_mask_prob)
else:
return self.random_and_ptm_mask(input_ids, ranom_mask_prob)
|