nlp_group_project / pages /sasha_main_page_final.py
DanilO0o's picture
added new model
edcd390
from models.bert_classifier import MyTinyBERT
from models.lstm_attention import LSTMAttention
from models.text_preprocessor import MyCustomTextPreprocessor
import streamlit as st
from sklearn.utils.class_weight import compute_class_weight
import torch.nn.functional as F
import torch.optim as optim
import joblib
from torch import nn
from sklearn.base import BaseEstimator, TransformerMixin
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from time import time
from sklearn.feature_extraction.text import TfidfVectorizer
import pymorphy3
import string
import re
import pandas as pd
import numpy as np
import torch
import sklearn
import matplotlib.pyplot as plt
import warnings
warnings.simplefilter("ignore")
# Metrics
# custom
# ======= Глобальная инициализация токенизатора =======
tokenizer = AutoTokenizer.from_pretrained(
"cointegrated/rubert-tiny2") # Для LSTM и BERT
# ======= Инициализация обработчика текста =======
preprocessor = MyCustomTextPreprocessor()
# ======= Загрузка моделей и векторизатора =======
# @st.cache_resource
def load_resources():
# Загрузка TF-IDF векторизатора
vectorizer = joblib.load('models/vectorizer.pkl') # TF-IDF
# Загрузка модели логистической регрессии
# Логистическая регрессия
model1 = joblib.load('models/Sasha_logistic_model2.pkl')
# Настройка модели LSTM
# Используем уже загруженный токенизатор
VOCAB_SIZE = len(tokenizer.get_vocab())
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
OUTPUT_DIM = 10
model2 = LSTMAttention(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)
model2.load_state_dict(torch.load(
'models/Sasha_best_lstm_model3.pth', map_location=torch.device('cpu')))
model2.eval()
# Настройка модели BERT
model3 = MyTinyBERT()
model3.load_state_dict(torch.load(
'models/Sasha_best_model_bert.pth', map_location=torch.device('cpu')))
model3.eval()
return model1, model2, model3, vectorizer
# Загружаем ресурсы
model1, model2, model3, vectorizer = load_resources()
# ======= Предобработка текста =======
def preprocess_for_model1(text):
"""TF-IDF векторизация для логистической регрессии"""
processed_text = preprocessor.preprocess(
text, lemmatize=True) # Лемматизация включена
return vectorizer.transform([processed_text])
def preprocess_for_model2_and_model3(text):
"""Общая обработка для LSTM и BERT моделей (без лемматизации)"""
processed_text = preprocessor.preprocess(
text, lemmatize=False) # Лемматизация выключена
return processed_text
def preprocess_for_model2(text, tokenizer):
"""Токенизация для LSTM модели"""
processed_text = preprocess_for_model2_and_model3(text)
tokenized_data = tokenizer(
[processed_text],
padding=True,
truncation=True,
return_tensors="pt",
max_length=256
)
return tokenized_data["input_ids"], tokenized_data["attention_mask"]
def preprocess_for_model3(text, tokenizer):
"""Токенизация для BERT модели"""
processed_text = preprocess_for_model2_and_model3(text)
tokenized_data = tokenizer(
[processed_text],
padding=True,
truncation=True,
return_tensors="pt",
max_length=256
)
return tokenized_data
# ======= Прогноз и визуализация =======
def predict_and_visualize(text):
# ======= Модель 1 (Logistic Regression) =======
start_time = time() # Начало времени предсказания
vectorized_text = preprocess_for_model1(text)
probs1 = model1.predict_proba(vectorized_text)[0]
model1_time = time() - start_time # Рассчитываем время предсказания для модели 1
# ======= Модель 2 (LSTM & Attention) =======
start_time = time() # Начало времени предсказания
input_ids, _ = preprocess_for_model2(
text, tokenizer) # Получаем только input_ids
with torch.no_grad():
logits2, attn_weights = model2(input_ids) # Передаём только input_ids
probs2 = torch.softmax(logits2, dim=1).numpy()[0]
attention_vector = attn_weights.cpu().numpy()[0]
model2_time = time() - start_time # Рассчитываем время предсказания для модели 2
# ======= Модель 3 (BERT) =======
start_time = time() # Начало времени предсказания
tokenized_text = preprocess_for_model3(text, tokenizer)
with torch.no_grad():
logits3 = model3(tokenized_text)
probs3 = torch.softmax(logits3, dim=1).numpy()[0]
model3_time = time() - start_time # Рассчитываем время предсказания для модели 3
# ======= Финальное предсказание =======
final_probs = (probs1 + probs2 + probs3) / 3
final_class = np.argmax(final_probs)
# ======= Визуализация =======
st.subheader("Распределение вероятностей")
for probs, model_name in zip([probs1, probs2, probs3], ['Model 1 (Logistic Regression)', 'Model 2 (LSTM)', 'Model 3 (BERT)']):
fig, ax = plt.subplots()
ax.bar(range(1, len(probs) + 1), probs) # Сдвиг индекса на +1
ax.set_title(f'{model_name} Probabilities')
ax.set_xlabel('Class (1-10)')
ax.set_ylabel('Probability')
st.pyplot(fig)
# ======= Визуализация внимания (LSTM) =======
st.subheader("Веса внимания (LSTM)")
# Проверяем наличие attention weights
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
tokens = tokens[:len(attention_vector)]
attention_vector = attention_vector[:len(tokens)]
fig, ax = plt.subplots(figsize=(12, 6))
ax.bar(range(len(tokens)), attention_vector, align="center")
ax.set_xticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha="right")
ax.set_title("Attention Weights (LSTM)")
ax.set_xlabel("Токены")
ax.set_ylabel("Вес внимания")
st.pyplot(fig)
# Итоговое предсказание
st.subheader("Итоговое предсказание")
# Смещение на +1
st.write(f"Наиболее вероятный класс: **{final_class + 1}**")
# Вывод времени выполнения
st.subheader("Время выполнения моделей")
st.write(f"Модель 1 (Logistic Regression): {model1_time:.4f} секунд")
st.write(f"Модель 2 (LSTM): {model2_time:.4f} секунд")
st.write(f"Модель 3 (BERT): {model3_time:.4f} секунд")
return final_class
# ======= Streamlit UI =======
st.title("Классификация текстов с 3 моделями")
st.write("Введите текст отзыва, чтобы получить результаты классификации от трёх моделей.")
# Ввод текста пользователем
user_input = st.text_area("Введите текст отзыва:", "")
if st.button("Классифицировать"):
if user_input.strip():
# Прогноз и визуализация
predict_and_visualize(user_input)
else:
st.warning("Введите текст для анализа.")
st.subheader("F1 macro, валидационная выборка")
st.write(f'f1 macro valid logreg=0.2516')
st.write(f'f1 macro valid lstm=0.2515')
st.write(f'f1 macro valid bert=0.2709')