m. polinsky commited on
Commit
80b5ef0
1 Parent(s): e0e6908

Create digestor.py

Browse files
Files changed (1) hide show
  1. digestor.py +252 -0
digestor.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # digestor.py is an implementation of a digestor that creates news digests.
2
+ # the digestor manages the creation of summaries and assembles them into one digest...
3
+
4
+ import requests, json
5
+ from collections import namedtuple
6
+ from functools import lru_cache
7
+ from typing import List
8
+ from dataclasses import dataclass, field
9
+ from datetime import datetime as dt
10
+ import streamlit as st
11
+
12
+ from codetiming import Timer
13
+ from transformers import AutoTokenizer
14
+
15
+ from source import Source, Summary
16
+ from scrape_sources import stub as stb
17
+
18
+
19
+
20
+ @dataclass
21
+ class Digestor:
22
+ timer: Timer
23
+ cache: bool = True
24
+ text: str = field(default="no_digest")
25
+ stubs: List = field(default_factory=list)
26
+ # For clarity.
27
+ # Each stub/summary has its entities.
28
+ user_choices: List =field(default_factory=list)
29
+ # The digest text
30
+ summaries: List = field(default_factory=list)
31
+ #sources:List = field(default_factory=list) # I'm thinking create a string list for easy ref
32
+ # text:str = None
33
+
34
+ digest_meta:namedtuple(
35
+ "digestMeta",
36
+ [
37
+ 'digest_time',
38
+ 'number_articles',
39
+ 'digest_length',
40
+ 'articles_per_cluster'
41
+ ]) = None
42
+
43
+ # Summarization params:
44
+ token_limit: int = 512
45
+ word_limit: int = 400
46
+ SUMMARIZATION_PARAMETERS = {
47
+ "do_sample": False,
48
+ "use_cache": cache
49
+ }
50
+
51
+ # Inference parameters
52
+ API_URL = "https://api-inference.huggingface.co/models/sshleifer/distilbart-cnn-12-6"
53
+ headers = {"Authorization": f"""Bearer {st.secrets['ato']}"""}
54
+
55
+ # I would like to keep the whole scraped text separate if I can,
56
+ # which I'm not doing here
57
+ # After this runs, the digestor is populated with s
58
+
59
+ # relevance is a matter of how many chosen clusters this article belongs to.
60
+ # max relevance is the number of unique chosen entities. min is 1.
61
+ # Allows placing articles that hit more chosen topics to go higher up,
62
+ # mirroring "upside down pyramid" journalism convention, i.e. ordering facts by decreasing information content.
63
+ def relevance(self, summary):
64
+ return len(set(self.user_choices) & set(summary.cluster_list))
65
+
66
+ def digest(self):
67
+ """Retrieves all data for user-chosen articles, builds summary object list"""
68
+ # Clear timer from previous digestion
69
+ self.timer.timers.clear()
70
+ # Start digest timer
71
+ with Timer(name=f"digest_time", text="Total digest time: {seconds:.4f} seconds"):
72
+ # Loop through stubs, collecting data and instantiating
73
+ # and collecting Summary objects.
74
+ for stub in self.stubs:
75
+ # Check to see if we already have access to this summary:
76
+ if not isinstance(stub, stb):
77
+ self.summaries.append(stub)
78
+ else:
79
+ # if not:
80
+ summary_data: List
81
+ # Get full article data
82
+ text, summary_data = stub.source.retrieve_article(stub)
83
+ # Drop problem scrapes
84
+ # Log here
85
+ if text != None and summary_data != None:
86
+ # Start chunk timer
87
+ with Timer(name=f"{stub.hed}_chunk_time", logger=None):
88
+ chunk_list = self.chunk_piece(text, self.word_limit, stub.source.source_summarization_checkpoint)
89
+ # start totoal summarization timer. Summarization queries are timed in 'perform_summarzation()'
90
+ with Timer(name=f"{stub.hed}_summary_time", text="Whole article summarization time: {:.4f} seconds"):
91
+ summary = self.perform_summarization(
92
+ stub.hed,
93
+ chunk_list,
94
+ self.API_URL,
95
+ self.headers,
96
+ cache = self.cache,
97
+ )
98
+ # return these things and instantiate a Summary object with them,
99
+ # add that summary object to a list or somesuch collection.
100
+ # There is also timer data and data on articles
101
+
102
+ self.summaries.append(
103
+ Summary(
104
+ source=summary_data[0],
105
+ cluster_list=summary_data[1],
106
+ link_ext=summary_data[2],
107
+ hed=summary_data[3],
108
+ dek=summary_data[4],
109
+ date=summary_data[5],
110
+ authors=summary_data[6],
111
+ original_length = summary_data[7],
112
+ summary_text=summary,
113
+ summary_length=len(' '.join(summary).split(' ')),
114
+ chunk_time=self.timer.timers[f'{stub.hed}_chunk_time'],
115
+ query_time=self.timer.timers[f"{stub.hed}_query_time"],
116
+ mean_query_time=self.timer.timers.mean(f'{stub.hed}_query_time'),
117
+ summary_time=self.timer.timers[f'{stub.hed}_summary_time'],
118
+
119
+ )
120
+ )
121
+ else:
122
+ print("Null article") # looog this.
123
+
124
+
125
+ # When finished, order the summaries based on the number of user-selected clusters each article appears in.
126
+ self.summaries.sort(key=self.relevance, reverse=True)
127
+
128
+ # Query the HuggingFace Inference engine.
129
+ def query(self, payload, API_URL, headers):
130
+ """Performs summarization inference API call."""
131
+ data = json.dumps(payload)
132
+ response = requests.request("POST", API_URL, headers=headers, data=data)
133
+ return json.loads(response.content.decode("utf-8"))
134
+
135
+
136
+ def chunk_piece(self, piece, limit, tokenizer_checkpoint, include_tail=False):
137
+ """Breaks articles into chunks that will fit the desired token length limit"""
138
+ # Get approximate word count
139
+ words = len(piece.split(' ')) # rough estimate of words. # words <= number tokens generally.
140
+ # get number of chunks by idividing number of words by chunk size (word limit)
141
+ # Create list of ints to create rangelist from
142
+ base_range = [i*limit for i in range(words//limit+1)]
143
+ # For articles less than limit in length base_range will only contain zero.
144
+ # For most articles there is a small final chunk less than the limit.
145
+ # It may make summaries less coherent.
146
+ if include_tail or base_range == [0]:
147
+ base_range.append(base_range[-1]+words%limit) # add odd part at end of text...maybe remove.
148
+ # list of int ranges
149
+ range_list = [i for i in zip(base_range,base_range[1:])]
150
+
151
+
152
+ # Setup for chunking/checking tokenized chunk length
153
+ fractured = piece.split(' ')
154
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)
155
+ chunk_list = []
156
+
157
+ # Finally, chunk the piece, adjusting the chunks if too long.
158
+ for i, j in range_list:
159
+ if (tokenized_len := len(tokenizer(chunk := ' '.join(fractured[i:j]).replace('\n',' ')))) <= self.token_limit:
160
+ chunk_list.append(chunk)
161
+ else: # if chunks of <limit> words are too long, back them off.
162
+ chunk_list.append(' '.join(chunk.split(' ')[: self.token_limit - tokenized_len ]).replace('\n',' '))
163
+
164
+ return chunk_list
165
+
166
+
167
+
168
+ # Returns list of summarized chunks instead of concatenating them which loses info about the process.
169
+ def perform_summarization(self, stubhead, chunklist : List[str], API_URL: str, headers: None, cache=True) -> List[str]:
170
+ """For each in chunk_list, appends result of query(chunk) to list collection_bin."""
171
+ collection_bin = []
172
+ repeat = 0
173
+ # loop list and pass each chunk to the summarization API, storing results.
174
+ # API CALLS: consider placing the code from query() into here. * * * *
175
+ for chunk in chunklist:
176
+ safe = False
177
+ with Timer(name=f"{stubhead}_query_time", logger=None):
178
+ while not safe and repeat < 4:
179
+ try: # make these digest params.
180
+ summarized_chunk = self.query(
181
+ {
182
+ "inputs": str(chunk),
183
+ "parameters": self.SUMMARIZATION_PARAMETERS
184
+ },
185
+ API_URL,
186
+ headers,
187
+ )[0]['summary_text']
188
+ safe = True
189
+ except Exception as e:
190
+ print("Summarization error, repeating...")
191
+ print(e)
192
+ repeat+=1
193
+ collection_bin.append(summarized_chunk)
194
+ return collection_bin
195
+
196
+
197
+
198
+ # Order for display, arrange links?
199
+ def build_digest(self) -> str:
200
+ """Called to show the digest. Also creates data dict for digest and summaries."""
201
+ # builds summaries from pieces in each object
202
+ # orders summaries according to cluster count
203
+ # above done below not
204
+ # Manages data to be presented along with digest.
205
+ # returns all as data to display method either here or in main.
206
+ digest = []
207
+ for each in self.summaries:
208
+ digest.append(' '.join(each.summary_text))
209
+
210
+ # Create dict to write out digest data for analysis
211
+ out_data = {}
212
+ datetime_str = f"""{dt.now()}"""
213
+ choices_str = ', '.join(self.user_choices)
214
+ digest_str = '\n\n'.join(digest)
215
+
216
+
217
+ # This is a long comprehension to store all the fields and values in each summary.
218
+ # integer: {
219
+ # name_of_field:value except for source,
220
+ # which is unhashable so needs explicit handling.
221
+ # }
222
+ summaries = { # k is a summary tuple, i,p = enumerate(k)
223
+ # Here we take the first dozen words of the first summary chunk as key
224
+ c: {
225
+ # field name : value unless its the source
226
+ k._fields[i]:p if k._fields[i]!='source'
227
+ else
228
+ {
229
+ 'name': k.source.source_name,
230
+ 'source_url': k.source.source_url,
231
+ 'Summarization" Checkpoint': k.source.source_summarization_checkpoint,
232
+ 'NER Checkpoint': k.source.source_ner_checkpoint,
233
+ } for i,p in enumerate(k)
234
+ } for c,k in enumerate(self.summaries)}
235
+
236
+ out_data['timestamp'] = datetime_str
237
+ out_data['choices'] = choices_str
238
+ out_data['digest_text'] = digest_str
239
+ out_data['article_count'] = len(self.summaries)
240
+ out_data['digest_length'] = len(digest_str.split(" "))
241
+ out_data['digest_time'] = self.timer.timers['digest_time']
242
+ out_data['sum_params'] = {
243
+ 'token_limit':self.token_limit,
244
+ 'word_limit':self.word_limit,
245
+ 'params':self.SUMMARIZATION_PARAMETERS,
246
+ }
247
+ out_data['summaries'] = summaries
248
+
249
+
250
+ self.text = digest_str
251
+
252
+ return out_data