derek-thomas's picture
derek-thomas HF staff
Major updates from sister repo
5d9e0b8
raw
history blame
No virus
2.38 kB
import os
from datetime import datetime
from typing import Any, Dict, List
import praw
from utilities.my_logger import setup_logger
# Setup logging
logger = setup_logger(__name__)
# Get subreddit
subreddit_var = os.getenv("SUBREDDIT")
reddit_pull_limit = int(os.getenv("REDDIT_PULL_LIMIT"))
# Dummy row for when we create a new repo
dummy_data = {
"content": ["This is a sample post content. Just for demonstration purposes!"],
"poster": ["sampleUser123"],
"date_utc": [datetime.strptime("2023-10-26 14:30:45", '%Y-%m-%d %H:%M:%S')],
"flair": ["Discussion"],
"title": ["Sample Post Title: How to Use Hugging Face?"],
"score": [457],
"permalink": ["/r/sampleSubreddit/comments/sampleID/sample_post_title_how_to_use_hugging_face/"],
"id": ['id']
}
def get_reddit_instance() -> praw.Reddit:
"""Initialize and return a Reddit instance using PRAW."""
return praw.Reddit(
client_id=os.getenv('REDDIT_CLIENT_ID'),
client_secret=os.getenv('REDDIT_CLIENT_SECRET'),
user_agent=os.getenv('REDDIT_USER_AGENT'),
ratelimit_seconds=20,
)
def extract_submission_data(submission: praw.models.Submission) -> Dict[str, Any]:
"""Extract and return relevant data from a given Reddit submission."""
return {
"content": submission.selftext,
"poster": str(submission.author),
"date_utc": datetime.utcfromtimestamp(submission.created_utc).strftime('%Y-%m-%d %H:%M:%S'),
"flair": submission.link_flair_text,
"title": submission.title,
"score": submission.ups,
"permalink": submission.permalink,
}
def praw_downloader() -> List[Dict[str, str]]:
"""Main function to extract and save all submissions from the subreddit."""
reddit = get_reddit_instance()
subreddit = reddit.subreddit(subreddit_var)
logger.info(f'Starting to fetch submissions from {os.getenv("SUBREDDIT")}.')
submissions = []
for submission in subreddit.new(limit=reddit_pull_limit): # Set limit=None to get all posts
# logger.debug(f'Processing post {submission.id} - {submission.title}')
data = extract_submission_data(submission)
submissions.append(data)
logger.info(f'Finished downloading {len(submissions)} submissions.')
return submissions
if __name__ == "__main__":
praw_downloader()