""" |
UseCase_with_Streamlit.py |
This basic Streamlit app fetches Reddit posts from a few subreddits over the past 14 days, |
computes sentiment scores using a PyTorch model, forecasts a 7-day sentiment trend using a pre-trained forecast model, |
and displays the forecast plot. |
Note: No extra logging or scheduling is included. |
""" |
import os, re, datetime, io |
import numpy as np |
import pandas as pd |
import matplotlib.pyplot as plt |
from scipy.interpolate import make_interp_spline |
import matplotlib.font_manager as fm |
import joblib |
import torch |
import torch.nn as nn |
from transformers import AutoTokenizer |
import praw |
import streamlit as st |
st.markdown( |
""" |
<style> |
body { |
background-color: #fffff2; |
} |
@font-face { |
font-family: 'Afacada'; |
src: url('AfacadFlux-VariableFont_slnt,wght[1].ttf') format('truetype'); |
font-weight: normal; |
font-style: normal; |
} |
/* Title styling */ |
h1 { |
font-family: 'Afacada', sans-serif; |
color: #244B48; |
} |
/* Button styling */ |
.stButton>button { |
font-family: 'Afacada', sans-serif; |
font-size: 20px; |
padding: 0.75rem 1.5rem; |
background-color: #244B48; |
color: white; |
border: none; |
border-radius: 4px; |
} |
.stButton>button:hover { |
background-color: #1f3e38; |
color: white; |
} |
.stButton>button:active, .stButton>button:focus { |
background-color: #244B48; |
color: white; |
outline: none; |
} |
</style> |
""", |
unsafe_allow_html=True |
) |
sentiment_model = joblib.load('sentiment_forecast_model.pkl') |
MODEL = "cardiffnlp/xlm-twitter-politics-sentiment" |
tokenizer = AutoTokenizer.from_pretrained(MODEL) |
class ScorePredictor(nn.Module): |
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1): |
super(ScorePredictor, self).__init__() |
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) |
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) |
self.fc = nn.Linear(hidden_dim, output_dim) |
self.sigmoid = nn.Sigmoid() |
def forward(self, input_ids, attention_mask): |
embedded = self.embedding(input_ids) |
lstm_out, _ = self.lstm(embedded) |
final_hidden_state = lstm_out[:, -1, :] |
output = self.fc(final_hidden_state) |
return self.sigmoid(output) |
score_model = ScorePredictor(tokenizer.vocab_size) |
score_model.load_state_dict(torch.load("score_predictor.pth")) |
score_model.eval() |
reddit = praw.Reddit( |
client_id=os.environ.get("REDDIT_CLIENT_ID"), |
client_secret=os.environ.get("REDDIT_CLIENT_SECRET"), |
user_agent='MyAPI/0.0.1', |
check_for_async=False |
) |
def fetch_posts(subreddit_name, start_time, limit=100): |
posts = [] |
subreddit = reddit.subreddit(subreddit_name) |
for post in subreddit.new(limit=limit): |
post_time = datetime.datetime.utcfromtimestamp(post.created_utc) |
if post_time >= start_time: |
posts.append({ |
"date": post_time.strftime('%Y-%m-%d %H:%M:%S'), |
"post_text": post.title |
}) |
return posts |
def predict_score(text): |
if not text: |
return 0.0 |
encoded = tokenizer(text.split(), return_tensors='pt', padding=True, truncation=True) |
with torch.no_grad(): |
score = score_model(encoded["input_ids"], encoded["attention_mask"])[0].item() |
return score |
st.title("7-Day Sentiment Forecast") |
if st.button("Run Analysis"): |
subreddits = ["ohio", "libertarian", "centrist"] |
start_time = datetime.datetime.utcnow() - datetime.timedelta(days=14) |
all_posts = [] |
for sub in subreddits: |
all_posts.extend(fetch_posts(sub, start_time)) |
if not all_posts: |
st.error("No posts fetched.") |
else: |
df = pd.DataFrame(all_posts) |
df['date'] = pd.to_datetime(df['date']) |
df['date_only'] = df['date'].dt.date |
df = df.sort_values(by='date_only') |
df['sentiment_score'] = df['post_text'].apply(predict_score) |
daily_sentiment = df.groupby('date_only')['sentiment_score'].mean() |
if len(daily_sentiment) < 14: |
mean_val = daily_sentiment.mean() |
pad = [mean_val] * (14 - len(daily_sentiment)) |
daily_sentiment = np.concatenate([daily_sentiment.values, pad]) |
daily_sentiment = pd.Series(daily_sentiment) |
forecast = sentiment_model.predict(daily_sentiment.values.reshape(1, -1))[0] |
font_path = "AfacadFlux-VariableFont_slnt,wght[1].ttf" |
custom_font = fm.FontProperties(fname=font_path) |
today = datetime.date.today() |
days = [today + datetime.timedelta(days=i) for i in range(7)] |
days_str = [d.strftime('%a %m/%d') for d in days] |
xnew = np.linspace(0, 6, 300) |
spline = make_interp_spline(np.arange(7), forecast, k=3) |
smooth_forecast = spline(xnew) |
fig, ax = plt.subplots(figsize=(8, 5)) |
ax.fill_between(xnew, smooth_forecast, color='#244B48', alpha=0.4) |
ax.plot(xnew, smooth_forecast, color='#244B48', lw=3) |
ax.scatter(np.arange(7), forecast, color='#244B48', s=50) |
ax.set_title("7-Day Sentiment Forecast", fontproperties=custom_font, fontsize=20) |
ax.set_xlabel("Day", fontproperties=custom_font, fontsize=14) |
ax.set_ylabel("Sentiment", fontproperties=custom_font, fontsize=14) |
ax.set_xticks(np.arange(7)) |
ax.set_xticklabels(days_str, fontproperties=custom_font, fontsize=12) |
plt.tight_layout() |
buf = io.BytesIO() |
fig.savefig(buf, format='png') |
buf.seek(0) |
st.image(buf, caption="Forecast Plot") |