|
|
|
from typing import List, Set |
|
from collections import namedtuple |
|
import random |
|
import requests |
|
import json |
|
from datetime import datetime as dt |
|
from codetiming import Timer |
|
import streamlit as st |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from matplotlib import pyplot as plt |
|
|
|
from digestor import Digestor |
|
from source import Source |
|
from scrape_sources import NPRLite, CNNText, stub |
|
|
|
|
|
|
|
def initialize(limit, rando, use_cache=True): |
|
clusters: dict[str:List[namedtuple]] = dict() |
|
|
|
|
|
sources:List[Source]= [] |
|
|
|
|
|
sources.append(NPRLite( |
|
'npr', |
|
'https://text.npr.org/1001', |
|
'sshleifer/distilbart-cnn-12-6', |
|
|
|
'dbmdz/bert-large-cased-finetuned-conll03-english' |
|
)) |
|
sources.append(CNNText( |
|
'cnn', |
|
'https://lite.cnn.com', |
|
'sshleifer/distilbart-cnn-12-6', |
|
|
|
'dbmdz/bert-large-cased-finetuned-conll03-english' |
|
)) |
|
|
|
|
|
|
|
cluster_data: List[namedtuple('article', ['link','hed','entities', 'source'])] |
|
article_dict : dict[str:namedtuple] |
|
|
|
|
|
|
|
|
|
cluster_data = [] |
|
article_meta = namedtuple('article_meta',['source', 'count']) |
|
cluster_meta : List[article_meta] = [] |
|
for data_source in sources: |
|
if limit is not None: |
|
|
|
c_data, c_meta = data_source.retrieve_cluster_data(limit//len(sources)) |
|
else: |
|
c_data, c_meta = data_source.retrieve_cluster_data() |
|
cluster_data.append(c_data) |
|
cluster_meta.append(article_meta(data_source.source_name, c_meta)) |
|
st.session_state[data_source.source_name] = f"Number of articles from source: {c_meta}" |
|
|
|
cluster_data = cluster_data[0] + cluster_data[1] |
|
|
|
|
|
for tup in cluster_data: |
|
|
|
|
|
|
|
perform_ner(tup, cache=use_cache) |
|
generate_clusters(clusters, tup) |
|
st.session_state['num_clusters'] = f"""Total number of clusters: {len(clusters)}""" |
|
|
|
|
|
|
|
|
|
article_dict = {stub.hed: stub for stub in cluster_data} |
|
|
|
|
|
return article_dict, clusters |
|
|
|
|
|
|
|
def perform_ner(tup:namedtuple('article',['link','hed','entities', 'source']), cache=True): |
|
with Timer(name="ner_query_time", logger=None): |
|
result = ner_results(ner_query( |
|
{ |
|
"inputs":tup.hed, |
|
"paramters": |
|
{ |
|
"use_cache": cache, |
|
}, |
|
} |
|
)) |
|
for i in result: |
|
tup.entities.append(i) |
|
|
|
|
|
def ner_query(payload): |
|
data = json.dumps(payload) |
|
response = requests.request("POST", NER_API_URL, headers=headers, data=data) |
|
return json.loads(response.content.decode("utf-8")) |
|
|
|
|
|
|
|
def generate_clusters( |
|
the_dict: dict, |
|
tup : namedtuple('article_stub',[ 'link','hed','entities', 'source']) |
|
) -> dict: |
|
for entity in tup.entities: |
|
|
|
if entity not in the_dict: |
|
the_dict[entity] = [] |
|
|
|
the_dict[entity].append(tup) |
|
|
|
|
|
def ner_results(ner_object, groups=True, NER_THRESHOLD=0.5) -> List[str]: |
|
|
|
people, places, orgs, misc = [], [], [], [] |
|
|
|
|
|
|
|
ent = 'entity' if not groups else 'entity_group' |
|
designation = 'I-' if not groups else '' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
actions = {designation+'PER':people.append, |
|
designation+'LOC':places.append, |
|
designation+'ORG':orgs.append, |
|
designation+'MISC':misc.append |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
readable = [ actions[d[ent]](d['word']) for d in ner_object if '#' not in d['word'] and d['score'] > NER_THRESHOLD ] |
|
|
|
|
|
ner_list = [i for i in set(people) if len(i) > 2] + [i for i in set(places) if len(i) > 2] + [i for i in set(orgs) if len(i) > 2] + [i for i in set(misc) if len(i) > 2] |
|
|
|
return ner_list |
|
|
|
|
|
|
|
|
|
NER_API_URL = "https://api-inference.huggingface.co/models/dbmdz/bert-large-cased-finetuned-conll03-english" |
|
headers = {"Authorization": f"""Bearer {st.secrets['ato']}"""} |
|
|
|
LIMIT = 20 |
|
USE_CACHE = True |
|
|
|
if not USE_CACHE: |
|
print("NOT USING CACHE") |
|
if LIMIT is not None: |
|
print(f"LIMIT: {LIMIT}") |
|
|
|
|
|
digests = dict() |
|
out_dicts = [] |
|
|
|
|
|
|
|
print("Initializing....") |
|
article_dict, clusters = initialize(LIMIT, USE_CACHE) |
|
|
|
|
|
|
|
|
|
st.write(f"Welcome to TopigDig! You'll see three drop downs containing entities recognized in current news headlines. Select 1 to 3 topics and TopicDig will generate a digest consisting of summarizations of the articles related to the terms you select! It's a good idea to hit the refresh topics button the first time you run. This checks the current headlines and updates the lists. Named entity recognition and summarization are done using transformers from the HuggingFace model hub.") |
|
|
|
|
|
if st.button("Refresh topics!"): |
|
article_dict, clusters = initialize(LIMIT, USE_CACHE) |
|
|
|
selections = [] |
|
choices = list(clusters.keys()) |
|
choices.insert(0,'None') |
|
|
|
st.write(st.session_state['cnn']) |
|
st.write(st.session_state['npr']) |
|
st.write(st.session_state['num_clusters']) |
|
|
|
st.session_state['dt'] = dt.now() |
|
|
|
|
|
|
|
with st.form(key='columns_in_form'): |
|
cols = st.columns(3) |
|
for i, col in enumerate(cols): |
|
selections.append(col.selectbox(f'Make a Selection', choices, key=i)) |
|
submitted = st.form_submit_button('Submit') |
|
if submitted: |
|
st.write("Submitted.\nWhat you'll see:\n\t•The links of the articles being summarized for you digest.\n\t•The digest.\n\t•A graph showing the reduction in articles lengths from original to summary, for each article.\n\t•Some probably issues with the summary.") |
|
selections = [i for i in selections if i is not None] |
|
with st.spinner(text="Creating your digest: this will take a few moments."): |
|
chosen = [] |
|
|
|
for i in selections: |
|
if i != 'None': |
|
for j in clusters[i]: |
|
if j not in chosen: |
|
chosen.append(j) |
|
|
|
|
|
|
|
|
|
digestor = Digestor(timer=Timer(), cache = USE_CACHE, stubs=chosen, user_choices=selections) |
|
|
|
|
|
st.write("What you'll see:") |
|
st.write("First you'll see a list of links appear below. These are the links to the original articles being summarized for your digest, so you can get the full story if you're interested, or check the summary against the source.") |
|
st.write("In a few moments, your machine-generated digest will appear below the links, and below that you'll see an approximate word count of your digest and the time in seconds that the whole process took!") |
|
st.write("You'll also see a graph showing, for each article and summary, the original and summarized lengths.") |
|
st.write("Finally, you will see some possible errors detected in the summaries. This area of NLP is far from perfection and always developing. Hopefully this is an interesting step in the path!") |
|
digestor.digest() |
|
|
|
|
|
|
|
|
|
outdata = digestor.build_digest() |
|
|
|
if len(digestor.text) == 0: |
|
st.write("No text to return...huh.") |
|
else: |
|
st.write("Your digest is ready:\n") |
|
st.write(digestor.text) |
|
st.write(f"""Text approximately {len(digestor.text.split(" ") )} words.""") |
|
st.write(f"""Number of articles summarized: {outdata['article_count']}""") |
|
|
|
|
|
st.success(f"""Digest completed in {digestor.timer.timers['digest_time']} seconds.""") |
|
|
|
st.write("Here are some stats about the summarization:\n") |
|
|
|
|
|
labels = [i for i in range(outdata['article_count'])] |
|
original_length = [outdata['summaries'][i]['original_length'] for i in outdata['summaries']] |
|
summarized_length = [outdata['summaries'][i]['summary_length'] for i in outdata['summaries']] |
|
x = np.arange(len(labels)) |
|
width = 0.35 |
|
|
|
fig, ax = plt.subplots(figsize=(14,8)) |
|
rects1 = ax.bar(x - width/2, original_length, width, color='lightgreen',zorder=0) |
|
rects2 = ax.bar(x + width/2, summarized_length, width, color='lightblue',zorder=0) |
|
|
|
rects3 = ax.bar(x - width/2, original_length, width, color='none',edgecolor='black', hatch='XX', lw=1.25,zorder=1) |
|
rects4 = ax.bar(x + width/2, summarized_length, width, color='none',edgecolor='black', hatch='xx', lw=1.25,zorder=1) |
|
|
|
|
|
ax.set_ylabel('Text Length') |
|
ax.set_xticks(x) |
|
ax.set_yticks([i for i in range(0,max(original_length),max(summarized_length))]) |
|
ax.set_xticklabels(labels) |
|
ax.set_xlabel('Article') |
|
|
|
plt.title('Original to Summarized Lengths in Space-Separated Tokens') |
|
|
|
st.pyplot(fig) |
|
|
|
|
|
"st.session_state object:", st.session_state |