|
|
|
|
|
|
|
import requests, json |
|
from collections import namedtuple |
|
from functools import lru_cache |
|
from typing import List |
|
from dataclasses import dataclass, field |
|
from datetime import datetime as dt |
|
import streamlit as st |
|
|
|
from codetiming import Timer |
|
from transformers import AutoTokenizer |
|
|
|
from source import Source, Summary |
|
from scrape_sources import stub as stb |
|
|
|
|
|
|
|
@dataclass |
|
class Digestor: |
|
timer: Timer |
|
cache: bool = True |
|
text: str = field(default="no_digest") |
|
stubs: List = field(default_factory=list) |
|
|
|
|
|
user_choices: List =field(default_factory=list) |
|
|
|
summaries: List = field(default_factory=list) |
|
|
|
|
|
|
|
digest_meta:namedtuple( |
|
"digestMeta", |
|
[ |
|
'digest_time', |
|
'number_articles', |
|
'digest_length', |
|
'articles_per_cluster' |
|
]) = None |
|
|
|
|
|
token_limit: int = 512 |
|
word_limit: int = 400 |
|
SUMMARIZATION_PARAMETERS = { |
|
"do_sample": False, |
|
"use_cache": cache |
|
} |
|
|
|
|
|
API_URL = "https://api-inference.huggingface.co/models/sshleifer/distilbart-cnn-12-6" |
|
headers = {"Authorization": f"""Bearer {st.secrets['ato']}"""} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def relevance(self, summary): |
|
return len(set(self.user_choices) & set(summary.cluster_list)) |
|
|
|
def digest(self): |
|
"""Retrieves all data for user-chosen articles, builds summary object list""" |
|
|
|
self.timer.timers.clear() |
|
|
|
with Timer(name=f"digest_time", text="Total digest time: {seconds:.4f} seconds"): |
|
|
|
|
|
for stub in self.stubs: |
|
|
|
if not isinstance(stub, stb): |
|
self.summaries.append(stub) |
|
else: |
|
|
|
summary_data: List |
|
|
|
text, summary_data = stub.source.retrieve_article(stub) |
|
|
|
|
|
if text != None and summary_data != None: |
|
|
|
with Timer(name=f"{stub.hed}_chunk_time", logger=None): |
|
chunk_list = self.chunk_piece(text, self.word_limit, stub.source.source_summarization_checkpoint) |
|
|
|
with Timer(name=f"{stub.hed}_summary_time", text="Whole article summarization time: {:.4f} seconds"): |
|
summary = self.perform_summarization( |
|
stub.hed, |
|
chunk_list, |
|
self.API_URL, |
|
self.headers, |
|
cache = self.cache, |
|
) |
|
|
|
|
|
|
|
|
|
self.summaries.append( |
|
Summary( |
|
source=summary_data[0], |
|
cluster_list=summary_data[1], |
|
link_ext=summary_data[2], |
|
hed=summary_data[3], |
|
dek=summary_data[4], |
|
date=summary_data[5], |
|
authors=summary_data[6], |
|
original_length = summary_data[7], |
|
summary_text=summary, |
|
summary_length=len(' '.join(summary).split(' ')), |
|
chunk_time=self.timer.timers[f'{stub.hed}_chunk_time'], |
|
query_time=self.timer.timers[f"{stub.hed}_query_time"], |
|
mean_query_time=self.timer.timers.mean(f'{stub.hed}_query_time'), |
|
summary_time=self.timer.timers[f'{stub.hed}_summary_time'], |
|
|
|
) |
|
) |
|
else: |
|
print("Null article") |
|
|
|
|
|
|
|
self.summaries.sort(key=self.relevance, reverse=True) |
|
|
|
|
|
def query(self, payload, API_URL, headers): |
|
"""Performs summarization inference API call.""" |
|
data = json.dumps(payload) |
|
response = requests.request("POST", API_URL, headers=headers, data=data) |
|
return json.loads(response.content.decode("utf-8")) |
|
|
|
|
|
def chunk_piece(self, piece, limit, tokenizer_checkpoint, include_tail=False): |
|
"""Breaks articles into chunks that will fit the desired token length limit""" |
|
|
|
words = len(piece.split(' ')) |
|
|
|
|
|
base_range = [i*limit for i in range(words//limit+1)] |
|
|
|
|
|
|
|
if include_tail or base_range == [0]: |
|
base_range.append(base_range[-1]+words%limit) |
|
|
|
range_list = [i for i in zip(base_range,base_range[1:])] |
|
|
|
|
|
|
|
fractured = piece.split(' ') |
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) |
|
chunk_list = [] |
|
|
|
|
|
for i, j in range_list: |
|
if (tokenized_len := len(tokenizer(chunk := ' '.join(fractured[i:j]).replace('\n',' ')))) <= self.token_limit: |
|
chunk_list.append(chunk) |
|
else: |
|
chunk_list.append(' '.join(chunk.split(' ')[: self.token_limit - tokenized_len ]).replace('\n',' ')) |
|
|
|
return chunk_list |
|
|
|
|
|
|
|
|
|
def perform_summarization(self, stubhead, chunklist : List[str], API_URL: str, headers: None, cache=True) -> List[str]: |
|
"""For each in chunk_list, appends result of query(chunk) to list collection_bin.""" |
|
collection_bin = [] |
|
repeat = 0 |
|
|
|
|
|
for chunk in chunklist: |
|
safe = False |
|
summarized_chunk = None |
|
with Timer(name=f"{stubhead}_query_time", logger=None): |
|
while not safe and repeat < 4: |
|
try: |
|
summarized_chunk = self.query( |
|
{ |
|
"inputs": str(chunk), |
|
"parameters": self.SUMMARIZATION_PARAMETERS |
|
}, |
|
API_URL, |
|
headers, |
|
)[0]['summary_text'] |
|
safe = True |
|
except Exception as e: |
|
print("Summarization error, repeating...") |
|
print(e) |
|
repeat+=1 |
|
if summarized_chunk is not None: |
|
collection_bin.append(summarized_chunk) |
|
return collection_bin |
|
|
|
|
|
|
|
|
|
def build_digest(self) -> str: |
|
"""Called to show the digest. Also creates data dict for digest and summaries.""" |
|
|
|
|
|
|
|
|
|
|
|
digest = [] |
|
for each in self.summaries: |
|
digest.append(' '.join(each.summary_text)) |
|
|
|
self.text = '\n\n'.join(digest) |
|
|