# -*- coding: utf-8 -*- import os import datetime import numpy as np import pandas as pd import joblib import matplotlib.pyplot as plt from scipy.interpolate import make_interp_spline import torch import torch.nn as nn from transformers import AutoTokenizer import asyncpraw import asyncio import gradio as gr # Global settings num_days = 14 client_id = os.getenv('client_id') client_secret = os.getenv('client_secret') MODEL_PATH = "cardiffnlp/xlm-twitter-politics-sentiment" # Minimal model class definition (required) class ScorePredictor(nn.Module): def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) self.fc = nn.Linear(hidden_dim, output_dim) self.sigmoid = nn.Sigmoid() def forward(self, input_ids, attention_mask): x = self.embedding(input_ids) x, _ = self.lstm(x) x = self.fc(x[:, -1, :]) return self.sigmoid(x) # Load models sentiment_model = joblib.load('sentiment_forecast_model.pkl') tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) score_model = ScorePredictor(tokenizer.vocab_size) score_model.load_state_dict(torch.load("score_predictor.pth", map_location=torch.device('cpu'))) score_model.eval() print("Models loaded successfully.") # Function to fetch posts from Reddit async def get_posts(subreddit_name, time_filter='month'): # Initialize asyncpraw Reddit client async_reddit = asyncpraw.Reddit( client_id=client_id, client_secret=client_secret, user_agent="sentimentForecastAgent" ) subreddit = await async_reddit.subreddit(subreddit_name) posts = [] async for post in subreddit.top(time_filter=time_filter, limit=25): posts.append({ "date": datetime.datetime.utcfromtimestamp(post.created_utc).strftime('%Y-%m-%d %H:%M:%S'), "post_text": post.title }) return posts # Function to calculate sentiment def calculate_sentiment(text): if not text: return 0.0 else: encoded = tokenizer(text.split(), return_tensors='pt', padding=True, truncation=True, max_length=512) with torch.no_grad(): score_val = score_model(encoded["input_ids"], encoded["attention_mask"])[0].item() return score_val # Function to generate sentiment forecast async def generate_forecast(subreddit, num_days=14): # Fetch posts asynchronously posts = await get_posts(subreddit) # Create DataFrame and process dates df = pd.DataFrame(posts) df['date'] = pd.to_datetime(df['date']) df['date_only'] = df['date'].dt.date df = df.sort_values('date_only') # Calculate sentiment scores df['sentiment_score'] = df['post_text'].apply(calculate_sentiment) # Create a complete date index for the last num_days and group by date full_dates = sorted([datetime.date.today() - datetime.timedelta(days=i) for i in range(num_days)]) daily = df.groupby('date_only')['sentiment_score'].mean().reindex(full_dates, fill_value=0.0) historical = daily.values.tolist() # Forecast using the pre-loaded sentiment_model forecast = sentiment_model.predict(np.array(historical).reshape(1, -1))[0] # Create forecast plot today = datetime.date.today() forecast_dates = [today + datetime.timedelta(days=i) for i in range(7)] x = np.arange(7) xnew = np.linspace(0, 6, 300) spline = make_interp_spline(x, forecast, k=min(3, len(forecast)-1)) smooth = spline(xnew) fig, ax = plt.subplots(figsize=(14, 7)) ax.fill_between(xnew, smooth, color='#244B48', alpha=0.4) ax.plot(xnew, smooth, color='#244B48', lw=3, label='Forecast') ax.scatter(x, forecast, color='#244B48', s=100) ax.set_title("7-Day Negative Sentiment Forecast", fontsize=22, fontweight='bold', pad=20) ax.set_xlabel("Date", fontsize=16) ax.set_ylabel("Negative Sentiment (0-1)", fontsize=16) ax.set_xticks(x) ax.set_xticklabels([d.strftime('%a %m/%d') for d in forecast_dates], fontsize=12, rotation=45) ax.legend(fontsize=14, loc='upper right') plt.tight_layout() if not posts: summary = "Subreddit not found or criteria not met for Cypher." else: summary = f"r/{subreddit} has loaded!" return fig, summary # Gradio interface async def run_forecast(subreddit): fig, summary = await generate_forecast(subreddit) return fig, summary # Gradio app with gr.Blocks(title="Subreddit Negative Sentiment Forecast") as demo: gr.Markdown("Subreddit Negative Sentiment Forecast") gr.Markdown("Analyze recent Reddit posts to forecast negative sentiment for the next 7 days.") subreddit_input = gr.Textbox(label="Subreddit (without r/)", placeholder="e.g. politics", value="politics") output_text = gr.Textbox(label="Summary", lines=2) submit_btn = gr.Button("Generate Forecast") output_plot = gr.Plot(label="Forecast Plot") submit_btn.click(fn=run_forecast, inputs=subreddit_input, outputs=[output_plot, output_text]) # Launch the app if __name__ == "__main__": demo.launch()