Spaces:
Runtime error
Runtime error
from transformers import pipeline | |
import gradio as gr | |
import praw | |
from praw.models import MoreComments | |
import os | |
from statistics import mean | |
import torch | |
import time | |
MIN_SCORE_THRESHOLD=50 | |
NUM_VIBE_LABELS=5 | |
device = 0 | |
if not torch.cuda.is_available(): | |
print("GPU isn't available. Running on CPU.") | |
device = -1 | |
pipe = pipeline(model="facebook/bart-large-mnli", device=device) | |
reddit = praw.Reddit( | |
client_id=os.environ.get("reddit_client_id"), | |
client_secret=os.environ.get("reddit_client_secret"), | |
user_agent="Hugging Face Vibe Checker", | |
) | |
vibe_names = [ | |
"wholesome", | |
"chill", | |
"funny", | |
"inspiring", | |
"aesthetic", | |
"nerdy", | |
"supportive", | |
"informative", | |
"activism", | |
"nostalgic", | |
"creative", | |
"memorable", | |
"cryptic", | |
"dark", | |
"whimsical", | |
"spiritual", | |
"intellectual", | |
"meme", | |
] | |
vibe_name_color_map={vibe:"red" for vibe in vibe_names} | |
def vibe_check(url, min_karma): | |
comments = get_comments(url) | |
#comments.sort(key=comment_compare, reverse=True) | |
comment_bodies = [c["Comment"] for c in comments if c["Score"] > int(min_karma)] | |
print("Total comments: " + str(len(comment_bodies))) | |
print("Starting comment classification.") | |
start = time.time() | |
classes = pipe( | |
comment_bodies[:1], | |
candidate_labels=vibe_names, | |
) | |
end = time.time() | |
print("Comment classification took: " + str(end - start) + "ms") | |
averages = {} | |
for i in range(len(vibe_names)): | |
averages[vibe_names[i]] = mean([c["scores"][i] for c in classes]) | |
return averages | |
def get_vibes_html(vibes): | |
return " ".join( | |
[ | |
f"<span style=\"color:{vibe_name_color_map[vibe]};font-size:{12 * i}px\">{vibe}</span>" | |
for i, vibe in enumerate(vibes) | |
]) | |
def get_comments(url): | |
submission = reddit.submission(url=url) | |
comments = [] | |
for comment in submission.comments: | |
if isinstance(comment, MoreComments) or comment.body == "[deleted]": | |
continue | |
val = { | |
"Comment": comment.body, | |
"Author": comment.author, | |
"Date Posted": comment.created_utc, | |
"Score": comment.score, | |
} | |
comments.append(val) | |
return comments | |
def comment_compare(comment): | |
return comment["Score"] | |
with gr.Blocks() as demo: | |
url = gr.Textbox(label="Url") | |
min_karma = gr.Textbox(label="Minimum Karma", value=MIN_SCORE_THRESHOLD) | |
output = gr.Label(label="Output", num_top_classes=NUM_VIBE_LABELS) | |
submit_button = gr.Button("Submit") | |
submit_button.click( | |
fn=vibe_check, inputs=[url, min_karma], outputs=output, api_name="vibe_check" | |
) | |
gr.Examples( | |
[ | |
"https://www.reddit.com/r/AskReddit/comments/yiazab/would_you_support_a_mandatory_retirement_age_of/", | |
"https://www.reddit.com/r/politics/comments/yqa3cg/john_fetterman_wins_pennsylvania_senate_race/", | |
"https://www.reddit.com/r/pics/comments/zyj3ll/andrew_and_tristan_tate_were_arrested_they_are/", | |
], | |
url, | |
output, | |
vibe_check, | |
# cache_examples=True, | |
) | |
demo.theme=gr.themes.Base() | |
demo.launch() | |