nlp_proj / pages /
Add styles
import base64
import json
import pickle
import pandas as pd
import streamlit as st
import torch
import torch.nn as nn
import transformers
from model.funcs import (create_model_and_tokenizer, execution_time,
load_model, predict_sentiment)
from model.model import LSTMConcatAttentionEmbed
from preprocessing.preprocessing import data_preprocessing
from preprocessing.rnn_preprocessing import preprocess_single_string
def get_base64(file_path):
with open(file_path, "rb") as file:
base64_bytes = base64.b64encode(
base64_string = base64_bytes.decode("utf-8")
return base64_string
def set_background(png_file):
bin_str = get_base64(png_file)
page_bg_img = (
.stApp {
background-image: url("data:image/png;base64,%s");
background-size: auto;
% bin_str
st.markdown(page_bg_img, unsafe_allow_html=True)
def load_logreg():
with open("vectorizer.pkl", "rb") as f:
logreg_vectorizer = pickle.load(f)
with open("logreg_model.pkl", "rb") as f:
logreg_predictor = pickle.load(f)
return logreg_vectorizer, logreg_predictor
logreg_vectorizer, logreg_predictor = load_logreg()
def load_lstm():
with open("model/vocab.json", "r") as f:
vocab_to_int = json.load(f)
with open("model/int_vocab.json", "r") as f:
int_to_vocab = json.load(f)
model_concat_embed = LSTMConcatAttentionEmbed()
return vocab_to_int, int_to_vocab, model_concat_embed
vocab_to_int, int_to_vocab, model_concat_embed = load_lstm()
def load_bert():
model_class = transformers.AutoModel
tokenizer_class = transformers.AutoTokenizer
pretrained_weights = "cointegrated/rubert-tiny2"
weights_path = "model/best_bert_weights.pth"
model = load_model(model_class, pretrained_weights, weights_path)
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
return model, tokenizer
model, tokenizer = load_bert()
def plot_and_predict(review: str, SEQ_LEN: int, model: nn.Module):
inp = preprocess_single_string(review, SEQ_LEN, vocab_to_int)
with torch.inference_mode():
pred, _ = model(inp.long().unsqueeze(0))
pred = pred.sigmoid().item()
return 1 if pred > 0.75 else 0
def preprocess_text_logreg(text):
# Apply preprocessing steps (cleaning, tokenization, vectorization)
clean_text = data_preprocessing(
) # Assuming data_preprocessing is your preprocessing function
vectorized_text = logreg_vectorizer.transform([" ".join(clean_text)])
return vectorized_text
# Define function for making predictions
def predict_sentiment_logreg(text):
# Preprocess input text
processed_text = preprocess_text_logreg(text)
# Make prediction
prediction = logreg_predictor.predict(processed_text)
return prediction
metrics = {
"Models": ["Logistic Regression", "LSTM + attention", "ruBERTtiny2"],
"f1-macro score": [0.94376, 0.93317, 0.94070],
df = pd.DataFrame(metrics)
df.set_index("Models", inplace=True) = "Model"
st.sidebar.title("Model Selection")
model_type ="Select Model Type", ["Classic ML", "LSTM", "BERT"])
styled_text = """
.styled-title {
color: #FF00FF;
font-size: 40px;
text-shadow: -2px -2px 4px #000000;
-webkit-text-stroke-width: 1px;
-webkit-text-stroke-color: #000000;
.positive {
color: #00FF00;
font-size: 30px;
text-shadow: -2px -2px 4px #000000;
-webkit-text-stroke-width: 1px;
-webkit-text-stroke-color: #000000;
.negative {
color: #FF0000;
font-size: 30px;
text-shadow: -2px -2px 4px #000000;
-webkit-text-stroke-width: 1px;
-webkit-text-stroke-color: #000000;
st.markdown(styled_text, unsafe_allow_html=True)
# Streamlit app code
st.markdown('<div class="styled-title">Review Prediction</div>', unsafe_allow_html=True)
text_input = st.text_input("Enter your review:")
if st.button("Predict"):
if model_type == "Classic ML":
prediction = predict_sentiment_logreg(text_input)
elif model_type == "LSTM":
prediction = plot_and_predict(
review=text_input, SEQ_LEN=25, model=model_concat_embed
elif model_type == "BERT":
prediction = predict_sentiment(text_input, model, tokenizer, "cpu")
# Apply different styles based on prediction result
if prediction == 1:
f'<div class="positive">Отзыв положительный</div>', unsafe_allow_html=True
elif prediction == 0:
f'<div class="negative">Отзыв отрицательный</div>', unsafe_allow_html=True