summit / app.py
kkastr
forgot to return visibility status from function
717b41f
import os
import re
import sys
import nltk
import praw
import matplotlib
from tqdm import tqdm
import gradio as gr
import pandas as pd
import praw.exceptions
import matplotlib.pyplot as plt
from wordcloud import WordCloud
from transformers import pipeline
matplotlib.use('Agg')
def index_chunk(a):
n = round(0.3 * len(a))
k, m = divmod(len(a), n)
return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))
def sentence_chunk(a):
sentences = []
buffer = ""
# the 512 token threshold is empirical
for item in a:
token_length_estimation = len(nltk.word_tokenize(buffer + item))
if token_length_estimation > 512:
sentences.append(buffer)
buffer = ""
buffer += item
sentences.append(buffer)
return sentences
def preprocessData(df):
df["text"] = df["text"].apply(lambda x: re.sub(r"http\S+", "", x, flags=re.M))
df["text"] = df["text"].apply(lambda x: re.sub(r"^>.+", "", x, flags=re.M))
# The df is sorted by comment score
# Empirically, having more than ~100 comments doesn't change much but slows down the summarizer.
# Slowdown is not present with load api but still seems good to limit low score comments.
if len(df.text) >= 128:
df = df[:128]
# chunking to handle giving the model too large of an input which crashes
chunked = sentence_chunk(df.text)
return chunked
def getComments(url, debug=False):
if debug and os.path.isfile('./debug_comments.csv'):
df = pd.read_csv("./debug_comments.csv")
return df
client_id = os.environ['REDDIT_CLIENT_ID']
client_secret = os.environ['REDDIT_CLIENT_SECRET']
user_agent = os.environ['REDDIT_USER_AGENT']
reddit = praw.Reddit(client_id=client_id, client_secret=client_secret, user_agent=user_agent)
try:
submission = reddit.submission(url=url)
except praw.exceptions.InvalidURL:
print("The URL is invalid. Make sure that you have included the submission id")
submission.comments.replace_more(limit=0)
cols = [
"text",
"score",
"id",
"parent_id",
"submission_title",
"submission_score",
"submission_id"
]
rows = []
for comment in submission.comments.list():
if comment.stickied:
continue
data = [
comment.body,
comment.score,
comment.id,
comment.parent_id,
submission.title,
submission.score,
submission.id,
]
rows.append(data)
df = pd.DataFrame(data=rows, columns=cols)
if debug:
# save for debugging to avoid sending tons of requests to reddit
df.to_csv('debug_comments.csv', index=False)
return df
def summarizer(url: str) -> str:
# pushshift.io submission comments api doesn't work so have to use praw
df = getComments(url=url)
submission_title = '## ' + df.submission_title.unique()[0]
chunked_df = preprocessData(df)
text = ' '.join(chunked_df)
# transparent bg: background_color=None, mode='RGBA')
wc_opts = dict(collocations=False, width=1920, height=1080)
wcloud = WordCloud(**wc_opts).generate(text)
plt.imshow(wcloud, aspect='auto')
plt.axis("off")
plt.gca().set_position([0, 0, 1, 1])
plt.autoscale(tight=True)
fig = plt.gcf()
fig.patch.set_alpha(0.0)
fig.set_size_inches((12, 7))
lst_summaries = []
for grp in tqdm(chunked_df):
# treating a group of comments as one block of text
result = sum_api(grp)
lst_summaries.append(result)
long_output = ' '.join(lst_summaries).replace(" .", ".")
short_output = sum_api(long_output).replace(" .", ".")
sentiment = clf_api(short_output)
return gr.update(visible=True), submission_title, short_output, long_output, sentiment, fig
if __name__ == "__main__":
sum_model = "models/sshleifer/distilbart-cnn-12-6"
clf_model = "models/finiteautomata/bertweet-base-sentiment-analysis"
hf_token = os.environ["HF_TOKEN"]
sum_api = gr.Interface.load(sum_model, api_key=hf_token)
clf_api = gr.Interface.load(clf_model, api_key=hf_token)
sample_urls = ['https://www.reddit.com/r/wholesome/comments/10ehlxo/he_got_a_strong_message/']
with gr.Blocks(css=".gradio-container {max-width: 900px !important; width: 100%}") as demo:
submission_url = gr.Textbox(label='Post URL')
sub_btn = gr.Button("Summarize")
title = gr.Markdown("")
with gr.Column(visible=False) as result_col:
with gr.Row():
with gr.Column():
short_summary = gr.Textbox(label='Short Summary')
summary_sentiment = gr.Label(label='Sentiment')
thread_cloud = gr.Plot(label='Word Cloud')
long_summary = gr.Textbox(label='Long Summary')
out_lst = [result_col, title, short_summary, long_summary, summary_sentiment, thread_cloud]
sub_btn.click(fn=summarizer, inputs=[submission_url], outputs=out_lst)
examples = gr.Examples(examples=sample_urls,
fn=summarizer,
inputs=[submission_url],
outputs=out_lst,
cache_examples=True)
try:
demo.launch()
except KeyboardInterrupt:
gr.close_all()