|
import os |
|
import streamlit as st |
|
import transformers |
|
import torch |
|
import tokenizers |
|
from typing import List, Dict |
|
|
|
st.subheader('Эта демонстрация позволяет поэксперементировать с моделями, которые оценивают, насколько предлагаемый ответ подходит к контексту диалога.') |
|
model_name = st.selectbox( |
|
'Выберите модель', |
|
('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base', 'tinkoff-ai/response-quality-classifier-large') |
|
) |
|
auth_token = os.environ.get('TOKEN') or True |
|
|
|
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda tokenizer: hash(tokenizer.to_str())}, allow_output_mutation=True) |
|
def load_model(model_name: str): |
|
with st.spinner('Loading models...'): |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token) |
|
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name, use_auth_token=auth_token) |
|
if torch.cuda.is_available(): |
|
model = model.cuda() |
|
return tokenizer, model |
|
|
|
context_3 = 'привет' |
|
context_2 = 'привет!' |
|
context_1 = 'как дела?' |
|
|
|
st.markdown('👱🏻♀️ **Настя**: ' + context_3) |
|
st.markdown('🤖 **Диалоговый агент**: ' + context_2) |
|
st.markdown('👱🏻♀️ **Настя**: ' + context_1) |
|
response = st.text_input('🤖 Диалоговый агент:', 'норм') |
|
sample = { |
|
'context_3': context_3, |
|
'context_2': context_2, |
|
'context_1': context_1, |
|
'response': response |
|
} |
|
|
|
SEP_TOKEN = '[SEP]' |
|
CLS_TOKEN = '[CLS]' |
|
RESPONSE_TOKEN = '[RESPONSE_TOKEN]' |
|
MAX_SEQ_LENGTH = 128 |
|
sorted_dialog_columns = ['context_3', 'context_2', 'context_1', 'response'] |
|
|
|
|
|
def tokenize_dialog_data( |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
sample: Dict, |
|
max_seq_length: int, |
|
sorted_dialog_columns: List, |
|
): |
|
""" |
|
Tokenize both contexts and response of dialog data separately |
|
""" |
|
len_message_history = len(sorted_dialog_columns) |
|
max_seq_length = min(max_seq_length, tokenizer.model_max_length) |
|
max_each_message_length = max_seq_length // len_message_history - 1 |
|
messages = [sample[k] for k in sorted_dialog_columns] |
|
result = {model_input_name: [] for model_input_name in tokenizer.model_input_names} |
|
messages = [str(message) if message is not None else '' for message in messages] |
|
tokens = tokenizer( |
|
messages, padding=False, max_length=max_each_message_length, truncation=True, add_special_tokens=False |
|
) |
|
for model_input_name in tokens.keys(): |
|
result[model_input_name].extend(tokens[model_input_name]) |
|
return result |
|
|
|
|
|
def merge_dialog_data( |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
sample: Dict |
|
): |
|
cls_token = tokenizer(CLS_TOKEN, add_special_tokens=False) |
|
sep_token = tokenizer(SEP_TOKEN, add_special_tokens=False) |
|
response_token = tokenizer(RESPONSE_TOKEN, add_special_tokens=False) |
|
model_input_names = tokenizer.model_input_names |
|
result = {} |
|
for model_input_name in model_input_names: |
|
tokens = [] |
|
tokens.extend(cls_token[model_input_name]) |
|
for i, message in enumerate(sample[model_input_name]): |
|
tokens.extend(message) |
|
if i < len(sample[model_input_name]) - 2: |
|
tokens.extend(sep_token[model_input_name]) |
|
elif i == len(sample[model_input_name]) - 2: |
|
tokens.extend(response_token[model_input_name]) |
|
result[model_input_name] = torch.tensor([tokens]) |
|
if torch.cuda.is_available(): |
|
result[model_input_name] = result[model_input_name].cuda() |
|
return result |
|
|
|
|
|
@st.cache |
|
def inference(model_name: str, sample: dict): |
|
tokenizer, model = load_model(model_name) |
|
tokenized_dialog = tokenize_dialog_data(tokenizer, sample, MAX_SEQ_LENGTH, sorted_dialog_columns) |
|
tokens = merge_dialog_data(tokenizer, tokenized_dialog) |
|
with torch.inference_mode(): |
|
logits = model(**tokens).logits |
|
probas = torch.sigmoid(logits)[0].cpu().detach().numpy().tolist() |
|
return probas |
|
|
|
with st.spinner('Running inference...'): |
|
probas = inference(model_name, sample) |
|
st.metric( |
|
label='Вероятность того, что последний ответ диалогового агента релевантный', |
|
value=round(probas[0], 3) |
|
) |
|
st.metric( |
|
label='Вероятность того, что последний ответ диалогового агента вовлечённый', |
|
value=round(probas[1], 3) |
|
) |
|
|