File size: 4,162 Bytes
cbcecfb
3b3dbc9
cbcecfb
62f43b4
cbcecfb
 
 
 
00320ff
 
 
62f43b4
9f1606d
320952b
 
00320ff
 
62f43b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbcecfb
3b3dbc9
 
00320ff
320952b
00320ff
3cc406f
00320ff
320952b
00320ff
1d197a9
320952b
 
00320ff
9f1606d
62f43b4
 
 
00320ff
cbcecfb
 
 
 
 
62f43b4
 
 
 
f6c60e6
 
 
cbcecfb
f6c60e6
cbcecfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62f43b4
cbcecfb
 
 
 
3899fac
cbcecfb
 
 
 
 
 
 
320952b
00320ff
f6c60e6
62f43b4
cbcecfb
9f1606d
62f43b4
320952b
00320ff
3899fac
00320ff
3899fac
 
 
 
 
 
 
00320ff
 
320952b
9f1606d
3899fac
3cc406f
 
 
 
3899fac
 
 
3cc406f
3899fac
 
 
9f1606d
cbcecfb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import re
import sys
import nltk
import praw
import gradio as gr
import pandas as pd
import praw.exceptions
from transformers import pipeline


def index_chunk(a):
    n = round(0.3 * len(a))
    k, m = divmod(len(a), n)
    return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))


def sentence_chunk(a):
    sentences = []
    buffer = ""

    # the 512 token threshold is empirical

    for item in a:
        token_length_estimation = len(nltk.word_tokenize(buffer + item))

        if token_length_estimation > 512:
            sentences.append(buffer)
            buffer = ""

        buffer += item

    sentences.append(buffer)

    return sentences


def preprocessData(df):
    df["text"] = df["text"].apply(lambda x: re.sub(r"http\S+", "", x, flags=re.M))
    df["text"] = df["text"].apply(lambda x: re.sub(r"^>.+", "", x, flags=re.M))

    smax = df.score.max()

    threshold = round(0.05 * smax)

    df = df[df.score >= threshold]

    # empirically, having more than 200 comments doesn't change much but slows down the summarizer.
    if len(df.text) >= 200:
        df = df[:200]

    # chunking to handle giving the model too large of an input which crashes

    # chunked = list(index_chunk(df.text))
    chunked = sentence_chunk(df.text)

    return chunked


def getComments(url, debug=False):

    if debug and os.path.isfile('./debug_comments.csv'):
        df = pd.read_csv("./debug_comments.csv")
        return df

    client_id = os.environ['REDDIT_CLIENT_ID']
    client_secret = os.environ['REDDIT_CLIENT_SECRET']
    user_agent = os.environ['REDDIT_USER_AGENT']

    reddit = praw.Reddit(client_id=client_id, client_secret=client_secret, user_agent=user_agent)

    try:
        submission = reddit.submission(url=url)
    except praw.exceptions.InvalidURL:
        print("The URL is invalid. Make sure that you have included the submission id")

    submission.comments.replace_more(limit=0)

    cols = [
        "text",
        "score",
        "id",
        "parent_id",
        "submission_title",
        "submission_score",
        "submission_id"
    ]
    rows = []

    for comment in submission.comments.list():

        if comment.stickied:
            continue

        data = [
            comment.body,
            comment.score,
            comment.id,
            comment.parent_id,
            submission.title,
            submission.score,
            submission.id,
        ]

        rows.append(data)

    df = pd.DataFrame(data=rows, columns=cols)

    if debug:
        # save for debugging to avoid sending tons of requests to reddit

        df.to_csv('debug_comments.csv', index=False)

    return df


def summarizer(url: str) -> str:

    # pushshift.io submission comments api doesn't work so have to use praw
    df = getComments(url=url)
    chunked_df = preprocessData(df)

    submission_title = df.submission_title.unique()[0]

    lst_summaries = []

    nlp = pipeline('summarization', model="sshleifer/distilbart-cnn-12-6")

    for grp in chunked_df:
        # treating a group of comments as one block of text
        result = nlp(grp, max_length=500)[0]["summary_text"]
        lst_summaries.append(result)

    joined_summaries = ' '.join(lst_summaries).replace(" .", ".")

    total_summary = nlp(joined_summaries, max_length=500)[0]["summary_text"].replace(" .", ".")

    short_output = submission_title + '\n' + '\n' + total_summary

    long_output = submission_title + '\n' + '\n' + joined_summaries

    return short_output, long_output


if __name__ == "__main__":

    with gr.Blocks(css=".gradio-container {max-width: 900px !important; width: 100%}") as demo:
        submission_url = gr.Textbox(label='Post URL')

        sub_btn = gr.Button("Summarize")

        with gr.Row():
            short_summary = gr.Textbox(label='Short Comment Summary')
            long_summary = gr.Textbox(label='Long Comment Summary')

        sub_btn.click(fn=summarizer,
                      inputs=[submission_url],
                      outputs=[short_summary, long_summary])

    try:
        demo.launch()
    except KeyboardInterrupt:
        gr.close_all()