File size: 5,452 Bytes
cbcecfb
3b3dbc9
cbcecfb
62f43b4
cbcecfb
77d9839
7a8513b
cbcecfb
 
 
62984a8
 
00320ff
 
77d9839
 
00320ff
62f43b4
9f1606d
320952b
 
00320ff
 
62f43b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbcecfb
3b3dbc9
 
00320ff
7a8513b
 
 
 
 
00320ff
9f1606d
62f43b4
00320ff
cbcecfb
 
 
 
 
62f43b4
 
 
 
f6c60e6
 
 
cbcecfb
f6c60e6
cbcecfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62f43b4
cbcecfb
 
 
 
3899fac
cbcecfb
 
 
7a8513b
6dccd1c
7a8513b
cbcecfb
 
62984a8
7a8513b
62984a8
 
 
7a8513b
62984a8
7a8513b
 
 
 
 
e3c8b51
320952b
00320ff
7a8513b
9f1606d
7a8513b
320952b
00320ff
7a8513b
00320ff
7a8513b
3899fac
7a8513b
3899fac
717b41f
00320ff
 
320952b
9f1606d
7a8513b
 
 
 
 
 
 
 
3613d05
 
3899fac
3cc406f
 
 
 
7a8513b
 
3613d05
 
 
 
 
 
 
 
 
 
 
62984a8
3613d05
3cc406f
3613d05
 
 
 
 
9f1606d
cbcecfb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import re
import sys
import nltk
import praw
import matplotlib
from tqdm import tqdm
import gradio as gr
import pandas as pd
import praw.exceptions
import matplotlib.pyplot as plt
from wordcloud import WordCloud
from transformers import pipeline

matplotlib.use('Agg')


def index_chunk(a):
    n = round(0.3 * len(a))
    k, m = divmod(len(a), n)
    return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))


def sentence_chunk(a):
    sentences = []
    buffer = ""

    # the 512 token threshold is empirical

    for item in a:
        token_length_estimation = len(nltk.word_tokenize(buffer + item))

        if token_length_estimation > 512:
            sentences.append(buffer)
            buffer = ""

        buffer += item

    sentences.append(buffer)

    return sentences


def preprocessData(df):
    df["text"] = df["text"].apply(lambda x: re.sub(r"http\S+", "", x, flags=re.M))
    df["text"] = df["text"].apply(lambda x: re.sub(r"^>.+", "", x, flags=re.M))

    # The df is sorted by comment score
    # Empirically, having more than ~100 comments doesn't change much but slows down the summarizer.
    # Slowdown is not present with load api but still seems good to limit low score comments.
    if len(df.text) >= 128:
        df = df[:128]

    # chunking to handle giving the model too large of an input which crashes
    chunked = sentence_chunk(df.text)

    return chunked


def getComments(url, debug=False):

    if debug and os.path.isfile('./debug_comments.csv'):
        df = pd.read_csv("./debug_comments.csv")
        return df

    client_id = os.environ['REDDIT_CLIENT_ID']
    client_secret = os.environ['REDDIT_CLIENT_SECRET']
    user_agent = os.environ['REDDIT_USER_AGENT']

    reddit = praw.Reddit(client_id=client_id, client_secret=client_secret, user_agent=user_agent)

    try:
        submission = reddit.submission(url=url)
    except praw.exceptions.InvalidURL:
        print("The URL is invalid. Make sure that you have included the submission id")

    submission.comments.replace_more(limit=0)

    cols = [
        "text",
        "score",
        "id",
        "parent_id",
        "submission_title",
        "submission_score",
        "submission_id"
    ]
    rows = []

    for comment in submission.comments.list():

        if comment.stickied:
            continue

        data = [
            comment.body,
            comment.score,
            comment.id,
            comment.parent_id,
            submission.title,
            submission.score,
            submission.id,
        ]

        rows.append(data)

    df = pd.DataFrame(data=rows, columns=cols)

    if debug:
        # save for debugging to avoid sending tons of requests to reddit

        df.to_csv('debug_comments.csv', index=False)

    return df


def summarizer(url: str) -> str:

    # pushshift.io submission comments api doesn't work so have to use praw
    df = getComments(url=url)

    submission_title = '## ' + df.submission_title.unique()[0]

    chunked_df = preprocessData(df)

    text = ' '.join(chunked_df)
    # transparent bg: background_color=None, mode='RGBA')
    wc_opts = dict(collocations=False, width=1920, height=1080)
    wcloud = WordCloud(**wc_opts).generate(text)

    plt.imshow(wcloud, aspect='auto')
    plt.axis("off")
    plt.gca().set_position([0, 0, 1, 1])
    plt.autoscale(tight=True)
    fig = plt.gcf()
    fig.patch.set_alpha(0.0)
    fig.set_size_inches((12, 7))

    lst_summaries = []

    for grp in tqdm(chunked_df):
        # treating a group of comments as one block of text
        result = sum_api(grp)
        lst_summaries.append(result)

    long_output = ' '.join(lst_summaries).replace(" .", ".")

    short_output = sum_api(long_output).replace(" .", ".")

    sentiment = clf_api(short_output)

    return gr.update(visible=True), submission_title, short_output, long_output, sentiment, fig


if __name__ == "__main__":

    sum_model = "models/sshleifer/distilbart-cnn-12-6"
    clf_model = "models/finiteautomata/bertweet-base-sentiment-analysis"

    hf_token = os.environ["HF_TOKEN"]

    sum_api = gr.Interface.load(sum_model, api_key=hf_token)
    clf_api = gr.Interface.load(clf_model, api_key=hf_token)

    sample_urls = ['https://www.reddit.com/r/wholesome/comments/10ehlxo/he_got_a_strong_message/']

    with gr.Blocks(css=".gradio-container {max-width: 900px !important; width: 100%}") as demo:
        submission_url = gr.Textbox(label='Post URL')

        sub_btn = gr.Button("Summarize")

        title = gr.Markdown("")

        with gr.Column(visible=False) as result_col:
            with gr.Row():
                with gr.Column():
                    short_summary = gr.Textbox(label='Short Summary')
                    summary_sentiment = gr.Label(label='Sentiment')

                thread_cloud = gr.Plot(label='Word Cloud')

            long_summary = gr.Textbox(label='Long Summary')

        out_lst = [result_col, title, short_summary, long_summary, summary_sentiment, thread_cloud]

        sub_btn.click(fn=summarizer, inputs=[submission_url], outputs=out_lst)

        examples = gr.Examples(examples=sample_urls,
                               fn=summarizer,
                               inputs=[submission_url],
                               outputs=out_lst,
                               cache_examples=True)

    try:
        demo.launch()
    except KeyboardInterrupt:
        gr.close_all()