alejandralopezsosa commited on
Commit
ea5bbeb
·
1 Parent(s): 220905e

init commit

Browse files
README.md ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: peft
3
+ datasets:
4
+ - InstaDeepAI/nucleotide_transformer_downstream_tasks_revised
5
+ metrics:
6
+ - f1
7
+ base_model:
8
+ - tattabio/gLM2_150M
9
+ ---
10
+
11
+ # Model Card for Model ID
12
+
13
+ <!-- Provide a quick summary of what the model is/does. -->
14
+
15
+
16
+
17
+ ## Model Details
18
+
19
+ ### Model Description
20
+
21
+ <!-- Provide a longer summary of what this model is. -->
22
+
23
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
24
+
25
+ - **Developed by:** [More Information Needed]
26
+ - **Funded by [optional]:** [More Information Needed]
27
+ - **Shared by [optional]:** [More Information Needed]
28
+ - **Model type:** [More Information Needed]
29
+ - **Language(s) (NLP):** [More Information Needed]
30
+ - **License:** [More Information Needed]
31
+ - **Finetuned from model [optional]:** [More Information Needed]
32
+
33
+ ### Model Sources [optional]
34
+
35
+ <!-- Provide the basic links for the model. -->
36
+
37
+ - **Repository:** [More Information Needed]
38
+ - **Paper [optional]:** [More Information Needed]
39
+ - **Demo [optional]:** [More Information Needed]
40
+
41
+ ## Uses
42
+
43
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
44
+
45
+ ### Direct Use
46
+
47
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
48
+
49
+ [More Information Needed]
50
+
51
+ ### Downstream Use [optional]
52
+
53
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
54
+
55
+ [More Information Needed]
56
+
57
+ ### Out-of-Scope Use
58
+
59
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
60
+
61
+ [More Information Needed]
62
+
63
+ ## Bias, Risks, and Limitations
64
+
65
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
66
+
67
+ [More Information Needed]
68
+
69
+ ### Recommendations
70
+
71
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
72
+
73
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
74
+
75
+ ## How to Get Started with the Model
76
+
77
+ Use the code below to get started with the model.
78
+
79
+ [More Information Needed]
80
+
81
+ ## Training Details
82
+
83
+ ### Training Data
84
+
85
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
86
+
87
+ [More Information Needed]
88
+
89
+ ### Training Procedure
90
+
91
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
92
+
93
+ #### Preprocessing [optional]
94
+
95
+ [More Information Needed]
96
+
97
+
98
+ #### Training Hyperparameters
99
+
100
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
101
+
102
+ #### Speeds, Sizes, Times [optional]
103
+
104
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
105
+
106
+ [More Information Needed]
107
+
108
+ ## Evaluation
109
+
110
+ <!-- This section describes the evaluation protocols and provides the results. -->
111
+
112
+ ### Testing Data, Factors & Metrics
113
+
114
+ #### Testing Data
115
+
116
+ <!-- This should link to a Dataset Card if possible. -->
117
+
118
+ [More Information Needed]
119
+
120
+ #### Factors
121
+
122
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
123
+
124
+ [More Information Needed]
125
+
126
+ #### Metrics
127
+
128
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
129
+
130
+ [More Information Needed]
131
+
132
+ ### Results
133
+
134
+ [More Information Needed]
135
+
136
+ #### Summary
137
+
138
+
139
+
140
+ ## Model Examination [optional]
141
+
142
+ <!-- Relevant interpretability work for the model goes here -->
143
+
144
+ [More Information Needed]
145
+
146
+ ## Environmental Impact
147
+
148
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
149
+
150
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
151
+
152
+ - **Hardware Type:** [More Information Needed]
153
+ - **Hours used:** [More Information Needed]
154
+ - **Cloud Provider:** [More Information Needed]
155
+ - **Compute Region:** [More Information Needed]
156
+ - **Carbon Emitted:** [More Information Needed]
157
+
158
+ ## Technical Specifications [optional]
159
+
160
+ ### Model Architecture and Objective
161
+
162
+ [More Information Needed]
163
+
164
+ ### Compute Infrastructure
165
+
166
+ [More Information Needed]
167
+
168
+ #### Hardware
169
+
170
+ [More Information Needed]
171
+
172
+ #### Software
173
+
174
+ [More Information Needed]
175
+
176
+ ## Citation [optional]
177
+
178
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
179
+
180
+ **BibTeX:**
181
+
182
+ [More Information Needed]
183
+
184
+ **APA:**
185
+
186
+ [More Information Needed]
187
+
188
+ ## Glossary [optional]
189
+
190
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
191
+
192
+ [More Information Needed]
193
+
194
+ ## More Information [optional]
195
+
196
+ [More Information Needed]
197
+
198
+ ## Model Card Authors [optional]
199
+
200
+ [More Information Needed]
201
+
202
+ ## Model Card Contact
203
+
204
+ [More Information Needed]
adapter_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": {
4
+ "base_model_class": "gLM2ForSequenceClassification",
5
+ "parent_library": "__main__"
6
+ },
7
+ "base_model_name_or_path": "tattabio/gLM2_150M",
8
+ "bias": "none",
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 8,
17
+ "lora_dropout": 0.5,
18
+ "megatron_config": null,
19
+ "megatron_core": "megatron.core",
20
+ "modules_to_save": [
21
+ "score"
22
+ ],
23
+ "peft_type": "LORA",
24
+ "r": 16,
25
+ "rank_pattern": {},
26
+ "revision": null,
27
+ "target_modules": [
28
+ "wqkv"
29
+ ],
30
+ "task_type": "SEQ_CLS",
31
+ "use_dora": false,
32
+ "use_rslora": false
33
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6bfd6eb4d53b770e6605905f0ab72a5d6aa29addc6d589302a539516b21a4d8
3
+ size 4926096
config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "gLM2ForSequenceClassification"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "extension_glm2.gLM2ClassicationConfig",
7
+ "AutoModelForSequenceClassification": "extension_glm2.gLM2ForSequenceClassification"
8
+ },
9
+ "depth": 30,
10
+ "dim": 640,
11
+ "ffn_dim_multiplier": null,
12
+ "heads": 10,
13
+ "model_type": "gLM2",
14
+ "norm_eps": 1e-05,
15
+ "num_classes": 2,
16
+ "swiglu_multiple_of": 256,
17
+ "torch_dtype": "float32",
18
+ "transformers_version": "4.44.2",
19
+ "vocab_size": 37
20
+ }
configuration_glm2.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """gLM2 model configuration"""
2
+
3
+ from typing import Optional
4
+ from transformers import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class gLM2Config(PretrainedConfig):
11
+ model_type = "gLM2"
12
+
13
+ def __init__(
14
+ self,
15
+ dim: int = 640,
16
+ depth: int = 30,
17
+ heads: int = 10,
18
+ vocab_size: int = 37,
19
+ swiglu_multiple_of: int = 256,
20
+ ffn_dim_multiplier: Optional[float] = None,
21
+ norm_eps: float = 1e-5,
22
+ **kwargs
23
+ ):
24
+ super().__init__(**kwargs)
25
+ self.dim = dim
26
+ self.depth = depth
27
+ self.heads = heads
28
+ self.vocab_size = vocab_size
29
+ self.swiglu_multiple_of = swiglu_multiple_of
30
+ self.ffn_dim_multiplier = ffn_dim_multiplier
31
+ self.norm_eps = norm_eps
32
+
33
+ self.auto_map = {
34
+ "AutoConfig": "configuration_glm2.gLM2Config",
35
+ "AutoModel": "modeling_glm2.gLM2Model",
36
+ "AutoModelForMaskedLM": "modeling_glm2.gLM2ForMaskedLM"
37
+ }
extension_glm2.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers.modeling_outputs import (
4
+ BaseModelOutput,
5
+ SequenceClassifierOutput,
6
+ )
7
+
8
+ from typing import Optional, Union, Tuple
9
+ from .configuration_glm2 import gLM2Config
10
+ from .modeling_glm2 import gLM2Model, gLM2PreTrainedModel
11
+
12
+ from transformers import PretrainedConfig
13
+ from typing import List
14
+
15
+ class gLM2ClassicationConfig(gLM2Config):
16
+ def __init__(self, num_classes: int = 2, **kwargs):
17
+ super().__init__(**kwargs)
18
+
19
+ self.num_classes = num_classes
20
+
21
+ self.auto_map['AutoModelForSequenceClassification'] = "extension_glm2.gLM2ForSequenceClassification"
22
+
23
+ class gLM2ForSequenceClassification(gLM2PreTrainedModel):
24
+ config_class = gLM2ClassicationConfig
25
+
26
+ def __init__(self, config: gLM2ClassicationConfig):
27
+ super().__init__(config)
28
+
29
+ self.glm2 = gLM2Model(config)
30
+
31
+ self.score = nn.Linear(config.dim, config.num_classes, bias=False)
32
+
33
+ self.post_init()
34
+
35
+ def get_input_embeddings(self):
36
+ return self.glm2.tok_embeddings
37
+
38
+ def set_input_embeddings(self, value):
39
+ self.glm2.tok_embeddings = value
40
+
41
+ def forward(
42
+ self,
43
+ input_ids: torch.Tensor,
44
+ attention_mask: Optional[torch.Tensor] = None,
45
+ labels: Optional[torch.LongTensor] = None,
46
+ output_hidden_states: Optional[bool] = None,
47
+ return_dict: Optional[bool] = None,
48
+ **kwargs,
49
+ ) -> Union[Tuple, SequenceClassifierOutput]:
50
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
51
+
52
+ outputs = self.glm2(
53
+ input_ids,
54
+ attention_mask=attention_mask,
55
+ output_hidden_states=output_hidden_states,
56
+ return_dict=return_dict,
57
+ )
58
+ token_embeddings = outputs[0]
59
+
60
+ # use <+> as CLS token
61
+ cls_token = token_embeddings[:, 0, :]
62
+
63
+ logits = self.score(cls_token)
64
+
65
+ loss = None
66
+ if labels is not None:
67
+ labels = labels.to(logits.device)
68
+
69
+ loss_fct = nn.CrossEntropyLoss()
70
+ loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1))
71
+
72
+ if not return_dict:
73
+ output = (logits,) + outputs[2:]
74
+ return ((loss,) + output) if loss is not None else output
75
+
76
+ return SequenceClassifierOutput(
77
+ loss=loss,
78
+ logits=logits,
79
+ hidden_states=outputs.hidden_states,
80
+ )
modeling_glm2.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch gLM2 model.
2
+
3
+ Some modules adapted from:
4
+ https://github.com/meta-llama/llama/blob/main/llama/model.py
5
+ """
6
+
7
+ import torch
8
+ from einops import rearrange, repeat
9
+ from typing import Optional, Tuple, Union
10
+ from torch import nn
11
+ from torch.nn import CrossEntropyLoss
12
+ from transformers.modeling_outputs import (
13
+ BaseModelOutput,
14
+ MaskedLMOutput,
15
+ )
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+ from .configuration_glm2 import gLM2Config
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ def rotate_half(x, interleaved=False):
24
+ if not interleaved:
25
+ x1, x2 = x.chunk(2, dim=-1)
26
+ return torch.cat((-x2, x1), dim=-1)
27
+ else:
28
+ x1, x2 = x[..., ::2], x[..., 1::2]
29
+ return rearrange(
30
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
31
+ )
32
+
33
+
34
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
35
+ """
36
+ x: (batch_size, seqlen, nheads, headdim)
37
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
38
+ """
39
+ ro_dim = cos.shape[-1] * 2
40
+ assert ro_dim <= x.shape[-1]
41
+ seqlen = x.shape[1]
42
+ cos, sin = cos[:seqlen], sin[:seqlen]
43
+ cos = repeat(
44
+ cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
45
+ )
46
+ sin = repeat(
47
+ sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
48
+ )
49
+ return torch.cat(
50
+ [
51
+ x[..., :ro_dim] * cos +
52
+ rotate_half(x[..., :ro_dim], interleaved) * sin,
53
+ x[..., ro_dim:],
54
+ ],
55
+ dim=-1,
56
+ )
57
+
58
+
59
+ class RotaryEmbedding(torch.nn.Module):
60
+ """
61
+ Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
62
+ Changed to use the torch version of apply_rotary_emb_func.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ dim: int,
68
+ base=10000.0,
69
+ interleaved=False,
70
+ scale_base=None,
71
+ pos_idx_in_fp32=True,
72
+ device=None,
73
+ ):
74
+ super().__init__()
75
+ self.dim = dim
76
+ self.base = float(base)
77
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
78
+ # Generate and save the inverse frequency buffer (non trainable)
79
+ inv_freq = self._compute_inv_freq(device)
80
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
81
+ self.interleaved = interleaved
82
+ self.scale_base = scale_base
83
+ scale = (
84
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
85
+ / (1.4 * dim)
86
+ if scale_base is not None
87
+ else None
88
+ )
89
+ self.register_buffer("scale", scale, persistent=False)
90
+
91
+ self._seq_len_cached = 0
92
+ self._cos_cached = None
93
+ self._sin_cached = None
94
+ self._cos_k_cached = None
95
+ self._sin_k_cached = None
96
+
97
+ def _compute_inv_freq(self, device=None):
98
+ return 1.0 / (
99
+ self.base
100
+ ** (
101
+ torch.arange(0, self.dim, 2, device=device,
102
+ dtype=torch.float32)
103
+ / self.dim
104
+ )
105
+ )
106
+
107
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
108
+ # Reset the tables if the sequence length has changed,
109
+ # if we're on a new device (possibly due to tracing for instance),
110
+ # or if we're switching from inference mode to training
111
+ if (
112
+ seqlen > self._seq_len_cached
113
+ or self._cos_cached is None
114
+ or self._cos_cached.device != device
115
+ or self._cos_cached.dtype != dtype
116
+ or (self.training and self._cos_cached.is_inference())
117
+ ):
118
+ self._seq_len_cached = seqlen
119
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
120
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
121
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
122
+ if self.pos_idx_in_fp32:
123
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
124
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
125
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
126
+ # cos & sin output to change significantly.
127
+ # We want to recompute self.inv_freq if it was not loaded in fp32
128
+ if self.inv_freq.dtype != torch.float32:
129
+ inv_freq = self._compute_inv_freq(device=device)
130
+ else:
131
+ inv_freq = self.inv_freq
132
+ else:
133
+ t = torch.arange(seqlen, device=device,
134
+ dtype=self.inv_freq.dtype)
135
+ inv_freq = self.inv_freq
136
+ # Don't do einsum, it converts fp32 to fp16 under AMP
137
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
138
+ freqs = torch.outer(t, inv_freq)
139
+ if self.scale is None:
140
+ self._cos_cached = torch.cos(freqs).to(dtype)
141
+ self._sin_cached = torch.sin(freqs).to(dtype)
142
+ else:
143
+ power = (
144
+ torch.arange(
145
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
146
+ )
147
+ - seqlen // 2
148
+ ) / self.scale_base
149
+ scale = self.scale.to(device=power.device) ** rearrange(
150
+ power, "s -> s 1"
151
+ )
152
+ # We want the multiplication by scale to happen in fp32
153
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
154
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
155
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
156
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
157
+
158
+ def forward(
159
+ self,
160
+ qkv: torch.Tensor,
161
+ max_seqlen: Optional[int] = None,
162
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
163
+ """
164
+ qkv: (batch, seqlen, 3, nheads, headdim)
165
+ """
166
+ seqlen = qkv.shape[1]
167
+ if seqlen > self._seq_len_cached:
168
+ self._update_cos_sin_cache(
169
+ seqlen, device=qkv.device, dtype=qkv.dtype)
170
+ elif max_seqlen is not None:
171
+ self._update_cos_sin_cache(
172
+ max_seqlen, device=qkv.device, dtype=qkv.dtype)
173
+ q_rot = apply_rotary_emb_torch(
174
+ qkv[:, :, 0], self._cos_cached, self._sin_cached, self.interleaved
175
+ )
176
+ k_rot = apply_rotary_emb_torch(
177
+ qkv[:, :, 1], self._cos_cached, self._sin_cached, self.interleaved
178
+ )
179
+ return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
180
+
181
+
182
+ # @torch.jit.script
183
+ def rmsnorm_func(hidden_states, weight, variance_epsilon):
184
+ """Apply the root mean square normalization."""
185
+ input_dtype = hidden_states.dtype
186
+ hidden_states = hidden_states.to(torch.float32)
187
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
188
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
189
+ return (weight * hidden_states).to(input_dtype)
190
+
191
+
192
+ class RMSNorm(nn.Module):
193
+ """Root mean square normalization."""
194
+
195
+ def __init__(self, dim, eps=1e-6):
196
+ super().__init__()
197
+ self.weight = nn.Parameter(torch.ones(dim))
198
+ self.register_buffer(
199
+ "variance_epsilon",
200
+ torch.tensor(eps),
201
+ persistent=False,
202
+ )
203
+
204
+ def forward(self, hidden_states):
205
+ return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)
206
+
207
+
208
+ class Attention(nn.Module):
209
+ """Multi-head attention module."""
210
+
211
+ def __init__(self, config: gLM2Config):
212
+ super().__init__()
213
+ self.n_heads = config.heads
214
+ self.head_dim = config.dim // config.heads
215
+
216
+ self.wqkv = nn.Linear(config.dim, self.n_heads *
217
+ self.head_dim * 3, bias=False)
218
+ self.wo = nn.Linear(config.heads * self.head_dim,
219
+ config.dim, bias=False)
220
+
221
+ self.rotary_emb = RotaryEmbedding(self.head_dim)
222
+
223
+ def forward(
224
+ self,
225
+ x: torch.Tensor,
226
+ attention_mask: Optional[torch.Tensor] = None,
227
+ ) -> torch.Tensor:
228
+ bsz, seqlen, h_size = x.shape
229
+ qkv = self.wqkv(x)
230
+
231
+ qkv = qkv.view(bsz, seqlen, 3, self.n_heads, self.head_dim)
232
+ qkv = self.rotary_emb(qkv)
233
+
234
+ # (batch, nheads, 3, seqlen, headdim)
235
+ qkv = torch.transpose(qkv, 3, 1)
236
+ q = qkv[:, :, 0]
237
+ k = qkv[:, :, 1]
238
+ v = qkv[:, :, 2]
239
+ if attention_mask is not None:
240
+ attention_mask = attention_mask[:, None, None, :]
241
+ attention_mask = attention_mask.expand(
242
+ bsz, self.n_heads, seqlen, seqlen
243
+ ).bool()
244
+ # [B, heads, seq, D]
245
+ output = torch.nn.functional.scaled_dot_product_attention(
246
+ q, k, v, attn_mask=attention_mask
247
+ )
248
+ output = output.permute(0, 2, 1, 3).contiguous()
249
+
250
+ output = output.view(bsz, seqlen, h_size)
251
+ return self.wo(output)
252
+
253
+
254
+ class FeedForward(nn.Module):
255
+ def __init__(
256
+ self,
257
+ dim: int,
258
+ hidden_dim: int,
259
+ multiple_of: int,
260
+ ffn_dim_multiplier: Optional[float],
261
+ ):
262
+ """
263
+ SwiGLU FeedForward module.
264
+
265
+ Args:
266
+ dim (int): Input dimension.
267
+ hidden_dim (int): Hidden dimension of the feedforward layer.
268
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
269
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
270
+ """
271
+ super().__init__()
272
+ hidden_dim = int(2 * hidden_dim / 3)
273
+ # custom dim factor multiplier
274
+ if ffn_dim_multiplier is not None:
275
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
276
+ hidden_dim = multiple_of * \
277
+ ((hidden_dim + multiple_of - 1) // multiple_of)
278
+
279
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
280
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
281
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
282
+
283
+ def forward(self, x):
284
+ return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
285
+
286
+
287
+ class TransformerBlock(nn.Module):
288
+ def __init__(self, config: gLM2Config):
289
+ super().__init__()
290
+ self.n_heads = config.heads
291
+ self.dim = config.dim
292
+ self.head_dim = config.dim // config.heads
293
+ self.attention = Attention(config)
294
+ self.feed_forward = FeedForward(
295
+ dim=config.dim,
296
+ hidden_dim=4 * config.dim,
297
+ multiple_of=config.swiglu_multiple_of,
298
+ ffn_dim_multiplier=config.ffn_dim_multiplier,
299
+ )
300
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
301
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
302
+
303
+ def forward(
304
+ self,
305
+ x: torch.Tensor,
306
+ attention_mask: Optional[torch.Tensor] = None,
307
+ ) -> torch.Tensor:
308
+ r = self.attention(self.attention_norm(
309
+ x), attention_mask=attention_mask)
310
+ h = x + r
311
+ r = self.feed_forward(self.ffn_norm(h))
312
+ out = h + r
313
+ return out
314
+
315
+
316
+ class TransformerLayers(nn.Module):
317
+ def __init__(self, config: gLM2Config):
318
+ super().__init__()
319
+ self.config = config
320
+ self.layers = torch.nn.ModuleList(
321
+ [TransformerBlock(config=config) for _ in range(config.depth)]
322
+ )
323
+
324
+ def forward(
325
+ self,
326
+ x: torch.FloatTensor,
327
+ attention_mask: Optional[torch.BoolTensor] = None,
328
+ return_all_hiddens: bool = False,
329
+ ):
330
+ if x.shape[-1] != self.config.dim:
331
+ raise ValueError(
332
+ f"Input feature dim should be {self.config.dim}, but input has shape {x.shape}"
333
+ )
334
+ hiddens = []
335
+ for layer in self.layers:
336
+ x = layer(x, attention_mask=attention_mask)
337
+ if return_all_hiddens:
338
+ hiddens.append(x)
339
+
340
+ if return_all_hiddens:
341
+ return x, hiddens
342
+ return x
343
+
344
+
345
+ class gLM2PreTrainedModel(PreTrainedModel):
346
+ """
347
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
348
+ models.
349
+ """
350
+ config_class = gLM2Config
351
+ base_model_prefix = "glm2"
352
+ supports_gradient_checkpointing = False
353
+
354
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
355
+ def _init_weights(module, initializer_range=0.02):
356
+ if isinstance(module, nn.Linear):
357
+ nn.init.normal_(module.weight, std=initializer_range)
358
+ if module.bias is not None:
359
+ nn.init.zeros_(module.bias)
360
+ elif isinstance(module, nn.Embedding):
361
+ nn.init.normal_(module.weight, std=initializer_range)
362
+ if module.padding_idx is not None:
363
+ nn.init.zeros_(module.weight[module.padding_idx])
364
+
365
+
366
+ class gLM2Model(gLM2PreTrainedModel):
367
+ """gLM2 Model."""
368
+
369
+ def __init__(self, config: gLM2Config):
370
+ super().__init__(config)
371
+ self.config = config
372
+
373
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
374
+ self.encoder = TransformerLayers(config)
375
+ # Initialize weights and apply final processing
376
+ self.post_init()
377
+
378
+ def forward(
379
+ self,
380
+ input_ids: torch.Tensor,
381
+ attention_mask: Optional[torch.Tensor] = None,
382
+ output_hidden_states: Optional[bool] = None,
383
+ return_dict: Optional[bool] = None,
384
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
385
+ output_hidden_states = (
386
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
387
+ )
388
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
389
+
390
+ h = self.tok_embeddings(input_ids)
391
+ if output_hidden_states:
392
+ sequence_output, all_hidden_states = self.encoder(
393
+ h, attention_mask, return_all_hiddens=True)
394
+ else:
395
+ sequence_output = self.encoder(h, attention_mask)
396
+ all_hidden_states = None
397
+
398
+ if not return_dict:
399
+ return (sequence_output, all_hidden_states)
400
+
401
+ return BaseModelOutput(
402
+ last_hidden_state=sequence_output,
403
+ hidden_states=all_hidden_states,
404
+
405
+ )
406
+
407
+
408
+ class gLM2ForMaskedLM(gLM2PreTrainedModel):
409
+
410
+ def __init__(self, config: gLM2Config):
411
+ super().__init__(config)
412
+
413
+ self.glm2 = gLM2Model(config)
414
+ self.lm_head = gLM2LMHead(config)
415
+ self.init_weights()
416
+
417
+ def forward(
418
+ self,
419
+ input_ids: torch.Tensor,
420
+ attention_mask: Optional[torch.Tensor] = None,
421
+ labels: Optional[torch.LongTensor] = None,
422
+ output_hidden_states: Optional[bool] = None,
423
+ return_dict: Optional[bool] = None,
424
+ ) -> Union[Tuple, MaskedLMOutput]:
425
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
426
+
427
+ outputs = self.glm2(
428
+ input_ids,
429
+ attention_mask=attention_mask,
430
+ output_hidden_states=output_hidden_states,
431
+ return_dict=return_dict,
432
+ )
433
+ sequence_output = outputs[0]
434
+ prediction_scores = self.lm_head(sequence_output)
435
+
436
+ masked_lm_loss = None
437
+ if labels is not None:
438
+ loss_fct = CrossEntropyLoss()
439
+
440
+ labels = labels.to(prediction_scores.device)
441
+ masked_lm_loss = loss_fct(
442
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
443
+
444
+ if not return_dict:
445
+ output = (prediction_scores,) + outputs[2:]
446
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
447
+
448
+ return MaskedLMOutput(
449
+ loss=masked_lm_loss,
450
+ logits=prediction_scores,
451
+ hidden_states=outputs.hidden_states,
452
+ attentions=outputs.attentions,
453
+ )
454
+
455
+
456
+ class gLM2LMHead(nn.Module):
457
+ """gLM2 head for masked language modeling."""
458
+
459
+ def __init__(self, config):
460
+ super().__init__()
461
+
462
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
463
+ self.proj_output = nn.Linear(
464
+ config.dim, config.vocab_size, bias=False)
465
+
466
+ def forward(self, features):
467
+ return self.proj_output(self.norm(features))