ignore commited on
Commit
0e0a45a
1 Parent(s): a9ee55c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +276 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, BertTokenizer
2
+ import torch
3
+ import streamlit as st
4
+ import re
5
+ from typing import List, Tuple
6
+ import spacy
7
+ import numpy as np
8
+ from dataclasses import dataclass
9
+ from nltk.tokenize import sent_tokenize, word_tokenize
10
+
11
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ st.set_page_config(layout="wide")
13
+
14
+ @dataclass
15
+ class LexicalUnits:
16
+ unit_type: str
17
+ text: List[str]
18
+ self_info: List[float] = None
19
+
20
+ def __add__(self, other):
21
+ assert self.unit_type == other.unit_type, 'Cannot add two different unit types'
22
+ return LexicalUnits(self.unit_type, self.text + other.text, self.self_info + other.self_info)
23
+
24
+ def __radd__(self, other):
25
+ if other == 0:
26
+ return self
27
+ return NotImplementedError()
28
+
29
+ def add_to_head(self, token, self_info):
30
+ return LexicalUnits(self.unit_type, [token] + self.text, [self_info] + self.self_info)
31
+
32
+ def add_to_tail(self, token, self_info):
33
+ return LexicalUnits(self.unit_type, self.text + [token], self.self_info + [self_info])
34
+
35
+ class SelectiveContext:
36
+
37
+ def __init__(self, model_type = 'gpt2', lang = 'en'):
38
+
39
+ self.model_type = model_type
40
+ self.lang = lang
41
+
42
+ # this means we calculate self-information sentence by sentence
43
+ self.sent_level_self_info = True
44
+
45
+ self._prepare_phrase_tokenizer()
46
+ self.sent_tokenize_pattern = r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s"
47
+ self.phrase_mask_token = ''
48
+ self.sent_mask_token = "<deleted>"
49
+
50
+ self._prepare_model()
51
+
52
+ def _prepare_phrase_tokenizer(self):
53
+ # we use space to tokenize sentence into phrases
54
+ # for English, we should use `spacy.load("en_core_web_sm").add_pipe('merge_noun_chunks')`
55
+ # for Chinese, use `nlp = spacy.load('zh_core_web_sm')`` directly
56
+ lang = self.lang
57
+ if lang == "en":
58
+ self.nlp = spacy.load("en_core_web_sm", disable=["ner"])
59
+ self.nlp.add_pipe('merge_noun_chunks')
60
+ elif lang == "zh":
61
+ self.nlp = spacy.load('zh_core_web_sm', disable=["ner"])
62
+
63
+ def _prepare_model(self):
64
+ if self.model_type == 'gpt2':
65
+ if self.lang == 'zh':
66
+ self.model = GPT2LMHeadModel.from_pretrained('uer/gpt2-chinese-cluecorpussmall')
67
+ self.tokenizer = BertTokenizer.from_pretrained('uer/gpt2-chinese-cluecorpussmall')
68
+ else:
69
+ self.model = GPT2LMHeadModel.from_pretrained('gpt2')
70
+ self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
71
+ self.model.to(DEVICE)
72
+ self.model.eval()
73
+
74
+ print('model loaded')
75
+
76
+ self.max_token_length = self.model.config.n_positions
77
+ self.get_self_information = self._get_self_info_via_gpt2
78
+
79
+ def get_self_information(self, text: str) -> Tuple[List[str], List[float]]:
80
+ # it takes text as input, and return a list of words and a list of self-information scores
81
+ raise NotImplementedError
82
+
83
+ def _get_self_info_via_gpt2(self, text: str) -> Tuple[List[str], List[float]]:
84
+ if self.lang == 'en':
85
+ text = f"<|endoftext|>{text}"
86
+ elif self.lang == 'zh':
87
+ text = f"[CLS]{text}"
88
+ with torch.no_grad():
89
+ encoding = self.tokenizer(text, add_special_tokens=False, return_tensors='pt')
90
+ encoding = encoding.to(DEVICE)
91
+ outputs = self.model(**encoding)
92
+ logits = outputs.logits
93
+ probs = torch.softmax(logits, dim=-1)
94
+ self_info = -torch.log(probs)
95
+
96
+ input_ids = encoding['input_ids']
97
+ input_ids_expaned = input_ids[:, 1:].unsqueeze(-1)
98
+
99
+ tokens = [self.tokenizer.decode(token_) for token_ in input_ids.squeeze().tolist()[1:]]
100
+ return tokens, self_info[:, :-1].gather(-1, input_ids_expaned).squeeze(-1).squeeze(0).tolist()
101
+
102
+ def _lexical_unit(self, sents):
103
+
104
+ if self.sent_level_self_info:
105
+ sent_self_info = []
106
+ all_noun_phrases = []
107
+ all_noun_phrases_info = []
108
+ all_tokens = []
109
+ all_token_self_info = []
110
+
111
+ for sent in sents:
112
+ print(sent)
113
+ tokens, self_info = self.get_self_information(sent)
114
+ sent_self_info.append(np.mean(self_info))
115
+
116
+ all_tokens.extend(tokens)
117
+ all_token_self_info.extend(self_info)
118
+
119
+ noun_phrases, noun_phrases_info = self._calculate_lexical_unit(tokens, self_info)
120
+
121
+ # We need to add a space before the first noun phrase for every sentence except the first one
122
+ if len(all_noun_phrases) != 0:
123
+ noun_phrases[0] = f" {noun_phrases[0]}"
124
+ all_noun_phrases.extend(noun_phrases)
125
+ all_noun_phrases_info.extend(noun_phrases_info)
126
+
127
+ return [
128
+ LexicalUnits('sent', text=sents, self_info=sent_self_info),
129
+ LexicalUnits('phrase', text=all_noun_phrases, self_info=all_noun_phrases_info),
130
+ LexicalUnits('token', text=all_tokens, self_info=all_token_self_info)
131
+ ]
132
+
133
+ def _calculate_lexical_unit(self, tokens, self_info):
134
+ def _unit_info(tokens, self_info, units):
135
+ current_unit_idx = 0
136
+ current_position = 0
137
+ unit_self_info = [[] for _ in range(len(units))]
138
+
139
+ for idx, (token, info) in enumerate(zip(tokens, self_info)):
140
+ current_position += len(token)
141
+ if current_position == len(units[current_unit_idx]):
142
+ unit_self_info[current_unit_idx].append(info)
143
+ current_position = current_position - len(units[current_unit_idx])
144
+ current_unit_idx += 1
145
+ elif current_position > len(units[current_unit_idx]):
146
+ counter_ = 1
147
+ current_position = current_position - len(units[current_unit_idx])
148
+ current_unit_idx += 1
149
+ while current_position >= len(units[current_unit_idx]):
150
+ counter_ += 1
151
+ current_position = current_position - len(units[current_unit_idx])
152
+ current_unit_idx += 1
153
+ if current_unit_idx >= len(units):
154
+ break
155
+ partial_info = info/counter_
156
+ for _ in range(counter_):
157
+ unit_self_info[(current_unit_idx-1) - _].append(partial_info)
158
+ else:
159
+ if token == " ":
160
+ continue
161
+ unit_self_info[current_unit_idx].append(info)
162
+
163
+ unit_self_info_ = [np.mean(info) for info in unit_self_info]
164
+ return unit_self_info_
165
+
166
+ def _noun_phrases(sent):
167
+ noun_phrases = []
168
+ doc = self.nlp(sent)
169
+ for index, chunk in enumerate(doc):
170
+ if index == 0:
171
+ noun_phrases.append(chunk.text)
172
+ else:
173
+ noun_phrases.append(doc[index-1].whitespace_ + chunk.text)
174
+ return noun_phrases
175
+
176
+ if self.sent_level_self_info:
177
+ # in this case, the self_info is for each sentence
178
+ # we only need to calculate the self_info for each phrase
179
+
180
+ sent = ''.join(tokens)
181
+ # noun_phrases = [chunk.text for chunk in self.nlp(sent).noun_chunks]
182
+ noun_phrases = _noun_phrases(sent)
183
+ # noun_phrases[-1] = noun_phrases[-1] + ' '
184
+ noun_phrases_info = _unit_info(tokens, self_info, noun_phrases)
185
+
186
+ return noun_phrases, noun_phrases_info
187
+
188
+ def beautify_context(self, context: str) -> str:
189
+ context = re.sub(r"\s+", " ", context)
190
+ return context
191
+
192
+ def self_info_mask(self, sents: List[str], self_info: List[float], mask_level):
193
+ # mask_level: mask sentences, phrases, or tokens
194
+ sents_after_mask = []
195
+ masked_sents = []
196
+
197
+ self.ppl_threshold = np.nanpercentile(self_info, self.mask_ratio * 100)
198
+
199
+ # if title is not None:
200
+ # with open(os.path.join(self.path, title+'_prob_token.tsv'), 'w', encoding='utf-8') as f:
201
+ # for token, info in zip(tokens, self_info):
202
+ # f.write(f"{token}\t{info}\n")
203
+ # with open(os.path.join(self.path, title+'_prob_sent.tsv'), 'w', encoding='utf-8') as f:
204
+ # for sent, info in zip(sents, sent_self_info):
205
+ # f.write(f"{sent}\n{info}\n\n")
206
+
207
+ for sent, info in zip(sents, self_info):
208
+ if info < self.ppl_threshold:
209
+ masked_sents.append(sent)
210
+ sents_after_mask.append(self.mask_a_sent(sent, mask_level))
211
+ else:
212
+ sents_after_mask.append(sent)
213
+ masked_context = " ".join(sents_after_mask) if mask_level == 'sent' else "".join(sents_after_mask)
214
+
215
+ return masked_context, masked_sents
216
+
217
+ def mask_a_sent(self, sent, level):
218
+ if level == 'phrase':
219
+ return self.phrase_mask_token
220
+ elif level == 'sent':
221
+ return self.sent_mask_token
222
+ elif level == 'token':
223
+ return ''
224
+
225
+ def __call__(self, text: str, reduce_ratio: float = 0.35, reduce_level :str = 'phrase') -> List[str]:
226
+ context = self.beautify_context(text)
227
+
228
+ self.mask_ratio = reduce_ratio
229
+
230
+ sents = re.split(self.sent_tokenize_pattern, context)
231
+ sents = [sent.strip() for sent in sents if sent.strip()]
232
+
233
+ # You want the reduce happen at sentence level, phrase level, or token level?
234
+ assert reduce_level in ['sent', 'phrase', 'token'], f"reduce_level should be one of ['sent', 'phrase', 'token'], got {reduce_level}"
235
+ sent_lus, phrase_lus, token_lus = self._lexical_unit(sents)
236
+ lexical_level = {
237
+ 'sent': sent_lus,
238
+ 'phrase': phrase_lus,
239
+ 'token': token_lus
240
+ }
241
+
242
+ # context is the reduced context, masked_sents denotes what context has been filtered out
243
+ context, masked_sents = self.self_info_mask(lexical_level[reduce_level].text, lexical_level[reduce_level].self_info, reduce_level)
244
+ return context, masked_sents
245
+
246
+ # streamlit app.py
247
+ # here we ask the user to input the text and the reduce ratio
248
+ # then we call the SelectiveContext to compress the text
249
+
250
+ st.title("Selective Context: Compress your prompt")
251
+ st.markdown("This is a demo for the **Selective Context** algorithm.")
252
+ st.markdown("Use this algorithm to **compress** your prompt, so that LLMs can deal with **2x more context**!")
253
+ st.markdown("- The algorithm filters out the content that is less informative. \n - You can also choose to filter out phrases or tokens instead of sentences. \n - Checkout the paper for details and experiments! [https://arxiv.org/abs/2304.12102](https://arxiv.org/abs/2304.12102).")
254
+ st.write("")
255
+
256
+ st.subheader("Demo")
257
+
258
+ lang = st.radio("Please choose the language: ", ('en', 'zh'))
259
+ ratio = st.radio("Please choose the compress ratio [we recommend 0.5]: ", (0.5, 0.2, 0.35, 0.65, 0.8))
260
+ reduce_level = st.radio("Please choose the reduce level: ", ('phrase', 'token', 'sent'))
261
+
262
+ text = st.text_area("Please input your text here", height=300)
263
+
264
+ @st.cache_resource()
265
+ def load_model(lang):
266
+ model = SelectiveContext(lang=lang)
267
+ return model
268
+
269
+ if st.button("Compress"):
270
+ model = load_model(lang)
271
+ context, masked_sents = model(text, reduce_ratio=ratio, reduce_level=reduce_level)
272
+ st.subheader("The compressed context is:")
273
+ st.code(context)
274
+ # st.divider()
275
+ st.subheader("The filtered out content is:")
276
+ st.write(masked_sents)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ spacy>=3.5.0
3
+ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0.tar.gz#en_core_web_sm
4
+ https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.5.0/zh_core_web_sm-3.5.0.tar.gz#zh_core_web_sm
5
+ nltk
6
+ torch
7
+ numpy
8
+ altair==4.0.0