Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| from models.p_logreg import predict_tfidf | |
| from models.p_bert import predict_bert | |
| from models.p_lstm import * | |
| import time | |
| import os | |
| import pandas as pd | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def classify_reviews(): | |
| st.title("Классификация отзывов к фильмам") | |
| input_text = st.text_area("Напишите свой отзыв", height=100) | |
| if st.button("Классифицировать отзыв"): | |
| if input_text: | |
| with st.spinner('Идет классификация...'): | |
| model_lstm = MyModel() | |
| model_lstm.load_state_dict(torch.load(get_model_path(), map_location=device).state_dict()) | |
| model_lstm.to(device) | |
| model_lstm.eval() | |
| # TF-IDF | |
| start_time_tfidf = time.time() | |
| prediction_tfidf = predict_tfidf(input_text) | |
| time_tfidf = time.time() - start_time_tfidf | |
| # LSTM with Attention | |
| start_time_lstm = time.time() | |
| with torch.no_grad(): | |
| # input_ids = tokenize(input_text) | |
| # input_ids = torch.tensor([tokenize(input_text)]).to(device) | |
| input_ids = torch.tensor([tokenize(input_text)], device=device) | |
| prediction_lstm = torch.argmax(torch.nn.functional.softmax(model_lstm(input_ids.to(device)), dim=1), dim=1).item() | |
| time_lstm = time.time() - start_time_lstm | |
| # BERT | |
| start_time_bert = time.time() | |
| prediction_bert = predict_bert(input_text) | |
| time_bert = time.time() - start_time_bert | |
| # st.write("TF-IDF - отзыв:", "нейтральный" if prediction_tfidf == 1 else ("положительный" if prediction_tfidf == 2 else "отрицательный"), ", время:", round(time_tfidf * 1000, 2), "мс") | |
| # st.write("LSTM - отзыв:", "нейтральный" if prediction_lstm == 1 else ("положительный" if prediction_lstm == 2 else "отрицательный"), ", время:", round(time_lstm * 1000, 2), "мс") | |
| # st.write("BERT - отзыв:", "нейтральный" if prediction_bert == 1 else ("положительный" if prediction_bert == 2 else "отрицательный"), ", время:", round(time_bert * 1000, 2), "мс") | |
| # Define colors based on predictions | |
| color_tfidf = "blue" if prediction_tfidf == 1 else ("green" if prediction_tfidf == 2 else "red") | |
| color_lstm = "blue" if prediction_lstm == 1 else ("green" if prediction_lstm == 2 else "red") | |
| color_bert = "blue" if prediction_bert == 1 else ("green" if prediction_bert == 2 else "red") | |
| # Write predictions with colored text | |
| st.markdown( | |
| f"TF-IDF - отзыв: <span style='color:{color_tfidf}'>{ 'нейтральный' if prediction_tfidf == 1 else ('положительный' if prediction_tfidf == 2 else 'отрицательный')}</span>, время: {round(time_tfidf * 1000, 2)} мс", | |
| unsafe_allow_html=True | |
| ) | |
| st.markdown( | |
| f"LSTM - отзыв: <span style='color:{color_lstm}'>{ 'нейтральный' if prediction_lstm == 1 else ('положительный' if prediction_lstm == 2 else 'отрицательный')}</span>, время: {round(time_lstm * 1000, 2)} мс", | |
| unsafe_allow_html=True | |
| ) | |
| st.markdown( | |
| f"BERT - отзыв: <span style='color:{color_bert}'>{ 'нейтральный' if prediction_bert == 1 else ('положительный' if prediction_bert == 2 else 'отрицательный')}</span>, время: {round(time_bert * 1000, 2)} мс", | |
| unsafe_allow_html=True | |
| ) | |
| st.write("------------") | |
| metrics = { | |
| "Models": ["TF-IDF+LogReg", "LSTM + attention", "ruBERTtiny2"], | |
| "f1-macro score": [0.6982, 0.6977, 0.6957], | |
| } | |
| df = pd.DataFrame(metrics) | |
| df.set_index("Models", inplace=True) | |
| df.style.set_caption("Model Performance") | |
| df.index.name = "Модель" | |
| st.write(df) | |
| def get_model_path(): | |
| current_dir = os.path.dirname(__file__) | |
| return os.path.join(current_dir, "..", "models", "model_lstm.pt") | |
| classify_reviews() |