Балаганский Никита Николаевич
commited on
Commit
•
030a0f8
1
Parent(s):
8c1b530
add app.py
Browse files- app.py +73 -0
- generator.py +221 -0
- requirements.txt +3 -0
- sampling.py +143 -0
app.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import transformers
|
9 |
+
import tokenizers
|
10 |
+
from torch import autocast
|
11 |
+
|
12 |
+
from sampling import CAIFSampler, TopKWithTemperatureSampler
|
13 |
+
from generator import Generator
|
14 |
+
|
15 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
16 |
+
|
17 |
+
|
18 |
+
def main():
|
19 |
+
st.subheader(
|
20 |
+
'Эта демонстрация позволяет поэксперементировать с моделями, которые оценивают, насколько предлагаемый ответ подходит к контексту диалога.')
|
21 |
+
cls_model_name = st.selectbox(
|
22 |
+
'Выберите модель классификации',
|
23 |
+
('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base',
|
24 |
+
'tinkoff-ai/response-quality-classifier-large')
|
25 |
+
)
|
26 |
+
lm_model_name = st.selectbox(
|
27 |
+
'Выберите языковую модель',
|
28 |
+
('sberbank-ai/rugpt3small_based_on_gpt2',)
|
29 |
+
)
|
30 |
+
prompt = st.text_input("Как дела в качалке?")
|
31 |
+
auth_token = os.environ.get('TOKEN') or True
|
32 |
+
with st.spinner('Running inference...'):
|
33 |
+
text = inference(lm_model_name=lm_model_name, cls_model_name=cls_model_name)
|
34 |
+
st.text_area(text)
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda tokenizer: hash(tokenizer.to_str())}, allow_output_mutation=True)
|
40 |
+
def load_generator(lm_model_name: str) -> Generator:
|
41 |
+
with st.spinner('Loading language model...'):
|
42 |
+
generator = Generator(lm_model_name=lm_model_name, device=device)
|
43 |
+
return generator
|
44 |
+
|
45 |
+
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda tokenizer: hash(tokenizer.to_str())}, allow_output_mutation=True)
|
46 |
+
def load_sampler(cls_model_name, lm_tokenizer):
|
47 |
+
with st.spinner('Loading classifier model...'):
|
48 |
+
sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer)
|
49 |
+
return sampler
|
50 |
+
|
51 |
+
|
52 |
+
@st.cache
|
53 |
+
def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True) -> str:
|
54 |
+
generator = load_generator(lm_model_name=lm_model_name)
|
55 |
+
lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
|
56 |
+
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
|
57 |
+
generator.set_caif_sampler(caif_sampler)
|
58 |
+
ordinary_sampler = TopKWithTemperatureSampler()
|
59 |
+
generator.set_ordinary_sampler(ordinary_sampler)
|
60 |
+
with autocast(fp16):
|
61 |
+
sequences, tokens = generator.sample_sequences(
|
62 |
+
num_samples=1,
|
63 |
+
input_prompt=prompt,
|
64 |
+
max_length=20,
|
65 |
+
caif_period=1,
|
66 |
+
caif_tokens_num=100,
|
67 |
+
entropy=3.2
|
68 |
+
)
|
69 |
+
return sequences[0]
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
main()
|
generator.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import transformers
|
5 |
+
|
6 |
+
|
7 |
+
class Generator:
|
8 |
+
def __init__(self, lm_model_name, device, entropy=None):
|
9 |
+
|
10 |
+
self.device = device
|
11 |
+
|
12 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
13 |
+
lm_model_name
|
14 |
+
)
|
15 |
+
self.lm = transformers.AutoModelForCausalLM.from_pretrained(
|
16 |
+
lm_model_name
|
17 |
+
).to(device)
|
18 |
+
self.lm.eval()
|
19 |
+
|
20 |
+
self.lm.config.pad_token_id = self.lm.config.eos_token_id
|
21 |
+
self.tokenizer.add_special_tokens(
|
22 |
+
{"pad_token": self.tokenizer.decode(self.lm.config.eos_token_id)}
|
23 |
+
)
|
24 |
+
self.caif_sampler = None
|
25 |
+
self.ordinary_sampler = None
|
26 |
+
self.entropy_based_stats = {
|
27 |
+
"skips": 0,
|
28 |
+
"avg_entropy": 0,
|
29 |
+
"count": 0,
|
30 |
+
}
|
31 |
+
self.entropy = entropy
|
32 |
+
|
33 |
+
def set_caif_sampler(self, sampler):
|
34 |
+
self.caif_sampler = sampler
|
35 |
+
|
36 |
+
def set_ordinary_sampler(self, sampler):
|
37 |
+
self.ordinary_sampler = sampler
|
38 |
+
|
39 |
+
def sample_sequences(
|
40 |
+
self,
|
41 |
+
num_samples: int,
|
42 |
+
input_prompt: Optional[str],
|
43 |
+
max_length: int,
|
44 |
+
caif_period: int,
|
45 |
+
caif_tokens_num: Union[int, None] = None,
|
46 |
+
entropy: float = None,
|
47 |
+
**sampler_kwargs
|
48 |
+
):
|
49 |
+
self.entropy = entropy
|
50 |
+
|
51 |
+
input_ids, past, ended_sequences = self.get_input_ids(
|
52 |
+
input_prompt,
|
53 |
+
num_samples,
|
54 |
+
)
|
55 |
+
|
56 |
+
for i in range(max_length):
|
57 |
+
is_caif_step = (
|
58 |
+
i % caif_period == 0 and self.caif_sampler is not None
|
59 |
+
)
|
60 |
+
input_ids, past, ended_sequences = self.generation_step(
|
61 |
+
input_ids,
|
62 |
+
past,
|
63 |
+
ended_sequences,
|
64 |
+
is_caif_step,
|
65 |
+
caif_tokens_num=caif_tokens_num,
|
66 |
+
**sampler_kwargs
|
67 |
+
)
|
68 |
+
if ended_sequences.all():
|
69 |
+
break
|
70 |
+
|
71 |
+
return (
|
72 |
+
[
|
73 |
+
self.tokenizer.decode(sequence, skip_special_tokens=True)[
|
74 |
+
len(input_prompt) :
|
75 |
+
]
|
76 |
+
for sequence in input_ids
|
77 |
+
],
|
78 |
+
input_ids,
|
79 |
+
)
|
80 |
+
|
81 |
+
def generation_step(
|
82 |
+
self,
|
83 |
+
input_ids,
|
84 |
+
past,
|
85 |
+
ended_sequences,
|
86 |
+
is_caif_step: bool,
|
87 |
+
caif_tokens_num=None,
|
88 |
+
**sampler_kwargs
|
89 |
+
):
|
90 |
+
prepared_inputs = self.lm.prepare_inputs_for_generation(
|
91 |
+
input_ids, past, use_cache=True
|
92 |
+
)
|
93 |
+
outputs = self.lm(
|
94 |
+
**prepared_inputs,
|
95 |
+
output_attentions=False,
|
96 |
+
output_hidden_states=False,
|
97 |
+
return_dict=True
|
98 |
+
)
|
99 |
+
|
100 |
+
past = outputs.past_key_values
|
101 |
+
if self.entropy is not None:
|
102 |
+
normalized = torch.nn.functional.log_softmax(
|
103 |
+
outputs.logits, dim=-1
|
104 |
+
)
|
105 |
+
p = torch.exp(normalized)
|
106 |
+
output_probs = p
|
107 |
+
output_information = -normalized
|
108 |
+
output_entropy = (output_probs * output_information).sum(-1)[:, -1]
|
109 |
+
batch_size = output_entropy.shape[0]
|
110 |
+
caif_mask = torch.ge(output_entropy, self.entropy)
|
111 |
+
ordinary_mask = ~caif_mask
|
112 |
+
self.entropy_based_stats["skips"] += caif_mask.sum() / batch_size
|
113 |
+
self.entropy_based_stats["count"] += 1
|
114 |
+
self.entropy_based_stats["avg_entropy"] += (
|
115 |
+
output_entropy.sum() / batch_size
|
116 |
+
)
|
117 |
+
flatten_entropy = output_entropy.view(-1).cpu().tolist()
|
118 |
+
if "entropy" not in self.entropy_based_stats.keys():
|
119 |
+
self.entropy_based_stats["entropy"] = flatten_entropy
|
120 |
+
else:
|
121 |
+
self.entropy_based_stats["entropy"] += flatten_entropy
|
122 |
+
|
123 |
+
if caif_mask.sum() == 0:
|
124 |
+
next_tokens_sampler = self.ordinary_sampler
|
125 |
+
next_tokens = next_tokens_sampler(
|
126 |
+
input_ids,
|
127 |
+
outputs.logits,
|
128 |
+
caif_tokens_num=caif_tokens_num,
|
129 |
+
**sampler_kwargs
|
130 |
+
)
|
131 |
+
next_tokens = (
|
132 |
+
next_tokens * (1 - ended_sequences.long())
|
133 |
+
+ self.lm.config.eos_token_id * ended_sequences.long()
|
134 |
+
).long()
|
135 |
+
|
136 |
+
elif caif_mask.sum() == batch_size:
|
137 |
+
next_tokens_sampler = self.caif_sampler
|
138 |
+
next_tokens = next_tokens_sampler(
|
139 |
+
input_ids,
|
140 |
+
outputs.logits,
|
141 |
+
caif_tokens_num=caif_tokens_num,
|
142 |
+
**sampler_kwargs
|
143 |
+
)
|
144 |
+
next_tokens = (
|
145 |
+
next_tokens * (1 - ended_sequences.long())
|
146 |
+
+ self.lm.config.eos_token_id * ended_sequences.long()
|
147 |
+
).long()
|
148 |
+
|
149 |
+
else:
|
150 |
+
next_tokens_caif = self.caif_sampler(
|
151 |
+
input_ids[caif_mask],
|
152 |
+
outputs.logits[caif_mask],
|
153 |
+
caif_tokens_num=caif_tokens_num,
|
154 |
+
**sampler_kwargs
|
155 |
+
)
|
156 |
+
next_tokens_ordinary = self.ordinary_sampler(
|
157 |
+
input_ids[ordinary_mask],
|
158 |
+
outputs.logits[ordinary_mask],
|
159 |
+
caif_tokens_num=caif_tokens_num,
|
160 |
+
**sampler_kwargs
|
161 |
+
)
|
162 |
+
next_tokens_caif = (
|
163 |
+
next_tokens_caif * (1 - ended_sequences[caif_mask].long())
|
164 |
+
+ self.lm.config.eos_token_id
|
165 |
+
* ended_sequences[caif_mask].long()
|
166 |
+
).long()
|
167 |
+
next_tokens_ordinary = (
|
168 |
+
next_tokens_ordinary
|
169 |
+
* (1 - ended_sequences[ordinary_mask].long())
|
170 |
+
+ self.lm.config.eos_token_id
|
171 |
+
* ended_sequences[ordinary_mask].long()
|
172 |
+
).long()
|
173 |
+
|
174 |
+
next_tokens = torch.ones(batch_size).long().to(self.device)
|
175 |
+
next_tokens[caif_mask] = next_tokens_caif
|
176 |
+
next_tokens[ordinary_mask] = next_tokens_ordinary
|
177 |
+
else:
|
178 |
+
if is_caif_step:
|
179 |
+
next_tokens_sampler = self.caif_sampler
|
180 |
+
else:
|
181 |
+
next_tokens_sampler = self.ordinary_sampler
|
182 |
+
|
183 |
+
next_tokens = next_tokens_sampler(
|
184 |
+
input_ids,
|
185 |
+
outputs.logits,
|
186 |
+
caif_tokens_num=caif_tokens_num,
|
187 |
+
**sampler_kwargs
|
188 |
+
)
|
189 |
+
|
190 |
+
next_tokens = (
|
191 |
+
next_tokens * (1 - ended_sequences.long())
|
192 |
+
+ self.lm.config.eos_token_id * ended_sequences.long()
|
193 |
+
).long()
|
194 |
+
|
195 |
+
input_ids = torch.cat(
|
196 |
+
[input_ids, next_tokens[:, None].to(self.device)], dim=-1
|
197 |
+
)
|
198 |
+
|
199 |
+
ended_sequences += next_tokens == self.lm.config.eos_token_id
|
200 |
+
|
201 |
+
return input_ids, past, ended_sequences
|
202 |
+
|
203 |
+
def get_input_ids(self, input_prompt, num_samples):
|
204 |
+
input_ids = torch.tensor([[self.lm.config.bos_token_id]])
|
205 |
+
if input_prompt is not None:
|
206 |
+
input_prompt = self.tokenizer(
|
207 |
+
input_prompt, return_tensors="pt"
|
208 |
+
).input_ids
|
209 |
+
input_ids = torch.cat([input_ids, input_prompt], 1)
|
210 |
+
input_ids = input_ids.repeat(num_samples, 1).to(self.device)
|
211 |
+
past = None
|
212 |
+
ended_sequences = torch.zeros(
|
213 |
+
input_ids.shape[0], device=self.device
|
214 |
+
).bool()
|
215 |
+
|
216 |
+
return input_ids, past, ended_sequences
|
217 |
+
|
218 |
+
@staticmethod
|
219 |
+
def sample(unscaled_probs, values):
|
220 |
+
samples = torch.multinomial(unscaled_probs, 1)
|
221 |
+
return torch.take_along_dim(values, samples, dim=1)
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
transformers
|
3 |
+
torch
|
sampling.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
import transformers
|
5 |
+
|
6 |
+
|
7 |
+
def sample_from_values(unscaled_probs, values):
|
8 |
+
samples = torch.multinomial(unscaled_probs, 1)
|
9 |
+
return torch.take_along_dim(values, samples, dim=1)
|
10 |
+
|
11 |
+
|
12 |
+
class TopKWithTemperatureSampler:
|
13 |
+
def __call__(self, input_ids, output_logits, top_k, temperature, **kwargs):
|
14 |
+
|
15 |
+
next_token_logits = output_logits[:, -1]
|
16 |
+
next_token_log_probs = F.log_softmax(
|
17 |
+
next_token_logits, dim=-1
|
18 |
+
)
|
19 |
+
|
20 |
+
topk_log_probs = next_token_log_probs.topk(top_k, -1)
|
21 |
+
next_tokens = sample_from_values(
|
22 |
+
torch.exp(topk_log_probs[0] / temperature), topk_log_probs[1]
|
23 |
+
).squeeze(1)
|
24 |
+
|
25 |
+
return next_tokens
|
26 |
+
|
27 |
+
|
28 |
+
class CAIFSampler:
|
29 |
+
def __init__(self, classifier_name, lm_tokenizer, device, invert_cls_probs: bool = False):
|
30 |
+
self.device = device
|
31 |
+
self.classifier_tokenizer = transformers.AutoTokenizer.from_pretrained(
|
32 |
+
classifier_name
|
33 |
+
)
|
34 |
+
self.classifier_model = (
|
35 |
+
transformers.AutoModelForSequenceClassification.from_pretrained(
|
36 |
+
classifier_name
|
37 |
+
).to(device)
|
38 |
+
)
|
39 |
+
self.classifier_model.eval()
|
40 |
+
self.lm_tokenizer = lm_tokenizer
|
41 |
+
self.invert_cls_probs = invert_cls_probs
|
42 |
+
|
43 |
+
def __call__(
|
44 |
+
self,
|
45 |
+
input_ids,
|
46 |
+
output_logis,
|
47 |
+
top_k,
|
48 |
+
temperature,
|
49 |
+
top_k_classifier,
|
50 |
+
classifier_weight,
|
51 |
+
caif_tokens_num=None,
|
52 |
+
**kwargs
|
53 |
+
):
|
54 |
+
next_token_logits = output_logis[:, -1]
|
55 |
+
|
56 |
+
next_token_log_probs = F.log_softmax(
|
57 |
+
next_token_logits, dim=-1
|
58 |
+
)
|
59 |
+
|
60 |
+
(next_token_unnormalized_probs, topk_indices,) = self.get_unnormalized_probs(
|
61 |
+
input_ids,
|
62 |
+
next_token_log_probs,
|
63 |
+
temperature,
|
64 |
+
top_k_classifier,
|
65 |
+
classifier_weight,
|
66 |
+
caif_tokens_num=caif_tokens_num
|
67 |
+
)
|
68 |
+
topk_probs = next_token_unnormalized_probs.topk(top_k, -1)
|
69 |
+
next_tokens = sample_from_values(
|
70 |
+
topk_probs[0],
|
71 |
+
torch.take_along_dim(topk_indices, topk_probs[1], dim=1),
|
72 |
+
).squeeze(1)
|
73 |
+
|
74 |
+
return next_tokens
|
75 |
+
|
76 |
+
def get_unnormalized_probs(
|
77 |
+
self,
|
78 |
+
input_ids,
|
79 |
+
next_token_log_probs,
|
80 |
+
temperature,
|
81 |
+
top_k_classifier,
|
82 |
+
classifier_weight,
|
83 |
+
caif_tokens_num=None
|
84 |
+
):
|
85 |
+
|
86 |
+
if classifier_weight == 0.0:
|
87 |
+
raise ValueError(
|
88 |
+
"classifier weight equal to 0 is not supported for CAIF Sampling"
|
89 |
+
)
|
90 |
+
|
91 |
+
top_next_token_log_probs = next_token_log_probs.topk(top_k_classifier, -1)
|
92 |
+
classifier_input = torch.cat(
|
93 |
+
[
|
94 |
+
input_ids.unsqueeze(1).repeat(1, top_k_classifier, 1).flatten(0, 1),
|
95 |
+
top_next_token_log_probs[1].view(-1).unsqueeze(-1),
|
96 |
+
],
|
97 |
+
-1,
|
98 |
+
)
|
99 |
+
classifier_input = [
|
100 |
+
self.lm_tokenizer.decode(sequence, skip_special_tokens=True)
|
101 |
+
for sequence in classifier_input
|
102 |
+
]
|
103 |
+
|
104 |
+
if self.invert_cls_probs:
|
105 |
+
classifier_log_probs = torch.log(
|
106 |
+
1 - self.get_classifier_probs(
|
107 |
+
classifier_input, caif_tokens_num=caif_tokens_num
|
108 |
+
).view(-1, top_k_classifier)
|
109 |
+
)
|
110 |
+
else:
|
111 |
+
classifier_log_probs = self.get_classifier_log_probs(
|
112 |
+
classifier_input, caif_tokens_num=caif_tokens_num
|
113 |
+
).view(-1, top_k_classifier)
|
114 |
+
|
115 |
+
next_token_probs = torch.exp(
|
116 |
+
(top_next_token_log_probs[0] + classifier_weight * classifier_log_probs)
|
117 |
+
/ temperature
|
118 |
+
)
|
119 |
+
return next_token_probs, top_next_token_log_probs[1]
|
120 |
+
|
121 |
+
def get_classifier_log_probs(self, input, caif_tokens_num=None):
|
122 |
+
input_ids = self.classifier_tokenizer(
|
123 |
+
input, padding=True, return_tensors="pt"
|
124 |
+
).to(self.device)
|
125 |
+
if caif_tokens_num is not None:
|
126 |
+
input_ids["input_ids"] = input_ids["input_ids"][:, -caif_tokens_num:]
|
127 |
+
if "attention_mask" in input_ids.keys():
|
128 |
+
input_ids["attention_mask"] = input_ids["attention_mask"][:, -caif_tokens_num:]
|
129 |
+
if "token_type_ids" in input_ids.keys():
|
130 |
+
input_ids["token_type_ids"] = input_ids["token_type_ids"][:, -caif_tokens_num:]
|
131 |
+
logits = self.classifier_model(**input_ids).logits[:, 0].squeeze(-1)
|
132 |
+
return torch.log(torch.sigmoid(logits))
|
133 |
+
|
134 |
+
def get_classifier_probs(self, input, caif_tokens_num=None):
|
135 |
+
input_ids = self.classifier_tokenizer(
|
136 |
+
input, padding=True, return_tensors="pt"
|
137 |
+
).to(self.device)
|
138 |
+
if caif_tokens_num is not None:
|
139 |
+
input_ids["input_ids"] = input_ids["input_ids"][-caif_tokens_num:]
|
140 |
+
if "attention_mask" in input_ids.keys():
|
141 |
+
input_ids["attention_mask"] = input_ids["attention_mask"][-caif_tokens_num:]
|
142 |
+
logits = self.classifier_model(**input_ids).logits[:, 0].squeeze(-1)
|
143 |
+
return torch.sigmoid(logits)
|