Christina Theodoris commited on
Commit
b925dcc
1 Parent(s): 8ce598f

Update pretrainer for transformers==4.28.0

Browse files
examples/pretrain_geneformer_w_deepspeed.py CHANGED
@@ -137,9 +137,8 @@ training_args = {
137
  "weight_decay": weight_decay,
138
  "per_device_train_batch_size": geneformer_batch_size,
139
  "num_train_epochs": epochs,
140
- "load_best_model_at_end": True,
141
  "save_strategy": "steps",
142
- "save_steps": num_examples / geneformer_batch_size / 8, # 8 saves per epoch
143
  "logging_steps": 1000,
144
  "output_dir": training_output_dir,
145
  "logging_dir": logging_dir,
 
137
  "weight_decay": weight_decay,
138
  "per_device_train_batch_size": geneformer_batch_size,
139
  "num_train_epochs": epochs,
 
140
  "save_strategy": "steps",
141
+ "save_steps": np.floor(num_examples / geneformer_batch_size / 8), # 8 saves per epoch
142
  "logging_steps": 1000,
143
  "output_dir": training_output_dir,
144
  "logging_dir": logging_dir,
geneformer/pretrainer.py CHANGED
@@ -106,19 +106,23 @@ class TensorType(ExplicitEnum):
106
 
107
  class GeneformerPreCollator(SpecialTokensMixin):
108
  def __init__(self, *args, **kwargs) -> None:
 
 
 
109
  self.token_dictionary = kwargs.get("token_dictionary")
110
- self.mask_token = "<mask>"
111
- self.mask_token_id = self.token_dictionary.get("<mask>")
112
- self.pad_token = "<pad>"
113
- self.pad_token_id = self.token_dictionary.get("<pad>")
114
  self.padding_side = "right"
115
- self.all_special_ids = [
116
- self.token_dictionary.get("<mask>"),
117
- self.token_dictionary.get("<pad>"),
118
- ]
119
  self.model_input_names = ["input_ids"]
120
-
121
- super().__init__(*args, **kwargs)
 
122
 
123
  def _get_padding_truncation_strategies(
124
  self,
@@ -592,8 +596,8 @@ class GeneformerPreCollator(SpecialTokensMixin):
592
 
593
  class GeneformerPretrainer(Trainer):
594
  def __init__(self, *args, **kwargs):
595
- data_collator = kwargs.get("data_collator")
596
- token_dictionary = kwargs.get("token_dictionary")
597
 
598
  if data_collator is None:
599
  precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
@@ -604,17 +608,17 @@ class GeneformerPretrainer(Trainer):
604
  )
605
  kwargs["data_collator"] = data_collator
606
 
607
- super().__init__(*args, **kwargs)
608
-
609
  # load previously saved length vector for dataset to speed up LengthGroupedSampler
610
  # pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
611
- if kwargs.get("example_lengths_file"):
612
- with open(kwargs.get("example_lengths_file"), "rb") as f:
 
613
  self.example_lengths = pickle.load(f)
614
  else:
615
  raise Exception(
616
  "example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
617
  )
 
618
 
619
  # modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
620
  def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
@@ -634,7 +638,6 @@ class GeneformerPretrainer(Trainer):
634
  lengths = self.example_lengths
635
  else:
636
  lengths = None
637
- print(f"Lengths: {len(lengths)}")
638
  model_input_name = (
639
  self.tokenizer.model_input_names[0]
640
  if self.tokenizer is not None
@@ -642,16 +645,16 @@ class GeneformerPretrainer(Trainer):
642
  )
643
  if self.args.world_size <= 1:
644
  return LengthGroupedSampler(
645
- self.train_dataset,
646
- self.args.train_batch_size,
647
  lengths=lengths,
648
  model_input_name=model_input_name,
649
  generator=generator,
650
  )
651
  else:
652
  return CustomDistributedLengthGroupedSampler(
653
- self.train_dataset,
654
- self.args.train_batch_size,
655
  num_replicas=self.args.world_size,
656
  rank=self.args.process_index,
657
  lengths=lengths,
@@ -754,7 +757,7 @@ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
754
  # Deterministically shuffle based on epoch and seed
755
  g = torch.Generator()
756
  g.manual_seed(self.seed + self.epoch)
757
-
758
  indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
759
 
760
  if not self.drop_last:
 
106
 
107
  class GeneformerPreCollator(SpecialTokensMixin):
108
  def __init__(self, *args, **kwargs) -> None:
109
+
110
+ super().__init__(mask_token = "<mask>", pad_token = "<pad>")
111
+
112
  self.token_dictionary = kwargs.get("token_dictionary")
113
+ # self.mask_token = "<mask>"
114
+ # self.mask_token_id = self.token_dictionary.get("<mask>")
115
+ # self.pad_token = "<pad>"
116
+ # self.pad_token_id = self.token_dictionary.get("<pad>")
117
  self.padding_side = "right"
118
+ # self.all_special_ids = [
119
+ # self.token_dictionary.get("<mask>"),
120
+ # self.token_dictionary.get("<pad>"),
121
+ # ]
122
  self.model_input_names = ["input_ids"]
123
+
124
+ def convert_ids_to_tokens(self,value):
125
+ return self.token_dictionary.get(value)
126
 
127
  def _get_padding_truncation_strategies(
128
  self,
 
596
 
597
  class GeneformerPretrainer(Trainer):
598
  def __init__(self, *args, **kwargs):
599
+ data_collator = kwargs.get("data_collator",None)
600
+ token_dictionary = kwargs.pop("token_dictionary")
601
 
602
  if data_collator is None:
603
  precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
 
608
  )
609
  kwargs["data_collator"] = data_collator
610
 
 
 
611
  # load previously saved length vector for dataset to speed up LengthGroupedSampler
612
  # pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
613
+ example_lengths_file = kwargs.pop("example_lengths_file")
614
+ if example_lengths_file:
615
+ with open(example_lengths_file, "rb") as f:
616
  self.example_lengths = pickle.load(f)
617
  else:
618
  raise Exception(
619
  "example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
620
  )
621
+ super().__init__(*args, **kwargs)
622
 
623
  # modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
624
  def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
 
638
  lengths = self.example_lengths
639
  else:
640
  lengths = None
 
641
  model_input_name = (
642
  self.tokenizer.model_input_names[0]
643
  if self.tokenizer is not None
 
645
  )
646
  if self.args.world_size <= 1:
647
  return LengthGroupedSampler(
648
+ dataset=self.train_dataset,
649
+ batch_size=self.args.train_batch_size,
650
  lengths=lengths,
651
  model_input_name=model_input_name,
652
  generator=generator,
653
  )
654
  else:
655
  return CustomDistributedLengthGroupedSampler(
656
+ dataset=self.train_dataset,
657
+ batch_size=self.args.train_batch_size,
658
  num_replicas=self.args.world_size,
659
  rank=self.args.process_index,
660
  lengths=lengths,
 
757
  # Deterministically shuffle based on epoch and seed
758
  g = torch.Generator()
759
  g.manual_seed(self.seed + self.epoch)
760
+
761
  indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
762
 
763
  if not self.drop_last: