Spaces:
Paused
Paused
import streamlit as st | |
from functions_preprocess import LinguisticPreprocessor, download_if_non_existent, CNN | |
import pickle | |
import nltk | |
import torch | |
nltk.download('stopwords') | |
nltk.download('punkt') | |
download_if_non_existent('corpora/stopwords', 'stopwords') | |
download_if_non_existent('taggers/averaged_perceptron_tagger', 'averaged_perceptron_tagger') | |
download_if_non_existent('corpora/wordnet', 'wordnet') | |
#################################################################### Streamlit interface | |
st.title("Movie Reviews: An NLP Sentiment analysis") | |
#################################################################### Cache the model loading | |
def load_model(): | |
model_pkl_file = "sentiment_model.pkl" | |
with open(model_pkl_file, 'rb') as file: | |
model = pickle.load(file) | |
return model | |
def load_cnn(): | |
model = CNN(16236, 300, 128, [3, 8], 0.5, 2) | |
model.load_state_dict(torch.load('model_cnn.pkl')) | |
model.eval() | |
return model | |
def predict_sentiment(text, model): | |
processor.transform(text) | |
prediction = model.predict([text]) | |
return prediction | |
model_1 = load_model() | |
model_2 = load_cnn() | |
processor = LinguisticPreprocessor() | |
############################################################# Text input | |
with st.expander("Model 1: SGD Classifier"): | |
st.markdown("Give it a go by writing a positive or negative text, and analyze it!") | |
# Text input inside the expander | |
user_input = st.text_area("Enter text here...", key='model1_input') | |
if st.button('Analyze', key='model1_input'): | |
# Displaying output | |
result = predict_sentiment(user_input, model_1) | |
if result >= 0.5: | |
st.write('The sentiment is: Positive π', key='model1_input') | |
else: | |
st.write('The sentiment is: Negative π', key='model1_input') | |
with st.expander("Model 2: CNN Sentiment analysis"): | |
st.markdown("Give it a go by writing a positive or negative text, and analyze it!") | |
# Text input inside the expander | |
user_input = st.text_area("Enter text here...", key='model2_input') | |
if st.button('Analyze', key='model2_input'): | |
# Displaying output | |
result = predict_sentiment(user_input, model_2) | |
if result >= 0.5: | |
st.write('The sentiment is: Positive π', key='model2_input') | |
else: | |
st.write('The sentiment is: Negative π', key='model2_input') | |
st.caption("Por @efeperro.") |