trytry / app.py
rowanm945's picture
Update app.py
cb29d9b verified
# -*- coding: utf-8 -*-
"""
UseCase_with_Streamlit.py
This basic Streamlit app fetches Reddit posts from a few subreddits over the past 14 days,
computes sentiment scores using a PyTorch model, forecasts a 7-day sentiment trend using a pre-trained forecast model,
and displays the forecast plot.
Note: No extra logging or scheduling is included.
"""
import os, re, datetime, io
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline
import matplotlib.font_manager as fm
import joblib
import torch
import torch.nn as nn
from transformers import AutoTokenizer
import praw
import streamlit as st
# -------------------------------
# Inject custom CSS for Afacada font styling
# -------------------------------
st.markdown(
"""
<style>
body {
background-color: #fffff2;
}
@font-face {
font-family: 'Afacada';
src: url('AfacadFlux-VariableFont_slnt,wght[1].ttf') format('truetype');
font-weight: normal;
font-style: normal;
}
/* Title styling */
h1 {
font-family: 'Afacada', sans-serif;
color: #244B48;
}
/* Button styling */
.stButton>button {
font-family: 'Afacada', sans-serif;
font-size: 20px;
padding: 0.75rem 1.5rem;
background-color: #244B48;
color: white;
border: none;
border-radius: 4px;
}
.stButton>button:hover {
background-color: #1f3e38;
color: white;
}
.stButton>button:active, .stButton>button:focus {
background-color: #244B48;
color: white;
outline: none;
}
</style>
""",
unsafe_allow_html=True
)
# -------------------------------
# Load Models and Tokenizer
# -------------------------------
sentiment_model = joblib.load('sentiment_forecast_model.pkl')
MODEL = "cardiffnlp/xlm-twitter-politics-sentiment"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
class ScorePredictor(nn.Module):
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
super(ScorePredictor, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, input_ids, attention_mask):
embedded = self.embedding(input_ids)
lstm_out, _ = self.lstm(embedded)
final_hidden_state = lstm_out[:, -1, :]
output = self.fc(final_hidden_state)
return self.sigmoid(output)
score_model = ScorePredictor(tokenizer.vocab_size)
score_model.load_state_dict(torch.load("score_predictor.pth"))
score_model.eval()
# -------------------------------
# Set up Reddit API Client
# -------------------------------
reddit = praw.Reddit(
client_id=os.environ.get("REDDIT_CLIENT_ID"),
client_secret=os.environ.get("REDDIT_CLIENT_SECRET"),
user_agent='MyAPI/0.0.1',
check_for_async=False
)
# -------------------------------
# Helper Functions
# -------------------------------
def fetch_posts(subreddit_name, start_time, limit=100):
posts = []
subreddit = reddit.subreddit(subreddit_name)
for post in subreddit.new(limit=limit):
post_time = datetime.datetime.utcfromtimestamp(post.created_utc)
if post_time >= start_time:
posts.append({
"date": post_time.strftime('%Y-%m-%d %H:%M:%S'),
"post_text": post.title
})
return posts
def predict_score(text):
if not text:
return 0.0
encoded = tokenizer(text.split(), return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
score = score_model(encoded["input_ids"], encoded["attention_mask"])[0].item()
return score
# -------------------------------
# Streamlit Interface
# -------------------------------
st.title("7-Day Sentiment Forecast")
if st.button("Run Analysis"):
subreddits = ["ohio", "libertarian", "centrist"]
start_time = datetime.datetime.utcnow() - datetime.timedelta(days=14)
all_posts = []
for sub in subreddits:
all_posts.extend(fetch_posts(sub, start_time))
if not all_posts:
st.error("No posts fetched.")
else:
df = pd.DataFrame(all_posts)
df['date'] = pd.to_datetime(df['date'])
df['date_only'] = df['date'].dt.date
df = df.sort_values(by='date_only')
df['sentiment_score'] = df['post_text'].apply(predict_score)
daily_sentiment = df.groupby('date_only')['sentiment_score'].mean()
if len(daily_sentiment) < 14:
mean_val = daily_sentiment.mean()
pad = [mean_val] * (14 - len(daily_sentiment))
daily_sentiment = np.concatenate([daily_sentiment.values, pad])
daily_sentiment = pd.Series(daily_sentiment)
forecast = sentiment_model.predict(daily_sentiment.values.reshape(1, -1))[0]
font_path = "AfacadFlux-VariableFont_slnt,wght[1].ttf"
custom_font = fm.FontProperties(fname=font_path)
today = datetime.date.today()
days = [today + datetime.timedelta(days=i) for i in range(7)]
days_str = [d.strftime('%a %m/%d') for d in days]
xnew = np.linspace(0, 6, 300)
spline = make_interp_spline(np.arange(7), forecast, k=3)
smooth_forecast = spline(xnew)
fig, ax = plt.subplots(figsize=(8, 5))
ax.fill_between(xnew, smooth_forecast, color='#244B48', alpha=0.4)
ax.plot(xnew, smooth_forecast, color='#244B48', lw=3)
ax.scatter(np.arange(7), forecast, color='#244B48', s=50)
ax.set_title("7-Day Sentiment Forecast", fontproperties=custom_font, fontsize=20)
ax.set_xlabel("Day", fontproperties=custom_font, fontsize=14)
ax.set_ylabel("Sentiment", fontproperties=custom_font, fontsize=14)
ax.set_xticks(np.arange(7))
ax.set_xticklabels(days_str, fontproperties=custom_font, fontsize=12)
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
st.image(buf, caption="Forecast Plot")