Spaces:
Runtime error
Runtime error
mblackman
commited on
Commit
•
30f70a1
1
Parent(s):
84422f0
Added min karma threshold and color logic
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ import torch
|
|
8 |
import time
|
9 |
|
10 |
MIN_SCORE_THRESHOLD=50
|
|
|
11 |
|
12 |
device = 0
|
13 |
if not torch.cuda.is_available():
|
@@ -22,7 +23,7 @@ reddit = praw.Reddit(
|
|
22 |
user_agent="Hugging Face Vibe Checker",
|
23 |
)
|
24 |
|
25 |
-
|
26 |
"wholesome",
|
27 |
"chill",
|
28 |
"funny",
|
@@ -43,6 +44,7 @@ vibes = [
|
|
43 |
"meme",
|
44 |
]
|
45 |
|
|
|
46 |
|
47 |
def vibe_check(url, min_karma):
|
48 |
comments = get_comments(url)
|
@@ -54,19 +56,26 @@ def vibe_check(url, min_karma):
|
|
54 |
print("Starting comment classification.")
|
55 |
start = time.time()
|
56 |
classes = pipe(
|
57 |
-
comment_bodies,
|
58 |
-
candidate_labels=
|
59 |
)
|
60 |
end = time.time()
|
61 |
|
62 |
print("Comment classification took: " + str(end - start) + "ms")
|
63 |
|
64 |
averages = {}
|
65 |
-
for i in range(len(
|
66 |
-
averages[
|
67 |
|
68 |
return averages
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
def get_comments(url):
|
71 |
submission = reddit.submission(url=url)
|
72 |
comments = []
|
@@ -93,7 +102,7 @@ def comment_compare(comment):
|
|
93 |
with gr.Blocks() as demo:
|
94 |
url = gr.Textbox(label="Url")
|
95 |
min_karma = gr.Textbox(label="Minimum Karma", value=MIN_SCORE_THRESHOLD)
|
96 |
-
output = gr.Label(label="Output", num_top_classes=
|
97 |
submit_button = gr.Button("Submit")
|
98 |
|
99 |
submit_button.click(
|
@@ -111,5 +120,7 @@ with gr.Blocks() as demo:
|
|
111 |
vibe_check,
|
112 |
# cache_examples=True,
|
113 |
)
|
114 |
-
|
|
|
|
|
115 |
demo.launch()
|
|
|
8 |
import time
|
9 |
|
10 |
MIN_SCORE_THRESHOLD=50
|
11 |
+
NUM_VIBE_LABELS=5
|
12 |
|
13 |
device = 0
|
14 |
if not torch.cuda.is_available():
|
|
|
23 |
user_agent="Hugging Face Vibe Checker",
|
24 |
)
|
25 |
|
26 |
+
vibe_names = [
|
27 |
"wholesome",
|
28 |
"chill",
|
29 |
"funny",
|
|
|
44 |
"meme",
|
45 |
]
|
46 |
|
47 |
+
vibe_name_color_map={vibe:"red" for vibe in vibe_names}
|
48 |
|
49 |
def vibe_check(url, min_karma):
|
50 |
comments = get_comments(url)
|
|
|
56 |
print("Starting comment classification.")
|
57 |
start = time.time()
|
58 |
classes = pipe(
|
59 |
+
comment_bodies[:1],
|
60 |
+
candidate_labels=vibe_names,
|
61 |
)
|
62 |
end = time.time()
|
63 |
|
64 |
print("Comment classification took: " + str(end - start) + "ms")
|
65 |
|
66 |
averages = {}
|
67 |
+
for i in range(len(vibe_names)):
|
68 |
+
averages[vibe_names[i]] = mean([c["scores"][i] for c in classes])
|
69 |
|
70 |
return averages
|
71 |
|
72 |
+
def get_vibes_html(vibes):
|
73 |
+
return " ".join(
|
74 |
+
[
|
75 |
+
f"<span style=\"color:{vibe_name_color_map[vibe]};font-size:{12 * i}px\">{vibe}</span>"
|
76 |
+
for i, vibe in enumerate(vibes)
|
77 |
+
])
|
78 |
+
|
79 |
def get_comments(url):
|
80 |
submission = reddit.submission(url=url)
|
81 |
comments = []
|
|
|
102 |
with gr.Blocks() as demo:
|
103 |
url = gr.Textbox(label="Url")
|
104 |
min_karma = gr.Textbox(label="Minimum Karma", value=MIN_SCORE_THRESHOLD)
|
105 |
+
output = gr.Label(label="Output", num_top_classes=NUM_VIBE_LABELS)
|
106 |
submit_button = gr.Button("Submit")
|
107 |
|
108 |
submit_button.click(
|
|
|
120 |
vibe_check,
|
121 |
# cache_examples=True,
|
122 |
)
|
123 |
+
|
124 |
+
|
125 |
+
demo.theme=gr.themes.Base()
|
126 |
demo.launch()
|