File size: 3,219 Bytes
f332dc9
 
 
 
1df8c04
b04c5a8
 
 
f332dc9
b04c5a8
30f70a1
b04c5a8
 
 
 
 
 
 
 
f332dc9
1df8c04
 
f332dc9
 
 
30f70a1
f332dc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30f70a1
f332dc9
b04c5a8
f332dc9
b04c5a8
 
 
 
f332dc9
b04c5a8
 
f332dc9
30f70a1
 
f332dc9
b04c5a8
f332dc9
b04c5a8
f332dc9
b04c5a8
30f70a1
 
b04c5a8
 
f332dc9
30f70a1
 
 
 
 
 
 
f332dc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b04c5a8
30f70a1
f332dc9
 
 
b04c5a8
f332dc9
 
 
 
 
 
 
 
 
 
 
 
 
30f70a1
 
 
f332dc9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from transformers import pipeline
import gradio as gr
import praw
from praw.models import MoreComments
import os
from statistics import mean
import torch
import time

MIN_SCORE_THRESHOLD=50
NUM_VIBE_LABELS=5

device = 0
if not torch.cuda.is_available():
    print("GPU isn't available. Running on CPU.")
    device = -1


pipe = pipeline(model="facebook/bart-large-mnli", device=device)
reddit = praw.Reddit(
    client_id=os.environ.get("reddit_client_id"),
    client_secret=os.environ.get("reddit_client_secret"),
    user_agent="Hugging Face Vibe Checker",
)

vibe_names = [
    "wholesome",
    "chill",
    "funny",
    "inspiring",
    "aesthetic",
    "nerdy",
    "supportive",
    "informative",
    "activism",
    "nostalgic",
    "creative",
    "memorable",
    "cryptic",
    "dark",
    "whimsical",
    "spiritual",
    "intellectual",
    "meme",
]

vibe_name_color_map={vibe:"red" for vibe in vibe_names}

def vibe_check(url, min_karma):
    comments = get_comments(url)
    #comments.sort(key=comment_compare, reverse=True)
    comment_bodies = [c["Comment"] for c in comments if c["Score"] > int(min_karma)]

    print("Total comments: " + str(len(comment_bodies)))

    print("Starting comment classification.")
    start = time.time()
    classes = pipe(
        comment_bodies[:1],
        candidate_labels=vibe_names,
    )
    end = time.time()

    print("Comment classification took: " + str(end - start) + "ms")

    averages = {}
    for i in range(len(vibe_names)):
        averages[vibe_names[i]] = mean([c["scores"][i] for c in classes])
        
    return averages

def get_vibes_html(vibes):
    return "    ".join(
        [
            f"<span style=\"color:{vibe_name_color_map[vibe]};font-size:{12 * i}px\">{vibe}</span>"
            for i, vibe in enumerate(vibes)
        ]) 

def get_comments(url):
    submission = reddit.submission(url=url)
    comments = []

    for comment in submission.comments:
        if isinstance(comment, MoreComments) or comment.body == "[deleted]":
            continue

        val = {
            "Comment": comment.body,
            "Author": comment.author,
            "Date Posted": comment.created_utc,
            "Score": comment.score,
        }
        comments.append(val)

    return comments


def comment_compare(comment):
    return comment["Score"]


with gr.Blocks() as demo:
    url = gr.Textbox(label="Url")
    min_karma = gr.Textbox(label="Minimum Karma", value=MIN_SCORE_THRESHOLD)
    output = gr.Label(label="Output", num_top_classes=NUM_VIBE_LABELS)
    submit_button = gr.Button("Submit")

    submit_button.click(
        fn=vibe_check, inputs=[url, min_karma], outputs=output, api_name="vibe_check"
    )

    gr.Examples(
        [
            "https://www.reddit.com/r/AskReddit/comments/yiazab/would_you_support_a_mandatory_retirement_age_of/",
            "https://www.reddit.com/r/politics/comments/yqa3cg/john_fetterman_wins_pennsylvania_senate_race/",
            "https://www.reddit.com/r/pics/comments/zyj3ll/andrew_and_tristan_tate_were_arrested_they_are/",
        ],
        url,
        output,
        vibe_check,
        # cache_examples=True,
    )
    
    
demo.theme=gr.themes.Base()
demo.launch()