Christina Theodoris commited on
Commit
fb130e6
1 Parent(s): 86fe0dd

update kwargs for pretrainer

Browse files
Files changed (1) hide show
  1. geneformer/pretrainer.py +10 -9
geneformer/pretrainer.py CHANGED
@@ -106,9 +106,8 @@ class TensorType(ExplicitEnum):
106
 
107
  class GeneformerPreCollator(SpecialTokensMixin):
108
  def __init__(self, *args, **kwargs) -> None:
109
-
110
- super().__init__(mask_token = "<mask>", pad_token = "<pad>")
111
-
112
  self.token_dictionary = kwargs.get("token_dictionary")
113
  # self.mask_token = "<mask>"
114
  # self.mask_token_id = self.token_dictionary.get("<mask>")
@@ -120,8 +119,8 @@ class GeneformerPreCollator(SpecialTokensMixin):
120
  # self.token_dictionary.get("<pad>"),
121
  # ]
122
  self.model_input_names = ["input_ids"]
123
-
124
- def convert_ids_to_tokens(self,value):
125
  return self.token_dictionary.get(value)
126
 
127
  def _get_padding_truncation_strategies(
@@ -391,7 +390,6 @@ class GeneformerPreCollator(SpecialTokensMixin):
391
 
392
  for key, value in encoded_inputs.items():
393
  encoded_inputs[key] = to_py_obj(value)
394
-
395
 
396
  # Convert padding_strategy in PaddingStrategy
397
  padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
@@ -596,15 +594,17 @@ class GeneformerPreCollator(SpecialTokensMixin):
596
 
597
  class GeneformerPretrainer(Trainer):
598
  def __init__(self, *args, **kwargs):
599
- data_collator = kwargs.get("data_collator",None)
600
  token_dictionary = kwargs.pop("token_dictionary")
 
 
601
 
602
  if data_collator is None:
603
  precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
604
 
605
  # # Data Collator Functions
606
  data_collator = DataCollatorForLanguageModeling(
607
- tokenizer=precollator, mlm=True, mlm_probability=0.15
608
  )
609
  kwargs["data_collator"] = data_collator
610
 
@@ -694,6 +694,7 @@ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
694
  Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
695
  length while keeping a bit of randomness.
696
  """
 
697
  # Copied and adapted from PyTorch DistributedSampler.
698
  def __init__(
699
  self,
@@ -757,7 +758,7 @@ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
757
  # Deterministically shuffle based on epoch and seed
758
  g = torch.Generator()
759
  g.manual_seed(self.seed + self.epoch)
760
-
761
  indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
762
 
763
  if not self.drop_last:
 
106
 
107
  class GeneformerPreCollator(SpecialTokensMixin):
108
  def __init__(self, *args, **kwargs) -> None:
109
+ super().__init__(mask_token="<mask>", pad_token="<pad>")
110
+
 
111
  self.token_dictionary = kwargs.get("token_dictionary")
112
  # self.mask_token = "<mask>"
113
  # self.mask_token_id = self.token_dictionary.get("<mask>")
 
119
  # self.token_dictionary.get("<pad>"),
120
  # ]
121
  self.model_input_names = ["input_ids"]
122
+
123
+ def convert_ids_to_tokens(self, value):
124
  return self.token_dictionary.get(value)
125
 
126
  def _get_padding_truncation_strategies(
 
390
 
391
  for key, value in encoded_inputs.items():
392
  encoded_inputs[key] = to_py_obj(value)
 
393
 
394
  # Convert padding_strategy in PaddingStrategy
395
  padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
 
594
 
595
  class GeneformerPretrainer(Trainer):
596
  def __init__(self, *args, **kwargs):
597
+ data_collator = kwargs.get("data_collator", None)
598
  token_dictionary = kwargs.pop("token_dictionary")
599
+ mlm = kwargs.pop("mlm", True)
600
+ mlm_probability = kwargs.pop("mlm_probability", 0.15)
601
 
602
  if data_collator is None:
603
  precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
604
 
605
  # # Data Collator Functions
606
  data_collator = DataCollatorForLanguageModeling(
607
+ tokenizer=precollator, mlm=mlm, mlm_probability=mlm_probability
608
  )
609
  kwargs["data_collator"] = data_collator
610
 
 
694
  Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
695
  length while keeping a bit of randomness.
696
  """
697
+
698
  # Copied and adapted from PyTorch DistributedSampler.
699
  def __init__(
700
  self,
 
758
  # Deterministically shuffle based on epoch and seed
759
  g = torch.Generator()
760
  g.manual_seed(self.seed + self.epoch)
761
+
762
  indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
763
 
764
  if not self.drop_last: