Spaces:
Sleeping
Sleeping
ndhieunguyen
commited on
Commit
·
77180e4
1
Parent(s):
ff15dff
feat: remove mpi4py
Browse files- requirements.txt +0 -0
- src/improved_diffusion/dist_util.py +20 -21
- src/improved_diffusion/text_datasets.py +429 -273
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
src/improved_diffusion/dist_util.py
CHANGED
@@ -8,7 +8,6 @@ import socket
|
|
8 |
|
9 |
import blobfile as bf
|
10 |
|
11 |
-
from mpi4py import MPI
|
12 |
import torch as th
|
13 |
import torch.distributed as dist
|
14 |
|
@@ -46,26 +45,26 @@ def setup_dist(rank, world_size, port="12145"):
|
|
46 |
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
47 |
|
48 |
|
49 |
-
def dev():
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
def load_state_dict(path, **kwargs):
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
|
70 |
|
71 |
def sync_params(params):
|
|
|
8 |
|
9 |
import blobfile as bf
|
10 |
|
|
|
11 |
import torch as th
|
12 |
import torch.distributed as dist
|
13 |
|
|
|
45 |
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
46 |
|
47 |
|
48 |
+
# def dev():
|
49 |
+
# """
|
50 |
+
# Get the device to use for torch.distributed.
|
51 |
+
# """
|
52 |
+
# if th.cuda.is_available():
|
53 |
+
# return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
|
54 |
+
# return th.device("cpu")
|
55 |
+
|
56 |
+
|
57 |
+
# def load_state_dict(path, **kwargs):
|
58 |
+
# """
|
59 |
+
# Load a PyTorch file without redundant fetches across MPI ranks.
|
60 |
+
# """
|
61 |
+
# if MPI.COMM_WORLD.Get_rank() == 0:
|
62 |
+
# with bf.BlobFile(path, "rb") as f:
|
63 |
+
# data = f.read()
|
64 |
+
# else:
|
65 |
+
# data = None
|
66 |
+
# data = MPI.COMM_WORLD.bcast(data)
|
67 |
+
# return th.load(io.BytesIO(data), **kwargs)
|
68 |
|
69 |
|
70 |
def sync_params(params):
|
src/improved_diffusion/text_datasets.py
CHANGED
@@ -1,13 +1,21 @@
|
|
1 |
# from PIL import Image
|
2 |
# import blobfile as bf
|
3 |
-
from mpi4py import MPI
|
4 |
import numpy as np
|
5 |
from torch.utils.data import DataLoader, Dataset
|
6 |
-
from transformers import
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
# from datasets import load_dataset
|
9 |
import sys, os
|
10 |
import torch
|
|
|
11 |
# sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
|
12 |
# from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise
|
13 |
from collections import Counter, defaultdict
|
@@ -16,8 +24,18 @@ from itertools import chain
|
|
16 |
|
17 |
|
18 |
def load_data_text(
|
19 |
-
*,
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
):
|
22 |
"""
|
23 |
For a dataset, create a generator over (images, kwargs) pairs.
|
@@ -35,29 +53,34 @@ def load_data_text(
|
|
35 |
exception will be raised.
|
36 |
:param deterministic: if True, yield results in a deterministic order.
|
37 |
"""
|
38 |
-
print(
|
39 |
|
40 |
-
if data_args.experiment.startswith(
|
41 |
model = None
|
42 |
# elif data_args.experiment.startswith('random') and model is not None:
|
43 |
# print('loading initialized random embeddings. ')
|
44 |
|
45 |
-
if task_mode ==
|
46 |
pass
|
47 |
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
|
48 |
# padding_mode=padding_mode, split=split,
|
49 |
-
|
50 |
-
elif task_mode ==
|
51 |
pass
|
52 |
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
elif task_mode ==
|
57 |
-
print(
|
58 |
-
training_data, model = get_corpus_rocstory(
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
# elif task_mode == 'yelp':
|
62 |
# print('hello loading yelp ')
|
63 |
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
|
@@ -80,8 +103,12 @@ def load_data_text(
|
|
80 |
# training_data, model = get_corpus_book(data_args, tokenizer, model, image_size,
|
81 |
# padding_mode=padding_mode, split=split,)
|
82 |
|
83 |
-
if
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
# training_data,
|
86 |
# image_size,
|
87 |
# data_args,
|
@@ -98,7 +125,7 @@ def load_data_text(
|
|
98 |
|
99 |
if deterministic:
|
100 |
|
101 |
-
pass# data_loader = DataLoader(
|
102 |
# dataset,
|
103 |
# batch_size=batch_size, # 20,
|
104 |
# drop_last=True,
|
@@ -117,64 +144,83 @@ def load_data_text(
|
|
117 |
while True:
|
118 |
yield from data_loader
|
119 |
|
|
|
120 |
def helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, seqlen, data_args):
|
121 |
result_train_lst = []
|
122 |
group_lst = defaultdict(list)
|
123 |
with torch.no_grad():
|
124 |
-
for
|
125 |
-
tokenized_ = [vocab_dict.get(x, vocab_dict[
|
126 |
-
tokenized_src = [vocab_dict.get(x, vocab_dict[
|
127 |
input_ids = [0] + tokenized_ + [1]
|
128 |
-
group_lst[
|
129 |
-
group_lst[
|
130 |
|
131 |
-
print(group_lst[
|
132 |
-
print(
|
133 |
max_length = seqlen
|
134 |
-
group_lst[
|
135 |
-
|
|
|
|
|
136 |
print(max_src_length, seqlen)
|
137 |
max_src_length = min(seqlen, max_src_length)
|
138 |
-
group_lst[
|
139 |
-
|
140 |
-
|
141 |
-
return_mask=True)
|
142 |
-
|
143 |
|
144 |
-
for input_ids, src_ids, src_mask in zip(
|
145 |
-
|
146 |
-
|
|
|
147 |
hidden_state = model(torch.tensor(input_ids))
|
148 |
-
elif data_args.experiment ==
|
149 |
input_ids2 = torch.tensor(input_ids).to(model.device)
|
150 |
input_embs = model.transformer.wte(input_ids2) # input_embs
|
151 |
hidden_state = model.down_proj(input_embs)
|
152 |
hidden_state = hidden_state * data_args.emb_scale_factor
|
153 |
-
result_train_lst.append(
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
158 |
|
159 |
return result_train_lst
|
160 |
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
import psutil
|
|
|
163 |
# Process.memory_info is expressed in bytes, so convert to megabytes
|
164 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
165 |
from datasets import Dataset as Dataset2
|
166 |
-
|
|
|
167 |
print(raw_datasets)
|
168 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
169 |
|
170 |
-
|
171 |
def tokenize_function(examples):
|
172 |
if isinstance(vocab_dict, dict):
|
173 |
-
input_ids = [
|
|
|
|
|
|
|
174 |
elif isinstance(vocab_dict, PreTrainedTokenizerFast):
|
175 |
-
examples[
|
176 |
-
input_ids = vocab_dict(examples[
|
177 |
-
|
|
|
|
|
178 |
# clm input could be much much longer than block_size
|
179 |
return result_dict
|
180 |
|
@@ -182,28 +228,30 @@ def helper_tokenize_stream(sentence_lst, vocab_dict, model, seqlen, data_args, p
|
|
182 |
tokenize_function,
|
183 |
batched=True,
|
184 |
num_proc=4,
|
185 |
-
remove_columns=[
|
186 |
load_from_cache_file=True,
|
187 |
desc="Running tokenizer on dataset",
|
188 |
)
|
189 |
print(tokenized_datasets)
|
190 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
191 |
|
192 |
-
if padding_mode ==
|
193 |
block_size = seqlen
|
|
|
194 |
def group_texts(examples):
|
195 |
-
concatenated_examples = {
|
|
|
|
|
196 |
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
197 |
if total_length >= block_size:
|
198 |
total_length = (total_length // block_size) * block_size
|
199 |
result = {
|
200 |
-
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
201 |
for k, t in concatenated_examples.items()
|
202 |
}
|
203 |
result["labels"] = result["input_ids"].copy()
|
204 |
return result
|
205 |
|
206 |
-
|
207 |
lm_datasets = tokenized_datasets.map(
|
208 |
group_texts,
|
209 |
batched=True,
|
@@ -212,12 +260,17 @@ def helper_tokenize_stream(sentence_lst, vocab_dict, model, seqlen, data_args, p
|
|
212 |
desc=f"Grouping texts in chunks of {block_size}",
|
213 |
)
|
214 |
else:
|
|
|
215 |
def pad_function(group_lst):
|
216 |
max_length = seqlen
|
217 |
if isinstance(vocab_dict, dict):
|
218 |
-
group_lst[
|
|
|
|
|
219 |
else:
|
220 |
-
group_lst[
|
|
|
|
|
221 |
return group_lst
|
222 |
|
223 |
# Process.memory_info is expressed in bytes, so convert to megabytes
|
@@ -230,59 +283,72 @@ def helper_tokenize_stream(sentence_lst, vocab_dict, model, seqlen, data_args, p
|
|
230 |
desc=f"padding",
|
231 |
)
|
232 |
|
233 |
-
|
234 |
-
print(lm_datasets, 'padded dataset')
|
235 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
236 |
import datasets
|
|
|
237 |
raw_datasets = datasets.DatasetDict()
|
238 |
-
raw_datasets[
|
239 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
240 |
return raw_datasets
|
241 |
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
result_train_lst = []
|
244 |
group_lst = defaultdict(list)
|
245 |
with torch.no_grad():
|
246 |
for input_ids in sentence_lst:
|
247 |
-
tokenized_ = [vocab_dict.get(x, vocab_dict[
|
248 |
input_ids = [0] + tokenized_ + [1]
|
249 |
-
group_lst[
|
250 |
-
print(group_lst[
|
251 |
|
252 |
-
if padding_mode ==
|
253 |
-
print(
|
254 |
concatenated_examples = {k: sum(group_lst[k], []) for k in group_lst.keys()}
|
255 |
total_length = len(concatenated_examples[list(group_lst.keys())[0]])
|
256 |
block_size = seqlen
|
257 |
total_length = (total_length // block_size) * block_size
|
258 |
# Split by chunks of max_len.
|
259 |
group_lst = {
|
260 |
-
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
261 |
for k, t in concatenated_examples.items()
|
262 |
}
|
263 |
-
elif padding_mode ==
|
264 |
-
print(
|
265 |
max_length = seqlen
|
266 |
-
group_lst[
|
|
|
|
|
267 |
|
268 |
-
for input_ids in group_lst[
|
269 |
-
if data_args.experiment.startswith(
|
270 |
hidden_state = model(torch.tensor(input_ids))
|
271 |
-
elif data_args.experiment ==
|
272 |
input_ids2 = torch.tensor(input_ids).to(model.device)
|
273 |
input_embs = model.transformer.wte(input_ids2) # input_embs
|
274 |
hidden_state = model.down_proj(input_embs)
|
275 |
hidden_state = hidden_state * data_args.emb_scale_factor
|
276 |
-
elif data_args.experiment ==
|
277 |
hidden_state = model(torch.tensor(input_ids))
|
278 |
-
result_train_lst.append(
|
|
|
|
|
279 |
|
280 |
return result_train_lst
|
281 |
|
|
|
282 |
def load_glove_model(File):
|
283 |
print("Loading Glove Model")
|
284 |
glove_model = {}
|
285 |
-
with open(File,
|
286 |
for line in f:
|
287 |
split_line = line.split()
|
288 |
word = split_line[0]
|
@@ -292,9 +358,10 @@ def load_glove_model(File):
|
|
292 |
print(f"{len(glove_model)} words loaded!")
|
293 |
return glove_model
|
294 |
|
|
|
295 |
def load_glove(vocab):
|
296 |
model = torch.nn.Embedding(len(vocab), 50)
|
297 |
-
glove_model = load_glove_model(
|
298 |
array_lst = []
|
299 |
count_ = 0
|
300 |
for word, idx in vocab.items():
|
@@ -303,20 +370,21 @@ def load_glove(vocab):
|
|
303 |
else:
|
304 |
count_ += 1
|
305 |
array_lst.append(torch.randn(50))
|
306 |
-
print(f
|
307 |
array_lst = torch.stack(array_lst)
|
308 |
print(torch.norm(array_lst, dim=-1).mean())
|
309 |
model.weight.data = array_lst
|
310 |
return model
|
311 |
|
312 |
|
313 |
-
def get_corpus_rocstory(
|
314 |
-
|
|
|
315 |
import csv, torch, json
|
316 |
from spacy.lang.en import English
|
317 |
|
318 |
-
if data_args.experiment_mode ==
|
319 |
-
if data_args.modality ==
|
320 |
pass
|
321 |
# print('loading dataset from ROCStory')
|
322 |
# nlp = English()
|
@@ -347,7 +415,7 @@ def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
|
|
347 |
# # sentence_lst.append(word_lst)
|
348 |
# # sentence_lst = sentence_lst[1:]
|
349 |
# print(sentence_lst[:2])
|
350 |
-
if data_args.modality ==
|
351 |
pass
|
352 |
# print('loading dataset from ROCStory')
|
353 |
# nlp = English()
|
@@ -381,7 +449,7 @@ def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
|
|
381 |
# word_lst = [x.text for x in tokenizer(sentences)]
|
382 |
# sentence_lst.append(word_lst)
|
383 |
# print(sentence_lst[:2],sentence_lst[-2:], 'dataset size=',len(sentence_lst))
|
384 |
-
elif data_args.modality ==
|
385 |
pass
|
386 |
# print('loading dataset from simple wikipedia')
|
387 |
# sentence_lst = []
|
@@ -390,57 +458,62 @@ def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
|
|
390 |
# word_lst = row.lower().split()
|
391 |
# sentence_lst.append(word_lst)
|
392 |
# print(sentence_lst[:2])
|
393 |
-
elif data_args.modality ==
|
394 |
-
print(
|
395 |
sentence_lst = []
|
396 |
nlp = English()
|
397 |
tokenizer = nlp.tokenizer
|
398 |
-
if split ==
|
399 |
-
print(
|
400 |
-
path =
|
|
|
|
|
401 |
# path = f'../{data_args.e2e_train}/src1_train.txt'
|
402 |
-
elif split ==
|
403 |
-
print(
|
404 |
-
path = f
|
405 |
-
path =
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
|
|
|
|
412 |
path = data_args.debug_path
|
413 |
import json
|
414 |
-
|
|
|
415 |
for line in ff:
|
416 |
-
sentence_lst.append(json.loads(line)[0].split(
|
417 |
sentence_lst = sentence_lst + sentence_lst
|
418 |
-
if split in [
|
419 |
-
with open(path,
|
420 |
for row in ff:
|
421 |
-
word_lst = row.split(
|
422 |
word_lst = [x.text for x in tokenizer(word_lst)]
|
423 |
sentence_lst.append(word_lst)
|
424 |
print(sentence_lst[:2])
|
425 |
|
426 |
-
elif data_args.modality ==
|
427 |
-
print(
|
428 |
sentence_lst = []
|
429 |
nlp = English()
|
430 |
tokenizer = nlp.tokenizer
|
431 |
-
if split ==
|
432 |
-
print(
|
433 |
-
path = f
|
434 |
-
elif split ==
|
435 |
-
print(
|
436 |
-
path = f
|
437 |
-
elif split ==
|
438 |
-
print(
|
439 |
-
path = f
|
440 |
-
if split in [
|
441 |
-
|
442 |
-
with open(path,
|
443 |
-
yelp_reader = csv.reader(csvfile)
|
444 |
for row in yelp_reader:
|
445 |
sentences = row[1]
|
446 |
word_lst = [x.text for x in tokenizer(sentences)]
|
@@ -448,175 +521,188 @@ def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
|
|
448 |
sentence_lst = sentence_lst[1:]
|
449 |
print(sentence_lst[:2])
|
450 |
|
451 |
-
elif data_args.modality ==
|
452 |
-
print(
|
453 |
sentence_lst = []
|
454 |
nlp = English()
|
455 |
tokenizer = nlp.tokenizer
|
456 |
-
if split ==
|
457 |
-
print(
|
458 |
-
path = f
|
459 |
-
elif split ==
|
460 |
-
print(
|
461 |
-
path = f
|
462 |
-
elif split ==
|
463 |
-
print(
|
464 |
-
path = f
|
465 |
-
if split in [
|
466 |
-
with open(path,
|
467 |
for line in ff:
|
468 |
line = json.loads(line)
|
469 |
-
for sentences in line[
|
470 |
word_lst = [x.text for x in tokenizer(sentences)]
|
471 |
sentence_lst.append(word_lst)
|
472 |
print(sentence_lst[:2])
|
473 |
|
474 |
-
elif data_args.modality ==
|
475 |
-
print(
|
476 |
sentence_lst = []
|
477 |
nlp = English()
|
478 |
tokenizer = nlp.tokenizer
|
479 |
-
if split ==
|
480 |
-
print(
|
481 |
-
path = f
|
482 |
-
path_lst = [f
|
483 |
-
path_lst.append(
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
|
|
487 |
path_lst = []
|
488 |
-
elif split ==
|
489 |
-
print(
|
490 |
-
path = f
|
491 |
path_lst = []
|
492 |
|
493 |
-
if split in [
|
494 |
-
with open(path,
|
495 |
for line in ff:
|
496 |
line = json.loads(line)
|
497 |
-
for sentences in line[
|
498 |
word_lst = [x.text for x in tokenizer(sentences)]
|
499 |
sentence_lst.append(word_lst)
|
500 |
print(sentence_lst[:2])
|
501 |
import itertools
|
|
|
502 |
for path in path_lst:
|
503 |
-
if path.endswith(
|
504 |
-
with open(path,
|
505 |
for row in roc_reader:
|
506 |
sentences = row.strip()
|
507 |
word_lst = [x.text for x in tokenizer(sentences)]
|
508 |
spl = [[]]
|
509 |
-
for x, y in itertools.groupby(word_lst, lambda z: z ==
|
510 |
spl[-1].extend(y)
|
511 |
-
if x:
|
|
|
512 |
sentence_lst.extend(spl[:-1])
|
513 |
else:
|
514 |
-
with open(path,
|
515 |
for row in roc_reader:
|
516 |
sentences = json.loads(row)[0].strip()
|
517 |
word_lst = [x.text for x in tokenizer(sentences)]
|
518 |
spl = [[]]
|
519 |
-
for x, y in itertools.groupby(word_lst, lambda z: z ==
|
520 |
spl[-1].extend(y)
|
521 |
-
if x:
|
|
|
522 |
sentence_lst.extend(spl[:-1])
|
523 |
|
524 |
print(sentence_lst[-2:])
|
525 |
|
526 |
-
|
527 |
# get tokenizer.
|
528 |
if load_vocab is None:
|
529 |
counter = Counter()
|
530 |
for input_ids in sentence_lst:
|
531 |
counter.update(input_ids)
|
532 |
|
533 |
-
if data_args.experiment_mode ==
|
534 |
-
if data_args.modality ==
|
535 |
-
print(
|
536 |
sentence_lst = []
|
537 |
nlp = English()
|
538 |
tokenizer = nlp.tokenizer
|
539 |
-
if split ==
|
540 |
-
path = f
|
541 |
-
with open(path,
|
542 |
for row in ff:
|
543 |
-
src_lst, word_lst = row.split(
|
544 |
word_lst = [x.text for x in tokenizer(word_lst)]
|
545 |
src_lst = [x.text for x in tokenizer(src_lst)]
|
546 |
sentence_lst.append((src_lst, word_lst))
|
547 |
-
elif split ==
|
548 |
-
path = f
|
549 |
sentence_lst = read_e2e_files(path, data_args, tokenizer)
|
550 |
print(sentence_lst[:2])
|
551 |
# get tokenizer.
|
552 |
if load_vocab is None:
|
553 |
counter = Counter()
|
554 |
-
for
|
555 |
counter.update(input_ids)
|
556 |
counter.update(src_ids)
|
557 |
|
558 |
if load_vocab is None:
|
559 |
-
vocab_dict = {
|
560 |
for k, v in counter.items():
|
561 |
if v > 10:
|
562 |
vocab_dict[k] = len(vocab_dict)
|
563 |
print(len(counter), len(vocab_dict))
|
564 |
|
565 |
-
path_save_vocab =
|
566 |
-
print(f
|
567 |
-
with open(path_save_vocab,
|
568 |
json.dump(vocab_dict, f)
|
569 |
else:
|
570 |
vocab_dict = load_vocab
|
571 |
-
path_save_vocab =
|
572 |
if not os.path.exists(path_save_vocab):
|
573 |
-
print(f
|
574 |
if isinstance(vocab_dict, dict):
|
575 |
-
with open(path_save_vocab,
|
576 |
json.dump(vocab_dict, f)
|
577 |
-
assert vocab_dict[
|
578 |
elif isinstance(vocab_dict, PreTrainedTokenizerFast):
|
579 |
vocab_dict.save_pretrained(data_args.checkpoint_path)
|
580 |
else:
|
581 |
assert False, "invalid type of vocab_dict"
|
582 |
|
583 |
-
|
584 |
-
|
585 |
-
if model is None and data_args.experiment == 'random':
|
586 |
model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel)
|
587 |
-
print(
|
588 |
torch.nn.init.normal_(model.weight)
|
589 |
-
path_save =
|
590 |
-
print(
|
|
|
|
|
591 |
torch.save(model.state_dict(), path_save)
|
592 |
|
593 |
# path_save = f'{data_args.checkpoint_path}/random_emb.torch'
|
594 |
# if not os.path.exists(path_save) and data_args.experiment == 'random':
|
595 |
# torch.save(model.state_dict(), path_save)
|
596 |
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
|
|
|
|
|
|
|
|
|
|
601 |
return train_dataset, model
|
602 |
-
elif data_args.experiment_mode ==
|
603 |
-
result_train_lst = helper_tokenize_encode(
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
|
|
|
|
|
|
|
|
608 |
|
609 |
def write_e2e_corr(prompt_lst, file_dict, corr_path):
|
610 |
print(len(prompt_lst))
|
611 |
-
with open(corr_path,
|
612 |
for x in prompt_lst:
|
613 |
for line in file_dict[x]:
|
614 |
print(" ".join(line), file=f)
|
615 |
-
print(
|
616 |
|
617 |
|
618 |
def write_e2e_src(prompt_lst, corr_path):
|
619 |
-
with open(corr_path,
|
620 |
for x in prompt_lst:
|
621 |
print(" ".join(x), file=f)
|
622 |
return
|
@@ -624,48 +710,55 @@ def write_e2e_src(prompt_lst, corr_path):
|
|
624 |
|
625 |
def read_e2e_files(path, args, tokenizer):
|
626 |
file_dict = {}
|
627 |
-
with open(path,
|
628 |
for line in f:
|
629 |
-
src_lst, word_lst = line.strip().split(
|
630 |
tgt = tuple([x.text for x in tokenizer(word_lst)])
|
631 |
src = tuple([x.text for x in tokenizer(src_lst)])
|
632 |
if src not in file_dict:
|
633 |
file_dict[src] = []
|
634 |
file_dict[src].append(tgt)
|
635 |
-
temp =
|
636 |
prompt_text_dict = file_dict
|
637 |
prompt_text_lst = list(prompt_text_dict.keys())
|
638 |
-
gold_dir = os.path.join(args.out_dir,
|
639 |
print("gold dir", gold_dir)
|
640 |
write_e2e_corr(prompt_text_lst, prompt_text_dict, gold_dir)
|
641 |
-
src_dir = os.path.join(args.out_dir,
|
642 |
write_e2e_src(prompt_text_lst, src_dir)
|
643 |
final_lst = [(xx, prompt_text_dict[xx][0]) for xx in prompt_text_lst]
|
644 |
return final_lst
|
645 |
|
646 |
|
647 |
-
def get_corpus_book(
|
648 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
import os
|
650 |
-
|
651 |
-
|
|
|
652 |
if "validation" not in raw_datasets.keys():
|
653 |
raw_datasets["validation"] = load_dataset(
|
654 |
-
|
655 |
split=f"train[:1%]",
|
656 |
)
|
657 |
raw_datasets["train"] = load_dataset(
|
658 |
-
|
659 |
split=f"train[1%:]",
|
660 |
)
|
661 |
print(raw_datasets)
|
662 |
column_names = raw_datasets["train"].column_names
|
663 |
|
664 |
def tokenize_function(examples):
|
665 |
-
output = tokenizer(examples[
|
666 |
return output
|
667 |
|
668 |
-
|
669 |
tokenized_datasets = raw_datasets.map(
|
670 |
tokenize_function,
|
671 |
batched=True,
|
@@ -686,7 +779,7 @@ def get_corpus_book(data_args, tokenizer, model, image_size, padding_mode='block
|
|
686 |
if total_length >= block_size:
|
687 |
total_length = (total_length // block_size) * block_size
|
688 |
result = {
|
689 |
-
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
690 |
for k, t in concatenated_examples.items()
|
691 |
}
|
692 |
return result
|
@@ -702,32 +795,44 @@ def get_corpus_book(data_args, tokenizer, model, image_size, padding_mode='block
|
|
702 |
print(lm_datasets)
|
703 |
|
704 |
if model is None:
|
705 |
-
if data_args.training_mode.startswith(
|
706 |
-
print(
|
707 |
model = torch.nn.Embedding(len(tokenizer), 1)
|
708 |
else:
|
709 |
model = torch.nn.Embedding(len(tokenizer), data_args.in_channel)
|
710 |
-
print(
|
711 |
torch.nn.init.normal_(model.weight)
|
712 |
-
path_save = f
|
713 |
-
print(
|
|
|
|
|
714 |
torch.save(model.state_dict(), path_save)
|
715 |
|
716 |
-
if split ==
|
717 |
return lm_datasets, model
|
718 |
else:
|
719 |
-
lm_datasets[
|
720 |
return lm_datasets, model
|
721 |
|
722 |
|
723 |
class TextDataset(Dataset):
|
724 |
-
def __init__(
|
725 |
-
|
726 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
727 |
super().__init__()
|
728 |
self.resolution = resolution
|
729 |
self.text_datasets = text_datasets
|
730 |
-
self.length = len(self.text_datasets[
|
731 |
self.model_arch = model_arch
|
732 |
self.data_args = data_args
|
733 |
print(self.resolution)
|
@@ -745,8 +850,8 @@ class TextDataset(Dataset):
|
|
745 |
# We are not on a new enough PIL to support the `reducing_gap`
|
746 |
# argument, which uses BOX downsampling at powers of two first.
|
747 |
# Thus, we do it by hand to improve downsample quality.
|
748 |
-
if self.model_arch ==
|
749 |
-
pass# arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
|
750 |
# dtype=np.float32).reshape(self.resolution, self.resolution, -1)
|
751 |
# # print(self.eigen_transform.shape)
|
752 |
# if self.eigen_transform is not None:
|
@@ -757,15 +862,14 @@ class TextDataset(Dataset):
|
|
757 |
# if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
|
758 |
# arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
|
759 |
|
760 |
-
|
761 |
# out_dict = {}
|
762 |
# out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
|
763 |
# # if self.local_classes is not None:
|
764 |
# # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
765 |
# # print(out_dict.keys())
|
766 |
# return np.transpose(arr, [2, 0, 1]), out_dict
|
767 |
-
elif self.model_arch ==
|
768 |
-
pass# arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
|
769 |
# dtype=np.float32) # seqlen, dim
|
770 |
# if self.eigen_transform is not None:
|
771 |
# old_shape = arr.shape
|
@@ -783,27 +887,39 @@ class TextDataset(Dataset):
|
|
783 |
# # print(arr.shape)
|
784 |
# return arr, out_dict
|
785 |
else:
|
786 |
-
arr = np.array(
|
787 |
-
|
788 |
-
|
|
|
789 |
old_shape = arr.shape
|
790 |
# arr = arr.reshape(1, -1) @ self.eigen_transform
|
791 |
-
arr = arr.reshape(1, -1) - self.eigen_transform[
|
792 |
-
arr = arr @ self.eigen_transform[
|
793 |
arr = arr.reshape(old_shape)
|
794 |
-
|
795 |
-
if
|
|
|
|
|
|
|
796 |
# print(arr.dtype)
|
797 |
# print(self.data_args.noise_level, 'using the noise level.')
|
798 |
-
arr = arr + self.data_args.noise_level * np.random.randn(
|
|
|
|
|
799 |
# print(arr.dtype)
|
800 |
|
801 |
out_dict = {}
|
802 |
-
out_dict[
|
|
|
|
|
803 |
# out_dict['mapping_func'] = self.mapping_func
|
804 |
-
if self.data_args.experiment_mode ==
|
805 |
-
out_dict[
|
806 |
-
|
|
|
|
|
|
|
|
|
807 |
# if self.local_classes is not None:
|
808 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
809 |
return arr, out_dict
|
@@ -813,13 +929,23 @@ class TextDataset(Dataset):
|
|
813 |
|
814 |
|
815 |
class TextDataset_NoCache(Dataset):
|
816 |
-
def __init__(
|
817 |
-
|
818 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
819 |
super().__init__()
|
820 |
self.resolution = resolution
|
821 |
self.text_datasets = text_datasets
|
822 |
-
self.length = len(self.text_datasets[
|
823 |
self.model_arch = model_arch
|
824 |
self.data_args = data_args
|
825 |
print(self.resolution)
|
@@ -838,81 +964,110 @@ class TextDataset_NoCache(Dataset):
|
|
838 |
# argument, which uses BOX downsampling at powers of two first.
|
839 |
# Thus, we do it by hand to improve downsample quality.
|
840 |
with torch.no_grad():
|
841 |
-
input_ids = self.text_datasets[
|
842 |
model = self.model_emb
|
843 |
-
if self.data_args.experiment.startswith(
|
844 |
hidden_state = model(torch.tensor(input_ids))
|
845 |
-
elif self.data_args.experiment ==
|
846 |
input_ids2 = torch.tensor(input_ids).to(model.device)
|
847 |
input_embs = model.transformer.wte(input_ids2) # input_embs
|
848 |
hidden_state = model.down_proj(input_embs)
|
849 |
hidden_state = hidden_state * data_args.emb_scale_factor
|
850 |
|
851 |
-
if self.model_arch ==
|
852 |
-
arr = np.array(hidden_state,
|
853 |
-
|
|
|
854 |
# print(self.eigen_transform.shape)
|
855 |
if self.eigen_transform is not None:
|
856 |
old_shape = arr.shape
|
857 |
-
arr = arr.reshape(1, -1) - self.eigen_transform[
|
858 |
-
arr = arr @ self.eigen_transform[
|
859 |
arr = arr.reshape(old_shape)
|
860 |
-
if
|
861 |
-
|
|
|
|
|
|
|
|
|
|
|
862 |
|
863 |
out_dict = {}
|
864 |
-
out_dict[
|
|
|
|
|
865 |
# if self.local_classes is not None:
|
866 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
867 |
# print(out_dict.keys())
|
868 |
return np.transpose(arr, [2, 0, 1]), out_dict
|
869 |
-
elif self.model_arch ==
|
870 |
-
arr = np.array(hidden_state,
|
871 |
-
dtype=np.float32) # seqlen, dim
|
872 |
if self.eigen_transform is not None:
|
873 |
old_shape = arr.shape
|
874 |
-
arr = arr.reshape(1, -1) - self.eigen_transform[
|
875 |
-
arr = arr @ self.eigen_transform[
|
876 |
arr = arr.reshape(old_shape)
|
877 |
-
if
|
878 |
-
|
|
|
|
|
|
|
|
|
|
|
879 |
arr = np.transpose(arr, [1, 0])
|
880 |
out_dict = {}
|
881 |
-
out_dict[
|
|
|
|
|
882 |
# out_dict['mapping_func'] = self.mapping_func
|
883 |
# if self.local_classes is not None:
|
884 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
885 |
# print(arr.shape)
|
886 |
return arr, out_dict
|
887 |
else:
|
888 |
-
arr = np.array(hidden_state,
|
889 |
-
dtype=np.float32)
|
890 |
if self.eigen_transform is not None:
|
891 |
old_shape = arr.shape
|
892 |
# arr = arr.reshape(1, -1) @ self.eigen_transform
|
893 |
-
arr = arr.reshape(1, -1) - self.eigen_transform[
|
894 |
-
arr = arr @ self.eigen_transform[
|
895 |
arr = arr.reshape(old_shape)
|
896 |
|
897 |
-
if
|
|
|
|
|
|
|
898 |
# print(arr.dtype)
|
899 |
# print(self.data_args.noise_level, 'using the noise level.')
|
900 |
-
arr = arr + self.data_args.noise_level * np.random.randn(
|
|
|
|
|
901 |
# print(arr.dtype)
|
902 |
|
903 |
out_dict = {}
|
904 |
-
out_dict[
|
|
|
|
|
905 |
# out_dict['mapping_func'] = self.mapping_func
|
906 |
-
if self.data_args.experiment_mode ==
|
907 |
-
out_dict[
|
908 |
-
|
|
|
|
|
|
|
|
|
909 |
# if self.local_classes is not None:
|
910 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
911 |
return arr, out_dict
|
912 |
|
|
|
913 |
def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False):
|
914 |
-
result = torch.full(
|
915 |
-
|
|
|
|
|
|
|
|
|
916 |
for i, example in enumerate(examples):
|
917 |
curr_len = min(len(example), max_length)
|
918 |
result[i][:curr_len] = example[:curr_len]
|
@@ -921,6 +1076,7 @@ def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False)
|
|
921 |
return result, mask_
|
922 |
return result
|
923 |
|
|
|
924 |
def _torch_collate_batch(examples, pad_token_id, max_length):
|
925 |
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
926 |
import numpy as np
|
@@ -945,4 +1101,4 @@ def _torch_collate_batch(examples, pad_token_id, max_length):
|
|
945 |
result[i, : example.shape[0]] = example
|
946 |
else:
|
947 |
result[i, -example.shape[0] :] = example
|
948 |
-
return result
|
|
|
1 |
# from PIL import Image
|
2 |
# import blobfile as bf
|
3 |
+
# from mpi4py import MPI
|
4 |
import numpy as np
|
5 |
from torch.utils.data import DataLoader, Dataset
|
6 |
+
from transformers import (
|
7 |
+
AutoModelForCausalLM,
|
8 |
+
AutoConfig,
|
9 |
+
AutoTokenizer,
|
10 |
+
default_data_collator,
|
11 |
+
PreTrainedTokenizerFast,
|
12 |
+
PreTrainedTokenizer,
|
13 |
+
)
|
14 |
+
|
15 |
# from datasets import load_dataset
|
16 |
import sys, os
|
17 |
import torch
|
18 |
+
|
19 |
# sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
|
20 |
# from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise
|
21 |
from collections import Counter, defaultdict
|
|
|
24 |
|
25 |
|
26 |
def load_data_text(
|
27 |
+
*,
|
28 |
+
data_dir,
|
29 |
+
batch_size,
|
30 |
+
image_size,
|
31 |
+
class_cond=False,
|
32 |
+
deterministic=False,
|
33 |
+
data_args=None,
|
34 |
+
task_mode="roc",
|
35 |
+
model=None,
|
36 |
+
padding_mode="block",
|
37 |
+
split="train",
|
38 |
+
load_vocab=None,
|
39 |
):
|
40 |
"""
|
41 |
For a dataset, create a generator over (images, kwargs) pairs.
|
|
|
53 |
exception will be raised.
|
54 |
:param deterministic: if True, yield results in a deterministic order.
|
55 |
"""
|
56 |
+
print("hello loading text data. ")
|
57 |
|
58 |
+
if data_args.experiment.startswith("random") and model is None:
|
59 |
model = None
|
60 |
# elif data_args.experiment.startswith('random') and model is not None:
|
61 |
# print('loading initialized random embeddings. ')
|
62 |
|
63 |
+
if task_mode == "roc" or task_mode == "roc-aug":
|
64 |
pass
|
65 |
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
|
66 |
# padding_mode=padding_mode, split=split,
|
67 |
+
# load_vocab=load_vocab)
|
68 |
+
elif task_mode == "simple-wiki":
|
69 |
pass
|
70 |
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
|
71 |
+
# padding_mode=padding_mode, split=split,
|
72 |
+
# load_vocab=load_vocab)
|
73 |
+
|
74 |
+
elif task_mode == "e2e-tgt":
|
75 |
+
print("hello loading e2e-tgt. ")
|
76 |
+
training_data, model = get_corpus_rocstory(
|
77 |
+
data_args,
|
78 |
+
model,
|
79 |
+
image_size,
|
80 |
+
padding_mode=padding_mode,
|
81 |
+
split=split,
|
82 |
+
load_vocab=load_vocab,
|
83 |
+
)
|
84 |
# elif task_mode == 'yelp':
|
85 |
# print('hello loading yelp ')
|
86 |
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
|
|
|
103 |
# training_data, model = get_corpus_book(data_args, tokenizer, model, image_size,
|
104 |
# padding_mode=padding_mode, split=split,)
|
105 |
|
106 |
+
if (
|
107 |
+
data_args.modality
|
108 |
+
in ["roc-aug", "roc", "book", "yelp", "commonGen", "commonGen-aug"]
|
109 |
+
and data_args.cache_mode == "no"
|
110 |
+
):
|
111 |
+
pass # dataset = TextDataset_NoCache(
|
112 |
# training_data,
|
113 |
# image_size,
|
114 |
# data_args,
|
|
|
125 |
|
126 |
if deterministic:
|
127 |
|
128 |
+
pass # data_loader = DataLoader(
|
129 |
# dataset,
|
130 |
# batch_size=batch_size, # 20,
|
131 |
# drop_last=True,
|
|
|
144 |
while True:
|
145 |
yield from data_loader
|
146 |
|
147 |
+
|
148 |
def helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, seqlen, data_args):
|
149 |
result_train_lst = []
|
150 |
group_lst = defaultdict(list)
|
151 |
with torch.no_grad():
|
152 |
+
for src_ids, input_ids in sentence_lst:
|
153 |
+
tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids]
|
154 |
+
tokenized_src = [vocab_dict.get(x, vocab_dict["UNK"]) for x in src_ids]
|
155 |
input_ids = [0] + tokenized_ + [1]
|
156 |
+
group_lst["word_ids"].append(input_ids)
|
157 |
+
group_lst["src_ids"].append(tokenized_src)
|
158 |
|
159 |
+
print(group_lst["word_ids"][:2])
|
160 |
+
print("padding mode is pad")
|
161 |
max_length = seqlen
|
162 |
+
group_lst["word_ids"] = _collate_batch_helper(
|
163 |
+
group_lst["word_ids"], vocab_dict["PAD"], max_length
|
164 |
+
)
|
165 |
+
max_src_length = max([len(xx) for xx in group_lst["src_ids"]])
|
166 |
print(max_src_length, seqlen)
|
167 |
max_src_length = min(seqlen, max_src_length)
|
168 |
+
group_lst["src_ids"], group_lst["src_mask"] = _collate_batch_helper(
|
169 |
+
group_lst["src_ids"], vocab_dict["PAD"], max_src_length, return_mask=True
|
170 |
+
)
|
|
|
|
|
171 |
|
172 |
+
for input_ids, src_ids, src_mask in zip(
|
173 |
+
group_lst["word_ids"], group_lst["src_ids"], group_lst["src_mask"]
|
174 |
+
):
|
175 |
+
if data_args.experiment.startswith("random"):
|
176 |
hidden_state = model(torch.tensor(input_ids))
|
177 |
+
elif data_args.experiment == "gpt2_pre_compress":
|
178 |
input_ids2 = torch.tensor(input_ids).to(model.device)
|
179 |
input_embs = model.transformer.wte(input_ids2) # input_embs
|
180 |
hidden_state = model.down_proj(input_embs)
|
181 |
hidden_state = hidden_state * data_args.emb_scale_factor
|
182 |
+
result_train_lst.append(
|
183 |
+
{
|
184 |
+
"input_ids": input_ids,
|
185 |
+
"hidden_states": hidden_state.cpu().tolist(),
|
186 |
+
"src_ids": src_ids,
|
187 |
+
"src_mask": src_mask,
|
188 |
+
}
|
189 |
+
)
|
190 |
|
191 |
return result_train_lst
|
192 |
|
193 |
+
|
194 |
+
def helper_tokenize_stream(
|
195 |
+
sentence_lst,
|
196 |
+
vocab_dict,
|
197 |
+
model,
|
198 |
+
seqlen,
|
199 |
+
data_args,
|
200 |
+
padding_mode,
|
201 |
+
):
|
202 |
import psutil
|
203 |
+
|
204 |
# Process.memory_info is expressed in bytes, so convert to megabytes
|
205 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
206 |
from datasets import Dataset as Dataset2
|
207 |
+
|
208 |
+
raw_datasets = Dataset2.from_dict({"text": sentence_lst})
|
209 |
print(raw_datasets)
|
210 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
211 |
|
|
|
212 |
def tokenize_function(examples):
|
213 |
if isinstance(vocab_dict, dict):
|
214 |
+
input_ids = [
|
215 |
+
[0] + [vocab_dict.get(x, vocab_dict["UNK"]) for x in seq] + [1]
|
216 |
+
for seq in examples["text"]
|
217 |
+
]
|
218 |
elif isinstance(vocab_dict, PreTrainedTokenizerFast):
|
219 |
+
examples["text"] = [" ".join(seq) for seq in examples["text"]]
|
220 |
+
input_ids = vocab_dict(examples["text"], add_special_tokens=True)[
|
221 |
+
"input_ids"
|
222 |
+
]
|
223 |
+
result_dict = {"input_ids": input_ids}
|
224 |
# clm input could be much much longer than block_size
|
225 |
return result_dict
|
226 |
|
|
|
228 |
tokenize_function,
|
229 |
batched=True,
|
230 |
num_proc=4,
|
231 |
+
remove_columns=["text"],
|
232 |
load_from_cache_file=True,
|
233 |
desc="Running tokenizer on dataset",
|
234 |
)
|
235 |
print(tokenized_datasets)
|
236 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
237 |
|
238 |
+
if padding_mode == "block":
|
239 |
block_size = seqlen
|
240 |
+
|
241 |
def group_texts(examples):
|
242 |
+
concatenated_examples = {
|
243 |
+
k: list(chain(*examples[k])) for k in examples.keys()
|
244 |
+
}
|
245 |
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
246 |
if total_length >= block_size:
|
247 |
total_length = (total_length // block_size) * block_size
|
248 |
result = {
|
249 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
250 |
for k, t in concatenated_examples.items()
|
251 |
}
|
252 |
result["labels"] = result["input_ids"].copy()
|
253 |
return result
|
254 |
|
|
|
255 |
lm_datasets = tokenized_datasets.map(
|
256 |
group_texts,
|
257 |
batched=True,
|
|
|
260 |
desc=f"Grouping texts in chunks of {block_size}",
|
261 |
)
|
262 |
else:
|
263 |
+
|
264 |
def pad_function(group_lst):
|
265 |
max_length = seqlen
|
266 |
if isinstance(vocab_dict, dict):
|
267 |
+
group_lst["input_ids"] = _collate_batch_helper(
|
268 |
+
group_lst["input_ids"], vocab_dict["PAD"], max_length
|
269 |
+
)
|
270 |
else:
|
271 |
+
group_lst["input_ids"] = _collate_batch_helper(
|
272 |
+
group_lst["input_ids"], vocab_dict.pad_token_id, max_length
|
273 |
+
)
|
274 |
return group_lst
|
275 |
|
276 |
# Process.memory_info is expressed in bytes, so convert to megabytes
|
|
|
283 |
desc=f"padding",
|
284 |
)
|
285 |
|
286 |
+
print(lm_datasets, "padded dataset")
|
|
|
287 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
288 |
import datasets
|
289 |
+
|
290 |
raw_datasets = datasets.DatasetDict()
|
291 |
+
raw_datasets["train"] = lm_datasets
|
292 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
293 |
return raw_datasets
|
294 |
|
295 |
+
|
296 |
+
def helper_tokenize_encode(
|
297 |
+
sentence_lst,
|
298 |
+
vocab_dict,
|
299 |
+
model,
|
300 |
+
seqlen,
|
301 |
+
data_args,
|
302 |
+
padding_mode,
|
303 |
+
):
|
304 |
result_train_lst = []
|
305 |
group_lst = defaultdict(list)
|
306 |
with torch.no_grad():
|
307 |
for input_ids in sentence_lst:
|
308 |
+
tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids]
|
309 |
input_ids = [0] + tokenized_ + [1]
|
310 |
+
group_lst["word_ids"].append(input_ids)
|
311 |
+
print(group_lst["word_ids"][:2])
|
312 |
|
313 |
+
if padding_mode == "block":
|
314 |
+
print("padding mode is block")
|
315 |
concatenated_examples = {k: sum(group_lst[k], []) for k in group_lst.keys()}
|
316 |
total_length = len(concatenated_examples[list(group_lst.keys())[0]])
|
317 |
block_size = seqlen
|
318 |
total_length = (total_length // block_size) * block_size
|
319 |
# Split by chunks of max_len.
|
320 |
group_lst = {
|
321 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
322 |
for k, t in concatenated_examples.items()
|
323 |
}
|
324 |
+
elif padding_mode == "pad":
|
325 |
+
print("padding mode is pad")
|
326 |
max_length = seqlen
|
327 |
+
group_lst["word_ids"] = _collate_batch_helper(
|
328 |
+
group_lst["word_ids"], vocab_dict["PAD"], max_length
|
329 |
+
)
|
330 |
|
331 |
+
for input_ids in group_lst["word_ids"]:
|
332 |
+
if data_args.experiment.startswith("random"):
|
333 |
hidden_state = model(torch.tensor(input_ids))
|
334 |
+
elif data_args.experiment == "gpt2_pre_compress":
|
335 |
input_ids2 = torch.tensor(input_ids).to(model.device)
|
336 |
input_embs = model.transformer.wte(input_ids2) # input_embs
|
337 |
hidden_state = model.down_proj(input_embs)
|
338 |
hidden_state = hidden_state * data_args.emb_scale_factor
|
339 |
+
elif data_args.experiment == "glove":
|
340 |
hidden_state = model(torch.tensor(input_ids))
|
341 |
+
result_train_lst.append(
|
342 |
+
{"input_ids": input_ids, "hidden_states": hidden_state.cpu().tolist()}
|
343 |
+
)
|
344 |
|
345 |
return result_train_lst
|
346 |
|
347 |
+
|
348 |
def load_glove_model(File):
|
349 |
print("Loading Glove Model")
|
350 |
glove_model = {}
|
351 |
+
with open(File, "r") as f:
|
352 |
for line in f:
|
353 |
split_line = line.split()
|
354 |
word = split_line[0]
|
|
|
358 |
print(f"{len(glove_model)} words loaded!")
|
359 |
return glove_model
|
360 |
|
361 |
+
|
362 |
def load_glove(vocab):
|
363 |
model = torch.nn.Embedding(len(vocab), 50)
|
364 |
+
glove_model = load_glove_model("predictability/glove/glove.6B.50d.txt")
|
365 |
array_lst = []
|
366 |
count_ = 0
|
367 |
for word, idx in vocab.items():
|
|
|
370 |
else:
|
371 |
count_ += 1
|
372 |
array_lst.append(torch.randn(50))
|
373 |
+
print(f"{count_} out of {len(vocab)} is initialized. ")
|
374 |
array_lst = torch.stack(array_lst)
|
375 |
print(torch.norm(array_lst, dim=-1).mean())
|
376 |
model.weight.data = array_lst
|
377 |
return model
|
378 |
|
379 |
|
380 |
+
def get_corpus_rocstory(
|
381 |
+
data_args, model, image_size, padding_mode="block", split="train", load_vocab=None
|
382 |
+
):
|
383 |
import csv, torch, json
|
384 |
from spacy.lang.en import English
|
385 |
|
386 |
+
if data_args.experiment_mode == "lm":
|
387 |
+
if data_args.modality == "roc":
|
388 |
pass
|
389 |
# print('loading dataset from ROCStory')
|
390 |
# nlp = English()
|
|
|
415 |
# # sentence_lst.append(word_lst)
|
416 |
# # sentence_lst = sentence_lst[1:]
|
417 |
# print(sentence_lst[:2])
|
418 |
+
if data_args.modality == "roc-aug":
|
419 |
pass
|
420 |
# print('loading dataset from ROCStory')
|
421 |
# nlp = English()
|
|
|
449 |
# word_lst = [x.text for x in tokenizer(sentences)]
|
450 |
# sentence_lst.append(word_lst)
|
451 |
# print(sentence_lst[:2],sentence_lst[-2:], 'dataset size=',len(sentence_lst))
|
452 |
+
elif data_args.modality == "simple-wiki":
|
453 |
pass
|
454 |
# print('loading dataset from simple wikipedia')
|
455 |
# sentence_lst = []
|
|
|
458 |
# word_lst = row.lower().split()
|
459 |
# sentence_lst.append(word_lst)
|
460 |
# print(sentence_lst[:2])
|
461 |
+
elif data_args.modality == "e2e-tgt":
|
462 |
+
print("loading dataset from simple e2e dataset")
|
463 |
sentence_lst = []
|
464 |
nlp = English()
|
465 |
tokenizer = nlp.tokenizer
|
466 |
+
if split == "train":
|
467 |
+
print("loading form the TRAIN set")
|
468 |
+
path = (
|
469 |
+
"/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_train.txt"
|
470 |
+
)
|
471 |
# path = f'../{data_args.e2e_train}/src1_train.txt'
|
472 |
+
elif split == "valid":
|
473 |
+
print("loading form the VALID set")
|
474 |
+
path = f"../{data_args.e2e_train}/src1_valid.txt"
|
475 |
+
path = (
|
476 |
+
"/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_valid.txt"
|
477 |
+
)
|
478 |
+
elif split == "test":
|
479 |
+
print("loading form the TEST set")
|
480 |
+
path = f"../{data_args.e2e_train}/src1_test.txt"
|
481 |
+
path = "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_test.txt"
|
482 |
+
elif split == "debug":
|
483 |
+
print("loading form the DEBUG set")
|
484 |
path = data_args.debug_path
|
485 |
import json
|
486 |
+
|
487 |
+
with open(path, "r") as ff:
|
488 |
for line in ff:
|
489 |
+
sentence_lst.append(json.loads(line)[0].split(" "))
|
490 |
sentence_lst = sentence_lst + sentence_lst
|
491 |
+
if split in ["train", "valid", "test"]:
|
492 |
+
with open(path, "r") as ff:
|
493 |
for row in ff:
|
494 |
+
word_lst = row.split("||")[1]
|
495 |
word_lst = [x.text for x in tokenizer(word_lst)]
|
496 |
sentence_lst.append(word_lst)
|
497 |
print(sentence_lst[:2])
|
498 |
|
499 |
+
elif data_args.modality == "yelp":
|
500 |
+
print("loading dataset from simple YelpNLG dataset")
|
501 |
sentence_lst = []
|
502 |
nlp = English()
|
503 |
tokenizer = nlp.tokenizer
|
504 |
+
if split == "train":
|
505 |
+
print("loading form the TRAIN set")
|
506 |
+
path = f"{data_args.yelp_train}/yelpnlg-train.csv"
|
507 |
+
elif split == "valid":
|
508 |
+
print("loading form the VALID set")
|
509 |
+
path = f"{data_args.yelp_train}/yelpnlg-dev.csv"
|
510 |
+
elif split == "test":
|
511 |
+
print("loading form the TEST set")
|
512 |
+
path = f"{data_args.yelp_train}/yelpnlg-test.csv"
|
513 |
+
if split in ["train", "valid", "test"]:
|
514 |
+
|
515 |
+
with open(path, "r") as csvfile:
|
516 |
+
yelp_reader = csv.reader(csvfile) # delimiter=' ', quotechar='|')
|
517 |
for row in yelp_reader:
|
518 |
sentences = row[1]
|
519 |
word_lst = [x.text for x in tokenizer(sentences)]
|
|
|
521 |
sentence_lst = sentence_lst[1:]
|
522 |
print(sentence_lst[:2])
|
523 |
|
524 |
+
elif data_args.modality == "commonGen":
|
525 |
+
print("loading dataset from simple YelpNLG dataset")
|
526 |
sentence_lst = []
|
527 |
nlp = English()
|
528 |
tokenizer = nlp.tokenizer
|
529 |
+
if split == "train":
|
530 |
+
print("loading form the TRAIN set")
|
531 |
+
path = f"{data_args.commonGen_train}/commongen.train.jsonl"
|
532 |
+
elif split == "valid":
|
533 |
+
print("loading form the VALID set")
|
534 |
+
path = f"{data_args.commonGen_train}/commongen.dev.jsonl"
|
535 |
+
elif split == "test":
|
536 |
+
print("loading form the TEST set")
|
537 |
+
path = f"{data_args.commonGen_train}/commongen.test.jsonl"
|
538 |
+
if split in ["train", "valid", "test"]:
|
539 |
+
with open(path, "r") as ff:
|
540 |
for line in ff:
|
541 |
line = json.loads(line)
|
542 |
+
for sentences in line["scene"]:
|
543 |
word_lst = [x.text for x in tokenizer(sentences)]
|
544 |
sentence_lst.append(word_lst)
|
545 |
print(sentence_lst[:2])
|
546 |
|
547 |
+
elif data_args.modality == "commonGen-aug":
|
548 |
+
print("loading dataset from simple YelpNLG dataset")
|
549 |
sentence_lst = []
|
550 |
nlp = English()
|
551 |
tokenizer = nlp.tokenizer
|
552 |
+
if split == "train":
|
553 |
+
print("loading form the TRAIN set")
|
554 |
+
path = f"{data_args.commonGen_train}/commongen.train.jsonl"
|
555 |
+
path_lst = [f"{data_args.roc_train}/roc_train.json"]
|
556 |
+
path_lst.append(
|
557 |
+
"diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt"
|
558 |
+
)
|
559 |
+
elif split == "valid":
|
560 |
+
print("loading form the VALID set")
|
561 |
+
path = f"{data_args.commonGen_train}/commongen.dev.jsonl"
|
562 |
path_lst = []
|
563 |
+
elif split == "test":
|
564 |
+
print("loading form the TEST set")
|
565 |
+
path = f"{data_args.commonGen_train}/commongen.test.jsonl"
|
566 |
path_lst = []
|
567 |
|
568 |
+
if split in ["train", "valid", "test"]:
|
569 |
+
with open(path, "r") as ff:
|
570 |
for line in ff:
|
571 |
line = json.loads(line)
|
572 |
+
for sentences in line["scene"]:
|
573 |
word_lst = [x.text for x in tokenizer(sentences)]
|
574 |
sentence_lst.append(word_lst)
|
575 |
print(sentence_lst[:2])
|
576 |
import itertools
|
577 |
+
|
578 |
for path in path_lst:
|
579 |
+
if path.endswith("txt"):
|
580 |
+
with open(path, "r") as roc_reader:
|
581 |
for row in roc_reader:
|
582 |
sentences = row.strip()
|
583 |
word_lst = [x.text for x in tokenizer(sentences)]
|
584 |
spl = [[]]
|
585 |
+
for x, y in itertools.groupby(word_lst, lambda z: z == "."):
|
586 |
spl[-1].extend(y)
|
587 |
+
if x:
|
588 |
+
spl.append([])
|
589 |
sentence_lst.extend(spl[:-1])
|
590 |
else:
|
591 |
+
with open(path, "r") as roc_reader:
|
592 |
for row in roc_reader:
|
593 |
sentences = json.loads(row)[0].strip()
|
594 |
word_lst = [x.text for x in tokenizer(sentences)]
|
595 |
spl = [[]]
|
596 |
+
for x, y in itertools.groupby(word_lst, lambda z: z == "."):
|
597 |
spl[-1].extend(y)
|
598 |
+
if x:
|
599 |
+
spl.append([])
|
600 |
sentence_lst.extend(spl[:-1])
|
601 |
|
602 |
print(sentence_lst[-2:])
|
603 |
|
|
|
604 |
# get tokenizer.
|
605 |
if load_vocab is None:
|
606 |
counter = Counter()
|
607 |
for input_ids in sentence_lst:
|
608 |
counter.update(input_ids)
|
609 |
|
610 |
+
if data_args.experiment_mode == "conditional_gen":
|
611 |
+
if data_args.modality == "e2e":
|
612 |
+
print("loading dataset from simple e2e dataset")
|
613 |
sentence_lst = []
|
614 |
nlp = English()
|
615 |
tokenizer = nlp.tokenizer
|
616 |
+
if split == "train":
|
617 |
+
path = f"{data_args.e2e_train}/src1_train.txt"
|
618 |
+
with open(path, "r") as ff:
|
619 |
for row in ff:
|
620 |
+
src_lst, word_lst = row.split("||")
|
621 |
word_lst = [x.text for x in tokenizer(word_lst)]
|
622 |
src_lst = [x.text for x in tokenizer(src_lst)]
|
623 |
sentence_lst.append((src_lst, word_lst))
|
624 |
+
elif split == "valid":
|
625 |
+
path = f"{data_args.e2e_train}/src1_valid.txt"
|
626 |
sentence_lst = read_e2e_files(path, data_args, tokenizer)
|
627 |
print(sentence_lst[:2])
|
628 |
# get tokenizer.
|
629 |
if load_vocab is None:
|
630 |
counter = Counter()
|
631 |
+
for src_ids, input_ids in sentence_lst:
|
632 |
counter.update(input_ids)
|
633 |
counter.update(src_ids)
|
634 |
|
635 |
if load_vocab is None:
|
636 |
+
vocab_dict = {"START": 0, "END": 1, "UNK": 2, "PAD": 3}
|
637 |
for k, v in counter.items():
|
638 |
if v > 10:
|
639 |
vocab_dict[k] = len(vocab_dict)
|
640 |
print(len(counter), len(vocab_dict))
|
641 |
|
642 |
+
path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json"
|
643 |
+
print(f"save the vocab to {path_save_vocab}")
|
644 |
+
with open(path_save_vocab, "w") as f:
|
645 |
json.dump(vocab_dict, f)
|
646 |
else:
|
647 |
vocab_dict = load_vocab
|
648 |
+
path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json"
|
649 |
if not os.path.exists(path_save_vocab):
|
650 |
+
print(f"save the vocab to {path_save_vocab}")
|
651 |
if isinstance(vocab_dict, dict):
|
652 |
+
with open(path_save_vocab, "w") as f:
|
653 |
json.dump(vocab_dict, f)
|
654 |
+
assert vocab_dict["START"] == 0
|
655 |
elif isinstance(vocab_dict, PreTrainedTokenizerFast):
|
656 |
vocab_dict.save_pretrained(data_args.checkpoint_path)
|
657 |
else:
|
658 |
assert False, "invalid type of vocab_dict"
|
659 |
|
660 |
+
if model is None and data_args.experiment == "random":
|
|
|
|
|
661 |
model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel)
|
662 |
+
print("initializing the random embeddings", model)
|
663 |
torch.nn.init.normal_(model.weight)
|
664 |
+
path_save = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch"
|
665 |
+
print(
|
666 |
+
f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch"
|
667 |
+
)
|
668 |
torch.save(model.state_dict(), path_save)
|
669 |
|
670 |
# path_save = f'{data_args.checkpoint_path}/random_emb.torch'
|
671 |
# if not os.path.exists(path_save) and data_args.experiment == 'random':
|
672 |
# torch.save(model.state_dict(), path_save)
|
673 |
|
674 |
+
if (
|
675 |
+
data_args.experiment_mode == "lm"
|
676 |
+
and data_args.modality
|
677 |
+
in ["roc-aug", "roc", "yelp", "commonGen", "commonGen-aug"]
|
678 |
+
and data_args.cache_mode == "no"
|
679 |
+
):
|
680 |
+
train_dataset = helper_tokenize_stream(
|
681 |
+
sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode
|
682 |
+
)
|
683 |
return train_dataset, model
|
684 |
+
elif data_args.experiment_mode == "lm":
|
685 |
+
result_train_lst = helper_tokenize_encode(
|
686 |
+
sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode
|
687 |
+
)
|
688 |
+
elif data_args.experiment_mode == "conditional_gen":
|
689 |
+
result_train_lst = helper_tokenize_encode_cond(
|
690 |
+
sentence_lst, vocab_dict, model, image_size**2, data_args
|
691 |
+
)
|
692 |
+
return {"train": result_train_lst}, model
|
693 |
+
|
694 |
|
695 |
def write_e2e_corr(prompt_lst, file_dict, corr_path):
|
696 |
print(len(prompt_lst))
|
697 |
+
with open(corr_path, "w") as f:
|
698 |
for x in prompt_lst:
|
699 |
for line in file_dict[x]:
|
700 |
print(" ".join(line), file=f)
|
701 |
+
print("", file=f)
|
702 |
|
703 |
|
704 |
def write_e2e_src(prompt_lst, corr_path):
|
705 |
+
with open(corr_path, "w") as f:
|
706 |
for x in prompt_lst:
|
707 |
print(" ".join(x), file=f)
|
708 |
return
|
|
|
710 |
|
711 |
def read_e2e_files(path, args, tokenizer):
|
712 |
file_dict = {}
|
713 |
+
with open(path, "r") as f:
|
714 |
for line in f:
|
715 |
+
src_lst, word_lst = line.strip().split("||")
|
716 |
tgt = tuple([x.text for x in tokenizer(word_lst)])
|
717 |
src = tuple([x.text for x in tokenizer(src_lst)])
|
718 |
if src not in file_dict:
|
719 |
file_dict[src] = []
|
720 |
file_dict[src].append(tgt)
|
721 |
+
temp = "1"
|
722 |
prompt_text_dict = file_dict
|
723 |
prompt_text_lst = list(prompt_text_dict.keys())
|
724 |
+
gold_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "gold"))
|
725 |
print("gold dir", gold_dir)
|
726 |
write_e2e_corr(prompt_text_lst, prompt_text_dict, gold_dir)
|
727 |
+
src_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "src"))
|
728 |
write_e2e_src(prompt_text_lst, src_dir)
|
729 |
final_lst = [(xx, prompt_text_dict[xx][0]) for xx in prompt_text_lst]
|
730 |
return final_lst
|
731 |
|
732 |
|
733 |
+
def get_corpus_book(
|
734 |
+
data_args,
|
735 |
+
tokenizer,
|
736 |
+
model,
|
737 |
+
image_size,
|
738 |
+
padding_mode="block",
|
739 |
+
split="train",
|
740 |
+
):
|
741 |
+
max_length = image_size**2
|
742 |
import os
|
743 |
+
|
744 |
+
assert padding_mode == "block"
|
745 |
+
raw_datasets = load_dataset("bookcorpus")
|
746 |
if "validation" not in raw_datasets.keys():
|
747 |
raw_datasets["validation"] = load_dataset(
|
748 |
+
"bookcorpus",
|
749 |
split=f"train[:1%]",
|
750 |
)
|
751 |
raw_datasets["train"] = load_dataset(
|
752 |
+
"bookcorpus",
|
753 |
split=f"train[1%:]",
|
754 |
)
|
755 |
print(raw_datasets)
|
756 |
column_names = raw_datasets["train"].column_names
|
757 |
|
758 |
def tokenize_function(examples):
|
759 |
+
output = tokenizer(examples["text"], add_special_tokens=False)
|
760 |
return output
|
761 |
|
|
|
762 |
tokenized_datasets = raw_datasets.map(
|
763 |
tokenize_function,
|
764 |
batched=True,
|
|
|
779 |
if total_length >= block_size:
|
780 |
total_length = (total_length // block_size) * block_size
|
781 |
result = {
|
782 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
783 |
for k, t in concatenated_examples.items()
|
784 |
}
|
785 |
return result
|
|
|
795 |
print(lm_datasets)
|
796 |
|
797 |
if model is None:
|
798 |
+
if data_args.training_mode.startswith("e2e"):
|
799 |
+
print("since its e2e, initialize a dummy embedding")
|
800 |
model = torch.nn.Embedding(len(tokenizer), 1)
|
801 |
else:
|
802 |
model = torch.nn.Embedding(len(tokenizer), data_args.in_channel)
|
803 |
+
print("initializing the random embeddings", model)
|
804 |
torch.nn.init.normal_(model.weight)
|
805 |
+
path_save = f"{data_args.checkpoint_path}/random_emb.torch"
|
806 |
+
print(
|
807 |
+
f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch"
|
808 |
+
)
|
809 |
torch.save(model.state_dict(), path_save)
|
810 |
|
811 |
+
if split == "train":
|
812 |
return lm_datasets, model
|
813 |
else:
|
814 |
+
lm_datasets["train"] = lm_datasets["validation"]
|
815 |
return lm_datasets, model
|
816 |
|
817 |
|
818 |
class TextDataset(Dataset):
|
819 |
+
def __init__(
|
820 |
+
self,
|
821 |
+
text_datasets,
|
822 |
+
resolution,
|
823 |
+
data_args,
|
824 |
+
model_arch="conv-unet",
|
825 |
+
classes=None,
|
826 |
+
shard=0,
|
827 |
+
num_shards=1,
|
828 |
+
eigen_transform=None,
|
829 |
+
mapping_func=None,
|
830 |
+
model_emb=None,
|
831 |
+
):
|
832 |
super().__init__()
|
833 |
self.resolution = resolution
|
834 |
self.text_datasets = text_datasets
|
835 |
+
self.length = len(self.text_datasets["train"])
|
836 |
self.model_arch = model_arch
|
837 |
self.data_args = data_args
|
838 |
print(self.resolution)
|
|
|
850 |
# We are not on a new enough PIL to support the `reducing_gap`
|
851 |
# argument, which uses BOX downsampling at powers of two first.
|
852 |
# Thus, we do it by hand to improve downsample quality.
|
853 |
+
if self.model_arch == "conv-unet":
|
854 |
+
pass # arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
|
855 |
# dtype=np.float32).reshape(self.resolution, self.resolution, -1)
|
856 |
# # print(self.eigen_transform.shape)
|
857 |
# if self.eigen_transform is not None:
|
|
|
862 |
# if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
|
863 |
# arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
|
864 |
|
|
|
865 |
# out_dict = {}
|
866 |
# out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
|
867 |
# # if self.local_classes is not None:
|
868 |
# # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
869 |
# # print(out_dict.keys())
|
870 |
# return np.transpose(arr, [2, 0, 1]), out_dict
|
871 |
+
elif self.model_arch == "1d-unet":
|
872 |
+
pass # arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
|
873 |
# dtype=np.float32) # seqlen, dim
|
874 |
# if self.eigen_transform is not None:
|
875 |
# old_shape = arr.shape
|
|
|
887 |
# # print(arr.shape)
|
888 |
# return arr, out_dict
|
889 |
else:
|
890 |
+
arr = np.array(
|
891 |
+
self.text_datasets["train"][idx]["hidden_states"], dtype=np.float32
|
892 |
+
)
|
893 |
+
if self.eigen_transform is not None:
|
894 |
old_shape = arr.shape
|
895 |
# arr = arr.reshape(1, -1) @ self.eigen_transform
|
896 |
+
arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
|
897 |
+
arr = arr @ self.eigen_transform["map"]
|
898 |
arr = arr.reshape(old_shape)
|
899 |
+
|
900 |
+
if (
|
901 |
+
hasattr(self.data_args, "noise_level")
|
902 |
+
and self.data_args.noise_level > 0
|
903 |
+
):
|
904 |
# print(arr.dtype)
|
905 |
# print(self.data_args.noise_level, 'using the noise level.')
|
906 |
+
arr = arr + self.data_args.noise_level * np.random.randn(
|
907 |
+
*arr.shape
|
908 |
+
).astype(arr.dtype)
|
909 |
# print(arr.dtype)
|
910 |
|
911 |
out_dict = {}
|
912 |
+
out_dict["input_ids"] = np.array(
|
913 |
+
self.text_datasets["train"][idx]["input_ids"]
|
914 |
+
)
|
915 |
# out_dict['mapping_func'] = self.mapping_func
|
916 |
+
if self.data_args.experiment_mode == "conditional_gen":
|
917 |
+
out_dict["src_ids"] = np.array(
|
918 |
+
self.text_datasets["train"][idx]["src_ids"]
|
919 |
+
)
|
920 |
+
out_dict["src_mask"] = np.array(
|
921 |
+
self.text_datasets["train"][idx]["src_mask"]
|
922 |
+
)
|
923 |
# if self.local_classes is not None:
|
924 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
925 |
return arr, out_dict
|
|
|
929 |
|
930 |
|
931 |
class TextDataset_NoCache(Dataset):
|
932 |
+
def __init__(
|
933 |
+
self,
|
934 |
+
text_datasets,
|
935 |
+
resolution,
|
936 |
+
data_args,
|
937 |
+
model_arch="conv-unet",
|
938 |
+
classes=None,
|
939 |
+
shard=0,
|
940 |
+
num_shards=1,
|
941 |
+
eigen_transform=None,
|
942 |
+
mapping_func=None,
|
943 |
+
model_emb=None,
|
944 |
+
):
|
945 |
super().__init__()
|
946 |
self.resolution = resolution
|
947 |
self.text_datasets = text_datasets
|
948 |
+
self.length = len(self.text_datasets["train"])
|
949 |
self.model_arch = model_arch
|
950 |
self.data_args = data_args
|
951 |
print(self.resolution)
|
|
|
964 |
# argument, which uses BOX downsampling at powers of two first.
|
965 |
# Thus, we do it by hand to improve downsample quality.
|
966 |
with torch.no_grad():
|
967 |
+
input_ids = self.text_datasets["train"][idx]["input_ids"]
|
968 |
model = self.model_emb
|
969 |
+
if self.data_args.experiment.startswith("random"):
|
970 |
hidden_state = model(torch.tensor(input_ids))
|
971 |
+
elif self.data_args.experiment == "gpt2_pre_compress":
|
972 |
input_ids2 = torch.tensor(input_ids).to(model.device)
|
973 |
input_embs = model.transformer.wte(input_ids2) # input_embs
|
974 |
hidden_state = model.down_proj(input_embs)
|
975 |
hidden_state = hidden_state * data_args.emb_scale_factor
|
976 |
|
977 |
+
if self.model_arch == "conv-unet":
|
978 |
+
arr = np.array(hidden_state, dtype=np.float32).reshape(
|
979 |
+
self.resolution, self.resolution, -1
|
980 |
+
)
|
981 |
# print(self.eigen_transform.shape)
|
982 |
if self.eigen_transform is not None:
|
983 |
old_shape = arr.shape
|
984 |
+
arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
|
985 |
+
arr = arr @ self.eigen_transform["map"]
|
986 |
arr = arr.reshape(old_shape)
|
987 |
+
if (
|
988 |
+
hasattr(self.data_args, "noise_level")
|
989 |
+
and self.data_args.noise_level > 0
|
990 |
+
):
|
991 |
+
arr = arr + self.data_args.noise_level * np.random.randn(
|
992 |
+
*arr.shape
|
993 |
+
).astype(arr.dtype)
|
994 |
|
995 |
out_dict = {}
|
996 |
+
out_dict["input_ids"] = np.array(
|
997 |
+
self.text_datasets["train"][idx]["input_ids"]
|
998 |
+
)
|
999 |
# if self.local_classes is not None:
|
1000 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
1001 |
# print(out_dict.keys())
|
1002 |
return np.transpose(arr, [2, 0, 1]), out_dict
|
1003 |
+
elif self.model_arch == "1d-unet":
|
1004 |
+
arr = np.array(hidden_state, dtype=np.float32) # seqlen, dim
|
|
|
1005 |
if self.eigen_transform is not None:
|
1006 |
old_shape = arr.shape
|
1007 |
+
arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
|
1008 |
+
arr = arr @ self.eigen_transform["map"]
|
1009 |
arr = arr.reshape(old_shape)
|
1010 |
+
if (
|
1011 |
+
hasattr(self.data_args, "noise_level")
|
1012 |
+
and self.data_args.noise_level > 0
|
1013 |
+
):
|
1014 |
+
arr = arr + self.data_args.noise_level * np.random.randn(
|
1015 |
+
*arr.shape
|
1016 |
+
).astype(arr.dtype)
|
1017 |
arr = np.transpose(arr, [1, 0])
|
1018 |
out_dict = {}
|
1019 |
+
out_dict["input_ids"] = np.array(
|
1020 |
+
self.text_datasets["train"][idx]["input_ids"]
|
1021 |
+
)
|
1022 |
# out_dict['mapping_func'] = self.mapping_func
|
1023 |
# if self.local_classes is not None:
|
1024 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
1025 |
# print(arr.shape)
|
1026 |
return arr, out_dict
|
1027 |
else:
|
1028 |
+
arr = np.array(hidden_state, dtype=np.float32)
|
|
|
1029 |
if self.eigen_transform is not None:
|
1030 |
old_shape = arr.shape
|
1031 |
# arr = arr.reshape(1, -1) @ self.eigen_transform
|
1032 |
+
arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
|
1033 |
+
arr = arr @ self.eigen_transform["map"]
|
1034 |
arr = arr.reshape(old_shape)
|
1035 |
|
1036 |
+
if (
|
1037 |
+
hasattr(self.data_args, "noise_level")
|
1038 |
+
and self.data_args.noise_level > 0
|
1039 |
+
):
|
1040 |
# print(arr.dtype)
|
1041 |
# print(self.data_args.noise_level, 'using the noise level.')
|
1042 |
+
arr = arr + self.data_args.noise_level * np.random.randn(
|
1043 |
+
*arr.shape
|
1044 |
+
).astype(arr.dtype)
|
1045 |
# print(arr.dtype)
|
1046 |
|
1047 |
out_dict = {}
|
1048 |
+
out_dict["input_ids"] = np.array(
|
1049 |
+
self.text_datasets["train"][idx]["input_ids"]
|
1050 |
+
)
|
1051 |
# out_dict['mapping_func'] = self.mapping_func
|
1052 |
+
if self.data_args.experiment_mode == "conditional_gen":
|
1053 |
+
out_dict["src_ids"] = np.array(
|
1054 |
+
self.text_datasets["train"][idx]["src_ids"]
|
1055 |
+
)
|
1056 |
+
out_dict["src_mask"] = np.array(
|
1057 |
+
self.text_datasets["train"][idx]["src_mask"]
|
1058 |
+
)
|
1059 |
# if self.local_classes is not None:
|
1060 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
1061 |
return arr, out_dict
|
1062 |
|
1063 |
+
|
1064 |
def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False):
|
1065 |
+
result = torch.full(
|
1066 |
+
[len(examples), max_length], pad_token_id, dtype=torch.int64
|
1067 |
+
).tolist()
|
1068 |
+
mask_ = torch.full(
|
1069 |
+
[len(examples), max_length], pad_token_id, dtype=torch.int64
|
1070 |
+
).tolist()
|
1071 |
for i, example in enumerate(examples):
|
1072 |
curr_len = min(len(example), max_length)
|
1073 |
result[i][:curr_len] = example[:curr_len]
|
|
|
1076 |
return result, mask_
|
1077 |
return result
|
1078 |
|
1079 |
+
|
1080 |
def _torch_collate_batch(examples, pad_token_id, max_length):
|
1081 |
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
1082 |
import numpy as np
|
|
|
1101 |
result[i, : example.shape[0]] = example
|
1102 |
else:
|
1103 |
result[i, -example.shape[0] :] = example
|
1104 |
+
return result
|