Spaces:
Sleeping
Sleeping
import streamlit as st | |
import joblib | |
import pandas as pd | |
from models.model1.Custom_class import TextPreprocessor | |
from pathlib import Path | |
import sys | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import time | |
from transformers import AutoTokenizer, AutoModel | |
from models.model2.preprocess_text import TextPreprocessorBERT | |
project_root = Path(__file__).resolve().parents[1] | |
models_path = project_root / 'models' | |
sys.path.append(str(models_path)) | |
from models.model1.lstm_preprocessor import TextPreprocessorWord2Vec | |
from models.model1.lstm_model import LSTMConcatAttention | |
# Модель логистической регрессии | |
pipeline = joblib.load('models/model1/logistic_regression_pipeline.pkl') | |
st.title('Классификация отзывов на русском языке') | |
input_text = st.text_area('Введите текст отзыва') | |
device = 'cpu' | |
# Загрузка модели LSTM и словаря | |
def load_lstm_model(): | |
model = LSTMConcatAttention() | |
weights_path = models_path / 'model1' / 'lstm_weights' | |
state_dict = torch.load(weights_path, map_location=device) | |
model.load_state_dict(state_dict) | |
model.to(device) | |
model.eval() | |
return model | |
lstm_model = load_lstm_model() | |
def load_int_to_vocab(): | |
vocab_path = models_path / 'model1' / 'lstm_vocab_to_int.pkl' | |
vocab_to_int = joblib.load(vocab_path) | |
int_to_vocab = {j:i for i, j in vocab_to_int.items()} | |
return int_to_vocab | |
int_to_vocab = load_int_to_vocab() | |
def plot_and_predict_lstm(input_text): | |
preprocessor_lstm = TextPreprocessorWord2Vec() | |
preprocessed = preprocessor_lstm.transform(input_text) | |
lstm_model.eval() | |
with torch.inference_mode(): | |
pred, att_scores = lstm_model(preprocessed.long().unsqueeze(0)) | |
lstm_pred = pred.sigmoid().item() | |
# Получить индексы слов, которые не равны <pad> и не имеют индекс 0 | |
valid_indices = [i for i, x in enumerate(preprocessed) if x.item() != 0 and int_to_vocab[x.item()] != "<pad>"] | |
# Получить соответствующие оценки внимания и метки слов | |
valid_att_scores = att_scores.detach().cpu().numpy()[0][valid_indices] | |
valid_labels = [int_to_vocab[preprocessed[i].item()] for i in valid_indices] | |
# Упорядочить метки и оценки внимания по убыванию веса смысла | |
sorted_indices = np.argsort(valid_att_scores) | |
sorted_labels = [valid_labels[i] for i in sorted_indices] | |
sorted_att_scores = valid_att_scores[sorted_indices] | |
# Построить график с учетом только валидных меток | |
plt.figure(figsize=(4, 8)) | |
plt.barh(np.arange(len(sorted_indices)), sorted_att_scores) | |
plt.yticks(ticks=np.arange(len(sorted_indices)), labels=sorted_labels) | |
return lstm_pred, plt | |
#БЕРТа | |
def load_logreg_model(): | |
log_bert_path = models_path / 'model1' / 'clf.pkl' | |
return joblib.load(log_bert_path) | |
def load_rubert_model(): | |
return AutoModel.from_pretrained('cointegrated/rubert-tiny2') | |
def load_tokenizer(): | |
return AutoTokenizer.from_pretrained('cointegrated/rubert-tiny2') | |
logreg_model = load_logreg_model() | |
rubert_model = load_rubert_model() | |
tokenizer = load_tokenizer() | |
if st.button('Предсказать'): | |
#LOGREG | |
start_time_lr = time.time() | |
prediction = pipeline.predict(pd.Series([input_text])) | |
pred_probe = pipeline.predict_proba(pd.Series([input_text])) | |
pred_proba_rounded = np.round(pred_probe, 2).flatten() | |
if prediction[0] == 0: | |
predicted_class = "POSITIVE" | |
else: | |
predicted_class = "NEGATIVE" | |
st.subheader('Предсказанный класс с помощью логистической регрессии и tf-idf') | |
end_time_lr = time.time() | |
time_lr = end_time_lr - start_time_lr | |
st.write(f'**{predicted_class}** с вероятностью {pred_proba_rounded[0]}') | |
st.write(f'Время выполнения расчетов {time_lr:.4f} секунд') | |
#LSTM | |
start_time_lstm = time.time() | |
lstm_pred, lstm_plot = plot_and_predict_lstm(input_text) | |
if lstm_pred > 0.5: | |
predicted_lstm_class = "POSITIVE" | |
else: | |
predicted_lstm_class = "NEGATIVE" | |
st.subheader('Предсказанный класс с помощью LSTM + Word2Vec + BahdanauAttention:') | |
end_time_lstm = time.time() | |
time_lstm = end_time_lstm - start_time_lstm | |
st.write(f'**{predicted_lstm_class}** с вероятностью {round(lstm_pred, 3)}') | |
st.write(f'Время выполнения расчетов {time_lstm:.4f} секунд') | |
st.pyplot(lstm_plot) | |
#BERT | |
start_time_bert = time.time() | |
# Применяем предобработку | |
preprocessor = TextPreprocessorBERT() | |
preprocessed_text = preprocessor.transform(input_text) | |
tokens = tokenizer.encode_plus( | |
preprocessed_text, | |
add_special_tokens=True, | |
truncation=True, | |
max_length=64, | |
padding='max_length', | |
return_tensors='pt' | |
) | |
input_ids = tokens['input_ids'].to(device) | |
attention_mask = tokens['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = rubert_model(input_ids=input_ids, attention_mask=attention_mask) | |
embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() | |
prediction = logreg_model.predict(embeddings) | |
pred_prob = logreg_model.predict_proba(embeddings) | |
pred_prob_rounded = np.round(pred_prob, 2).flatten() | |
if prediction[0] == 0: | |
predicted_class = "POSITIVE" | |
else: | |
predicted_class = "NEGATIVE" | |
end_time_bert = time.time() | |
bert_time = end_time_bert - start_time_bert | |
st.subheader('Предсказанный класс с помощью модели Rubert-tiny2 + Logistic Regression:') | |
st.write(f'**{predicted_class}**, с вероятностью {np.round(pred_prob_rounded[0], 2)}') | |
st.write(f'Время выполнения: {bert_time:.4f} секунд') | |
st.write("# Сравнение характеристик моделей:") | |
df = pd.read_csv(str(project_root /'images/full_metrics.csv')) | |
st.write(df) | |
st.write("# Информация о датасете:") | |
st.write("Модель обучалась на предсказание 1 класса") | |
st.write("Размер датасета - 70597 текстов отзывов") | |
st.write("Проведена предобработка текста") | |
st.write("# Информация об обучении модели логистической регрессии и tf-idf:") | |
st.image(str(project_root / 'images/pipeline_logreg.png')) | |
st.write("Метрики:") | |
st.image(str(project_root / 'images/log_reg_metrics.png')) | |
st.write("# Информация об обучении модели LSTM + Word2Vec + BahdanauAttention:") | |
st.write("Время обучения модели - 10 эпох") | |
st.write("# Информация об обучении модели Rubert-tiny2 + Logistic Regression:") | |
st.write("Использовалась Rubert-tiny2 для получения эмбеддингов и подачей их логистической регрессии") | |
st.image(str(project_root / 'images/last_mode1_metric.png'), width=1000) | |