Mengyao00 commited on
Commit
684b16c
·
verified ·
1 Parent(s): b4cce60

Update llama_bidirectional_model.py

Browse files
Files changed (1) hide show
  1. llama_bidirectional_model.py +1 -99
llama_bidirectional_model.py CHANGED
@@ -12,13 +12,12 @@ from transformers.modeling_outputs import (
12
  )
13
  from transformers.models.llama.configuration_llama import LlamaConfig
14
  from transformers.models.llama.modeling_llama import (
15
- LlamaForSequenceClassification,
16
  LlamaModel,
17
  LlamaPreTrainedModel,
18
  )
19
  from transformers.utils import logging
20
 
21
- from .pooling import pool
22
 
23
  logger = logging.get_logger(__name__)
24
 
@@ -56,100 +55,3 @@ class LlamaBidirectionalModel(LlamaModel):
56
  return causal_mask
57
 
58
 
59
- class LlamaBidirectionalForSequenceClassification(LlamaForSequenceClassification):
60
- config_class = LlamaBidirectionalConfig
61
-
62
- def __init__(self, config):
63
- super().__init__(config)
64
- # Releasing the parameters of LlamaModel
65
- # created by parent LlamaForSequenceClassification
66
- del self.model
67
-
68
- self.model = LlamaBidirectionalModel(config)
69
-
70
- # Initialize weights and apply final processing
71
- self.post_init()
72
-
73
- def forward(
74
- self,
75
- input_ids: Optional[torch.LongTensor] = None,
76
- attention_mask: Optional[torch.Tensor] = None,
77
- position_ids: Optional[torch.LongTensor] = None,
78
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
79
- inputs_embeds: Optional[torch.FloatTensor] = None,
80
- labels: Optional[torch.LongTensor] = None,
81
- use_cache: Optional[bool] = None,
82
- output_attentions: Optional[bool] = None,
83
- output_hidden_states: Optional[bool] = None,
84
- return_dict: Optional[bool] = None,
85
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
86
- r"""
87
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
88
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
89
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
90
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
91
- """
92
- return_dict = (
93
- return_dict if return_dict is not None else self.config.use_return_dict
94
- )
95
-
96
- transformer_outputs = self.model(
97
- input_ids,
98
- attention_mask=attention_mask,
99
- position_ids=position_ids,
100
- past_key_values=past_key_values,
101
- inputs_embeds=inputs_embeds,
102
- use_cache=use_cache,
103
- output_attentions=output_attentions,
104
- output_hidden_states=output_hidden_states,
105
- return_dict=return_dict,
106
- )
107
- hidden_states = transformer_outputs[0]
108
-
109
- pooled_hidden_states = pool(
110
- last_hidden_states=hidden_states,
111
- attention_mask=attention_mask,
112
- pool_type=self.config.pooling,
113
- )
114
-
115
- pooled_logits = self.score(pooled_hidden_states)
116
- pooled_logits = pooled_logits / self.config.temperature
117
-
118
- loss = None
119
- if labels is not None:
120
- labels = labels.to(logits.device)
121
- if self.config.problem_type is None:
122
- if self.num_labels == 1:
123
- self.config.problem_type = "regression"
124
- elif self.num_labels > 1 and (
125
- labels.dtype == torch.long or labels.dtype == torch.int
126
- ):
127
- self.config.problem_type = "single_label_classification"
128
- else:
129
- self.config.problem_type = "multi_label_classification"
130
-
131
- if self.config.problem_type == "regression":
132
- loss_fct = MSELoss()
133
- if self.num_labels == 1:
134
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
135
- else:
136
- loss = loss_fct(pooled_logits, labels)
137
- elif self.config.problem_type == "single_label_classification":
138
- loss_fct = CrossEntropyLoss()
139
- loss = loss_fct(
140
- pooled_logits.view(-1, self.num_labels), labels.view(-1)
141
- )
142
- elif self.config.problem_type == "multi_label_classification":
143
- loss_fct = BCEWithLogitsLoss()
144
- loss = loss_fct(pooled_logits, labels)
145
- if not return_dict:
146
- output = (pooled_logits,) + transformer_outputs[1:]
147
- return ((loss,) + output) if loss is not None else output
148
-
149
- return SequenceClassifierOutputWithPast(
150
- loss=loss,
151
- logits=pooled_logits,
152
- past_key_values=transformer_outputs.past_key_values,
153
- hidden_states=transformer_outputs.hidden_states,
154
- attentions=transformer_outputs.attentions,
155
- )
 
12
  )
13
  from transformers.models.llama.configuration_llama import LlamaConfig
14
  from transformers.models.llama.modeling_llama import (
 
15
  LlamaModel,
16
  LlamaPreTrainedModel,
17
  )
18
  from transformers.utils import logging
19
 
20
+
21
 
22
  logger = logging.get_logger(__name__)
23
 
 
55
  return causal_mask
56
 
57