|
|
|
from typing import List, Set |
|
from collections import namedtuple |
|
import random |
|
import requests |
|
import json |
|
from datetime import datetime as dt |
|
from codetiming import Timer |
|
import streamlit as st |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from matplotlib import pyplot as plt |
|
|
|
from digestor import Digestor |
|
from source import Source |
|
from scrape_sources import NPRLite, CNNText, stub |
|
|
|
|
|
|
|
def initialize(limit, rando, use_cache=True): |
|
clusters: dict[str:List[namedtuple]] = dict() |
|
|
|
|
|
sources:List[Source]= [] |
|
|
|
|
|
sources.append(NPRLite( |
|
'npr', |
|
'https://text.npr.org/1001', |
|
'sshleifer/distilbart-cnn-12-6', |
|
|
|
'dbmdz/bert-large-cased-finetuned-conll03-english' |
|
)) |
|
sources.append(CNNText( |
|
'cnn', |
|
'https://lite.cnn.com', |
|
'sshleifer/distilbart-cnn-12-6', |
|
|
|
'dbmdz/bert-large-cased-finetuned-conll03-english' |
|
)) |
|
|
|
|
|
|
|
cluster_data: List[namedtuple('article', ['link','hed','entities', 'source'])] |
|
article_dict : dict[str:namedtuple] |
|
|
|
|
|
|
|
|
|
cluster_data = [] |
|
article_meta = namedtuple('article_meta',['source', 'count']) |
|
cluster_meta : List[article_meta] = [] |
|
for data_source in sources: |
|
if limit is not None: |
|
|
|
c_data, c_meta = data_source.retrieve_cluster_data(limit//len(sources)) |
|
else: |
|
c_data, c_meta = data_source.retrieve_cluster_data() |
|
cluster_data.append(c_data) |
|
cluster_meta.append(article_meta(data_source.source_name, c_meta)) |
|
st.session_state[data_source.source_name] = f"Number of articles from source: {c_meta}" |
|
|
|
cluster_data = cluster_data[0] + cluster_data[1] |
|
|
|
|
|
for tup in cluster_data: |
|
|
|
|
|
|
|
perform_ner(tup, cache=use_cache) |
|
generate_clusters(clusters, tup) |
|
st.session_state['num_clusters'] = f"""Total number of clusters: {len(clusters)}""" |
|
|
|
|
|
|
|
|
|
article_dict = {stub.hed: stub for stub in cluster_data} |
|
|
|
|
|
return article_dict, clusters |
|
|
|
|
|
|
|
def perform_ner(tup:namedtuple('article',['link','hed','entities', 'source']), cache=True): |
|
with Timer(name="ner_query_time", logger=None): |
|
result = ner_results(ner_query( |
|
{ |
|
"inputs":tup.hed, |
|
"paramters": |
|
{ |
|
"use_cache": cache, |
|
}, |
|
} |
|
)) |
|
for i in result: |
|
tup.entities.append(i) |
|
|
|
|
|
def ner_query(payload): |
|
data = json.dumps(payload) |
|
response = requests.request("POST", NER_API_URL, headers=headers, data=data) |
|
return json.loads(response.content.decode("utf-8")) |
|
|
|
|
|
|
|
def generate_clusters( |
|
the_dict: dict, |
|
tup : namedtuple('article_stub',[ 'link','hed','entities', 'source']) |
|
) -> dict: |
|
for entity in tup.entities: |
|
|
|
if entity not in the_dict: |
|
the_dict[entity] = [] |
|
|
|
the_dict[entity].append(tup) |
|
|
|
|
|
def ner_results(ner_object, groups=True, NER_THRESHOLD=0.5) -> List[str]: |
|
|
|
people, places, orgs, misc = [], [], [], [] |
|
|
|
|
|
|
|
ent = 'entity' if not groups else 'entity_group' |
|
designation = 'I-' if not groups else '' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
actions = {designation+'PER':people.append, |
|
designation+'LOC':places.append, |
|
designation+'ORG':orgs.append, |
|
designation+'MISC':misc.append |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
readable = [ actions[d[ent]](d['word']) for d in ner_object if '#' not in d['word'] and d['score'] > NER_THRESHOLD ] |
|
|
|
|
|
ner_list = [i for i in set(people) if len(i) > 2] + [i for i in set(places) if len(i) > 2] + [i for i in set(orgs) if len(i) > 2] + [i for i in set(misc) if len(i) > 2] |
|
|
|
return ner_list |
|
|
|
|
|
|
|
|
|
NER_API_URL = "https://api-inference.huggingface.co/models/dbmdz/bert-large-cased-finetuned-conll03-english" |
|
headers = {"Authorization": f"""Bearer {st.secrets['ato']}"""} |
|
|
|
LIMIT = 20 |
|
USE_CACHE = True |
|
|
|
if not USE_CACHE: |
|
print("NOT USING CACHE") |
|
if LIMIT is not None: |
|
print(f"LIMIT: {LIMIT}") |
|
|
|
|
|
digests = dict() |
|
out_dicts = [] |
|
|
|
|
|
|
|
print("Initializing....") |
|
article_dict, clusters = initialize(LIMIT, USE_CACHE) |
|
|
|
|
|
|
|
|
|
|
|
if st.button("Refresh topics!"): |
|
article_dict, clusters = initialize(LIMIT, USE_CACHE) |
|
|
|
selections = [] |
|
choices = list(clusters.keys()) |
|
choices.insert(0,'None') |
|
|
|
st.write(st.session_state['cnn']) |
|
st.write(st.session_state['npr']) |
|
st.write(st.session_state['num_clusters']) |
|
|
|
st.session_state['dt'] = dt.now() |
|
|
|
|
|
|
|
with st.form(key='columns_in_form'): |
|
cols = st.columns(3) |
|
for i, col in enumerate(cols): |
|
selections.append(col.selectbox(f'Make a Selection', choices, key=i)) |
|
submitted = st.form_submit_button('Submit') |
|
if submitted: |
|
st.write("Submitted.\nWhat you'll see:\n\t•The links of the articles being summarized for you digest.\n\t•The digest.\n\t•A graph showing the reduction in articles lengths from original to summary, for each article.\n\t•Some probably issues with the summary.") |
|
selections = [i for i in selections if i is not None] |
|
with st.spinner(text="Digesting...please wait, this will take a few moments...Maybe check some messages or start reading the latest papers on summarization with transformers....\n\nTransformers are so called because they make heavy use of mathematical transformations on the input data, in order to detect contextually related words or phrases in a sentence. \nThis project uses a checkpoint called distilbart-cnn-12-6, created by Sam Shleifer ()."): |
|
chosen = [] |
|
|
|
for i in selections: |
|
if i != 'None': |
|
for j in clusters[i]: |
|
if j not in chosen: |
|
chosen.append(j) |
|
|
|
|
|
|
|
|
|
digestor = Digestor(timer=Timer(), cache = USE_CACHE, stubs=chosen, user_choices=selections) |
|
|
|
|
|
digestor.digest() |
|
|
|
|
|
|
|
|
|
outdata = digestor.build_digest() |
|
|
|
if len(digestor.text) == 0: |
|
st.write("No text to return...huh.") |
|
else: |
|
st.write("Your digest is ready:\n") |
|
st.write(digestor.text) |
|
st.write(f"""Text approximately {len(digestor.text.split(" ") )} words.""") |
|
st.write(f"""Number of articles summarized: {outdata['article_count']}""") |
|
|
|
|
|
st.success(f"""Digest completed in {digestor.timer.timers['digest_time']} seconds.""") |
|
|
|
st.write("Here are some stats about the summarization:\n") |
|
|
|
|
|
labels = [i for i in range(outdata['article_count'])] |
|
original_length = [outdata['summaries'][i]['original_length'] for i in outdata['summaries']] |
|
summarized_length = [outdata['summaries'][i]['summary_length'] for i in outdata['summaries']] |
|
x = np.arange(len(labels)) |
|
width = 0.35 |
|
|
|
fig, ax = plt.subplots(figsize=(14,8)) |
|
rects1 = ax.bar(x - width/2, original_length, width, color='lightgreen',zorder=0) |
|
rects2 = ax.bar(x + width/2, summarized_length, width, color='lightblue',zorder=0) |
|
|
|
rects3 = ax.bar(x - width/2, original_length, width, color='none',edgecolor='black', hatch='XX', lw=1.25,zorder=1) |
|
rects4 = ax.bar(x + width/2, summarized_length, width, color='none',edgecolor='black', hatch='xx', lw=1.25,zorder=1) |
|
|
|
|
|
ax.set_ylabel('Text Length') |
|
ax.set_xticks(x) |
|
ax.set_yticks([i for i in range(0,max(original_length),max(summarized_length))]) |
|
ax.set_xticklabels(labels) |
|
ax.set_xlabel('Article') |
|
|
|
plt.title('Original to Summarized Lengths in Space-Separated Tokens') |
|
|
|
st.pyplot(fig) |
|
|
|
|
|
"st.session_state object:", st.session_state |