File size: 3,700 Bytes
cd7602e
cb1ecfc
 
 
 
 
 
 
 
 
 
cd7602e
 
 
cb1ecfc
 
 
7225715
 
 
 
cb1ecfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52719f0
 
 
cb1ecfc
 
7225715
cb1ecfc
 
 
7225715
a2eb4b1
cb1ecfc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import streamlit as st
import transformers
import torch
from typing import List, Dict


model_name = st.selectbox(
    'Выберите модель',
    ('tinkoff-ai/crossencoder-tiny', 'tinkoff-ai/crossencoder-medium', 'tinkoff-ai/crossencoder-large')
)
auth_token = os.environ.get('TOKEN') or True
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name, use_auth_token=auth_token)
if torch.cuda.is_available():
    model = model.cuda()

context_3 = st.text_input('Настя', 'Привет')
context_2 = st.text_input('Диалоговый агент', 'Здарова')
context_1 = st.text_input('Настя', 'Как жизнь?')
response = st.text_input('Диалоговый агент', 'Норм')
sample = {
    'context_3': context_3,
    'context_2': context_2,
    'context_1': context_1,
    'response': response
}


SEP_TOKEN = '[SEP]'
CLS_TOKEN = '[CLS]'
RESPONSE_TOKEN = '[RESPONSE_TOKEN]'
MAX_SEQ_LENGTH = 128
sorted_dialog_columns = ['context_3', 'context_2', 'context_1', 'response']


def tokenize_dialog_data(
    tokenizer: transformers.PreTrainedTokenizer,
    sample: Dict,
    max_seq_length: int,
    sorted_dialog_columns: List,
):
    """
    Tokenize both contexts and response of dialog data separately
    """
    len_message_history = len(sorted_dialog_columns)
    max_seq_length = min(max_seq_length, tokenizer.model_max_length)
    max_each_message_length = max_seq_length // len_message_history - 1
    messages = [sample[k] for k in sorted_dialog_columns]
    result = {model_input_name: [] for model_input_name in tokenizer.model_input_names}
    messages = [str(message) if message is not None else '' for message in messages]
    tokens = tokenizer(
        messages, padding=False, max_length=max_each_message_length, truncation=True, add_special_tokens=False
    )
    for model_input_name in tokens.keys():
        result[model_input_name].extend(tokens[model_input_name])
    return result


def merge_dialog_data(
    tokenizer: transformers.PreTrainedTokenizer, 
    sample: Dict
):
    cls_token = tokenizer(CLS_TOKEN, add_special_tokens=False)
    sep_token = tokenizer(SEP_TOKEN, add_special_tokens=False)
    response_token = tokenizer(RESPONSE_TOKEN, add_special_tokens=False)
    model_input_names = tokenizer.model_input_names
    result = {}
    for model_input_name in model_input_names:
        tokens = []
        tokens.extend(cls_token[model_input_name])
        for i, message in enumerate(sample[model_input_name]):
            tokens.extend(message)
            if i < len(sample[model_input_name]) - 2:
                tokens.extend(sep_token[model_input_name])
            elif i == len(sample[model_input_name]) - 2:
                tokens.extend(response_token[model_input_name])
        result[model_input_name] = torch.tensor([tokens])
        if torch.cuda.is_available():
           result[model_input_name] = result[model_input_name].cuda() 
    return result

tokenized_dialog = tokenize_dialog_data(tokenizer, sample, MAX_SEQ_LENGTH, sorted_dialog_columns)
tokens = merge_dialog_data(tokenizer, tokenized_dialog)
with torch.inference_mode():
    logits = model(**tokens).logits
    probas = torch.sigmoid(logits)[0].cpu().detach().numpy()

st.metric(
    label='Вероятность того, что последний ответ диалогового агента релевантный',
    value=probas[0]
)
st.metric(
    label='Вероятность того, что последний ответ диалогового агента специфичный',
    value=probas[1]
)