Spaces:
Sleeping
Sleeping
| from typing import List | |
| from transformers import AutoTokenizer | |
| import torch | |
| def prepare_tokenizer(tokenizer: AutoTokenizer, token_beg="<sensitive>", token_end="</sensitive>"): | |
| """ | |
| Add privacy special tokens | |
| """ | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.add_special_tokens({"additional_special_tokens": [token_beg, token_end]}) | |
| tokenizer.sensitive_beg_id = tokenizer.encode(token_beg, add_special_tokens=False)[0] | |
| tokenizer.sensitive_end_id = tokenizer.encode(token_end, add_special_tokens=False)[0] | |
| def generate_custom_mask(tokenizer: AutoTokenizer, prompts: List[str], device='cpu', padding_side='left'): | |
| """ | |
| Given a prepared tokenizer (i.e. with privacy special tokens), a list of prompts with privacy special tokens, | |
| tokenize and generate custom masks for a privacy-compatible transformer. | |
| """ | |
| input_ids = tokenizer(prompts)['input_ids'] | |
| return generate_custom_mask_input_ids(tokenizer, input_ids, device=device, padding_side='left')[0] | |
| def generate_custom_mask_input_ids(tokenizer: AutoTokenizer, input_ids, device='cpu', padding_side="right"): | |
| """ | |
| Given a prepared tokenizer (i.e. with privacy special tokens), a list of prompts with privacy special tokens, | |
| tokenize and generate custom masks for a privacy-compatible transformer. | |
| """ | |
| new_input_ids, new_attention_masks, seq_len_list = [], [], [] | |
| max_len = 0 | |
| batch_size = len(input_ids) | |
| for input_id in input_ids: | |
| trigger_privacy = False | |
| new_input_id = [] | |
| mask_pos_list = [] | |
| idx = 0 | |
| for token_id in input_id: | |
| if token_id == tokenizer.sensitive_beg_id: | |
| trigger_privacy = True | |
| elif token_id == tokenizer.sensitive_end_id: | |
| trigger_privacy = False | |
| else: | |
| new_input_id.append(token_id) | |
| if trigger_privacy: | |
| mask_pos_list.append(idx) | |
| idx += 1 | |
| seq_len = len(new_input_id) | |
| seq_len_list.append(seq_len) | |
| attention_mask = torch.tril(torch.ones((seq_len, seq_len))) | |
| for idx in mask_pos_list: | |
| # The last token can access everything. | |
| attention_mask[idx+1:-1, idx] = 0 | |
| attention_mask[idx,:idx] = 1 | |
| new_attention_masks.append(attention_mask) | |
| new_input_ids.append(new_input_id) | |
| max_len = max(max_len, seq_len) | |
| new_full_attention_mask = torch.zeros((batch_size, max_len)) | |
| for batch, seq_len in enumerate(seq_len_list): | |
| if padding_side == 'left': | |
| new_full_attention_mask[batch, max_len-seq_len:] = 1 | |
| else: | |
| new_full_attention_mask[batch, :seq_len] = 1 | |
| for idx, (input_ids, attention_mask) in enumerate(zip(new_input_ids, new_attention_masks)): | |
| current_len = len(input_ids) | |
| new_attention_mask = torch.zeros((max_len, max_len), dtype=torch.long) | |
| if padding_side == 'left': | |
| input_ids = [tokenizer.pad_token_id]*(max_len - current_len) + input_ids | |
| else: | |
| input_ids = input_ids + [tokenizer.pad_token_id]*(max_len - current_len) | |
| if padding_side == 'left': | |
| new_attention_mask[max_len-current_len:, max_len-current_len:] = attention_mask | |
| else: | |
| new_attention_mask[:current_len,:current_len] = attention_mask | |
| new_input_ids[idx] = torch.tensor(input_ids).unsqueeze(0) | |
| new_attention_masks[idx] = new_attention_mask.unsqueeze(0) | |
| input_id = torch.cat(new_input_ids, dim=0) | |
| attention_mask = torch.cat(new_attention_masks, dim=0) | |
| return {'input_ids': input_id.to(device), 'attention_mask': attention_mask.to(device)}, new_full_attention_mask.to(device) |