summit / app.py
kkastr
branch for huggingface spaces
f6c60e6
raw
history blame
No virus
4.16 kB
import os
import re
import sys
import nltk
import praw
import gradio as gr
import pandas as pd
import praw.exceptions
from transformers import pipeline
def index_chunk(a):
n = round(0.3 * len(a))
k, m = divmod(len(a), n)
return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))
def sentence_chunk(a):
sentences = []
buffer = ""
# the 512 token threshold is empirical
for item in a:
token_length_estimation = len(nltk.word_tokenize(buffer + item))
if token_length_estimation > 512:
sentences.append(buffer)
buffer = ""
buffer += item
sentences.append(buffer)
return sentences
def preprocessData(df):
df["text"] = df["text"].apply(lambda x: re.sub(r"http\S+", "", x, flags=re.M))
df["text"] = df["text"].apply(lambda x: re.sub(r"^>.+", "", x, flags=re.M))
smax = df.score.max()
threshold = round(0.05 * smax)
df = df[df.score >= threshold]
# empirically, having more than 200 comments doesn't change much but slows down the summarizer.
if len(df.text) >= 200:
df = df[:200]
# chunking to handle giving the model too large of an input which crashes
# chunked = list(index_chunk(df.text))
chunked = sentence_chunk(df.text)
return chunked
def getComments(url, debug=False):
if debug and os.path.isfile('./debug_comments.csv'):
df = pd.read_csv("./debug_comments.csv")
return df
client_id = os.environ['REDDIT_CLIENT_ID']
client_secret = os.environ['REDDIT_CLIENT_SECRET']
user_agent = os.environ['REDDIT_USER_AGENT']
reddit = praw.Reddit(client_id=client_id, client_secret=client_secret, user_agent=user_agent)
try:
submission = reddit.submission(url=url)
except praw.exceptions.InvalidURL:
print("The URL is invalid. Make sure that you have included the submission id")
submission.comments.replace_more(limit=0)
cols = [
"text",
"score",
"id",
"parent_id",
"submission_title",
"submission_score",
"submission_id"
]
rows = []
for comment in submission.comments.list():
if comment.stickied:
continue
data = [
comment.body,
comment.score,
comment.id,
comment.parent_id,
submission.title,
submission.score,
submission.id,
]
rows.append(data)
df = pd.DataFrame(data=rows, columns=cols)
if debug:
# save for debugging to avoid sending tons of requests to reddit
df.to_csv('debug_comments.csv', index=False)
return df
def summarizer(url: str) -> str:
# pushshift.io submission comments api doesn't work so have to use praw
df = getComments(url=url)
chunked_df = preprocessData(df)
submission_title = df.submission_title.unique()[0]
lst_summaries = []
nlp = pipeline('summarization', model="sshleifer/distilbart-cnn-12-6")
for grp in chunked_df:
# treating a group of comments as one block of text
result = nlp(grp, max_length=500)[0]["summary_text"]
lst_summaries.append(result)
joined_summaries = ' '.join(lst_summaries).replace(" .", ".")
total_summary = nlp(joined_summaries, max_length=500)[0]["summary_text"].replace(" .", ".")
short_output = submission_title + '\n' + '\n' + total_summary
long_output = submission_title + '\n' + '\n' + joined_summaries
return short_output, long_output
if __name__ == "__main__":
with gr.Blocks(css=".gradio-container {max-width: 900px !important; width: 100%}") as demo:
submission_url = gr.Textbox(label='Post URL')
sub_btn = gr.Button("Summarize")
with gr.Row():
short_summary = gr.Textbox(label='Short Comment Summary')
long_summary = gr.Textbox(label='Long Comment Summary')
sub_btn.click(fn=summarizer,
inputs=[submission_url],
outputs=[short_summary, long_summary])
try:
demo.launch()
except KeyboardInterrupt:
gr.close_all()