Балаганский Никита Николаевич commited on
Commit
030a0f8
1 Parent(s): 8c1b530

add app.py

Browse files
Files changed (4) hide show
  1. app.py +73 -0
  2. generator.py +221 -0
  3. requirements.txt +3 -0
  4. sampling.py +143 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple
3
+
4
+ import streamlit as st
5
+
6
+ import torch
7
+
8
+ import transformers
9
+ import tokenizers
10
+ from torch import autocast
11
+
12
+ from sampling import CAIFSampler, TopKWithTemperatureSampler
13
+ from generator import Generator
14
+
15
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
16
+
17
+
18
+ def main():
19
+ st.subheader(
20
+ 'Эта демонстрация позволяет поэксперементировать с моделями, которые оценивают, насколько предлагаемый ответ подходит к контексту диалога.')
21
+ cls_model_name = st.selectbox(
22
+ 'Выберите модель классификации',
23
+ ('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base',
24
+ 'tinkoff-ai/response-quality-classifier-large')
25
+ )
26
+ lm_model_name = st.selectbox(
27
+ 'Выберите языковую модель',
28
+ ('sberbank-ai/rugpt3small_based_on_gpt2',)
29
+ )
30
+ prompt = st.text_input("Как дела в качалке?")
31
+ auth_token = os.environ.get('TOKEN') or True
32
+ with st.spinner('Running inference...'):
33
+ text = inference(lm_model_name=lm_model_name, cls_model_name=cls_model_name)
34
+ st.text_area(text)
35
+
36
+
37
+
38
+
39
+ @st.cache(hash_funcs={tokenizers.Tokenizer: lambda tokenizer: hash(tokenizer.to_str())}, allow_output_mutation=True)
40
+ def load_generator(lm_model_name: str) -> Generator:
41
+ with st.spinner('Loading language model...'):
42
+ generator = Generator(lm_model_name=lm_model_name, device=device)
43
+ return generator
44
+
45
+ @st.cache(hash_funcs={tokenizers.Tokenizer: lambda tokenizer: hash(tokenizer.to_str())}, allow_output_mutation=True)
46
+ def load_sampler(cls_model_name, lm_tokenizer):
47
+ with st.spinner('Loading classifier model...'):
48
+ sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer)
49
+ return sampler
50
+
51
+
52
+ @st.cache
53
+ def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True) -> str:
54
+ generator = load_generator(lm_model_name=lm_model_name)
55
+ lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
56
+ caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
57
+ generator.set_caif_sampler(caif_sampler)
58
+ ordinary_sampler = TopKWithTemperatureSampler()
59
+ generator.set_ordinary_sampler(ordinary_sampler)
60
+ with autocast(fp16):
61
+ sequences, tokens = generator.sample_sequences(
62
+ num_samples=1,
63
+ input_prompt=prompt,
64
+ max_length=20,
65
+ caif_period=1,
66
+ caif_tokens_num=100,
67
+ entropy=3.2
68
+ )
69
+ return sequences[0]
70
+
71
+
72
+ if __name__ == "__main__":
73
+ main()
generator.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import transformers
5
+
6
+
7
+ class Generator:
8
+ def __init__(self, lm_model_name, device, entropy=None):
9
+
10
+ self.device = device
11
+
12
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
13
+ lm_model_name
14
+ )
15
+ self.lm = transformers.AutoModelForCausalLM.from_pretrained(
16
+ lm_model_name
17
+ ).to(device)
18
+ self.lm.eval()
19
+
20
+ self.lm.config.pad_token_id = self.lm.config.eos_token_id
21
+ self.tokenizer.add_special_tokens(
22
+ {"pad_token": self.tokenizer.decode(self.lm.config.eos_token_id)}
23
+ )
24
+ self.caif_sampler = None
25
+ self.ordinary_sampler = None
26
+ self.entropy_based_stats = {
27
+ "skips": 0,
28
+ "avg_entropy": 0,
29
+ "count": 0,
30
+ }
31
+ self.entropy = entropy
32
+
33
+ def set_caif_sampler(self, sampler):
34
+ self.caif_sampler = sampler
35
+
36
+ def set_ordinary_sampler(self, sampler):
37
+ self.ordinary_sampler = sampler
38
+
39
+ def sample_sequences(
40
+ self,
41
+ num_samples: int,
42
+ input_prompt: Optional[str],
43
+ max_length: int,
44
+ caif_period: int,
45
+ caif_tokens_num: Union[int, None] = None,
46
+ entropy: float = None,
47
+ **sampler_kwargs
48
+ ):
49
+ self.entropy = entropy
50
+
51
+ input_ids, past, ended_sequences = self.get_input_ids(
52
+ input_prompt,
53
+ num_samples,
54
+ )
55
+
56
+ for i in range(max_length):
57
+ is_caif_step = (
58
+ i % caif_period == 0 and self.caif_sampler is not None
59
+ )
60
+ input_ids, past, ended_sequences = self.generation_step(
61
+ input_ids,
62
+ past,
63
+ ended_sequences,
64
+ is_caif_step,
65
+ caif_tokens_num=caif_tokens_num,
66
+ **sampler_kwargs
67
+ )
68
+ if ended_sequences.all():
69
+ break
70
+
71
+ return (
72
+ [
73
+ self.tokenizer.decode(sequence, skip_special_tokens=True)[
74
+ len(input_prompt) :
75
+ ]
76
+ for sequence in input_ids
77
+ ],
78
+ input_ids,
79
+ )
80
+
81
+ def generation_step(
82
+ self,
83
+ input_ids,
84
+ past,
85
+ ended_sequences,
86
+ is_caif_step: bool,
87
+ caif_tokens_num=None,
88
+ **sampler_kwargs
89
+ ):
90
+ prepared_inputs = self.lm.prepare_inputs_for_generation(
91
+ input_ids, past, use_cache=True
92
+ )
93
+ outputs = self.lm(
94
+ **prepared_inputs,
95
+ output_attentions=False,
96
+ output_hidden_states=False,
97
+ return_dict=True
98
+ )
99
+
100
+ past = outputs.past_key_values
101
+ if self.entropy is not None:
102
+ normalized = torch.nn.functional.log_softmax(
103
+ outputs.logits, dim=-1
104
+ )
105
+ p = torch.exp(normalized)
106
+ output_probs = p
107
+ output_information = -normalized
108
+ output_entropy = (output_probs * output_information).sum(-1)[:, -1]
109
+ batch_size = output_entropy.shape[0]
110
+ caif_mask = torch.ge(output_entropy, self.entropy)
111
+ ordinary_mask = ~caif_mask
112
+ self.entropy_based_stats["skips"] += caif_mask.sum() / batch_size
113
+ self.entropy_based_stats["count"] += 1
114
+ self.entropy_based_stats["avg_entropy"] += (
115
+ output_entropy.sum() / batch_size
116
+ )
117
+ flatten_entropy = output_entropy.view(-1).cpu().tolist()
118
+ if "entropy" not in self.entropy_based_stats.keys():
119
+ self.entropy_based_stats["entropy"] = flatten_entropy
120
+ else:
121
+ self.entropy_based_stats["entropy"] += flatten_entropy
122
+
123
+ if caif_mask.sum() == 0:
124
+ next_tokens_sampler = self.ordinary_sampler
125
+ next_tokens = next_tokens_sampler(
126
+ input_ids,
127
+ outputs.logits,
128
+ caif_tokens_num=caif_tokens_num,
129
+ **sampler_kwargs
130
+ )
131
+ next_tokens = (
132
+ next_tokens * (1 - ended_sequences.long())
133
+ + self.lm.config.eos_token_id * ended_sequences.long()
134
+ ).long()
135
+
136
+ elif caif_mask.sum() == batch_size:
137
+ next_tokens_sampler = self.caif_sampler
138
+ next_tokens = next_tokens_sampler(
139
+ input_ids,
140
+ outputs.logits,
141
+ caif_tokens_num=caif_tokens_num,
142
+ **sampler_kwargs
143
+ )
144
+ next_tokens = (
145
+ next_tokens * (1 - ended_sequences.long())
146
+ + self.lm.config.eos_token_id * ended_sequences.long()
147
+ ).long()
148
+
149
+ else:
150
+ next_tokens_caif = self.caif_sampler(
151
+ input_ids[caif_mask],
152
+ outputs.logits[caif_mask],
153
+ caif_tokens_num=caif_tokens_num,
154
+ **sampler_kwargs
155
+ )
156
+ next_tokens_ordinary = self.ordinary_sampler(
157
+ input_ids[ordinary_mask],
158
+ outputs.logits[ordinary_mask],
159
+ caif_tokens_num=caif_tokens_num,
160
+ **sampler_kwargs
161
+ )
162
+ next_tokens_caif = (
163
+ next_tokens_caif * (1 - ended_sequences[caif_mask].long())
164
+ + self.lm.config.eos_token_id
165
+ * ended_sequences[caif_mask].long()
166
+ ).long()
167
+ next_tokens_ordinary = (
168
+ next_tokens_ordinary
169
+ * (1 - ended_sequences[ordinary_mask].long())
170
+ + self.lm.config.eos_token_id
171
+ * ended_sequences[ordinary_mask].long()
172
+ ).long()
173
+
174
+ next_tokens = torch.ones(batch_size).long().to(self.device)
175
+ next_tokens[caif_mask] = next_tokens_caif
176
+ next_tokens[ordinary_mask] = next_tokens_ordinary
177
+ else:
178
+ if is_caif_step:
179
+ next_tokens_sampler = self.caif_sampler
180
+ else:
181
+ next_tokens_sampler = self.ordinary_sampler
182
+
183
+ next_tokens = next_tokens_sampler(
184
+ input_ids,
185
+ outputs.logits,
186
+ caif_tokens_num=caif_tokens_num,
187
+ **sampler_kwargs
188
+ )
189
+
190
+ next_tokens = (
191
+ next_tokens * (1 - ended_sequences.long())
192
+ + self.lm.config.eos_token_id * ended_sequences.long()
193
+ ).long()
194
+
195
+ input_ids = torch.cat(
196
+ [input_ids, next_tokens[:, None].to(self.device)], dim=-1
197
+ )
198
+
199
+ ended_sequences += next_tokens == self.lm.config.eos_token_id
200
+
201
+ return input_ids, past, ended_sequences
202
+
203
+ def get_input_ids(self, input_prompt, num_samples):
204
+ input_ids = torch.tensor([[self.lm.config.bos_token_id]])
205
+ if input_prompt is not None:
206
+ input_prompt = self.tokenizer(
207
+ input_prompt, return_tensors="pt"
208
+ ).input_ids
209
+ input_ids = torch.cat([input_ids, input_prompt], 1)
210
+ input_ids = input_ids.repeat(num_samples, 1).to(self.device)
211
+ past = None
212
+ ended_sequences = torch.zeros(
213
+ input_ids.shape[0], device=self.device
214
+ ).bool()
215
+
216
+ return input_ids, past, ended_sequences
217
+
218
+ @staticmethod
219
+ def sample(unscaled_probs, values):
220
+ samples = torch.multinomial(unscaled_probs, 1)
221
+ return torch.take_along_dim(values, samples, dim=1)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ transformers
3
+ torch
sampling.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import transformers
5
+
6
+
7
+ def sample_from_values(unscaled_probs, values):
8
+ samples = torch.multinomial(unscaled_probs, 1)
9
+ return torch.take_along_dim(values, samples, dim=1)
10
+
11
+
12
+ class TopKWithTemperatureSampler:
13
+ def __call__(self, input_ids, output_logits, top_k, temperature, **kwargs):
14
+
15
+ next_token_logits = output_logits[:, -1]
16
+ next_token_log_probs = F.log_softmax(
17
+ next_token_logits, dim=-1
18
+ )
19
+
20
+ topk_log_probs = next_token_log_probs.topk(top_k, -1)
21
+ next_tokens = sample_from_values(
22
+ torch.exp(topk_log_probs[0] / temperature), topk_log_probs[1]
23
+ ).squeeze(1)
24
+
25
+ return next_tokens
26
+
27
+
28
+ class CAIFSampler:
29
+ def __init__(self, classifier_name, lm_tokenizer, device, invert_cls_probs: bool = False):
30
+ self.device = device
31
+ self.classifier_tokenizer = transformers.AutoTokenizer.from_pretrained(
32
+ classifier_name
33
+ )
34
+ self.classifier_model = (
35
+ transformers.AutoModelForSequenceClassification.from_pretrained(
36
+ classifier_name
37
+ ).to(device)
38
+ )
39
+ self.classifier_model.eval()
40
+ self.lm_tokenizer = lm_tokenizer
41
+ self.invert_cls_probs = invert_cls_probs
42
+
43
+ def __call__(
44
+ self,
45
+ input_ids,
46
+ output_logis,
47
+ top_k,
48
+ temperature,
49
+ top_k_classifier,
50
+ classifier_weight,
51
+ caif_tokens_num=None,
52
+ **kwargs
53
+ ):
54
+ next_token_logits = output_logis[:, -1]
55
+
56
+ next_token_log_probs = F.log_softmax(
57
+ next_token_logits, dim=-1
58
+ )
59
+
60
+ (next_token_unnormalized_probs, topk_indices,) = self.get_unnormalized_probs(
61
+ input_ids,
62
+ next_token_log_probs,
63
+ temperature,
64
+ top_k_classifier,
65
+ classifier_weight,
66
+ caif_tokens_num=caif_tokens_num
67
+ )
68
+ topk_probs = next_token_unnormalized_probs.topk(top_k, -1)
69
+ next_tokens = sample_from_values(
70
+ topk_probs[0],
71
+ torch.take_along_dim(topk_indices, topk_probs[1], dim=1),
72
+ ).squeeze(1)
73
+
74
+ return next_tokens
75
+
76
+ def get_unnormalized_probs(
77
+ self,
78
+ input_ids,
79
+ next_token_log_probs,
80
+ temperature,
81
+ top_k_classifier,
82
+ classifier_weight,
83
+ caif_tokens_num=None
84
+ ):
85
+
86
+ if classifier_weight == 0.0:
87
+ raise ValueError(
88
+ "classifier weight equal to 0 is not supported for CAIF Sampling"
89
+ )
90
+
91
+ top_next_token_log_probs = next_token_log_probs.topk(top_k_classifier, -1)
92
+ classifier_input = torch.cat(
93
+ [
94
+ input_ids.unsqueeze(1).repeat(1, top_k_classifier, 1).flatten(0, 1),
95
+ top_next_token_log_probs[1].view(-1).unsqueeze(-1),
96
+ ],
97
+ -1,
98
+ )
99
+ classifier_input = [
100
+ self.lm_tokenizer.decode(sequence, skip_special_tokens=True)
101
+ for sequence in classifier_input
102
+ ]
103
+
104
+ if self.invert_cls_probs:
105
+ classifier_log_probs = torch.log(
106
+ 1 - self.get_classifier_probs(
107
+ classifier_input, caif_tokens_num=caif_tokens_num
108
+ ).view(-1, top_k_classifier)
109
+ )
110
+ else:
111
+ classifier_log_probs = self.get_classifier_log_probs(
112
+ classifier_input, caif_tokens_num=caif_tokens_num
113
+ ).view(-1, top_k_classifier)
114
+
115
+ next_token_probs = torch.exp(
116
+ (top_next_token_log_probs[0] + classifier_weight * classifier_log_probs)
117
+ / temperature
118
+ )
119
+ return next_token_probs, top_next_token_log_probs[1]
120
+
121
+ def get_classifier_log_probs(self, input, caif_tokens_num=None):
122
+ input_ids = self.classifier_tokenizer(
123
+ input, padding=True, return_tensors="pt"
124
+ ).to(self.device)
125
+ if caif_tokens_num is not None:
126
+ input_ids["input_ids"] = input_ids["input_ids"][:, -caif_tokens_num:]
127
+ if "attention_mask" in input_ids.keys():
128
+ input_ids["attention_mask"] = input_ids["attention_mask"][:, -caif_tokens_num:]
129
+ if "token_type_ids" in input_ids.keys():
130
+ input_ids["token_type_ids"] = input_ids["token_type_ids"][:, -caif_tokens_num:]
131
+ logits = self.classifier_model(**input_ids).logits[:, 0].squeeze(-1)
132
+ return torch.log(torch.sigmoid(logits))
133
+
134
+ def get_classifier_probs(self, input, caif_tokens_num=None):
135
+ input_ids = self.classifier_tokenizer(
136
+ input, padding=True, return_tensors="pt"
137
+ ).to(self.device)
138
+ if caif_tokens_num is not None:
139
+ input_ids["input_ids"] = input_ids["input_ids"][-caif_tokens_num:]
140
+ if "attention_mask" in input_ids.keys():
141
+ input_ids["attention_mask"] = input_ids["attention_mask"][-caif_tokens_num:]
142
+ logits = self.classifier_model(**input_ids).logits[:, 0].squeeze(-1)
143
+ return torch.sigmoid(logits)