lvwerra HF staff commited on
Commit
24790fd
1 Parent(s): 62d4de7

tweak data loader to log metrics

Browse files
Files changed (1) hide show
  1. codeparrot_training.py +3 -0
codeparrot_training.py CHANGED
@@ -22,6 +22,7 @@ class ConstantLengthDataset(IterableDataset):
22
  self.dataset = dataset
23
  self.seq_length = seq_length
24
  self.input_characters = seq_length * chars_per_token * num_of_sequences
 
25
 
26
  def __iter__(self):
27
  iterator = iter(self.dataset)
@@ -36,6 +37,8 @@ class ConstantLengthDataset(IterableDataset):
36
  buffer_len += len(buffer[-1])
37
  except StopIteration:
38
  iterator = iter(self.dataset)
 
 
39
  tokenized_inputs = tokenizer(buffer, truncation=False)['input_ids']
40
  all_token_ids = []
41
  for tokenized_input in tokenized_inputs:
 
22
  self.dataset = dataset
23
  self.seq_length = seq_length
24
  self.input_characters = seq_length * chars_per_token * num_of_sequences
25
+ self.epoch = 0
26
 
27
  def __iter__(self):
28
  iterator = iter(self.dataset)
 
37
  buffer_len += len(buffer[-1])
38
  except StopIteration:
39
  iterator = iter(self.dataset)
40
+ self.epoch += 1
41
+ logger.info(f"Dataset epoch: {self.epoch}")
42
  tokenized_inputs = tokenizer(buffer, truncation=False)['input_ids']
43
  all_token_ids = []
44
  for tokenized_input in tokenized_inputs: