mblackman
Added min karma threshold and color logic
30f70a1
from transformers import pipeline
import gradio as gr
import praw
from praw.models import MoreComments
import os
from statistics import mean
import torch
import time
MIN_SCORE_THRESHOLD=50
NUM_VIBE_LABELS=5
device = 0
if not torch.cuda.is_available():
print("GPU isn't available. Running on CPU.")
device = -1
pipe = pipeline(model="facebook/bart-large-mnli", device=device)
reddit = praw.Reddit(
client_id=os.environ.get("reddit_client_id"),
client_secret=os.environ.get("reddit_client_secret"),
user_agent="Hugging Face Vibe Checker",
)
vibe_names = [
"wholesome",
"chill",
"funny",
"inspiring",
"aesthetic",
"nerdy",
"supportive",
"informative",
"activism",
"nostalgic",
"creative",
"memorable",
"cryptic",
"dark",
"whimsical",
"spiritual",
"intellectual",
"meme",
]
vibe_name_color_map={vibe:"red" for vibe in vibe_names}
def vibe_check(url, min_karma):
comments = get_comments(url)
#comments.sort(key=comment_compare, reverse=True)
comment_bodies = [c["Comment"] for c in comments if c["Score"] > int(min_karma)]
print("Total comments: " + str(len(comment_bodies)))
print("Starting comment classification.")
start = time.time()
classes = pipe(
comment_bodies[:1],
candidate_labels=vibe_names,
)
end = time.time()
print("Comment classification took: " + str(end - start) + "ms")
averages = {}
for i in range(len(vibe_names)):
averages[vibe_names[i]] = mean([c["scores"][i] for c in classes])
return averages
def get_vibes_html(vibes):
return " ".join(
[
f"<span style=\"color:{vibe_name_color_map[vibe]};font-size:{12 * i}px\">{vibe}</span>"
for i, vibe in enumerate(vibes)
])
def get_comments(url):
submission = reddit.submission(url=url)
comments = []
for comment in submission.comments:
if isinstance(comment, MoreComments) or comment.body == "[deleted]":
continue
val = {
"Comment": comment.body,
"Author": comment.author,
"Date Posted": comment.created_utc,
"Score": comment.score,
}
comments.append(val)
return comments
def comment_compare(comment):
return comment["Score"]
with gr.Blocks() as demo:
url = gr.Textbox(label="Url")
min_karma = gr.Textbox(label="Minimum Karma", value=MIN_SCORE_THRESHOLD)
output = gr.Label(label="Output", num_top_classes=NUM_VIBE_LABELS)
submit_button = gr.Button("Submit")
submit_button.click(
fn=vibe_check, inputs=[url, min_karma], outputs=output, api_name="vibe_check"
)
gr.Examples(
[
"https://www.reddit.com/r/AskReddit/comments/yiazab/would_you_support_a_mandatory_retirement_age_of/",
"https://www.reddit.com/r/politics/comments/yqa3cg/john_fetterman_wins_pennsylvania_senate_race/",
"https://www.reddit.com/r/pics/comments/zyj3ll/andrew_and_tristan_tate_were_arrested_they_are/",
],
url,
output,
vibe_check,
# cache_examples=True,
)
demo.theme=gr.themes.Base()
demo.launch()