Spaces:
Sleeping
Sleeping
from models.bert_classifier import MyTinyBERT | |
from models.lstm_attention import LSTMAttention | |
from models.text_preprocessor import MyCustomTextPreprocessor | |
import streamlit as st | |
from sklearn.utils.class_weight import compute_class_weight | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import joblib | |
from torch import nn | |
from sklearn.base import BaseEstimator, TransformerMixin | |
from transformers import AutoTokenizer, AutoModel | |
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.model_selection import train_test_split | |
from torch.utils.data import DataLoader, TensorDataset | |
from time import time | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
import pymorphy3 | |
import string | |
import re | |
import pandas as pd | |
import numpy as np | |
import torch | |
import sklearn | |
import matplotlib.pyplot as plt | |
import warnings | |
warnings.simplefilter("ignore") | |
# Metrics | |
# custom | |
# ======= Глобальная инициализация токенизатора ======= | |
tokenizer = AutoTokenizer.from_pretrained( | |
"cointegrated/rubert-tiny2") # Для LSTM и BERT | |
# ======= Инициализация обработчика текста ======= | |
preprocessor = MyCustomTextPreprocessor() | |
# ======= Загрузка моделей и векторизатора ======= | |
# @st.cache_resource | |
def load_resources(): | |
# Загрузка TF-IDF векторизатора | |
vectorizer = joblib.load('models/vectorizer.pkl') # TF-IDF | |
# Загрузка модели логистической регрессии | |
# Логистическая регрессия | |
model1 = joblib.load('models/Sasha_logistic_model2.pkl') | |
# Настройка модели LSTM | |
# Используем уже загруженный токенизатор | |
VOCAB_SIZE = len(tokenizer.get_vocab()) | |
EMBEDDING_DIM = 128 | |
HIDDEN_DIM = 256 | |
OUTPUT_DIM = 10 | |
model2 = LSTMAttention(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM) | |
model2.load_state_dict(torch.load( | |
'models/Sasha_best_lstm_model3.pth', map_location=torch.device('cpu'))) | |
model2.eval() | |
# Настройка модели BERT | |
model3 = MyTinyBERT() | |
model3.load_state_dict(torch.load( | |
'models/Sasha_best_model_bert.pth', map_location=torch.device('cpu'))) | |
model3.eval() | |
return model1, model2, model3, vectorizer | |
# Загружаем ресурсы | |
model1, model2, model3, vectorizer = load_resources() | |
# ======= Предобработка текста ======= | |
def preprocess_for_model1(text): | |
"""TF-IDF векторизация для логистической регрессии""" | |
processed_text = preprocessor.preprocess( | |
text, lemmatize=True) # Лемматизация включена | |
return vectorizer.transform([processed_text]) | |
def preprocess_for_model2_and_model3(text): | |
"""Общая обработка для LSTM и BERT моделей (без лемматизации)""" | |
processed_text = preprocessor.preprocess( | |
text, lemmatize=False) # Лемматизация выключена | |
return processed_text | |
def preprocess_for_model2(text, tokenizer): | |
"""Токенизация для LSTM модели""" | |
processed_text = preprocess_for_model2_and_model3(text) | |
tokenized_data = tokenizer( | |
[processed_text], | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
max_length=256 | |
) | |
return tokenized_data["input_ids"], tokenized_data["attention_mask"] | |
def preprocess_for_model3(text, tokenizer): | |
"""Токенизация для BERT модели""" | |
processed_text = preprocess_for_model2_and_model3(text) | |
tokenized_data = tokenizer( | |
[processed_text], | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
max_length=256 | |
) | |
return tokenized_data | |
# ======= Прогноз и визуализация ======= | |
def predict_and_visualize(text): | |
# ======= Модель 1 (Logistic Regression) ======= | |
start_time = time() # Начало времени предсказания | |
vectorized_text = preprocess_for_model1(text) | |
probs1 = model1.predict_proba(vectorized_text)[0] | |
model1_time = time() - start_time # Рассчитываем время предсказания для модели 1 | |
# ======= Модель 2 (LSTM & Attention) ======= | |
start_time = time() # Начало времени предсказания | |
input_ids, _ = preprocess_for_model2( | |
text, tokenizer) # Получаем только input_ids | |
with torch.no_grad(): | |
logits2, attn_weights = model2(input_ids) # Передаём только input_ids | |
probs2 = torch.softmax(logits2, dim=1).numpy()[0] | |
attention_vector = attn_weights.cpu().numpy()[0] | |
model2_time = time() - start_time # Рассчитываем время предсказания для модели 2 | |
# ======= Модель 3 (BERT) ======= | |
start_time = time() # Начало времени предсказания | |
tokenized_text = preprocess_for_model3(text, tokenizer) | |
with torch.no_grad(): | |
logits3 = model3(tokenized_text) | |
probs3 = torch.softmax(logits3, dim=1).numpy()[0] | |
model3_time = time() - start_time # Рассчитываем время предсказания для модели 3 | |
# ======= Финальное предсказание ======= | |
final_probs = (probs1 + probs2 + probs3) / 3 | |
final_class = np.argmax(final_probs) | |
# ======= Визуализация ======= | |
st.subheader("Распределение вероятностей") | |
for probs, model_name in zip([probs1, probs2, probs3], ['Model 1 (Logistic Regression)', 'Model 2 (LSTM)', 'Model 3 (BERT)']): | |
fig, ax = plt.subplots() | |
ax.bar(range(1, len(probs) + 1), probs) # Сдвиг индекса на +1 | |
ax.set_title(f'{model_name} Probabilities') | |
ax.set_xlabel('Class (1-10)') | |
ax.set_ylabel('Probability') | |
st.pyplot(fig) | |
# ======= Визуализация внимания (LSTM) ======= | |
st.subheader("Веса внимания (LSTM)") | |
# Проверяем наличие attention weights | |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
tokens = tokens[:len(attention_vector)] | |
attention_vector = attention_vector[:len(tokens)] | |
fig, ax = plt.subplots(figsize=(12, 6)) | |
ax.bar(range(len(tokens)), attention_vector, align="center") | |
ax.set_xticks(range(len(tokens))) | |
ax.set_xticklabels(tokens, rotation=45, ha="right") | |
ax.set_title("Attention Weights (LSTM)") | |
ax.set_xlabel("Токены") | |
ax.set_ylabel("Вес внимания") | |
st.pyplot(fig) | |
# Итоговое предсказание | |
st.subheader("Итоговое предсказание") | |
# Смещение на +1 | |
st.write(f"Наиболее вероятный класс: **{final_class + 1}**") | |
# Вывод времени выполнения | |
st.subheader("Время выполнения моделей") | |
st.write(f"Модель 1 (Logistic Regression): {model1_time:.4f} секунд") | |
st.write(f"Модель 2 (LSTM): {model2_time:.4f} секунд") | |
st.write(f"Модель 3 (BERT): {model3_time:.4f} секунд") | |
return final_class | |
# ======= Streamlit UI ======= | |
st.title("Классификация текстов с 3 моделями") | |
st.write("Введите текст отзыва, чтобы получить результаты классификации от трёх моделей.") | |
# Ввод текста пользователем | |
user_input = st.text_area("Введите текст отзыва:", "") | |
if st.button("Классифицировать"): | |
if user_input.strip(): | |
# Прогноз и визуализация | |
predict_and_visualize(user_input) | |
else: | |
st.warning("Введите текст для анализа.") | |
st.subheader("F1 macro, валидационная выборка") | |
st.write(f'f1 macro valid logreg=0.2516') | |
st.write(f'f1 macro valid lstm=0.2515') | |
st.write(f'f1 macro valid bert=0.2709') | |