Transformers
PyTorch
code
custom_code
Inference Endpoints
codesage commited on
Commit
4c7b3c8
1 Parent(s): 189ff33

Update modeling_codesage.py

Browse files
Files changed (1) hide show
  1. modeling_codesage.py +71 -1
modeling_codesage.py CHANGED
@@ -11,7 +11,11 @@ from transformers.activations import ACT2FN
11
  from transformers.modeling_utils import Conv1D, PreTrainedModel
12
  from transformers.utils import logging
13
  from .config_codesage import CodeSageConfig
14
- from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
 
 
 
 
15
 
16
  logger = logging.get_logger(__name__)
17
 
@@ -151,6 +155,7 @@ class CodeSageBlock(nn.Module):
151
 
152
  class CodeSagePreTrainedModel(PreTrainedModel):
153
  config_class = CodeSageConfig
 
154
 
155
  def _init_weights(self, module):
156
  """Initialize the weights."""
@@ -277,7 +282,72 @@ class CodeSageModel(CodeSagePreTrainedModel):
277
  )
278
 
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  class CodeSageForSequenceClassification(CodeSagePreTrainedModel):
 
281
  def __init__(self, config):
282
  super().__init__(config)
283
  self.num_labels = config.num_labels
 
11
  from transformers.modeling_utils import Conv1D, PreTrainedModel
12
  from transformers.utils import logging
13
  from .config_codesage import CodeSageConfig
14
+ from transformers.modeling_outputs import (
15
+ BaseModelOutputWithPooling,
16
+ MaskedLMOutput,
17
+ SequenceClassifierOutput
18
+ )
19
 
20
  logger = logging.get_logger(__name__)
21
 
 
155
 
156
  class CodeSagePreTrainedModel(PreTrainedModel):
157
  config_class = CodeSageConfig
158
+ base_model_prefix = "transformer"
159
 
160
  def _init_weights(self, module):
161
  """Initialize the weights."""
 
282
  )
283
 
284
 
285
+ class CodeSageForMaskedLM(CodeSagePreTrainedModel):
286
+ _tied_weights_keys = ["lm_head.weight"]
287
+
288
+ def __init__(self, config):
289
+ super().__init__(config)
290
+ self.transformer = CodeSageModel(config)
291
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
292
+
293
+ self.init_weights()
294
+
295
+ # Model parallel
296
+ self.model_parallel = False
297
+ self.device_map = None
298
+
299
+ def get_output_embeddings(self):
300
+ return self.lm_head
301
+
302
+ def set_output_embeddings(self, new_embeddings):
303
+ self.lm_head = new_embeddings
304
+
305
+ def forward(
306
+ self,
307
+ input_ids=None,
308
+ attention_mask=None,
309
+ position_ids=None,
310
+ head_mask=None,
311
+ inputs_embeds=None,
312
+ labels=None,
313
+ output_attentions=None,
314
+ output_hidden_states=None,
315
+ return_dict=None
316
+ ):
317
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
318
+
319
+ transformer_outputs = self.transformer(
320
+ input_ids,
321
+ attention_mask=attention_mask,
322
+ position_ids=position_ids,
323
+ head_mask=head_mask,
324
+ inputs_embeds=inputs_embeds,
325
+ output_attentions=output_attentions,
326
+ output_hidden_states=output_hidden_states,
327
+ return_dict=return_dict
328
+ )
329
+ hidden_states = transformer_outputs[0]
330
+ lm_logits = self.lm_head(hidden_states)
331
+
332
+ masked_lm_loss = None
333
+ if labels is not None:
334
+ loss_fct = CrossEntropyLoss()
335
+ masked_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
336
+
337
+ if not return_dict:
338
+ output = (lm_logits,) + transformer_outputs[1:]
339
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
340
+
341
+ return MaskedLMOutput(
342
+ loss=masked_lm_loss,
343
+ logits=lm_logits,
344
+ hidden_states=transformer_outputs.hidden_states,
345
+ attentions=transformer_outputs.attentions,
346
+ )
347
+
348
+
349
  class CodeSageForSequenceClassification(CodeSagePreTrainedModel):
350
+
351
  def __init__(self, config):
352
  super().__init__(config)
353
  self.num_labels = config.num_labels