new_thing / app.py
junaid17's picture
Upload 6 files
6779b8d verified
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import DistilBertTokenizer, DistilBertModel
import time
# Set page config with dark theme
st.set_page_config(
page_title="TwittoBERT",
page_icon="🐦",
layout="centered",
initial_sidebar_state="expanded"
)
# Custom CSS for dark theme
st.markdown("""
<style>
:root {
--primary-color: #1DA1F2;
--background-color: #0F0F0F;
--secondary-background: #1E1E1E;
--text-color: #FFFFFF;
--font: sans-serif;
}
body {
background-color: var(--background-color);
color: var(--text-color);
font-family: var(--font);
}
.stApp {
background-color: var(--background-color);
}
.stTextInput>div>div>input {
background-color: var(--secondary-background);
color: var(--text-color);
border: 1px solid #333;
}
.stButton>button {
background-color: var(--primary-color);
color: white;
border-radius: 8px;
padding: 0.5rem 1rem;
border: none;
font-weight: bold;
transition: all 0.3s;
}
.stButton>button:hover {
background-color: #1991db;
transform: scale(1.02);
}
.prediction-box {
padding: 1.5rem;
border-radius: 10px;
margin: 1.5rem 0;
background-color: var(--secondary-background);
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
border-left: 5px solid var(--primary-color);
}
.header {
color: var(--primary-color);
}
.positive {
border-left-color: #4CAF50;
}
.neutral {
border-left-color: #FFCC00;
}
.negative {
border-left-color: #FF4D4D;
}
.sample-tweet {
padding: 0.5rem;
margin: 0.5rem 0;
border-radius: 5px;
background-color: var(--secondary-background);
cursor: pointer;
transition: all 0.2s;
}
.sample-tweet:hover {
background-color: #2A2A2A;
}
</style>
""", unsafe_allow_html=True)
# SentimentClassifier model definition
class SentimentClassifier(torch.nn.Module):
def __init__(self):
super(SentimentClassifier, self).__init__()
self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
for param in self.bert.parameters():
param.requires_grad = False
self.classifier = torch.nn.Sequential(
torch.nn.Linear(768, 256),
torch.nn.BatchNorm1d(256),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(256, 128),
torch.nn.BatchNorm1d(128),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(128, 64),
torch.nn.BatchNorm1d(64),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(64, 3)
)
def forward(self, input_ids, attention_mask):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
sentence_embeddings = bert_output.last_hidden_state[:, 0, :]
return self.classifier(sentence_embeddings)
# Load model and tokenizer
@st.cache_resource
def load_model():
model = SentimentClassifier()
model.load_state_dict(torch.load('BERT_MODEL.pth', map_location=torch.device('cpu')))
model.eval()
return model
@st.cache_resource
def load_tokenizer():
return DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
# Prediction function
def predict_sentiment(model, tokenizer, tweet):
inputs = tokenizer(
tweet,
padding="max_length",
max_length=200,
truncation=True,
return_tensors="pt"
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
with torch.no_grad():
logits = model(input_ids, attention_mask)
probs = F.softmax(logits, dim=1)
confidence, predicted_class = torch.max(probs, dim=1)
class_names = ["Negative", "Neutral", "Positive"]
label = class_names[predicted_class.item()]
confidence_percent = confidence.item() * 100
return label, confidence_percent
def main():
st.title("🐦 TwittoBERT")
st.markdown("Analyze the sentiment of tweets using a fine-tuned BERT model", unsafe_allow_html=True)
# Load model and tokenizer
try:
model = load_model()
tokenizer = load_tokenizer()
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.stop()
# Sample tweets
st.subheader("Try these sample tweets:")
sample_tweets = [
"I love this product! It's absolutely amazing! 😍",
"The service was okay, nothing special.",
"This is the worst experience I've ever had. Terrible!",
"Just had the best coffee of my life at this new café!",
"The movie was decent but could have been better.",
"I'm so frustrated with this terrible customer service!"
]
cols = st.columns(2)
for i, tweet in enumerate(sample_tweets):
with cols[i % 2]:
if st.button(tweet[:50] + "..." if len(tweet) > 50 else tweet,
key=f"sample_{i}",
help="Click to analyze this tweet"):
st.session_state.sample_tweet = tweet
# Tweet input
tweet = st.text_area("Or enter your own tweet to analyze:",
height=100,
placeholder="Type your tweet here...",
value=st.session_state.get("sample_tweet", ""))
if st.button("Analyze Sentiment") and tweet:
with st.spinner("Analyzing sentiment..."):
time.sleep(0.5) # Simulate processing time
label, confidence = predict_sentiment(model, tokenizer, tweet)
# Display result with appropriate styling
if label == "Negative":
st.markdown(f"""
<div class="prediction-box negative">
<h3>Sentiment: {label}</h3>
<p>Confidence: {confidence:.2f}%</p>
</div>
""", unsafe_allow_html=True)
elif label == "Neutral":
st.markdown(f"""
<div class="prediction-box neutral">
<h3>Sentiment: {label}</h3>
<p>Confidence: {confidence:.2f}%</p>
</div>
""", unsafe_allow_html=True)
else:
st.markdown(f"""
<div class="prediction-box positive">
<h3>Sentiment: {label}</h3>
<p>Confidence: {confidence:.2f}%</p>
</div>
""", unsafe_allow_html=True)
# Sidebar info
st.sidebar.header("About")
st.sidebar.markdown("""
This app uses a fine-tuned DistilBERT model to analyze sentiment in tweets.
It can classify tweets as Positive, Negative, or Neutral with confidence scores.
""")
st.sidebar.header("Model Info")
st.sidebar.text("Model: DistilBERT-base-uncased")
st.sidebar.text("Classes: Negative, Neutral, Positive")
if __name__ == "__main__":
main()