File size: 5,248 Bytes
6c3fbea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Tuple, Union

import torch
from torch import nn

from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama import LlamaForCausalLM

class LlamaskForCausalLM(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.special_tokens = nn.Embedding(2, config.hidden_size) # 0 -> mask encoding, 1 -> buffer token
        self.post_init()

    def generate(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            max_tokens: int=32,
            temperature: float=1.0,
    ):
        eos_token_tensor = torch.tensor(self.config.eos_token_id, device=input_ids.device)
        for _ in range(max_tokens):
            outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs['logits'][:,-1,:]/temperature

            probs = torch.nn.functional.softmax(logits, dim=-1)
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

            batch_size, seq_len, _ = attention_mask.shape
            expanded_mask = torch.zeros(batch_size, seq_len + 1, seq_len + 1, dtype=attention_mask.dtype, device=attention_mask.device)

            # Step 1: Copy the existing attention mask (top-left block of the expanded mask)
            expanded_mask[:, :seq_len, :seq_len] = attention_mask

            # Step 2: Copy the last row of the original attention mask into the new row (excluding the last position)
            expanded_mask[:, seq_len, :seq_len] = attention_mask[:, -1, :]

            # Step 3: Set the diagonal of the new token to attend to all previous tokens by setting the new last element to 1
            expanded_mask[:, seq_len, seq_len] = 1

            next_tokens = next_tokens[:, None]

            input_ids = torch.cat([input_ids, next_tokens], dim=-1)
            attention_mask = expanded_mask
            if torch.all(torch.any(next_tokens==eos_token_tensor, dim=1)):
                break
        return input_ids

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        num_buffer_token: Optional[int] = 0,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        batch_size = input_ids.shape[0]

        # print("BEWARE PRIVACY TAG DISABLE")
        # privacy_tag = self.special_tokens(torch.tensor([0], device=input_ids.device))
        # buffer_token = self.special_tokens(torch.tensor([0], device=input_ids.device)).unsqueeze(0)
        
        inputs_embeds = self.model.embed_tokens(input_ids)
        # buffer_tokens = buffer_token.repeat(batch_size, num_buffer_token, 1)
        # inputs_embeds = torch.cat([inputs_embeds, buffer_tokens], dim=1)
        
        # inputs_embeds[attention_mask[:,-1,:]==0] = inputs_embeds[attention_mask[:,-1,:]==0] + privacy_tag

        attention_mask = attention_mask.unsqueeze(1)
        attention_mask = attention_mask.to(inputs_embeds.dtype)
        attention_mask = attention_mask.masked_fill(attention_mask == 0, -1e9)
        attention_mask = attention_mask.masked_fill(attention_mask == 1, float(0.0))
        
        outputs = super().forward(
                input_ids=None,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                labels=labels,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                cache_position=cache_position,
            )
        return outputs