PatrickHaller commited on
Commit
c8df02f
1 Parent(s): 1423859

Upload modeling_hgrn2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_hgrn2.py +117 -0
modeling_hgrn2.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fla.models.hgrn2 import HGRN2ForCausalLM, HGRN2Model
2
+ from typing import Optional, Tuple, Union, List
3
+
4
+ from fla.models.hgrn2.modeling_hgrn2 import HGRN2PreTrainedModel, HGRN2Model
5
+ from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
+
11
+ from transformers.modeling_outputs import SequenceClassifierOutputWithPast
12
+
13
+ def register_hgrn2_for_sequence_classification():
14
+ from transformers import AutoModelForSequenceClassification
15
+ AutoModelForSequenceClassification.register(HGRN2Config, HGRN2ForSequenceClassification)
16
+
17
+
18
+ class HGRN2ForSequenceClassification(HGRN2PreTrainedModel):
19
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
20
+
21
+ def __init__(self, config):
22
+ super().__init__(config)
23
+ self.num_labels = config.num_labels
24
+ self.model = HGRN2Model(config)
25
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
26
+
27
+ # Initialize weights and apply final processing
28
+ self.post_init()
29
+
30
+ def get_input_embeddings(self):
31
+ return self.model.embeddings
32
+
33
+ def set_input_embeddings(self, value):
34
+ self.model.embeddings = value
35
+
36
+ def forward(
37
+ self,
38
+ input_ids: torch.LongTensor = None,
39
+ attention_mask: Optional[torch.Tensor] = None,
40
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
41
+ inputs_embeds: Optional[torch.FloatTensor] = None,
42
+ labels: Optional[torch.LongTensor] = None,
43
+ use_cache: Optional[bool] = None,
44
+ output_attentions: Optional[bool] = None,
45
+ output_hidden_states: Optional[bool] = None,
46
+ return_dict: Optional[bool] = None,
47
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
48
+ r"""
49
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
50
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
51
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
52
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
53
+ """
54
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
55
+
56
+ outputs = self.model(
57
+ input_ids=input_ids,
58
+ attention_mask=attention_mask,
59
+ inputs_embeds=inputs_embeds,
60
+ output_attentions=output_attentions,
61
+ use_cache=use_cache,
62
+ past_key_values=past_key_values,
63
+ output_hidden_states=output_hidden_states,
64
+ return_dict=return_dict,
65
+ )
66
+ hidden_states = outputs[0]
67
+ logits = self.score(hidden_states)
68
+
69
+ if input_ids is not None:
70
+ batch_size = input_ids.shape[0]
71
+ else:
72
+ batch_size = inputs_embeds.shape[0]
73
+
74
+ if self.config.pad_token_id is None and batch_size != 1:
75
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
76
+ if self.config.pad_token_id is None:
77
+ sequence_lengths = -1
78
+ else:
79
+ if input_ids is not None:
80
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
81
+ else:
82
+ sequence_lengths = -1
83
+
84
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
85
+
86
+ loss = None
87
+ if labels is not None:
88
+ labels = labels.to(logits.device)
89
+ if self.config.problem_type is None:
90
+ if self.num_labels == 1:
91
+ self.config.problem_type = "regression"
92
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
93
+ self.config.problem_type = "single_label_classification"
94
+ else:
95
+ self.config.problem_type = "multi_label_classification"
96
+
97
+ if self.config.problem_type == "regression":
98
+ loss_fct = MSELoss()
99
+ if self.num_labels == 1:
100
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
101
+ else:
102
+ loss = loss_fct(pooled_logits, labels)
103
+ elif self.config.problem_type == "single_label_classification":
104
+ loss_fct = CrossEntropyLoss()
105
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
106
+ elif self.config.problem_type == "multi_label_classification":
107
+ loss_fct = BCEWithLogitsLoss()
108
+ loss = loss_fct(pooled_logits, labels)
109
+ if not return_dict:
110
+ output = (pooled_logits,) + outputs[1:]
111
+ return ((loss,) + output) if loss is not None else output
112
+
113
+ return SequenceClassifierOutputWithPast(
114
+ loss=loss,
115
+ logits=pooled_logits,
116
+ hidden_states=outputs.hidden_states,
117
+ )