ndhieunguyen commited on
Commit
77180e4
1 Parent(s): ff15dff

feat: remove mpi4py

Browse files
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
src/improved_diffusion/dist_util.py CHANGED
@@ -8,7 +8,6 @@ import socket
8
 
9
  import blobfile as bf
10
 
11
- from mpi4py import MPI
12
  import torch as th
13
  import torch.distributed as dist
14
 
@@ -46,26 +45,26 @@ def setup_dist(rank, world_size, port="12145"):
46
  dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
47
 
48
 
49
- def dev():
50
- """
51
- Get the device to use for torch.distributed.
52
- """
53
- if th.cuda.is_available():
54
- return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
55
- return th.device("cpu")
56
-
57
-
58
- def load_state_dict(path, **kwargs):
59
- """
60
- Load a PyTorch file without redundant fetches across MPI ranks.
61
- """
62
- if MPI.COMM_WORLD.Get_rank() == 0:
63
- with bf.BlobFile(path, "rb") as f:
64
- data = f.read()
65
- else:
66
- data = None
67
- data = MPI.COMM_WORLD.bcast(data)
68
- return th.load(io.BytesIO(data), **kwargs)
69
 
70
 
71
  def sync_params(params):
 
8
 
9
  import blobfile as bf
10
 
 
11
  import torch as th
12
  import torch.distributed as dist
13
 
 
45
  dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
46
 
47
 
48
+ # def dev():
49
+ # """
50
+ # Get the device to use for torch.distributed.
51
+ # """
52
+ # if th.cuda.is_available():
53
+ # return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
54
+ # return th.device("cpu")
55
+
56
+
57
+ # def load_state_dict(path, **kwargs):
58
+ # """
59
+ # Load a PyTorch file without redundant fetches across MPI ranks.
60
+ # """
61
+ # if MPI.COMM_WORLD.Get_rank() == 0:
62
+ # with bf.BlobFile(path, "rb") as f:
63
+ # data = f.read()
64
+ # else:
65
+ # data = None
66
+ # data = MPI.COMM_WORLD.bcast(data)
67
+ # return th.load(io.BytesIO(data), **kwargs)
68
 
69
 
70
  def sync_params(params):
src/improved_diffusion/text_datasets.py CHANGED
@@ -1,13 +1,21 @@
1
  # from PIL import Image
2
  # import blobfile as bf
3
- from mpi4py import MPI
4
  import numpy as np
5
  from torch.utils.data import DataLoader, Dataset
6
- from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator, PreTrainedTokenizerFast, \
7
- PreTrainedTokenizer
 
 
 
 
 
 
 
8
  # from datasets import load_dataset
9
  import sys, os
10
  import torch
 
11
  # sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
12
  # from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise
13
  from collections import Counter, defaultdict
@@ -16,8 +24,18 @@ from itertools import chain
16
 
17
 
18
  def load_data_text(
19
- *, data_dir, batch_size, image_size, class_cond=False, deterministic=False, data_args=None,
20
- task_mode='roc', model=None, padding_mode='block', split='train', load_vocab=None,
 
 
 
 
 
 
 
 
 
 
21
  ):
22
  """
23
  For a dataset, create a generator over (images, kwargs) pairs.
@@ -35,29 +53,34 @@ def load_data_text(
35
  exception will be raised.
36
  :param deterministic: if True, yield results in a deterministic order.
37
  """
38
- print('hello loading text data. ')
39
 
40
- if data_args.experiment.startswith('random') and model is None:
41
  model = None
42
  # elif data_args.experiment.startswith('random') and model is not None:
43
  # print('loading initialized random embeddings. ')
44
 
45
- if task_mode == 'roc' or task_mode == 'roc-aug' :
46
  pass
47
  # training_data, model = get_corpus_rocstory(data_args, model, image_size,
48
  # padding_mode=padding_mode, split=split,
49
- # load_vocab=load_vocab)
50
- elif task_mode == 'simple-wiki':
51
  pass
52
  # training_data, model = get_corpus_rocstory(data_args, model, image_size,
53
- # padding_mode=padding_mode, split=split,
54
- # load_vocab=load_vocab)
55
-
56
- elif task_mode == 'e2e-tgt':
57
- print('hello loading e2e-tgt. ')
58
- training_data, model = get_corpus_rocstory(data_args, model, image_size,
59
- padding_mode=padding_mode, split=split,
60
- load_vocab=load_vocab)
 
 
 
 
 
61
  # elif task_mode == 'yelp':
62
  # print('hello loading yelp ')
63
  # training_data, model = get_corpus_rocstory(data_args, model, image_size,
@@ -80,8 +103,12 @@ def load_data_text(
80
  # training_data, model = get_corpus_book(data_args, tokenizer, model, image_size,
81
  # padding_mode=padding_mode, split=split,)
82
 
83
- if data_args.modality in ['roc-aug', 'roc', 'book', 'yelp', 'commonGen', 'commonGen-aug'] and data_args.cache_mode=='no':
84
- pass# dataset = TextDataset_NoCache(
 
 
 
 
85
  # training_data,
86
  # image_size,
87
  # data_args,
@@ -98,7 +125,7 @@ def load_data_text(
98
 
99
  if deterministic:
100
 
101
- pass# data_loader = DataLoader(
102
  # dataset,
103
  # batch_size=batch_size, # 20,
104
  # drop_last=True,
@@ -117,64 +144,83 @@ def load_data_text(
117
  while True:
118
  yield from data_loader
119
 
 
120
  def helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, seqlen, data_args):
121
  result_train_lst = []
122
  group_lst = defaultdict(list)
123
  with torch.no_grad():
124
- for (src_ids, input_ids) in sentence_lst:
125
- tokenized_ = [vocab_dict.get(x, vocab_dict['UNK']) for x in input_ids]
126
- tokenized_src = [vocab_dict.get(x, vocab_dict['UNK']) for x in src_ids]
127
  input_ids = [0] + tokenized_ + [1]
128
- group_lst['word_ids'].append(input_ids)
129
- group_lst['src_ids'].append(tokenized_src)
130
 
131
- print(group_lst['word_ids'][:2])
132
- print('padding mode is pad')
133
  max_length = seqlen
134
- group_lst['word_ids'] = _collate_batch_helper(group_lst['word_ids'], vocab_dict['PAD'], max_length)
135
- max_src_length = max([len(xx) for xx in group_lst['src_ids']])
 
 
136
  print(max_src_length, seqlen)
137
  max_src_length = min(seqlen, max_src_length)
138
- group_lst['src_ids'], group_lst['src_mask'] = _collate_batch_helper(group_lst['src_ids'],
139
- vocab_dict['PAD'],
140
- max_src_length,
141
- return_mask=True)
142
-
143
 
144
- for input_ids, src_ids, src_mask in zip(group_lst['word_ids'], group_lst['src_ids'],
145
- group_lst['src_mask']):
146
- if data_args.experiment.startswith('random'):
 
147
  hidden_state = model(torch.tensor(input_ids))
148
- elif data_args.experiment == 'gpt2_pre_compress':
149
  input_ids2 = torch.tensor(input_ids).to(model.device)
150
  input_embs = model.transformer.wte(input_ids2) # input_embs
151
  hidden_state = model.down_proj(input_embs)
152
  hidden_state = hidden_state * data_args.emb_scale_factor
153
- result_train_lst.append({'input_ids': input_ids,
154
- 'hidden_states': hidden_state.cpu().tolist(),
155
- 'src_ids':src_ids,
156
- 'src_mask':src_mask
157
- })
 
 
 
158
 
159
  return result_train_lst
160
 
161
- def helper_tokenize_stream(sentence_lst, vocab_dict, model, seqlen, data_args, padding_mode, ):
 
 
 
 
 
 
 
 
162
  import psutil
 
163
  # Process.memory_info is expressed in bytes, so convert to megabytes
164
  print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
165
  from datasets import Dataset as Dataset2
166
- raw_datasets = Dataset2.from_dict({'text':sentence_lst})
 
167
  print(raw_datasets)
168
  print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
169
 
170
-
171
  def tokenize_function(examples):
172
  if isinstance(vocab_dict, dict):
173
- input_ids = [[0] + [vocab_dict.get(x, vocab_dict['UNK']) for x in seq] + [1] for seq in examples['text']]
 
 
 
174
  elif isinstance(vocab_dict, PreTrainedTokenizerFast):
175
- examples['text'] = [" ".join(seq) for seq in examples['text']]
176
- input_ids = vocab_dict(examples['text'], add_special_tokens=True)['input_ids']
177
- result_dict = {'input_ids': input_ids}
 
 
178
  # clm input could be much much longer than block_size
179
  return result_dict
180
 
@@ -182,28 +228,30 @@ def helper_tokenize_stream(sentence_lst, vocab_dict, model, seqlen, data_args, p
182
  tokenize_function,
183
  batched=True,
184
  num_proc=4,
185
- remove_columns=['text'],
186
  load_from_cache_file=True,
187
  desc="Running tokenizer on dataset",
188
  )
189
  print(tokenized_datasets)
190
  print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
191
 
192
- if padding_mode == 'block':
193
  block_size = seqlen
 
194
  def group_texts(examples):
195
- concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
 
 
196
  total_length = len(concatenated_examples[list(examples.keys())[0]])
197
  if total_length >= block_size:
198
  total_length = (total_length // block_size) * block_size
199
  result = {
200
- k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
201
  for k, t in concatenated_examples.items()
202
  }
203
  result["labels"] = result["input_ids"].copy()
204
  return result
205
 
206
-
207
  lm_datasets = tokenized_datasets.map(
208
  group_texts,
209
  batched=True,
@@ -212,12 +260,17 @@ def helper_tokenize_stream(sentence_lst, vocab_dict, model, seqlen, data_args, p
212
  desc=f"Grouping texts in chunks of {block_size}",
213
  )
214
  else:
 
215
  def pad_function(group_lst):
216
  max_length = seqlen
217
  if isinstance(vocab_dict, dict):
218
- group_lst['input_ids'] = _collate_batch_helper(group_lst['input_ids'], vocab_dict['PAD'], max_length)
 
 
219
  else:
220
- group_lst['input_ids'] = _collate_batch_helper(group_lst['input_ids'], vocab_dict.pad_token_id, max_length)
 
 
221
  return group_lst
222
 
223
  # Process.memory_info is expressed in bytes, so convert to megabytes
@@ -230,59 +283,72 @@ def helper_tokenize_stream(sentence_lst, vocab_dict, model, seqlen, data_args, p
230
  desc=f"padding",
231
  )
232
 
233
-
234
- print(lm_datasets, 'padded dataset')
235
  print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
236
  import datasets
 
237
  raw_datasets = datasets.DatasetDict()
238
- raw_datasets['train'] = lm_datasets
239
  print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
240
  return raw_datasets
241
 
242
- def helper_tokenize_encode(sentence_lst, vocab_dict, model, seqlen, data_args, padding_mode, ):
 
 
 
 
 
 
 
 
243
  result_train_lst = []
244
  group_lst = defaultdict(list)
245
  with torch.no_grad():
246
  for input_ids in sentence_lst:
247
- tokenized_ = [vocab_dict.get(x, vocab_dict['UNK']) for x in input_ids]
248
  input_ids = [0] + tokenized_ + [1]
249
- group_lst['word_ids'].append(input_ids)
250
- print(group_lst['word_ids'][:2])
251
 
252
- if padding_mode == 'block':
253
- print('padding mode is block')
254
  concatenated_examples = {k: sum(group_lst[k], []) for k in group_lst.keys()}
255
  total_length = len(concatenated_examples[list(group_lst.keys())[0]])
256
  block_size = seqlen
257
  total_length = (total_length // block_size) * block_size
258
  # Split by chunks of max_len.
259
  group_lst = {
260
- k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
261
  for k, t in concatenated_examples.items()
262
  }
263
- elif padding_mode == 'pad':
264
- print('padding mode is pad')
265
  max_length = seqlen
266
- group_lst['word_ids'] = _collate_batch_helper(group_lst['word_ids'], vocab_dict['PAD'], max_length)
 
 
267
 
268
- for input_ids in group_lst['word_ids']:
269
- if data_args.experiment.startswith('random'):
270
  hidden_state = model(torch.tensor(input_ids))
271
- elif data_args.experiment == 'gpt2_pre_compress':
272
  input_ids2 = torch.tensor(input_ids).to(model.device)
273
  input_embs = model.transformer.wte(input_ids2) # input_embs
274
  hidden_state = model.down_proj(input_embs)
275
  hidden_state = hidden_state * data_args.emb_scale_factor
276
- elif data_args.experiment == 'glove':
277
  hidden_state = model(torch.tensor(input_ids))
278
- result_train_lst.append({'input_ids': input_ids, 'hidden_states': hidden_state.cpu().tolist()})
 
 
279
 
280
  return result_train_lst
281
 
 
282
  def load_glove_model(File):
283
  print("Loading Glove Model")
284
  glove_model = {}
285
- with open(File,'r') as f:
286
  for line in f:
287
  split_line = line.split()
288
  word = split_line[0]
@@ -292,9 +358,10 @@ def load_glove_model(File):
292
  print(f"{len(glove_model)} words loaded!")
293
  return glove_model
294
 
 
295
  def load_glove(vocab):
296
  model = torch.nn.Embedding(len(vocab), 50)
297
- glove_model = load_glove_model('predictability/glove/glove.6B.50d.txt')
298
  array_lst = []
299
  count_ = 0
300
  for word, idx in vocab.items():
@@ -303,20 +370,21 @@ def load_glove(vocab):
303
  else:
304
  count_ += 1
305
  array_lst.append(torch.randn(50))
306
- print(f'{count_} out of {len(vocab)} is initialized. ')
307
  array_lst = torch.stack(array_lst)
308
  print(torch.norm(array_lst, dim=-1).mean())
309
  model.weight.data = array_lst
310
  return model
311
 
312
 
313
- def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
314
- split='train', load_vocab=None):
 
315
  import csv, torch, json
316
  from spacy.lang.en import English
317
 
318
- if data_args.experiment_mode == 'lm':
319
- if data_args.modality == 'roc':
320
  pass
321
  # print('loading dataset from ROCStory')
322
  # nlp = English()
@@ -347,7 +415,7 @@ def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
347
  # # sentence_lst.append(word_lst)
348
  # # sentence_lst = sentence_lst[1:]
349
  # print(sentence_lst[:2])
350
- if data_args.modality == 'roc-aug':
351
  pass
352
  # print('loading dataset from ROCStory')
353
  # nlp = English()
@@ -381,7 +449,7 @@ def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
381
  # word_lst = [x.text for x in tokenizer(sentences)]
382
  # sentence_lst.append(word_lst)
383
  # print(sentence_lst[:2],sentence_lst[-2:], 'dataset size=',len(sentence_lst))
384
- elif data_args.modality == 'simple-wiki':
385
  pass
386
  # print('loading dataset from simple wikipedia')
387
  # sentence_lst = []
@@ -390,57 +458,62 @@ def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
390
  # word_lst = row.lower().split()
391
  # sentence_lst.append(word_lst)
392
  # print(sentence_lst[:2])
393
- elif data_args.modality == 'e2e-tgt':
394
- print('loading dataset from simple e2e dataset')
395
  sentence_lst = []
396
  nlp = English()
397
  tokenizer = nlp.tokenizer
398
- if split == 'train':
399
- print('loading form the TRAIN set')
400
- path = '/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_train.txt'
 
 
401
  # path = f'../{data_args.e2e_train}/src1_train.txt'
402
- elif split == 'valid':
403
- print('loading form the VALID set')
404
- path = f'../{data_args.e2e_train}/src1_valid.txt'
405
- path = '/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_valid.txt'
406
- elif split == 'test':
407
- print('loading form the TEST set')
408
- path = f'../{data_args.e2e_train}/src1_test.txt'
409
- path = '/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_test.txt'
410
- elif split == 'debug':
411
- print('loading form the DEBUG set')
 
 
412
  path = data_args.debug_path
413
  import json
414
- with open(path, 'r') as ff:
 
415
  for line in ff:
416
- sentence_lst.append(json.loads(line)[0].split(' '))
417
  sentence_lst = sentence_lst + sentence_lst
418
- if split in ['train', 'valid', 'test']:
419
- with open(path, 'r') as ff:
420
  for row in ff:
421
- word_lst = row.split('||')[1]
422
  word_lst = [x.text for x in tokenizer(word_lst)]
423
  sentence_lst.append(word_lst)
424
  print(sentence_lst[:2])
425
 
426
- elif data_args.modality == 'yelp':
427
- print('loading dataset from simple YelpNLG dataset')
428
  sentence_lst = []
429
  nlp = English()
430
  tokenizer = nlp.tokenizer
431
- if split == 'train':
432
- print('loading form the TRAIN set')
433
- path = f'{data_args.yelp_train}/yelpnlg-train.csv'
434
- elif split == 'valid':
435
- print('loading form the VALID set')
436
- path = f'{data_args.yelp_train}/yelpnlg-dev.csv'
437
- elif split == 'test':
438
- print('loading form the TEST set')
439
- path = f'{data_args.yelp_train}/yelpnlg-test.csv'
440
- if split in ['train', 'valid', 'test']:
441
-
442
- with open(path, 'r') as csvfile:
443
- yelp_reader = csv.reader(csvfile) #delimiter=' ', quotechar='|')
444
  for row in yelp_reader:
445
  sentences = row[1]
446
  word_lst = [x.text for x in tokenizer(sentences)]
@@ -448,175 +521,188 @@ def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
448
  sentence_lst = sentence_lst[1:]
449
  print(sentence_lst[:2])
450
 
451
- elif data_args.modality == 'commonGen':
452
- print('loading dataset from simple YelpNLG dataset')
453
  sentence_lst = []
454
  nlp = English()
455
  tokenizer = nlp.tokenizer
456
- if split == 'train':
457
- print('loading form the TRAIN set')
458
- path = f'{data_args.commonGen_train}/commongen.train.jsonl'
459
- elif split == 'valid':
460
- print('loading form the VALID set')
461
- path = f'{data_args.commonGen_train}/commongen.dev.jsonl'
462
- elif split == 'test':
463
- print('loading form the TEST set')
464
- path = f'{data_args.commonGen_train}/commongen.test.jsonl'
465
- if split in ['train', 'valid', 'test']:
466
- with open(path, 'r') as ff:
467
  for line in ff:
468
  line = json.loads(line)
469
- for sentences in line['scene']:
470
  word_lst = [x.text for x in tokenizer(sentences)]
471
  sentence_lst.append(word_lst)
472
  print(sentence_lst[:2])
473
 
474
- elif data_args.modality == 'commonGen-aug':
475
- print('loading dataset from simple YelpNLG dataset')
476
  sentence_lst = []
477
  nlp = English()
478
  tokenizer = nlp.tokenizer
479
- if split == 'train':
480
- print('loading form the TRAIN set')
481
- path = f'{data_args.commonGen_train}/commongen.train.jsonl'
482
- path_lst = [f'{data_args.roc_train}/roc_train.json']
483
- path_lst.append('diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt')
484
- elif split == 'valid':
485
- print('loading form the VALID set')
486
- path = f'{data_args.commonGen_train}/commongen.dev.jsonl'
 
 
487
  path_lst = []
488
- elif split == 'test':
489
- print('loading form the TEST set')
490
- path = f'{data_args.commonGen_train}/commongen.test.jsonl'
491
  path_lst = []
492
 
493
- if split in ['train', 'valid', 'test']:
494
- with open(path, 'r') as ff:
495
  for line in ff:
496
  line = json.loads(line)
497
- for sentences in line['scene']:
498
  word_lst = [x.text for x in tokenizer(sentences)]
499
  sentence_lst.append(word_lst)
500
  print(sentence_lst[:2])
501
  import itertools
 
502
  for path in path_lst:
503
- if path.endswith('txt'):
504
- with open(path, 'r') as roc_reader:
505
  for row in roc_reader:
506
  sentences = row.strip()
507
  word_lst = [x.text for x in tokenizer(sentences)]
508
  spl = [[]]
509
- for x, y in itertools.groupby(word_lst, lambda z: z == '.'):
510
  spl[-1].extend(y)
511
- if x: spl.append([])
 
512
  sentence_lst.extend(spl[:-1])
513
  else:
514
- with open(path, 'r') as roc_reader:
515
  for row in roc_reader:
516
  sentences = json.loads(row)[0].strip()
517
  word_lst = [x.text for x in tokenizer(sentences)]
518
  spl = [[]]
519
- for x, y in itertools.groupby(word_lst, lambda z: z == '.'):
520
  spl[-1].extend(y)
521
- if x: spl.append([])
 
522
  sentence_lst.extend(spl[:-1])
523
 
524
  print(sentence_lst[-2:])
525
 
526
-
527
  # get tokenizer.
528
  if load_vocab is None:
529
  counter = Counter()
530
  for input_ids in sentence_lst:
531
  counter.update(input_ids)
532
 
533
- if data_args.experiment_mode == 'conditional_gen':
534
- if data_args.modality == 'e2e':
535
- print('loading dataset from simple e2e dataset')
536
  sentence_lst = []
537
  nlp = English()
538
  tokenizer = nlp.tokenizer
539
- if split == 'train':
540
- path = f'{data_args.e2e_train}/src1_train.txt'
541
- with open(path, 'r') as ff:
542
  for row in ff:
543
- src_lst, word_lst = row.split('||')
544
  word_lst = [x.text for x in tokenizer(word_lst)]
545
  src_lst = [x.text for x in tokenizer(src_lst)]
546
  sentence_lst.append((src_lst, word_lst))
547
- elif split == 'valid':
548
- path = f'{data_args.e2e_train}/src1_valid.txt'
549
  sentence_lst = read_e2e_files(path, data_args, tokenizer)
550
  print(sentence_lst[:2])
551
  # get tokenizer.
552
  if load_vocab is None:
553
  counter = Counter()
554
- for (src_ids, input_ids) in sentence_lst:
555
  counter.update(input_ids)
556
  counter.update(src_ids)
557
 
558
  if load_vocab is None:
559
- vocab_dict = {'START': 0, 'END': 1, 'UNK':2, 'PAD':3}
560
  for k, v in counter.items():
561
  if v > 10:
562
  vocab_dict[k] = len(vocab_dict)
563
  print(len(counter), len(vocab_dict))
564
 
565
- path_save_vocab = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json'
566
- print(f'save the vocab to {path_save_vocab}')
567
- with open(path_save_vocab, 'w') as f:
568
  json.dump(vocab_dict, f)
569
  else:
570
  vocab_dict = load_vocab
571
- path_save_vocab = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json'
572
  if not os.path.exists(path_save_vocab):
573
- print(f'save the vocab to {path_save_vocab}')
574
  if isinstance(vocab_dict, dict):
575
- with open(path_save_vocab, 'w') as f:
576
  json.dump(vocab_dict, f)
577
- assert vocab_dict['START'] == 0
578
  elif isinstance(vocab_dict, PreTrainedTokenizerFast):
579
  vocab_dict.save_pretrained(data_args.checkpoint_path)
580
  else:
581
  assert False, "invalid type of vocab_dict"
582
 
583
-
584
-
585
- if model is None and data_args.experiment == 'random':
586
  model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel)
587
- print('initializing the random embeddings', model)
588
  torch.nn.init.normal_(model.weight)
589
- path_save = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch'
590
- print(f'save the random encoder to {data_args.checkpoint_path}/random_emb.torch')
 
 
591
  torch.save(model.state_dict(), path_save)
592
 
593
  # path_save = f'{data_args.checkpoint_path}/random_emb.torch'
594
  # if not os.path.exists(path_save) and data_args.experiment == 'random':
595
  # torch.save(model.state_dict(), path_save)
596
 
597
-
598
- if data_args.experiment_mode == 'lm' and data_args.modality in ['roc-aug', 'roc', 'yelp', 'commonGen', 'commonGen-aug'] \
599
- and data_args.cache_mode=='no':
600
- train_dataset = helper_tokenize_stream(sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode)
 
 
 
 
 
601
  return train_dataset, model
602
- elif data_args.experiment_mode == 'lm':
603
- result_train_lst = helper_tokenize_encode(sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode)
604
- elif data_args.experiment_mode == 'conditional_gen':
605
- result_train_lst = helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, image_size ** 2, data_args)
606
- return {'train': result_train_lst}, model
607
-
 
 
 
 
608
 
609
  def write_e2e_corr(prompt_lst, file_dict, corr_path):
610
  print(len(prompt_lst))
611
- with open(corr_path, 'w') as f:
612
  for x in prompt_lst:
613
  for line in file_dict[x]:
614
  print(" ".join(line), file=f)
615
- print('', file=f)
616
 
617
 
618
  def write_e2e_src(prompt_lst, corr_path):
619
- with open(corr_path, 'w') as f:
620
  for x in prompt_lst:
621
  print(" ".join(x), file=f)
622
  return
@@ -624,48 +710,55 @@ def write_e2e_src(prompt_lst, corr_path):
624
 
625
  def read_e2e_files(path, args, tokenizer):
626
  file_dict = {}
627
- with open(path, 'r') as f:
628
  for line in f:
629
- src_lst, word_lst = line.strip().split('||')
630
  tgt = tuple([x.text for x in tokenizer(word_lst)])
631
  src = tuple([x.text for x in tokenizer(src_lst)])
632
  if src not in file_dict:
633
  file_dict[src] = []
634
  file_dict[src].append(tgt)
635
- temp = '1'
636
  prompt_text_dict = file_dict
637
  prompt_text_lst = list(prompt_text_dict.keys())
638
- gold_dir = os.path.join(args.out_dir, '{}_{}_{}'.format(temp, args.split, 'gold'))
639
  print("gold dir", gold_dir)
640
  write_e2e_corr(prompt_text_lst, prompt_text_dict, gold_dir)
641
- src_dir = os.path.join(args.out_dir, '{}_{}_{}'.format(temp, args.split, 'src'))
642
  write_e2e_src(prompt_text_lst, src_dir)
643
  final_lst = [(xx, prompt_text_dict[xx][0]) for xx in prompt_text_lst]
644
  return final_lst
645
 
646
 
647
- def get_corpus_book(data_args, tokenizer, model, image_size, padding_mode='block', split='train',):
648
- max_length = image_size ** 2
 
 
 
 
 
 
 
649
  import os
650
- assert padding_mode == 'block'
651
- raw_datasets = load_dataset('bookcorpus')
 
652
  if "validation" not in raw_datasets.keys():
653
  raw_datasets["validation"] = load_dataset(
654
- 'bookcorpus',
655
  split=f"train[:1%]",
656
  )
657
  raw_datasets["train"] = load_dataset(
658
- 'bookcorpus',
659
  split=f"train[1%:]",
660
  )
661
  print(raw_datasets)
662
  column_names = raw_datasets["train"].column_names
663
 
664
  def tokenize_function(examples):
665
- output = tokenizer(examples['text'], add_special_tokens=False)
666
  return output
667
 
668
-
669
  tokenized_datasets = raw_datasets.map(
670
  tokenize_function,
671
  batched=True,
@@ -686,7 +779,7 @@ def get_corpus_book(data_args, tokenizer, model, image_size, padding_mode='block
686
  if total_length >= block_size:
687
  total_length = (total_length // block_size) * block_size
688
  result = {
689
- k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
690
  for k, t in concatenated_examples.items()
691
  }
692
  return result
@@ -702,32 +795,44 @@ def get_corpus_book(data_args, tokenizer, model, image_size, padding_mode='block
702
  print(lm_datasets)
703
 
704
  if model is None:
705
- if data_args.training_mode.startswith('e2e'):
706
- print('since its e2e, initialize a dummy embedding' )
707
  model = torch.nn.Embedding(len(tokenizer), 1)
708
  else:
709
  model = torch.nn.Embedding(len(tokenizer), data_args.in_channel)
710
- print('initializing the random embeddings', model)
711
  torch.nn.init.normal_(model.weight)
712
- path_save = f'{data_args.checkpoint_path}/random_emb.torch'
713
- print(f'save the random encoder to {data_args.checkpoint_path}/random_emb.torch')
 
 
714
  torch.save(model.state_dict(), path_save)
715
 
716
- if split == 'train':
717
  return lm_datasets, model
718
  else:
719
- lm_datasets['train'] = lm_datasets['validation']
720
  return lm_datasets, model
721
 
722
 
723
  class TextDataset(Dataset):
724
- def __init__(self, text_datasets, resolution, data_args, model_arch='conv-unet',
725
- classes=None, shard=0, num_shards=1, eigen_transform=None,
726
- mapping_func=None, model_emb=None):
 
 
 
 
 
 
 
 
 
 
727
  super().__init__()
728
  self.resolution = resolution
729
  self.text_datasets = text_datasets
730
- self.length = len(self.text_datasets['train'])
731
  self.model_arch = model_arch
732
  self.data_args = data_args
733
  print(self.resolution)
@@ -745,8 +850,8 @@ class TextDataset(Dataset):
745
  # We are not on a new enough PIL to support the `reducing_gap`
746
  # argument, which uses BOX downsampling at powers of two first.
747
  # Thus, we do it by hand to improve downsample quality.
748
- if self.model_arch == 'conv-unet':
749
- pass# arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
750
  # dtype=np.float32).reshape(self.resolution, self.resolution, -1)
751
  # # print(self.eigen_transform.shape)
752
  # if self.eigen_transform is not None:
@@ -757,15 +862,14 @@ class TextDataset(Dataset):
757
  # if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
758
  # arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
759
 
760
-
761
  # out_dict = {}
762
  # out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
763
  # # if self.local_classes is not None:
764
  # # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
765
  # # print(out_dict.keys())
766
  # return np.transpose(arr, [2, 0, 1]), out_dict
767
- elif self.model_arch == '1d-unet':
768
- pass# arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
769
  # dtype=np.float32) # seqlen, dim
770
  # if self.eigen_transform is not None:
771
  # old_shape = arr.shape
@@ -783,27 +887,39 @@ class TextDataset(Dataset):
783
  # # print(arr.shape)
784
  # return arr, out_dict
785
  else:
786
- arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
787
- dtype=np.float32)
788
- if self.eigen_transform is not None:
 
789
  old_shape = arr.shape
790
  # arr = arr.reshape(1, -1) @ self.eigen_transform
791
- arr = arr.reshape(1, -1) - self.eigen_transform['mean']
792
- arr = arr @ self.eigen_transform['map']
793
  arr = arr.reshape(old_shape)
794
-
795
- if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
 
 
 
796
  # print(arr.dtype)
797
  # print(self.data_args.noise_level, 'using the noise level.')
798
- arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
 
 
799
  # print(arr.dtype)
800
 
801
  out_dict = {}
802
- out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
 
 
803
  # out_dict['mapping_func'] = self.mapping_func
804
- if self.data_args.experiment_mode == 'conditional_gen':
805
- out_dict['src_ids'] = np.array(self.text_datasets['train'][idx]['src_ids'])
806
- out_dict['src_mask'] = np.array(self.text_datasets['train'][idx]['src_mask'])
 
 
 
 
807
  # if self.local_classes is not None:
808
  # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
809
  return arr, out_dict
@@ -813,13 +929,23 @@ class TextDataset(Dataset):
813
 
814
 
815
  class TextDataset_NoCache(Dataset):
816
- def __init__(self, text_datasets, resolution, data_args, model_arch='conv-unet',
817
- classes=None, shard=0, num_shards=1, eigen_transform=None,
818
- mapping_func=None, model_emb=None):
 
 
 
 
 
 
 
 
 
 
819
  super().__init__()
820
  self.resolution = resolution
821
  self.text_datasets = text_datasets
822
- self.length = len(self.text_datasets['train'])
823
  self.model_arch = model_arch
824
  self.data_args = data_args
825
  print(self.resolution)
@@ -838,81 +964,110 @@ class TextDataset_NoCache(Dataset):
838
  # argument, which uses BOX downsampling at powers of two first.
839
  # Thus, we do it by hand to improve downsample quality.
840
  with torch.no_grad():
841
- input_ids = self.text_datasets['train'][idx]['input_ids']
842
  model = self.model_emb
843
- if self.data_args.experiment.startswith('random'):
844
  hidden_state = model(torch.tensor(input_ids))
845
- elif self.data_args.experiment == 'gpt2_pre_compress':
846
  input_ids2 = torch.tensor(input_ids).to(model.device)
847
  input_embs = model.transformer.wte(input_ids2) # input_embs
848
  hidden_state = model.down_proj(input_embs)
849
  hidden_state = hidden_state * data_args.emb_scale_factor
850
 
851
- if self.model_arch == 'conv-unet':
852
- arr = np.array(hidden_state,
853
- dtype=np.float32).reshape(self.resolution, self.resolution, -1)
 
854
  # print(self.eigen_transform.shape)
855
  if self.eigen_transform is not None:
856
  old_shape = arr.shape
857
- arr = arr.reshape(1, -1) - self.eigen_transform['mean']
858
- arr = arr @ self.eigen_transform['map']
859
  arr = arr.reshape(old_shape)
860
- if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
861
- arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
 
 
 
 
 
862
 
863
  out_dict = {}
864
- out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
 
 
865
  # if self.local_classes is not None:
866
  # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
867
  # print(out_dict.keys())
868
  return np.transpose(arr, [2, 0, 1]), out_dict
869
- elif self.model_arch == '1d-unet':
870
- arr = np.array(hidden_state,
871
- dtype=np.float32) # seqlen, dim
872
  if self.eigen_transform is not None:
873
  old_shape = arr.shape
874
- arr = arr.reshape(1, -1) - self.eigen_transform['mean']
875
- arr = arr @ self.eigen_transform['map']
876
  arr = arr.reshape(old_shape)
877
- if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
878
- arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
 
 
 
 
 
879
  arr = np.transpose(arr, [1, 0])
880
  out_dict = {}
881
- out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
 
 
882
  # out_dict['mapping_func'] = self.mapping_func
883
  # if self.local_classes is not None:
884
  # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
885
  # print(arr.shape)
886
  return arr, out_dict
887
  else:
888
- arr = np.array(hidden_state,
889
- dtype=np.float32)
890
  if self.eigen_transform is not None:
891
  old_shape = arr.shape
892
  # arr = arr.reshape(1, -1) @ self.eigen_transform
893
- arr = arr.reshape(1, -1) - self.eigen_transform['mean']
894
- arr = arr @ self.eigen_transform['map']
895
  arr = arr.reshape(old_shape)
896
 
897
- if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
 
 
 
898
  # print(arr.dtype)
899
  # print(self.data_args.noise_level, 'using the noise level.')
900
- arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
 
 
901
  # print(arr.dtype)
902
 
903
  out_dict = {}
904
- out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
 
 
905
  # out_dict['mapping_func'] = self.mapping_func
906
- if self.data_args.experiment_mode == 'conditional_gen':
907
- out_dict['src_ids'] = np.array(self.text_datasets['train'][idx]['src_ids'])
908
- out_dict['src_mask'] = np.array(self.text_datasets['train'][idx]['src_mask'])
 
 
 
 
909
  # if self.local_classes is not None:
910
  # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
911
  return arr, out_dict
912
 
 
913
  def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False):
914
- result = torch.full([len(examples), max_length], pad_token_id, dtype=torch.int64).tolist()
915
- mask_ = torch.full([len(examples), max_length], pad_token_id, dtype=torch.int64).tolist()
 
 
 
 
916
  for i, example in enumerate(examples):
917
  curr_len = min(len(example), max_length)
918
  result[i][:curr_len] = example[:curr_len]
@@ -921,6 +1076,7 @@ def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False)
921
  return result, mask_
922
  return result
923
 
 
924
  def _torch_collate_batch(examples, pad_token_id, max_length):
925
  """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
926
  import numpy as np
@@ -945,4 +1101,4 @@ def _torch_collate_batch(examples, pad_token_id, max_length):
945
  result[i, : example.shape[0]] = example
946
  else:
947
  result[i, -example.shape[0] :] = example
948
- return result
 
1
  # from PIL import Image
2
  # import blobfile as bf
3
+ # from mpi4py import MPI
4
  import numpy as np
5
  from torch.utils.data import DataLoader, Dataset
6
+ from transformers import (
7
+ AutoModelForCausalLM,
8
+ AutoConfig,
9
+ AutoTokenizer,
10
+ default_data_collator,
11
+ PreTrainedTokenizerFast,
12
+ PreTrainedTokenizer,
13
+ )
14
+
15
  # from datasets import load_dataset
16
  import sys, os
17
  import torch
18
+
19
  # sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
20
  # from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise
21
  from collections import Counter, defaultdict
 
24
 
25
 
26
  def load_data_text(
27
+ *,
28
+ data_dir,
29
+ batch_size,
30
+ image_size,
31
+ class_cond=False,
32
+ deterministic=False,
33
+ data_args=None,
34
+ task_mode="roc",
35
+ model=None,
36
+ padding_mode="block",
37
+ split="train",
38
+ load_vocab=None,
39
  ):
40
  """
41
  For a dataset, create a generator over (images, kwargs) pairs.
 
53
  exception will be raised.
54
  :param deterministic: if True, yield results in a deterministic order.
55
  """
56
+ print("hello loading text data. ")
57
 
58
+ if data_args.experiment.startswith("random") and model is None:
59
  model = None
60
  # elif data_args.experiment.startswith('random') and model is not None:
61
  # print('loading initialized random embeddings. ')
62
 
63
+ if task_mode == "roc" or task_mode == "roc-aug":
64
  pass
65
  # training_data, model = get_corpus_rocstory(data_args, model, image_size,
66
  # padding_mode=padding_mode, split=split,
67
+ # load_vocab=load_vocab)
68
+ elif task_mode == "simple-wiki":
69
  pass
70
  # training_data, model = get_corpus_rocstory(data_args, model, image_size,
71
+ # padding_mode=padding_mode, split=split,
72
+ # load_vocab=load_vocab)
73
+
74
+ elif task_mode == "e2e-tgt":
75
+ print("hello loading e2e-tgt. ")
76
+ training_data, model = get_corpus_rocstory(
77
+ data_args,
78
+ model,
79
+ image_size,
80
+ padding_mode=padding_mode,
81
+ split=split,
82
+ load_vocab=load_vocab,
83
+ )
84
  # elif task_mode == 'yelp':
85
  # print('hello loading yelp ')
86
  # training_data, model = get_corpus_rocstory(data_args, model, image_size,
 
103
  # training_data, model = get_corpus_book(data_args, tokenizer, model, image_size,
104
  # padding_mode=padding_mode, split=split,)
105
 
106
+ if (
107
+ data_args.modality
108
+ in ["roc-aug", "roc", "book", "yelp", "commonGen", "commonGen-aug"]
109
+ and data_args.cache_mode == "no"
110
+ ):
111
+ pass # dataset = TextDataset_NoCache(
112
  # training_data,
113
  # image_size,
114
  # data_args,
 
125
 
126
  if deterministic:
127
 
128
+ pass # data_loader = DataLoader(
129
  # dataset,
130
  # batch_size=batch_size, # 20,
131
  # drop_last=True,
 
144
  while True:
145
  yield from data_loader
146
 
147
+
148
  def helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, seqlen, data_args):
149
  result_train_lst = []
150
  group_lst = defaultdict(list)
151
  with torch.no_grad():
152
+ for src_ids, input_ids in sentence_lst:
153
+ tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids]
154
+ tokenized_src = [vocab_dict.get(x, vocab_dict["UNK"]) for x in src_ids]
155
  input_ids = [0] + tokenized_ + [1]
156
+ group_lst["word_ids"].append(input_ids)
157
+ group_lst["src_ids"].append(tokenized_src)
158
 
159
+ print(group_lst["word_ids"][:2])
160
+ print("padding mode is pad")
161
  max_length = seqlen
162
+ group_lst["word_ids"] = _collate_batch_helper(
163
+ group_lst["word_ids"], vocab_dict["PAD"], max_length
164
+ )
165
+ max_src_length = max([len(xx) for xx in group_lst["src_ids"]])
166
  print(max_src_length, seqlen)
167
  max_src_length = min(seqlen, max_src_length)
168
+ group_lst["src_ids"], group_lst["src_mask"] = _collate_batch_helper(
169
+ group_lst["src_ids"], vocab_dict["PAD"], max_src_length, return_mask=True
170
+ )
 
 
171
 
172
+ for input_ids, src_ids, src_mask in zip(
173
+ group_lst["word_ids"], group_lst["src_ids"], group_lst["src_mask"]
174
+ ):
175
+ if data_args.experiment.startswith("random"):
176
  hidden_state = model(torch.tensor(input_ids))
177
+ elif data_args.experiment == "gpt2_pre_compress":
178
  input_ids2 = torch.tensor(input_ids).to(model.device)
179
  input_embs = model.transformer.wte(input_ids2) # input_embs
180
  hidden_state = model.down_proj(input_embs)
181
  hidden_state = hidden_state * data_args.emb_scale_factor
182
+ result_train_lst.append(
183
+ {
184
+ "input_ids": input_ids,
185
+ "hidden_states": hidden_state.cpu().tolist(),
186
+ "src_ids": src_ids,
187
+ "src_mask": src_mask,
188
+ }
189
+ )
190
 
191
  return result_train_lst
192
 
193
+
194
+ def helper_tokenize_stream(
195
+ sentence_lst,
196
+ vocab_dict,
197
+ model,
198
+ seqlen,
199
+ data_args,
200
+ padding_mode,
201
+ ):
202
  import psutil
203
+
204
  # Process.memory_info is expressed in bytes, so convert to megabytes
205
  print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
206
  from datasets import Dataset as Dataset2
207
+
208
+ raw_datasets = Dataset2.from_dict({"text": sentence_lst})
209
  print(raw_datasets)
210
  print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
211
 
 
212
  def tokenize_function(examples):
213
  if isinstance(vocab_dict, dict):
214
+ input_ids = [
215
+ [0] + [vocab_dict.get(x, vocab_dict["UNK"]) for x in seq] + [1]
216
+ for seq in examples["text"]
217
+ ]
218
  elif isinstance(vocab_dict, PreTrainedTokenizerFast):
219
+ examples["text"] = [" ".join(seq) for seq in examples["text"]]
220
+ input_ids = vocab_dict(examples["text"], add_special_tokens=True)[
221
+ "input_ids"
222
+ ]
223
+ result_dict = {"input_ids": input_ids}
224
  # clm input could be much much longer than block_size
225
  return result_dict
226
 
 
228
  tokenize_function,
229
  batched=True,
230
  num_proc=4,
231
+ remove_columns=["text"],
232
  load_from_cache_file=True,
233
  desc="Running tokenizer on dataset",
234
  )
235
  print(tokenized_datasets)
236
  print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
237
 
238
+ if padding_mode == "block":
239
  block_size = seqlen
240
+
241
  def group_texts(examples):
242
+ concatenated_examples = {
243
+ k: list(chain(*examples[k])) for k in examples.keys()
244
+ }
245
  total_length = len(concatenated_examples[list(examples.keys())[0]])
246
  if total_length >= block_size:
247
  total_length = (total_length // block_size) * block_size
248
  result = {
249
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
250
  for k, t in concatenated_examples.items()
251
  }
252
  result["labels"] = result["input_ids"].copy()
253
  return result
254
 
 
255
  lm_datasets = tokenized_datasets.map(
256
  group_texts,
257
  batched=True,
 
260
  desc=f"Grouping texts in chunks of {block_size}",
261
  )
262
  else:
263
+
264
  def pad_function(group_lst):
265
  max_length = seqlen
266
  if isinstance(vocab_dict, dict):
267
+ group_lst["input_ids"] = _collate_batch_helper(
268
+ group_lst["input_ids"], vocab_dict["PAD"], max_length
269
+ )
270
  else:
271
+ group_lst["input_ids"] = _collate_batch_helper(
272
+ group_lst["input_ids"], vocab_dict.pad_token_id, max_length
273
+ )
274
  return group_lst
275
 
276
  # Process.memory_info is expressed in bytes, so convert to megabytes
 
283
  desc=f"padding",
284
  )
285
 
286
+ print(lm_datasets, "padded dataset")
 
287
  print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
288
  import datasets
289
+
290
  raw_datasets = datasets.DatasetDict()
291
+ raw_datasets["train"] = lm_datasets
292
  print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
293
  return raw_datasets
294
 
295
+
296
+ def helper_tokenize_encode(
297
+ sentence_lst,
298
+ vocab_dict,
299
+ model,
300
+ seqlen,
301
+ data_args,
302
+ padding_mode,
303
+ ):
304
  result_train_lst = []
305
  group_lst = defaultdict(list)
306
  with torch.no_grad():
307
  for input_ids in sentence_lst:
308
+ tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids]
309
  input_ids = [0] + tokenized_ + [1]
310
+ group_lst["word_ids"].append(input_ids)
311
+ print(group_lst["word_ids"][:2])
312
 
313
+ if padding_mode == "block":
314
+ print("padding mode is block")
315
  concatenated_examples = {k: sum(group_lst[k], []) for k in group_lst.keys()}
316
  total_length = len(concatenated_examples[list(group_lst.keys())[0]])
317
  block_size = seqlen
318
  total_length = (total_length // block_size) * block_size
319
  # Split by chunks of max_len.
320
  group_lst = {
321
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
322
  for k, t in concatenated_examples.items()
323
  }
324
+ elif padding_mode == "pad":
325
+ print("padding mode is pad")
326
  max_length = seqlen
327
+ group_lst["word_ids"] = _collate_batch_helper(
328
+ group_lst["word_ids"], vocab_dict["PAD"], max_length
329
+ )
330
 
331
+ for input_ids in group_lst["word_ids"]:
332
+ if data_args.experiment.startswith("random"):
333
  hidden_state = model(torch.tensor(input_ids))
334
+ elif data_args.experiment == "gpt2_pre_compress":
335
  input_ids2 = torch.tensor(input_ids).to(model.device)
336
  input_embs = model.transformer.wte(input_ids2) # input_embs
337
  hidden_state = model.down_proj(input_embs)
338
  hidden_state = hidden_state * data_args.emb_scale_factor
339
+ elif data_args.experiment == "glove":
340
  hidden_state = model(torch.tensor(input_ids))
341
+ result_train_lst.append(
342
+ {"input_ids": input_ids, "hidden_states": hidden_state.cpu().tolist()}
343
+ )
344
 
345
  return result_train_lst
346
 
347
+
348
  def load_glove_model(File):
349
  print("Loading Glove Model")
350
  glove_model = {}
351
+ with open(File, "r") as f:
352
  for line in f:
353
  split_line = line.split()
354
  word = split_line[0]
 
358
  print(f"{len(glove_model)} words loaded!")
359
  return glove_model
360
 
361
+
362
  def load_glove(vocab):
363
  model = torch.nn.Embedding(len(vocab), 50)
364
+ glove_model = load_glove_model("predictability/glove/glove.6B.50d.txt")
365
  array_lst = []
366
  count_ = 0
367
  for word, idx in vocab.items():
 
370
  else:
371
  count_ += 1
372
  array_lst.append(torch.randn(50))
373
+ print(f"{count_} out of {len(vocab)} is initialized. ")
374
  array_lst = torch.stack(array_lst)
375
  print(torch.norm(array_lst, dim=-1).mean())
376
  model.weight.data = array_lst
377
  return model
378
 
379
 
380
+ def get_corpus_rocstory(
381
+ data_args, model, image_size, padding_mode="block", split="train", load_vocab=None
382
+ ):
383
  import csv, torch, json
384
  from spacy.lang.en import English
385
 
386
+ if data_args.experiment_mode == "lm":
387
+ if data_args.modality == "roc":
388
  pass
389
  # print('loading dataset from ROCStory')
390
  # nlp = English()
 
415
  # # sentence_lst.append(word_lst)
416
  # # sentence_lst = sentence_lst[1:]
417
  # print(sentence_lst[:2])
418
+ if data_args.modality == "roc-aug":
419
  pass
420
  # print('loading dataset from ROCStory')
421
  # nlp = English()
 
449
  # word_lst = [x.text for x in tokenizer(sentences)]
450
  # sentence_lst.append(word_lst)
451
  # print(sentence_lst[:2],sentence_lst[-2:], 'dataset size=',len(sentence_lst))
452
+ elif data_args.modality == "simple-wiki":
453
  pass
454
  # print('loading dataset from simple wikipedia')
455
  # sentence_lst = []
 
458
  # word_lst = row.lower().split()
459
  # sentence_lst.append(word_lst)
460
  # print(sentence_lst[:2])
461
+ elif data_args.modality == "e2e-tgt":
462
+ print("loading dataset from simple e2e dataset")
463
  sentence_lst = []
464
  nlp = English()
465
  tokenizer = nlp.tokenizer
466
+ if split == "train":
467
+ print("loading form the TRAIN set")
468
+ path = (
469
+ "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_train.txt"
470
+ )
471
  # path = f'../{data_args.e2e_train}/src1_train.txt'
472
+ elif split == "valid":
473
+ print("loading form the VALID set")
474
+ path = f"../{data_args.e2e_train}/src1_valid.txt"
475
+ path = (
476
+ "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_valid.txt"
477
+ )
478
+ elif split == "test":
479
+ print("loading form the TEST set")
480
+ path = f"../{data_args.e2e_train}/src1_test.txt"
481
+ path = "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_test.txt"
482
+ elif split == "debug":
483
+ print("loading form the DEBUG set")
484
  path = data_args.debug_path
485
  import json
486
+
487
+ with open(path, "r") as ff:
488
  for line in ff:
489
+ sentence_lst.append(json.loads(line)[0].split(" "))
490
  sentence_lst = sentence_lst + sentence_lst
491
+ if split in ["train", "valid", "test"]:
492
+ with open(path, "r") as ff:
493
  for row in ff:
494
+ word_lst = row.split("||")[1]
495
  word_lst = [x.text for x in tokenizer(word_lst)]
496
  sentence_lst.append(word_lst)
497
  print(sentence_lst[:2])
498
 
499
+ elif data_args.modality == "yelp":
500
+ print("loading dataset from simple YelpNLG dataset")
501
  sentence_lst = []
502
  nlp = English()
503
  tokenizer = nlp.tokenizer
504
+ if split == "train":
505
+ print("loading form the TRAIN set")
506
+ path = f"{data_args.yelp_train}/yelpnlg-train.csv"
507
+ elif split == "valid":
508
+ print("loading form the VALID set")
509
+ path = f"{data_args.yelp_train}/yelpnlg-dev.csv"
510
+ elif split == "test":
511
+ print("loading form the TEST set")
512
+ path = f"{data_args.yelp_train}/yelpnlg-test.csv"
513
+ if split in ["train", "valid", "test"]:
514
+
515
+ with open(path, "r") as csvfile:
516
+ yelp_reader = csv.reader(csvfile) # delimiter=' ', quotechar='|')
517
  for row in yelp_reader:
518
  sentences = row[1]
519
  word_lst = [x.text for x in tokenizer(sentences)]
 
521
  sentence_lst = sentence_lst[1:]
522
  print(sentence_lst[:2])
523
 
524
+ elif data_args.modality == "commonGen":
525
+ print("loading dataset from simple YelpNLG dataset")
526
  sentence_lst = []
527
  nlp = English()
528
  tokenizer = nlp.tokenizer
529
+ if split == "train":
530
+ print("loading form the TRAIN set")
531
+ path = f"{data_args.commonGen_train}/commongen.train.jsonl"
532
+ elif split == "valid":
533
+ print("loading form the VALID set")
534
+ path = f"{data_args.commonGen_train}/commongen.dev.jsonl"
535
+ elif split == "test":
536
+ print("loading form the TEST set")
537
+ path = f"{data_args.commonGen_train}/commongen.test.jsonl"
538
+ if split in ["train", "valid", "test"]:
539
+ with open(path, "r") as ff:
540
  for line in ff:
541
  line = json.loads(line)
542
+ for sentences in line["scene"]:
543
  word_lst = [x.text for x in tokenizer(sentences)]
544
  sentence_lst.append(word_lst)
545
  print(sentence_lst[:2])
546
 
547
+ elif data_args.modality == "commonGen-aug":
548
+ print("loading dataset from simple YelpNLG dataset")
549
  sentence_lst = []
550
  nlp = English()
551
  tokenizer = nlp.tokenizer
552
+ if split == "train":
553
+ print("loading form the TRAIN set")
554
+ path = f"{data_args.commonGen_train}/commongen.train.jsonl"
555
+ path_lst = [f"{data_args.roc_train}/roc_train.json"]
556
+ path_lst.append(
557
+ "diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt"
558
+ )
559
+ elif split == "valid":
560
+ print("loading form the VALID set")
561
+ path = f"{data_args.commonGen_train}/commongen.dev.jsonl"
562
  path_lst = []
563
+ elif split == "test":
564
+ print("loading form the TEST set")
565
+ path = f"{data_args.commonGen_train}/commongen.test.jsonl"
566
  path_lst = []
567
 
568
+ if split in ["train", "valid", "test"]:
569
+ with open(path, "r") as ff:
570
  for line in ff:
571
  line = json.loads(line)
572
+ for sentences in line["scene"]:
573
  word_lst = [x.text for x in tokenizer(sentences)]
574
  sentence_lst.append(word_lst)
575
  print(sentence_lst[:2])
576
  import itertools
577
+
578
  for path in path_lst:
579
+ if path.endswith("txt"):
580
+ with open(path, "r") as roc_reader:
581
  for row in roc_reader:
582
  sentences = row.strip()
583
  word_lst = [x.text for x in tokenizer(sentences)]
584
  spl = [[]]
585
+ for x, y in itertools.groupby(word_lst, lambda z: z == "."):
586
  spl[-1].extend(y)
587
+ if x:
588
+ spl.append([])
589
  sentence_lst.extend(spl[:-1])
590
  else:
591
+ with open(path, "r") as roc_reader:
592
  for row in roc_reader:
593
  sentences = json.loads(row)[0].strip()
594
  word_lst = [x.text for x in tokenizer(sentences)]
595
  spl = [[]]
596
+ for x, y in itertools.groupby(word_lst, lambda z: z == "."):
597
  spl[-1].extend(y)
598
+ if x:
599
+ spl.append([])
600
  sentence_lst.extend(spl[:-1])
601
 
602
  print(sentence_lst[-2:])
603
 
 
604
  # get tokenizer.
605
  if load_vocab is None:
606
  counter = Counter()
607
  for input_ids in sentence_lst:
608
  counter.update(input_ids)
609
 
610
+ if data_args.experiment_mode == "conditional_gen":
611
+ if data_args.modality == "e2e":
612
+ print("loading dataset from simple e2e dataset")
613
  sentence_lst = []
614
  nlp = English()
615
  tokenizer = nlp.tokenizer
616
+ if split == "train":
617
+ path = f"{data_args.e2e_train}/src1_train.txt"
618
+ with open(path, "r") as ff:
619
  for row in ff:
620
+ src_lst, word_lst = row.split("||")
621
  word_lst = [x.text for x in tokenizer(word_lst)]
622
  src_lst = [x.text for x in tokenizer(src_lst)]
623
  sentence_lst.append((src_lst, word_lst))
624
+ elif split == "valid":
625
+ path = f"{data_args.e2e_train}/src1_valid.txt"
626
  sentence_lst = read_e2e_files(path, data_args, tokenizer)
627
  print(sentence_lst[:2])
628
  # get tokenizer.
629
  if load_vocab is None:
630
  counter = Counter()
631
+ for src_ids, input_ids in sentence_lst:
632
  counter.update(input_ids)
633
  counter.update(src_ids)
634
 
635
  if load_vocab is None:
636
+ vocab_dict = {"START": 0, "END": 1, "UNK": 2, "PAD": 3}
637
  for k, v in counter.items():
638
  if v > 10:
639
  vocab_dict[k] = len(vocab_dict)
640
  print(len(counter), len(vocab_dict))
641
 
642
+ path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json"
643
+ print(f"save the vocab to {path_save_vocab}")
644
+ with open(path_save_vocab, "w") as f:
645
  json.dump(vocab_dict, f)
646
  else:
647
  vocab_dict = load_vocab
648
+ path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json"
649
  if not os.path.exists(path_save_vocab):
650
+ print(f"save the vocab to {path_save_vocab}")
651
  if isinstance(vocab_dict, dict):
652
+ with open(path_save_vocab, "w") as f:
653
  json.dump(vocab_dict, f)
654
+ assert vocab_dict["START"] == 0
655
  elif isinstance(vocab_dict, PreTrainedTokenizerFast):
656
  vocab_dict.save_pretrained(data_args.checkpoint_path)
657
  else:
658
  assert False, "invalid type of vocab_dict"
659
 
660
+ if model is None and data_args.experiment == "random":
 
 
661
  model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel)
662
+ print("initializing the random embeddings", model)
663
  torch.nn.init.normal_(model.weight)
664
+ path_save = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch"
665
+ print(
666
+ f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch"
667
+ )
668
  torch.save(model.state_dict(), path_save)
669
 
670
  # path_save = f'{data_args.checkpoint_path}/random_emb.torch'
671
  # if not os.path.exists(path_save) and data_args.experiment == 'random':
672
  # torch.save(model.state_dict(), path_save)
673
 
674
+ if (
675
+ data_args.experiment_mode == "lm"
676
+ and data_args.modality
677
+ in ["roc-aug", "roc", "yelp", "commonGen", "commonGen-aug"]
678
+ and data_args.cache_mode == "no"
679
+ ):
680
+ train_dataset = helper_tokenize_stream(
681
+ sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode
682
+ )
683
  return train_dataset, model
684
+ elif data_args.experiment_mode == "lm":
685
+ result_train_lst = helper_tokenize_encode(
686
+ sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode
687
+ )
688
+ elif data_args.experiment_mode == "conditional_gen":
689
+ result_train_lst = helper_tokenize_encode_cond(
690
+ sentence_lst, vocab_dict, model, image_size**2, data_args
691
+ )
692
+ return {"train": result_train_lst}, model
693
+
694
 
695
  def write_e2e_corr(prompt_lst, file_dict, corr_path):
696
  print(len(prompt_lst))
697
+ with open(corr_path, "w") as f:
698
  for x in prompt_lst:
699
  for line in file_dict[x]:
700
  print(" ".join(line), file=f)
701
+ print("", file=f)
702
 
703
 
704
  def write_e2e_src(prompt_lst, corr_path):
705
+ with open(corr_path, "w") as f:
706
  for x in prompt_lst:
707
  print(" ".join(x), file=f)
708
  return
 
710
 
711
  def read_e2e_files(path, args, tokenizer):
712
  file_dict = {}
713
+ with open(path, "r") as f:
714
  for line in f:
715
+ src_lst, word_lst = line.strip().split("||")
716
  tgt = tuple([x.text for x in tokenizer(word_lst)])
717
  src = tuple([x.text for x in tokenizer(src_lst)])
718
  if src not in file_dict:
719
  file_dict[src] = []
720
  file_dict[src].append(tgt)
721
+ temp = "1"
722
  prompt_text_dict = file_dict
723
  prompt_text_lst = list(prompt_text_dict.keys())
724
+ gold_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "gold"))
725
  print("gold dir", gold_dir)
726
  write_e2e_corr(prompt_text_lst, prompt_text_dict, gold_dir)
727
+ src_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "src"))
728
  write_e2e_src(prompt_text_lst, src_dir)
729
  final_lst = [(xx, prompt_text_dict[xx][0]) for xx in prompt_text_lst]
730
  return final_lst
731
 
732
 
733
+ def get_corpus_book(
734
+ data_args,
735
+ tokenizer,
736
+ model,
737
+ image_size,
738
+ padding_mode="block",
739
+ split="train",
740
+ ):
741
+ max_length = image_size**2
742
  import os
743
+
744
+ assert padding_mode == "block"
745
+ raw_datasets = load_dataset("bookcorpus")
746
  if "validation" not in raw_datasets.keys():
747
  raw_datasets["validation"] = load_dataset(
748
+ "bookcorpus",
749
  split=f"train[:1%]",
750
  )
751
  raw_datasets["train"] = load_dataset(
752
+ "bookcorpus",
753
  split=f"train[1%:]",
754
  )
755
  print(raw_datasets)
756
  column_names = raw_datasets["train"].column_names
757
 
758
  def tokenize_function(examples):
759
+ output = tokenizer(examples["text"], add_special_tokens=False)
760
  return output
761
 
 
762
  tokenized_datasets = raw_datasets.map(
763
  tokenize_function,
764
  batched=True,
 
779
  if total_length >= block_size:
780
  total_length = (total_length // block_size) * block_size
781
  result = {
782
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
783
  for k, t in concatenated_examples.items()
784
  }
785
  return result
 
795
  print(lm_datasets)
796
 
797
  if model is None:
798
+ if data_args.training_mode.startswith("e2e"):
799
+ print("since its e2e, initialize a dummy embedding")
800
  model = torch.nn.Embedding(len(tokenizer), 1)
801
  else:
802
  model = torch.nn.Embedding(len(tokenizer), data_args.in_channel)
803
+ print("initializing the random embeddings", model)
804
  torch.nn.init.normal_(model.weight)
805
+ path_save = f"{data_args.checkpoint_path}/random_emb.torch"
806
+ print(
807
+ f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch"
808
+ )
809
  torch.save(model.state_dict(), path_save)
810
 
811
+ if split == "train":
812
  return lm_datasets, model
813
  else:
814
+ lm_datasets["train"] = lm_datasets["validation"]
815
  return lm_datasets, model
816
 
817
 
818
  class TextDataset(Dataset):
819
+ def __init__(
820
+ self,
821
+ text_datasets,
822
+ resolution,
823
+ data_args,
824
+ model_arch="conv-unet",
825
+ classes=None,
826
+ shard=0,
827
+ num_shards=1,
828
+ eigen_transform=None,
829
+ mapping_func=None,
830
+ model_emb=None,
831
+ ):
832
  super().__init__()
833
  self.resolution = resolution
834
  self.text_datasets = text_datasets
835
+ self.length = len(self.text_datasets["train"])
836
  self.model_arch = model_arch
837
  self.data_args = data_args
838
  print(self.resolution)
 
850
  # We are not on a new enough PIL to support the `reducing_gap`
851
  # argument, which uses BOX downsampling at powers of two first.
852
  # Thus, we do it by hand to improve downsample quality.
853
+ if self.model_arch == "conv-unet":
854
+ pass # arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
855
  # dtype=np.float32).reshape(self.resolution, self.resolution, -1)
856
  # # print(self.eigen_transform.shape)
857
  # if self.eigen_transform is not None:
 
862
  # if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
863
  # arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
864
 
 
865
  # out_dict = {}
866
  # out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
867
  # # if self.local_classes is not None:
868
  # # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
869
  # # print(out_dict.keys())
870
  # return np.transpose(arr, [2, 0, 1]), out_dict
871
+ elif self.model_arch == "1d-unet":
872
+ pass # arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
873
  # dtype=np.float32) # seqlen, dim
874
  # if self.eigen_transform is not None:
875
  # old_shape = arr.shape
 
887
  # # print(arr.shape)
888
  # return arr, out_dict
889
  else:
890
+ arr = np.array(
891
+ self.text_datasets["train"][idx]["hidden_states"], dtype=np.float32
892
+ )
893
+ if self.eigen_transform is not None:
894
  old_shape = arr.shape
895
  # arr = arr.reshape(1, -1) @ self.eigen_transform
896
+ arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
897
+ arr = arr @ self.eigen_transform["map"]
898
  arr = arr.reshape(old_shape)
899
+
900
+ if (
901
+ hasattr(self.data_args, "noise_level")
902
+ and self.data_args.noise_level > 0
903
+ ):
904
  # print(arr.dtype)
905
  # print(self.data_args.noise_level, 'using the noise level.')
906
+ arr = arr + self.data_args.noise_level * np.random.randn(
907
+ *arr.shape
908
+ ).astype(arr.dtype)
909
  # print(arr.dtype)
910
 
911
  out_dict = {}
912
+ out_dict["input_ids"] = np.array(
913
+ self.text_datasets["train"][idx]["input_ids"]
914
+ )
915
  # out_dict['mapping_func'] = self.mapping_func
916
+ if self.data_args.experiment_mode == "conditional_gen":
917
+ out_dict["src_ids"] = np.array(
918
+ self.text_datasets["train"][idx]["src_ids"]
919
+ )
920
+ out_dict["src_mask"] = np.array(
921
+ self.text_datasets["train"][idx]["src_mask"]
922
+ )
923
  # if self.local_classes is not None:
924
  # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
925
  return arr, out_dict
 
929
 
930
 
931
  class TextDataset_NoCache(Dataset):
932
+ def __init__(
933
+ self,
934
+ text_datasets,
935
+ resolution,
936
+ data_args,
937
+ model_arch="conv-unet",
938
+ classes=None,
939
+ shard=0,
940
+ num_shards=1,
941
+ eigen_transform=None,
942
+ mapping_func=None,
943
+ model_emb=None,
944
+ ):
945
  super().__init__()
946
  self.resolution = resolution
947
  self.text_datasets = text_datasets
948
+ self.length = len(self.text_datasets["train"])
949
  self.model_arch = model_arch
950
  self.data_args = data_args
951
  print(self.resolution)
 
964
  # argument, which uses BOX downsampling at powers of two first.
965
  # Thus, we do it by hand to improve downsample quality.
966
  with torch.no_grad():
967
+ input_ids = self.text_datasets["train"][idx]["input_ids"]
968
  model = self.model_emb
969
+ if self.data_args.experiment.startswith("random"):
970
  hidden_state = model(torch.tensor(input_ids))
971
+ elif self.data_args.experiment == "gpt2_pre_compress":
972
  input_ids2 = torch.tensor(input_ids).to(model.device)
973
  input_embs = model.transformer.wte(input_ids2) # input_embs
974
  hidden_state = model.down_proj(input_embs)
975
  hidden_state = hidden_state * data_args.emb_scale_factor
976
 
977
+ if self.model_arch == "conv-unet":
978
+ arr = np.array(hidden_state, dtype=np.float32).reshape(
979
+ self.resolution, self.resolution, -1
980
+ )
981
  # print(self.eigen_transform.shape)
982
  if self.eigen_transform is not None:
983
  old_shape = arr.shape
984
+ arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
985
+ arr = arr @ self.eigen_transform["map"]
986
  arr = arr.reshape(old_shape)
987
+ if (
988
+ hasattr(self.data_args, "noise_level")
989
+ and self.data_args.noise_level > 0
990
+ ):
991
+ arr = arr + self.data_args.noise_level * np.random.randn(
992
+ *arr.shape
993
+ ).astype(arr.dtype)
994
 
995
  out_dict = {}
996
+ out_dict["input_ids"] = np.array(
997
+ self.text_datasets["train"][idx]["input_ids"]
998
+ )
999
  # if self.local_classes is not None:
1000
  # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
1001
  # print(out_dict.keys())
1002
  return np.transpose(arr, [2, 0, 1]), out_dict
1003
+ elif self.model_arch == "1d-unet":
1004
+ arr = np.array(hidden_state, dtype=np.float32) # seqlen, dim
 
1005
  if self.eigen_transform is not None:
1006
  old_shape = arr.shape
1007
+ arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
1008
+ arr = arr @ self.eigen_transform["map"]
1009
  arr = arr.reshape(old_shape)
1010
+ if (
1011
+ hasattr(self.data_args, "noise_level")
1012
+ and self.data_args.noise_level > 0
1013
+ ):
1014
+ arr = arr + self.data_args.noise_level * np.random.randn(
1015
+ *arr.shape
1016
+ ).astype(arr.dtype)
1017
  arr = np.transpose(arr, [1, 0])
1018
  out_dict = {}
1019
+ out_dict["input_ids"] = np.array(
1020
+ self.text_datasets["train"][idx]["input_ids"]
1021
+ )
1022
  # out_dict['mapping_func'] = self.mapping_func
1023
  # if self.local_classes is not None:
1024
  # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
1025
  # print(arr.shape)
1026
  return arr, out_dict
1027
  else:
1028
+ arr = np.array(hidden_state, dtype=np.float32)
 
1029
  if self.eigen_transform is not None:
1030
  old_shape = arr.shape
1031
  # arr = arr.reshape(1, -1) @ self.eigen_transform
1032
+ arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
1033
+ arr = arr @ self.eigen_transform["map"]
1034
  arr = arr.reshape(old_shape)
1035
 
1036
+ if (
1037
+ hasattr(self.data_args, "noise_level")
1038
+ and self.data_args.noise_level > 0
1039
+ ):
1040
  # print(arr.dtype)
1041
  # print(self.data_args.noise_level, 'using the noise level.')
1042
+ arr = arr + self.data_args.noise_level * np.random.randn(
1043
+ *arr.shape
1044
+ ).astype(arr.dtype)
1045
  # print(arr.dtype)
1046
 
1047
  out_dict = {}
1048
+ out_dict["input_ids"] = np.array(
1049
+ self.text_datasets["train"][idx]["input_ids"]
1050
+ )
1051
  # out_dict['mapping_func'] = self.mapping_func
1052
+ if self.data_args.experiment_mode == "conditional_gen":
1053
+ out_dict["src_ids"] = np.array(
1054
+ self.text_datasets["train"][idx]["src_ids"]
1055
+ )
1056
+ out_dict["src_mask"] = np.array(
1057
+ self.text_datasets["train"][idx]["src_mask"]
1058
+ )
1059
  # if self.local_classes is not None:
1060
  # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
1061
  return arr, out_dict
1062
 
1063
+
1064
  def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False):
1065
+ result = torch.full(
1066
+ [len(examples), max_length], pad_token_id, dtype=torch.int64
1067
+ ).tolist()
1068
+ mask_ = torch.full(
1069
+ [len(examples), max_length], pad_token_id, dtype=torch.int64
1070
+ ).tolist()
1071
  for i, example in enumerate(examples):
1072
  curr_len = min(len(example), max_length)
1073
  result[i][:curr_len] = example[:curr_len]
 
1076
  return result, mask_
1077
  return result
1078
 
1079
+
1080
  def _torch_collate_batch(examples, pad_token_id, max_length):
1081
  """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
1082
  import numpy as np
 
1101
  result[i, : example.shape[0]] = example
1102
  else:
1103
  result[i, -example.shape[0] :] = example
1104
+ return result