Markus28 commited on
Commit
f6fcfb5
1 Parent(s): 75d7a16

feat: implemented task_type_ids

Browse files
Files changed (1) hide show
  1. modeling_bert.py +10 -0
modeling_bert.py CHANGED
@@ -340,14 +340,21 @@ class BertModel(BertPreTrainedModel):
340
  self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
341
  self.encoder = BertEncoder(config)
342
  self.pooler = BertPooler(config) if add_pooling_layer else None
 
343
 
344
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
 
 
 
 
 
345
 
346
  def forward(
347
  self,
348
  input_ids,
349
  position_ids=None,
350
  token_type_ids=None,
 
351
  attention_mask=None,
352
  masked_tokens_mask=None,
353
  ):
@@ -359,6 +366,9 @@ class BertModel(BertPreTrainedModel):
359
  hidden_states = self.embeddings(
360
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
361
  )
 
 
 
362
  # TD [2022-12:18]: Don't need to force residual in fp32
363
  # BERT puts embedding LayerNorm before embedding dropout.
364
  if not self.fused_dropout_add_ln:
 
340
  self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
341
  self.encoder = BertEncoder(config)
342
  self.pooler = BertPooler(config) if add_pooling_layer else None
343
+ self.task_type_embeddings = nn.Embedding(config.num_tasks, config.hidden_size)
344
 
345
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
346
+ # We now initialize the task embeddings to 0; We do not use task types during
347
+ # pretraining. When we start using task types during embedding training,
348
+ # we want the model to behave exactly as in pretraining (i.e. task types
349
+ # have no effect).
350
+ self.task_type_embeddings.fill_(0)
351
 
352
  def forward(
353
  self,
354
  input_ids,
355
  position_ids=None,
356
  token_type_ids=None,
357
+ task_type_ids=None,
358
  attention_mask=None,
359
  masked_tokens_mask=None,
360
  ):
 
366
  hidden_states = self.embeddings(
367
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
368
  )
369
+ if task_type_ids is not None:
370
+ hidden_states = hidden_states + self.task_type_embeddings(task_type_ids)
371
+
372
  # TD [2022-12:18]: Don't need to force residual in fp32
373
  # BERT puts embedding LayerNorm before embedding dropout.
374
  if not self.fused_dropout_add_ln: