|
|
|
""" |
|
UseCase_with_Streamlit.py |
|
|
|
This basic Streamlit app fetches Reddit posts from a few subreddits over the past 14 days, |
|
computes sentiment scores using a PyTorch model, forecasts a 7-day sentiment trend using a pre-trained forecast model, |
|
and displays the forecast plot. |
|
Note: No extra logging or scheduling is included. |
|
""" |
|
|
|
import os, re, datetime, io |
|
import numpy as np |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from scipy.interpolate import make_interp_spline |
|
import matplotlib.font_manager as fm |
|
import joblib |
|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoTokenizer |
|
import praw |
|
import streamlit as st |
|
|
|
|
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
body { |
|
background-color: #fffff2; |
|
} |
|
@font-face { |
|
font-family: 'Afacada'; |
|
src: url('AfacadFlux-VariableFont_slnt,wght[1].ttf') format('truetype'); |
|
font-weight: normal; |
|
font-style: normal; |
|
} |
|
/* Title styling */ |
|
h1 { |
|
font-family: 'Afacada', sans-serif; |
|
color: #244B48; |
|
} |
|
/* Button styling */ |
|
.stButton>button { |
|
font-family: 'Afacada', sans-serif; |
|
font-size: 20px; |
|
padding: 0.75rem 1.5rem; |
|
background-color: #244B48; |
|
color: white; |
|
border: none; |
|
border-radius: 4px; |
|
} |
|
.stButton>button:hover { |
|
background-color: #1f3e38; |
|
color: white; |
|
} |
|
.stButton>button:active, .stButton>button:focus { |
|
background-color: #244B48; |
|
color: white; |
|
outline: none; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
sentiment_model = joblib.load('sentiment_forecast_model.pkl') |
|
MODEL = "cardiffnlp/xlm-twitter-politics-sentiment" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL) |
|
|
|
class ScorePredictor(nn.Module): |
|
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1): |
|
super(ScorePredictor, self).__init__() |
|
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) |
|
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) |
|
self.fc = nn.Linear(hidden_dim, output_dim) |
|
self.sigmoid = nn.Sigmoid() |
|
def forward(self, input_ids, attention_mask): |
|
embedded = self.embedding(input_ids) |
|
lstm_out, _ = self.lstm(embedded) |
|
final_hidden_state = lstm_out[:, -1, :] |
|
output = self.fc(final_hidden_state) |
|
return self.sigmoid(output) |
|
|
|
score_model = ScorePredictor(tokenizer.vocab_size) |
|
score_model.load_state_dict(torch.load("score_predictor.pth")) |
|
score_model.eval() |
|
|
|
|
|
|
|
|
|
reddit = praw.Reddit( |
|
client_id=os.environ.get("REDDIT_CLIENT_ID"), |
|
client_secret=os.environ.get("REDDIT_CLIENT_SECRET"), |
|
user_agent='MyAPI/0.0.1', |
|
check_for_async=False |
|
) |
|
|
|
|
|
|
|
|
|
def fetch_posts(subreddit_name, start_time, limit=100): |
|
posts = [] |
|
subreddit = reddit.subreddit(subreddit_name) |
|
for post in subreddit.new(limit=limit): |
|
post_time = datetime.datetime.utcfromtimestamp(post.created_utc) |
|
if post_time >= start_time: |
|
posts.append({ |
|
"date": post_time.strftime('%Y-%m-%d %H:%M:%S'), |
|
"post_text": post.title |
|
}) |
|
return posts |
|
|
|
def predict_score(text): |
|
if not text: |
|
return 0.0 |
|
encoded = tokenizer(text.split(), return_tensors='pt', padding=True, truncation=True) |
|
with torch.no_grad(): |
|
score = score_model(encoded["input_ids"], encoded["attention_mask"])[0].item() |
|
return score |
|
|
|
|
|
|
|
|
|
st.title("7-Day Sentiment Forecast") |
|
|
|
if st.button("Run Analysis"): |
|
subreddits = ["ohio", "libertarian", "centrist"] |
|
start_time = datetime.datetime.utcnow() - datetime.timedelta(days=14) |
|
all_posts = [] |
|
for sub in subreddits: |
|
all_posts.extend(fetch_posts(sub, start_time)) |
|
|
|
if not all_posts: |
|
st.error("No posts fetched.") |
|
else: |
|
df = pd.DataFrame(all_posts) |
|
df['date'] = pd.to_datetime(df['date']) |
|
df['date_only'] = df['date'].dt.date |
|
df = df.sort_values(by='date_only') |
|
df['sentiment_score'] = df['post_text'].apply(predict_score) |
|
|
|
daily_sentiment = df.groupby('date_only')['sentiment_score'].mean() |
|
if len(daily_sentiment) < 14: |
|
mean_val = daily_sentiment.mean() |
|
pad = [mean_val] * (14 - len(daily_sentiment)) |
|
daily_sentiment = np.concatenate([daily_sentiment.values, pad]) |
|
daily_sentiment = pd.Series(daily_sentiment) |
|
|
|
forecast = sentiment_model.predict(daily_sentiment.values.reshape(1, -1))[0] |
|
|
|
font_path = "AfacadFlux-VariableFont_slnt,wght[1].ttf" |
|
custom_font = fm.FontProperties(fname=font_path) |
|
today = datetime.date.today() |
|
days = [today + datetime.timedelta(days=i) for i in range(7)] |
|
days_str = [d.strftime('%a %m/%d') for d in days] |
|
|
|
xnew = np.linspace(0, 6, 300) |
|
spline = make_interp_spline(np.arange(7), forecast, k=3) |
|
smooth_forecast = spline(xnew) |
|
|
|
fig, ax = plt.subplots(figsize=(8, 5)) |
|
ax.fill_between(xnew, smooth_forecast, color='#244B48', alpha=0.4) |
|
ax.plot(xnew, smooth_forecast, color='#244B48', lw=3) |
|
ax.scatter(np.arange(7), forecast, color='#244B48', s=50) |
|
ax.set_title("7-Day Sentiment Forecast", fontproperties=custom_font, fontsize=20) |
|
ax.set_xlabel("Day", fontproperties=custom_font, fontsize=14) |
|
ax.set_ylabel("Sentiment", fontproperties=custom_font, fontsize=14) |
|
ax.set_xticks(np.arange(7)) |
|
ax.set_xticklabels(days_str, fontproperties=custom_font, fontsize=12) |
|
plt.tight_layout() |
|
|
|
buf = io.BytesIO() |
|
fig.savefig(buf, format='png') |
|
buf.seek(0) |
|
st.image(buf, caption="Forecast Plot") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|