nlp-bert-team / pages /1_policlinic.py
VerVelVel's picture
rubert+logreg fo first task
e0e0815
import streamlit as st
import joblib
import pandas as pd
from models.model1.Custom_class import TextPreprocessor
from pathlib import Path
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
from transformers import AutoTokenizer, AutoModel
from models.model2.preprocess_text import TextPreprocessorBERT
project_root = Path(__file__).resolve().parents[1]
models_path = project_root / 'models'
sys.path.append(str(models_path))
from models.model1.lstm_preprocessor import TextPreprocessorWord2Vec
from models.model1.lstm_model import LSTMConcatAttention
# Модель логистической регрессии
pipeline = joblib.load('models/model1/logistic_regression_pipeline.pkl')
st.title('Классификация отзывов на русском языке')
input_text = st.text_area('Введите текст отзыва')
device = 'cpu'
# Загрузка модели LSTM и словаря
@st.cache_resource
def load_lstm_model():
model = LSTMConcatAttention()
weights_path = models_path / 'model1' / 'lstm_weights'
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
lstm_model = load_lstm_model()
@st.cache_resource
def load_int_to_vocab():
vocab_path = models_path / 'model1' / 'lstm_vocab_to_int.pkl'
vocab_to_int = joblib.load(vocab_path)
int_to_vocab = {j:i for i, j in vocab_to_int.items()}
return int_to_vocab
int_to_vocab = load_int_to_vocab()
def plot_and_predict_lstm(input_text):
preprocessor_lstm = TextPreprocessorWord2Vec()
preprocessed = preprocessor_lstm.transform(input_text)
lstm_model.eval()
with torch.inference_mode():
pred, att_scores = lstm_model(preprocessed.long().unsqueeze(0))
lstm_pred = pred.sigmoid().item()
# Получить индексы слов, которые не равны <pad> и не имеют индекс 0
valid_indices = [i for i, x in enumerate(preprocessed) if x.item() != 0 and int_to_vocab[x.item()] != "<pad>"]
# Получить соответствующие оценки внимания и метки слов
valid_att_scores = att_scores.detach().cpu().numpy()[0][valid_indices]
valid_labels = [int_to_vocab[preprocessed[i].item()] for i in valid_indices]
# Упорядочить метки и оценки внимания по убыванию веса смысла
sorted_indices = np.argsort(valid_att_scores)
sorted_labels = [valid_labels[i] for i in sorted_indices]
sorted_att_scores = valid_att_scores[sorted_indices]
# Построить график с учетом только валидных меток
plt.figure(figsize=(4, 8))
plt.barh(np.arange(len(sorted_indices)), sorted_att_scores)
plt.yticks(ticks=np.arange(len(sorted_indices)), labels=sorted_labels)
return lstm_pred, plt
#БЕРТа
@st.cache_resource
def load_logreg_model():
log_bert_path = models_path / 'model1' / 'clf.pkl'
return joblib.load(log_bert_path)
@st.cache_resource
def load_rubert_model():
return AutoModel.from_pretrained('cointegrated/rubert-tiny2')
@st.cache_resource
def load_tokenizer():
return AutoTokenizer.from_pretrained('cointegrated/rubert-tiny2')
logreg_model = load_logreg_model()
rubert_model = load_rubert_model()
tokenizer = load_tokenizer()
if st.button('Предсказать'):
#LOGREG
start_time_lr = time.time()
prediction = pipeline.predict(pd.Series([input_text]))
pred_probe = pipeline.predict_proba(pd.Series([input_text]))
pred_proba_rounded = np.round(pred_probe, 2).flatten()
if prediction[0] == 0:
predicted_class = "POSITIVE"
else:
predicted_class = "NEGATIVE"
st.subheader('Предсказанный класс с помощью логистической регрессии и tf-idf')
end_time_lr = time.time()
time_lr = end_time_lr - start_time_lr
st.write(f'**{predicted_class}** с вероятностью {pred_proba_rounded[0]}')
st.write(f'Время выполнения расчетов {time_lr:.4f} секунд')
#LSTM
start_time_lstm = time.time()
lstm_pred, lstm_plot = plot_and_predict_lstm(input_text)
if lstm_pred > 0.5:
predicted_lstm_class = "POSITIVE"
else:
predicted_lstm_class = "NEGATIVE"
st.subheader('Предсказанный класс с помощью LSTM + Word2Vec + BahdanauAttention:')
end_time_lstm = time.time()
time_lstm = end_time_lstm - start_time_lstm
st.write(f'**{predicted_lstm_class}** с вероятностью {round(lstm_pred, 3)}')
st.write(f'Время выполнения расчетов {time_lstm:.4f} секунд')
st.pyplot(lstm_plot)
#BERT
start_time_bert = time.time()
# Применяем предобработку
preprocessor = TextPreprocessorBERT()
preprocessed_text = preprocessor.transform(input_text)
tokens = tokenizer.encode_plus(
preprocessed_text,
add_special_tokens=True,
truncation=True,
max_length=64,
padding='max_length',
return_tensors='pt'
)
input_ids = tokens['input_ids'].to(device)
attention_mask = tokens['attention_mask'].to(device)
with torch.no_grad():
outputs = rubert_model(input_ids=input_ids, attention_mask=attention_mask)
embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
prediction = logreg_model.predict(embeddings)
pred_prob = logreg_model.predict_proba(embeddings)
pred_prob_rounded = np.round(pred_prob, 2).flatten()
if prediction[0] == 0:
predicted_class = "POSITIVE"
else:
predicted_class = "NEGATIVE"
end_time_bert = time.time()
bert_time = end_time_bert - start_time_bert
st.subheader('Предсказанный класс с помощью модели Rubert-tiny2 + Logistic Regression:')
st.write(f'**{predicted_class}**, с вероятностью {np.round(pred_prob_rounded[0], 2)}')
st.write(f'Время выполнения: {bert_time:.4f} секунд')
st.write("# Сравнение характеристик моделей:")
df = pd.read_csv(str(project_root /'images/full_metrics.csv'))
st.write(df)
st.write("# Информация о датасете:")
st.write("Модель обучалась на предсказание 1 класса")
st.write("Размер датасета - 70597 текстов отзывов")
st.write("Проведена предобработка текста")
st.write("# Информация об обучении модели логистической регрессии и tf-idf:")
st.image(str(project_root / 'images/pipeline_logreg.png'))
st.write("Метрики:")
st.image(str(project_root / 'images/log_reg_metrics.png'))
st.write("# Информация об обучении модели LSTM + Word2Vec + BahdanauAttention:")
st.write("Время обучения модели - 10 эпох")
st.write("# Информация об обучении модели Rubert-tiny2 + Logistic Regression:")
st.write("Использовалась Rubert-tiny2 для получения эмбеддингов и подачей их логистической регрессии")
st.image(str(project_root / 'images/last_mode1_metric.png'), width=1000)