mblackman commited on
Commit
b04c5a8
1 Parent(s): 1df8c04

Added label UI and added min karma threshold

Browse files
Files changed (2) hide show
  1. app.py +29 -7
  2. requirements.txt +2 -1
app.py CHANGED
@@ -3,8 +3,19 @@ import gradio as gr
3
  import praw
4
  from praw.models import MoreComments
5
  import os
 
 
 
6
 
7
- pipe = pipeline(model="facebook/bart-large-mnli")
 
 
 
 
 
 
 
 
8
  reddit = praw.Reddit(
9
  client_id=os.environ.get("reddit_client_id"),
10
  client_secret=os.environ.get("reddit_client_secret"),
@@ -33,18 +44,28 @@ vibes = [
33
  ]
34
 
35
 
36
- def vibe_check(url):
37
  comments = get_comments(url)
38
- comments.sort(key=comment_compare, reverse=True)
39
- comment_bodies = [c["Comment"] for c in comments[0:10]]
 
 
40
 
 
 
41
  classes = pipe(
42
  comment_bodies,
43
  candidate_labels=vibes,
44
  )
 
45
 
46
- return classes
47
 
 
 
 
 
 
48
 
49
  def get_comments(url):
50
  submission = reddit.submission(url=url)
@@ -71,11 +92,12 @@ def comment_compare(comment):
71
 
72
  with gr.Blocks() as demo:
73
  url = gr.Textbox(label="Url")
74
- output = gr.Textbox(label="Output")
 
75
  submit_button = gr.Button("Submit")
76
 
77
  submit_button.click(
78
- fn=vibe_check, inputs=url, outputs=output, api_name="vibe_check"
79
  )
80
 
81
  gr.Examples(
 
3
  import praw
4
  from praw.models import MoreComments
5
  import os
6
+ from statistics import mean
7
+ import torch
8
+ import time
9
 
10
+ MIN_SCORE_THRESHOLD=50
11
+
12
+ device = 0
13
+ if not torch.cuda.is_available():
14
+ print("GPU isn't available. Running on CPU.")
15
+ device = -1
16
+
17
+
18
+ pipe = pipeline(model="facebook/bart-large-mnli", device=device)
19
  reddit = praw.Reddit(
20
  client_id=os.environ.get("reddit_client_id"),
21
  client_secret=os.environ.get("reddit_client_secret"),
 
44
  ]
45
 
46
 
47
+ def vibe_check(url, min_karma):
48
  comments = get_comments(url)
49
+ #comments.sort(key=comment_compare, reverse=True)
50
+ comment_bodies = [c["Comment"] for c in comments if c["Score"] > int(min_karma)]
51
+
52
+ print("Total comments: " + str(len(comment_bodies)))
53
 
54
+ print("Starting comment classification.")
55
+ start = time.time()
56
  classes = pipe(
57
  comment_bodies,
58
  candidate_labels=vibes,
59
  )
60
+ end = time.time()
61
 
62
+ print("Comment classification took: " + str(end - start) + "ms")
63
 
64
+ averages = {}
65
+ for i in range(len(vibes)):
66
+ averages[vibes[i]] = mean([c["scores"][i] for c in classes])
67
+
68
+ return averages
69
 
70
  def get_comments(url):
71
  submission = reddit.submission(url=url)
 
92
 
93
  with gr.Blocks() as demo:
94
  url = gr.Textbox(label="Url")
95
+ min_karma = gr.Textbox(label="Minimum Karma", value=MIN_SCORE_THRESHOLD)
96
+ output = gr.Label(label="Output", num_top_classes=5)
97
  submit_button = gr.Button("Submit")
98
 
99
  submit_button.click(
100
+ fn=vibe_check, inputs=[url, min_karma], outputs=output, api_name="vibe_check"
101
  )
102
 
103
  gr.Examples(
requirements.txt CHANGED
@@ -2,4 +2,5 @@ pylint
2
  transformers
3
  torch
4
  praw
5
- gradio
 
 
2
  transformers
3
  torch
4
  praw
5
+ gradio
6
+ torch