import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import praw
import time
from datetime import datetime, timedelta
import json
import os
from typing import List, Dict, Any, Optional, Tuple
import concurrent.futures
from functools import lru_cache
import hashlib
import pytz
import sqlite3
import networkx as nx
from pathlib import Path
# Advanced features optional - will gracefully degrade if not available
try:
from advanced_reddit_scraper import (
AdvancedRedditScraper,
ExponentialBackoff,
CommentHierarchyTracker,
CheckpointManager
)
ADVANCED_FEATURES = True
except ImportError:
ADVANCED_FEATURES = False
def load_env_file(env_path: str = ".env") -> Dict[str, str]:
"""
Load environment variables from .env file
Args:
env_path: Path to .env file
Returns:
Dictionary of environment variables
"""
env_vars = {}
env_file = Path(env_path)
if env_file.exists():
with open(env_file, 'r') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, value = line.split('=', 1)
key = key.strip()
value = value.strip().strip('"').strip("'")
env_vars[key] = value
return env_vars
st.set_page_config(
page_title="Reddit Research Dashboard",
page_icon="đ",
layout="wide",
initial_sidebar_state="expanded"
)
st.markdown("""
""", unsafe_allow_html=True)
class OptimizedRedditScraper:
"""
Optimized Reddit scraper with batch processing, caching, and temporal analytics
"""
def __init__(self, client_id: str, client_secret: str, user_agent: str):
"""Initialize with Reddit API credentials"""
self.reddit = praw.Reddit(
client_id=client_id,
client_secret=client_secret,
user_agent=user_agent,
check_for_async=False
)
self.last_request_time = 0
self.min_delay = 0.5
def fetch_subreddit_data_verbose(self, subreddit_name: str, sort_by: str = "hot",
limit: int = 200, time_filter: str = "month",
log_container=None) -> pd.DataFrame:
"""
Fetch Reddit data with verbose logging
Args:
subreddit_name: Name of subreddit to scrape
sort_by: Sort method (hot, new, top, rising)
limit: Number of posts to fetch (optimized for 200+ items)
time_filter: Time filter for top posts
log_container: Streamlit container for logging output
Returns:
DataFrame with Reddit posts data
"""
def stream_post(post_data, stream_container):
"""Display a post as it's collected"""
if stream_container:
timestamp = datetime.now().strftime("%H:%M:%S")
with stream_container.container():
with st.expander(f"đ {post_data['title'][:80]}...", expanded=False):
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Score", post_data['score'])
with col2:
st.metric("Comments", post_data['num_comments'])
with col3:
st.text(f"u/{post_data['author']}")
with col4:
st.text(timestamp)
def update_stats(stats_container, total, authors, comments):
"""Update collection statistics"""
if stats_container:
stats_container.empty()
with stats_container:
col1, col2, col3 = st.columns(3)
with col1:
st.metric("đ Posts", total)
with col2:
st.metric("đĨ Authors", authors)
with col3:
st.metric("đŦ Comments", f"{comments:,}")
# Initialize streaming containers
stats_container = None
stream_container = None
if log_container:
# Check if log_container is a tuple of (stats, stream)
if isinstance(log_container, tuple):
stats_container, stream_container = log_container
else:
stats_container = log_container
stream_container = log_container
data = []
try:
subreddit = self.reddit.subreddit(subreddit_name)
# Choose appropriate method based on sort_by
if sort_by == "top":
submissions = subreddit.top(limit=limit, time_filter=time_filter)
elif sort_by == "new":
submissions = subreddit.new(limit=limit)
elif sort_by == "rising":
submissions = subreddit.rising(limit=limit)
else:
submissions = subreddit.hot(limit=limit)
# Batch processing with rate limiting
batch_size = 25
batch = []
batch_num = 1
post_count = 0
total_comments = 0
try:
# Convert to list to handle iterator exhaustion gracefully
submissions_list = []
try:
for submission in submissions:
try:
# Force PRAW to load the submission by accessing an attribute
_ = submission.id
submissions_list.append(submission)
if len(submissions_list) >= limit:
break
except Exception as sub_error:
# Skip submissions that fail to load
continue
except StopIteration:
pass # Iterator exhausted naturally
except Exception as fetch_error:
error_msg = str(fetch_error)
if "Ran out of input" in error_msg or "prawcore" in error_msg.lower():
# PRAW iterator exhausted - not an error, just end of data
pass
else:
if log_container:
st.warning(f"â ī¸ Stopped early: {error_msg}")
if not submissions_list:
if log_container:
st.error(f"No data could be fetched: {error_msg}")
raise
for i, submission in enumerate(submissions_list):
try:
# Rate limiting before fetching submission data
current_time = time.time()
if current_time - self.last_request_time < self.min_delay:
time.sleep(self.min_delay - (current_time - self.last_request_time))
self.last_request_time = time.time()
batch.append(submission)
post_count += 1
if len(batch) >= batch_size or post_count >= limit:
# Process batch
for idx, sub in enumerate(batch):
try:
# Safely extract all attributes with error handling
try:
post_id = sub.id
post_title = sub.title
post_author = str(sub.author) if sub.author else '[deleted]'
post_created = datetime.fromtimestamp(sub.created_utc, tz=pytz.UTC)
post_score = sub.score
post_comments = sub.num_comments
post_ratio = sub.upvote_ratio
post_text = sub.selftext[:500] if sub.selftext else ''
post_url = sub.url
post_flair = sub.link_flair_text or 'No Flair'
post_video = sub.is_video
post_self = sub.is_self
post_permalink = f"https://reddit.com{sub.permalink}"
except AttributeError as attr_error:
# Missing attribute - skip this post
continue
except Exception as access_error:
# Any other error accessing attributes - skip
continue
post_data = {
'id': post_id,
'title': post_title,
'author': post_author,
'created_utc': post_created,
'score': post_score,
'num_comments': post_comments,
'upvote_ratio': post_ratio,
'selftext': post_text,
'url': post_url,
'subreddit': subreddit_name,
'flair': post_flair,
'is_video': post_video,
'is_self': post_self,
'permalink': post_permalink
}
data.append(post_data)
total_comments += post_data['num_comments']
# Stream the post to UI
stream_post(post_data, stream_container)
except Exception as post_error:
# Skip posts that cause any error
continue
# Update stats
if log_container:
unique_authors = len(set(d['author'] for d in data))
update_stats(stats_container, len(data), unique_authors, total_comments)
batch = []
batch_num += 1
# Update progress
if st.session_state.get('progress_bar'):
progress = min(post_count / limit, 1.0)
st.session_state.progress_bar.progress(progress)
# Stop if we've reached the limit
if post_count >= limit:
break
except StopIteration:
break
except Exception as iter_error:
continue
# Process any remaining items in batch
if batch:
for idx, sub in enumerate(batch):
try:
# Safely extract all attributes
try:
post_id = sub.id
post_title = sub.title
post_author = str(sub.author) if sub.author else '[deleted]'
post_created = datetime.fromtimestamp(sub.created_utc, tz=pytz.UTC)
post_score = sub.score
post_comments = sub.num_comments
post_ratio = sub.upvote_ratio
post_text = sub.selftext[:500] if sub.selftext else ''
post_url = sub.url
post_flair = sub.link_flair_text or 'No Flair'
post_video = sub.is_video
post_self = sub.is_self
post_permalink = f"https://reddit.com{sub.permalink}"
except Exception:
# Skip posts that fail attribute access
continue
post_data = {
'id': post_id,
'title': post_title,
'author': post_author,
'created_utc': post_created,
'score': post_score,
'num_comments': post_comments,
'upvote_ratio': post_ratio,
'selftext': post_text,
'url': post_url,
'subreddit': subreddit_name,
'flair': post_flair,
'is_video': post_video,
'is_self': post_self,
'permalink': post_permalink
}
data.append(post_data)
total_comments += post_data['num_comments']
stream_post(post_data, stream_container)
except Exception:
# Skip any problematic posts
continue
except StopIteration:
pass
# Final stats update
if log_container:
unique_authors = len(set(d['author'] for d in data))
update_stats(stats_container, len(data), unique_authors, total_comments)
except Exception as e:
error_msg = str(e)
# Don't show scary errors for common PRAW issues
if "Ran out of input" in error_msg or "prawcore" in error_msg.lower():
if log_container and len(data) == 0:
st.warning("â ī¸ No posts could be fetched. The subreddit may be empty or private.")
else:
if log_container:
st.error(f"â Error: {error_msg}")
if len(data) == 0: # Only raise if we got no data at all
raise
# Return whatever data we managed to collect
if len(data) == 0 and log_container:
st.info("âšī¸ No posts were collected. Try adjusting your filters or selecting a different subreddit.")
return pd.DataFrame(data)
def fetch_subreddit_data(self, subreddit_name: str, sort_by: str = "hot",
limit: int = 200, time_filter: str = "month") -> pd.DataFrame:
"""
Fetch data with manual session-based caching
"""
# Create cache key
cache_key = f"{subreddit_name}_{sort_by}_{limit}_{time_filter}"
# Check if data exists in session state cache
if 'data_cache' not in st.session_state:
st.session_state.data_cache = {}
if cache_key in st.session_state.data_cache:
cache_entry = st.session_state.data_cache[cache_key]
# Check if cache is still valid (1 hour TTL)
if (datetime.now() - cache_entry['timestamp']).total_seconds() < 3600:
return cache_entry['data']
# Fetch new data
df = self.fetch_subreddit_data_verbose(subreddit_name, sort_by, limit, time_filter, None)
# Store in cache
st.session_state.data_cache[cache_key] = {
'data': df,
'timestamp': datetime.now()
}
return df
def fetch_multiple_subreddits(self, subreddits: List[str], limit_per: int = 100,
sort_by: str = "hot") -> pd.DataFrame:
"""
Fetch data from multiple subreddits with manual caching
Args:
subreddits: List of subreddit names
limit_per: Posts per subreddit
sort_by: Sort method
Returns:
Combined DataFrame
"""
# Create cache key
cache_key = f"multi_{'_'.join(sorted(subreddits))}_{sort_by}_{limit_per}"
# Check cache
if 'data_cache' not in st.session_state:
st.session_state.data_cache = {}
if cache_key in st.session_state.data_cache:
cache_entry = st.session_state.data_cache[cache_key]
# Check if cache is still valid (30 min TTL)
if (datetime.now() - cache_entry['timestamp']).total_seconds() < 1800:
return cache_entry['data']
# Fetch new data
all_data = []
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
future_to_sub = {
executor.submit(self.fetch_subreddit_data, sub, sort_by, limit_per): sub
for sub in subreddits
}
for future in concurrent.futures.as_completed(future_to_sub):
sub = future_to_sub[future]
try:
data = future.result()
all_data.append(data)
except Exception as e:
st.error(f"Error fetching r/{sub}: {e}")
if all_data:
df = pd.concat(all_data, ignore_index=True)
else:
df = pd.DataFrame()
# Store in cache
st.session_state.data_cache[cache_key] = {
'data': df,
'timestamp': datetime.now()
}
return df
def create_temporal_visualizations(df: pd.DataFrame) -> Dict[str, go.Figure]:
"""
Create comprehensive temporal analytics visualizations
Args:
df: DataFrame with Reddit data
Returns:
Dictionary of Plotly figures
"""
figures = {}
# Ensure datetime column
if 'created_utc' in df.columns:
df['created_utc'] = pd.to_datetime(df['created_utc'])
df = df.sort_values('created_utc')
# Get actual date range of collected data with padding
date_min = df['created_utc'].min()
date_max = df['created_utc'].max()
date_range = (date_max - date_min).days
# Add 2% padding to prevent edge clipping
padding = pd.Timedelta(days=max(1, int(date_range * 0.02)))
date_min_padded = date_min - padding
date_max_padded = date_max + padding
# 1. Hourly activity heatmap
df['hour'] = df['created_utc'].dt.hour
df['day_of_week'] = df['created_utc'].dt.day_name()
heatmap_data = df.groupby(['day_of_week', 'hour']).size().reset_index(name='count')
pivot_data = heatmap_data.pivot(index='day_of_week', columns='hour', values='count').fillna(0)
# Reorder days
days_order = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
pivot_data = pivot_data.reindex(days_order)
fig_heatmap = go.Figure(data=go.Heatmap(
z=pivot_data.values,
x=pivot_data.columns,
y=pivot_data.index,
colorscale='RdYlBu_r',
text=pivot_data.values.astype(int),
texttemplate='%{text}',
textfont={"size": 8},
hovertemplate='%{y}
%{x}:00
Posts: %{z}