File size: 3,737 Bytes
6c3fbea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import List

from transformers import AutoTokenizer
import torch

def prepare_tokenizer(tokenizer: AutoTokenizer, token_beg="<sensitive>", token_end="</sensitive>"):
    """
    Add privacy special tokens
    """
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.add_special_tokens({"additional_special_tokens": [token_beg, token_end]})
    tokenizer.sensitive_beg_id = tokenizer.encode(token_beg, add_special_tokens=False)[0]
    tokenizer.sensitive_end_id = tokenizer.encode(token_end, add_special_tokens=False)[0]

def generate_custom_mask(tokenizer: AutoTokenizer, prompts: List[str], device='cpu', padding_side='left'):
    """
    Given a prepared tokenizer (i.e. with privacy special tokens), a list of prompts with privacy special tokens,
    tokenize and generate custom masks for a privacy-compatible transformer.
    """
    input_ids = tokenizer(prompts)['input_ids']
    return generate_custom_mask_input_ids(tokenizer, input_ids, device=device, padding_side='left')[0]

def generate_custom_mask_input_ids(tokenizer: AutoTokenizer, input_ids, device='cpu', padding_side="right"):
    """
    Given a prepared tokenizer (i.e. with privacy special tokens), a list of prompts with privacy special tokens,
    tokenize and generate custom masks for a privacy-compatible transformer.
    """
    new_input_ids, new_attention_masks, seq_len_list = [], [], []
    max_len = 0
    batch_size = len(input_ids)
    for input_id in input_ids:
        trigger_privacy = False
        new_input_id = []
        mask_pos_list = []
        idx = 0
        for token_id in input_id:
            if token_id == tokenizer.sensitive_beg_id:
                trigger_privacy = True
            elif token_id == tokenizer.sensitive_end_id:
                trigger_privacy = False
            else:
                new_input_id.append(token_id)
                if trigger_privacy:
                    mask_pos_list.append(idx)
                idx += 1
        seq_len = len(new_input_id)
        seq_len_list.append(seq_len)

        attention_mask = torch.tril(torch.ones((seq_len, seq_len)))
        
        for idx in mask_pos_list:
            # The last token can access everything.
            attention_mask[idx+1:-1, idx] = 0
            attention_mask[idx,:idx] = 1
        new_attention_masks.append(attention_mask)
        new_input_ids.append(new_input_id)

        max_len = max(max_len, seq_len)
    
    new_full_attention_mask = torch.zeros((batch_size, max_len))
    for batch, seq_len in enumerate(seq_len_list):
        if padding_side == 'left':
            new_full_attention_mask[batch, max_len-seq_len:] = 1
        else:
            new_full_attention_mask[batch, :seq_len] = 1

    for idx, (input_ids, attention_mask) in enumerate(zip(new_input_ids, new_attention_masks)):
        current_len = len(input_ids)
        new_attention_mask = torch.zeros((max_len, max_len), dtype=torch.long)
        if padding_side == 'left':
            input_ids = [tokenizer.pad_token_id]*(max_len - current_len) + input_ids
        else:
            input_ids = input_ids + [tokenizer.pad_token_id]*(max_len - current_len)
        
        if padding_side == 'left':
            new_attention_mask[max_len-current_len:, max_len-current_len:] = attention_mask
        else:
            new_attention_mask[:current_len,:current_len] = attention_mask
        new_input_ids[idx] = torch.tensor(input_ids).unsqueeze(0)
        new_attention_masks[idx] = new_attention_mask.unsqueeze(0)
    input_id = torch.cat(new_input_ids, dim=0)
    attention_mask = torch.cat(new_attention_masks, dim=0)
    return {'input_ids': input_id.to(device), 'attention_mask': attention_mask.to(device)}, new_full_attention_mask.to(device)