from typing import List from transformers import AutoTokenizer import torch def prepare_tokenizer(tokenizer: AutoTokenizer, token_beg="", token_end=""): """ Add privacy special tokens """ tokenizer.pad_token = tokenizer.eos_token tokenizer.add_special_tokens({"additional_special_tokens": [token_beg, token_end]}) tokenizer.sensitive_beg_id = tokenizer.encode(token_beg, add_special_tokens=False)[0] tokenizer.sensitive_end_id = tokenizer.encode(token_end, add_special_tokens=False)[0] def generate_custom_mask(tokenizer: AutoTokenizer, prompts: List[str], device='cpu', padding_side='left'): """ Given a prepared tokenizer (i.e. with privacy special tokens), a list of prompts with privacy special tokens, tokenize and generate custom masks for a privacy-compatible transformer. """ input_ids = tokenizer(prompts)['input_ids'] return generate_custom_mask_input_ids(tokenizer, input_ids, device=device, padding_side='left')[0] def generate_custom_mask_input_ids(tokenizer: AutoTokenizer, input_ids, device='cpu', padding_side="right"): """ Given a prepared tokenizer (i.e. with privacy special tokens), a list of prompts with privacy special tokens, tokenize and generate custom masks for a privacy-compatible transformer. """ new_input_ids, new_attention_masks, seq_len_list = [], [], [] max_len = 0 batch_size = len(input_ids) for input_id in input_ids: trigger_privacy = False new_input_id = [] mask_pos_list = [] idx = 0 for token_id in input_id: if token_id == tokenizer.sensitive_beg_id: trigger_privacy = True elif token_id == tokenizer.sensitive_end_id: trigger_privacy = False else: new_input_id.append(token_id) if trigger_privacy: mask_pos_list.append(idx) idx += 1 seq_len = len(new_input_id) seq_len_list.append(seq_len) attention_mask = torch.tril(torch.ones((seq_len, seq_len))) for idx in mask_pos_list: # The last token can access everything. attention_mask[idx+1:-1, idx] = 0 attention_mask[idx,:idx] = 1 new_attention_masks.append(attention_mask) new_input_ids.append(new_input_id) max_len = max(max_len, seq_len) new_full_attention_mask = torch.zeros((batch_size, max_len)) for batch, seq_len in enumerate(seq_len_list): if padding_side == 'left': new_full_attention_mask[batch, max_len-seq_len:] = 1 else: new_full_attention_mask[batch, :seq_len] = 1 for idx, (input_ids, attention_mask) in enumerate(zip(new_input_ids, new_attention_masks)): current_len = len(input_ids) new_attention_mask = torch.zeros((max_len, max_len), dtype=torch.long) if padding_side == 'left': input_ids = [tokenizer.pad_token_id]*(max_len - current_len) + input_ids else: input_ids = input_ids + [tokenizer.pad_token_id]*(max_len - current_len) if padding_side == 'left': new_attention_mask[max_len-current_len:, max_len-current_len:] = attention_mask else: new_attention_mask[:current_len,:current_len] = attention_mask new_input_ids[idx] = torch.tensor(input_ids).unsqueeze(0) new_attention_masks[idx] = new_attention_mask.unsqueeze(0) input_id = torch.cat(new_input_ids, dim=0) attention_mask = torch.cat(new_attention_masks, dim=0) return {'input_ids': input_id.to(device), 'attention_mask': attention_mask.to(device)}, new_full_attention_mask.to(device)