AleVento's picture
Update app.py
40124a1
raw
history blame
1.16 kB
import streamlit as st
import torch
from transformers import BertTokenizer
BERT_MODEL_NAME = 'bert-base-cased'
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
model = torch.load("./model.pt")
st.title("Analisis de Sentimientos")
txt = st.text_area(label="Please write what you want to analyze...")
def run_sentiment_analysis (txt) :
THRESHOLD = 0.5
encoding = tokenizer.encode_plus(
txt,
add_special_tokens=True,
max_length=512,
return_token_type_ids=False,
padding="max_length",
return_attention_mask=True,
return_tensors='pt',
)
_, test_prediction = model(encoding["input_ids"], encoding["attention_mask"])
test_prediction = test_prediction.flatten().numpy()
predictions = []
print('-------------------- Predictions ---------------------')
for label, prediction in zip(LABEL_COLUMNS, test_prediction):
if prediction < THRESHOLD:
continue
predictions.append(" ".join([label,str(prediction)]))
return predictions
predictions = run_sentiment_analysis(txt)
for prediction in predictions:
st.write(prediction)