from transformers import pipeline import gradio as gr import praw from praw.models import MoreComments import os from statistics import mean import torch import time MIN_SCORE_THRESHOLD=50 NUM_VIBE_LABELS=5 device = 0 if not torch.cuda.is_available(): print("GPU isn't available. Running on CPU.") device = -1 pipe = pipeline(model="facebook/bart-large-mnli", device=device) reddit = praw.Reddit( client_id=os.environ.get("reddit_client_id"), client_secret=os.environ.get("reddit_client_secret"), user_agent="Hugging Face Vibe Checker", ) vibe_names = [ "wholesome", "chill", "funny", "inspiring", "aesthetic", "nerdy", "supportive", "informative", "activism", "nostalgic", "creative", "memorable", "cryptic", "dark", "whimsical", "spiritual", "intellectual", "meme", ] vibe_name_color_map={vibe:"red" for vibe in vibe_names} def vibe_check(url, min_karma): comments = get_comments(url) #comments.sort(key=comment_compare, reverse=True) comment_bodies = [c["Comment"] for c in comments if c["Score"] > int(min_karma)] print("Total comments: " + str(len(comment_bodies))) print("Starting comment classification.") start = time.time() classes = pipe( comment_bodies[:1], candidate_labels=vibe_names, ) end = time.time() print("Comment classification took: " + str(end - start) + "ms") averages = {} for i in range(len(vibe_names)): averages[vibe_names[i]] = mean([c["scores"][i] for c in classes]) return averages def get_vibes_html(vibes): return " ".join( [ f"{vibe}" for i, vibe in enumerate(vibes) ]) def get_comments(url): submission = reddit.submission(url=url) comments = [] for comment in submission.comments: if isinstance(comment, MoreComments) or comment.body == "[deleted]": continue val = { "Comment": comment.body, "Author": comment.author, "Date Posted": comment.created_utc, "Score": comment.score, } comments.append(val) return comments def comment_compare(comment): return comment["Score"] with gr.Blocks() as demo: url = gr.Textbox(label="Url") min_karma = gr.Textbox(label="Minimum Karma", value=MIN_SCORE_THRESHOLD) output = gr.Label(label="Output", num_top_classes=NUM_VIBE_LABELS) submit_button = gr.Button("Submit") submit_button.click( fn=vibe_check, inputs=[url, min_karma], outputs=output, api_name="vibe_check" ) gr.Examples( [ "https://www.reddit.com/r/AskReddit/comments/yiazab/would_you_support_a_mandatory_retirement_age_of/", "https://www.reddit.com/r/politics/comments/yqa3cg/john_fetterman_wins_pennsylvania_senate_race/", "https://www.reddit.com/r/pics/comments/zyj3ll/andrew_and_tristan_tate_were_arrested_they_are/", ], url, output, vibe_check, # cache_examples=True, ) demo.theme=gr.themes.Base() demo.launch()