Spaces:
Sleeping
Sleeping
File size: 3,737 Bytes
6c3fbea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
from typing import List
from transformers import AutoTokenizer
import torch
def prepare_tokenizer(tokenizer: AutoTokenizer, token_beg="<sensitive>", token_end="</sensitive>"):
"""
Add privacy special tokens
"""
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({"additional_special_tokens": [token_beg, token_end]})
tokenizer.sensitive_beg_id = tokenizer.encode(token_beg, add_special_tokens=False)[0]
tokenizer.sensitive_end_id = tokenizer.encode(token_end, add_special_tokens=False)[0]
def generate_custom_mask(tokenizer: AutoTokenizer, prompts: List[str], device='cpu', padding_side='left'):
"""
Given a prepared tokenizer (i.e. with privacy special tokens), a list of prompts with privacy special tokens,
tokenize and generate custom masks for a privacy-compatible transformer.
"""
input_ids = tokenizer(prompts)['input_ids']
return generate_custom_mask_input_ids(tokenizer, input_ids, device=device, padding_side='left')[0]
def generate_custom_mask_input_ids(tokenizer: AutoTokenizer, input_ids, device='cpu', padding_side="right"):
"""
Given a prepared tokenizer (i.e. with privacy special tokens), a list of prompts with privacy special tokens,
tokenize and generate custom masks for a privacy-compatible transformer.
"""
new_input_ids, new_attention_masks, seq_len_list = [], [], []
max_len = 0
batch_size = len(input_ids)
for input_id in input_ids:
trigger_privacy = False
new_input_id = []
mask_pos_list = []
idx = 0
for token_id in input_id:
if token_id == tokenizer.sensitive_beg_id:
trigger_privacy = True
elif token_id == tokenizer.sensitive_end_id:
trigger_privacy = False
else:
new_input_id.append(token_id)
if trigger_privacy:
mask_pos_list.append(idx)
idx += 1
seq_len = len(new_input_id)
seq_len_list.append(seq_len)
attention_mask = torch.tril(torch.ones((seq_len, seq_len)))
for idx in mask_pos_list:
# The last token can access everything.
attention_mask[idx+1:-1, idx] = 0
attention_mask[idx,:idx] = 1
new_attention_masks.append(attention_mask)
new_input_ids.append(new_input_id)
max_len = max(max_len, seq_len)
new_full_attention_mask = torch.zeros((batch_size, max_len))
for batch, seq_len in enumerate(seq_len_list):
if padding_side == 'left':
new_full_attention_mask[batch, max_len-seq_len:] = 1
else:
new_full_attention_mask[batch, :seq_len] = 1
for idx, (input_ids, attention_mask) in enumerate(zip(new_input_ids, new_attention_masks)):
current_len = len(input_ids)
new_attention_mask = torch.zeros((max_len, max_len), dtype=torch.long)
if padding_side == 'left':
input_ids = [tokenizer.pad_token_id]*(max_len - current_len) + input_ids
else:
input_ids = input_ids + [tokenizer.pad_token_id]*(max_len - current_len)
if padding_side == 'left':
new_attention_mask[max_len-current_len:, max_len-current_len:] = attention_mask
else:
new_attention_mask[:current_len,:current_len] = attention_mask
new_input_ids[idx] = torch.tensor(input_ids).unsqueeze(0)
new_attention_masks[idx] = new_attention_mask.unsqueeze(0)
input_id = torch.cat(new_input_ids, dim=0)
attention_mask = torch.cat(new_attention_masks, dim=0)
return {'input_ids': input_id.to(device), 'attention_mask': attention_mask.to(device)}, new_full_attention_mask.to(device) |