File size: 10,175 Bytes
0bf81ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import torch
from torch import nn
import torch.nn.functional as F
from typing import Union, List, Dict, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTJForCausalLM
from transformers.generation_logits_process import (
    LogitsProcessorList,
    NoBadWordsLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
)

from utils import (
    NEGATIVE_INF, HALF_NEGATIVE_INF,
    logits_to_entropy, mask_pad
)
from clipcap import ClipCap


class Policy(nn.Module):
    def __init__(self, model_name, temperature, device, clipcap_path='', fix_gpt=False,
                 use_transformer_mapper: bool = False, use_ptuning_v2: bool = False,
                 prefix_length=10, clipcap_num_layers: int = 1,
                 label_path: str = '', model_weight: str = 'None', use_label_prefix: bool = False):
        super().__init__()

        self.device = device

        self.model = ClipCap(model_name, device,
                             model_path=clipcap_path, fix_gpt=fix_gpt,
                             prefix_length=prefix_length,
                             num_layers=clipcap_num_layers,
                             label_path=label_path, model_weight=model_weight,
                             use_transformer_mapper=use_transformer_mapper,
                             use_ptuning_v2=use_ptuning_v2,
                             use_label_prefix=use_label_prefix)

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, pad_token="<|endoftext|>")
        self.model.gpt.config.pad_token_id = self.tokenizer.pad_token_id

        self.temperature = temperature

    def get_processor(self, no_repeat_ngram_size: int = 3):
        logits_processor = LogitsProcessorList()
        if no_repeat_ngram_size > 0:
            logits_processor.append(NoRepeatNGramLogitsProcessor(ngram_size=no_repeat_ngram_size))
        '''
        logits_processor.append(NoBadWordsLogitsProcessor([[self.tokenizer.pad_token_id]],
                                                          self.tokenizer.pad_token_id))
        '''
        return logits_processor

    def sample(self,
               input_ids: torch.Tensor = None,
               features: torch.Tensor = None,
               attention_mask: torch.Tensor = None,
               labels: Optional[torch.Tensor] = None,
               max_len: int = 20,
               sample: bool = True,
               top_k: int = None,
               top_p: float = None,
               temperature: float = None,
               no_repeat_ngram_size: int = 0,
               invalidate_eos: bool = True,
               device = None) -> Dict[str, Union[torch.Tensor, List[str]]]:
        if device is None:
            device = self.device
        if temperature is None:
            temperature = self.temperature

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        model_kwargs = {'attention_mask': attention_mask}
        batch_size, input_seq_len = input_ids.shape

        logits_processor = self.get_processor(no_repeat_ngram_size=no_repeat_ngram_size)

        logits_warper = self.model.gpt._get_logits_warper(
            top_k=top_k, top_p=top_p, temperature=temperature, num_beams=1
        )

        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
        output_logprob = torch.zeros([batch_size, 0], device=device)
        eos_logprobs = torch.zeros([batch_size, 0], device=device)
        output_mask = torch.ones([batch_size, 0], dtype=torch.long, device=device)

        self.model.eval()
        with torch.no_grad():
            for step in range(max_len):
                # prepare model inputs
                model_inputs = self.model.prepare_inputs_for_generation(input_ids,
                                                                        features=features,
                                                                        labels=labels,
                                                                        **model_kwargs)

                # forward pass to get next token
                outputs = self.model(
                    **model_inputs,
                    device=device
                )

                # in the first decoding step, we want to use the 'real' last position for each sentence
                if step == 0:
                    last_non_masked_idx = torch.sum(attention_mask, dim=1) - 1
                    next_token_logits = outputs.logits[range(batch_size), last_non_masked_idx, :]
                else:
                    next_token_logits = outputs.logits[:, -1, :]

                negative_inf = HALF_NEGATIVE_INF if next_token_logits.dtype == torch.half else NEGATIVE_INF
                next_token_scores = logits_processor(input_ids, next_token_logits)
                if invalidate_eos:
                    next_token_scores[:, self.tokenizer.eos_token_id] = negative_inf  # no endoftext
                log_prob = F.log_softmax(next_token_scores, dim=-1)  # authentic sampling distribution
                next_token_scores = logits_warper(input_ids, next_token_scores)
                if sample:
                    # Temperature (higher temperature => more likely to sample low probability tokens)
                    probs = F.softmax(next_token_scores, dim=-1)
                    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                else:
                    # Greedy decoding
                    next_tokens = torch.argmax(next_token_scores, dim=-1)

                # finished sentences should have their next token be a padding token
                next_tokens = next_tokens * unfinished_sequences + self.tokenizer.pad_token_id * (1 - unfinished_sequences)

                    # update output mask
                output_mask = torch.cat([output_mask, unfinished_sequences[:, None]], dim=-1)
                # update output log probability
                eos_logprob = log_prob[:, self.tokenizer.eos_token_id]
                eos_logprob = eos_logprob * unfinished_sequences + negative_inf * (1 - unfinished_sequences)
                eos_logprobs = torch.cat([eos_logprobs, eos_logprob[:, None]], dim=-1)

                token_logprob = torch.gather(log_prob, 1, next_tokens[:, None]).squeeze(1)
                token_logprob = token_logprob * unfinished_sequences + negative_inf * (1 - unfinished_sequences)
                output_logprob = torch.cat([output_logprob, token_logprob[:, None]], dim=-1)

                # update generated ids, model inputs for next step
                input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
                model_kwargs = self.model.gpt._update_model_kwargs_for_generation(
                    outputs, model_kwargs, is_encoder_decoder=self.model.gpt.config.is_encoder_decoder
                )

                # if eos_token was found in one sentence, set sentence to finished
                unfinished_sequences = unfinished_sequences.mul((next_tokens != self.tokenizer.eos_token_id).long())

                if unfinished_sequences.max() == 0:
                    break

        response_ids = input_ids[:, input_seq_len:]
        response_text = [self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                         for output in response_ids]

        prompt_ids = input_ids[:, :input_seq_len]
        prompts = [self.tokenizer.decode(query, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                    for query in prompt_ids]
        eos_probs = eos_logprobs.exp()

        return {
            'query/input_ids': prompt_ids,
            'query/text': prompts,
            'query/mask': attention_mask,
            'response/input_ids': response_ids,
            'response/text': response_text,
            'response/mask': output_mask,
            'response/log_prob': output_logprob,
            'response/eos_prob': eos_probs,
        }

    def forward_pass(self,
                     query_input_ids: torch.Tensor,
                     query_mask: torch.Tensor,
                     response_input_ids: torch.Tensor,
                     response_mask: torch.Tensor,
                     features: torch.Tensor,
                     labels: Optional[torch.Tensor] = None,
                     invalidate_eos: bool = True,
                     device = None):

        if device is None:
            device = self.device

        batch_size, query_seq_len = query_input_ids.shape
        input_ids = torch.cat([query_input_ids, response_input_ids], dim=-1)
        attention_mask = torch.cat([query_mask, response_mask], dim=-1)

        # forward pass to get next token
        outputs = self.model(
            input_ids,
            features,
            attention_mask,
            labels,
            device=device
        )
        # get the first logit
        query_logits = outputs.logits[:, :query_seq_len, :]
        last_non_masked_idx = torch.sum(query_mask, dim=1) - 1
        first_logits = query_logits[range(batch_size), last_non_masked_idx, :]
        # get the second to last logit
        response_logits = outputs.logits[:, query_seq_len:-1, :]
        logits = torch.cat([first_logits[:, None], response_logits], dim=1)

        negative_inf = HALF_NEGATIVE_INF if logits.dtype == torch.half else NEGATIVE_INF
        if invalidate_eos:
            logits[:, :, self.tokenizer.eos_token_id] = negative_inf  # no endoftext

        log_prob = F.log_softmax(logits, dim=-1)
        output_logprob = torch.gather(log_prob, 2, response_input_ids[:, :, None]).squeeze(2)
        output_entropy = logits_to_entropy(logits)
        eos_prob = F.softmax(logits, dim=-1)[:, :, self.tokenizer.eos_token_id]

        pos_logit = torch.gather(logits, 2, response_input_ids[:, :, None]).squeeze(2)

        return {
            'response/log_prob': mask_pad(output_logprob, response_mask),
            'response/eos_prob': mask_pad(eos_prob, response_mask),
            'response/entropy': mask_pad(output_entropy, response_mask),
            'response/pos_logit': mask_pad(pos_logit, response_mask),
            'response/logits': logits,
        }