fittar commited on
Commit
3815e0a
1 Parent(s): 30f7222

push summagary

Browse files
Files changed (4) hide show
  1. app.py +58 -0
  2. requirements.txt +8 -0
  3. summagery_pipline.py +233 -0
  4. utils.py +119 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ temp_dir = './temp/'
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
5
+ os.environ['TMPDIR'] = temp_dir
6
+ import gradio as gr
7
+ import shutil
8
+ from summagery_pipline import Summagery
9
+
10
+ if os.path.exists(temp_dir):
11
+ try:
12
+ shutil.rmtree(temp_dir)
13
+ print(f"The directory at {temp_dir} has been removed successfully along with its contents.")
14
+ except OSError as e:
15
+ print(f"Error: {temp_dir} - {e}")
16
+
17
+ os.makedirs(temp_dir, exist_ok=True)
18
+
19
+ def generate(text, batch_size, model_type, abstractness):
20
+
21
+ model = Summagery(model_type,batch_size=int(batch_size),abstractness=abstractness)
22
+ images=model.ignite(text)
23
+
24
+ return images
25
+
26
+
27
+ with gr.Blocks(theme=gr.themes.Soft(),) as demo:
28
+ gr.Markdown(
29
+ """
30
+ <h1 style="text-align:center;">Welcome to Summagery: Document Summarization through Images</h1>
31
+
32
+ <h3 style="text-align:center;">Summarize long and short documents on any topic as images</h3>
33
+
34
+ <p style="text-align:left;">1. <b>Document:</b> Enter the text of the document you want to summarize.</p>
35
+ <p style="text-align:left;">2. <b>Batch Size:</b> Adjust the batch size for processing very long documents (e.g., 500 pages)</p>
36
+ <p style="text-align:left;">3. <b>T5_Model_Checkpoint:</b> Choose the model checkpoint (e.g., "t5-large", "t5-base", "t5-small"). Smaller models require less memory.</p>
37
+ <p style="text-align:left;">4. <b>Abstractness:</b> Slide to select the level of abstractness of your document, vary this attribute to explore different images.</p>
38
+
39
+ <p style="text-align:left;"> <b>For more details:</b> check out my <a href="https://fittar.me/post/summagary/" target="_blank">blog post</a> for a comprehensive explanation of the Summagery project.</p>
40
+ """)
41
+
42
+
43
+ inputs = [
44
+ gr.Textbox(label="Document", lines=10,interactive=True),
45
+ gr.Number(label="Batch Size", value=5),
46
+ gr.Dropdown(label="T5_Model_Checkpoint", choices=["t5-large", "t5-base", "t5-small"], value='t5-large'),
47
+ gr.Slider(label="Abstractness", minimum=0, maximum=1, value=.2)
48
+ ]
49
+
50
+ outputs = gr.Gallery(
51
+ label="Generated images", show_label=False, elem_id="gallery"
52
+ , columns=[2], rows=[2], object_fit="contain", height="auto")
53
+
54
+ clear = gr.ClearButton([inputs[0]])
55
+ greet_btn = gr.Button("Submit")
56
+ greet_btn.click(fn=generate, inputs=inputs, outputs=outputs, api_name="Summagery")
57
+
58
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch~=2.0.1
2
+ diffusers~=0.19.3
3
+ transformers~=4.30.2
4
+ image-reward~=1.5
5
+ numpy~=1.24.4
6
+ tqdm
7
+ pandas
8
+ gradio
summagery_pipline.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelWithLMHead, AutoTokenizer
2
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
+ from diffusers import DiffusionPipeline
4
+ import torch
5
+ from tqdm import tqdm
6
+ import pandas as pd
7
+ import numpy as np
8
+ import random
9
+ from utils import mpnet_embed_class, get_concreteness, Collate_t5
10
+ from torch.utils.data import DataLoader
11
+ from utils import SentenceDataset
12
+
13
+
14
+ class Summagery:
15
+
16
+ def __init__(self, t5_checkpoint, batch_size=5, abstractness=.4, max_d_length=1256, num_prompt=3, device='cuda'):
17
+
18
+ # ViPE: Visualize Pretty-much Everything
19
+ self.vipe_model = GPT2LMHeadModel.from_pretrained('fittar/ViPE-M-CTX7')
20
+ vipe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
21
+ vipe_tokenizer.pad_token = vipe_tokenizer.eos_token
22
+ self.vipe_tokenizer = vipe_tokenizer
23
+
24
+ # SDXL, load both base & refiner
25
+ self.basexl = DiffusionPipeline.from_pretrained(
26
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
27
+ )
28
+ self.refinerxl = DiffusionPipeline.from_pretrained(
29
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
30
+ text_encoder_2=self.basexl.text_encoder_2,
31
+ vae=self.basexl.vae,
32
+ torch_dtype=torch.float16,
33
+ use_safetensors=True,
34
+ variant="fp16",
35
+ )
36
+
37
+ self.device = device
38
+ self.max_d_length = max_d_length # maximum document length to handle before chunking
39
+ self.final_document_length = 60
40
+ self.num_prompt = num_prompt # how many prompts to generate per document
41
+ self.abstractness = abstractness # to explore the prompts , just a handle from 0 to 1
42
+ self.concreteness_dataset = './data/concreteness.csv'
43
+ self.batch_size = batch_size
44
+ # T5
45
+ self.t5_model = AutoModelWithLMHead.from_pretrained(t5_checkpoint)
46
+ self.t5_tokenizer = AutoTokenizer.from_pretrained(t5_checkpoint, model_max_length=max_d_length)
47
+ self.collate_t5 = Collate_t5(self.t5_tokenizer)
48
+
49
+ # for concrteness rating of the prompts
50
+ data = pd.read_csv(self.concreteness_dataset, header=0,
51
+ delimiter='\t')
52
+ self.word2score = {w: s for w, s in zip(data['WORD'], data['RATING'])}
53
+
54
+ # for large documents, divide them into chunks with self.max_d_length size
55
+ def document_preprocess(self, document):
56
+ documents = []
57
+ words = document.split()
58
+ if len(words) <= self.max_d_length:
59
+ return [document]
60
+
61
+ start = 0
62
+ while (len(words) > start):
63
+ if len(words) > start + self.max_d_length:
64
+ chunk = ' '.join(words[start:start + self.max_d_length])
65
+ else:
66
+ chunk = ' '.join(words[start:])
67
+
68
+ start += self.max_d_length
69
+ documents.append(chunk)
70
+
71
+ return documents
72
+
73
+ def t5_summarize(self, document):
74
+
75
+ continue_summarization = True
76
+ if len(document.split()) <= self.final_document_length:
77
+ return document
78
+
79
+ self.t5_model.to(self.device)
80
+
81
+ documents = self.document_preprocess(document)
82
+
83
+ if len(documents) > self.batch_size:
84
+
85
+ # use batch inference to make things faster
86
+ while (continue_summarization):
87
+ dataset = SentenceDataset(documents)
88
+ dataloader = DataLoader(dataset, batch_size=self.batch_size, collate_fn=self.collate_t5, num_workers=2)
89
+ summaries = ''
90
+ print('summarizing...')
91
+ for text_batch, batch in tqdm(dataloader):
92
+ if batch.input_ids.shape[1] > 5:
93
+ max_length = int(batch.input_ids.shape[1] / 2) # summarize the current chunk by half
94
+ if max_length < self.final_document_length: # unless max_length is too short
95
+ max_length = self.final_document_length
96
+
97
+ batch = batch.to(self.device)
98
+ generated_ids = self.t5_model.generate(input_ids=batch.input_ids,
99
+ attention_mask=batch.attention_mask, num_beams=3,
100
+ max_length=max_length,
101
+ repetition_penalty=2.5,
102
+ length_penalty=1.0, early_stopping=True)
103
+ preds = \
104
+ [self.t5_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
105
+ for g
106
+ in
107
+ generated_ids]
108
+ for pred in preds:
109
+ summaries = summaries + pred + '. '
110
+ else:
111
+ for chunk in text_batch:
112
+ summaries = summaries + chunk + '. '
113
+
114
+ if len(summaries.split()) <= self.final_document_length:
115
+ continue_summarization = False
116
+ print('finished summarizing.')
117
+ else:
118
+ documents = self.document_preprocess(summaries)
119
+ else:
120
+
121
+ # skip batch inference since we only have a few documents
122
+ while (continue_summarization):
123
+ summaries = ''
124
+ print('summarizing...')
125
+ for chunk in tqdm(documents):
126
+ if len(chunk.split()) > 2:
127
+ max_length = int(len(chunk.split()) / 2) # summarize the current chunk by half
128
+ if max_length < self.final_document_length: # unless max_length is too short
129
+ max_length = self.final_document_length
130
+
131
+ input_ids = self.t5_tokenizer.encode('summarize: ' + chunk, return_tensors="pt",
132
+ add_special_tokens=True, padding='longest',
133
+ max_length=self.max_d_length)
134
+ input_ids = input_ids.to(self.device)
135
+ generated_ids = self.t5_model.generate(input_ids=input_ids, num_beams=3, max_length=max_length,
136
+ repetition_penalty=2.5,
137
+ length_penalty=1.0, early_stopping=True)
138
+
139
+ pred = \
140
+ [self.t5_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g
141
+ in
142
+ generated_ids][0]
143
+ summaries = summaries + pred + '. '
144
+ else:
145
+ summaries = summaries + chunk + '. '
146
+
147
+ if len(summaries.split()) <= self.final_document_length:
148
+ continue_summarization = False
149
+ print('finished summarizing.')
150
+ else:
151
+ documents = self.document_preprocess(summaries)
152
+
153
+ return summaries
154
+
155
+ def vipe_generate(self, summary, do_sample=True, top_k=100, epsilon_cutoff=.00005, temperature=1):
156
+ batch_size = random.choice([20, 40, 60])
157
+ input_text = [summary] * batch_size
158
+ # mark the text with special tokens
159
+ input_text = [self.vipe_tokenizer.eos_token + i + self.vipe_tokenizer.eos_token for i in input_text]
160
+ batch = self.vipe_tokenizer(input_text, padding=True, return_tensors="pt")
161
+
162
+ input_ids = batch["input_ids"].to(self.device)
163
+ attention_mask = batch["attention_mask"].to(self.device)
164
+ self.vipe_model.to(self.device)
165
+ # how many new tokens to generate at max
166
+ max_prompt_length = 50
167
+
168
+ generated_ids = self.vipe_model.generate(input_ids=input_ids, attention_mask=attention_mask,
169
+ max_new_tokens=max_prompt_length, do_sample=do_sample, top_k=top_k,
170
+ epsilon_cutoff=epsilon_cutoff, temperature=temperature)
171
+ # return only the generated prompts
172
+ prompts = self.vipe_tokenizer.batch_decode(generated_ids[:, -(generated_ids.shape[1] - input_ids.shape[1]):],
173
+ skip_special_tokens=True)
174
+
175
+ # for semantic similarity
176
+ mpnet_object = mpnet_embed_class(device=self.device, nli=False)
177
+
178
+ similarities = mpnet_object.get_mpnet_embed_batch(prompts, [summary] * batch_size,
179
+ batch_size=batch_size).cpu().numpy()
180
+ concreteness_score = get_concreteness(prompts, self.word2score)
181
+
182
+ final_scores = [i * (1 - self.abstractness) + (self.abstractness) * j for i, j in
183
+ zip(similarities, concreteness_score)]
184
+ # Get the indices that would sort the final_scores in descending order
185
+ sorted_indices = np.argsort(final_scores)[::-1]
186
+
187
+ # Extract the indices of the top 5 highest scores
188
+ top_indices = sorted_indices[:self.num_prompt]
189
+ prompts = [prompts[i] for i in top_indices]
190
+
191
+ return prompts
192
+
193
+ def sdxl_generate(self, prompts):
194
+ # Define how many steps and what % of steps to be run on each experts (80/20) here
195
+ n_steps = 50
196
+ high_noise_frac = 0.8
197
+ self.basexl.to(self.device)
198
+ self.refinerxl.to(self.device)
199
+
200
+ images=[]
201
+ for i, p in enumerate(prompts):
202
+ # torch.manual_seed(i)
203
+ image = self.basexl(
204
+ prompt=p,
205
+ num_inference_steps=n_steps,
206
+ denoising_end=high_noise_frac,
207
+ output_type="latent",
208
+ ).images
209
+ image = self.refinerxl(
210
+ prompt=p,
211
+ num_inference_steps=n_steps,
212
+ denoising_start=high_noise_frac,
213
+ image=image,
214
+ ).images[0]
215
+
216
+ images.append(image)
217
+
218
+ return images
219
+
220
+ def ignite(self, document):
221
+ prompts = []
222
+ summary = self.t5_summarize(document)
223
+ prompts.append(summary)
224
+ summary = summary.replace('. ', '; ')
225
+ print(summary)
226
+ prompts.extend(self.vipe_generate(summary))
227
+
228
+ for p in prompts:
229
+ print(p + '\n')
230
+
231
+ images=self.sdxl_generate(prompts)
232
+
233
+ return images
utils.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.functional import cosine_similarity
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import numpy as np
6
+
7
+
8
+ def get_concreteness(prompts, word2score):
9
+ scores=[]
10
+ for prompt in prompts:
11
+ conc_scores=[word2score[w]/10 for w in prompt.split() if w in word2score]
12
+ if len(conc_scores) < 1:
13
+ scores.append(0.10)
14
+ else:
15
+ scores.append(np.mean(conc_scores))
16
+
17
+ return scores
18
+ # Mean Pooling - Take attention mask into account for correct averaging
19
+ def mean_pooling(model_output, attention_mask):
20
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
21
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
22
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
23
+
24
+ def compute_cosine_similarity(embeddings_1, embeddings_2):
25
+ # Compute cosine similarity between embeddings_1 and embeddings_2
26
+ similarities = cosine_similarity(embeddings_1, embeddings_2)
27
+
28
+ return similarities
29
+
30
+ class SentenceDataset(Dataset):
31
+ def __init__(self, sentences):
32
+ self.sentences = sentences
33
+
34
+ def __len__(self):
35
+ return len(self.sentences)
36
+
37
+ def __getitem__(self, index):
38
+ return self.sentences[index]
39
+
40
+ class Collate_t5:
41
+ def __init__(self, tokenizer):
42
+ self.t5_tokenizer = tokenizer
43
+
44
+ def __call__(self, documents):
45
+ batch=['summarize: ' + s for s in documents]
46
+ # Tokenize sentences
47
+ encoded_inputs = self.t5_tokenizer(batch, return_tensors="pt",
48
+ add_special_tokens=True, padding='longest',
49
+ )
50
+ return documents, encoded_inputs
51
+
52
+ class collate_cl:
53
+ def __init__(self, tokenizer):
54
+ self.tokenizer = tokenizer
55
+
56
+ def __call__(self, batch):
57
+ # Tokenize sentences
58
+ encoded_inputs = self.tokenizer(batch, padding=True, truncation=True, return_tensors='pt')
59
+ return encoded_inputs
60
+
61
+ class mpnet_embed_class():
62
+ def __init__(self, device='cuda', nli=True):
63
+ self.device = device
64
+
65
+ if nli:
66
+ model = AutoModel.from_pretrained('sentence-transformers/nli-mpnet-base-v2')
67
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/nli-mpnet-base-v2')
68
+ else:
69
+ model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
70
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
71
+
72
+ model.to(device)
73
+ self.model = model
74
+ self.tokenizer = tokenizer
75
+ self.collate_fn = collate_cl(tokenizer)
76
+
77
+ def get_mpnet_embed_batch(self, predictions, ground_truth, batch_size=10):
78
+
79
+ dataset_1 = SentenceDataset(predictions)
80
+ dataset_2 = SentenceDataset(ground_truth)
81
+
82
+ dataloader_1 = DataLoader(dataset_1, batch_size=batch_size, collate_fn=self.collate_fn, num_workers=1)
83
+ dataloader_2 = DataLoader(dataset_2, batch_size=batch_size, collate_fn=self.collate_fn, num_workers=1)
84
+
85
+ # Compute token embeddings
86
+ embeddings_1 = []
87
+ embeddings_2 = []
88
+
89
+ with torch.no_grad():
90
+ for count, (batch_1, batch_2) in enumerate(zip(dataloader_1, dataloader_2)):
91
+ if count % 50 == 0:
92
+ print(count, ' out of ', len(dataloader_2))
93
+ batch_1 = {key: value.to(self.device) for key, value in batch_1.items()}
94
+ batch_2 = {key: value.to(self.device) for key, value in batch_2.items()}
95
+
96
+ model_output_1 = self.model(**batch_1)
97
+ model_output_2 = self.model(**batch_2)
98
+
99
+ sentence_embeddings_1 = mean_pooling(model_output_1, batch_1['attention_mask'])
100
+ sentence_embeddings_2 = mean_pooling(model_output_2, batch_2['attention_mask'])
101
+
102
+ embeddings_1.append(sentence_embeddings_1)
103
+ embeddings_2.append(sentence_embeddings_2)
104
+
105
+ # Concatenate embeddings
106
+ embeddings_1 = torch.cat(embeddings_1)
107
+ embeddings_2 = torch.cat(embeddings_2)
108
+
109
+ # Normalize embeddings
110
+ embeddings_1 = torch.nn.functional.normalize(embeddings_1, p=2, dim=1)
111
+ embeddings_2 = torch.nn.functional.normalize(embeddings_2, p=2, dim=1)
112
+
113
+ # Compute cosine similarity
114
+ similarities = compute_cosine_similarity(embeddings_1, embeddings_2)
115
+
116
+ # # Average cosine similarity
117
+ # average_similarity = torch.mean(similarities)
118
+
119
+ return similarities