Markus28 commited on
Commit
3f5615c
1 Parent(s): 599c64e

fix: assert is None for other kwargs too

Browse files
Files changed (2) hide show
  1. modeling_bert.py +0 -4
  2. modeling_for_glue.py +5 -5
modeling_bert.py CHANGED
@@ -379,16 +379,12 @@ class BertModel(BertPreTrainedModel):
379
  task_type_ids=None,
380
  attention_mask=None,
381
  masked_tokens_mask=None,
382
- head_mask=None,
383
  ):
384
  """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
385
  we only want the output for the masked tokens. This means that we only compute the last
386
  layer output for these tokens.
387
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
388
  """
389
- if head_mask is not None:
390
- raise NotImplementedError('Masking heads is not supported')
391
-
392
  hidden_states = self.embeddings(
393
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
394
  )
 
379
  task_type_ids=None,
380
  attention_mask=None,
381
  masked_tokens_mask=None,
 
382
  ):
383
  """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
384
  we only want the output for the masked tokens. This means that we only compute the last
385
  layer output for these tokens.
386
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
387
  """
 
 
 
388
  hidden_states = self.embeddings(
389
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
390
  )
modeling_for_glue.py CHANGED
@@ -51,16 +51,16 @@ class BertForSequenceClassification(BertPreTrainedModel):
51
  return_dict if return_dict is not None else self.config.use_return_dict
52
  )
53
 
 
 
 
 
 
54
  outputs = self.bert(
55
  input_ids,
56
  attention_mask=attention_mask,
57
  token_type_ids=token_type_ids,
58
  position_ids=position_ids,
59
- head_mask=head_mask,
60
- inputs_embeds=inputs_embeds,
61
- output_attentions=output_attentions,
62
- output_hidden_states=output_hidden_states,
63
- return_dict=return_dict,
64
  )
65
 
66
  pooled_output = outputs[1]
 
51
  return_dict if return_dict is not None else self.config.use_return_dict
52
  )
53
 
54
+ assert head_mask is None
55
+ assert inputs_embeds is None
56
+ assert output_attentions is None
57
+ assert output_hidden_states is None
58
+ assert return_dict is None
59
  outputs = self.bert(
60
  input_ids,
61
  attention_mask=attention_mask,
62
  token_type_ids=token_type_ids,
63
  position_ids=position_ids,
 
 
 
 
 
64
  )
65
 
66
  pooled_output = outputs[1]