bert_sentiment7 / app.py
jonghhhh's picture
Update app.py
71675d3 verified
raw
history blame
No virus
1.18 kB
import streamlit as st
import torch
from transformers import BertConfig, BertForSequenceClassification, BertTokenizer
import numpy as np
# Load the model and tokenizer
def load_model():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=7)
model_state_dict = torch.load('sentiment7_model_acc8878.pth', map_location=torch.device('cpu')) # cpu ์‚ฌ์šฉ
model.load_state_dict(model_state_dict)
model.eval()
return model, tokenizer
model, tokenizer = load_model()
# Define the inference function
def inference(input_doc):
inputs = tokenizer(input_doc, return_tensors='pt')
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1).squeeze().tolist()
class_idx = {'๊ณตํฌ': 0, '๋†€๋žŒ': 1, '๋ถ„๋…ธ': 2, '์Šฌํ””': 3, '์ค‘๋ฆฝ': 4, 'ํ–‰๋ณต': 5, 'ํ˜์˜ค': 6}
return {class_name: prob for class_name, prob in zip(class_idx, probs)}
# Set up the Streamlit interface
st.title('Sentiment Analysis with BERT')
user_input = st.text_area("Enter text here:")
if st.button('Analyze'):
result = inference(user_input)
st.write(result)