bert_sentiment7 / app.py
jonghhhh's picture
Update app.py
2aff68f verified
raw
history blame
1.22 kB
import streamlit as st
import torch
from transformers import BertConfig, BertForSequenceClassification, BertTokenizer
import numpy as np
import requests
from io import BytesIO
# Load the model and tokenizer
def load_model():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=7)
model_state_dict = torch.load('sentiment7_model_acc8878.pth', map_location=torch.device('cpu')) # cpu ์‚ฌ์šฉ
model.load_state_dict(model_state_dict)
model.eval()
return model, tokenizer
model, tokenizer = load_model()
# Define the inference function
def inference(input_doc):
inputs = tokenizer(input_doc, return_tensors='pt')
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1).squeeze().tolist()
class_idx = {'๊ณตํฌ': 0, '๋†€๋žŒ': 1, '๋ถ„๋…ธ': 2, '์Šฌํ””': 3, '์ค‘๋ฆฝ': 4, 'ํ–‰๋ณต': 5, 'ํ˜์˜ค': 6}
return {class_name: prob for class_name, prob in zip(class_idx, probs)}
# Set up the Streamlit interface
st.title('Sentiment Analysis with BERT')
user_input = st.text_area("Enter text here:")
if st.button('Analyze'):
result = inference(user_input)
st.write(result)